In [1]:
import torch
from torch import nn
import random
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import spacy
import copy
from tqdm import tqdm
import sys
import numpy as np

In [2]:
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, 
                 n_layers, dropout=0.5, bidirectional=True):
        
        super(Encoder, self).__init__()
        
        self.n_layers = n_layers
        self.hid_dim = hid_dim
        self.bidirectional = bidirectional
        
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout,
                           bidirectional=bidirectional)
        
    def forward(self, x, lens, hidden, cell):
        
        # x [seq_len, batch_size]
        embedded = self.embedding(x)
        
        packed = nn.utils.rnn.pack_padded_sequence(embedded, lens, enforce_sorted=False,
                                                   batch_first=False)
        
        _, (hidden, cell) = self.rnn(packed, (hidden, cell))
        
        return hidden, cell

In [3]:
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, 
                 n_layers, dropout=0.5, bidirectional=True):
    
        super(Decoder, self).__init__()
        
        self.output_dim = output_dim
        self.hid_dim = hid_dim
        
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout,
                           bidirectional=bidirectional)
        
        if bidirectional:
            self.fc = nn.Linear(hid_dim * 2, output_dim)
        else:
            self.fc = nn.Linear(hid_dim, output_dim)
            
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, hidden, cell):
        
        batch_size = x.shape[0]
        
        x = x.reshape(1, -1)
        # x [1, batch_size]
        
        embedded = self.dropout(self.embedding(x))
        
        output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
        
        output = output.reshape(batch_size, -1)
        
        output = F.log_softmax(self.fc(output), dim=1)
        
        return output, hidden, cell

In [4]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device, max_trg_len=100):
        
        super(Seq2Seq, self).__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        
        self.encoder.to(device)
        self.decoder.to(device)
        
        self.n_encoder_directions = 2 if encoder.bidirectional else 1
        
        self.max_trg_len = max_trg_len
        
    def forward(self, src, src_lens, trg, trg_lens, teacher_forcing_ratio=0.5):
        
        # src [seq_len, batch_size]
        # trg [seq_len, batch_size]
        
        batch_size = src.shape[1]
        
        hidden = torch.zeros(self.encoder.n_layers * self.n_encoder_directions, batch_size,
                             self.encoder.hid_dim, device=self.device)
        
        cell = torch.zeros(self.encoder.n_layers * self.n_encoder_directions, batch_size,
                           self.encoder.hid_dim, device=self.device)
        
        hidden, cell = self.encoder(src, src_lens, hidden, cell)
        
        x = trg[0, :]
        
        max_trg_len = max(trg_lens)
        
        outputs = torch.zeros(max_trg_len, batch_size, self.decoder.output_dim)
        
        for t in range(1, max_trg_len):
            
            output, hidden, cell = self.decoder(x.to(self.device), hidden, cell)
            
            outputs[t] = output
            
            top1 = output.argmax(dim=1).detach()
            
            teacher_forcing = random.random() < teacher_forcing_ratio
            
            x = trg[t, :] if teacher_forcing else top1
        
        return outputs
            
    def translate(self, src, src_lens, trg):
        
        batch_size = src.shape[1]
        
        hidden = torch.zeros(self.encoder.n_layers * self.n_encoder_directions, batch_size,
                             self.encoder.hid_dim, device=self.device)
        
        cell = torch.zeros(self.encoder.n_layers * self.n_encoder_directions, batch_size,
                           self.encoder.hid_dim, device=self.device)
        
        hidden, cell = self.encoder(src, src_lens, hidden, cell)
        
        x = trg[0, :]
        
        outputs = torch.zeros(self.max_trg_len, batch_size, self.decoder.output_dim)
        
        for t in range(0, self.max_trg_len):
            
            output, hidden, cell = self.decoder(x.to(self.device), hidden, cell)
            
            outputs[t] = output
            
            top1 = output.argmax(dim=1).detach()
            
            x = top1
        
        return outputs

In [5]:
class NMTDataset(Dataset):
    def __init__(self, src, trg):
        self.src = src
        self.trg = trg

    def __len__(self):
        return len(self.src)

    def __getitem__(self, index):
        return {
            'src': self.src[index]['sentence'],
            'src_len': self.src[index]['len'],
            'trg': self.trg[index]['sentence'],
            'trg_len': self.trg[index]['len']
        }

