In [None]:
import os
import gc
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from sentencepiece import SentencePieceTrainer, SentencePieceProcessor
import torch.nn.functional as F

import warnings
warnings.filterwarnings('ignore')

from tqdm.autonotebook import tqdm

import wandb
from pathlib import Path
Path('/kaggle/working/model').mkdir(parents=True, exist_ok=True)

In [None]:
# key = ''
# wandb.login(key=key)

In [None]:
UNK_IDX = 0
BOS_IDX = 1
EOS_IDX = 2
PAD_IDX = 3

### Data (loaders, preprocessing)

In [None]:
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func


def yield_tokens(tokenizer, texts):
    for text_sample in tqdm(texts, desc="Building vocab"):
        sentence = tokenizer(text_sample)
        for token in sentence:
            yield token



class TextDataset(Dataset):

    def __init__(self, type, path_src, path_tgt='', path_vocab='vocabs', vocab_size=50000):
        self.type = type
        self.path_src = path_src
        self.lang_src = 'de'
        self.path_tgt = path_tgt
        self.lang_tgt = 'en'

        self.vocab_size = vocab_size

        self.texts = {'de':[], 'en':[]}
        with open(self.path_src, encoding="utf-8") as f:
            self.texts[self.lang_src] = [line.rstrip() for line in f.readlines()]
        
        if self.type != 'test':
            with open(self.path_tgt, encoding="utf-8") as f:
                self.texts[self.lang_tgt] = [line.rstrip() for line in f.readlines()]

        self.token_transform = {}

        if not os.path.isfile('tokenizer_src.model'):
            SentencePieceTrainer.train(
                input=self.path_src, vocab_size=self.vocab_size,
                model_type='word', model_prefix='tokenizer_src',
                normalization_rule_name='nmt_nfkc_cf',
                pad_id = PAD_IDX,
                unk_surface='<unk>'
            )
        self.token_transform[self.lang_src] = SentencePieceProcessor(model_file='tokenizer_src.model')

        if self.type != 'test':
            if not os.path.isfile('tokenizer_tgt.model'):
                SentencePieceTrainer.train(
                    input=self.path_tgt, vocab_size=self.vocab_size,
                    model_type='word', model_prefix='tokenizer_tgt',
                    normalization_rule_name='nmt_nfkc_cf',
                    pad_id = PAD_IDX,
                    unk_surface='<unk>'
                )
            self.token_transform[self.lang_tgt] = SentencePieceProcessor(model_file='tokenizer_tgt.model')

            
        self.text_transform = {}
        
        self.text_transform[self.lang_src] = sequential_transforms(
            self.token_transform[self.lang_src].encode, TextDataset.tensor_transform)
        
        if self.type != 'test':
            self.text_transform[self.lang_tgt] = sequential_transforms(
                self.token_transform[self.lang_tgt].encode, TextDataset.tensor_transform)

    @staticmethod
    def tensor_transform(token_ids):
        return torch.cat((torch.tensor([BOS_IDX]),
                        torch.tensor(token_ids),
                        torch.tensor([EOS_IDX])))
    
    def __len__(self):
        return len(self.texts[self.lang_src])


    def __getitem__(self, index):
        if self.type != 'test':
            return tuple([self.text_transform[self.lang_src](self.texts[self.lang_src][index]), self.text_transform[self.lang_tgt](self.texts[self.lang_tgt][index])])
        return tuple([self.text_transform[self.lang_src](self.texts[self.lang_src][index])])
        
    def collate(self, batch):
        maxlen = 0
        
        if self.type != 'test':
            maxlen = max([max(batch[i][0].shape[0], batch[i][1].shape[0]) for i in range(len(batch))])
        else:
            maxlen = max([batch[i][0].shape[0] for i in range(len(batch))])

        src_batch = torch.Tensor([torch.cat((b[0], torch.full((maxlen - b[0].shape[0], ), fill_value=3))).numpy() for b in batch])


        if self.type == 'test':
            return src_batch.T

        tgt_batch = torch.Tensor([torch.cat((b[1], torch.full((maxlen - b[1].shape[0], ), fill_value=3))).numpy() for b in batch])


        return src_batch.T, tgt_batch.T


### Model architecture

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* np.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens):
        return self.embedding(tokens.long()) * np.sqrt(self.emb_size)

class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = nn.Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(
            emb_size, dropout=dropout)

    def forward(self,
                src,
                trg,
                src_mask,
                tgt_mask,
                src_padding_mask,
                tgt_padding_mask,
                memory_key_padding_mask):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src, src_mask):
        return self.transformer.encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt, memory, tgt_mask):
        return self.transformer.decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)

