## MPCS 53113
## Final Project
## Conditioned Text Generation
## By: Andrew Comstock

## Project layout:
- 1. Define the Environment
- 2. Define the Model
- 3. Train the Model
- 4. Test the Model

## 1. Create the Environment

In [39]:
import torch
import torch.utils.data as tud
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import Counter, defaultdict
from torchtext import data
from torchtext import datasets
import operator
import os, math
import numpy as np
import random
import copy
from torchtext.vocab import Vectors
from itertools import chain

# Set up training enviornment
BATCH_SIZE = 32
MAX_VOCAB_SIZE = 20000
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

SEED = 9432
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

# Set up the dataset we will use to train the generator
TEXT = data.Field(tokenize = 'spacy', lower=True, eos_token="<eos>", init_token="<init>")
LABEL = data.LabelField(dtype = torch.float)

disc_train_data, disc_test_data = datasets.IMDB.splits(TEXT, LABEL)
disc_train_data, disc_valid_data = disc_train_data.split(random_state = random.seed(SEED))

TEXT.build_vocab(disc_train_data, max_size = MAX_VOCAB_SIZE)
LABEL.build_vocab(disc_train_data)
VOCAB_SIZE_DISC = len(TEXT.vocab)

VOCAB_SIZE = len(TEXT.vocab)

# Split the data into a train, valid, and test set
disc_train_iterator, disc_valid_iterator, disc_test_iterator = data.BucketIterator.splits(
    (disc_train_data, disc_valid_data, disc_test_data), 
    batch_size = BATCH_SIZE, device = device)

# setup model parameters
INPUT_DIM = VOCAB_SIZE
EMBEDDING_DIM = 50
CONTEXT_DIM = 1
LATENT_DIM = 30
HIDDEN_DIM = 100
PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]
USE_CUDA = torch.cuda.is_available()
VAE_PATH = 'VAE-model.pt'

Setting up Vocab
Setting up Iterators
Done with setup!


In [0]:
# Define several commonly used functions
def lookup_indexes(x, vocab=TEXT.vocab):
    ''' 
    lokoup_indexes takes as input a iterable of strings and an optional
    vocabulary, and returns the vocabulary indexes of those words

    x: iterable of strings
    vocab: vocabulary item

    returns: list of integers
    '''
    if vocab is not None:
        x = [vocab.stoi[i] for i in x]
    return [t for t in x]
    
def lookup_words(x, vocab=TEXT.vocab):
    '''
    lookup_words takes as in put an iterable of integers and an optional
    vocabulary, and returns the vocabulary word at that integer index

    x: iterable of integers
    vocab: vocabulary item

    returns: list of strings
    '''
    if vocab is not None:
        x = [vocab.itos[i] for i in x]
    return [str(t) for t in x]

def makeStarter(s, device, random=False, starter=True):
    '''
    makeStarter takes a string and device and returns a tensor of
    vocabulary indices of the form required for the Encoder, Decoder, and Generator
    
    s: a string if random is false.
       If random is true, s is an integer representing the number of random
       words to create.
    device: A tensor device, either CPU or CUDA
    random: True or False. If True, instead of creating a tensor from a string,
            it generates a tensor with s random words from the dictionary.
    starter: If True, the first element in the return will be
             the start of sentence tag.
    
    returns: A 1D tensor of type device which can be input to the model
    
    NOTE: In the vocabulary, words and puncuation are seperated by a space. 
          For instance "Great!" should be "Great !".
    '''
    if starter:
        starter = torch.Tensor(lookup_indexes(["<init>"])).long().reshape((1,1)).to(device)
    else:
        starter = torch.Tensor().long().to(device)
        
    if random:
        for i in range(s):
            nextWord = torch.randint(INPUT_DIM, (1,1), dtype=torch.long).to(device)
            starter = torch.cat([starter, nextWord], dim=0)
    else:
        arr = s.split(" ")
        for s in arr:
            nextWord = torch.Tensor(lookup_indexes([s.lower()])).long().reshape((1,1)).to(device)
            starter = torch.cat([starter, nextWord], dim=0)
    return starter

