In [None]:
import torch
import torch.nn as nn
from torch.autograd import Variable as V
import torch.nn.functional as F
from torch.distributions import Normal
from torch.distributions.kl import kl_divergence
import numpy as np
import matplotlib.pyplot as plt

#torch text
import torchtext.data as data
import torchtext.datasets as datasets
from torchtext.vocab import GloVe

In [None]:
BOS_WORD = '<s>'
EOS_WORD = '</s>'
TEXT = data.Field(lower=True, init_token=BOS_WORD, eos_token=EOS_WORD, fix_length=30)
LABEL = data.Field(sequential=False)

# make splits for data
train, val, test = datasets.SST.splits(TEXT, LABEL)
TEXT.build_vocab(train, vectors=GloVe(name='6B', dim=100))
LABEL.build_vocab(train)

In [None]:
print('vars(train[0])', vars(train[9]))

In [None]:
# LSTM Encoder / Inference Network
class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_size, hidden_size, latent_dim):
        super(Encoder, self).__init__()
        
        #relevant sizes
        self.vocab_size = vocab_size 
        self.embed_size = embedding_size
        self.latent_dim = latent_dim
        self.hidden_size = hidden_size
        self.dropout = nn.Dropout(0.5)
        self.emb_layer = nn.Embedding(self.embed_size, self.vocab_size)
        self.emb_layer.weight.data = TEXT.vocab.vectors.clone()
        self.enc_layer = nn.GRU(self.embed_size, self.hidden_size)
        
        self.mu_layer = nn.Linear(self.hidden_size, self.latent_dim)
        self.logvar_layer = nn.Linear(self.hidden_size, self.latent_dim)
        
    def forward(self, input_seq):
        emb = self.emb_layer(input_seq)
        emb = self.dropout(emb)
        _, hidden  = self.enc_layer(emb)        
        mu = self.mu_layer(hidden)
        logvar = self.logvar_layer(hidden)
        return mu, logvar, emb

In [None]:
# Bag-of-Word Generative Model
class Decoder(nn.Module):
    def __init__(self, vocab_size, embedding_size, hidden_size, latent_dim):
        super(Decoder, self).__init__()
        # important sizes
        self.vocab_size = vocab_size 
        self.embed_size = embedding_size
        self.latent_dim = latent_dim
        self.hidden_size = hidden_size
        
        self.layer1 = nn.Linear(latent_dim, self.hidden_size)
        self.decode_layer = nn.GRU(self.embed_size, self.hidden_size)
        self.layer2 = nn.Linear(self.hidden_size, self.vocab_size)
        
    def forward(self, decoder_input, latent):
        hidden = self.layer1(latent)

        output, _ = self.decode_layer(decoder_input, hidden)
        projection = self.layer2(output)

        return F.log_softmax(projection, dim=-1)

In [None]:
# VAE using reparameterization "rsample"
class NormalVAE(nn.Module):
    def __init__(self, encoder, decoder):
        super(NormalVAE, self).__init__()

        # Parameters phi and computes variational parameters lambda
        self.encoder = encoder

        # Parameters theta, p(x | z)
        self.decoder = decoder
    
    def forward(self, x_src):
        # Example variational parameters lambda
        mu, logvar, input_emb = self.encoder(x_src)
        #print(logvar.mul(0.5).exp())
        
        q_normal = Normal(loc=mu, scale=logvar.mul(0.5).exp())
        
        # Reparameterized sample.
        z_sample = q_normal.rsample()
        #z_sample = mu
        return self.decoder(input_emb, z_sample), mu, logvar       

This part is slow to run on CPU. But it shows the setup for a Miao (2016) type VAE over text. Here we use powerful encoder in the form of a LSTM. But use a very simple generative model that predicts a set of works (in binary represenation) as the output. The aim is that the latent variable should learn something akin to a topic about the words themseles. 

In [None]:
use_cuda = torch.cuda.is_available()
if use_cuda: 
    cuda_device = 0 
else: 
    cuda_device = -1
    
torch.cuda.set_device(0)
print(cuda_device)

print(torch.backends.cudnn.version())

In [None]:
PAD_token = TEXT.vocab.stoi["<pad>"]
SOS_token = TEXT.vocab.stoi["<s>"]
EOS_token = TEXT.vocab.stoi["</s>"]
print(PAD_token, SOS_token, EOS_token)

