In [93]:
import sys

import torch
from torch import nn
from torch.nn import functional as F
import numpy as np




class TextCNN(nn.Module):
    def __init__(self, param: dict):
        super().__init__()
        ci = 1  # input chanel size
        kernel_num = param['kernel_num'] # output chanel size
        kernel_size = param['kernel_size']
        vocab_size = param['vocab_size']
        embed_dim = param['embed_dim']
        dropout = param['dropout']
        class_num = param['class_num']
        self.param = param
        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=1)
        self.conv11 = nn.Conv2d(ci, kernel_num, (kernel_size[0], embed_dim))
        self.conv12 = nn.Conv2d(ci, kernel_num, (kernel_size[1], embed_dim))
        self.conv13 = nn.Conv2d(ci, kernel_num, (kernel_size[2], embed_dim))
        self.dropout = nn.Dropout(dropout)
        self.fc1 = nn.Linear(len(kernel_size) * kernel_num, class_num)

    @staticmethod
    def conv_and_pool(x, conv):
        # x: (batch, 1, sentence_length, embed_dim)
        x = conv(x)
        # x: (batch, kernel_num, H_out, 1)
        x = F.relu(x.squeeze(3))
        # x: (batch, kernel_num, H_out)
        x = F.max_pool1d(x, x.size(2)).squeeze(2)
        #  (batch, kernel_num)
        return x

    def forward(self, x):
        # x: (batch, sentence_length)
        x = self.embed(x)
        # x: (batch, sentence_length, embed_dim)
        # TODO init embed matrix with pre-trained
        x = x.unsqueeze(1)
        # x: (batch, 1, sentence_length, embed_dim)
        x1 = self.conv_and_pool(x, self.conv11)  # (batch, kernel_num)
        x2 = self.conv_and_pool(x, self.conv12)  # (batch, kernel_num)
        x3 = self.conv_and_pool(x, self.conv13)  # (batch, kernel_num)
        x = torch.cat((x1, x2, x3), 1)  # (batch, 3 * kernel_num)
        x = self.dropout(x)
        logit = F.log_softmax(self.fc1(x), dim=1)
        # logit = F.softmax(self.fc1(x), dim=1)
        # logit = self.fc1(x)
        return logit

In [2]:
def read_data(path):
    s = []
    label = []
    with open(path, 'r') as f:
        data = f.readlines()
    for i in data:
        label.append(i.split(" ")[0])
        s.append(i.strip("\n").split(" ")[1:])
    return s, label

In [3]:
train_x, train_y = read_data('/Users/zhouzhirui/project/Task-Oriented-Chatbot/corpus/intent/fastText/demo.train.txt')
test_x, test_y = read_data('/Users/zhouzhirui/project/Task-Oriented-Chatbot/corpus/intent/fastText/demo.test.txt')

In [4]:
word2id = {"UNK": 0, "PAD": 1}
idx = 2
for s in train_x:
    for w in s:
        if w in word2id:
            pass
        else:
            word2id[w] = idx
            idx += 1

label2id = {}
label_idx = 0
for l in train_y:
    if l in label2id:
        pass
    else:
        label2id[l] = label_idx
        label_idx+=1

In [5]:
def convert_text(sentences, d, max_length=10):
    result = []
    for s in sentences:
        s = s[:max_length]
        s = ["PAD"] * (max_length - len(s)) + s
        s = [d.get(w, 0) for w in s]
        result.append(s)
    return np.array(result)

def conver_label(labels, d):
    result = [d[l] for l in labels]
    return np.array(result)

In [6]:
train_x = convert_text(train_x, word2id)
train_y = conver_label(train_y, label2id)
test_x = convert_text(test_x, word2id)
test_y = conver_label(test_y, label2id)

In [7]:
def get_batch(x, y, batch=100):
    assert x.shape[0] == y.shape[0]
    size = x.shape[0]
    idx = np.array(list(range(0, size)))
    np.random.shuffle(idx)
    x = x[idx].copy()
    y = y[idx].copy()
    n = size // batch
    for i in range(n):
        yield x[batch*i: batch*(i+1)], y[batch*i: batch*(i+1)]

In [8]:
test_x = torch.tensor(test_x)
test_y = torch.tensor(test_y)

In [138]:
def train(model, lr, epochs):
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
#     optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)
    for epoch in range(epochs):
        for step, (x, y) in enumerate(get_batch(train_x, train_y)):
            x = torch.tensor(x)
            y = torch.tensor(y)
            optimizer.zero_grad()
            logit = model(x)
            loss = F.cross_entropy(logit, y)
            # loss = F.nll_loss(logit, y)
            loss.backward()
            optimizer.step()
            if step % 100 == 0:
                model.eval()
                eval_logit = model(test_x)
                eval_loss = F.cross_entropy(eval_logit, test_y)
                model.train()
                print("epoch: {:>2}, step: {:>4} ,train loss: {:.6f}, eval loss: {:.6f}".format(epoch, step, loss, eval_loss))

In [139]:
textCNN_param = {
    "vocab_size": len(word2id),
    "embed_dim": 40,
    "class_num": len(label2id),
    "kernel_num": 16,
    "kernel_size": [3, 4, 5],
    "dropout": 0.5,
}

model = TextCNN(textCNN_param)

In [140]:
train(model, 0.01, 4)
train(model, 0.001, 6)
# train(model, 0.0001, 4)
# train(model, 0.00001, 2)

epoch:  0, step:    0 ,train loss: 4.477167, eval loss: 4.021335
epoch:  0, step:  100 ,train loss: 1.443769, eval loss: 0.983097
epoch:  0, step:  200 ,train loss: 0.954692, eval loss: 0.652551
epoch:  0, step:  300 ,train loss: 0.638155, eval loss: 0.524966
epoch:  1, step:    0 ,train loss: 0.514146, eval loss: 0.460144
epoch:  1, step:  100 ,train loss: 0.721711, eval loss: 0.420159
epoch:  1, step:  200 ,train loss: 0.815014, eval loss: 0.392086
epoch:  1, step:  300 ,train loss: 0.485019, eval loss: 0.383653
epoch:  2, step:    0 ,train loss: 0.676553, eval loss: 0.354957
epoch:  2, step:  100 ,train loss: 0.292083, eval loss: 0.349564
epoch:  2, step:  200 ,train loss: 0.656120, eval loss: 0.352064
epoch:  2, step:  300 ,train loss: 0.283528, eval loss: 0.335862
epoch:  3, step:    0 ,train loss: 0.396631, eval loss: 0.334483
epoch:  3, step:  100 ,train loss: 0.588300, eval loss: 0.329411
epoch:  3, step:  200 ,train loss: 0.537329, eval loss: 0.336878
epoch:  3, step:  300 ,tr

In [141]:
model.eval()
y_pred = model(test_x)
c = 0
for idx, i in enumerate(torch.argmax(y_pred, 1)):
    if test_y[idx] == i:
        c += 1
print(c / len(test_y))

0.9419343131917982