def makeTarget(target, device, vocab_len, pad_idx=1, max_len=None):
    '''
    makeTarget takes a tensor of vocabulary words and modifies it to
    the target form for training the VAE.
    
    target: ND tensor of type device
    device: pytorch device.
    vocab_len: The size of the vocabulary
    pad_idx: vocabulary index of the pad character.
    
    returns: A tensor of the same shape and size of the original, where
             each element is shifted to the left by one. With a pad character
             appended to the end.
    '''
    batch_size = target.shape[1]
    pad_word = torch.Tensor([pad_idx]).long().repeat(1, batch_size).to(device)
    target = torch.cat([target[1:], pad_word], dim=0)
    return target.long()

## 2. Define the System

Encoder Module:
    This is used for both the encoder and the discriminator

In [0]:
class RNNBinaryEncoder(nn.Module):
    '''
    RNNBinaryEncoder is a encoder module capable of producing
    context and latent vectors for encoding and sentiment analysis.
    
    This module is used for both the Discriminator and the Encoder.
    
    For Discriminator operaiton, the latent dimensions are set to zero.
    '''
    def __init__(self, output_dim, embedding_dim, hidden_dim, context_dim, latent_dim, device):
        super().__init__()
        
        self.output_dim = output_dim
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.context_dim = context_dim
        self.latent_dim = latent_dim        
        
        self.embedding = nn.Embedding(output_dim, embedding_dim)
        
        self.rnn = nn.RNN(embedding_dim, hidden_dim)
        
        # Linear layers for the context and latent layers
        self.fcout = nn.Linear(hidden_dim, context_dim)
        self.fcmu = nn.Linear(hidden_dim, latent_dim)
        self.fcsig = nn.Linear(hidden_dim, latent_dim)
        self.device = device
        
    def forward(self, text, usedAsDiscriminator=False):
        '''
        usedAsDiscriminator: Boolean. If true, then the output is binary. IE.
        the output is 0 or 1.
        '''
        # Run the text through the embeddings
        embedded = self.embedding(text)
        meaned = torch.mean(embedded, dim=0, keepdim=True)
        
        # Run the embeddings through the RNN
        _, hidden = self.rnn(meaned)
        
        # Find the context
        context = self.fcout(hidden.squeeze(0))
        latent = None
        mu = None
        sig = None
        
        # If usedAsDiscriminator, set the output to either 0 or 1
        if usedAsDiscriminator:
            context[context < 0.5] = 0
            context[context >= 0.5] = 1
        # Otherwise, find the latent vector
        else:
            mu = torch.abs(self.fcmu(hidden.squeeze(0)))
            sig = torch.abs(self.fcsig(hidden.squeeze(0)))
            latent = self.make_latent(mu, sig)

        # Returns mu and sigma for calculating the loss
        return context, latent, mu, sig
    
    def make_latent(self, mu, sig):
        '''
        Make the latent vector using mu, sigma, and sample of the normal distribution
        '''
        # latent = mu + sigma^(1/2) * (unit guassian)
        gauss = torch.randn(self.latent_dim).to(self.device)
        return torch.abs(mu + torch.exp(sig/2) * gauss)

Decoder module: This is used for the Generator

In [0]:
class GRUDecoder(nn.Module):
    '''
    GRUDecoder is the decoder for the VAE module
    Decodes input text, with a given context and latent vector
    '''
    def __init__(self, output_dim, embedding_dim, context_dim, latent_dim, hidden_dim, num_layers=1):
        super().__init__()
        
        self.embedding_dim = embedding_dim
        self.context_dim = context_dim
        self.latent_dim = latent_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.num_layers = num_layers
        
        self.embedding = nn.Embedding(output_dim, embedding_dim)
        
        self.rnn = nn.GRU(context_dim + latent_dim + embedding_dim, hidden_dim, num_layers)
        
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, input, context, latent, hidden=None):
        length = input.shape[0]
        batch_size = input.shape[1]
        
        # Create the embedding and context
        other = torch.cat([context.unsqueeze(0), latent.unsqueeze(0)],dim=2)
        input = self.embedding(input)
        
        input = torch.cat([input, other.repeat(length, 1, 1)],dim=2)
        
        # Run through the network
        output, hidden = self.rnn(input, hidden)
        
        length, batch_size, _ = output.shape
        output = self.fc(output)
        return output, hidden


Generator model:

