In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from net import Net
from utils import *
%matplotlib inline

In [2]:
batch_size = 64
max_vocab = 18280

hidden_dim = 128
num_layers = 2

ckpt_path = "./checkpoint/sst_20.pth"
data_root = "./data"

In [3]:
# prepare SST dataset
train_iter, val_iter, test_iter, sst_info = load_sst(batch_size, max_vocab)
vocab_size = sst_info["vocab_size"]
num_class  = sst_info["num_class"]
TEXT = sst_info["TEXT"]
print("[!] vocab_size: {}, num_class: {}".format(vocab_size, num_class))

[!] vocab_size: 18280, num_class: 5


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = Net(TEXT, 
          hidden_dim,
          num_layers, num_class).to(device)

# load pre-trained model
state_dict = torch.load(ckpt_path)
net.load_state_dict(state_dict)

In [5]:
net.eval()
with torch.no_grad():
    for step, inputs in enumerate(test_iter):
        X = inputs.text.to(device)
        y = inputs.label.to(device)
        
        pred_y = net(X)
        _, pred_y = torch.max(pred_y.detach(), 1)

        break # run only first batch

  return Variable(arr, volatile=not train)


In [6]:
def get_label_str(label, fine_grained=True):
    pre = "very" if fine_grained else ""
    return {0: pre + " negative", 1: "negative", 2: "neutral",
            3: "positive", 4: pre + " positive", None: None}[label]


def indices_to_string(indices, TEXT):
    sentence = list()    
    for idx in indices:
        char = TEXT.vocab.itos[idx.item()]
        
        # ignore <pad> symbol
        if char in ["<pad>"]:
            continue

        # no needs of space between the special symbols
        if char[0] in ["'", ".", "?", "!", ","]:
            sentence[-1] += char
        else:
            sentence.append(char)
        
    return " ".join(sentence)

In [7]:
pprint = "{0:15} {1:15} {2}"
print(pprint.format("Label", "Predict", "Text"))
print("="*60)

for i in range(batch_size):
    indices = X[:,i].cpu().numpy()
    text_repr = indices_to_string(indices, TEXT)
    
    label_repr = get_label_str(y[i].item())
    pred_repr = get_label_str(pred_y[i].item())
    
    print(pprint.format(label_repr, pred_repr, text_repr))

Label           Predict         Text
neutral         negative        Some actors steal scenes.
very positive   very positive   A very bad sign.
very positive   very positive   A dreadful live-action movie.
negative        negative        But I was n't.
neutral         negative        Largely a <unk> <unk>.
neutral         negative        Both awful and appealing.
neutral         very positive   It's a trifle.
negative        negative        <unk> this is not.
neutral         neutral         So I just did.
positive        neutral         Much <unk> for all.
positive        positive        well worth the time.
neutral         very negative   Boisterous and daft documentary.
very negative   positive        A vivid cinematic portrait.
positive        positive        <unk> and ingenious entertainment.
very negative   very negative   A pleasant romantic comedy.
very negative   neutral         He's Super Spy!
neutral         neutral         Bourne, Jason Bourne.
positive        positive      