In [1]:
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

import torch
import torchtext
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

import pandas as pd
import re
from collections import Counter
from tqdm.notebook import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # 能用gpu则用gpu

# 1.准备数据

In [2]:
class MyDataset(Dataset):
    def __init__(self, file_path, tokenizer, stopwords, debug=True):
        df = pd.read_csv(file_path)
        df = df.dropna().reset_index(drop=True)
        if debug:
            df = df.sample(2000).reset_index(drop=True)
        counter = Counter()
        sentences = []
        for title in tqdm(df['title']):
            # 去除标点符号
            title = re.sub(r'[^\u4e00-\u9fa5]', '', title)
            tokens = [token for token in tokenizer(title.strip()) if token not in stopwords]
            counter.update(tokens)
            sentences.append(tokens)
        self.vocab = torchtext.vocab.vocab(counter, specials=['<unk>', '<pad>'])
        self.vocab.set_default_index(self.vocab['<unk>'])

        self.inputs = [self.vocab.lookup_indices(tokens) for tokens in sentences]
        self.labels = [[label] for label in df['label'].values.tolist()]
        self.n_class = len(df['label'].unique())

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return torch.LongTensor(self.inputs[idx]), torch.LongTensor(self.labels[idx])

In [3]:
file_path = '../data/THUCNews/train.csv'
tokenizer = torchtext.data.utils.get_tokenizer('spacy', language='zh_core_web_sm')
stopwords = [line.strip() for line in open('../stopwords/cn_stopwords.txt', 'r', encoding='utf-8').readlines()]
dataset = MyDataset(file_path, tokenizer, stopwords, debug=False)

  0%|          | 0/501644 [00:00<?, ?it/s]

In [9]:
def collate_fn(batch_data):
    return pad_sequence([x for x, y in batch_data], padding_value=1), torch.tensor([y for x, y in batch_data]).unsqueeze(1)

dataloader = DataLoader(dataset, batch_size=1024, shuffle=True, collate_fn=collate_fn)

# 2.构建模型

In [10]:
class TextRNN(nn.Module):
    def __init__(self, vocab_size, embed_size, n_class, n_hidden):
        super().__init__()
        # 嵌入层
        self.embedding = nn.Embedding(vocab_size, embed_size)
        # rnn层，深度为1
        self.rnn = nn.RNN(embed_size, n_hidden, 1, batch_first=True)
#         self.rnn = nn.LSTM(embed_size, n_hidden, 1, batch_first=True)  # LSTM
#         self.rnn = nn.LSTM(embed_size, n_hidden, 2, batch_first=True)  # 双向LSTM
#         self.rnn = nn.GRU(embed_size, n_hidden, 1, batch_first=True)  # GRU
#         self.rnn = nn.GRU(embed_size, n_hidden, 2, batch_first=True)  # 双向GRU
        
        # 激活函数
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.1)
        # 输出头
        self.fc = nn.Linear(n_hidden, n_class)

    def forward(self, x):  # x: [batch_size * 句子长度]
        x = x.permute(1, 0)
        x = self.embedding(x)  # [batch_size * 句子长度 * embed_size]
        out, _ = self.rnn(x)  # [batch_size * 句子长度 * n_hidden]
        out = self.relu(out)
        out = self.dropout(out)
        logits = self.fc(out[:,-1,:])  # 全连接输出头，[batch_size * n_class]
        return logits

In [11]:
model = TextRNN(vocab_size=len(dataset.vocab), embed_size=256,
                n_class=dataset.n_class, n_hidden=256).to(device)
print(model)

TextRNN(
  (embedding): Embedding(332814, 256)
  (rnn): RNN(256, 256, batch_first=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.1, inplace=False)
  (fc): Linear(in_features=256, out_features=14, bias=True)
)


# 3.训练模型

In [12]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss().to(device)
model.train()
for epoch in range(50):
    for feature, target in dataloader:
        feature = feature.to(device)
        target = target.to(device)
        optimizer.zero_grad()
        logits = model(feature)
        loss = criterion(logits, target.squeeze())
        loss.backward()
        optimizer.step()
    print('epoch:', epoch + 1, ', loss:', loss.item())

epoch: 1 , loss: 1.5687791109085083
epoch: 2 , loss: 0.9600813388824463
epoch: 3 , loss: 0.7435045838356018
epoch: 4 , loss: 0.6274538636207581
epoch: 5 , loss: 0.45718684792518616
epoch: 6 , loss: 0.4046489894390106
epoch: 7 , loss: 0.2518417537212372
epoch: 8 , loss: 0.2638021409511566
epoch: 9 , loss: 0.30624470114707947
epoch: 10 , loss: 0.2156132608652115
epoch: 11 , loss: 0.1962645798921585
epoch: 12 , loss: 0.23841294646263123
epoch: 13 , loss: 0.19347882270812988
epoch: 14 , loss: 0.17331674695014954
epoch: 15 , loss: 0.1403488665819168
epoch: 16 , loss: 0.10971462726593018
epoch: 17 , loss: 0.09726649522781372
epoch: 18 , loss: 0.13067583739757538
epoch: 19 , loss: 0.10973517596721649
epoch: 20 , loss: 0.09472400695085526
epoch: 21 , loss: 0.06777341663837433
epoch: 22 , loss: 0.1027136892080307
epoch: 23 , loss: 0.05104175582528114
epoch: 24 , loss: 0.09576085209846497
epoch: 25 , loss: 0.05780002102255821
epoch: 26 , loss: 0.0844801589846611
epoch: 27 , loss: 0.0906575992703

# 4.预测

In [13]:
model.eval()
df_train = pd.read_csv('../data/THUCNews/train.csv')
df_test = pd.read_csv('../data/THUCNews/test.csv')

In [15]:
for i, row in df_test.sample(10).iterrows():
    title = row['title']
    actual = row['class']
    title = re.sub(r'[^\u4e00-\u9fa5]', '', title)
    tokens = [token for token in tokenizer(title.strip()) if token not in stopwords]
    inputs = dataset.vocab.lookup_indices(tokens)
    inputs = torch.LongTensor(inputs).unsqueeze(1).to(device)
    predict = model(inputs)
    predict_class = dict(zip(df_train['label'], df_train['class']))[predict.max(1)[1].item()]    
    print(' '.join(tokens), '||| actual:', actual, ', predict:', predict_class)

沪指站 稳日线 两 市 仅 只股 下跌 ||| actual: 股票 , predict: 股票
组图 不务正业 明星 周慧敏 代言 塑身 中心 ||| actual: 娱乐 , predict: 娱乐
家用 新机 富士 套装 促销 元 ||| actual: 科技 , predict: 科技
减肥 过来 八 种 食物 刮油 组图 ||| actual: 时尚 , predict: 时尚
考核 学术 潜力 考研 面试 导师 爱问 开放式 问题 ||| actual: 教育 , predict: 教育
西甲 最佳 防线 岁 组合 创造 岁 领袖 敬业 精神 超罗 ||| actual: 体育 , predict: 体育
基金节 前 净 申购 亿份 创下 半年 新高 ||| actual: 财经 , predict: 财经
托福 口语 考试 中 四 种 欠佳 表现 建议 ||| actual: 教育 , predict: 教育
观望 氛围 渐 浓股 市场 早盘 震荡 下跌 ||| actual: 股票 , predict: 股票
米兰 第三 位 新援 闪电 加盟 全能 后卫 火速 飞抵 体检 签约 ||| actual: 体育 , predict: 体育
