In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import pandas as pd
import jieba
import re

In [2]:
class MyDataset(Dataset):
    def __init__(self, max_len, debug=True):
        super().__init__()
        df = pd.read_csv('../../datasets/THUCNews/train.csv')
        df = df.dropna()
        if debug:
            df = df.sample(2000).reset_index(drop=True)
        else:
            df = df.sample(50000).reset_index(drop=True)
        # 读取常用停用词
        stopwords = [line.strip() for line in open('../../stopwords/cn_stopwords.txt', 'r', encoding='utf-8').readlines()]
        sentences = []
        for title in df['title']:
            # 去除标点符号
            title = re.sub(r'[^\u4e00-\u9fa5]', '', title)
            # jieba分词
            sentence_seged = jieba.cut(title.strip())
            outstr = ''
            for word in sentence_seged:
                if word != '\t' and word not in stopwords:
                    outstr += word
                    outstr += ' '
            if outstr != '':
                sentences.append(outstr)
        # 获取所有词(token), <pad>用来填充不满足max_len的句子
        token_list = ['<pad>'] + list(set(' '.join(sentences).split()))
        # token和index互转字典
        self.token2idx = {token: i for i, token in enumerate(token_list)}
        self.idx2token = {i: token for i, token in enumerate(token_list)}
        self.vocab_size = len(self.token2idx)

        self.inputs = []
        for sentence in sentences:
            tokens = sentence.split()
            input_ = [self.token2idx[token] for token in tokens]
            if len(input_) < max_len:
                self.inputs.append(input_ + [self.token2idx['<pad>']] * (max_len - len(input_)))
            else:
                self.inputs.append(input_[: max_len])

        self.labels = [[label] for label in df['label'].values.tolist()]
        self.n_class = len(df['label'].unique())

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return torch.LongTensor(self.inputs[idx]), torch.LongTensor(self.labels[idx])


class TextRNN(nn.Module):
    def __init__(self, vocab_size, embed_size, n_class, n_hidden):
        super().__init__()
        # 嵌入层
        self.embedding = nn.Embedding(vocab_size, embed_size)
        # rnn层，深度为1
        self.rnn = nn.RNN(embed_size, n_hidden, 1, batch_first=True)
        # 激活函数
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.1)
        # 输出头
        self.fc = nn.Linear(n_hidden, n_class)

    def forward(self, x):  # x: [batch_size * 句子长度]
        x = self.embedding(x)  # [batch_size * 句子长度 * embed_size]
        out, _ = self.rnn(x)  # [batch_size * 句子长度 * n_hidden]
        out = self.relu(out)
        out = self.dropout(out)
        logits = self.fc(out[:,-1,:])  # 全连接输出头，[batch_size * n_class]
        return logits

In [3]:
dataset = MyDataset(max_len=10)  # 构造长度为10的句子输入，超过10的句子切掉，不足10的补<pad>
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
model = TextRNN(vocab_size=dataset.vocab_size, embed_size=128,
                n_class=dataset.n_class, n_hidden=256)
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
for epoch in range(50):
    for feature, target in dataloader:
        optimizer.zero_grad()
        logits = model(feature)
        loss = criterion(logits, target.squeeze())
        loss.backward()
        optimizer.step()
    print('epoch:', epoch + 1, ', loss:', loss.item())

Building prefix dict from the default dictionary ...
Loading model from cache /var/folders/d1/4_gsqv2176z583_7rmpm27lh0000gn/T/jieba.cache
Loading model cost 0.412 seconds.
Prefix dict has been built successfully.


TextRNN(
  (embedding): Embedding(8202, 128)
  (rnn): RNN(128, 256, batch_first=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.1, inplace=False)
  (fc): Linear(in_features=256, out_features=14, bias=True)
)
epoch: 1 , loss: 2.2921788692474365
epoch: 2 , loss: 2.694180727005005
epoch: 3 , loss: 1.8544557094573975
epoch: 4 , loss: 1.861046552658081
epoch: 5 , loss: 1.4497593641281128
epoch: 6 , loss: 1.3731626272201538
epoch: 7 , loss: 1.2065224647521973
epoch: 8 , loss: 0.8219239115715027
epoch: 9 , loss: 0.6023156642913818
epoch: 10 , loss: 0.6815493106842041
epoch: 11 , loss: 0.22835366427898407
epoch: 12 , loss: 0.24305564165115356
epoch: 13 , loss: 0.5704612731933594
epoch: 14 , loss: 0.028425700962543488
epoch: 15 , loss: 0.039246343076229095
epoch: 16 , loss: 0.011273359879851341
epoch: 17 , loss: 0.012752080336213112
epoch: 18 , loss: 0.004150383174419403
epoch: 19 , loss: 0.006946096196770668
epoch: 20 , loss: 0.009439648129045963
epoch: 21 , loss: 0.007732153404504061
epoch: 2

In [4]:
df = pd.read_csv('../../datasets/THUCNews/train.csv')
predict = model(feature).max(1)[1].tolist()
for i in range(len(feature.tolist())):
    print(' '.join([dataset.idx2token[idx] for idx in feature.tolist()[i]]),
          '---> label:',
          dict(zip(df['label'], df['class']))[predict[i]])

喜欢 田园 风情 清爽 柔美 图 <pad> <pad> <pad> <pad> ---> label: 家居
山西 完成 罕见 超级 转会 卫冕冠军 宏远 三 集体 加盟 ---> label: 体育
斯达康 通信 技术 公司 出售 旗下 资产 <pad> <pad> <pad> ---> label: 科技
沪 基指 半日 上涨 两市 近九成 封基 飘红 <pad> <pad> ---> label: 财经
常昊评 韩联 崔哲瀚 胜 李昌镐 崔毒 见 灭 之势 谱 ---> label: 体育
马苏 亮相 碧海 雄心 启动 仪式 现场 变 身 追星族 ---> label: 娱乐
超 两成 个股 跌破 增发 价 盐湖 钾肥 四 公司 ---> label: 股票
一毛 没花 巨人 仙途衡 时代 手工 装备 <pad> <pad> <pad> ---> label: 游戏
美国 月 费城 联储 制造业 指数 <pad> <pad> <pad> <pad> ---> label: 股票
中国 休闲 食品 拟下 周四 参加 港交所 聆讯 <pad> <pad> ---> label: 股票
张韶涵 打拼 十年 决战 唱响 之巅 图 <pad> <pad> <pad> ---> label: 娱乐
厂商 敌个 布朗 巴顿 世界冠军 诞生 <pad> <pad> <pad> <pad> ---> label: 体育
通用 食言 不卖 欧宝 德国 生气 <pad> <pad> <pad> <pad> ---> label: 股票
退市 前狂 爆 性价比 尼康 套 机仅 <pad> <pad> <pad> ---> label: 科技
联游 网络 濒临 退市 唐骏 强硬 裁员 指 杀鸡取卵 <pad> ---> label: 科技
靓丽 流线 造型 昂达 仅售元 <pad> <pad> <pad> <pad> <pad> ---> label: 科技
