In [None]:
!pip install -q transformers
!pip install -q datasets
!pip install rouge_score sacrebleu jiwer
!pip install evaluate
!pip install --upgrade accelerate

## ***Load Data***

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
from datasets import load_dataset

# Load the data
IMG_DIR = "/data/images_formulas/"

data_files = {
    "train": "/data/datafiles/train_data.csv",
    "valid": "/data/datafiles/valid_data.csv"
    }
data = load_dataset("csv", data_files=data_files)

In [4]:
import torch
from torch.utils.data import Dataset
from PIL import Image

In [None]:
class Im2LatexDataset(Dataset):
    def __init__(self, root_dir, df, processor, max_target_length=128):
        self.root_dir = root_dir
        self.df = df
        self.processor = processor
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_name = self.df['image_name'][idx]
        text = self.df['formula'][idx]
        image = Image.open(self.root_dir + file_name).convert("RGB")
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        labels = self.processor.tokenizer(text,
                                          padding="max_length",
                                          max_length=self.max_target_length).input_ids
    
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding

In [None]:
from transformers import TrOCRProcessor

processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")

train_dataset = Im2LatexDataset(root_dir=IMG_DIR, df=data["train"], processor=processor)
eval_dataset = Im2LatexDataset(root_dir= IMG_DIR, df=data["valid"], processor=processor)

### ***Fine-tuning of the model***

In [None]:
from transformers import VisionEncoderDecoderModel

model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-stage1")

In [None]:
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size

model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4

In [None]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

batch_size = 8
model_name = "TrOCR-Base-Image-to-Latex"
model_dir = f"/content/drive/MyDrive/models/{model_name}"

training_args = Seq2SeqTrainingArguments(
    model_dir,
    predict_with_generate=True,
    num_train_epochs=3,
    evaluation_strategy="steps",
    save_strategy="steps",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    learning_rate=4e-5,
    weight_decay=0.01,
    fp16=True,
    logging_steps=2,
    save_steps=1000,
    eval_steps=200,
    load_best_model_at_end=True,
    save_total_limit=1,
)

In [None]:
from transformers import default_data_collator

trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=processor.feature_extractor,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=default_data_collator,
)
trainer.train()
trainer.save_model()

### ***Testing***

In [6]:
from google.colab import files

In [7]:
!cp /utils/cf_custom_functions.py /content

In [8]:
import cf_custom_functions as cf
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

In [None]:
df_test = cf.load_test_data("/data/test_data.json")

### ***Load the pre-trained model***

In [None]:
from transformers import VisionEncoderDecoderModel, TrOCRProcessor

# Load the pre-trained model
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-printed')
model_pt = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-stage1')

In [None]:
def generate_OCR_predictions(test_data:pd.DataFrame, model:object, processor:object, IMG_DIR:str) -> pd.DataFrame:
  df = test_data.copy()
  model = model
  image_path = IMG_DIR
  y_preds = []

  for i, entry in df.iterrows():
    image_name = entry["image_name"]
    print(image_path+image_name)
    image = Image.open(image_path + image_name).convert('RGB')
    pixel_values = processor(images=image, return_tensors="pt").pixel_values
    generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
    generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    print(generated_caption)
    y_preds.append(generated_caption)

  y_preds = np.array(y_preds)
  df["prediction"] = y_preds
  return df

In [None]:
df_preds_pt = generate_OCR_predictions(df_test,model_pt,processor,IMG_DIR)
cf.compute_OCR_evaluation_metrics(df_preds_pt,"prediction")

### ***Generate Predicitons for Fine-tuned model***

In [None]:
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-printed')
model_ft = VisionEncoderDecoderModel.from_pretrained("../models/TrOCR-Base-Image-to-Latex")

In [None]:
df_preds_ft = generate_OCR_predictions(df_test,model_ft,processor,IMG_DIR)
cf.compute_OCR_evaluation_metrics(df_preds_ft,"prediction")