In [None]:
def train_model(train_set, n_epochs, batch_size, criterion, optim, vae_model, enc_model, dec_model): 
    step = 0 
    ELBO = [] 
    NLL = [] 
    KL = [] 
    ALPHA = [] 
    for epoch in range(n_epochs):
        total_loss = 0
        total_kl = 0
        total = 0
        x0 = 500 
        k = 0.10
        for i, t in enumerate(train_set):
            if t.label.size(0) != batch_size : continue
            vae_model.zero_grad()
            x = t.text
            target = x[1:, :]
            pad =  V(torch.Tensor(1, batch_size).fill_(PAD_token).long())
            pad = pad.cuda() if use_cuda else pad 
            
            target = torch.cat((target, pad), dim=0) 
            out, mu, logvar = vae_model(x)
            _, sample = torch.topk(out, 1, dim=-1)

            kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

            #Linear KL annealing 
            #alpha = min(1, step/x0)
            #Logistic KL annealing
            alpha = float(1/(1+np.exp(-k*(step-x0))))
            
            NLL_loss = criterion(out.view(-1, out.size()[-1]), target.view(-1))
            loss = NLL_loss + alpha * kl 
            loss = loss / batch_size

            KL.append(float(kl.data/batch_size))
            NLL.append(float(NLL_loss.data/batch_size))
            ELBO.append(float(loss.data))
            ALPHA.append(alpha)
            
            total_loss += loss.data / batch_size
            total_kl += kl.data / batch_size
            
            total += 1
            step += 1 
            loss.backward()
            optim.step()

            torch.nn.utils.clip_grad_norm(enc_model.parameters(), 1.0)
            torch.nn.utils.clip_grad_norm(dec_model.parameters(), 1.0)
        print(epoch, total_loss[0] / total , total_kl[0] / total)
    return KL, NLL, ELBO, ALPHA

In [None]:
BATCH_SIZE = 64
WORD_DIM = 100
HIDDEN_DIM = 256 
LATENT_DIM = 32
num_embeddings = len(TEXT.vocab)

train_iter, test_iter = data.BucketIterator.splits(
    (train, test), batch_size=BATCH_SIZE, device=cuda_device, repeat=False)

#bce = nn.BCEWithLogitsLoss(size_average=False)
NLL = torch.nn.NLLLoss(size_average=False, ignore_index=PAD_token)

encoder = Encoder(len(TEXT.vocab), WORD_DIM, HIDDEN_DIM, LATENT_DIM)
decoder = Decoder(len(TEXT.vocab), WORD_DIM, HIDDEN_DIM, LATENT_DIM)

encoder = encoder.cuda() if use_cuda else encoder
decoder = decoder.cuda() if use_cuda else decoder 

vae = NormalVAE(encoder, decoder)
vae = vae.cuda() if use_cuda else vae 

learning_rate = 0.01
#optim = torch.optim.SGD(vae.parameters(), lr = learning_rate)
optim = torch.optim.Adam(vae.parameters(), lr = learning_rate)


if use_cuda: 
    p = Normal(V(torch.zeros(BATCH_SIZE, LATENT_DIM)).cuda(), V(torch.ones(BATCH_SIZE, LATENT_DIM)).cuda())
else: 
    p = Normal(V(torch.zeros(BATCH_SIZE, LATENT_DIM)), V(torch.ones(BATCH_SIZE, LATENT_DIM)))

NUM_EPOCHS = 10
KL, NLL, ELBO, ALPHA = train_model(train_set=train_iter, n_epochs=NUM_EPOCHS, batch_size=BATCH_SIZE,
            criterion=NLL, optim=optim, vae_model=vae, enc_model=encoder,
            dec_model=decoder)

# Post Training Analysis 

## Loss analysis

In [None]:
fig, ax1 = plt.subplots() 

ax1.plot(range(len(KL)), KL, '-g')
ax1.set_ylabel('KL Loss')

ax2 = ax1.twinx() 
ax2.plot(range(len(ALPHA)), ALPHA, '-b')
ax2.set_ylabel('KL Term Weight')
#plt.plot(range(len(ELBO)), ELBO, label ='ELBO loss')
#print(NLL[-10:])
plt.xlabel('Step')
plt.title('KL loss with logistic annealing schedule')
plt.show()

In [None]:
import scipy.signal 

fig, ax1 = plt.subplots() 

ax1.plot(range(len(KL[5:])), KL[5:])
ax1.set_ylabel('KL Loss')

filtered_ELBO = scipy.signal.savgol_filter(ELBO, 101, 2)

ax2 = ax1.twinx() 
# ax2.plot(range(len(ELBO)), ELBO, 'b-')
# ax2.set_ylabel('ELBO')
ax2.plot(range(len(ELBO[5:])), filtered_ELBO[5:], 'g-')
ax2.set_ylabel('ELBO')
#plt.plot(range(len(ELBO)), ELBO, label ='ELBO loss')
#print(NLL[-10:])
plt.xlabel('Step')
plt.title('KL loss and filtered ELBO with logistic annealing schedule')
plt.show()

## Generating Sentences

