In [1]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import norm
import torch as to
import torch.nn as nn
from torch.nn import Parameter
import sys
sys.path.append('../')
import probtorch
from probtorch.util import expand_inputs
print('probtorch:', probtorch.__version__, 
      'torch:', to.__version__, 
      'cuda:', to.cuda.is_available())
from util_data import *
from amorgibbs_v import *
from smc_v import *
from util_plots import *
from util_lstm import *

probtorch: 0.0+f9f5c9b torch: 0.4.1 cuda: True


# Tests

In [5]:
INPUT_DIM = 4
K = 10
HIDDEN_DIM = K*K
BATCH_SIZE = 20
NUM_EPOCHS = 200
TARGET_DIM = K*K
DATASET_SIZE = BATCH_SIZE * 20

CUDA = to.cuda.is_available()
RESTORE = False
MODEL_NAME = 'lstm'
DATA_PATH = './data'
WEIGHTS_PATH = './lstm_weights'

LEARNING_RATE = 1e-3


# Parameters for data generation
boundary = 32
noise_cov = np.eye(2)*0.5
T_min = 30
T_max = 50
dt = 10
init_v = init_velocity(dt)

In [6]:
def kl_dirichlet(alpha1, alpha2):
    A = to.lgamma(alpha1.sum()) - to.lgamma(alpha2.sum())
    B = (to.lgamma(alpha1) - to.lgamma(alpha2)).sum()
    C = (to.mul(alpha1 - alpha2, to.digamma(alpha1) - to.digamma(alpha1.sum()))).sum()
    kl = A - B + C
    return kl

def MKL(alphas_pred, alpha_true):
    kkl = 0
    for i in range(alphas_pred.size()[0]):
        for k in range(K): 
            kkl += kl_dirichlet(alphas_pred[i, k*K:(k+1)*K], alpha_true[k*K:(k+1)*K])
            #print(alphas_pred[i, k*K:(k+1)*K], alpha_true[k*K:(k+1)*K], 'asd')
    return kkl/K

alpha_true = to.ones(K*K) / K
loss_fn = lambda alpha: MKL(alpha, alpha_true)

In [7]:
import time
import os
from random import random

Seq, Len = genSeq(T_min, T_max, dt, init_v, noise_cov, boundary, DATASET_SIZE)
batches = PackSeq(Seq, Len, BATCH_SIZE) 

model = LSTM(INPUT_DIM, HIDDEN_DIM, BATCH_SIZE, TARGET_DIM)
optimizer = to.optim.Adam(model.parameters(), )


if not RESTORE:
    mask = {}
    for e in range(NUM_EPOCHS):
        train_start = time.time()
        train_loss = train_epoch(batches, model, optimizer, loss_fn)
        train_end = time.time()
        #test_start = time.time()
        #test_loss = test(test_loader, model)
        #test_end = time.time()
        print('[Epoch %d] Train: Loss %.4e (%ds)' % (e, train_loss, train_end - train_start))

    if not os.path.isdir(WEIGHTS_PATH):
        os.mkdir(WEIGHTS_PATH)
    to.save(model.state_dict(),
               '%s/%s-%s-%s-enc.rar' % (WEIGHTS_PATH, MODEL_NAME, probtorch.__version__, to.__version__))

[Epoch 0] Train: Loss 9.3842e+01 (0s)
[Epoch 1] Train: Loss 7.0244e+01 (0s)
[Epoch 2] Train: Loss 5.6478e+01 (0s)
[Epoch 3] Train: Loss 5.2409e+01 (0s)
[Epoch 4] Train: Loss 5.0418e+01 (0s)
[Epoch 5] Train: Loss 4.9678e+01 (0s)
[Epoch 6] Train: Loss 4.8531e+01 (0s)
[Epoch 7] Train: Loss 4.7808e+01 (0s)
[Epoch 8] Train: Loss 4.7414e+01 (0s)
[Epoch 9] Train: Loss 4.6987e+01 (0s)
[Epoch 10] Train: Loss 4.6659e+01 (0s)
[Epoch 11] Train: Loss 4.6386e+01 (0s)
[Epoch 12] Train: Loss 4.5939e+01 (0s)
[Epoch 13] Train: Loss 4.5810e+01 (0s)
[Epoch 14] Train: Loss 4.5793e+01 (0s)
[Epoch 15] Train: Loss 4.5501e+01 (0s)
[Epoch 16] Train: Loss 4.5229e+01 (0s)
[Epoch 17] Train: Loss 4.5221e+01 (0s)
[Epoch 18] Train: Loss 4.5216e+01 (0s)
[Epoch 19] Train: Loss 4.5213e+01 (0s)
[Epoch 20] Train: Loss 4.5210e+01 (0s)
[Epoch 21] Train: Loss 4.5207e+01 (0s)
[Epoch 22] Train: Loss 4.5205e+01 (0s)
[Epoch 23] Train: Loss 4.5203e+01 (0s)
[Epoch 24] Train: Loss 4.5201e+01 (0s)
[Epoch 25] Train: Loss 4.5199e+01 (

KeyboardInterrupt: 