In [1]:
# autoreload modules on every saved changes
%load_ext autoreload
%autoreload 2

In [3]:
from core.dataset import Dataset

Dataset.download_all()

parallel-corpora-en-id-master.zip: 18.4MiB [00:01, 10.8MiB/s]
Indonesian-English-Bilingual-Corpus-master.zip: 3.00MiB [00:00, 6.21MiB/s]
TALPCo-master.zip: 717kiB [00:01, 553kiB/s]  


In [2]:
from core.dataset import Dataset

SRC_LANG = 'en'
TGT_LANG = 'id'

UNK_IDX, PAD_IDX, SOS_IDX, EOS_IDX = 0, 1, 2, 3

dataset = Dataset.load_dataset(init_vocab=['<unk>', '<pad>', '<sos>', '<eos>'])

Tokenizing data (en): 100%|██████████| 315508/315508 [00:32<00:00, 9729.01it/s] 
Tokenizing data (id): 100%|██████████| 315508/315508 [00:39<00:00, 7997.76it/s] 


In [9]:
import pandas as pd
import numpy as np

pd.DataFrame(np.transpose([dataset.data[lang].sent[-10:] for lang in dataset.langs]), columns=dataset.langs).head(10)

Unnamed: 0,en,id
0,"the opposite of ""heavy"" is ""light"".","lawan kata ""berat"" adalah ""ringan""."
1,you cannot go into this room.,tidak boleh masuk kamar ini.
2,"if you press here, the sound will get louder.","kalau ditekan di sini, bunyinya akan menjadi b..."
3,mr. lee and i are in the same group.,saya sekelompok dengan bapak lee.
4,i always read a book for only thirty minutes b...,saya selalu membaca buku tiga puluh menit saja...
5,i ate only one snack.,saya makan kue sebuah saja.
6,"b: one, two, three, four...","b: satu, dua, tiga, empat...."
7,a: how many are there?,a: ada berapa?
8,there are seven eggs in the refrigerator.,di dalam kulkas ada tujuh butir telur.
9,i always go to work at nine in the morning.,saya pergi ke perusahaan setiap pagi pukul sem...


In [11]:
for lang in dataset.langs:
    print('vocab ({}): {}'.format(lang, len(dataset.data[lang].vocab)))

vocab (en): 6667439
vocab (id): 5841565


In [5]:
import torch
import torch.nn as nn
import torch.functional as F
import math

from torch import Tensor

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class PositionalEncoding(nn.Module):
    def __init__(self, emb_size: int, dropout: float, max_length: int=1000):
        super(PositionalEncoding, self).__init__()

        pos = torch.arange(0, max_length).reshape(max_length, 1)
        den = torch.exp(-torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)

        pos_emb = torch.zeros((max_length, emb_size))
        pos_emb[:, 0::2] = torch.sin(pos * den)
        pos_emb[:, 1::2] = torch.cos(pos * den)
        pos_emb = pos_emb.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_emb', pos_emb)

    def forward(self, token_emb):
        return self.dropout(token_emb + self.pos_emb[:token_emb.size(0), :])

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size: int):
        super(TokenEmbedding, self).__init__()

        self.emb = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens):
        return self.emb(tokens.long()) * math.sqrt(self.emb_size)

class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                encoder_layers: int,
                decode_layers: int,
                emb_size: int,
                nhead: int,
                source_vocab_size: int,
                target_vocab_size: int,
                dim_feedforward: int,
                dropout: float):
                
        super(Seq2SeqTransformer, self).__init__()

        self.pos_encoding = PositionalEncoding(emb_size, dropout=dropout)

        self.source_token_emb = TokenEmbedding(source_vocab_size, emb_size)
        self.target_token_emb = TokenEmbedding(target_vocab_size, emb_size)
        
        self.transformer = nn.Transformer(d_model=emb_size,
                                    nhead=nhead,
                                    num_encoder_layers=encoder_layers,
                                    num_decoder_layers=decode_layers,
                                    dim_feedforward=dim_feedforward,
                                    dropout=dropout)

        self.generator = nn.Linear(emb_size, target_vocab_size)
    
    def forward(self,
                source: Tensor,
                target: Tensor,
                source_mask: Tensor,
                target_mask: Tensor,
                source_padding_mask: Tensor,
                target_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):

        source_emb = self.pos_encoding(self.source_token_emb(source))
        target_emb = self.pos_encoding(self.target_token_emb(target))

        out_seq = self.transformer(source_emb, target_emb, source_mask, target_mask,
                                None, source_padding_mask, target_padding_mask, memory_key_padding_mask)

        return self.generator(out_seq)

    def encode(self, source: Tensor, source_mask: Tensor):
        return self.transformer.encoder(self.pos_encoding(self.source_token_emb(source)), source_mask)
    
    def decode(self, target: Tensor, memory: Tensor, target_mask: Tensor):
        return self.transformer.decoder(self.pos_encoding(self.target_token_emb(target)), memory, target_mask)
  