In [None]:
def inference(z, enc_model, dec_model, max_len): 
    t = 0 
    b_size = z.size(0)
    generations = torch.Tensor(max_len, b_size).fill_(PAD_token).long()
    running_seqs = torch.arange(0, b_size, out=torch.LongTensor()).long() 
    running_seqs = running_seqs.cuda() if use_cuda else running_seqs
    
    hidden = dec_model.layer1(z)
    
    hidden = hidden.unsqueeze(0)
    while(t < max_len): 
        if t == 0: 
            input_seq = V(torch.Tensor(b_size).fill_(SOS_token).long())
            input_seq = input_seq.cuda() if use_cuda else input_seq
        
        input_seq = input_seq.unsqueeze(0)
        
        #embed
        input_emb = enc_model.emb_layer(input_seq)
        
        output, hidden = dec_model.decode_layer(input_emb, hidden)
        logits = F.log_softmax(dec_model.layer2(output), dim=-1)
        
        _, sample = torch.topk(logits, 1, dim=-1)
        input_seq = sample.squeeze()
        generations[t, :] = input_seq
        
        t += 1 
    return generations

In [None]:
# sample outputs
NUM_SAMPLES = 10 
#sample from p(z) standard normal 
m = Normal(torch.zeros(LATENT_DIM), torch.ones(LATENT_DIM))
sample = m.sample((NUM_SAMPLES, 1))
sample = V(sample.squeeze(1))
print(sample.shape)
sample = sample.cuda() if use_cuda else sample 

gen = inference(sample, encoder, decoder, 15)
for i in range(gen.size()[1]):
    idx = gen[:, i]
    print(" ".join([TEXT.vocab.itos[d] for d in idx]))

In [None]:
#interpolated
NUM_SAMPLES = 1 
m = Normal(torch.zeros(LATENT_DIM), torch.ones(LATENT_DIM))
z_1 = m.sample((NUM_SAMPLES, 1))
z_2 = m.sample((NUM_SAMPLES, 1)) 
z_interpol = torch.zeros(6, LATENT_DIM)
for i, alpha in enumerate([0, 0.2, 0.4, 0.6, 0.8, 1.0]): 
    z_comb = alpha*z_1 + (1-alpha)*z_2
    z_interpol[i]= z_comb
z_interpol = np.asarray(z_interpol)

sample = V(torch.Tensor(z_interpol))
sample = sample.cuda() if use_cuda else sample 

gen = inference(sample, encoder, decoder, 15)
for i in range(gen.size()[1]):
    idx = gen[:, i]
    print(" ".join([TEXT.vocab.itos[d] for d in idx]))

## Plotting variational means by class label

In [None]:
label_x = [ [] for i in range(3)]
label_y = [ [] for i in range(3)]
full_arr = [ [] for i in range(3)]
for datum in list(test_iter): 
    text = datum.text
    label = datum.label
    label -= 1 
    _, mu, _ = encoder(text)
    mu = mu.squeeze(0)
    for i in range(len(label)): 
        label_x[label[i].data.cpu().numpy()].append(mu[i].data.cpu().numpy()[0])
        label_y[label[i].data.cpu().numpy()].append(mu[i].data.cpu().numpy()[1])
        full_arr[label[i].data.cpu().numpy()].append(mu[i].data.cpu().numpy())

fig = plt.figure()
ax1 = fig.add_subplot(111)
for i, label in enumerate(['neg', 'neutral', 'pos']): 
    ax1.scatter(label_x[i], label_y[i], s = 5, label=label)

plt.legend(loc='upper left');
plt.show()

In [None]:
print([np.mean(np.asarray(label_x[i])) for i in range(3)])
print([np.mean(np.asarray(label_y[i])) for i in range(3)])
z_neg = np.mean(np.asarray(full_arr[0]), axis=0)
print(z_neg)
z_pos = np.mean(np.asarray(full_arr[1]), axis=0)
print(z_pos)

In [None]:
z_interpol = torch.zeros(6, LATENT_DIM)
z_neg = torch.Tensor(z_neg)
z_pos = torch.Tensor(z_pos)
for i, alpha in enumerate([0, 0.3, 0.5, 0.7, 1.0]): 
    z_comb = alpha*z_neg + (1-alpha)*z_pos
    z_interpol[i]= z_comb
z_interpol = np.asarray(z_interpol)

sample = V(torch.Tensor(z_interpol))
sample = sample.cuda() if use_cuda else sample 

gen = inference(sample, encoder, decoder, 15)
for i in range(gen.size()[1]):
    idx = gen[:, i]
    print(" ".join([TEXT.vocab.itos[d] for d in idx]))

In [None]:
feed = iter(test_iter)

In [None]:
next_batch = next(feed)
print(next_batch.label[:10])
texts = next_batch.text[:, :10]
print(texts.shape)
for i in range(text.size()[1]): 
    print(" ".join([TEXT.vocab.itos[d] for d in texts[:, i]]))

In [None]:
print(TEXT.vocab.itos[5])