# Train script for Advanced RNN VAE model

### imports

In [1]:
%matplotlib inline

from torch.autograd import Variable
from torch import optim
from torch.utils.data import DataLoader
import torch

import sys
sys.path.append("../Modules")
sys.path.append("../ToyDatasets")

# local imports
from train import train
from helpers import kl_loss,nll_loss,mse_loss,kl_loss_multi
from models import AdvancedRNNVAE
from timeSeries import Sinusoids

## Define dataset loader

In [2]:
batch_size = 256
num_steps = 16
dataset_size = 5000
num_classes = 10

data_loader = DataLoader(Sinusoids(num_steps,virtual_size=dataset_size,quantization=num_classes),batch_size=batch_size,shuffle=True)
valid_data_loader = DataLoader(Sinusoids(num_steps,virtual_size=dataset_size,quantization=num_classes),batch_size=batch_size,shuffle=True)

batch_loader = iter(data_loader)
valid_batch_loader = iter(valid_data_loader)

## Define model

In [3]:
model = AdvancedRNNVAE(input_size=1,rnn_size=256,latent_size=64,output_size=num_classes,use_softmax=True,bidirectional=True)
print("Number of trainable parameters {}".format(sum(p.numel() for p in model.parameters() if p.requires_grad)))

# test forward pass
try:
    initial_batch = batch_loader.next()
    x = Variable(initial_batch).type(torch.FloatTensor).transpose(1,0)
    test,_ = model(x)
    assert (test.shape[0] == x.shape[0] and test.shape[1] == x.shape[1])
    print("Forward pass succesfull")
except:
    print("Error in forward pass. Output should have shape: {} but had {}".format(x.contiguous().view(-1).shape,test.view(-1).shape))

Number of trainable parameters 878218
Forward pass succesfull


## Define optimizer and loss

In [4]:
learning_rate = 1e-4
optimizer = optim.Adam(model.parameters(),lr=learning_rate,weight_decay=1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode="min",factor=0.5,verbose=True,cooldown=5)

## Define trainer

In [5]:
# use at least 50 iterations to get good results
train(data_loader,dataset_size,valid_batch_loader,model,optimizer,scheduler,nll_loss,kl_loss_multi,n_iters=50,use_softmax=True,print_every=1)


Train (1 2%) loss: 2.2408 r_loss: 2.2408 kl: 14.3251 aux_loss: 0.0000 beta 0.00e+00
Valid (1 2%) loss: 2.1608 r_loss: 2.1608 kl: 27.5486 aux_loss: 0.0000 beta 0.00e+00

Train (2 4%) loss: 2.1198 r_loss: 2.0904 kl: 36.0589 aux_loss: 0.0000 beta 8.16e-04
Valid (2 4%) loss: 2.0467 r_loss: 2.0169 kl: 36.5102 aux_loss: 0.0000 beta 8.16e-04

Train (3 6%) loss: 1.9886 r_loss: 1.9379 kl: 31.0342 aux_loss: 0.0000 beta 1.63e-03
Valid (3 6%) loss: 1.8473 r_loss: 1.8040 kl: 26.5238 aux_loss: 0.0000 beta 1.63e-03

Train (4 8%) loss: 1.7724 r_loss: 1.7077 kl: 26.4011 aux_loss: 0.0000 beta 2.45e-03
Valid (4 8%) loss: 1.6358 r_loss: 1.5632 kl: 29.6703 aux_loss: 0.0000 beta 2.45e-03


KeyboardInterrupt: 

In [None]:
#torch.save(model.state_dict(),"Saved_models/RNNVAE_nll_map.pt")