SWIN-GPT2 - Formula Image to Text

In [None]:
!pip install -q torch transformers datasets
!pip install -q --upgrade accelerate
!pip install -q evaluate sacrebleu jiwer rouge_score
!pip -q install xformers
!pip install wandb -Uqq

In [None]:
import requests
import torch
from PIL import Image
from transformers import *
from tqdm import tqdm
# set device to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Available device: ", device)

### Configuration class

In [None]:
class cfc:

  IMG_DIR = "/data/images_formulas/"

  data_files_dir = "/data/datafiles/"
  test_file_path = "data/datafiles/test_data.json"

  model_name = "Swin-GPT2_image-to-text"
  model_dir = f"/content/drive/MyDrive/models/{model_name}"

  encoder_model = "microsoft/swin-base-patch4-window7-224-in22k"
  decoder_model = "gpt2"

  # Hyperparameter
  learning_rate = 4e-5
  batch_size = 16
  weight_decay = 0.01
  num_epochs = 12

  wandb_project = "VLM"
  run_name = model_name

### Data Preprocessing

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Load the dataset
from torch.utils.data import Dataset
from datasets import DatasetDict
data = DatasetDict.load_from_disk('/content/drive/MyDrive/data/my_corpus/formula2text-4k')

In [None]:
# Loading pre-trained models
image_processor = ViTImageProcessor.from_pretrained(cfc.encoder_model)
tokenizer =  GPT2TokenizerFast.from_pretrained(cfc.decoder_model)
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(cfc.encoder_model, cfc.decoder_model).to(device)

In [None]:
if "gpt2" in cfc.decoder_model:
  tokenizer.pad_token = tokenizer.eos_token
  model.config.eos_token_id = tokenizer.eos_token_id
  model.config.pad_token_id = tokenizer.pad_token_id
  model.config.decoder_start_token_id = tokenizer.bos_token_id
else:
  model.config.decoder_start_token_id = tokenizer.cls_token_id
  model.config.pad_token_id = tokenizer.pad_token_id

In [None]:
def preprocess(items):
  pixel_values = image_processor(items["image"], return_tensors="pt").pixel_values.to(device)
  targets = tokenizer(items["label"], max_length=50, padding="max_length", truncation=True, return_tensors="pt").to(device)
  return {'pixel_values': pixel_values, 'labels': targets["input_ids"]}

In [None]:
train_dataset = data["train"].with_transform(preprocess)
valid_dataset = data["valid"].with_transform(preprocess)
test_dataset  = data["test"].with_transform(preprocess)

In [None]:
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.stack([x['labels'] for x in batch])
    }

### Evaluation Metrics

In [None]:
import numpy as np
import evaluate

rouge = evaluate.load("rouge")
bleu = evaluate.load("bleu")
ter = evaluate.load("ter")

def compute_metrics(eval_pred):
  label_ids= eval_pred.label_ids
  pred_ids = eval_pred.predictions

  pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)

  if tokenizer.pad_token_id is not None:
    label_ids[label_ids == -100] = tokenizer.pad_token_id
  else:
    label_ids[label_ids == -100] = tokenizer.eos_token_id

  label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

  bleu_res = bleu.compute(predictions=pred_str, references=label_str)
  ter_res = ter.compute(predictions=pred_str, references=label_str)
  rouge_res = rouge.compute(predictions=pred_str, references=label_str)
  ter_acc = (1-(ter_res["score"]/100))

  metrics = {
      "BLEU": bleu_res["bleu"],
      "TER" : ter_res["score"],
      "TER-ACC" : ter_acc,
      "ROUGE-1" : rouge_res["rouge1"],
      "ROUGE-2" : rouge_res["rouge2"],
      "ROUGE-L" : rouge_res["rougeL"],
      }

  return metrics

### Fine-tuning the model

In [None]:
import wandb
wandb.login()

In [None]:
wandb.init(
    project=cfc.wandb_project,
    name = cfc.run_name,
    config={
        "architecture": "Swin-GPT2",
        "dataset": "Formula2Text-4k",
    })

