In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import sys
sys.path.append("../")

In [3]:
import pickle
from itertools import chain
import numpy as np
from tqdm import tqdm_notebook as tqdm
import matplotlib.pyplot as plt
import torch

In [4]:
import textgen

## Load models

In [5]:
vocab_path = textgen.get_data_path("vocab.pkl")
with open(vocab_path, "rb") as fin:
    posts_vocab, comments_vocab = pickle.load(fin)
posts_vocab, comments_vocab

(<Vocabulary: 27074 items, emb dim: (27074, 300)>, <Vocabulary: 1468 items>)

In [6]:
model_path = textgen.get_data_path("toxic-model-max-30-200.pt")
model = textgen.ToxicCommentModel(posts_vocab, comments_vocab, torch.device('cpu'))
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()

ToxicCommentModel(
  (embed): Embedding(27074, 300, padding_idx=27072)
  (tgt_embed): Embedding(1468, 200, padding_idx=1466)
  (avg): AdaptiveMaxPool2d(output_size=(1, 150))
  (fc1): Linear(in_features=150, out_features=30, bias=True)
  (rnn): GRU(200, 30, batch_first=True)
  (fc2): Linear(in_features=30, out_features=1468, bias=True)
)

## model prediction

In [7]:
toxic_data_path = textgen.get_data_path("toxic_data.pkl")
toxic_data = textgen.ToxicDataset.load(toxic_data_path)
toxic_data

<ToxicDataset: 240 samples>

In [8]:
from torch.utils.data import DataLoader
from functools import partial
collate_fn = partial(textgen.collate_fn, 
                     src_pad=posts_vocab.stoi["<PAD>"], 
                     tgt_pad=comments_vocab.stoi["<PAD>"])
xx, yy = next(iter(DataLoader(toxic_data, batch_size=2, collate_fn=collate_fn)))

In [9]:
intext = """
郭樹清在發言稿中強調，貿易戰不能解決任何問題，損人不利己且危害全世界。
從中國來看，美國固然可以把關稅加到極限水平，但是這對中國經濟的影響將非常有限。
國際上有觀點認為，中國經濟的快速發展是實行「國家壟斷資本主義」的結果，這種說法毫無根據。"""
text_vec = np.vstack([textgen.convert_text(intext, posts_vocab)] * 2)
in_tensor = torch.LongTensor(text_vec)

Building prefix dict from the default dictionary ...
Loading model from cache C:\Users\seantyh\AppData\Local\Temp\jieba.cache
Loading model cost 0.645 seconds.
Prefix dict has been built succesfully.


In [10]:
model.tgt_soi

1464

In [24]:
out = model(xx, yy)
textgen.inverse_convert_comment(out.argmax(2).numpy().tolist()[0], comments_vocab)

tensor(1464)


['<SOS>',
 '老蔡脯，沒多久',
 '可以',
 '不要再',
 '干預股市了嗎',
 '<EOS>',
 '<PAD>',
 '<PAD>',
 '<PAD>',
 '<PAD>',
 '<PAD>',
 '<PAD>']

In [22]:
out = model(xx)
textgen.inverse_convert_comment(out.argmax(2).numpy().tolist()[0], comments_vocab)

tensor([-1.0000, -1.0000, -1.0000,  ..., -0.9972, -1.0000, -1.0000],
       grad_fn=<SelectBackward>)


['<SOS>',
 '<SOS>',
 '<SOS>',
 '<SOS>',
 '<SOS>',
 '<SOS>',
 '<SOS>',
 '<SOS>',
 '<SOS>',
 '<SOS>',
 '<SOS>',
 '<SOS>',
 '<SOS>',
 '<SOS>',
 '<SOS>',
 '<SOS>',
 '<SOS>',
 '<SOS>',
 '<SOS>',
 '<SOS>']