In [None]:
import jieba
import pandas as pd
import torch
import numpy as np
from torch import nn
from transformers import TrainingArguments, Trainer, TrainerCallback

In [None]:
class Vocab:
    def __init__(self, vocab_file, stop_words_file=None):
        self.stop_words_file = self.load_stop_words(stop_words_file)
        self.idx2word, self.word2idx, self.words = self.load_vocab(vocab_file)
        self.word_size = len(self.words)
        self.vocab_size = len(self.idx2word)

    def load_vocab(self, vocab_file):
        idx2word = {}
        word2idx = {}

        words = []
        contents = pd.read_csv(vocab_file, encoding="GBK", header=None)

        for idx, row in contents.iterrows():
            line = row[0]
            if not self.stop_words_file:
                current_line_words = [
                    word for word in jieba.cut(line) if word not in self.stop_words_file
                ]
            else:
                current_line_words = list(jieba.cut(line))
            words.extend(current_line_words)

        for idx, word in enumerate(set(words)):
            idx2word[idx] = word
            word2idx[word] = idx
        return idx2word, word2idx, words

    def load_stop_words(self, stop_words_file):
        if stop_words_file is None:
            return set()
        else:
            with open(stop_words_file, "r") as f:
                return set(f.read().splitlines())

    def get_idx(self, word):
        return self.word2idx[word]

    def get_word(self, idx):
        return self.idx2word[idx]

In [None]:
vocab = Vocab("./assets/数学原始数据.csv", "./assets/stopwords.txt")

In [None]:
vocab.word_size, vocab.vocab_size

In [None]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, ngram: int, vocab: Vocab):
        self.ngram = ngram
        self.vocab = vocab
        self.word_size = vocab.word_size
        self.vocab_size = vocab.vocab_size

    def __len__(self):
        return self.word_size - 2 * self.ngram - 1

    def __getitem__(self, idx):
        left_idx = idx
        right_idx = idx + 2 * self.ngram + 1
        words = self.vocab.words[left_idx:right_idx]
        current_word = words.pop(self.ngram)
        label = self.vocab.get_idx(current_word)

        # current_word_onehot = np.zeros(self.vocab_size)
        # current_word_onehot[self.vocab.get_idx(current_word)] = 1

        another_word_onhot = np.zeros((2 * self.ngram, self.vocab_size))

        for i, word in enumerate(words):
            another_word_onhot[i][self.vocab.get_idx(word)] = 1

        return {
            "inputs": torch.tensor(another_word_onhot, dtype=torch.float32),
            "labels": torch.tensor(label, dtype=torch.long),
        }

In [None]:
data = MyDataset(2, vocab)

In [None]:
data_iter = torch.utils.data.DataLoader(data, batch_size=32, shuffle=True)

In [None]:
class Net(nn.Module):
    def __init__(self, vocab_size, embedding_size):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding_size = embedding_size

        self.model = nn.Sequential(
            nn.Linear(
                vocab_size,
                embedding_size,
                bias=True,
            ),
            nn.Linear(
                embedding_size,
                vocab_size,
                bias=True,
            ),
        )

    def forward(self, inputs, labels=None):

        loss_fn = nn.CrossEntropyLoss()
        batch_size, ngram, vocab_size = inputs.shape
        # [batch_size, ngram, vocab_size] -> [batch_size * ngram, vocab_size]
        inputs = inputs.reshape(-1, self.vocab_size)
        # [batch_size * ngram, vocab_size] -> [batch_size * ngram, vocab_size]
        inputs_logits = self.model(inputs)
        # [batch_size * ngram, vocab_size] -> [batch_size, ngram, vocab_size]
        inputs_logits = inputs_logits.reshape(batch_size, ngram, vocab_size)
        # [batch_size, ngram, vocab_size] -> [batch_size, vocab_size]
        inputs_logits = torch.mean(inputs_logits, dim=1)
        if labels is not None:
            # [batch_size, vocab_size] 和 [batch_size, vocab_size]
            loss = loss_fn(inputs_logits, labels)
            return {"logits": inputs_logits, "loss": loss}
        else:
            return {"logits": inputs_logits}

In [None]:
model = Net(vocab.vocab_size, 100)

In [None]:
class MyCallBacks(TrainerCallback):

    def on_train_begin(self, args, state, control, **kwargs):
        print("\nStarting training")

    def on_train_end(self, args, state, control, **kwargs):
        print("\nEnding training")

    def on_save(self, args, state, control, **kwargs):
        print("\nSaving model")

In [None]:
training_args = TrainingArguments(
    output_dir="./word2vec",
    num_train_epochs=3,
    logging_strategy="steps",
    save_strategy="epoch",
    use_cpu=False,
    save_total_limit=3,
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=data,
    optimizers=(torch.optim.SGD(model.parameters(), 0.01), None),
    callbacks=[MyCallBacks],
)

In [None]:
trainer.train()

In [None]:
torch.save(model.state_dict(), "./word2vec.pth")