In [2]:
import torch
import torchtext
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

import pandas as pd
import re
from collections import Counter
from tqdm.notebook import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # 能用gpu则用gpu

# 1.准备数据

In [3]:
class MyDataset(Dataset):
    def __init__(self, file_path, tokenizer, stopwords, sample=None):
        df = pd.read_csv(file_path)
        df = df.dropna().reset_index(drop=True)
        if sample:
            df = df.sample(sample).reset_index(drop=True)
            
        counter = Counter()
        sentences = []
        for title in tqdm(df['title']):
            # 去除标点符号
            title = re.sub(r'[^\u4e00-\u9fa5]', '', title)
            tokens = [token for token in tokenizer(title.strip()) if token not in stopwords]
            counter.update(tokens)
            sentences.append(tokens)
            
        self.vocab = torchtext.vocab.vocab(counter, specials=['<unk>', '<pad>'])
        self.vocab.set_default_index(self.vocab['<unk>'])

        self.inputs = [self.vocab.lookup_indices(tokens) for tokens in sentences]
        self.labels = [[label] for label in df['label'].values.tolist()]
        self.n_class = len(df['label'].unique())

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return torch.LongTensor(self.inputs[idx]), torch.LongTensor(self.labels[idx])

In [5]:
file_path = '../../datasets/THUCNews/train.csv'
tokenizer = torchtext.data.utils.get_tokenizer('spacy', language='zh_core_web_sm')
stopwords = [line.strip() for line in open('../stopwords/cn_stopwords.txt', 'r', encoding='utf-8').readlines()]
dataset = MyDataset(file_path, tokenizer, stopwords, sample=10000)

  0%|          | 0/10000 [00:00<?, ?it/s]

In [8]:
def collate_fn(batch_data):
    return pad_sequence([x for x, y in batch_data], padding_value=1), torch.tensor([y for x, y in batch_data]).unsqueeze(1)

dataloader = DataLoader(dataset, batch_size=1024, shuffle=True, collate_fn=collate_fn)

# 2.构建模型

In [17]:
class TextCNN(nn.Module):
    def __init__(self, vocab_size, embed_size, n_class):
        super().__init__()
        # 嵌入层
        self.embedding = nn.Embedding(vocab_size, embed_size)
        # 输入通道为1，卷成16通道输出，卷积核大小为(3*embed_size)，3类似于n-gram，可以换
        self.conv = nn.Conv2d(1, 16, (3, embed_size))
        self.dropout = nn.Dropout(0.2)
        # 输出头
        self.fc = nn.Linear(16, n_class)

    def forward(self, x):  # x: [batch_size * 句子长度]
        x = x.permute(1, 0)
        x = self.embedding(x)  # [batch_size * 句子长度 * embed_size]
        x = x.unsqueeze(1)  # [batch_size * 1 * 句子长度 * embed_size]，加一个维度，用于卷积层的输入
        x = self.conv(x)  # [batch_size * 16(卷积层输出通道数) * 8(卷积后的宽) * 1(卷积后的高)]
        x = x.squeeze(3)  # [batch_size * 16(卷积层输出通道数) * 8(卷积后的宽)] 压缩大小为1的维度
        x = torch.relu(x)  # 激活函数，尺寸不变
        x = torch.max_pool1d(x, x.size(2))  # 在每个通道做最大池化，[batch_size * 16(卷积层输出通道数) * 1]
        x = x.squeeze(2)  # 压缩维度2，[batch_size * 16(卷积层输出通道数)]
        x = self.dropout(x)  # dropout，尺寸不变
        logits = self.fc(x)  # 全连接输出头，[batch_size * n_class]
        return logits


model = TextCNN(vocab_size=len(dataset.vocab), embed_size=256, n_class=dataset.n_class).to(device)
model

TextCNN(
  (embedding): Embedding(25642, 256)
  (conv): Conv2d(1, 16, kernel_size=(3, 256), stride=(1, 1))
  (dropout): Dropout(p=0.2, inplace=False)
  (fc): Linear(in_features=16, out_features=14, bias=True)
)

