In [1]:
from architecture.seq2seq_module import Seq2Seq, seq2seq_encoder, seq2seq_decoder
from tokenization.tokenizer_module import language_set, spec_tokens, translation_tokenizer
from torch.utils.data import Dataset, DataLoader
import random
import numpy as np
import torch
import os
import gc
import copy
from timeit import default_timer as timer

seed = 1996
torch.manual_seed(seed)
torch.mps.manual_seed(seed)
random.seed(seed)

In [2]:
def get_data(file_path):
    txt = None
    with open(file_path) as f:
        txt = f.read().split("\n")
    txt = [t.strip() for t in txt]
    temp = []
    for t in txt:
        if len(t) > 0:
            temp.append(t)
    txt = temp
    return np.array(txt)


def split_train_test_val(it_data, fr_data, perc_test = 0.1, n_val_cases = 100):
    assert len(it_data) > n_val_cases
    assert len(it_data) - n_val_cases > 0
    n_train_test, n = len(it_data) - n_val_cases, len(it_data)
    n_test = int(n_train_test * (1-perc_test))
    indices = random.sample(range(len(it_data)), len(it_data))
    train_indices, test_indices, val_indices  = indices[:n_test], indices[n_test:-n_val_cases], indices[n-n_val_cases:]
    return (language_set(source=it_data[train_indices], target=fr_data[train_indices]), 
            language_set(source=it_data[test_indices], target=fr_data[test_indices]),
            language_set(source=it_data[val_indices], target=fr_data[val_indices]) )


root = "data/it_fr/"

it_data = get_data(os.path.join(root, "Tatoeba.fr-it.it"))
fr_data = get_data(os.path.join(root, "Tatoeba.fr-it.fr"))

train_data, test_data, val_data = split_train_test_val(it_data, fr_data)

In [3]:
class translation_dataset(Dataset):
    def __init__(self, data:language_set):
        self.data = data
        
    def __len__(self):
        return len(self.data.source)
    
    def __getitem__(self, idx):
        return self.data.source[idx], self.data.target[idx]

train_set = translation_dataset(train_data)
test_set = translation_dataset(test_data)
val_set = translation_dataset(val_data)

### Tokenizer

In [None]:
VOCAB_SIZE = 10_000
MAX_SEQUENCE_LEN = 20

tokenizer = translation_tokenizer(VOCAB_SIZE, MAX_SEQUENCE_LEN)
tokenizer.set_tokenizers(language_set(source=[*train_data.source, *test_data.source], target=[*train_data.target, *test_data.target]))

In [5]:
tokenizer.save_tokenizer("models")

In [None]:
print(tokenizer.src_wrap.vocab)

In [None]:
import random

for k in range(0,5):
    i=random.randint(0, len(train_data.source))
    print(tokenizer.src_wrap(train_data.source[i]).tokens(), 
          tokenizer.trg_wrap(train_data.target[i]).tokens())
    
    print(tokenizer.src_wrap.encode(train_data.source[i]), 
          tokenizer.trg_wrap.encode(train_data.target[i]))
    
    print(tokenizer.src_wrap.decode(tokenizer.src_wrap.encode(train_data.source[i])), 
          tokenizer.trg_wrap.decode(tokenizer.trg_wrap.encode(train_data.target[i])))
    print()

In [None]:
print(tokenizer.src_wrap(train_data.source[i]).word_ids())

### Define model

In [9]:
DEVICE = "mps"
BATCH_SIZE = 256
SRC_VOCAB_SIZE = len(tokenizer.src_wrap)
TRG_VOCAB_SIZE = len(tokenizer.trg_wrap)
ENCODER_EMBEDDING_DIM = 256
DECODER_EMBEDDING_DIM = 256
HIDDEN_DIM = 512
N_LAYERS = 2
ENCODER_DROPOUT = 0.5
DECODER_DROPOUT = 0.5

encoder = seq2seq_encoder(SRC_VOCAB_SIZE, ENCODER_EMBEDDING_DIM, HIDDEN_DIM, N_LAYERS, ENCODER_DROPOUT, bidirectional=True)

decoder = seq2seq_decoder(TRG_VOCAB_SIZE, DECODER_EMBEDDING_DIM, HIDDEN_DIM, N_LAYERS, DECODER_DROPOUT, attention=True,
                          enc_hidden_dim=2*HIDDEN_DIM)

model = Seq2Seq(encoder, decoder, DEVICE).to(DEVICE)
model.init_weights()

In [10]:
seed = 1996
torch.manual_seed(seed)
torch.mps.manual_seed(seed)