In [0]:
class EncDec(nn.Module):
    '''
    EncDec is the main model. It contains the encoder, decoder, discrminator
    and generation.
    '''
    def __init__(self, device, discriminator, vocab_size, embedding_dim, hidden_dim, context_dim, latent_dim, num_layers=1, unk_idx=0, pad_idx=1, start_idx=2, eos_idx=3):
        super().__init__()
        self.device = device
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.context_dim = context_dim
        self.num_layers = num_layers
        
        # Default vocabulary definitions for generations
        self.unk_idx = unk_idx
        self.pad_idx = pad_idx
        self.start_idx = start_idx
        self.eos_idx = eos_idx
        
        self.encoder = RNNBinaryEncoder(vocab_size, embedding_dim, hidden_dim, context_dim, latent_dim, device).to(device)
        
        self.decoder = GRUDecoder(vocab_size, embedding_dim, context_dim, latent_dim, hidden_dim, num_layers=num_layers).to(device)
        
        self.discriminator = discriminator.to(device)
        
        # Group the model parameters for training
        self.param_enc = chain(self.encoder.parameters())
        self.param_dec = chain(self.decoder.parameters())
        self.param_vae = chain(self.encoder.parameters(), self.decoder.parameters())
        self.param_disc = chain(self.discriminator.parameters())
        
    def forward(self, text, length=None, teacherForcing=False, use_discriminator=True, tfr=0.5, loss=False):
        '''
        text: tensor of input text
        length: Maximum length of output
        teacherForcing: force teacher forcing
        use_discriminator: Run the discriminator model on the output
        tfr: teacher forcing ratio
        loss: also returns mu and sigma for the loss calculations
        
        returns: output tensor, context tensor, discriminators context tensor
        '''
        if length is None:
            length = text.shape[0]
        else:
            length = min(text.shape[0], length)
        batch_size = text.shape[1]
        
        # Encode the text
        chat, latent, mu, sig = self.encoder(text)
        
        # Decode the text
        if teacherForcing: # Chance for teacher forcing
            outputs = torch.zeros(length, batch_size, self.vocab_size).to(self.device)
            decText = text[0,:].unsqueeze(0)
            hidden = None
            for timestep in range(length):
                output, hidden = self.decoder(decText, chat, latent, hidden=hidden)
                outputs[timestep] = output
                tf = random.random() < tfr
                top1 = output.max(2)[1]
                if tf:
                    decText = text[timestep,:].unsqueeze(0)
                else:
                    decText = top1
        else: # No teacher forcing
            outputs, _ = self.decoder(text, chat, latent)
            
        # Discriminate the output
        cstar = None
        if use_discriminator:
            discInputs = torch.zeros(length, batch_size).to(self.device)
            for t in range(length):
                discInputs[t] = outputs[t].max(1)[1]

            cstar, _, _, _ = self.discriminator(discInputs.long(), usedAsDiscriminator=True)
        
        if loss:
            return outputs, chat, cstar, mu, sig            
        else:
            return outputs, chat, cstar
      
    def generate_beam(self, text=None, context=None, latent=None, maxlen=50, beam=3, adaptive=False, unk=True, batch_size=1, min_len=None, continueStarter=False, rand=False):
        '''
        text: staring text
        context: context vector
        latent: latent vector
        maxlen: maximum output size
        beam: beam size
        adaptive: derank undesirable output
        unk: derank unk characters in output
        batch_size: size of batch
        min_length: suggested minimum length of output
        continueStarter: Generate the context and latent based off of text
        rand: Randomly generate the context
        
        returns: array of beams where each beam contains the following
        (ongoing, probability, output text, input text, hidden, prev)
        '''
        # If not defined, starting text is just the start character
        if text is None:
            text = torch.Tensor([self.start_idx]).long().repeat(1, batch_size).to(self.device)
            
        # Generate the context vector
        if context is None:
            if continueStarter:
                context, _, _, _ = self.encoder(text)
            elif rand:
                context = torch.randn((1,1)).to(self.device)
            else:
                context = torch.randint(1+1, (1,batch_size), dtype=torch.float).to(self.device)
        
        # Generate the latent vector
        if latent is None:
            if continueStarter:
                latent = self.make_latent(batch_size)
            else:
                _, latent, _, _ = self.encoder(text)
                
        # Create the first beam, just containing the starting text
        beams = []
        ongoing = beam
        for i in range(1):
            outputs = []
            for item in text.cpu().flatten():
                outputs.append(item)
            # (ongoing, probability, output text, input text, hidden, prev)
            beams.append((True, 1, outputs, text, None, -1))
            
        first = False
        sm = nn.Softmax(dim=1)
        timestep = 0
        # While there are at least 1 ongoing beam
        while ongoing > 0 and timestep < maxlen:
            timestep = timestep + 1
            newBeams = []
            
            # For each beam in the beams list
            for item in beams:
                con, prob, outputs, text, hidden, prev = item
                
                # If continuing
                if con:
                    # Generate the probabilities for the next word
                    output, newHidden = self.decoder(text, context, latent, hidden=hidden)
                    output = sm(output[-1])
                    tops = torch.topk(output, k=beam, dim=1)
                    
                    idxs = torch.topk(output, k=beam, dim=1)[1]
                    vals = torch.topk(output, k=beam, dim=1)[0]
                    
                    # For all top probability words
                    for k in range(len(idxs[0])):
                        idx = idxs[0][k]
                        
                        newprob = prob * vals[0][k]
                        
                        # Derank output based on generated value
                        if unk and (idx == self.unk_idx or idx == self.pad_idx):
                            newprob = newprob * 0.50
                        
                        if first and adaptive and k > 0:
                            newprob = newprob * 0.80
                            
                        if adaptive and prev == idx:
                            newprob = newprob * 0.1
                            
                        if (not min_len is None) and text.shape[0] < min_len and (idx == self.eos_idx or idx == self.pad_idx):
                            newprob = newprob * 0.01
                        
                        # Create the tensor for the next word
                        nextWord = torch.Tensor([idx]).long().repeat(1, batch_size).to(self.device)
                        
                        newtext = torch.cat([text, nextWord], dim=0)
                        
                        newoutputs = copy.copy(outputs)
                        newoutputs.append(nextWord)
                        
                        # If generated the end of sentence tag, then discontinue the beam
                        newcon = con
                        if idx == self.eos_idx or idx == self.pad_idx:
                            newcon = False
                            
                        newBeams.append((newcon, newprob, newoutputs, newtext, newHidden, idx))
                else:
                    newBeams.append((con, prob, outputs, text, hidden, prev))
            
            # Sort the beams based on their probability 
            topBeams = [newBeams[0]]
            for i in range(1, len(newBeams)):
                _, prob, _, _, _, _ = newBeams[i]
                
                for j in range(len(topBeams)):
                    if prob > topBeams[j][1]:
                        temp = topBeams[j][1]
                        topBeams.insert(j, newBeams[i])
                        if len(topBeams) > beam:
                            del topBeams[-1]
                        break
                if len(topBeams) < beam:
                    topBeams.append(newBeams[i])
                  
            # Count the ongoing beams
            beams = topBeams
            ongoing = 0
            for item in beams:
                con, _, _, _, _, _ = item
                if con:
                    ongoing = ongoing + 1
                    
        return beams
        
    def generate(self, text=None, context=None, latent=None, max_len=50, batch_size=1):
        '''
        Generate random sentences without beamsearch.
        Text: Starting text
        context: desired context vector
        latent: Desired latent vector
        max_len: maximum length of generated sentences
        batch_size: size of batch
        
        returns: tensor of generated words contexted on text, context, and latent
        '''
        # If not provided, setup starting word
        if text is None:
            text = torch.Tensor([self.start_idx]).long().repeat(1, batch_size).to(self.device)

        # If not provided, setup random context
        if context is None:
            context = torch.randint(1+1, (batch_size, self.context_dim), dtype=torch.float).to(device)
        
        # If not provided, setup random latent
        if latent is None:
            latent = self.make_latent(batch_size)
        
        outputs = torch.zeros(max_len, batch_size, 1).to(self.device)
        decText = text[0,:].unsqueeze(0)
        hidden = None
        
        # Generate max_len words and add them to the list of words
        for timestep in range(1, max_len):
            output, hidden = self.decoder(decText, context, latent, hidden=hidden)
            top1 = output.max(2)[1]
            decText = top1
            outputs[timestep] = top1.unsqueeze(2).squeeze(0)
            
        return outputs.squeeze(2).long(), context, latent    
    
    def make_latent(self, batch=1):
        '''
        Sample a random latent vector from the normal distribution
        '''
        latent = torch.abs(torch.randn(batch, self.latent_dim).to(device))
        return latent

