In [None]:
import sys
sys.path.append('..')

In [None]:
import torch
from random import randint

In [None]:
from ptb import PTB
from linear.model import RNNVAE
from utils import transform, interpolate

In [None]:
# device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
# Penn TreeBank (PTB) dataset
data_path = '../data'
max_len = 64
dataset = PTB(root=data_path, split='train')
idx_to_word = dataset.idx_to_word
symbols = dataset.symbols

In [None]:
# model setting
embedding_size = 300
hidden_size = 256
latent_dim = 16
dropout_rate = 0.5

In [None]:
# load the trained annealing model
annealing_vae = RNNVAE(vocab_size=dataset.vocab_size,
                       embed_size=embedding_size,
                       time_step=max_len,
                       hidden_size=hidden_size,
                       z_dim=latent_dim,
                       dropout_rate=dropout_rate,
                       bos_idx=symbols['<bos>'],
                       eos_idx=symbols['<eos>'],
                       pad_idx=symbols['<pad>'])
annealing_vae_checkpoint_path = 'linear/E19.pkl'
annealing_vae.load_state_dict(torch.load(annealing_vae_checkpoint_path))
annealing_vae = annealing_vae.to(device)
annealing_vae.eval()
print("Annealing VAE loaded from %s" % annealing_vae_checkpoint_path)

In [None]:
# show interpolation function
def show(begin, end, interpolation):
    print(begin)
    print('-' * 80)
    print(*transform(interpolation.cpu().numpy(),
                     idx_to_word=idx_to_word,
                     eos_idx=symbols['<eos>']), sep='\n')
    print('-' * 80)
    print(end)

In [None]:
# randomly sample from data
num_sampels = 5
idx1, idx2 = randint(0, len(dataset)), randint(0, len(dataset))
print("idx1 = %d and idx2 = %d" %(idx1, idx2))
enc_seq1, dec_seq1, _, len1 = dataset[idx1]
enc_seq2, dec_seq2, _, len2 = dataset[idx2]
enc_seqs = torch.LongTensor([enc_seq1, enc_seq2]).to(device)
dec_seqs = torch.LongTensor([dec_seq1, dec_seq2]).to(device)
lens = torch.LongTensor([len1, len2]).to(device)
begin_seq, end_seq = transform(enc_seqs.cpu().numpy(),
                               idx_to_word=idx_to_word,
                               eos_idx=symbols['<pad>'])

In [None]:
# annealing VAE latent space interpolation
_, z, _ = annealing_vae(enc_seqs, dec_seqs, lens)
seq1_z, seq2_z = torch.chunk(z.data.cpu(), 2)
seq1_z, seq2_z = seq1_z.squeeze().numpy(), seq2_z.squeeze().numpy()
z = torch.Tensor(interpolate(seq1_z, seq2_z, num_sampels)).to(device)
samples = annealing_vae.inference(z)
show(begin_seq, end_seq, samples)