In [9]:
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

def create_mask(source, target):
    source_seq_len = source.shape[0]
    target_seq_len = target.shape[0]

    target_mask = generate_square_subsequent_mask(target_seq_len)
    source_mask = torch.zeros((source_seq_len, source_seq_len),device=device).type(torch.bool)

    source_padding_mask = (source == PAD_IDX).transpose(0, 1)
    target_padding_mask = (target == PAD_IDX).transpose(0, 1)
    return source_mask, target_mask, source_padding_mask, target_padding_mask

In [23]:
transformer = Seq2SeqTransformer(3, 3, 512, 8,
                            len(dataset.data[SRC_LANG].vocab), len(dataset.data[TGT_LANG].vocab), 512, 0.1)

for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

transformer = transformer.to(device)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

In [11]:
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import TensorDataset, DataLoader

def vocab_tf(vocab, tokens):
    return [SOS_IDX] + [vocab[tok] if tok in vocab.keys() else UNK_IDX for tok in tokens] + [EOS_IDX]

def feature_tensors(vocab, sent_tokens):
    tensors = [torch.Tensor(vocab_tf(vocab, tokens)) for tokens in sent_tokens]

    return pad_sequence(tensors, padding_value=PAD_IDX)

source = feature_tensors(dataset.data[SRC_LANG].vocab, dataset.data[SRC_LANG].sent_tokens)
target = feature_tensors(dataset.data[TGT_LANG].vocab, dataset.data[TGT_LANG].sent_tokens)

tensor_dataset = TensorDataset(source, target)
dataloader = DataLoader(tensor_dataset)

AssertionError: Size mismatch between tensors

In [None]:
from torch.utils.data import TensorDataset, DataLoader

def train_epoch(model, optimizer):
    model.train()
    losses = 0
    train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    for src, tgt in train_dataloader:
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

        optimizer.zero_grad()

        tgt_out = tgt[1:, :]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()

    return losses / len(train_dataloader)


def evaluate(model):
    model.eval()
    losses = 0

    val_iter = Multi30k(split='valid', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    val_dataloader = DataLoader(val_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    for src, tgt in val_dataloader:
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

        tgt_out = tgt[1:, :]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        losses += loss.item()

    return losses / len(val_dataloader)

In [None]:
num_epochs = 18

for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train_epoch(transformer, optimizer)
    end_time = timer()
    val_loss = evaluate(transformer)
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}"))

# function to generate output sequence using greedy algorithm
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to(DEVICE)
    src_mask = src_mask.to(DEVICE)

    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
    for i in range(max_len-1):
        memory = memory.to(DEVICE)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                    .type(torch.bool)).to(DEVICE)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()

        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        if next_word == EOS_IDX:
            break
    return ys

# actual function to translate input sentence into target language
def translate(model: torch.nn.Module, src_sentence: str):
    model.eval()
    src = text_transform[SRC_LANGUAGE](src_sentence).view(-1, 1)
    num_tokens = src.shape[0]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    tgt_tokens = greedy_decode(
        model,  src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten()
    return " ".join(vocab_transform[TGT_LANGUAGE].lookup_tokens(list(tgt_tokens.cpu().numpy()))).replace("<bos>", "").replace("<eos>", "")



In [None]:
# Plot loss and Accuracy
