<a target="_blank" href="https://colab.research.google.com/github/cswamy/pytorch/blob/main/notebooks/Translation_finetuned_marian_kde4.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

### **Notes**

Notebook to finetune a MarianMT model for translation from English to French using the KDE4 dataset.

App: https://huggingface.co/spaces/cswamy/english_to_french_translator

Resources:

Hugging face checkpoint: https://huggingface.co/Helsinki-NLP/opus-mt-en-fr

Original Marian paper: https://arxiv.org/abs/1804.00344

KDE4 dataset: https://huggingface.co/datasets/kde4

Inspired by hugging face tutorial: https://huggingface.co/learn/nlp-course/chapter7/4?fw=pt

### **Setup**

In [1]:
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [2]:
try:
  import transformers
  print("[INFO] Hugging face transformers imported successfully!")
except:
  !pip install -q transformers
  import transformers
  print("[INFO] Hugging face transformers installed and imported successfully!")

[INFO] Hugging face transformers imported successfully!


In [3]:
try:
  import datasets
  print("[INFO] Hugging face datasets imported successfully!")
except:
  !pip install -q datasets
  import datasets
  print("[INFO] Hugging face datasets installed and imported successfully!")

[INFO] Hugging face datasets imported successfully!


### **Download dataset**

#### Get data

In [4]:
from datasets import load_dataset

raw_datasets = load_dataset('kde4', lang1='en', lang2='fr')

In [5]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['id', 'translation'],
        num_rows: 210173
    })
})

In [6]:
raw_datasets['train'][:5]

{'id': ['0', '1', '2', '3', '4'],
 'translation': [{'en': 'Lauri Watts', 'fr': 'Lauri Watts'},
  {'en': '& Lauri. Watts. mail;', 'fr': '& Lauri. Watts. mail;'},
  {'en': 'ROLES_OF_TRANSLATORS', 'fr': '& traducteurJeromeBlanc;'},
  {'en': '2006-02-26 3.5.1', 'fr': '2006-02-26 3.5.1'},
  {'en': 'The Babel & konqueror; plugin gives you quick access to the Babelfish translation service.',
   'fr': 'Le module externe Babel pour & konqueror; vous donne un accès rapide au service de traduction Babelfish.'}]}

#### Clean data before downsampling

In [7]:
# Function to remove sentences with less than 5 words
def clean_examples(example):
  # print(f"Example: {example['translation']['en']}")
  return (len(example['translation']['en'].split()) >= 5)

In [8]:
cleaned_datasets = raw_datasets.filter(clean_examples)

In [9]:
cleaned_datasets

DatasetDict({
    train: Dataset({
        features: ['id', 'translation'],
        num_rows: 70236
    })
})

In [10]:
cleaned_datasets['train'][:5]

{'id': ['4', '11', '13', '14', '15'],
 'translation': [{'en': 'The Babel & konqueror; plugin gives you quick access to the Babelfish translation service.',
   'fr': 'Le module externe Babel pour & konqueror; vous donne un accès rapide au service de traduction Babelfish.'},
  {'en': 'The Babel & konqueror; plugin',
   'fr': 'Le module externe Babel pour & konqueror;'},
  {'en': 'Babelfish is a machine translation service provided by AltaVista.',
   'fr': 'Babelfish est un service de traduction automatique fourni par AltaVista.'},
  {'en': 'The plugin allows you to automatically translate web pages between several languages.',
   'fr': 'Le module externe vous permet de traduire automatiquement les pages web dans plusieurs langues.'},
  {'en': 'The Babelfish plugin can be accessed in the & konqueror; menubar under Tools Translate Web Page. Select from the list that drops down the language to translate from and the language to translate to.',
   'fr': 'Vous pouvez accéder au module externe

#### Create train and val splits

In [11]:
# Downsample for faster training
train_size = int(0.3 * cleaned_datasets['train'].num_rows)
val_size = int(0.1 * cleaned_datasets['train'].num_rows)

dataset_splits = cleaned_datasets['train'].train_test_split(train_size=train_size,
                                                            test_size=val_size,
                                                            seed=42)
dataset_splits['validation'] = dataset_splits.pop('test')
dataset_splits

DatasetDict({
    train: Dataset({
        features: ['id', 'translation'],
        num_rows: 21070
    })
    validation: Dataset({
        features: ['id', 'translation'],
        num_rows: 7023
    })
})

### **Tokenize data**

#### Download tokenizer

In [12]:
!pip install sentencepiece
import sentencepiece



In [13]:
from transformers import AutoTokenizer

checkpoint = 'Helsinki-NLP/opus-mt-en-fr'
tokenizer = AutoTokenizer.from_pretrained(checkpoint, return_tensors='pt')

Downloading (…)olve/main/source.spm:   0%|          | 0.00/778k [00:00<?, ?B/s]

Downloading (…)olve/main/target.spm:   0%|          | 0.00/802k [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.34M [00:00<?, ?B/s]



#### Define tokenize function

In [14]:
def tokenize_func(examples):
  # Set max_length
  max_length = 128

  inputs = [ex['en'] for ex in examples['translation']]
  targets = [ex['fr'] for ex in examples['translation']]

  model_inputs = tokenizer(inputs,
                           text_target=targets,
                           max_length=max_length,
                           truncation=True)

  return model_inputs

#### Tokenize datasets

In [15]:
tokenized_datasets = dataset_splits.map(tokenize_func,
                                        batched=True,
                                        remove_columns=dataset_splits['train'].column_names)

tokenized_datasets

Map:   0%|          | 0/21070 [00:00<?, ? examples/s]

Map:   0%|          | 0/7023 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 21070
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 7023
    })
})

