# Sample script for Hybrid VAE model

### imports

In [1]:
%matplotlib inline

from torch.autograd import Variable
from torch.utils.data import DataLoader
import torch
import matplotlib.pyplot as plt
import numpy as np

import sys
sys.path.append("../../Modules")
sys.path.append("../../Datasets")

# local imports
from visualize import printText
from models import LadderVAE
from babelDatasets.sentiment140 import Sentiment140
from babelDatasets.utils import padding_merge

## Define dataset loader

In [2]:
batch_size = 32

dataset = Sentiment140(data_directory="../../Datasets/Data",train=False,max_sentence_size=32)
data_loader = DataLoader(dataset,batch_size=batch_size,shuffle=True,collate_fn=padding_merge)

num_classes = 82

batch_loader = iter(data_loader)
# input has shape [batch_size,seq_len,1]

# N sents: 52990  train: False  sentences_path: ../../Datasets/Data/test_sentences.txt


## Define and load model

In [3]:
model = LadderVAE(input_size=1,hidden_sizes=[512,256,128],latent_sizes=[64,32,16],recon_hidden_size=256,output_size=num_classes,use_softmax=True)
model.eval()
model.load_state_dict(torch.load("../../Saved_models/Sentiment140/LVAE_nll_map.pt"))

model_big = LadderVAE(input_size=1,hidden_sizes=[512,256,128,64,32],latent_sizes=[128,64,32,16,8],recon_hidden_size=256,output_size=num_classes,use_softmax=True)
model_big.eval()
model_big.load_state_dict(torch.load("../../Saved_models/Sentiment140/LVAE_nll_map_5l.pt"))

## Collect some samples and plot

In [12]:
def sampleCompare(batch_loader,model,data_decoder):
    sampled_data = {"x":[],"recon_x":[]}
    try:
        batch = next(batch_loader)[0]
    except StopIteration:
        batch_loader = iter(data_loader)
        batch = next(batch_loader)[0]
    x = Variable(torch.FloatTensor(batch)).unsqueeze(2).transpose(1,0)
    size = (x.size()[0],x.size()[1],model.sample_size)
    z = 2*Variable(torch.FloatTensor(*size).normal_(),requires_grad=False)
    recon_x = model.sample(z)
    _,topi = recon_x.data.topk(1)
    pred_x = Variable(topi)
    sampled_data["x"] = x.squeeze(2)
    sampled_data["recon_x"] = pred_x.squeeze(2)
    printText(sampled_data,data_decoder,recon_x_text="Sample")

In [13]:
sampleCompare(batch_loader,model,dataset.encoderDecoder)


True:
 @user thanks!******************
Sample:
  yu d Iyhryi *ro*it ***y*is'*y?

True:
 @user Morning hun**************
Sample:
 wh T *yyy m ****xu*****lrr*****

True:
 @user how precious! congrats***
Sample:
 *s ww yo ii t   *w *yat ae *oeo

True:
 @user Happy Birthday Alice!****
Sample:
 syayy  *iu lyn ***y*w*w*yay*y**

True:
 LOVED Up!! especially in 3D!***
Sample:
 w. ievsad a ry f **sr*y***ye***

True:
 @user you're silly*************
Sample:
 @x  Ce s n slr o  wy       uavy

True:
 Watching the mtv movie awards**
Sample:
 *yyo w y y h  sayek worour *yhe

True:
 @user looking forward to it****
Sample:
 I Au gr woi     forrt  t wlr ly

True:
 up and about!******************
Sample:
 siy iseyy**i y l sw ***y*******

True:
 says i had fun last night  @url
Sample:
 ww i  syyc wwor  @o yiy il ya y

True:
 I have candy*******************
Sample:
 **wii y ryi ******t **tl yu Iui


In [14]:
sampleCompare(batch_loader,model_big,dataset.encoderDecoder)


True:
 Back from beach, good day******
Sample:
 *h Tc Twpa*A*T**T*TOT**o *iO***

True:
 @user South America loves you**
Sample:
 *ia********uu*a *aeea*****Iayar

True:
 @user of course****************
Sample:
 Itm   euai  t   T  ma  @ ro   a

True:
 The nyc skyline is unreal******
Sample:
 *b  Te *I*T****h*hOOy*******T**

True:
 Damn. You are cool*************
Sample:
      h c T*TET I **hs*A*HO*Te**

True:
 Apprentice and BB night tonight
Sample:
 M   aO*I Ta **i**hn*T wa*T **Tw

True:
 @user how short****************
Sample:
 TeTT *D*Ieh T**T Ntckan*areS*I 

True:
 thank god. itunes is now sorted
Sample:
  tc shato***I'th *ae*Te*****a**

True:
 @user what did you think*******
Sample:
 *ie  **T hre*w *To wayehek w **

True:
 Watching my favorite show @user
Sample:
 TrPvi.*I a********T******TO*Hi 

True:
 @user agreed*******************
Sample:
 **s ****uir****Tus*T*Toa****TTy
