In [1]:
import pandas as pd
from nltk.tokenize import word_tokenize
from collections import Counter
from tqdm import tqdm
from sklearn.model_selection import train_test_split

import torchtext
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset

device = torch.device('cuda' if torch.cuda.is_available else 'cpu')

In [2]:
def build_vocab(df):
    counter = Counter()
    for en in df['en']:
        counter.update(word_tokenize(en))
    en_vocab = torchtext.vocab.vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])
    
    counter = Counter()
    for zh in df['zh']:
        counter.update(list(zh))
    zh_vocab = torchtext.vocab.vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])
    
    return {'en': en_vocab, 'zh': zh_vocab}

In [3]:
df = pd.read_csv('../../datasets/cmn.txt', sep='\t', header=None, names=['en', 'zh'])
vocab = build_vocab(df)

In [4]:
class Zh2En(Dataset):
    def __init__(self, df, vocab):
        super().__init__()
        self.vocab = vocab
        self.source = []
        self.target = []
        for zh in df['zh']:
            self.source.append(list(zh))
        for en in df['en']:
            self.target.append(word_tokenize(en))
            
    def __len__(self):
        return len(self.target)
    
    def __getitem__(self, idx):
        zh = [self.vocab['zh']['<bos>']] + self.vocab['zh'].lookup_indices(self.source[idx]) + [self.vocab['zh']['<eos>']]
        en = [self.vocab['en']['<bos>']] + self.vocab['en'].lookup_indices(self.target[idx]) + [self.vocab['en']['<eos>']]
        return torch.LongTensor(zh), torch.LongTensor(en)
                

In [5]:
df_train, df_valid = train_test_split(df, test_size=0.3)
train_dataset = Zh2En(df_train, vocab)
valid_dataset = Zh2En(df_valid, vocab)

In [6]:
def collate_fn(batch):
    zh_batch, en_batch = [], []
    for zh, en in batch:
        zh_batch.append(zh)
        en_batch.append(en)
    zh_batch = pad_sequence(zh_batch, padding_value=1, batch_first=True)
    en_batch = pad_sequence(en_batch, padding_value=1, batch_first=True)
    return zh_batch, en_batch

In [7]:
train_dataloader = DataLoader(train_dataset, batch_size=128,
                              shuffle=True, drop_last=True, collate_fn=collate_fn)
valid_dataloader = DataLoader(valid_dataset, batch_size=128,
                              shuffle=False, drop_last=True, collate_fn=collate_fn)

In [8]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_size, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.lstm = nn.LSTM(embedding_size, hidden_size, batch_first=True)
        self.dropout = nn.Dropout()
        
    def forward(self, x):
        embedded = self.embedding(x)
        output, (h, c) = self.lstm(embedded)
        output = self.dropout(output)
        
        return output, (h, c)
    
class Decoder(nn.Module):
    def __init__(self, vocab_size, embedding_size, hidden_size):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.lstm = nn.LSTM(embedding_size, hidden_size, batch_first=True)
        self.dropout = nn.Dropout()
        self.fc = nn.Linear(hidden_size, vocab_size)
        
    def forward(self, x, h, c):
        embedded = self.embedding(x)
        output, (h, c) = self.lstm(embedded, (h, c))
        output = self.dropout(output)
        return self.fc(output), (h, c)
    
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, source, target):
        enc_out, (h, c) = self.encoder(source)
        batch_size = target.shape[0]
        num_step = target.shape[1]
        output = torch.zeros(batch_size, num_step, self.decoder.vocab_size, device=device)
        y_t = target[:, 0]
        
        for t in range(1, num_step):
            y_t.unsqueeze_(1)
            y_t, (h, c) = self.decoder(y_t, h, c)
            y_t.squeeze_(1)
            output[:, t, :] = y_t
            y_t = y_t.argmax(1)
            
        return output

In [9]:
encoder = Encoder(vocab_size=len(vocab['zh']), embedding_size=128, hidden_size=128)
decoder = Decoder(vocab_size=len(vocab['en']), embedding_size=128, hidden_size=128)
model = Seq2Seq(encoder, decoder).to(device)
print(model)

criterion = nn.CrossEntropyLoss(ignore_index=1)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

Seq2Seq(
  (encoder): Encoder(
    (embedding): Embedding(3441, 128)
    (lstm): LSTM(128, 128, batch_first=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (decoder): Decoder(
    (embedding): Embedding(6976, 128)
    (lstm): LSTM(128, 128, batch_first=True)
    (dropout): Dropout(p=0.5, inplace=False)
    (fc): Linear(in_features=128, out_features=6976, bias=True)
  )
)


In [10]:
import wandb
wandb.init(project='seq2seq')


for epoch in tqdm(range(50)):
    model.train()
    total_train_loss = 0
    for x, y in train_dataloader:
        x = x.to(device)
        y = y.to(device)
        output = model(x, y)
        loss = criterion(output[:, 1:, :].reshape(-1, output.shape[-1]), y[:, 1:].reshape(-1))
        total_train_loss += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    epoch_train_loss = total_train_loss / len(train_dataloader.dataset)
        
    model.eval()
    total_valid_loss = 0
    for x, y in valid_dataloader:
        x = x.to(device)
        y = y.to(device)
        output = model(x, y)
        loss = criterion(output[:, 1:, :].reshape(-1, output.shape[-1]), y[:, 1:].reshape(-1))
        total_valid_loss += loss.item()
    epoch_valid_loss = total_valid_loss / len(valid_dataloader.dataset)
    
    wandb.log({'train loss': epoch_train_loss, 'valid loss': epoch_valid_loss})

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mgechengze[0m. Use [1m`wandb login --relogin`[0m to force relogin


100%|██████████| 50/50 [01:42<00:00,  2.04s/it]