def attention_mask(size, device):
    mask = (torch.triu(torch.ones((size, size), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

def create_mask(src, tgt, device):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = attention_mask(tgt_seq_len, device)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=device).type(torch.bool)

    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask


### Train and validate

In [None]:
class Trainer:
    def __init__(self, model, device, optimizer, criterion, model_saver):
        self.model = model
        self.model = self.model.to(device)
        self.device = device

        self.optimizer = optimizer
        self.criterion = criterion
        
        self.model_saver = model_saver


    def train_epoch(self, dataloader, desc):
        self.model.train()
        losses = 0
        
        for src, tgt in tqdm(dataloader, desc):

            src = src.type(torch.long)
            tgt = tgt.type(torch.long)

            src = src.to(self.device)
            tgt = tgt.to(self.device)

            tgt_input = tgt[:-1, :]

            src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input, self.device)

            logits = self.model(src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)
            self.optimizer.zero_grad()

            tgt_out = tgt[1:, :]
            loss = self.criterion(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
            loss.backward()

            self.optimizer.step()
            losses += loss.item()
            
        return losses / len(list(dataloader))


    def evaluate(self, dataloader, desc):
        self.model.eval()
        losses = 0

        for src, tgt in tqdm(dataloader, desc):
            src = src.type(torch.long)
            tgt = tgt.type(torch.long)
            
            src = src.to(self.device)
            tgt = tgt.to(self.device)

            tgt_input = tgt[:-1, :]

            src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input, self.device)

            logits = self.model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

            tgt_out = tgt[1:, :]
            loss = self.criterion(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
            losses += loss.item()

        return losses / len(list(dataloader))



    def train(self, train_loader, val_loader, n_epochs, start_epoch=0, continue_training=False, model_path='model/saved_model.pth', log=True):
        if continue_training:
            start_epoch = self.load_model(model_path)
        for epoch in range(start_epoch, n_epochs):
            
            torch.cuda.empty_cache()
            gc.collect()
            
            train_loss = self.train_epoch(train_loader, f'Training epoch {epoch}/{n_epochs}')
            val_loss = self.evaluate(val_loader, f'Validating epoch {epoch}/{n_epochs}')
            
            print(f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}")
            self.model_saver(val_loss, epoch, self.model, self.optimizer, model_path)
            if log:
                wandb.log({'train_loss':train_loss, 'val_loss':val_loss})

            
            
            
    def load_model(self, model_path='model/saved_model.pth'):
        checkpoint = torch.load(model_path)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        return checkpoint['epoch']
        
        
        
        
class SaveModel:
    def __init__(self, save_best=False, best_val_loss=torch.inf):
        self.save_best = save_best
        self.best_val_loss = best_val_loss
        
    def __call__(self, val_loss, epoch, model, optimizer, model_path='/kaggle/working/saved_model.pth'):
        
        if val_loss < self.best_val_loss or not self.save_best:
            self.best_val_loss = val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()}, model_path)
            print('New best model with loss {:.5f} is saved'.format(val_loss))


### Translate

In [None]:
@torch.no_grad()
def greedy_decode(model, src, src_mask, max_len, start_symbol, device):
    src = src.to(device)
    src_mask = src_mask.to(device)

    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device)
    
    for i in range(max_len-1):
        memory = memory.to(device)
        tgt_mask = (attention_mask(ys.size(0), device)
                    .type(torch.bool)).to(device)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()

        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        if next_word == EOS_IDX:
            break
    return ys

@torch.no_grad()
def translate_1(model, src, device, vocab_transform_tgt):
    model.eval()
    num_tokens = src.shape[0]
    print(num_tokens)
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    tgt_tokens = greedy_decode(
        model, src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX, device=device).flatten()

    ready_tgt_tokens = list(map(int, tgt_tokens.cpu().numpy()))
    print(ready_tgt_tokens)
    return "".join(vocab_transform_tgt(ready_tgt_tokens)).replace("<bos>", "").replace("<eos>", "")


### Run

In [None]:
data_dir = '/kaggle/input/bhw-2-translation-dataset/bhw2-data/data/'

train_pth = 'train.de-en.'
val_pth = 'val.de-en.'
test_pth = 'test1.de-en.'

src_lang = 'de'
tgt_lang = 'en'

In [None]:
num_encoder_layers = 3
num_decoder_layers = 3
emb_size = 512
nhead = 8
dim_feedforward = 1024

batch_size = 128
vocab_size = 40000
n_epochs = 15

In [None]:
SEED = 11
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
np.random.seed(SEED)
torch.cuda.empty_cache()
gc.collect()


In [None]:


train_dataset = TextDataset('train', data_dir+train_pth+src_lang, data_dir+train_pth+tgt_lang, vocab_size=vocab_size)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=train_dataset.collate)

val_dataset = TextDataset('val', data_dir+val_pth+src_lang, data_dir+val_pth+src_lang, vocab_size=vocab_size)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=val_dataset.collate)


device = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


transformer = Seq2SeqTransformer(num_encoder_layers, num_decoder_layers, emb_size, nhead, vocab_size, vocab_size, dim_feedforward)



for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

transformer = transformer.to(device)

optimizer = torch.optim.Adam(transformer.parameters(), lr=1e-4, betas=(0.9, 0.98), eps=1e-9)
criterion = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX, label_smoothing=0.1)

model_saver = SaveModel()
trainer = Trainer(transformer, device, optimizer, criterion, model_saver)



In [None]:
run = wandb.init(project='bhw-2')
run.watch(transformer)

trainer.train(train_loader, val_loader, n_epochs)

### Test

In [None]:
test_dataset = TextDataset('test', data_dir+test_pth+src_lang, vocab_size=vocab_size)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=test_dataset.collate)

translated = []

for src in tqdm(test_loader):
    sentence = translate_1(trainer.model, src, device, train_dataset.token_transform[tgt_lang].decode)
    translated.append(sentence)
    
with open('/kaggle/working/output_1.txt', "w") as f:
    for sentence in translated:
        f.write(sentence + "\n")

In [None]:
gc.collect()
wandb.finish()