def collate_func(batch):
    src, trg = [], []
    for b in batch:
        src.append(b[0])
        trg.append(b[1])
    src_tokens, trg_tokens = tokenizer(src, trg)
    src_batch, target_batch = torch.tensor( src_tokens , dtype=torch.long ) , torch.tensor( trg_tokens, dtype=torch.long  )

    return src_batch, target_batch


In [11]:
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, collate_fn=collate_func)
test_loader =  DataLoader(test_set, batch_size=BATCH_SIZE,  collate_fn=collate_func)

In [12]:
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_id())
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-9)

### Training functions

In [13]:
def train_fn(model, data_loader, optimizer, device, teacher_ratio, clip=1):
    model.train()
    epoch_loss = 0
    optimizer.zero_grad()   
    for src, trg in data_loader:
        src, trg = src.to(device), trg.to(device)
        trg_out = trg[:, 1:]
        logits = model(src, trg, teacher_ratio) 
        optimizer.zero_grad()
        loss = loss_fn(logits[:,1:,:].contiguous().view(-1, TRG_VOCAB_SIZE), trg_out.contiguous().view(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item()
    return epoch_loss / len(data_loader)

def evaluate_fn(model, data_loader, device):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for src, trg in data_loader:
            src, trg = src.to(device), trg.to(device)
            trg_input, trg_out = trg[:, :-1], trg[:, 1:]
            logits = model(src, trg , 0) 
            loss = loss_fn(logits[:,1:,:].contiguous().view(-1, TRG_VOCAB_SIZE), trg_out.contiguous().view(-1))
            epoch_loss += loss.item()           
    return epoch_loss / len(data_loader)

def translate_sentence( sentence:str, model:Seq2Seq, max_len = 25):
    model.eval()
    with torch.no_grad():
        tensor = torch.tensor( tokenizer(sentence, "")[0] ).unsqueeze(0).to(DEVICE)
        encoder_output, hidden, cell_state = model.encoder(tensor)
        decoder_hidden = torch.zeros_like(hidden)  # Initialize decoder hidden state
        decoder_cell = torch.zeros_like(hidden)
        out_tokens = torch.ones(1, 1).fill_(tokenizer.sos_id()).type(torch.long).to(DEVICE)
        for i in range(max_len):
            output, decoder_hidden, decoder_cell  = model.decoder(out_tokens[:,i].unsqueeze(0), decoder_hidden, decoder_cell, encoder_output)
            _, predicted_token = torch.max(torch.softmax(output, dim=-1), dim=1)
            predicted_token = predicted_token[-1].item()
            out_tokens = torch.cat([out_tokens, torch.ones(1, 1).type_as(out_tokens.data).fill_(predicted_token).to(DEVICE)], dim=1)
            if predicted_token == tokenizer.eos_id():
                break
        return "".join( tokenizer.decode( out_tokens.cpu().tolist()[0]) )

In [None]:
translate_sentence(test_data.source[0], model)

In [None]:
decoder.rnn

In [None]:
NUM_EPOCHS=20

gc.collect()
os.environ["TOKENIZERS_PARALLELISM"] = "false"

train_loss_list, valid_loss_list = [], []
torch.mps.empty_cache()  

best_valid_loss = float("inf")
best_model = copy.deepcopy(model)
sentence = test_data.source[0]
for epoch in range(NUM_EPOCHS):
    start_time = timer()
    train_loss = train_fn( model, train_loader, optimizer, DEVICE, teacher_ratio=0.5)
    valid_loss = evaluate_fn( model, test_loader, DEVICE )
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        best_model = copy.deepcopy(model)
        torch.save(model.state_dict(), os.path.join("models", "checkpoint.pt"))
    end_time = timer()  
    print(f"Epoch: {epoch+1:02}\t time = {(end_time - start_time):.3f}s")
    print(f"\tTrain Loss: {train_loss:7.3f} | Train PPL: {np.exp(train_loss):7.3f}")
    print(f"\tValid Loss: {valid_loss:7.3f} | Valid PPL: {np.exp(valid_loss):7.3f}")
    print(f"Original text: {sentence}")
    print(f"Translated text: {translate_sentence(sentence, best_model)}")
    print()

In [19]:
torch.save(best_model, 'models/model.pth')

In [None]:
n_val = 0
sentence = val_data.source[n_val]
translation = translate_sentence(sentence, best_model)

print(f"Original text: {sentence}")
print(f"Translated text: {translation}")

In [None]:
print(val_data.target[n_val])