In [1]:
import sys
sys.path.append('..')

In [2]:
import os
import time
import math
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from sklearn.manifold import TSNE
from multiprocessing import cpu_count

In [3]:
from ptb import PTB
from model import RNNVAE

In [4]:
# device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [5]:
# Penn TreeBank (PTB) dataset
data_path = '../data'
max_len = 96
splits = ['train', 'valid', 'test']
datasets = {split: PTB(root=data_path, split=split) for split in splits}

In [6]:
# data loader
batch_size = 32
dataloaders = {split: DataLoader(datasets[split],
                                 batch_size=batch_size,
                                 shuffle=split=='train',
                                 num_workers=cpu_count(),
                                 pin_memory=torch.cuda.is_available())
                                 for split in splits}
symbols = datasets['train'].symbols

In [7]:
# RNNVAE model
embedding_size = 300
hidden_size = 256
latent_dim = 2
dropout_rate = 0.5
model = RNNVAE(vocab_size=datasets['train'].vocab_size,
               embed_size=embedding_size,
               time_step=max_len,
               hidden_size=hidden_size,
               z_dim=latent_dim,
               dropout_rate=dropout_rate,
               bos_idx=symbols['<bos>'],
               eos_idx=symbols['<eos>'],
               pad_idx=symbols['<pad>'])
model = model.to(device)


RNNVAE(
  (encoder): LSTMEncoder(
    (embedding): Embedding(10002, 300, padding_idx=0)
    (rnn): LSTM(300, 256, batch_first=True)
    (output): Linear(in_features=512, out_features=4, bias=True)
  )
  (embedding): Embedding(10002, 300, padding_idx=0)
  (init_h): Linear(in_features=2, out_features=256, bias=True)
  (init_c): Linear(in_features=2, out_features=256, bias=True)
  (rnn): LSTM(300, 256, batch_first=True)
  (output): Linear(in_features=256, out_features=10002, bias=True)
)


In [8]:
# folder to save model
save_path = 'vanilla'
if not os.path.exists(save_path):
    os.makedirs(save_path)

In [9]:
# objective function
learning_rate = 0.001
criterion = nn.NLLLoss(size_average=False, ignore_index=symbols['<pad>'])
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# negative log likelihood
def NLL(logp, target, length):
    target = target[:, :torch.max(length).item()].contiguous().view(-1)
    logp = logp.view(-1, logp.size(-1))
    return criterion(logp, target)

# KL divergence
def KL_div(mu, logvar):
    return -0.5 * torch.sum(1. + logvar - mu.pow(2) - logvar.exp())

In [None]:
# training setting
epoch = 20
print_every = 50

In [None]:
# training interface
step = 0
tracker = {'ELBO': [], 'NLL': [], 'KL': [], 'KL_weight': []}
start_time = time.time()
for ep in range(epoch):
    # learning rate decay
    if ep >= 10 and ep % 2 == 0:
        learning_rate = learning_rate * 0.5
        for param_group in optimizer.param_groups:
            param_group['lr'] = learning_rate

    for split in splits:
        dataloader = dataloaders[split]
        model.train() if split == 'train' else model.eval()
        totals = {'ELBO': 0., 'NLL': 0., 'KL': 0., 'words': 0}

        for itr, (enc_inputs, dec_inputs, targets, lengths) in enumerate(dataloader):
            bsize = enc_inputs.size(0)
            enc_inputs = enc_inputs.to(device)
            dec_inputs = dec_inputs.to(device)
            targets = targets.to(device)
            lengths = lengths.to(device)

            # forward
            logp, mu, logvar = model(enc_inputs, dec_inputs, lengths)

            # calculate loss
            NLL_loss = NLL(logp, targets, lengths + 1)
            KL_loss = KL_div(mu, logvar)
            KL_weight = 1.0
            loss = (NLL_loss + KL_weight * KL_loss) / bsize

            # cumulate
            totals['ELBO'] += loss.item() * bsize
            totals['NLL'] += NLL_loss.item()
            totals['KL'] += KL_loss.item()
            totals['words'] += torch.sum(lengths).item()

            # backward and optimize
            if split == 'train':
                step += 1
                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), 5)
                optimizer.step()

                # track
                tracker['ELBO'].append(loss.item())
                tracker['NLL'].append(NLL_loss.item() / bsize)
                tracker['KL'].append(KL_loss.item() / bsize)
                tracker['KL_weight'].append(KL_weight)

                # print statistics
                if itr % print_every == 0 or itr + 1 == len(dataloader):
                    print("%s Batch %04d/%04d, ELBO-Loss %.4f, "
                          "NLL-Loss %.4f, KL-Loss %.4f, KL-Weight %.4f"
                          % (split.upper(), itr, len(dataloader),
                             tracker['ELBO'][-1], tracker['NLL'][-1],
                             tracker['KL'][-1], tracker['KL_weight'][-1]))

        samples = len(datasets[split])
        print("%s Epoch %02d/%02d, ELBO %.4f, NLL %.4f, KL %.4f, PPL %.4f"
              % (split.upper(), ep, epoch, totals['ELBO'] / samples,
                 totals['NLL'] / samples, totals['KL'] / samples,
                 math.exp(totals['NLL'] / totals['words'])))

    # save checkpoint
    checkpoint_path = os.path.join(save_path, "E%02d.pkl" % ep)
    torch.save(model.state_dict(), checkpoint_path)
    print("Model saved at %s\n" % checkpoint_path)