## 3. Train the System

In [0]:
class TrainVAE():
    '''
    Training module for the entire system.
    '''
    def __init__(self, device, model, lmbda=1, blmbda=0.1, clmbda=0.4, clip=5, pad_idx=1):
        self.device = device
        self.model = model
        self.loss_fn = nn.BCEWithLogitsLoss()
        self.criterion = nn.CrossEntropyLoss()
        
        # Optimizers for different parts of the model
        self.optim_enc  = optim.Adam(self.model.param_enc)
        self.optim_dec  = optim.Adam(self.model.param_dec)
        self.optim_vae  = optim.Adam(self.model.param_vae)
        self.optim_disc = optim.Adam(self.model.param_disc)
        
        self.clip = clip
        self.lmbda = lmbda
        self.blmbda = blmbda
        self.clmbda = clmbda
        self.pad_idx = pad_idx
        
    def train(self, train_iterator, valid_iterator, epocs=1, trainDisc=True, trainVAE=True, trainGen=True, max_len=100, valid=True):
        '''
        trainDisc: boolean, train the discriminator
        trainVAE: boolean, train the VAE
        trainGen: boolean, train the Generator
        max_len: integer, maximum length of a sentence
        valid: boolean, evaluate on the valid set
        '''
        best_valid_loss = float('inf')
        for epoch in range(epocs):

            train_loss = self.train_epoc(train_iterator, epoch, trainDisc, trainVAE, trainGen, max_len)
            if valid:
                valid_loss = self.evaluate(valid_iterator)

                if valid_loss < best_valid_loss:
                    best_valid_loss = valid_loss

            print(f'Epoch: {epoch+1:02}')
            print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
            if valid:
                print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')
        
    def train_disc(self, data, label):
        '''
        train_disc trains the discriminator module
        '''
        self.optim_disc.zero_grad()
        
        # Compute the Discriminator Loss
        context, _, _, _ = self.model.discriminator(data)
        context = context.squeeze(1)

        # Compute the loss
        loss = self.loss_fn(context, label)

        # Compute the gradients
        loss.backward()

        # Clip the gradients
        torch.nn.utils.clip_grad_norm_(self.model.param_disc, self.clip)

        self.optim_disc.step() # update the parameters
        self.optim_disc.zero_grad()
        
        return loss.item()
        
    def train_vae(self, data, max_len=100, target=None):
        '''
        train_vae trains the VAE module
        '''
        batch_size = data.shape[1]
        loss = 0
        
        # if not defined, generate the target sentence
        if target is None:
            target = makeTarget(data, device=self.device, vocab_len=VOCAB_SIZE, pad_idx=self.pad_idx, max_len=max_len)
            
        # Compute VAE Loss
        # Split sentences up into max_len sizes
        self.optim_vae.zero_grad()
        mini_batch_size = max_len
        for i in range(0, data.shape[0], mini_batch_size):
            en = i + mini_batch_size
            if en > data.shape[0]:
                en = data.shape[0]
            if i != en:
                loss += self.train_vae_iter(data[i:en], target[i:en], max_len)
                torch.cuda.empty_cache()
                
        return loss
        
    def train_vae_iter(self, data, target, max_len):
        '''
        train_vae_iter computes the loss for each sentence
        '''
        output, chat, context, mu, sig = self.model(data, length=max_len, teacherForcing=True, loss=True)
        lo = self.clmbda * torch.mean(0.5 * torch.sum(torch.exp(sig) + mu**2 - 1 - sig, dim=1))
        trgt = target
        # Compute the loss
        lossDec = self.criterion(output.view(-1, VOCAB_SIZE), trgt.view(-1))  
        lossEnc = self.loss_fn(chat, context) * self.lmbda
        lossGen = lossDec + lo + lossEnc
        # Compute the gradients
        lossGen.backward()

        # Clip the gradients
        torch.nn.utils.clip_grad_norm_(self.model.param_vae, self.clip)

        self.optim_vae.step() # update the parameters
        self.optim_vae.zero_grad()
        
        return lossGen.item()
        
    def train_gen(self, data, label, max_len=100, target=None):
        '''
        train_gen trains the generator module
        '''
        if target is None:
            target = makeTarget(data, device=self.device, vocab_len=VOCAB_SIZE, pad_idx=self.pad_idx, max_len=max_len)
            
        batch_size = data.shape[1]
        if max_len > data.shape[0]:
            max_len = data.shape[0]
            
        # Calculate the generator loss on the generator
        loss = self.gen_1(data,label, max_len)
        
        # Calculate the generator loss on the encoder
        loss += self.gen_3(data, label, max_len)
        torch.cuda.empty_cache()
        
        # Calculate the generator loss on the decoder
        mini_batch_size = max_len
        for i in range(0, data.shape[0], mini_batch_size):
            en = i + mini_batch_size
            if en > data.shape[0]:
                en = data.shape[0]
            if i != en:
                loss += self.gen_2(data[i:en], target[i:en], max_len)
                torch.cuda.empty_cache()
        
        return loss
    
    def gen_1(self, data, label, max_len):
        '''
        Calculate the generator loss on the discriminator
        '''
        batch_size = data.shape[1]
        # Compute Loss for generator-discriminator
        self.optim_disc.zero_grad()

        x, chat, _ = self.model.generate(batch_size=batch_size, max_len=max_len)
        cstar_gen, _, _, _ = self.model.discriminator(x)
        cstar_real, _, _, _ = self.model.discriminator(data)
        
        # compute losses
        lo = F.log_softmax(cstar_gen, dim=1)
        ent = -lo.mean()
        lossDisc = self.loss_fn(cstar_real.squeeze(1), label) + (self.loss_fn(cstar_gen, chat) + self.blmbda*ent) * self.blmbda

        # Compute the gradients
        lossDisc.backward()

        # Clip the gradients
        torch.nn.utils.clip_grad_norm_(self.model.param_disc, self.clip)

        # update disc
        self.optim_disc.step() # update the parameters
        self.optim_disc.zero_grad()

        return lossDisc.item()
    
    def gen_2(self, data, target, max_len):
        '''
        Calculate the generator loss on the Decoder
        '''
        batch_size = data.shape[1]
        
        output, chat, context, mu, sig = self.model(data, length=max_len, teacherForcing=False, loss=True)
        lo = torch.abs(self.clmbda * torch.mean(0.5 * torch.sum(torch.exp(sig) + mu**2 - 1 - sig, 1)))
        
        x, cgen, lat = self.model.generate(batch_size=batch_size, max_len=max_len)
        
        
        _, zenc, _, _ = self.model.encoder(x)
        cdisc, _, _, _ = self.model.discriminator(x)
        
        # compute losses
        lossA = self.criterion(output.view(-1, VOCAB_SIZE), target.view(-1))  
        lossB = self.loss_fn(chat, context) * self.lmbda
        lossVAE = lossA + lo + lossB
        
        lossC = self.loss_fn(cdisc, cgen)
        lossL = F.mse_loss(zenc, lat)
        
        lossGen = lossVAE + lossC*self.blmbda + lossL*self.blmbda
        
        # back Prop loss
        lossGen.backward()
        torch.nn.utils.clip_grad_norm_(model.param_vae, self.clip)
        self.optim_dec.step()
        self.optim_dec.zero_grad()
        
        return lossGen.item()
        
    def gen_3(self, data, label, max_len):
        '''
        Calculate the generator loss on the Encoder
        '''
        batch_size = data.shape[1]
        
        # Compute losses
        _, chat, cstar, mu, sig = self.model(data, length=(max_len), teacherForcing=False, loss=True)
        lo = torch.abs(self.clmbda * torch.mean(0.5 * torch.sum(torch.exp(sig) + mu**2 - 1 - sig, 1)))
        
        context, _, _, _ = self.model.encoder(data)
        context = context.squeeze(1)
        
        lossEnc = lo + self.loss_fn(chat, cstar) + self.loss_fn(context,label)
        
        # backprop the loss
        lossEnc.backward()
        torch.nn.utils.clip_grad_norm_(self.model.param_enc, self.clip)
        self.optim_enc.step()
        self.optim_enc.zero_grad()      
        return lossEnc.item()
        
    def train_epoc(self, iterator, epoch, trainDisc=True, trainVAE=True, trainGen=True, max_len=100):
        '''
        train_epoc trains one epoch of the model
        '''
        self.model.train()
        epoch_loss = 0
        
        for i, batch in enumerate(iterator):
            torch.cuda.empty_cache()
            data, label = batch.text.to(self.device), batch.label.to(self.device)
            
            target = makeTarget(data, device=self.device, vocab_len=VOCAB_SIZE, pad_idx=self.pad_idx, max_len=max_len)
            lossD = 0
            lossV = 0
            lossG = 0
            
            # Train the discriminator
            if trainDisc:
                lossD = self.train_disc(data, label)
            # Train the VAE using teacher forcing
            if trainVAE:
                lossV += self.train_vae(data, max_len=max_len, target=target)
            # Train the Generator
            if trainGen:
                lossG += self.train_gen(data, label, max_len=max_len, target=target)
                    
            loss = lossD + lossV + lossG
            epoch_loss += loss
                
            if i % 100 == 0:
                print("epoch", epoch, "iter", i, "loss", loss)
                if trainDisc:
                    print("DISC loss", lossD)
                if trainVAE:
                    print("VAE loss", lossV)
                if trainGen:
                    print("Gen loss", lossG)
                torch.save(self.model.state_dict(), VAE_PATH)
                
        return epoch_loss / len(iterator)
    
        
    def evaluate(self, iterator):
        '''
        Evaluate the losses on the components
        '''
        self.model.eval()
        epoch_loss = 0

        with torch.no_grad():
            for i, batch in enumerate(iterator):
                data, label = batch.text.to(self.device), batch.label.to(self.device)
                max_len = data.shape[0]
                batch_size = data.shape[1]
                target = makeTarget(data, device=self.device, vocab_len=VOCAB_SIZE, pad_idx=self.pad_idx, max_len=max_len)

                # Compute loss for Discriminator
                context, _, _, _ = self.model.discriminator(data)
                context = context.squeeze(1)
                lossD = self.loss_fn(context, label)
                lossD = lossD.item()
                
                # Compute VAE Loss
                output, chat, context, mu, sig = self.model(data, teacherForcing=False, loss=True)
                lo = torch.abs(self.clmbda * torch.mean(0.5 * torch.sum(torch.exp(sig) + mu**2 - 1 - sig, 1)))
        
                lossDec = self.criterion(output.view(-1, VOCAB_SIZE), target.view(-1))  
                lossEnc = self.loss_fn(chat, context) * self.lmbda
                lossG = lossDec + lo + lossEnc
                lossG = lossG.item()

                # Compute Generator Loss
                output, chat, context, mu, sig = self.model(data, teacherForcing=False, loss=True)
                lo = torch.abs(self.clmbda * torch.mean(0.5 * torch.sum(torch.exp(sig) + mu**2 - 1 - sig, 1)))
        
                x, cgen, lat = self.model.generate(batch_size=batch_size, max_len=max_len)

                cenc, zenc, _, _ = self.model.encoder(x)
                cdisc, _, _, _ = self.model.discriminator(x)

                lossA = self.criterion(output.view(-1, VOCAB_SIZE), target.view(-1))  
                lossB = self.loss_fn(chat, context) * self.lmbda
                lossVAE = lossA + lo + lossB

                lossC = self.loss_fn(cdisc, cgen)
                lossL = F.mse_loss(zenc, lat)

                lossGen = lossVAE + lossC*self.blmbda + lossL*self.blmbda
                lossGen = lossGen.item()

                # Count and display the loss
                loss = lossD + lossG + lossGen
                epoch_loss += lossD + lossG + lossGen
                
                if i % 100 == 0:
                    print("iter", i, "Total loss", loss, "Discriminator loss", lossD, "VAE loss", lossG, "Generator loss", lossGen)
              
        return epoch_loss / len(iterator)
    

