In [1]:
import torch
import torchtext
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import re
import jieba
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from collections import Counter

# 1.准备数据

In [2]:
# 构造数据集
class MyDataset(Dataset):
    def __init__(self, file_path, stopwords, sample=None):
        df = pd.read_csv(file_path).dropna().reset_index(drop=True)
        if sample:
            df = df.sample(sample).reset_index(drop=True)

        counter = Counter()
        sentences = []
        
        for title in tqdm(df['title']):   
            # 去除标点符号
            title = re.sub(r'[^\u4e00-\u9fa5]', '', title)
            tokens = [token for token in jieba.cut(title.strip()) if token not in stopwords]
            counter.update(tokens)
            sentences.append(tokens)
        self.vocab = torchtext.vocab.vocab(counter, specials=['<unk>', '<pad>'])
        
        # 构造输入和输出，跳元模型，用当前字预测前一个字和后一个字
        self.inputs = []
        self.labels = []
        for sen in sentences:
            for i in range(1, len(sen) - 1):
                self.inputs.append([self.vocab[sen[i]]])
                self.labels.append([self.vocab[sen[i - 1]]])
                self.inputs.append([self.vocab[sen[i]]])
                self.labels.append([self.vocab[sen[i + 1]]])

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return torch.LongTensor(self.inputs[idx]), torch.LongTensor(self.labels[idx])

In [6]:
file_path = '../../datasets/THUCNews/train.csv'
stopwords = [line.strip() for line in open('../stopwords/cn_stopwords.txt', 'r', encoding='utf-8').readlines()]

dataset = MyDataset(file_path, stopwords, sample=1000)
dataloader = DataLoader(dataset=dataset, batch_size=128, shuffle=True)

  0%|          | 0/1000 [00:00<?, ?it/s]

# 2.构建模型

In [9]:
class Word2Vec(nn.Module):
    def __init__(self, vocab_size, embed_size):
        super().__init__()
        # W和WT的形状是转置的
        self.W = nn.Embedding(vocab_size, embed_size)  # vocab_size -> embed_size
        self.WT = nn.Linear(embed_size, vocab_size, bias=False)  # embed_size -> vocab_size

    def forward(self, X):
        # X形状：batch_size * vocab_size
        hidden_layer = self.W(X)
        output_layer = self.WT(hidden_layer)
        return output_layer

In [10]:
# 初始化模型
model = Word2Vec(vocab_size=len(dataset.vocab), embed_size=512)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 查看模型
print(model)

Word2Vec(
  (W): Embedding(5000, 512)
  (WT): Linear(in_features=512, out_features=5000, bias=False)
)


# 3.训练模型

In [11]:
# 训练20个epoch
for epoch in range(20):
    for train_input, train_label in dataloader:
        output = model(train_input)
        loss = criterion(output.squeeze_(), train_label.squeeze_())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print('epoch:', epoch + 1, 'loss =', '{:.6f}'.format(loss))

epoch: 1 loss = 8.514538
epoch: 2 loss = 6.321123
epoch: 3 loss = 4.709826
epoch: 4 loss = 3.780671
epoch: 5 loss = 3.456005
epoch: 6 loss = 3.442016
epoch: 7 loss = 3.048970
epoch: 8 loss = 3.067402
epoch: 9 loss = 3.020471
epoch: 10 loss = 2.988716
epoch: 11 loss = 2.834623
epoch: 12 loss = 2.975353
epoch: 13 loss = 2.804723
epoch: 14 loss = 2.660689
epoch: 15 loss = 2.793330
epoch: 16 loss = 3.006624
epoch: 17 loss = 2.877311
epoch: 18 loss = 2.860091
epoch: 19 loss = 2.853536
epoch: 20 loss = 2.676005


# 4.查看模型参数（vector）

In [12]:
W, WT = model.parameters()

In [13]:
# W对应vocab中每个词的vector，这里是512维
print(W.shape, WT.shape)

torch.Size([5000, 512]) torch.Size([5000, 512])


In [14]:
print(W[0])

tensor([ 9.7986e-02,  8.0324e-01, -2.6938e-01,  1.2315e-01,  6.8222e-01,
        -1.6469e+00, -1.0018e+00, -1.1281e-01, -1.1169e+00,  5.1680e-01,
         2.8906e-02, -1.0431e+00,  8.9166e-01, -1.1216e+00,  1.3011e+00,
        -9.1865e-01, -1.7741e-01, -2.0525e-01,  7.3771e-01, -4.9387e-01,
        -1.1042e-01, -1.8730e+00,  5.6812e-01, -8.4762e-02,  1.8729e-01,
         8.0230e-01, -7.0468e-01,  7.1599e-01,  1.2728e+00,  4.8130e-01,
         5.6721e-01, -5.9125e-01, -1.3044e+00,  2.2909e+00,  3.2052e-01,
         2.5068e-01,  7.1812e-01, -1.2433e-01, -4.7760e-01, -4.6170e-01,
        -4.7226e-03,  4.7565e-01,  1.7193e+00,  1.2292e+00,  2.3644e+00,
        -9.8047e-01, -1.4936e+00, -4.0834e-01,  1.4220e-01, -3.2324e-01,
         6.9451e-03,  4.4440e-01, -1.0876e+00,  1.3307e+00,  5.3569e-01,
        -1.4927e+00, -5.8807e-01, -2.5607e+00, -1.2350e+00,  6.2898e-01,
        -4.7516e-01,  5.9551e-01, -3.9008e-01,  1.0815e+00,  1.7505e-01,
         1.5820e+00, -2.7845e+00,  2.0758e-01,  1.4