In [4]:
from VGDLData.ptb import PTB

import torch
from loss import VAE_Loss
from model import LSTM_VAE
from train import Trainer

from settings import global_setting, model_setting, training_setting

from utils import  interpolate, plot_elbo, get_latent_codes, visualize_latent_codes

In [9]:
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(global_setting["seed"])

batch_size = training_setting["batch_size"]
bptt =  training_setting["bptt"]
embed_size = model_setting["embed_size"]
hidden_size = model_setting["hidden_size"]
latent_size =   model_setting["latent_size"]
lr = training_setting["lr"]



# Load the data
train_data = PTB(data_dir="./VGDLDataGeneralized", split="train", create_data= False, max_sequence_length= bptt)
test_data = PTB(data_dir="./VGDLDataGeneralized", split="test", create_data= False, max_sequence_length=bptt)
valid_data = PTB(data_dir="./VGDLDataGeneralized", split="valid", create_data= False, max_sequence_length= bptt)

# Batchify the data
train_loader = torch.utils.data.DataLoader( dataset= train_data, batch_size=batch_size, shuffle= True)
test_loader = torch.utils.data.DataLoader( dataset= test_data, batch_size= batch_size, shuffle= True)
valid_loader = torch.utils.data.DataLoader( dataset= valid_data, batch_size= batch_size, shuffle= True)



vocab_size = train_data.vocab_size
model = LSTM_VAE(vocab_size = vocab_size, embed_size = embed_size, hidden_size = hidden_size, latent_size = latent_size).to(device)

Loss = VAE_Loss()
optimizer = torch.optim.Adam(model.parameters(), lr= training_setting["lr"])

trainer = Trainer(train_loader, test_loader, model, Loss, optimizer)

In [14]:
train_losses = []
test_losses = []
for epoch in range(training_setting["epochs"]):
    print("Epoch: ", epoch)
    print("Training.......")
    train_losses = trainer.train(train_losses, epoch, training_setting["batch_size"], training_setting["clip"])
    print("Testing.......")
    test_losses = trainer.test(test_losses, epoch, training_setting["batch_size"])
    if epoch % 50 == 0:
        torch.save(model.state_dict(), "models/VGDL_VAE_GENERALIZED_" + str(epoch) + ".pt")


plot_elbo(train_losses, "train")
plot_elbo(test_losses, "test")

torch.save(model.state_dict(), "models/VGDL_VAE_GENERALIZED.pt")


Epoch:  0
Training.......
| epoch   0 | elbo_loss 3.976975 | kl_loss 0.000194 | recons_loss 3.976781 
| epoch   0 | elbo_loss 4.357484 | kl_loss 0.000216 | recons_loss 4.357268 
| epoch   0 | elbo_loss 3.994875 | kl_loss 0.000182 | recons_loss 3.994692 
| epoch   0 | elbo_loss 4.106413 | kl_loss 0.000206 | recons_loss 4.106207 
Testing.......
Epoch:  1
Training.......
| epoch   1 | elbo_loss 3.908797 | kl_loss 0.000529 | recons_loss 3.908268 
| epoch   1 | elbo_loss 4.095924 | kl_loss 0.000175 | recons_loss 4.095749 
| epoch   1 | elbo_loss 4.379000 | kl_loss 0.000228 | recons_loss 4.378771 
| epoch   1 | elbo_loss 4.214752 | kl_loss 0.000206 | recons_loss 4.214546 
Testing.......
Epoch:  2
Training.......
| epoch   2 | elbo_loss 4.035322 | kl_loss 0.000152 | recons_loss 4.035170 
| epoch   2 | elbo_loss 3.975257 | kl_loss 0.000438 | recons_loss 3.974818 
| epoch   2 | elbo_loss 4.343894 | kl_loss 0.000173 | recons_loss 4.343722 
| epoch   2 | elbo_loss 4.006879 | kl_loss 0.000146 | re

: 

: 