### Create and train the model

In [0]:
discriminator = RNNBinaryEncoder(VOCAB_SIZE_DISC, EMBEDDING_DIM, HIDDEN_DIM, CONTEXT_DIM, 0, device).to(device)
model = EncDec(device, discriminator, INPUT_DIM,  EMBEDDING_DIM, HIDDEN_DIM, CONTEXT_DIM, LATENT_DIM, pad_idx=PAD_IDX).to(device)
model.load_state_dict(torch.load(VAE_PATH))

trainer = TrainVAE(device, model, lmbda=0.1, blmbda=0.1)

print("Starting to train the Discriminator")
trainer.train(disc_train_iterator, disc_valid_iterator, trainDisc=True, trainVAE=False, trainGen=False, max_len=30, epocs=3, valid=False)

print("Starting to train the VAE")
trainer.train(disc_train_iterator, disc_valid_iterator, trainDisc=False, trainVAE=True, trainGen=False, max_len=30, epocs=3, valid=False)

print("Starting to train the Generator/Encoder and the Discriminator")
trainer.train(disc_train_iterator, disc_valid_iterator, trainDisc=False, trainVAE=False, trainGen=True, max_len=30, epocs=3, valid=False)

print("Finding loss of final model")
test_loss = trainer.evaluate(disc_test_iterator)