end_time = time.time()
print('Total cost time',
      time.strftime("%H hr %M min %S sec", time.gmtime(end_time - start_time)))

TRAIN Batch 0000/1315, ELBO-Loss 176.2793, NLL-Loss 176.2487, KL-Loss 0.0306, KL-Weight 1.0000
TRAIN Batch 0050/1315, ELBO-Loss 137.3026, NLL-Loss 137.2905, KL-Loss 0.0121, KL-Weight 1.0000
TRAIN Batch 0100/1315, ELBO-Loss 131.1157, NLL-Loss 131.1116, KL-Loss 0.0041, KL-Weight 1.0000
TRAIN Batch 0150/1315, ELBO-Loss 145.3204, NLL-Loss 145.3178, KL-Loss 0.0026, KL-Weight 1.0000
TRAIN Batch 0200/1315, ELBO-Loss 130.7767, NLL-Loss 130.7741, KL-Loss 0.0026, KL-Weight 1.0000
TRAIN Batch 0250/1315, ELBO-Loss 142.1345, NLL-Loss 142.1327, KL-Loss 0.0019, KL-Weight 1.0000
TRAIN Batch 0300/1315, ELBO-Loss 133.4554, NLL-Loss 133.4541, KL-Loss 0.0013, KL-Weight 1.0000
TRAIN Batch 0350/1315, ELBO-Loss 139.0087, NLL-Loss 139.0070, KL-Loss 0.0017, KL-Weight 1.0000
TRAIN Batch 0400/1315, ELBO-Loss 132.7210, NLL-Loss 132.7196, KL-Loss 0.0015, KL-Weight 1.0000
TRAIN Batch 0450/1315, ELBO-Loss 131.0713, NLL-Loss 131.0698, KL-Loss 0.0016, KL-Weight 1.0000
TRAIN Batch 0500/1315, ELBO-Loss 123.7692, NLL-Los

TRAIN Batch 1300/1315, ELBO-Loss 111.3552, NLL-Loss 111.3546, KL-Loss 0.0007, KL-Weight 1.0000
TRAIN Batch 1314/1315, ELBO-Loss 110.6116, NLL-Loss 110.6113, KL-Loss 0.0003, KL-Weight 1.0000
TRAIN Epoch 02/20, ELBO 111.8283, NLL 111.8276, KL 0.0007, PPL 200.4506
VALID Epoch 02/20, ELBO 111.5242, NLL 111.5237, KL 0.0005, PPL 208.3713
TEST Epoch 02/20, ELBO 110.3886, NLL 110.3881, KL 0.0005, PPL 195.8649
Model saved at vanilla/E02.pkl

