In [1]:
import numpy as np 
import os
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import math
from matplotlib import pyplot as plt

import sys
sys.path.append('../utils')

from text.utils import load_text_state, split_sentences, Vocabulary
from model.utils import get_device, load_model_state
from multitext_model import MultiTextModel

import torchtext

### Verify if CUDA is available

In [2]:
device = get_device()
print(device)

cuda


### Loading vocabulary

In [3]:
CHECKPOINT_BASE = 'checkpoints'
AUTHORS = ['Friedrich Nietzsche', 'Ernest Hemingway', 'Oscar Wilde']
SENTENCES_LENGTH = 22

In [4]:
vocabs = []

for author in AUTHORS:
    vocab, _, _ = load_text_state(os.path.join(CHECKPOINT_BASE, author + '_text.pk'))    
    vocabs.append(vocab)
    
    print('Text checkpoint loaded for {}'.format(author))

Text checkpoint loaded for Friedrich Nietzsche
Text checkpoint loaded for Ernest Hemingway
Text checkpoint loaded for Oscar Wilde


### Loading components

In [5]:
LEARNING_RATE = 1e-4
BETA_1 = 0.9
BETA_2 = 0.999
EPOCHS = 50
NUM_HEADS = 15
ENCODER_LAYERS = 2
DECODER_LAYERS = 1
EMBEDDING_SIZE = 512
FF_DIM = 1024
DROPOUT=0.2
STEP_LR_DECAY = 15
LR_FACTOR_DECAY = 0.7


VECTORS_LOADED = 40000

In [6]:
vocab = Vocabulary('Multitext')
fasttext = torchtext.vocab.FastText(language='en', max_vectors=VECTORS_LOADED - len(vocab), cache='../.vector_cache')

for word in fasttext.stoi.keys():
    vocab.add_word(word)       

In [7]:
class MultiTextModelPrediction(MultiTextModel):
    def forward(self, source, target):     
        outputs = [None] * self.num_decoders
        srcs = self.embedding(source)         
        srcs = self.pos_encoder(srcs.transpose(0, 1))
        
        for idx in range(self.num_decoders):                
            
            tgts = self.embedding(target[idx]) 
            tgts = self.pos_encoder(tgts.transpose(0, 1))
        
            outputs[idx] = self.decoders[idx](tgts, self.encoder(srcs))
            outputs[idx]  = self.linears[idx](outputs[idx].transpose(0, 1))
        
        return outputs
    

In [8]:
model = MultiTextModelPrediction(
    authors=AUTHORS,
    vocab_size=len(vocab),
    embedding_size=fasttext.dim, 
    num_heads=NUM_HEADS, 
    encoder_layers=ENCODER_LAYERS, 
    decoder_layers=DECODER_LAYERS, 
    dim_feedforward=FF_DIM,
    dropout=DROPOUT
)

optimizer = torch.optim.Adam(
    list(model.parameters()), 
    lr=LEARNING_RATE,
    betas=(BETA_1, BETA_2)
)

model, _, _, _, val_loss_history, least_validation_loss, _, _, _ = load_model_state(
    os.path.join(CHECKPOINT_BASE, 'best.pt'), 
    model, 
    optimizer
)

model

model.to(device)

for state in optimizer.state.values():
        for k, v in state.items():
            if isinstance(v, torch.Tensor):
                state[k] = v.to(device)

In [9]:
print('Validation loss: {}'.format(least_validation_loss))

Validation loss: 0.007192154897278586


In [10]:
len(val_loss_history)

49

### Predicting

In [11]:
PAD_TOKEN = vocab.word2index[vocab.PAD_STR]
START_TOKEN = vocab.word2index[vocab.START_STR]
END_TOKEN = vocab.word2index[vocab.END_STR]

In [12]:
def simplify_sentence(vocab, sentence):
    simplified = [word for word in sentence if word not in [PAD_TOKEN, START_TOKEN, END_TOKEN]]
    return ' '.join(vocab.to_words(simplified))

In [40]:
text = [
    'No medicine cures what happiness cannot',
    'It\'s enough for me to be sure that you and I exist at this moment',
    'sex is the consolation you have when you can\'t have love',
    'Nobody deserves your tears, but whoever deserves them will not make you cry'
]

In [41]:
sentences = []

for sentence in text:
    # split = [vocab.START_STR]
    split = []
    split.extend(split_sentences(sentence)[0][:SENTENCES_LENGTH])
    split.extend([vocab.PAD_STR] * ( SENTENCES_LENGTH - len(split) ))
    split.append(vocab.END_STR)
    split = [word.lower().strip() for word in split]
    indices = vocab.to_indices(split)    
    sentences.append(indices)

In [42]:
model.eval()

source = torch.LongTensor(sentences).to(device)
target = np.array([[START_TOKEN] * (SENTENCES_LENGTH + 1)]  * len(sentences))
target = [torch.LongTensor(target).to(device)] * len(AUTHORS)
for _ in range(SENTENCES_LENGTH):
    with torch.no_grad():
        predictions = model(source, target)
                
        target = []
        
        for idx, prediction in enumerate(predictions):
            sent_author = np.array([[START_TOKEN] * (SENTENCES_LENGTH + 1)]  * len(sentences))
            predicted = torch.argmax(torch.softmax(prediction, dim=2), dim=2)
            sent_author[:, 1:] = predicted[:, :-1].cpu().numpy()
            target.append(torch.LongTensor(sent_author).to(device))
        

for idx, sentence in enumerate(text):
    print('----------------------------------------------------------------------')
    print('Original: {}'.format(text[idx]))
    print('Input: {}'.format(simplify_sentence(vocab, sentences[idx])))
    
    for jdx, author in enumerate(AUTHORS):
        print('{}: {}'.format(author, simplify_sentence(vocab, target[jdx][idx].cpu().numpy())))

----------------------------------------------------------------------
Original: No medicine cures what happiness cannot
Input: no medicine cures what happiness cannot
Friedrich Nietzsche: no medicine cures what happiness cannot
Ernest Hemingway: no medicine tasting what happiness cannot
Oscar Wilde: no academy garlic what happiness cannot
----------------------------------------------------------------------
Original: It's enough for me to be sure that you and I exist at this moment
Input: it ' s enough for me to be sure that you and i exist at this moment
Friedrich Nietzsche: it ' s enough for me to be sure that you and i exist at this moment
Ernest Hemingway: it ' s enough for me to be sure that you and i matter at this moment
Oscar Wilde: it ' s enough for me to be sure that you and i exist at this moment
----------------------------------------------------------------------
Original: sex is the consolation you have when you can't have love
Input: sex is the consolation you have wh