In [1]:
import numpy as np 
import os
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import math
from matplotlib import pyplot as plt

import sys
sys.path.append('../utils')

from text.utils import load_text_state, split_sentences
from model.utils import get_device, load_model_state
from text_model import TextModel

### Verify if CUDA is available

In [2]:
device = get_device()
print(device)

cuda


### Loading vocabulary

In [3]:
CHECKPOINT_BASE = 'checkpoints'
AUTHORS = ['Ernest Hemingway', 'Friedrich Nietzsche', 'Oscar Wilde']
SENTENCES_LENGTHS = [15, 27, 16]

In [4]:
vocabs = []

for author in AUTHORS:
    vocab, _, _ = load_text_state(os.path.join(CHECKPOINT_BASE, author + '_text.pk'))    
    vocabs.append(vocab)
    
    print('Text checkpoint loaded for {}'.format(author))

Text checkpoint loaded for Ernest Hemingway
Text checkpoint loaded for Friedrich Nietzsche
Text checkpoint loaded for Oscar Wilde


### Loading components

In [5]:
LEARNING_RATE = 1e-4
BETA_1 = 0.1
BETA_2 = 0.999
EPOCHS = 50
NUM_HEADS = 8
ENCODER_LAYERS = 1
DECODER_LAYERS = 1
EMBEDDING_SIZE = 512
FF_DIM = 1024
DROPOUT=0.1
STEP_LR_DECAY = 15
LR_FACTOR_DECAY = 0.5

In [6]:
models = []

for idx, author in enumerate(AUTHORS):
    model = TextModel(
        vocab_size=len(vocabs[idx]),
        embedding_size=EMBEDDING_SIZE, 
        num_heads=NUM_HEADS, 
        encoder_layers=ENCODER_LAYERS, 
        decoder_layers=DECODER_LAYERS, 
        dim_feedforward=FF_DIM,
        dropout=DROPOUT
    )

    optimizer = torch.optim.Adam(
        list(model.parameters()), 
        lr=LEARNING_RATE,
        betas=(BETA_1, BETA_2)
    )
    
    model, _, _, _, _, _ = load_model_state(
        os.path.join(CHECKPOINT_BASE, author + '_best.pt'), 
        model, 
        optimizer
    )

    model.to(device)

    for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)
                    
    models.append(model)
    print('Model checkpoint loaded for {}'.format(author))

Model checkpoint loaded for Ernest Hemingway
Model checkpoint loaded for Friedrich Nietzsche
Model checkpoint loaded for Oscar Wilde


### Predicting

In [7]:
text = [
    'The man said he was angry and wet',
    'I don\'t believe it Mr. Grey',
    'random said believe man'
]

In [8]:
batches = []

for idx, length in enumerate(SENTENCES_LENGTHS):
    sentences = []
    for sentence in text:
        # split = [vocab.START_STR]
        split = []
        split.extend(split_sentences(sentence)[0][:length])
        split.extend([vocabs[idx].PAD_STR] * ( length - len(split) ))
        split.append(vocabs[idx].END_STR)
        split = [word.lower().strip() for word in split]
        indices = vocabs[idx].to_indices(split)    
        sentences.append(indices)
        
    batches.append(sentences)

In [9]:
model.eval()

for idx, author in enumerate(AUTHORS):
    
    print('**********************************************************************')
    print('Author: {}'.format(author))
    
    sentences = batches[idx]
    target = np.array([[vocabs[idx].START_TOKEN] * (SENTENCES_LENGTHS[idx] + 1)]  * len(sentences))
    
    for word_idx in range(SENTENCES_LENGTHS[idx]):
        with torch.no_grad():
            predicted = model(torch.LongTensor(sentences).to(device), torch.LongTensor(target).to(device))
            predicted = torch.argmax(predicted, dim=2).cpu().numpy()

        target[:, 1:] = predicted[:, :-1]

    for jdx, sentence in enumerate(predicted):
        print('----------------------------------------------------------------------')
        print('Original: {}'.format(text[jdx]))
        print('Input: {}'.format(' '.join(vocabs[idx].to_words(sentences[jdx]))))
        print('Output: {}'.format(' '.join(vocabs[idx].to_words(sentence))))

**********************************************************************
Author: Ernest Hemingway
----------------------------------------------------------------------
Original: The man said he was angry and wet
Input: the man said he was <unk> and wet <pad> <pad> <pad> <pad> <pad> <pad> <pad> <eos>
Output: the <unk> said he was <unk> and wet <pad> <pad> <pad> <pad> <pad> <pad> <pad> <eos>
----------------------------------------------------------------------
Original: I don't believe it Mr. Grey
Input: i don ' t believe it <unk> <unk> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <eos>
Output: i don ' <unk> believe it <unk> <unk> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <eos>
----------------------------------------------------------------------
Original: random said believe man
Input: <unk> said believe man <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <eos>
Output: <unk> said believe <unk> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <eos>
******