In [1]:
import os
import time
import string
import itertools
import wandb
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
from torchtext.datasets import Multi30k

In [2]:
from modules.data import Vocabulary, PolEngDS, get_loader
from modules.model import Model

In [3]:
@dataclass
class Config:
    epochs:int = 20
    learning_rate:float = 3e-4
    batch_size:int = 4
    limit:int = 100000
    max_length:int = 50
    embed_size:int = 256
    num_layers:int = 3
    heads:int = 8
    forward_expansion:int = 4
    dropout:int = 0.15
        
config = Config()

In [4]:
class DeEnDS(Dataset):
    def __init__(self):
        self.data = list(Multi30k('./data_exp/multi30k-dataset/task1')[0])
        
        self.preprocessing()
    def __getitem__(self, index):
        de, en = [text.split() for text in self.data.iloc[index].values]

        de = torch.IntTensor([self.vocab_de['<sos>'], *[self.vocab_de[word] for word in de], self.vocab_de['<eos>']])
        en = torch.IntTensor([self.vocab_en['<sos>'], *[self.vocab_en[word] for word in en], self.vocab_en['<eos>']])

        return de, en 
    
    def __len__(self):
        return len(self.data)
    
    def preprocessing(self):
        preprocessed_data = {
            'deutsch': [],
            'english': []
        }
        
        for de, en in self.data:
            preprocessed_data['deutsch'].append(self._text_prep(de))
            preprocessed_data['english'].append(self._text_prep(en))
        
        self.data = pd.DataFrame(preprocessed_data)
        
        self.vocab_de = Vocabulary(self._flat_list(self.data['deutsch']))
        self.vocab_en = Vocabulary(self._flat_list(self.data['english']))
        
    @staticmethod
    def _text_prep(text):
        #remove punctuations
        text = text.translate(str.maketrans('', '', string.punctuation))
        text = text.strip().lower()
        text.split('/n')
        
        return text
    
    @staticmethod
    def _flat_list(data):
        data = [text.split() for text in data]
        return list(itertools.chain.from_iterable(data))

In [5]:
train_data = DeEnDS()

vocab_de = train_data.vocab_de
vocab_en = train_data.vocab_en



In [6]:
def pad_seq(batch, padding_de=1, padding_en=1):
    de, en = [], []

    for i, (de_text, en_text) in enumerate(batch):
        de.append(de_text)
        en.append(en_text)

    de = pad_sequence(de, batch_first=True, padding_value=padding_de)
    en = pad_sequence(en, batch_first=True, padding_value=padding_en)

    return de, en

In [7]:
train_loader = DataLoader(train_data, batch_size=32, collate_fn=pad_seq, shuffle=True)

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [9]:
model = Model(
    src_vocab_size=len(vocab_de),
    trg_vocab_size=len(vocab_en),
    src_pad_idx=vocab_de['<pad>'],
    trg_pad_idx=vocab_en['<pad>'],
    embed_size=config.embed_size,
    num_layers=config.num_layers,
    heads=config.heads,
    forward_expansion=config.forward_expansion,
    dropout=config.dropout,
    max_length=config.max_length,
    device=device
)

In [10]:
optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate)
criterion = nn.CrossEntropyLoss(ignore_index=vocab_en['<pad>'])

scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    factor=0.1,
    patience=5,
    verbose=True
)

In [11]:
def train_epoch(model, loader, epoch, device=device):
    model.train()
    
    losses = []
    t0 = time.time()
    t_batch = t0
    
    for batch_idx, (src, trg) in enumerate(loader):
        src = src.to(device)
        trg = trg.to(device)

        scores = model(src, trg[:, :-1])

        loss = criterion(
            scores.reshape(-1, scores.shape[2]), 
            trg[:, 1:].reshape(-1).type(torch.long)
        )
        
        losses.append(loss)
        
        optimizer.zero_grad()
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
        
        optimizer.step()
        
        if (batch_idx + 1) % int(len(loader) / 5) == 0:
            print('Epoch: {epoch}, batch: {batch_idx}/{no_batches}, loss: {loss:.3f}, time: {t:.2f}'.format(
                epoch=epoch+1,
                batch_idx=batch_idx,
                no_batches=len(loader),
                loss=sum(losses)/len(losses),
                t=time.time()-t_batch
                )
            )
            
            t_batch = time.time()
        
    loss = sum(losses) / len(losses)
    
    scheduler.step(loss)
    
    print('Epoch: {epoch}, loss: {loss:.3f}, time: {t:.2f}'.format(
        epoch=epoch+1, 
        loss=loss, 
        t=time.time()-t0
        )
    )
    
    return loss

In [12]:
loss = []

for epoch in range(config.epochs):
    l = train_epoch(model, train_loader, epoch)
    loss.append(l)
    
    torch.save({
        'epoch': epoch,
        'model_sd': model.state_dict(),
        'optimizer_sd': optimizer.state_dict(),
        'loss': loss
       }, f'./models/checkpoint-{epoch}.pt')