In [None]:
from transformers import TrainingArguments, Trainer, Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    cfc.model_dir,
    report_to = "wandb",
    predict_with_generate=True,
    num_train_epochs=cfc.num_epochs,
    learning_rate = cfc.learning_rate,
    per_device_train_batch_size=cfc.batch_size,
    per_device_eval_batch_size=cfc.batch_size,
    weight_decay=cfc.weight_decay,
    evaluation_strategy="steps",
    eval_steps=200,
    logging_steps=200,
    save_steps=200,
    save_total_limit=1,
    load_best_model_at_end=True,
)

In [None]:
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    tokenizer=image_processor,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
)

In [None]:
from torch.utils.data import DataLoader

def get_eval_loader(eval_dataset=None):
  return DataLoader(valid_dataset, collate_fn=collate_fn, batch_size=cfc.batch_size)

def get_test_loader(eval_dataset=None):
  return DataLoader(test_dataset, collate_fn=collate_fn, batch_size=cfc.batch_size)

trainer.get_train_dataloader = lambda: DataLoader(train_dataset, collate_fn=collate_fn, batch_size=cfc.batch_size)
trainer.get_eval_dataloader = get_eval_loader
trainer.get_test_dataloader = get_test_loader

In [None]:
trainer.train()
trainer.save_model()

In [None]:
wandb.finish()

### Model Evaluation on Testset

In [None]:
from google.colab import files

In [None]:
!cp /content/drive/MyDrive/cf_module/cf_custom_functions.py /content

In [None]:
import cf_custom_functions as cf
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

In [None]:
# Load Testset
df_test = cf.load_test_data(cfc.test_file_path)

### Load pre-trained model

In [None]:
model_pt = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(cfc.encoder_model, cfc.decoder_model).to(device)

if "gpt2" in cfc.decoder_model:
  tokenizer.pad_token = tokenizer.eos_token
  model_pt.config.eos_token_id = tokenizer.eos_token_id
  model_pt.config.pad_token_id = tokenizer.pad_token_id
  model_pt.config.decoder_start_token_id = tokenizer.bos_token_id
else:
  model_pt.config.decoder_start_token_id = tokenizer.cls_token_id
  model_pt.config.pad_token_id = tokenizer.pad_token_id

In [None]:
def generate_VLM_predictions(test_data, model, image_processor, tokenizer, IMG_DIR) -> pd.DataFrame:
  df = test_data.copy()
  model = model
  image_processor = image_processor
  tokenizer = tokenizer
  image_path = IMG_DIR
  y_preds = []

  for i, entry in df.iterrows():
    image_name = entry["image_name"]
    image = Image.open(image_path + image_name).convert('RGB')
    img = image_processor(image, return_tensors="pt").to(device)

    output = model.generate(**img)
    caption = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
    y_preds.append(caption)

  y_preds = np.array(y_preds)
  df["prediction"] = y_preds
  return df

In [None]:
# Generate predictions
df_preds_pt = generate_VLM_predictions(df_test,model_pt,image_processor,tokenizer,cfc.img_dir)
df_preds_pt_clean = cf.post_processing_multi_predictions(df_preds_pt)

In [None]:
metrics_pt = cf.compute_evaluation_metrics(df_preds_pt_clean,"clean_prediction")
cf.save_evaluation_metrics(f"{cfc.model_name}_pretrained",metrics_pt,"../metrics/VLM_metrics.json")

### Load fine-tuned model

In [None]:
model_ft = VisionEncoderDecoderModel.from_pretrained(cfc.model_dir).to(device)

if "gpt2" in cfc.decoder_model:
  tokenizer.pad_token = tokenizer.eos_token
  model_ft.config.eos_token_id = tokenizer.eos_token_id
  model_ft.config.pad_token_id = tokenizer.pad_token_id
  model_ft.config.decoder_start_token_id = tokenizer.bos_token_id
else:
  model_ft.config.decoder_start_token_id = tokenizer.cls_token_id
  model_ft.config.pad_token_id = tokenizer.pad_token_id

In [None]:
# Generate predictions
df_preds_ft = generate_VLM_predictions(df_test,model_ft,image_processor,tokenizer,cfc.img_dir)
df_preds_ft_clean = cf.post_processing_multi_predictions(df_preds_ft)

In [None]:
metrics_ft = cf.compute_evaluation_metrics(df_preds_ft_clean,"clean_prediction")
cf.save_evaluation_metrics(f"{cfc.model_name}_finetuned",metrics_ft,"../metrics/VLM_metrics.json")