## ***GIT - A Generative Image-to-Text Transformer***

In [None]:
!pip install -q transformers datasets
!pip install -q evaluate sacrebleu rouge_score jiwer
!pip install -q wandb

In [None]:
from datasets import Features
from datasets import load_dataset, Dataset, Image
from datasets import load_dataset, Image
from PIL import Image
import torch

import warnings
warnings.filterwarnings("ignore")

In [None]:
if torch.cuda.is_available():
  device = torch.device("cuda")
  print("There are %d GPU(s) available." % torch.cuda.device_count())
  print("We will use the GPU:", torch.cuda.get_device_name(0))
else:
  print("No GPU available, using the CPU instead.")
  device = torch.device("cpu")

### ***Load data***

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
img_dir = "/data/images_formulas/"
data_files_dir = "/data/datafiles/"

data_files = {
    "train": data_files_dir + "train_data.csv",
    "valid": data_files_dir + "valid_data.csv",
    "test" : data_files_dir + "test_data.csv"
    }
data = load_dataset("csv", data_files=data_files)

In [None]:
# Load the dataset from directory
from datasets import DatasetDict
data = DatasetDict.load_from_disk('/data/formula2text-4k-pad')

In [None]:
# Delete columns from the 'train' dataset
columns_to_delete = ['formula', 'label_list', "image_name"]
data['train'] = data['train'].remove_columns(columns_to_delete)
data['valid'] = data['valid'].remove_columns(columns_to_delete)
data['test'] = data['test'].remove_columns(columns_to_delete)

### ***Prepare data for model***

In [None]:
from torch.utils.data import Dataset
from PIL import Image

class ImageCaptioningDataset(Dataset):
    def __init__(self, dataset, processor, img_path):
        self.dataset = dataset
        self.processor = processor
        self.image_path = img_path

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = Image.open(self.image_path + item["image_name"]).convert('RGB')

        encoding = self.processor(images=image, text=item["label"], padding="max_length", return_tensors="pt")

        # remove batch dimension
        encoding = {k:v.squeeze().to(device) for k,v in encoding.items()}
        return encoding

In [None]:
from transformers import AutoProcessor

checkpoint = "microsoft/git-large-textcaps"
processor = AutoProcessor.from_pretrained(checkpoint)

In [None]:
train_dataset = ImageCaptioningDataset(data["train"], processor)
valid_dataset = ImageCaptioningDataset(data["valid"], processor)
test_dataset = ImageCaptioningDataset(data["test"], processor)
print(train_dataset)
print(valid_dataset)
print(test_dataset)

### Create Pytorch Data Loader

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=16)
valid_dataloader = DataLoader(valid_dataset, shuffle=True, batch_size=16)
test_dataloader = DataLoader(test_dataset, shuffle=True, batch_size=16)

### ***Fine-tuning the model***

In [None]:
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(checkpoint)

In [None]:
from tqdm.notebook import tqdm
import torch

num_epochs = 8
lr=5e-5

# Create an optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

# Check if GPU or CPU available
device = "cuda" if torch.cuda.is_available() else "cpu"
print("device: ", device )
model.to(device)

print(f"Number of epochs: {num_epochs}")
num_training_steps = num_epochs * len(train_dataloader)
print(f"Number of training steps: {num_training_steps}")
progress_bar = tqdm(range(num_training_steps))


# Train the model
model.train()
for epoch in range(num_epochs):
  #print("Epoch:", epoch)
  for idx, batch in enumerate(train_dataloader):
    input_ids = batch.pop("input_ids").to(device)
    pixel_values = batch.pop("pixel_values").to(device)

    outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=input_ids)
    loss = outputs.loss
    print("Loss:", loss.item())
    loss.backward()

    optimizer.step()
    optimizer.zero_grad()
    progress_bar.update(1)

In [None]:
torch.save(model.state_dict(), ".../models/GIT-Large-Image-to-Text.pth")

## ***Testing***

In [None]:
from google.colab import files

In [None]:
!cp /utils/cf_custom_functions.py /content

In [None]:
import cf_custom_functions as cf
import pandas as pd
import numpy as np

In [None]:
# Load test data
df_test = cf.load_test_data("/data/datafiles/test_data.json")

### ***Generate Predictions with Pre-trained model***

In [None]:
IMG_DIR = "/data/images_formulas/"

In [None]:
model = AutoModelForCausalLM.from_pretrained(checkpoint)
processor = AutoProcessor.from_pretrained(checkpoint)

In [None]:
def generate_VLM_predictions(test_data:pd.DataFrame, model:object, processor:object, IMG_DIR:str) -> pd.DataFrame:
  df = test_data.copy()
  model = model
  image_path = IMG_DIR
  y_preds = []

  for i, entry in df.iterrows():
    image_name = entry["image_name"]
    image = Image.open(image_path + image_name).convert('RGB')
    inputs = processor(images=image, return_tensors="pt")
    pixel_values = inputs.pixel_values
    generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
    generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    y_preds.append(generated_caption)

  y_preds = np.array(y_preds)
  df["prediction"] = y_preds
  return df

In [None]:
# Predictions of test_data
df_preds_pt = generate_VLM_predictions(df_test,model,processor, IMG_DIR)
df_preds_pt_clean = cf.post_processing_multi_predictions(df_preds_pt)

In [None]:
metrics_pt = cf.compute_evaluation_metrics(df_preds_pt_clean,"clean_prediction")
cf.save_evaluation_metrics("GIT-Large-Image-to-Text_pretrained",metrics_pt,"../metrics/VLM_metrics.json")

## ***Generate Predictions with Fine-tuned model***

In [None]:
# Load model
model_ft = AutoModelForCausalLM.from_pretrained(checkpoint)
processor_ft = AutoProcessor.from_pretrained(checkpoint)
model_ft.load_state_dict(torch.load("/content/drive/MyDrive/models/GIT-Large-Formula-to-Text_pad.pth"))
model_ft.eval()

In [None]:
# Predictions of test_data
df_preds_ft = generate_VLM_predictions(df_test,model_ft,processor_ft, IMG_DIR)
df_preds_ft_clean = cf.post_processing_multi_predictions(df_preds_ft)

In [None]:
metrics_ft = cf.compute_evaluation_metrics(df_preds_ft_clean,"clean_prediction")
cf.save_evaluation_metrics("GIT-Large-Image-to-Text_finetuned",metrics_ft,"../metrics/VLM_metrics.json")