## T5-Base - Machine Translation

In [None]:
!pip install transformers -qq
!pip install datasets jiwer -qq
!pip install evaluate -qq
!pip install sacrebleu rouge_score -qq
!pip install --upgrade accelerate
!pip install wandb -Uqq

In [3]:
from datasets import load_dataset
from transformers import create_optimizer, AutoTokenizer, DataCollatorForSeq2Seq,TFAutoModelForSeq2SeqLM
from datasets import DatasetDict, Dataset, load_metric
import numpy as np
import sacrebleu

### Load data

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
DATA_FILES_DIR = "/data/datafiles/"

data_files = {
    "train": DATA_FILES_DIR + "train_data.csv",
    "valid": DATA_FILES_DIR + "valid_data.csv",
    "test" : DATA_FILES_DIR + "test_data.csv"
    }

raw_datasets = load_dataset("csv", data_files=data_files)

### Configuration class

In [8]:
class cfc:
  checkpoint = "t5-base"
  model_name     = "T5-Base-finetuned-latex-to-text-tuned"
  model_dir      = f"/content/drive/MyDrive/models/{model_name}"
  test_file_path = "/content/drive/MyDrive/data/my_corpus/test_data_all_cleaned.json"

  wandb_project = "NLG"
  run_name = model_name

  # hyperparameters
  #lr_rate = 5e-4
  #batch_size = 32
  #epochs = 4
  #weight_decay = 0.01

  # tuned hyperparameters
  lr_rate = 0.0002668
  batch_size = 4
  epochs = 2
  weight_decay = 0.01

### Data Preprocessing

In [None]:
tokenizer = AutoTokenizer.from_pretrained(cfc.checkpoint)

In [10]:
# Prefix the input with a prompt so T5 knows this is our translation task.
prefix = "translate Latex to Text: "

def preprocess_function(examples):
    inputs = [prefix + example for example in examples["formula"]]
    targets = [example for example in examples["label"]]
    model_inputs = tokenizer(inputs, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True, remove_columns=["image_name","formula","label","label_list"])
print(tokenized_datasets)

### Fine-Tuning the model

In [None]:
import wandb
wandb.login()

wandb.init(
    project=cfc.wandb_project,
    name = cfc.run_name,
    config={"architecture": cfc.model_name, "dataset": "Formula2Text-4k"}
    )

In [None]:
import numpy as np
import evaluate

bleu = evaluate.load("bleu")
ter = evaluate.load("ter")
rouge = evaluate.load("rouge")

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]
    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    bleu_res = bleu.compute(predictions=decoded_preds, references=decoded_labels)
    ter_res = ter.compute(predictions=decoded_preds, references=decoded_labels)
    rouge_res = rouge.compute(predictions=decoded_preds, references=decoded_labels)
    ter_acc = (1-(ter_res["score"]/100))

    metrics = {
        "BLEU": bleu_res["bleu"],
        "TER" : ter_res["score"],
        "TER-ACC" : ter_acc,
        "ROUGE-1" : rouge_res["rouge1"],
        "ROUGE-2" : rouge_res["rouge2"],
        "ROUGE-L" : rouge_res["rougeL"],
        }
    return metrics

In [None]:
# Loading the pre-trained model for fine-tuning
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
model = AutoModelForSeq2SeqLM.from_pretrained(cfc.checkpoint)

In [None]:
def model_init():
    return AutoModelForSeq2SeqLM.from_pretrained(cfc.checkpoint)

In [None]:
model_dir = cfc.model_dir

args = Seq2SeqTrainingArguments(
    model_dir,
    report_to = "wandb",
    predict_with_generate=True,
    num_train_epochs= cfc.epochs,
    learning_rate= cfc.lr_rate,
    weight_decay=cfc.weight_decay,
    per_device_train_batch_size=cfc.batch_size,
    per_device_eval_batch_size=cfc.batch_size,
    evaluation_strategy = "steps",
    eval_steps=200,
    logging_strategy="steps",
    logging_steps=200,
    save_strategy = "steps",
    save_steps = 200,
    fp16=False,
    save_total_limit=1,
    load_best_model_at_end=True
    )

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer)

