In [1]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import norm
import torch as to
import torch.nn as nn
from torch.nn import Parameter
import sys
sys.path.append('../')
import probtorch
from probtorch.util import expand_inputs
print('probtorch:', probtorch.__version__, 
      'torch:', to.__version__, 
      'cuda:', to.cuda.is_available())
from util_data import *
from amorgibbs_v import *
from smc_v import *
from util_plots import *
from util_lstm import *

probtorch: 0.0+f9f5c9b torch: 0.4.1 cuda: True


# Tests

In [2]:
INPUT_DIM = 4
K = 10
HIDDEN_DIM = K*K
BATCH_SIZE = 20
NUM_EPOCHS = 200
DATASET_SIZE = BATCH_SIZE * 20

CUDA = to.cuda.is_available()
RESTORE = False
MODEL_NAME = 'lstm'
DATA_PATH = './data'
WEIGHTS_PATH = './lstm_weights'

LEARNING_RATE = 1e-3


# Parameters for data generation
boundary = 32
noise_cov = np.eye(2)*0.5
T_min = 30
T_max = 50
dt = 10
init_v = init_velocity(dt)

In [3]:
def kl_dirichlet(alpha1, alpha2):
    A = to.lgamma(alpha1.sum()) - to.lgamma(alpha2.sum())
    B = (to.lgamma(alpha1) - to.lgamma(alpha2)).sum()
    C = (to.mul(alpha1 - alpha2, to.digamma(alpha1) - to.digamma(alpha1.sum()))).sum()
    kl = A - B + C
    return kl

def MKL(alphas_pred, alpha_true):
    kkl = 0
    for i in range(alphas_pred.size()[0]):
        for k in range(K): 
            kkl += kl_dirichlet(alphas_pred[i, k*K:(k+1)*K], alpha_true[k*K:(k+1)*K])
            #print(alphas_pred[i, k*K:(k+1)*K], alpha_true[k*K:(k+1)*K], 'asd')
    return kkl/K

alpha_true = to.ones(K*K) / K
loss_fn = lambda alpha: MKL(alpha, alpha_true)

In [4]:
import time
import os
from random import random

Seq, Len = genSeq(T_min, T_max, dt, init_v, noise_cov, boundary, DATASET_SIZE)
batches = PackSeq(Seq, Len, BATCH_SIZE) 

model = LSTM(INPUT_DIM, HIDDEN_DIM, BATCH_SIZE)
optimizer = to.optim.Adam(model.parameters(), )


if not RESTORE:
    mask = {}
    for e in range(NUM_EPOCHS):
        train_start = time.time()
        train_loss = train_epoch(batches, model, optimizer, loss_fn)
        train_end = time.time()
        #test_start = time.time()
        #test_loss = test(test_loader, model)
        #test_end = time.time()
        print('[Epoch %d] Train: Loss %.4e (%ds)' % (e, train_loss, train_end - train_start))

    if not os.path.isdir(WEIGHTS_PATH):
        os.mkdir(WEIGHTS_PATH)
    to.save(model.state_dict(),
               '%s/%s-%s-%s-enc.rar' % (WEIGHTS_PATH, MODEL_NAME, probtorch.__version__, to.__version__))

[Epoch 0] Train: Loss 9.1755e+01 (1s)
[Epoch 1] Train: Loss 6.5567e+01 (1s)
[Epoch 2] Train: Loss 5.3313e+01 (1s)
[Epoch 3] Train: Loss 4.9926e+01 (1s)
[Epoch 4] Train: Loss 4.8061e+01 (0s)
[Epoch 5] Train: Loss 4.7370e+01 (1s)
[Epoch 6] Train: Loss 4.7104e+01 (1s)
[Epoch 7] Train: Loss 4.6994e+01 (1s)
[Epoch 8] Train: Loss 4.6502e+01 (1s)
[Epoch 9] Train: Loss 4.5967e+01 (1s)
[Epoch 10] Train: Loss 4.5323e+01 (1s)
[Epoch 11] Train: Loss 4.5276e+01 (1s)
[Epoch 12] Train: Loss 4.5251e+01 (1s)
[Epoch 13] Train: Loss 4.5165e+01 (1s)
[Epoch 14] Train: Loss 4.4775e+01 (1s)
[Epoch 15] Train: Loss 4.4682e+01 (1s)
[Epoch 16] Train: Loss 4.4672e+01 (1s)
[Epoch 17] Train: Loss 4.4666e+01 (1s)
[Epoch 18] Train: Loss 4.4661e+01 (1s)
[Epoch 19] Train: Loss 4.4657e+01 (1s)
[Epoch 20] Train: Loss 4.4653e+01 (1s)
[Epoch 21] Train: Loss 4.4650e+01 (1s)
[Epoch 22] Train: Loss 4.4647e+01 (1s)
[Epoch 23] Train: Loss 4.4645e+01 (1s)
[Epoch 24] Train: Loss 4.4643e+01 (1s)
[Epoch 25] Train: Loss 4.4641e+01 (

KeyboardInterrupt: 