In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torchtext.data.utils import get_tokenizer
from torchtext.datasets import IMDB
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data.functional import to_map_style_dataset

import time
import os

import numpy as np

from sklearn.feature_extraction.text import CountVectorizer

from google.colab import drive

In [3]:
drive.mount('/content/gdrive')

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 2
MIN_FREQ = 10

UNK_IDX = 0
BOS_IDX = 1
EOS_IDX = 2
PAD_IDX = 3
SPEC_TOKENS = ['<UNK>', '<BOS>', '<EOS>', '<PAD>']

lr = 1e-3
lr_decay_every = 1000000
epochs = 5
log_interval = 10

Mounted at /content/gdrive


In [4]:
class WordDataset:
    def __init__(self):
        self.tokenizer = get_tokenizer('basic_english')

        train_dataset, test_dataset = iter(IMDB(split=('train', 'test')))
        train_dataset, test_dataset = to_map_style_dataset(train_dataset), to_map_style_dataset(test_dataset)

        self.vocab = build_vocab_from_iterator(self.build_vocab([train_dataset, test_dataset]), specials=SPEC_TOKENS)
        self.vocab.set_default_index(self.vocab['<UNK>'])
        self.vocab_length = len(self.vocab.get_itos())

        self.vectorizer = CountVectorizer(vocabulary=self.vocab.get_itos(), tokenizer=self.tokenizer)

        self.text_transform = lambda x: self.vocab(self.tokenizer(x))

        self.train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, collate_fn=self.vectorize_batch)
        self.test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, collate_fn=self.vectorize_batch)

    def build_vocab(self, datasets):
        for dataset in datasets:
            for _, text in dataset:
                yield self.tokenizer(text)

    def vectorize_batch(self, batch):
        label_list, text_list, offsets = [], [], []
        for Y, X in batch:
            print(Y, X)
            label_list.append(self.text_transform(Y))
            tmp_X = torch.tensor(self.text_transform(X), dtype=torch.int64)
            text_list.append(torch.cat([torch.tensor([BOS_IDX]), tmp_X, torch.tensor([EOS_IDX])]))
        label_list = torch.tensor(label_list, dtype=torch.int64)
        text_list = torch.cat(text_list)
        return label_list.to(DEVICE), text_list.to(DEVICE)

    def text_translate(self, x):
        return ' '.join([self.vocab.get_itos()[i] for i in x])



In [5]:
class cVAE(nn.Module):
    def __init__(self, vocab_len, h_dim, z_dim, c_dim, p_word_dropout=0.3, max_sent_len=15):
        super(cVAE, self).__init__()

        self.MAX_SENT_LEN = max_sent_len

        self.vocab_len = vocab_len
        self.h_dim = h_dim
        self.z_dim = z_dim
        self.c_dim = c_dim
        self.p_word_dropout = p_word_dropout

        # Word embeddings layer
        self.emb_dim = h_dim
        self.word_emb = nn.Embedding(vocab_len, h_dim, PAD_IDX)

        # Encoder
        self.encoder = nn.GRU(self.emb_dim, h_dim)
        self.mu = nn.Linear(h_dim, z_dim)
        self.log_var = nn.Linear(h_dim, z_dim)

        # Decoder
        self.decoder = nn.GRU(self.emb_dim + z_dim + c_dim, z_dim + c_dim, dropout=0.3)
        self.decoder_fc = nn.Linear(z_dim + c_dim, vocab_len)

    def forward(self, sentence, label):
        self.train()

        encoder_inputs = sentence
        decoder_inputs = sentence
        decoder_targets = torch.cat([sentence[1:], torch.tensor([PAD_IDX]).to(DEVICE)], dim=0).to(DEVICE)
        
        mu, log_var = self.encoder_fn(encoder_inputs)
        z = self.reparameterize(mu, log_var)
        
        y = self.decoder_fn(decoder_inputs, z, label)

        recon_loss = F.cross_entropy(
            y.view(-1, self.vocab_len),
            decoder_targets.view(-1),
            size_average=True
        )

        kl_loss = torch.mean(0.5 * torch.sum(torch.exp(log_var) + mu ** 2 - 1 - log_var, 1))

        return recon_loss, kl_loss

    def encoder_fn(self, inputs):
        inputs = self.word_emb(inputs).unsqueeze(0)

        _, h = self.encoder(inputs, None)
        h = h.view(-1, self.h_dim)

        mu = self.mu(h)
        log_var = self.log_var(h)

        return mu, log_var

    def decoder_fn(self, inputs, z, label):
        decoder_inputs = self.word_dropout(inputs)

        seq_len = decoder_inputs.size(0)
        
        init_h = torch.cat([z.unsqueeze(0), torch.transpose(label, 0, 1).expand(seq_len, -1).unsqueeze(0)], dim=2)

        inputs_emb = self.word_emb(decoder_inputs).unsqueeze(0)
        inputs_emb = torch.cat([inputs_emb, init_h], dim=2).to(DEVICE)

        outputs, a = self.decoder(inputs_emb, init_h)
        _, seq_len, sentence_size = outputs.size()

        y = self.decoder_fc(outputs)

        return y

    def word_dropout(self, inputs):
        data = inputs.clone()

        mask = torch.from_numpy(np.random.binomial(1, p=self.p_word_dropout, size=tuple(data.size())).astype('bool'))

        data[mask] = UNK_IDX

        return data

    def reparameterize(self, mu, log_var):
        eps = torch.randn(self.z_dim).to(DEVICE)
        return mu + torch.exp(log_var / 2).to(DEVICE) * eps

    def sample_z(self, sentence_size):
        z = torch.randn(sentence_size, self.z_dim).to(DEVICE)
        return z

    def sample_sentence(self, z, c, temp=1):
        self.eval()

        word = torch.LongTensor([BOS_IDX]).to(DEVICE)

        z = z.view(1, 1, -1)
        c = c.view(1, 1, -1)

        h = torch.cat([z, c], dim=2).to(DEVICE)

        outputs = []

        for i in range(self.MAX_SENT_LEN):
            emb = self.word_emb(word).view(1, 1, -1)
            emb = torch.cat([emb, z, c], dim=2)

            output, h = self.decoder(emb, h)
            y = self.decoder_fc(output).view(-1)
            y = F.softmax(y / temp, dim=0)

            idx = torch.multinomial(y, 1)

            word = torch.LongTensor([int(idx)]).to(DEVICE)

            outputs.append(idx)

        self.train()

        outputs = torch.LongTensor(outputs)
        return outputs

    def generate_sentences(self, seed, label):
        s = time.time()
        z = self.word_emb(seed)
        c = torch.FloatTensor([1, 0]).to(DEVICE) if label == 'pos' else torch.FloatTensor([1, 0]).to(DEVICE)
        X_gen = self.sample_sentence(z, c)
        print('TIME: ', time.time() - s)

        return X_gen