In [None]:
trainer = Seq2SeqTrainer(
    model_init=model_init,
    args=args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["valid"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    )

trainer.train()
trainer.save_model()

In [None]:
trainer.evaluate(tokenized_datasets["test"])

In [None]:
wandb.finish()

### Hyperparameter Tuning

#### Weights and Biases Setup

In [None]:
import wandb
wandb.login()

In [None]:
%env WANDB_PROJECT=NLG_Sweeps
%env WANDB_LOG_MODEL=true

In [None]:
import pprint

# SWEEP CONFIGURATION
sweep_config = {
    'method': 'random'
    }

metric = {
    'name': 'BLEU',
    'goal': 'maximize'
    }

sweep_config['metric'] = metric

# define hyperparameters
parameters_dict = {
    "epochs": {"values": [2,5,10,15,25]},
    "batch_size" : {"values":[4,8,16,32,64]},
    "learning_rate" : {
        "distribution" : "log_uniform_values",
        "min" : 1e-5, "max" : 1e-3
    },
    "weight_decay" :{"values" : [0.0,0.1,0.2,0.3]},
}
sweep_config["parameters"] = parameters_dict
pprint.pprint(sweep_config)

In [None]:
# Initialize the sweep
sweep_id = wandb.sweep(sweep_config, project="NLG_Sweeps")

In [None]:
def train(config=None):
  with wandb.init(config=config):
    config = wandb.config

    args = Seq2SeqTrainingArguments(
        output_dir = "vit-sweeps",
        report_to = "wandb",
        run_name = "T5-Base",
        predict_with_generate=True,
        num_train_epochs= config.epochs,
        learning_rate= config.learning_rate,
        weight_decay=config.weight_decay,
        per_device_train_batch_size=config.batch_size,
        per_device_eval_batch_size=16,
        evaluation_strategy = "steps",
        eval_steps=200,
        logging_strategy="steps",
        logging_steps=200,
        save_strategy = "steps",
        save_steps = 200,
        fp16=False,
        save_total_limit=1,
        load_best_model_at_end=True
    )

    trainer = Seq2SeqTrainer(
      model_init=model_init,
      args=args,
      train_dataset=tokenized_datasets["train"],
      eval_dataset=tokenized_datasets["valid"],
      data_collator=data_collator,
      tokenizer=tokenizer,
      compute_metrics=compute_metrics,
    )
    trainer.train()
    wandb.log({"BLEU": metric})
    trainer.save_model()

In [None]:
wandb.agent(sweep_id, train, count=5)

In [None]:
wandb.finish()

### Evaluation on Testset

In [None]:
from google.colab import files

In [None]:
!cp /utils/cf_custom_functions.py /content

In [None]:
import pandas as pd
from transformers import AutoModelForSeq2SeqLM
import evaluate
import cf_custom_functions as cf

### Pre-trained model evaluation

In [None]:
pt_tokenizer = AutoTokenizer.from_pretrained(cfc.checkpoint)
pt_model = AutoModelForSeq2SeqLM.from_pretrained(cfc.checkpoint)

In [None]:
prefix = "translate Latex to Text: "
pt_metrics, pt_preds = cf.model_evaluation_on_testset(cfc.test_file_path, pt_model, pt_tokenizer, prefix)
print(pt_metrics)
cf.save_evaluation_metrics(cfc.model_name+" Pre-trained",pt_metrics,"../metrics/NLG_metrics_new.json")

### Fine-tuned model evaluation

In [None]:
ft_tokenizer = AutoTokenizer.from_pretrained(cfc.model_dir)
ft_model = AutoModelForSeq2SeqLM.from_pretrained(cfc.model_dir)

In [None]:
ft_metrics, ft_preds = cf.model_evaluation_on_testset(cfc.test_file_path, ft_model, ft_tokenizer, prefix)
print(ft_metrics)
cf.save_evaluation_metrics(cfc.model_name+"Fine-tuned",ft_metrics,"../metrics/NLG_metrics_new.json")