print(f'Test Loss: {test_loss:.3f}')
print("Done!")

# Save the model's parameters to file
torch.save(model.state_dict(), VAE_PATH)

## 4. Test the system

In [0]:
# Load the model from file
discriminator = RNNBinaryEncoder(VOCAB_SIZE_DISC, EMBEDDING_DIM, HIDDEN_DIM, CONTEXT_DIM, 0, device).to(device)
model = EncDec(device, discriminator, INPUT_DIM,  EMBEDDING_DIM, HIDDEN_DIM, CONTEXT_DIM, LATENT_DIM).to(device)
model.load_state_dict(torch.load(VAE_PATH))

discriminator.eval()
model.eval()

# Define several variables
pad = torch.Tensor(lookup_indexes(["<pad>"])).long().reshape((1,1)).to(device)
unk = torch.Tensor(lookup_indexes(["<unk>"])).long().reshape((1,1)).to(device)

# Create some context vectors
pos = 1
neg = -1
neu = 0
ptep = 0.5
ntep = -0.5

contextP = torch.Tensor([pos]).reshape((1,1)).to(device)
contextN = torch.Tensor([neg]).reshape((1,1)).to(device)
contextNEU = torch.Tensor([neu]).reshape((1,1)).to(device)
contextPT = torch.Tensor([ptep]).reshape((1,1)).to(device)
contextNT = torch.Tensor([ntep]).reshape((1,1)).to(device)

