# Fine-tuning a Machine Translation model

In this sheet we take the multilingual model MBart and fine tune it for legal translation using a French-English corpus of legal documents. We will look at whether this improves using the BLEU score.


In [None]:
!pip install evaluate

In [None]:
!pip install datasets

# MBart

MBart is a multilingual encoder-decoder (sequence-to-sequence) model primarily intended for translation. A special language id token is added in both the source and target text depending on the language pair targeted.

In [None]:
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast, DataCollatorForSeq2Seq, TrainingArguments, Trainer
import pandas as pd
from datasets import Dataset, DatasetDict

In [None]:
device="cuda"
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt").to(device)
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")

# Cadlaw corpus

This is an English–French corpus built from Canadian legal documents. The corpus contains over 16 million words in each language and is composed of documents that are legally equivalent in both languages but not the result of a translation. The corpus is built upon enactments co-drafted by two jurists to ensure legal equality of each version and to re­flect the concepts, terms and institutions of two legal traditions.

For more information see here:

https://www.researchgate.net/publication/353471306_Cadlaws_-_An_English-French_Parallel_Corpus_of_Legally_Equivalent_Documents

We are going to use it to fine tune MBart to perform the task of legal translation.


In [None]:
!gdown 1CPYP5JNzKzBqlKZMiGdfXD7y_0g-ruus
!unzip cadlaws-fr-en.txt.zip


In [None]:
df = pd.read_csv("cadlaws-fr-en.txt",sep="\t",header=None)
df.columns = ["fr","en"]
df=df.dropna()
df=df.head(10000)
ds=Dataset.from_pandas(df)

In [None]:
train_test=ds.train_test_split(test_size=500/df.shape[0],seed=99)

In [None]:
ds_test=train_test["test"]

In [None]:
target_encodings = tokenizer(train_test["train"]["en"], max_length=1024,truncation=True)
input_encodings = tokenizer(train_test["train"]["fr"], max_length=1024,truncation=True)
ds_pt=Dataset.from_dict({"input_ids": input_encodings["input_ids"], "attention_mask": input_encodings["attention_mask"],"labels": target_encodings["input_ids"]})
columns = ["input_ids", "labels", "attention_mask"]
ds_pt.set_format(type="torch", columns=columns)

In [None]:
train_valid=ds_pt.train_test_split(test_size=500/(df.shape[0]-500),seed=99)

First of all we will evaluate the performance of untouched MBart on a test set of 500 utterances from Cadlaws using Bleu.

In [None]:
import torch
torch.cuda.empty_cache()
model.to("cuda")
tokenizer.src_lang = "fr_XX"
predictions=[]
for i in range(len(ds_test["fr"])):
  input_ = tokenizer.batch_encode_plus(ds_test["fr"][i:(i+1)], max_length=1024, pad_to_max_length=True,truncation=True, padding='longest', return_tensors="pt")
  input_ids = input_['input_ids']
  input_mask = input_['attention_mask']
  responses_ft = model.generate(input_ids=input_ids.to(device),
                         attention_mask=input_mask.to(device),
                         num_beams=100,
                         no_repeat_ngram_size=2,
                         early_stopping=True,
                         num_return_sequences=1,
                         max_length=1024,
                          )
  predictions.extend(tokenizer.batch_decode(responses_ft, skip_special_tokens=True))

In [None]:
predictions_vanilla = predictions

In [None]:
import evaluate
predictions
references=[ds_test["en"]]
bleu = evaluate.load("bleu")
bleu.add(predictions=str(predictions_vanilla), references=str(references))
bleu.compute()

Next we will fine tune MBart using the training section of Cadlaws. Please note that this will take many hours so I am including a downloadable version of a fine-tuned model below. To fine tune from scratch you will need to comment out the next few blocks.

In [None]:
#seq2seq_data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

#training_args = TrainingArguments(
#    output_dir='fr-en', num_train_epochs=1, warmup_steps=500,
#    per_device_train_batch_size=1, per_device_eval_batch_size=1,
#    weight_decay=0.01, logging_steps=10, push_to_hub=False,
#    evaluation_strategy='steps', eval_steps=30, save_steps=1e6,gradient_accumulation_steps=128)

#trainer = Trainer(model=model, args=training_args,
#                  tokenizer=tokenizer, data_collator=seq2seq_data_collator,
#                  train_dataset=ds_pt["train"],
#                  eval_dataset=ds_pt["valid"])

In [None]:
#!pip install wandb

In [None]:
#import wandb
#from huggingface_hub import notebook_login

#notebook_login()
#wandb.init(mode="disabled")

In [None]:
# hide_output
#torch.cuda.empty_cache()
#trainer.train()
# To save your fine-tuned model:
#trainer.save_model("en-fr-legal-mbart")

To download and load the already fine-tuned model run the next cell.

In [None]:
!gdown 1V5Ib_QwDqfw_Qbjyg1lZxcuHjX4wZlLP
!gunzip en-fr-legal-mbart.tar.gz
!tar xf en-fr-legal-mbart.tar

In [None]:
device="cuda"
model = MBartForConditionalGeneration.from_pretrained("./en-fr-legal-mbart").to(device)
tokenizer = MBart50TokenizerFast.from_pretrained("./en-fr-legal-mbart")


We can evaluate this model using Bleu.

In [None]:
import torch
torch.cuda.empty_cache()
model.to("cuda")
tokenizer.src_lang = "fr_XX"
predictions=[]
for i in range(len(ds_test["fr"])):
  input_ = tokenizer.batch_encode_plus(ds_test["fr"][i:(i+1)], max_length=1024, pad_to_max_length=True,truncation=True, padding='longest', return_tensors="pt")
  input_ids = input_['input_ids']
  input_mask = input_['attention_mask']
  responses_ft = model.generate(input_ids=input_ids.to(device),
                         attention_mask=input_mask.to(device),
                         num_beams=100,
                         no_repeat_ngram_size=2,
                         early_stopping=True,
                         num_return_sequences=1,
                         max_length=1024,
                          )
  predictions.extend(tokenizer.batch_decode(responses_ft, skip_special_tokens=True))

In [None]:
predictions

In [None]:
import evaluate
predictions
references=[ds_test["en"]]
bleu = evaluate.load("bleu")
bleu.add(predictions=str(predictions), references=str(references))
bleu.compute()