# Train script for Advanced Hybrid VAE model

### imports

In [1]:
%matplotlib inline

from torch.autograd import Variable
from torch import optim
from torch.utils.data import DataLoader
import torch

import sys
sys.path.append("../Modules")
sys.path.append("../ToyDatasets")

# local imports
from train import train
from helpers import kl_loss,nll_loss,mse_loss,kl_loss_multi
from models import AdvancedHybridVAE
from timeSeries import Sinusoids

## Define dataset loader

In [2]:
batch_size = 256
num_steps = 16
dataset_size = 5000
num_classes = 10

data_loader = DataLoader(Sinusoids(num_steps,virtual_size=dataset_size,quantization=num_classes),batch_size=batch_size,shuffle=True)
valid_data_loader = DataLoader(Sinusoids(num_steps,virtual_size=dataset_size,quantization=num_classes),batch_size=batch_size,shuffle=True)

batch_loader = iter(data_loader)
valid_batch_loader = iter(valid_data_loader)

## Define model

In [3]:
model = AdvancedHybridVAE(input_size=1,conv_size=128,rnn_size=128,latent_size=32,output_size=num_classes,use_softmax=True)
print("Number of trainable parameters {}".format(sum(p.numel() for p in model.parameters() if p.requires_grad)))

# test forward pass
try:
    initial_batch = batch_loader.next()
    x = Variable(initial_batch).type(torch.FloatTensor).transpose(1,0)
    test,_,_ = model(x)
    assert (test.shape[0] == x.shape[0] and test.shape[1] == x.shape[1] and test.shape[2] == num_classes)
    print("Forward pass succesfull")
except:
    print("Error in forward pass. Output should have shape: {} but had {}".format(x.contiguous().view(-1).shape,test.view(-1).shape))

Number of trainable parameters 269396
Forward pass succesfull


## Define optimizer and loss

In [4]:
learning_rate = 1e-4
optimizer = optim.Adam(model.parameters(),lr=learning_rate,weight_decay=1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode="min",factor=0.5,verbose=True,cooldown=5)

## Run trainer

In [5]:
# use at least 50 iterations to get good results
train(data_loader,dataset_size,valid_batch_loader,model,optimizer,scheduler,nll_loss,kl_loss_multi,n_iters=50,use_softmax=True,print_every=1)


Train (1 2%) loss: 2.9225 r_loss: 2.2511 kl: 81.5021 aux_loss: 2.2380 beta 0.00e+00
Valid (1 2%) loss: 2.8605 r_loss: 2.2129 kl: 227.8517 aux_loss: 2.1588 beta 0.00e+00

Train (2 4%) loss: 2.9080 r_loss: 2.2010 kl: 90.7534 aux_loss: 2.1097 beta 8.16e-04
Valid (2 4%) loss: 2.8276 r_loss: 2.1951 kl: 13.6367 aux_loss: 2.0710 beta 8.16e-04

Train (3 6%) loss: 2.7886 r_loss: 2.1870 kl: 14.1670 aux_loss: 1.9282 beta 1.63e-03
Valid (3 6%) loss: 2.6775 r_loss: 2.1317 kl: 18.7978 aux_loss: 1.7168 beta 1.63e-03

Train (4 8%) loss: 2.6308 r_loss: 2.0763 kl: 19.6007 aux_loss: 1.6883 beta 2.45e-03
Valid (4 8%) loss: 2.4770 r_loss: 1.9447 kl: 20.7031 aux_loss: 1.6055 beta 2.45e-03


KeyboardInterrupt: 

In [None]:
#torch.save(model.state_dict(),"Saved_models/HybridVAE_nll_map.pt")