### **Prepare dataloaders**

#### Setup data collator

In [16]:
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq

tmp_model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer,
                                       model=tmp_model)

Downloading pytorch_model.bin:   0%|          | 0.00/301M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/293 [00:00<?, ?B/s]

#### Create dataloaders

In [17]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(dataset=tokenized_datasets['train'],
                              batch_size=32,
                              shuffle=True,
                              collate_fn=data_collator)

val_dataloader = DataLoader(dataset=tokenized_datasets['validation'],
                            batch_size=32,
                            shuffle=False,
                            collate_fn=data_collator)

len(train_dataloader), len(val_dataloader)

(659, 220)

### **Prepare Sacre BLEU metric**

#### Define function to process data for metric

In [18]:
import numpy as np

def process_for_bleu(preds:torch.tensor,
                     labels:torch.tensor):
  """
  Function to process preds and labels for Sacre BLEU score.
  Args:
    preds(torch.tensor): batch of predicted tokens from model.
    labels(torch.tensor): batch of truth label tokens.
  Returns:
    Lists of decoded preds and labels for batch.
  """

  # Convert tensor to numpy
  preds = preds.cpu().numpy()
  labels = labels.cpu().numpy()

  decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

  # Remove -100 from labels and decode
  labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
  decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

  # Remove white spaces
  decoded_preds = [pred.strip() for pred in decoded_preds]
  decoded_labels = [[label.strip()] for label in decoded_labels]

  return decoded_preds, decoded_labels

In [19]:
try:
  from torchmetrics.text import SacreBLEUScore
  sacre_bleu_fn = SacreBLEUScore().to(device)
