In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
from collections import Counter
import os
from argparse import Namespace
import logging

logger = logging.getLogger(__name__)
logger.setLevel(level = logging.INFO)
handler = logging.FileHandler("lstm-3.log")
handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)

flags = Namespace(
    train_file_path='dict_no_stop_jieba',
    seq_size=32,
    batch_size=64,
    embedding_size=128,
    lstm_size=128,
    gradients_norm=5,
    predict_top_k=5,
    num_epochs=20,
    checkpoint_path='checkpoint',
)
logger.info(str(flags))

Some reference on LSTMs:
* Colah, <https://colah.github.io/posts/2015-08-Understanding-LSTMs/>
* Pytorch tutorial, <https://pytorch.org/tutorials/beginner/nlp/sequence_models_tutorial.html>
* Text Generation With Pytorch, <https://machinetalk.org/2019/02/08/text-generation-with-pytorch/>
* Language Modelling and Text Generation using LSTMs — Deep Learning for NLP, <https://medium.com/@shivambansal36/language-modelling-text-generation-using-lstms-deep-learning-for-nlp-ed36b224b275>

In [2]:
def get_data_from_file(file_path,batch_size,seq_size):
    text = []
    for i,file_name in enumerate(os.listdir(file_path),1):
        with open("{}/{}".format(file_path,file_name),"r",encoding="utf-8") as infile:
            for j,line in enumerate(infile):
#                 if j == 0 or line[0] == "（":
#                     continue
                text += line.split()
    word_counts = Counter(text)
    sorted_vocab = sorted(word_counts, key=word_counts.get, reverse=True)
    int_to_vocab = {k: w for k, w in enumerate(sorted_vocab)}
    vocab_to_int = {w: k for k, w in int_to_vocab.items()}
    n_vocab = len(int_to_vocab)

    print("Int_to_vocab",list(int_to_vocab.items())[:10])
    print('Vocabulary size', n_vocab)

    int_text = [vocab_to_int[w] for w in text]
    num_batches = int(len(int_text) / (seq_size * batch_size))
    in_text = int_text[:num_batches * batch_size * seq_size]
    out_text = np.zeros_like(in_text)
    out_text[:-1] = in_text[1:]
    out_text[-1] = in_text[0]
    in_text = np.reshape(in_text, (batch_size, -1))
    out_text = np.reshape(out_text, (batch_size, -1))
    print(in_text[:10, :10])
    print(out_text[:10, :10])
    return int_to_vocab, vocab_to_int, n_vocab, in_text, out_text

In [3]:
def get_batches(in_text, out_text, batch_size, seq_size):
    num_batches = np.prod(in_text.shape) // (seq_size * batch_size)
    for i in range(0, num_batches * seq_size, seq_size):
        yield in_text[:, i:i+seq_size], out_text[:, i:i+seq_size]

In [4]:
class RNNModule(nn.Module):
    def __init__(self, n_vocab, seq_size, embedding_size, lstm_size):
        super(RNNModule, self).__init__()
        self.seq_size = seq_size
        self.lstm_size = lstm_size
        self.embedding = nn.Embedding(n_vocab, embedding_size)
        self.lstm = nn.LSTM(embedding_size,
                            lstm_size,
                            batch_first=True)
        self.dense = nn.Linear(lstm_size, n_vocab)
        
    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.dense(output)

        return logits, state

    def zero_state(self, batch_size):
        return (torch.zeros(1, batch_size, self.lstm_size),
                torch.zeros(1, batch_size, self.lstm_size))

In [5]:
def get_loss_and_train_op(net, lr=0.001):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)

    return criterion, optimizer

In [6]:
import time

