In [1]:
from data_generators import get_iterator, get_dataset
from classifiers import theme_classifier
import torch
import torch.nn as nn
from torchtext.vocab import GloVe 
GLOVE_EMBEDDING = GloVe(name="6B", dim=300)

In [2]:
train_dataset, val_dataset, test_dataset, review_text_FIELD, theme_FIELD = get_dataset(vectors = GLOVE_EMBEDDING)

In [3]:
batch_size = 20
train_iter = get_iterator(train_dataset, batch_size, train=True, shuffle=True, repeat=False)
val_iter = get_iterator(val_dataset, batch_size, train=False, shuffle=True, repeat=False)
test_iter = get_iterator(test_dataset, batch_size, train=False, shuffle=True, repeat=False)

In [4]:
test_list = list(test_iter)

In [5]:
from baseline_model import BaseModel, repackage_hidden

In [6]:
with open('./baseline/best_model_base_model_ppl_67.5461793533482.model', 'rb') as file:
    model = torch.load(file)
    model.eval()

In [7]:
vocab_size = 12304

In [8]:
import numpy as np

In [9]:
BOS_WORD = '<sos>'
EOS_WORD = '<eos>'

In [11]:
def generate_next_n_words(self, first_word, label, max_len = 3, # ngram=3
                            target_vocab = review_text_FIELD, top_k = 10):

    self.eval()
    with torch.no_grad():      

        labels = torch.tensor([label]).cuda().long()
        source = torch.tensor([first_word]).unsqueeze(1).cuda().long()
        
        eos_idx = target_vocab.vocab.stoi[EOS_WORD]
        y_hat = source[:1,:]
        y_hat = y_hat.unsqueeze(-1)

        prediction_idxs = [[y_hat[0]]]
        prediction_probs = [0]

        next_hidden = [None]
        
        for t in range(max_len):
            current_idxs = []
            current_probs = []
            current_hidd = [] 

            for k in range(y_hat.shape[-1]):
                y_hat_k = y_hat[:,:,k]
                current_hidd_k = next_hidden[k]
                output, hidden = self(y_hat_k, labels, current_hidd_k)
                current_hidd += [hidden]*top_k
                top_probs, top_idx = output.topk(top_k, -1)
                current_idxs.append(top_idx.squeeze(0).squeeze(0))
                current_probs.append(top_probs.squeeze(0).squeeze(0))
            current_idxs = torch.cat(current_idxs)
            current_probs = torch.cat(current_probs)

            next_args = torch.sort(current_probs, -1, True)[1][:top_k] 

            next_idxs, next_probs, next_hidden = [], [], []
            tmp_pred_idxs, tmp_pred_probs = [], []
            for idx in next_args:
                chain_head = idx//top_k

                next_idx = current_idxs[idx].int()
                next_prob = current_probs[idx]
                sentence_so_far = prediction_idxs[chain_head]
                
                if len(sentence_so_far)>0 and sentence_so_far[-1] == eos_idx:
                    tmp_pred_idxs.append(sentence_so_far + [eos_idx])
                    tmp_pred_probs.append(prediction_probs[chain_head])
                else:
                    tmp_pred_idxs.append(sentence_so_far + [next_idx])
                    tmp_pred_probs.append(prediction_probs[chain_head] + next_prob)

                next_idxs.append(next_idx.unsqueeze(0))
                next_hidden.append(current_hidd[idx])

            prediction_idxs = tmp_pred_idxs
            prediction_probs = tmp_pred_probs
            next_idxs = torch.cat(next_idxs)

            y_hat = next_idxs.unsqueeze(0).unsqueeze(0).long()
            
        prediction_probs = torch.tensor(prediction_probs)
        prediction_idxs = torch.tensor(prediction_idxs)
        correct_order = torch.argsort(prediction_probs, descending=True).long()
        
        prediction_idxs = torch.index_select(prediction_idxs, 0, correct_order)

        return prediction_idxs, labels

In [12]:
test_list[100].review_text.shape

torch.Size([5, 20])

In [13]:
def transtaltion2string(raw_translations, target_vocab = review_text_FIELD, max_words=30000):
    string_translations = []
    for raw_sentence in raw_translations:
        string_sentence = []
        for i, word_idx in enumerate(raw_sentence):
            if i == max_words: break
            word = target_vocab.vocab.itos[word_idx]
            if word == '<eos>':
                break
            if word != '<sos>':
                string_sentence.append(word)
        string_translations.append(string_sentence)

    return string_translations

(tensor([[  7,  14,  46, 187],
         [  7,  14,  46, 154],
         [  7,  11,  76,  26],
         [  7,  11,  76,  27],
         [  7,  11,  76,  36],
         [  7,  14,  36,  10],
         [  7,  11,  62,   8],
         [  7,  14,  20,  26],
         [  7,  14,  20, 737],
         [  7,  14, 165,   4]]), tensor([3], device='cuda:0'))

In [33]:
import pandas as pd

In [26]:
#review_text_FIELD.vocab.itos

In [24]:
theme_FIELD.vocab.itos

['<unk>', 'other', 'plot', 'acting', 'effect', 'production']

In [39]:
results = []
for i in range(30):
    pred, label = generate_next_n_words(model, i+4, 4, max_len=10)
    pred = transtaltion2string(pred[:1])
    string = ' '.join(pred[0])
    predicted_class = theme_classifier(string)
    results.append([theme_FIELD.vocab.itos[label+1], string, predicted_class])
pd.DataFrame(results, columns=['true_label', 'generation', 'predicted'])

Unnamed: 0,true_label,generation,predicted
0,production,the story is a little too long for the sake of,plot
1,production,. i really did n't care about any of my favorite,other
2,production,", i ' ve never been a huge fan of the",other
3,production,and i ' ve never been a huge fan of the,other
4,production,a little slow and the director could ' ve easi...,production
5,production,of an old director who does n't seem to have a,production
6,production,"to the point of the production , i would n't have",production
7,production,it could ' ve easily been a better director .,production
8,production,"is one of my favorite films of the same name ,",other
9,production,"in the end , it would ' ve easily been a",other


In [212]:
#transtaltion2string(test_list[200].review_text[1,:].tolist())