TRAIN Batch 0000/1315, ELBO-Loss 108.3616, NLL-Loss 108.3612, KL-Loss 0.0004, KL-Weight 1.0000
TRAIN Batch 0050/1315, ELBO-Loss 112.6370, NLL-Loss 112.6361, KL-Loss 0.0009, KL-Weight 1.0000
TRAIN Batch 0100/1315, ELBO-Loss 116.0280, NLL-Loss 116.0275, KL-Loss 0.0005, KL-Weight 1.0000
TRAIN Batch 0150/1315, ELBO-Loss 114.9816, NLL-Loss 114.9812, KL-Loss 0.0004, KL-Weight 1.0000
TRAIN Batch 0200/1315, ELBO-Loss 99.7936, NLL-Loss 99.7930, KL-Loss 0.0006, KL-Weight 1.0000
TRAIN Batch 0250/1315, ELBO-Loss 106.5426, NLL-Loss 106.5420, KL-Loss 0.0006, KL-Weight 1.

TRAIN Batch 1050/1315, ELBO-Loss 98.9857, NLL-Loss 98.9848, KL-Loss 0.0009, KL-Weight 1.0000
TRAIN Batch 1100/1315, ELBO-Loss 95.9790, NLL-Loss 95.9777, KL-Loss 0.0013, KL-Weight 1.0000
TRAIN Batch 1150/1315, ELBO-Loss 110.1893, NLL-Loss 110.1884, KL-Loss 0.0009, KL-Weight 1.0000
TRAIN Batch 1200/1315, ELBO-Loss 117.2894, NLL-Loss 117.2880, KL-Loss 0.0013, KL-Weight 1.0000
TRAIN Batch 1250/1315, ELBO-Loss 102.8723, NLL-Loss 102.8713, KL-Loss 0.0010, KL-Weight 1.0000
TRAIN Batch 1300/1315, ELBO-Loss 90.3957, NLL-Loss 90.3949, KL-Loss 0.0008, KL-Weight 1.0000
TRAIN Batch 1314/1315, ELBO-Loss 115.8133, NLL-Loss 115.8128, KL-Loss 0.0006, KL-Weight 1.0000
TRAIN Epoch 05/20, ELBO 104.1027, NLL 104.1016, KL 0.0011, PPL 138.9837
VALID Epoch 05/20, ELBO 108.8868, NLL 108.8859, KL 0.0009, PPL 183.6503
TEST Epoch 05/20, ELBO 107.5780, NLL 107.5771, KL 0.0009, PPL 171.2349
Model saved at vanilla/E05.pkl

TRAIN Batch 0000/1315, ELBO-Loss 112.3991, NLL-Loss 112.3984, KL-Loss 0.0007, KL-Weight 1.0000

In [None]:
# plot KL curve
fig, ax1 = plt.subplots()
lns1 = ax1.plot(tracker['KL_weight'], 'b', label='KL term weight')
ax1.set_ylim([-0.05, 1.05])
ax1.set_xlabel('Step')
ax1.set_ylabel('KL term weight')
ax2 = ax1.twinx()
lns2 = ax2.plot(tracker['KL'], 'r', label='KL term value')
ax2.set_ylabel('KL term value')
lns = lns1 + lns2
labs = [l.get_label() for l in lns]
ax1.legend(lns, labs, bbox_to_anchor=(0., 1.02, 1., .102),
           ncol=2, mode="expand", borderaxespad=0.)
plt.show()

In [None]:
# latent space visualization
features = np.empty([len(datasets['test']), latent_dim])
for itr, (enc_inputs, dec_inputs, _, lengths) in enumerate(dataloaders['test']):
    enc_inputs = enc_inputs.to(device)
    dec_inputs = dec_inputs.to(device)
    lengths = lengths.to(device)
    _, mu, _ = model(enc_inputs, dec_inputs, lengths)
    start, end = batch_size * itr, batch_size * (itr + 1)
    features[start:end] = mu.data.cpu().numpy()
tsne_z = TSNE(n_components=2).fit_transform(features)
tracker['z'] = tsne_z

plt.figure()
plt.scatter(tsne_z[:, 0], tsne_z[:, 1], s=25, alpha=0.5)
plt.show()

In [None]:
# save learning results
sio.savemat("vanilla.mat", tracker)