def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # device = torch.device('cpu')
    int_to_vocab, vocab_to_int, n_vocab, in_text, out_text = get_data_from_file(
        flags.train_file_path, flags.batch_size, flags.seq_size)

    net = RNNModule(n_vocab, flags.seq_size,
                    flags.embedding_size, flags.lstm_size)
    net = net.to(device)

    criterion, optimizer = get_loss_and_train_op(net, 0.01)

    iteration = 0

    start_time = time.time()
    for e in range(flags.num_epochs):
        batches = get_batches(in_text, out_text, flags.batch_size, flags.seq_size)
        state_h, state_c = net.zero_state(flags.batch_size)
        state_h = state_h.to(device)
        state_c = state_c.to(device)
        for x, y in batches:
            iteration += 1
            net.train()

            optimizer.zero_grad()

            x = torch.tensor(x).to(device)
            y = torch.tensor(y).to(device)

            logits, (state_h, state_c) = net(x, (state_h, state_c))
            loss = criterion(logits.transpose(1, 2), y)

            loss_value = loss.item()

            loss.backward()

            state_h = state_h.detach()
            state_c = state_c.detach()

            _ = torch.nn.utils.clip_grad_norm_(
                net.parameters(), flags.gradients_norm)

            optimizer.step()

            if iteration % 100 == 0:
                print('Epoch: {}/{}'.format(e+1, flags.num_epochs),
                      'Iteration: {}'.format(iteration),
                      'Loss: {}'.format(loss_value))
                logger.info('Epoch: {}/{} Iteration: {} Loss: {}'.format(e+1, flags.num_epochs, iteration, loss_value))

            if iteration % 1000 == 0:
                torch.save(net.state_dict(),
                           '{}/model-{}.pth'.format(flags.checkpoint_path,iteration))
    print("Time:{}s".format(time.time()-start_time))
    return net

In [21]:
import jieba

stopwords = [word[:-1] for word in open("stopwords.txt","r",encoding="utf-8")] # delete \n

def predict(device, net, question, n_vocab, vocab_to_int, int_to_vocab, top_k=5):
    net.eval()
    q_index = question_str.index("[MASK]")
    question_pre, question_post = question_str[:q_index], question_str[q_index+len("[MASK]"):]
    seg_pre, seg_post = jieba.lcut(question_pre,cut_all=False), jieba.lcut(question_post,cut_all=False)
    seg_pre.insert(0,"<BOS>")
    seg_post.insert(len(seg_post),"<EOS>")
    seg_pre_lst, seg_post_lst = [], []
    for word in seg_pre:
        if word not in stopwords and word != "\n":
            seg_pre_lst.append(word)
    for word in seg_post:
        if word not in stopwords and word != "\n":
            seg_post_lst.append(word)
    words = seg_pre_lst

    state_h, state_c = net.zero_state(1)
    state_h = state_h.to(device)
    state_c = state_c.to(device)
    for w in words:
        index = vocab_to_int.get(w,vocab_to_int["<BOS>"])
        ix = torch.tensor([[index]]).to(device)
        output, (state_h, state_c) = net(ix, (state_h, state_c))

    _, top_ix = torch.topk(output[0], k=top_k)
    choices = top_ix.tolist()
    choice = np.random.choice(choices[0])

#     return int_to_vocab[choice]
    return [int_to_vocab[x] for x in choices[0]]

In [8]:
if __name__ == "__main__":
    net = main()
    torch.save(net.state_dict(),'{}/model-{}.pth'.format(flags.checkpoint_path,"final"))

Int_to_vocab [(0, '<BOS>'), (1, '<EOS>'), (2, '月'), (3, '公司'), (4, '年'), (5, '日'), (6, '中'), (7, '中国'), (8, '5G'), (9, '新')]
Vocabulary size 48469
[[    0     1     0    19     2     5    31   785   777    23]
 [   62   155    79  1300  1233 18149  4255 18150   923   110]
 [ 3381     1     0  2267 13263    64    39    59  4872  2303]
 [ 8972    77    73    34     6  6492  1662   366  1662  1056]
 [ 5838  6042     1     0   941  6042 15493  1646  3927     4]
 [    5  1191  1259   731   197     4     2  1191 15617 13578]
 [  132    29   638   250  1244   489   530   532   440     1]
 [    6   138 33373     1     0     1     0 15890   258  2215]
 [ 2658  9604     1     0     1     0    99  2054  5683  4500]
 [ 3081   137 24479 16072  1573   137 24480  2328 24481 10018]]
