In [8]:
import torch
import torchtext
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import pandas as pd
import re
import jieba
import numpy as np
from tqdm.notebook import tqdm
from collections import Counter

# 1.准备数据

In [6]:
# 构造数据集
class MyDataset(Dataset):
    def __init__(self, file_path, tokenizer, stopwords, debug=True):
        df = pd.read_csv(file_path)
        df = df.dropna().reset_index(drop=True)
        if debug:
            df = df.sample(2000).reset_index(drop=True)
        else:
            df = df.sample(50000).reset_index(drop=True)
        # 读取常用停用词
        counter = Counter()
        sentences = []
        for title in tqdm(df['title']):   
            # 去除标点符号
            title = re.sub(r'[^\u4e00-\u9fa5]', '', title)
            tokens = [token for token in tokenizer(title.strip()) if token not in stopwords]
            counter.update(tokens)
            sentences.append(tokens)
        self.vocab = torchtext.vocab.vocab(counter, specials=['<unk>', '<pad>'])
        
        # 构造输入和输出，跳元模型，用当前字预测前一个字和后一个字
        self.inputs = []
        self.labels = []
        for sen in sentences:
            for i in range(1, len(sen) - 1):
                self.inputs.append([self.vocab[sen[i]]])
                self.labels.append([self.vocab[sen[i - 1]]])
                self.inputs.append([self.vocab[sen[i]]])
                self.labels.append([self.vocab[sen[i + 1]]])

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return torch.LongTensor(self.inputs[idx]), torch.LongTensor(self.labels[idx])

In [9]:
file_path = '../data/THUCNews/train.csv'
tokenizer = torchtext.data.utils.get_tokenizer('spacy', language='zh_core_web_sm')
stopwords = [line.strip() for line in open('../stopwords/cn_stopwords.txt', 'r', encoding='utf-8').readlines()]
dataset = MyDataset(file_path, tokenizer, stopwords)
dataloader = DataLoader(dataset=dataset, batch_size=128, shuffle=True)

  0%|          | 0/2000 [00:00<?, ?it/s]

# 2.构建模型

In [10]:
class Word2Vec(nn.Module):
    def __init__(self, vocab_size, embed_size):
        super().__init__()
        # W和WT的形状是转置的
        self.W = nn.Embedding(vocab_size, embed_size)  # vocab_size -> embed_size
        self.WT = nn.Linear(embed_size, vocab_size, bias=False)  # embed_size -> vocab_size

    def forward(self, X):
        # X形状：batch_size * vocab_size
        hidden_layer = self.W(X)
        output_layer = self.WT(hidden_layer)
        return output_layer

In [12]:
# 初始化模型
model = Word2Vec(vocab_size=len(dataset.vocab), embed_size=512)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 查看模型
print(model)

Word2Vec(
  (W): Embedding(8456, 512)
  (WT): Linear(in_features=512, out_features=8456, bias=False)
)


# 3.训练模型

In [13]:
# 训练20个epoch
for epoch in range(20):
    for train_input, train_label in dataloader:
        output = model(train_input)
        loss = criterion(output.squeeze_(), train_label.squeeze_())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print('epoch:', epoch + 1, 'loss =', '{:.6f}'.format(loss))

epoch: 1 loss = 8.912559
epoch: 2 loss = 6.400843
epoch: 3 loss = 5.419957
epoch: 4 loss = 4.417579
epoch: 5 loss = 4.015151
epoch: 6 loss = 3.830552
epoch: 7 loss = 3.863740
epoch: 8 loss = 3.856899
epoch: 9 loss = 3.621994
epoch: 10 loss = 3.606398
epoch: 11 loss = 3.689243
epoch: 12 loss = 3.309719
epoch: 13 loss = 3.388777
epoch: 14 loss = 3.631728
epoch: 15 loss = 3.625928
epoch: 16 loss = 3.460016
epoch: 17 loss = 3.599987
epoch: 18 loss = 3.469801
epoch: 19 loss = 3.327806
epoch: 20 loss = 3.369408


# 4.查看模型参数（vector）

In [14]:
W, WT = model.parameters()

In [20]:
# W对应vocab中每个词的vector，这里是512维
print(W.shape, WT.shape)

torch.Size([8456, 512]) torch.Size([8456, 512])


In [19]:
print(W[0])

tensor([ 8.2287e-01,  1.7147e+00,  6.3301e-02, -8.2303e-01,  2.6472e-04,
        -8.6783e-01,  7.7127e-01, -7.2511e-02, -7.2578e-02, -7.4657e-01,
         3.7388e-01,  2.5191e-01, -7.3627e-01, -1.6520e+00, -5.1887e-01,
         1.0842e+00,  7.9234e-01,  2.3700e-01,  5.0631e-01,  1.6158e+00,
         1.3460e+00, -5.6373e-01, -1.3444e+00, -1.4603e-03,  2.8087e-02,
         7.8602e-01, -7.2021e-01,  1.3479e+00, -7.6845e-01,  1.3702e+00,
        -3.0853e-02, -5.1576e-01,  9.2350e-01, -1.4483e+00,  3.8712e-02,
        -2.7035e-01,  1.2906e+00, -3.3333e-01, -1.1882e-02, -6.2058e-01,
        -1.3944e+00, -1.1244e-01, -4.3255e-01, -2.8653e+00,  1.7908e+00,
        -3.3256e-01,  1.4750e+00,  4.1109e-01,  7.2152e-01,  8.5703e-02,
        -1.3381e+00,  1.3475e-01,  6.8010e-01,  5.6846e-01,  3.7432e-01,
        -5.4027e-01,  9.2805e-02, -4.3790e-01, -8.1059e-01,  2.2791e-01,
         7.0500e-01,  6.0567e-01,  5.0068e-01, -3.4578e-02, -4.5616e-01,
         1.3117e+00, -3.9012e-01,  7.3155e-01,  1.9