In [6]:
dataset = WordDataset()

100%|██████████| 84.1M/84.1M [00:02<00:00, 32.0MB/s]


In [7]:
model = cVAE(
    vocab_len=dataset.vocab_length,
    h_dim=64,
    z_dim=64,
    c_dim=2
).to(DEVICE)

optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss(reduction='mean')
scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1)
total_accu = None

  "num_layers={}".format(dropout, num_layers))


=== Training script ===

In [None]:
kld_start_inc = 3000
kld_weight = 0.01
kld_max = 0.15
kld_inc = (kld_max - kld_weight) / (epochs - kld_start_inc)

model = cVAE(
    vocab_len=dataset.vocab_length,
    h_dim=64,
    z_dim=64,
    c_dim=2
).to(DEVICE)

optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss(reduction='mean')
scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1)
total_accu = None

for e in range(epochs):
    print('Current epoch: ', e)
    for idx, (label, text) in enumerate(dataset.train_loader):
        recon_loss, kl_loss = model.forward(text, label)
        loss = recon_loss + kld_weight * kl_loss
        
        if e > kld_start_inc and kld_weight < kld_max:
            kld_weight += kld_inc
 
        loss.backward()
        grad_norm = nn.utils.clip_grad_norm_(model.parameters(), 5)
        optimizer.step()
        optimizer.zero_grad()

        if e % log_interval == 0:
            z = model.sample_z(1)

            sample_idxs = model.sample_sentence(z, label)
            sample_sent = dataset.text_translate(sample_idxs)

            print('Iter-{}\n\tLoss:\t\t{:.4f}\n\tRecon:\t\t{:.4f}\n\tKL:\t\t{:.4f}\n\tGrad_norm:\t{:.4f}'
                  .format(e, loss.data, recon_loss.data, kl_loss.data, grad_norm))

            print('Sample: "{}"\n'.format(sample_sent))

torch.save(model.state_dict(), '/content/gdrive/My Drive/DYPLOM/models/cVAE.bin')

=== Testing script ===

In [None]:
torch.manual_seed(int(time.time()))

MODEL_NAME = 'cVAE_50'
samples = 5
seed = torch.Tensor(dataset.text_transform('The')).to(torch.int64).to(DEVICE)

model = cVAE(
    vocab_len=dataset.vocab_length,
    h_dim=64,
    z_dim=64,
    c_dim=2
).to(DEVICE)

model.load_state_dict(torch.load('/content/gdrive/My Drive/DYPLOM/models/{}.bin'.format(MODEL_NAME), 
                                 map_location=lambda storage, 
                                 loc: storage))
for i in range(samples):
  print('\n=== ', i + 1, ' ===')
  pred_pos = model.generate_sentences(seed, 'pos')
  print('POS: ', dataset.text_translate(pred_pos))

  pred_neg = model.generate_sentences(seed, 'neg')
  print('NEG: ', dataset.text_translate(pred_neg))

In [None]:
!pip install torchtext==0.11.0

In [None]:
!nvidia-smi