[[    1     0    19     2     5    31   785   777    23    42]
 [  155    79  1300  1233 18149  4255 18150   923   110   148]
 [    1     0  2267 13263    64    39    59  4872  2303  2455]
 [   77    73    34     6  6492  

In [19]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
int_to_vocab, vocab_to_int, n_vocab, in_text, out_text = get_data_from_file(
    flags.train_file_path, flags.batch_size, flags.seq_size)
net=RNNModule(n_vocab, flags.seq_size,flags.embedding_size, flags.lstm_size)
net.load_state_dict(torch.load("checkpoint/model-final.pth"))
net.to(device)

Int_to_vocab [(0, '<BOS>'), (1, '<EOS>'), (2, '月'), (3, '公司'), (4, '年'), (5, '日'), (6, '中'), (7, '中国'), (8, '5G'), (9, '新')]
Vocabulary size 48469
[[    0     1     0    19     2     5    31   785   777    23]
 [   62   155    79  1300  1233 18149  4255 18150   923   110]
 [ 3381     1     0  2267 13263    64    39    59  4872  2303]
 [ 8972    77    73    34     6  6492  1662   366  1662  1056]
 [ 5838  6042     1     0   941  6042 15493  1646  3927     4]
 [    5  1191  1259   731   197     4     2  1191 15617 13578]
 [  132    29   638   250  1244   489   530   532   440     1]
 [    6   138 33373     1     0     1     0 15890   258  2215]
 [ 2658  9604     1     0     1     0    99  2054  5683  4500]
 [ 3081   137 24479 16072  1573   137 24480  2328 24481 10018]]
[[    1     0    19     2     5    31   785   777    23    42]
 [  155    79  1300  1233 18149  4255 18150   923   110   148]
 [    1     0  2267 13263    64    39    59  4872  2303  2455]
 [   77    73    34     6  6492  

RNNModule(
  (embedding): Embedding(48469, 128)
  (lstm): LSTM(128, 128, batch_first=True)
  (dense): Linear(in_features=128, out_features=48469, bias=True)
)

In [22]:
groundtrue = [line[:-1] for line in open("answer.txt","r",encoding="utf-8")]
acc = 0

myanswer = open("myanswer.txt","w",encoding="utf-8")
print("Use LSTM model to predict")

with open("questions.txt","r",encoding="utf-8") as question_file:
    for i,question_str in enumerate(question_file,1):
        pred = predict(device, net, question_str, n_vocab, vocab_to_int, int_to_vocab, top_k=5)
        if groundtrue[i-1] in pred:
            acc += 1
            print("{}√ [MASK] = {} - {}".format(i,pred,groundtrue[i-1]),flush=True)
        else:
            print("{} [MASK] = {} - {}".format(i,pred,groundtrue[i-1]),flush=True)
        myanswer.write("{}\n".format(pred[0]))
print("Accuracy: {:.2f}%".format(acc))

Use LSTM model to predict
1 [MASK] = ['色情', '不想', '<EOS>', '发生', '很难'] - 座椅
2 [MASK] = ['推测', '现状', '企业', '头部', '市'] - 汽车
3 [MASK] = ['B', '采访', '量产', 'APP', '每经'] - 欧洲
4 [MASK] = ['设计', '音效', '低价', '元素', '体系'] - 玻璃
5 [MASK] = ['就活', '增加', '<EOS>', '支付', '显得'] - 利润
6√ [MASK] = ['绑定', '故宫', '消费', '相关', '查看'] - 故宫
7 [MASK] = ['新能源', '固定', '销售', '客户', '品牌'] - 颁发
8 [MASK] = ['<EOS>', '新', '专注', '希望', '达成'] - 广阔
9 [MASK] = ['<EOS>', '这是', '计划', '推出', '包括'] - 对话
10 [MASK] = ['全球', '时间', '产品', '操作系统', '5G'] - 手机
11 [MASK] = ['未来', '<EOS>', '中', '最终', '推进'] - 神经网络
12 [MASK] = ['中', '技术', '公司', '实验室', '未成年人'] - 机器人
13 [MASK] = ['预期', '影响力', '深远', '优势', '短期'] - 应用
14√ [MASK] = ['增加', '增长', '提升', '提高', '经营'] - 增长
15 [MASK] = ['操作', '兆瓦', '运行', '算法', '识别'] - 学习
16 [MASK] = ['技术', '创新', 'AI', '支持', '带来'] - 门店
17 [MASK] = ['<EOS>', '密密麻麻', '进一步', '提供', '更'] - 干燥
18 [MASK] = ['棒棒糖', '身份验证', '电信公司', '情报机构', '努力'] - 政府
19√ [MASK] = ['厂商', '业务', '行业', '产品', '工作'] - 厂商
20 [MASK] = ['多达', '减少', '这项', '波音'