In [1]:
from data_generators import get_iterator, get_dataset
from classifiers import theme_classifier
import torch
import torch.nn as nn
from torchtext.vocab import GloVe 
GLOVE_EMBEDDING = GloVe(name="6B", dim=300)

In [2]:
train_dataset, val_dataset, test_dataset, review_text_FIELD, theme_FIELD = get_dataset(vectors = GLOVE_EMBEDDING)

In [3]:
batch_size = 20
train_iter = get_iterator(train_dataset, batch_size, train=True, shuffle=True, repeat=False)
val_iter = get_iterator(val_dataset, batch_size, train=False, shuffle=True, repeat=False)
test_iter = get_iterator(test_dataset, batch_size, train=False, shuffle=True, repeat=False)

In [4]:
test_list = list(test_iter)

In [5]:
def evaluate(model, data_source, criterion):
    model.eval()
    total_loss_e = 0
    total_number_of_words = 0

    with torch.no_grad():
        for i, batch in enumerate(data_source):
            labels = batch.theme.cuda().long() - 1
            batch = batch.review_text.cuda().long()
            hidden = None
            if batch.shape[0] > 3:
                data, targets = batch[1:-1,:], batch[2:,:]
                
                output, hidden = model(data, labels, hidden)
                output_flat = output.contiguous().view(-1, vocab_size)
                target_flat = targets.contiguous().view(-1)
                batch_loss = criterion(output_flat, target_flat).detach().item()

                number_of_words = data.shape[0] * data.shape[1]
                total_loss_e += batch_loss * number_of_words
                total_number_of_words += number_of_words

                repackage_hidden(hidden)
            
    return total_loss_e / total_number_of_words

In [6]:
from baseline_model import BaseModel, repackage_hidden

In [7]:
with open('./baseline/best_model_base_model_ppl_67.5461793533482.model', 'rb') as file:
    model = torch.load(file)
    model.eval()

In [8]:
criterion = nn.CrossEntropyLoss(reduction='mean',
                       ignore_index=train_dataset.fields["review_text"].vocab.stoi['<pad>']).cuda()

In [9]:
vocab_size = 12304
test_error = evaluate(model, test_iter, criterion)

In [10]:
import numpy as np
test_error, np.exp(test_error)

(4.206123341756414, 67.09592700457719)

In [11]:
BOS_WORD = '<sos>'
EOS_WORD = '<eos>'

In [257]:
def beam_search_translation(self, source, idx, max_len = MAX_LEN, # ngram=3
                            target_vocab = review_text_FIELD, top_k = 5):

    self.eval()
    with torch.no_grad():      

        labels = source.theme.cuda().long() - 1
        source = source.review_text.cuda().long()
        
        labels = labels[idx].unsqueeze(0)
        source = source[:,idx].unsqueeze(1)
        
        batch_size = source.shape[1] 
        assert batch_size == 1
        
        eos_idx = target_vocab.vocab.stoi[EOS_WORD]
        y_hat = source[:1,:]
        y_hat = y_hat.unsqueeze(-1)

        #print(y_hat[0])
        prediction_idxs = [[y_hat[0]]] #[[]] #bos_idx
        prediction_probs = [0]

        if source.shape[0] > 3:
            data, targets = source[1:-1,:], source[2:,:]
        else:
            return prediction_idxs
        
        #encoder_output, hidden = self.encoder(source)
        #hidden = (hidden[0][-1:], hidden[1][-1:])
        next_hidden = [None]
        
        for t in range(max_len):
            current_idxs = []
            current_probs = []
            current_hidd = [] 

            for k in range(y_hat.shape[-1]):
                y_hat_k = y_hat[:,:,k]
                current_hidd_k = next_hidden[k]
                output, hidden = self(y_hat_k, labels, current_hidd_k)
                current_hidd += [hidden]*top_k
                top_probs, top_idx = output.topk(top_k, -1)
                current_idxs.append(top_idx.squeeze(0).squeeze(0))
                current_probs.append(top_probs.squeeze(0).squeeze(0))
            current_idxs = torch.cat(current_idxs)
            current_probs = torch.cat(current_probs)

            next_args = torch.sort(current_probs, -1, True)[1][:top_k] 

            next_idxs, next_probs, next_hidden = [], [], []
            tmp_pred_idxs, tmp_pred_probs = [], []
            for idx in next_args:
                chain_head = idx//top_k

                next_idx = current_idxs[idx].int()
                next_prob = current_probs[idx]
                sentence_so_far = prediction_idxs[chain_head]
                
                if len(sentence_so_far)>0 and sentence_so_far[-1] == eos_idx:
                    tmp_pred_idxs.append(sentence_so_far + [eos_idx])
                    tmp_pred_probs.append(prediction_probs[chain_head])
                else:
                    tmp_pred_idxs.append(sentence_so_far + [next_idx])
                    tmp_pred_probs.append(prediction_probs[chain_head] + next_prob)

                next_idxs.append(next_idx.unsqueeze(0))
                next_hidden.append(current_hidd[idx])

            prediction_idxs = tmp_pred_idxs
            prediction_probs = tmp_pred_probs
            next_idxs = torch.cat(next_idxs)

            y_hat = next_idxs.unsqueeze(0).unsqueeze(0).long()
            
        prediction_probs = torch.tensor(prediction_probs)
        prediction_idxs = torch.tensor(prediction_idxs)
        correct_order = torch.argsort(prediction_probs, descending=True).long()
        
        prediction_idxs = torch.index_select(prediction_idxs, 0, correct_order)

        return prediction_idxs, labels

In [258]:
test_list[100].review_text.shape

torch.Size([5, 20])

In [259]:
MAX_LEN = 30

In [260]:
def transtaltion2string(raw_translations, target_vocab = review_text_FIELD, max_words=30000):
    string_translations = []
    for raw_sentence in raw_translations:
        string_sentence = []
        for i, word_idx in enumerate(raw_sentence):
            if i == max_words: break
            word = target_vocab.vocab.itos[word_idx]
            if word == '<eos>':
                break
            if word != '<sos>':
                string_sentence.append(word)
        string_translations.append(string_sentence)

    return string_translations

In [265]:
beam_search_translation(model, 7, 3, max_len=20, top_k=10)

[[tensor([[7]])]]

In [246]:
for i in range(20):
    pred, label = beam_search_translation(model, test_list[200], idx = i, max_len=20, top_k=10)
    print(pred[:1])
    pred = transtaltion2string(pred[:1])
    string = ' '.join(pred[0])
    predicted_class = theme_classifier(string)
    print(theme_FIELD.vocab.itos[label+1], '|', string, '|', predicted_class)

tensor([1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0],
       device='cuda:0') torch.Size([6, 20])
tensor([1], device='cuda:0') torch.Size([6, 1])


ZeroDivisionError: division by zero

In [212]:
#transtaltion2string(test_list[200].review_text[1,:].tolist())