contextRand = torch.randn((1,1)).to(device)

# Create some starter tensors
starter = torch.randint(INPUT_DIM, (1,1), dtype=torch.long).to(device)
starterInit = torch.Tensor(lookup_indexes(["<init>"])).long().reshape((1,1)).to(device)
 
# Generate context and latent vectors from example sentences
contextR1, latentR1, _, _ = model.encoder(makeStarter("i went to see this movie last week .", device))
contextR2, latentR2, _, _ = model.encoder(makeStarter("i loved this movie !", device))
contextNE, latentNE, _, _ = model.encoder(makeStarter("this was the abolute worst movie ever do not see it i hate this movie it was terrible", device))
contextPE, latentPE, _, _ = model.encoder(makeStarter("good i love this movie . it was fantastic . the acting was excellent , the score was great", device))
  
latentZ = torch.zeros(1, LATENT_DIM).to(device)        

# For a demonstration use this cell:

In [0]:
# NOTE: punctuation needs to be sepearte from words by a space.
# EXAMPLE: "I loved it!" should be "I loved it !"

# For Demo, use any of the following contexts, latents, and starters
# Positive contexts
# contextP, contextPT, contextPE, contextR2

# Negative contexts
# contextN, contextNT, contextNE

# Positive Latent Vectors
# latentPE, latentR2

# Negative Latent Vectors
# LatentNE

# Emtpy starter
# starterInit

# Custom/Random starters
# starterCustum, starterRand
starterCustom = makeStarter("i went to see", device)
starterRand = makeStarter(1, device, random=True, starter=True)

# generate_beam inputs:
# generate_beam(starting text, context vector, latent vector, max length, beam size, use adaptive serch, derank "unk", suggested min length, random context)
# Note: All values are optional

# Generate a beam of results
output = model.generate_beam(text=starterInit, context=contextP, latent=latentP, maxlen=20, beam=15, adaptive=True, unk=True, min_len=5, rand=False)
#output = model.generate_beam(text=starterInit, context=None, latent=None, maxlen=20, beam=15, adaptive=True, unk=True, min_len=5)

# Display the beam of results
for beam in output:
    print(" ".join(lookup_words(beam[2])))