In [1]:
import torch
import torchtext
from torch import nn
from torch.utils.data import Dataset, DataLoader
device = torch.device('cuda' if torch.cuda.is_available else 'cpu')

import jieba
import re
import pandas as pd
from collections import Counter
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split

import wandb
wandb.init(project='NNLM')

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mgechengze[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
def build_vocab(df, stopwords):
    counter = Counter()
    sentences = []

    pbar = tqdm(df['title'])
    for title in pbar:
        pbar.set_description('building vocab')
        title = re.sub(r'[^\u4e00-\u9fff]', '', title)
        tokens = [token for token in jieba.cut(title.strip()) if token not in stopwords]
        counter.update(tokens)
        sentences.append(tokens)

    return torchtext.vocab.vocab(counter, specials=['<unk>', '<pad>'])

class THUCNews(Dataset):
    def __init__(self, df, vocab, stopwords):
        self.inputs = []
        self.labels = []
        
        pbar = tqdm(df['title'])
        for title in pbar:
            pbar.set_description('building dataset')
            title = re.sub(r'[^\u4e00-\u9fff]', '', title)
            tokens = [token for token in jieba.cut(title.strip()) if token not in stopwords]
            for i in range(len(tokens) - 3):
                self.inputs.append(vocab.lookup_indices(tokens[i: i + 3]))
                self.labels.append([vocab[tokens[i + 3]]])
                
    def __len__(self):
        return len(self.labels)
                
    def __getitem__(self, idx):
        return torch.LongTensor(self.inputs[idx]), torch.LongTensor(self.labels[idx])

In [3]:
df = pd.read_csv('../../datasets/THUCNews/title.csv').sample(100000)

with open('../stopwords/cn_stopwords.txt') as f:
    stopwords = [line.strip() for line in f.readlines()]
    
vocab = build_vocab(df, stopwords)

  0%|          | 0/100000 [00:00<?, ?it/s]

Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 0.411 seconds.
Prefix dict has been built successfully.


In [4]:
df_train, df_valid = train_test_split(df, test_size=0.3)

In [5]:
train_datasets = THUCNews(df_train, vocab, stopwords)
valid_datasets = THUCNews(df_valid, vocab, stopwords)

  0%|          | 0/70000 [00:00<?, ?it/s]

  0%|          | 0/30000 [00:00<?, ?it/s]

In [6]:
train_dataloader = DataLoader(train_datasets, batch_size=512, shuffle=True, drop_last=True)
valid_dataloader = DataLoader(valid_datasets, batch_size=512, shuffle=False, drop_last=True)

In [7]:
class NNLM(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.fc = nn.Sequential(
            nn.Linear(embed_size * 3, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, vocab_size)
        )
        
    def forward(self, x):
        embedded = self.embedding(x)
        embedded = embedded.reshape(embedded.shape[0], -1)
        output = self.fc(embedded)
        return output

In [8]:
model = NNLM(vocab_size=len(vocab), embed_size=256, hidden_size=256)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [9]:
for epoch in range(20):
    model.train()
    total_train_loss = 0
    num_x = 0
    pbar = tqdm(train_dataloader)
    for x, y in pbar:
        pbar.set_description('epoch' + str(epoch))
        num_x += x.shape[0]
        x = x.to(device)
        y = y.to(device)
        output = model(x)
        loss = criterion(output, y.squeeze_())
        total_train_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        pbar.set_postfix({'loss': format(total_train_loss / num_x, '.6f')})
        
    model.eval()
    total_valid_loss = 0
    num_x = 0
    pbar = tqdm(valid_dataloader)
    for x, y in pbar:
        pbar.set_description('valid')
        num_x += x.shape[0]
        x = x.to(device)
        y = y.to(device)
        output = model(x)
        loss = criterion(output, y.squeeze_())
        total_valid_loss += loss.item()
        pbar.set_postfix({'loss': format(total_valid_loss / num_x, '.6f')})
        
    wandb.log({"Train Loss": total_train_loss / num_x, 
               "Valid Loss": total_valid_loss / num_x})


  0%|          | 0/615 [00:00<?, ?it/s]

  0%|          | 0/263 [00:00<?, ?it/s]

  0%|          | 0/615 [00:00<?, ?it/s]

  0%|          | 0/263 [00:00<?, ?it/s]

  0%|          | 0/615 [00:00<?, ?it/s]

  0%|          | 0/263 [00:00<?, ?it/s]

  0%|          | 0/615 [00:00<?, ?it/s]

  0%|          | 0/263 [00:00<?, ?it/s]

  0%|          | 0/615 [00:00<?, ?it/s]

  0%|          | 0/263 [00:00<?, ?it/s]

  0%|          | 0/615 [00:00<?, ?it/s]

  0%|          | 0/263 [00:00<?, ?it/s]

  0%|          | 0/615 [00:00<?, ?it/s]

  0%|          | 0/263 [00:00<?, ?it/s]

  0%|          | 0/615 [00:00<?, ?it/s]

  0%|          | 0/263 [00:00<?, ?it/s]

  0%|          | 0/615 [00:00<?, ?it/s]

  0%|          | 0/263 [00:00<?, ?it/s]

  0%|          | 0/615 [00:00<?, ?it/s]

  0%|          | 0/263 [00:00<?, ?it/s]

  0%|          | 0/615 [00:00<?, ?it/s]

  0%|          | 0/263 [00:00<?, ?it/s]

  0%|          | 0/615 [00:00<?, ?it/s]

  0%|          | 0/263 [00:00<?, ?it/s]

  0%|          | 0/615 [00:00<?, ?it/s]

  0%|          | 0/263 [00:00<?, ?it/s]

  0%|          | 0/615 [00:00<?, ?it/s]

  0%|          | 0/263 [00:00<?, ?it/s]

  0%|          | 0/615 [00:00<?, ?it/s]

  0%|          | 0/263 [00:00<?, ?it/s]

  0%|          | 0/615 [00:00<?, ?it/s]

  0%|          | 0/263 [00:00<?, ?it/s]

  0%|          | 0/615 [00:00<?, ?it/s]

  0%|          | 0/263 [00:00<?, ?it/s]

  0%|          | 0/615 [00:00<?, ?it/s]

  0%|          | 0/263 [00:00<?, ?it/s]

  0%|          | 0/615 [00:00<?, ?it/s]

  0%|          | 0/263 [00:00<?, ?it/s]

  0%|          | 0/615 [00:00<?, ?it/s]

  0%|          | 0/263 [00:00<?, ?it/s]