In [None]:
# here we finetune MBART50 on machine translation and supply an additional context vector to the decoder to improve intertextuality 

# we will use the Hugging Face transformers library
from transformers import MBartTokenizer, MBartForConditionalGeneration, Trainer, TrainingArguments
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments
from transformers import pipeline

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import numpy as np
import pandas as pd
import os
import random
import json
from typing import List, Dict, Any, Tuple

# set the seed for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# set the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# load the dataset
data = pd.read_csv('data/eng_fra.csv')
data = data.dropna()
data = data.drop_duplicates()
data = data.reset_index(drop=True)
data.head()

# split the data into train and test
from sklearn.model_selection import train_test_split
train_data, test_data = train_test_split(data, test_size=0.1, random_state=seed)
train_data = train_data.reset_index(drop=True)
test_data = test_data.reset_index(drop=True)


# need to add a callback to the trainer to supply the context vector to the decoder
class ContextCallback:
    def __init__(self, tokenizer: MBartTokenizer, context: List[str]):
        self.tokenizer = tokenizer
        self.context = context

    def __call__(self, model, inputs):
        inputs['decoder_input_ids'] = self.tokenizer(self.context, return_tensors='pt', padding=True).input_ids
        return inputs