except:
  !pip install -q torchmetrics
  from torchmetrics.text import SacreBLEUScore
  sacre_bleu_fn = SacreBLEUScore().to(device)

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/764.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m174.1/764.8 kB[0m [31m5.1 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━[0m [32m563.2/764.8 kB[0m [31m8.0 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m764.8/764.8 kB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
[?25h

### **Train and eval**

#### Setup training

In [20]:
from transformers import AutoModelForSeq2SeqLM

# Initialise model
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint).to(device)

# Setup optimiser
optimiser = torch.optim.AdamW(params=model.parameters(),
                              lr=2e-5)

# Setup scheduler
lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer=optimiser)

#### Train loop

In [21]:
from tqdm.auto import tqdm

EPOCHS = 10
train_loss, train_bleu = 0, 0

for epoch in tqdm(range(EPOCHS)):
  for batch_num, batch in enumerate(train_dataloader):
    # Send data to device
    batch = {k: v.to(device) for k, v in batch.items()}

    # Forward pass
    outputs = model(**batch)

    # Accumulate loss
    loss = outputs.loss
    train_loss += loss

    # Zero grad optimiser
    optimiser.zero_grad()

    # Backpropagate loss
    loss.backward()

    # Step optimiser and scheduler
    optimiser.step()
    lr_scheduler.step()

  # Average loss and bleu score across batches
  train_loss /= len(train_dataloader)
  train_bleu /= len(train_dataloader)

  # Print loss abd bleu for entire batch
  print(f"Epoch: {epoch+1} | Training loss: {train_loss:.4f}")

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 1 | Training loss: 0.9687
Epoch: 2 | Training loss: 0.7767
Epoch: 3 | Training loss: 0.6604
Epoch: 4 | Training loss: 0.5623
Epoch: 5 | Training loss: 0.4728
Epoch: 6 | Training loss: 0.3888
Epoch: 7 | Training loss: 0.3084
Epoch: 8 | Training loss: 0.2374
Epoch: 9 | Training loss: 0.1758
Epoch: 10 | Training loss: 0.1268


#### Eval loop

In [62]:
val_loss, val_bleu = 0, 0

model.eval()
with torch.inference_mode():
  for batch in val_dataloader:
    # Send batch to device
    batch = {k: v.to(device) for k, v in batch.items()}

    # Forwards pass
    loss = model(**batch).loss

    # Accumulate loss
    val_loss += loss

    # Calculate Sacre BLEU
    preds = model.generate(**batch)
    processed_preds, processed_labels = process_for_bleu(preds, batch['labels'])
    batch_bleu = sacre_bleu_fn(processed_preds, processed_labels)
    val_bleu += batch_bleu.item()

  # Average loss and bleu across batches
  val_loss /= len(val_dataloader)
  val_bleu /= len(val_dataloader)

# Print output
print(f"Validation loss: {val_loss:.4f} | Sacre BLEU: {val_bleu:.4f}")

Validation loss: 1.1557 | Sacre BLEU: 0.5632


### **Save model**

In [90]:
!git clone https://github.com/cswamy/pytorch

Cloning into 'pytorch'...
remote: Enumerating objects: 80, done.[K
remote: Counting objects: 100% (80/80), done.[K
remote: Compressing objects: 100% (72/72), done.[K
remote: Total 80 (delta 25), reused 18 (delta 3), pack-reused 0[K
Receiving objects: 100% (80/80), 55.95 KiB | 868.00 KiB/s, done.
Resolving deltas: 100% (25/25), done.


In [92]:
from pytorch.scripts import utils

utils.save_model(model=model,
                 target_dir='models',
                 model_name='marian_finetuned_kde4_enfr.pth')

[INFO] Saving model to: models/marian_finetuned_kde4_enfr.pth


### **Predict / Generate**

In [111]:
text = "I love music"

input = tokenizer(text,
                  max_length=128,
                  truncation=True,
                  return_tensors='pt')
input = {k: v.to(device) for k, v in input.items()}
output = model.generate(**input)
tokenizer.batch_decode(output, skip_special_tokens=True)

["J'aime la musique"]

### **Deploy to hugging face**

In [104]:
from pathlib import Path

# Create folders
demo_path = Path("demos/marian_en_fr")
demo_path.mkdir(parents=True, exist_ok=True)

In [105]:
# Move model to demo folder
!mv models/marian_finetuned_kde4_enfr.pth demos/marian_en_fr

In [107]:
%%writefile demos/marian_en_fr/model.py
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

def create_marian_enfr():
  """
  Initializes model and tokenizer.
  """
  checkpoint = 'Helsinki-NLP/opus-mt-en-fr'
  tokenizer = AutoTokenizer.from_pretrained(checkpoint, return_tensors='pt')
  model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

  return model, tokenizer

Writing demos/marian_en_fr/model.py


In [112]:
%%writefile demos/marian_en_fr/app.py
import torch
import gradio as gr

from model import create_marian_enfr

# Setup model and tokenizer
model, tokenizer = create_marian_enfr()

# Load state dict from model
model.load_state_dict(
    torch.load(
        f="marian_finetuned_kde4_enfr.pth",
        map_location=torch.device("cpu")
    ))

# Predict function
def predict(text:str):

  # Tokenize inputs and get model outputs
  input = tokenizer(text,
                    max_length=128,
                    truncation=True,
                    return_tensors="pt")
  output_tokens = model.generate(**input)
  output_text = tokenizer.batch_decode(output_tokens,
                                       skip_special_tokens=True)

  return output_text

# Create examples list
examples_list = ['What a beautiful day',
                 'I love music']

# Create gradio app
title = "English to French translator"
description = "Marian model finetuned for english to french translation on the kde4 dataset."

demo = gr.Interface(fn=predict,
                    inputs=gr.inputs.Textbox(label="Input",
                                             placeholder="Enter sentence here..."),
                    outputs="text",
                    examples=examples_list,
                    title=title,
                    description=description)

# Launch gradio
demo.launch()

Writing demos/marian_en_fr/app.py


In [113]:
%%writefile demos/marian_en_fr/requirements.txt
torch==1.12.0
gradio==3.1.4
transformers==4.33.1
sentencepiece==0.1.99

Writing demos/marian_en_fr/requirements.txt


In [114]:
!cd demos/marian_en_fr && zip -r ../marian_en_fr.zip *

  adding: app.py (deflated 57%)
  adding: marian_finetuned_kde4_enfr.pth (deflated 7%)
  adding: model.py (deflated 41%)
  adding: requirements.txt (stored 0%)


In [115]:
try:
  from google.colab import files
  files.download("demos/marian_en_fr.zip")
except:
  print(f"Download failed!")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>