In [6]:
spacy_en = spacy.load('en')
def tokenize_en(text):
    """
    Tokenizes English text from a string into a list of strings (tokens)
    """
    return [tok.text for tok in spacy_en.tokenizer(text)]

In [7]:
with open('cmn.txt', 'r', encoding='utf-8') as f:
        data = f.read()

data = data.strip().split('\n')

print('number of examples: ', len(data))

en_data = [line.split('\t')[0] for line in data]
zh_data = [line.split('\t')[1] for line in data]

assert len(en_data) == len(zh_data)

number of examples:  23610


In [8]:
zh_words = set()
en_words = set()

for i in tqdm(range(len(zh_data))):
    en_seg = tokenize_en(en_data[i])
    zh_seg = list(zh_data[i])

    zh_words.update(zh_seg)
    en_words.update(en_seg)

100%|██████████| 23610/23610 [00:02<00:00, 11239.16it/s]


In [9]:
zh_word2idx = {value: index + 4 for index, value in enumerate(zh_words)}

zh_word2idx['<pad>'] = 0
zh_word2idx['<sos>'] = 1
zh_word2idx['<eos>'] = 2
zh_word2idx['<unk>'] = 3

zh_idx2word = {zh_word2idx[k]: k for k in zh_word2idx.keys()}

In [10]:
en_word2idx = {value: index + 4 for index, value in enumerate(en_words)}

en_word2idx['<pad>'] = 0
en_word2idx['<sos>'] = 1
en_word2idx['<eos>'] = 2
en_word2idx['<unk>'] = 3

en_idx2word = {en_word2idx[k]: k for k in en_word2idx.keys()}

In [11]:
zh = []
en = []

for i in tqdm(range(len(zh_data))):
    en_seg = tokenize_en(en_data[i])
    zh_seg = list(zh_data[i])

    en_sentence = [en_word2idx['<sos>']] + [en_word2idx[w] for w in en_seg] + [en_word2idx['<eos>']]
    zh_sentence = [zh_word2idx['<sos>']] + [zh_word2idx[w] for w in zh_seg] + [zh_word2idx['<eos>']]

    en_len = len(en_sentence)
    zh_len = len(zh_sentence)

    zh.append({
         'sentence': zh_sentence,
         'len': zh_len
    })
    en.append({
        'sentence': en_sentence,
        'len': en_len
    })

100%|██████████| 23610/23610 [00:01<00:00, 15004.52it/s]


In [12]:
BATCH_SIZE = 128
LEARNING_RATE = 1e-4
INPUT_DIM = len(en_word2idx)
OUTPUT_DIM = len(zh_word2idx)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
HID_DIM = 512
EMB_DIM = 256
N_LAYERS = 2
EPOCH = 200

In [13]:
def padding_batch(batch):

    src_lens = [d["src_len"] for d in batch]
    trg_lens = [d["trg_len"] for d in batch]

    src_max = max([d["src_len"] for d in batch])
    trg_max = max([d["trg_len"] for d in batch])

    srcs = []
    trgs = []

    for d in batch:
        src = copy.deepcopy(d['src'])
        trg = copy.deepcopy(d['trg'])

        src.extend([en_word2idx["<pad>"]]*(src_max-d["src_len"]))
        trg.extend([zh_word2idx["<pad>"]]*(trg_max-d["trg_len"]))

        srcs.append(src)
        trgs.append(trg)

    srcs = torch.tensor(srcs, dtype=torch.long, device=DEVICE)
    trgs = torch.tensor(trgs, dtype=torch.long, device=DEVICE)

    batch = {"src":srcs.T, "src_lens":src_lens,
             "trg":trgs.T, "trg_lens":trg_lens}
    return batch

In [14]:
encoder = Encoder(INPUT_DIM, emb_dim=EMB_DIM, hid_dim=HID_DIM, n_layers=N_LAYERS)
decoder = Decoder(OUTPUT_DIM, emb_dim=EMB_DIM, hid_dim=HID_DIM, n_layers=N_LAYERS)
model = Seq2Seq(encoder, decoder, DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.NLLLoss(ignore_index=0)

In [15]:
dataset = NMTDataset(en, zh)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, collate_fn=padding_batch, shuffle=True)

In [16]:
best_valid_loss = sys.maxsize
print_every = 10

