In [None]:
from transformers import MBartTokenizer, MBartForConditionalGeneration, MBart50TokenizerFast
from transformers import AdamW, get_linear_schedule_with_warmup
from transformers.models.bart.modeling_bart import shift_tokens_right
import torch 
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from datasets import load_metric
from tqdm.notebook import tqdm

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
import re
import random

In [None]:
# Hyper params
batch_size = 1
epochs = 2
learning_rate = 4e-5

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [None]:
def seed_data(seed_val=42):
    random.seed(seed_val)
    np.random.seed(seed_val)
    torch.manual_seed(seed_val)
    torch.cuda.manual_seed_all(seed_val)

# Simple EDA

In [None]:
df = pd.read_csv('data/eng_-french.csv')

In [None]:
df.head()

In [None]:
df.tail()

In [None]:
df['english_length'] = df['English words/sentences'].apply(lambda x: len(x.split()))
df['french_length'] = df['French words/sentences'].apply(lambda x: len(x.split()))

In [None]:
df['english_length'].max(), df['english_length'].min()

In [None]:
df['french_length'].max(), df['english_length'].min()

# Clean data & Split

In [None]:
def remove_special_char(x):
    if len(x.split()) > 1:
        return re.sub('[^A-Za-z0-9]+', ' ', x)
    else:
        return re.sub('[^A-Za-z0-9]+', '', x)
    
def remove_empty_last_space(x):
    x = x.split()
    x = [s for s in x if s != ' ']
    
    return ' '.join(x)

In [None]:
# remove 
df['English words/sentences'] = df['English words/sentences'].apply(lambda x: remove_special_char(x))
df['English words/sentences'] = df['English words/sentences'].apply(lambda x: remove_empty_last_space(x))

df['French words/sentences'] = df['French words/sentences'].apply(lambda x: remove_special_char(x))
df['French words/sentences'] = df['French words/sentences'].apply(lambda x: remove_empty_last_space(x))

In [None]:
# Split to train and test
train_df, test_df = train_test_split(df, test_size=0.2, random_state=2020)
eval_df, test_df = train_test_split(test_df, test_size=0.5, random_state=2021) # Further split for eval and test

# Prepare dataset

In [None]:
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", src_lang="en_XX", tgt_lang="es_XX")

In [None]:
class Seq2SeqDataset(Dataset):
    def __init__(self, tokenizer, df, max_length=55):
        self.tokenizer = tokenizer
        self.df = df
        self.max_length = max_length
        
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, index):
        data = self.df.iloc[index]
        input_text = data['English words/sentences']
        target_text = data['French words/sentences']
        
        tokenized_example = tokenizer.prepare_seq2seq_batch(
            src_texts=[input_text],
            tgt_texts=[target_text],
            src_lang='en_XX',
            tgt_lang='es_XX',
            max_length=self.max_length,
            padding="max_length", 
            return_tensors="pt",
            truncation=True,
        )
        decoder_input_ids = tokenized_example["labels"].clone()
        decoder_input_ids = shift_tokens_right(
            decoder_input_ids, tokenizer.pad_token_id, tokenizer.lang_code_to_id['es_XX']
        )
        labels = tokenized_example["labels"]
        labels[labels == tokenizer.pad_token_id] = -100
        
        return {
            'input_ids': tokenized_example['input_ids'].squeeze(),
            'attention_mask': tokenized_example['attention_mask'].squeeze(),
            'decoder_input_ids': decoder_input_ids.squeeze(),
            'labels': labels.squeeze()
        }

In [None]:
train_dataset = Seq2SeqDataset(tokenizer, train_df)
eval_dataset = Seq2SeqDataset(tokenizer, eval_df)

train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size)

eval_sampler = SequentialSampler(eval_dataset)
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=batch_size)

# Prepare model

In [None]:
model = MBartForConditionalGeneration.from_pretrained('facebook/mbart-large-50')
model.to(device)

In [None]:
# prepare metric for evaluation
metrics = load_metric('sacrebleu')

In [None]:
optimizer = AdamW(
    model.parameters(),
    lr=learning_rate,
    eps=1e-8
)
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader) * epochs
)

In [None]:
seed_data(10)

loss_values = []

for i in tqdm(range(epochs)):
    # Training
    model.train()
    total_loss = 0
    
    for batch, data in tqdm(enumerate(train_dataloader), total=len(train_dataloader), leave=False):
        input_ids = data['input_ids'].to(device)
        attn_mask = data['attention_mask'].to(device)
        decoder_input_ids = data['decoder_input_ids'].to(device)
        labels = data['labels'].to(device)
        
        model.zero_grad()
        
        # Forward pass
        outputs = model(
            input_ids,
            attention_mask=attn_mask,
            decoder_input_ids=decoder_input_ids, 
            labels=labels
        )
        
        loss, logits = outputs[:2]
        
        curr_loss = loss.item()
        loss.backward()
        
        torch.nn.utils.clip_grad_norm(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        
        total_loss += curr_loss
    
        if batch % 300 == 0:
            labels = labels.detach().cpu().numpy()
            metrics.add_batch(predictions=logits, references=labels)
            print('Score: {} '.format(metrics.compute()))
            
    # Calculate the average loss over the training data.
    avg_train_loss = total_loss / len(train_dataloader)            
    
    # Store the loss value for plotting the learning curve.
    loss_values.append(avg_train_loss)

    print("")
    print("  Average training loss: {0:.3f}".format(avg_train_loss))