# 3.训练模型

In [11]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
model.train()
for epoch in range(200):
    for feature, target in dataloader:
        feature = feature.to(device)
        target = target.to(device)
        optimizer.zero_grad()
        logits = model(feature)
        loss = criterion(logits, target.squeeze())
        loss.backward()
        optimizer.step()
    print('epoch:', epoch + 1, ', loss:', format(loss.item(), '.6f'))

epoch: 1 , loss: 0.897087
epoch: 2 , loss: 0.864144
epoch: 3 , loss: 0.922457
epoch: 4 , loss: 0.805261
epoch: 5 , loss: 0.712432
epoch: 6 , loss: 0.658637
epoch: 7 , loss: 0.646216
epoch: 8 , loss: 0.629971
epoch: 9 , loss: 0.589703
epoch: 10 , loss: 0.528293
epoch: 11 , loss: 0.613047
epoch: 12 , loss: 0.533437
epoch: 13 , loss: 0.518082
epoch: 14 , loss: 0.405103
epoch: 15 , loss: 0.461087
epoch: 16 , loss: 0.462371
epoch: 17 , loss: 0.402825
epoch: 18 , loss: 0.395251
epoch: 19 , loss: 0.369918
epoch: 20 , loss: 0.395495
epoch: 21 , loss: 0.359664
epoch: 22 , loss: 0.390282
epoch: 23 , loss: 0.338754
epoch: 24 , loss: 0.330922
epoch: 25 , loss: 0.325939
epoch: 26 , loss: 0.283240
epoch: 27 , loss: 0.264535
epoch: 28 , loss: 0.267931
epoch: 29 , loss: 0.271422
epoch: 30 , loss: 0.267239
epoch: 31 , loss: 0.242171
epoch: 32 , loss: 0.230579
epoch: 33 , loss: 0.203122
epoch: 34 , loss: 0.248100
epoch: 35 , loss: 0.179979
epoch: 36 , loss: 0.220583
epoch: 37 , loss: 0.163618
epoch: 38 

# 4.预测

In [14]:
model.eval()
df_train = pd.read_csv('../../datasets/THUCNews/train.csv')
df_test = pd.read_csv('../../datasets/THUCNews/test.csv')

In [16]:
for i, row in df_test.sample(10).iterrows():
    title = row['title']
    actual = row['class']
    title = re.sub(r'[^\u4e00-\u9fa5]', '', title)
    tokens = [token for token in tokenizer(title.strip()) if token not in stopwords]
    inputs = dataset.vocab.lookup_indices(tokens)
    inputs = torch.LongTensor(inputs).unsqueeze(1).to(device)
    predict = model(inputs)
    predict_class = dict(zip(df_train['label'], df_train['class']))[predict.max(1)[1].item()]    
    print(''.join(tokens), '|| actual:', actual, '|| predict:', predict_class)

诺基亚领衔节热门降价手机盘点 || actual: 科技 || predict: 科技
国足主帅竞聘成下岗工聚会票选超级女生专业 || actual: 体育 || predict: 体育
员工发牢骚称厂垮要求离职组图 || actual: 社会 || predict: 娱乐
高开低走逢低吸纳待涨升 || actual: 股票 || predict: 股票
搞笑诺贝尔奖虫瓶交配憋尿会失策 || actual: 科技 || predict: 时政
冲刺高考辽源名师教抢夺关键分图 || actual: 教育 || predict: 教育
男子持枪劫持人质要挟警方寻找女网友 || actual: 社会 || predict: 时政
组图邦女郎盖玛阿特登泳装秀 || actual: 娱乐 || predict: 房产
中考分数保密高考上线率惹祸 || actual: 教育 || predict: 教育
房地产房价决定政策进度资金决定房价走向 || actual: 股票 || predict: 股票