In [None]:
for epoch in range(EPOCH):
    
    print_loss_total = 0
    
    model.train()
    for index, batch in enumerate(dataloader):
        
        src = batch['src']
        src_lens = batch['src_lens']
        trg = batch['trg']
        trg_lens = batch['trg_lens']
        
        optimizer.zero_grad()
        
        outputs = model(src, src_lens, trg, trg_lens)
        
        outputs = outputs[1:, :, :]
        
        outputs = outputs.reshape(-1, OUTPUT_DIM)
        
        trg = trg[1:, :]
        
        trg = trg.reshape(-1).cpu()
        
        loss = criterion(outputs, trg)
        
        print_loss_total += loss.item()
        
        loss.backward()
        
        nn.utils.clip_grad_norm_(model.parameters(), 1)

        optimizer.step()
        
        if (index + 1) % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            
            info = 'Train Epoch [{}/{}], Avg Loss: {:.4f}'. \
                    format(epoch + 1, EPOCH, print_loss_avg)
            print(info)
    
    valid_loss = 0
    model.eval()
    
    with torch.no_grad():
        for index, batch in enumerate(dataloader):
            
            src = batch['src']
            src_lens = batch['src_lens']
            trg = batch['trg']
            trg_lens = batch['trg_lens']
            
            output = model(src, src_lens, trg, trg_lens, teacher_forcing_ratio=0)
            
            outputs = model(src, src_lens, trg, trg_lens)
        
            outputs = outputs[1:, :, :]

            outputs = outputs.reshape(-1, OUTPUT_DIM)

            trg = trg[1:, :]

            trg = trg.reshape(-1).cpu()

            loss = criterion(outputs, trg)
            
            valid_loss += loss.item()
    
    valid_loss = valid_loss / len(dataloader)
    info = 'Train Epoch [{}/{}], Valid Loss: {:.4f}'. \
            format(epoch + 1, EPOCH, valid_loss)
    print(info)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'seq2seq.pth')

In [49]:
print('best valid loss: {}'.format(best_valid_loss))

best valid loss: 0.036710509526971226


In [50]:
model.load_state_dict(torch.load('seq2seq.pth'))

<All keys matched successfully>

In [53]:
def translate(model, en):
    en = tokenize_en(en)
    en.append('<eos>')
    
    en_len = len(en)

    en_data = []

    for w in en:
        if w in en_word2idx.keys():
            en_data.append(en_word2idx[w])
        else:
            en_data.append(en_word2idx['<unk>'])
    en = en_data
    
    en = [en]

    en = torch.LongTensor(en).permute(1, 0)
    
    model.eval()
    with torch.no_grad():
        outputs = model.translate(en.to(DEVICE), [en_len], torch.LongTensor([[1]]))

        outputs = outputs.permute(1, 0, 2).cpu().detach().numpy()

        outputs = np.argmax(outputs, axis=2)

        outputs = outputs[0]

        zh_data = []

        for w in outputs:
            if w in zh_idx2word.keys():
                zh_data.append(zh_idx2word[w])
                if zh_idx2word[w] == '<eos>':
                    break
            else:
                zh_data.append('<unk>')

        print(zh_data)

In [55]:
# 他們有危險。
translate(model, 'They are in danger.')

['他', '有', '危', '險', '。', '<eos>']


In [56]:
translate(model, 'I love you.')

['我', '爱', '您', '。', '<eos>']


In [61]:
# 我們要問問湯姆，看看他怎麼想。
translate(model, 'We\'ll ask Tom and see what he thinks.')

['我', '們', '問', '問', '湯', '姆', '看', '看', '他', '他', '什', '麼', '。', '<eos>']


In [62]:
# 他受不了咖啡的苦味。
translate(model, 'He couldn\'t stand the bitterness of the coffee.')

['他', '受', '不', '了', '咖', '啡', '的', '苦', '。', '<eos>']


In [63]:
# 他扔一塊石頭到池塘裡。
translate(model, 'He threw a stone into the pond.')

['他', '扔', '一', '塊', '石', '塘', '到', '池', '塘', '。', '<eos>']


In [65]:
translate(model, 'I like eating bread.')

['我', '喜', '歡', '面', '。', '<eos>']


In [66]:
# 她一週之內會回來。
translate(model, 'She will be back within a week.')

['一', '週', '之', '內', '就', '回', '家', '。', '<eos>']
