## Set-up environment

In [None]:
!pip3 install evaluate



In [None]:
!pip3 install -q git+https://github.com/huggingface/transformers.git

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from transformers import AutoProcessor, Pix2StructForConditionalGeneration
from transformers.optimization import Adafactor, get_cosine_schedule_with_warmup
import json
from PIL import Image
import os
import torch
from evaluate import load
from itertools import cycle

### Understanding `max_patches` argument

The paper introduces a new paradigm for processing the input image. It takes the image and create `n_patches` aspect-ratio preserving patches, and concatenates the remaining sequence with padding tokens to finally get `max_patches` patches. It appears that this argument is quite crucial for training and evaluation, as the model becomes very sensitive to this parameter.

For the sake of our example, we will fine-tune a model with `max_patches=1024`.

Note that most of the `-base` models have been fine-tuned with `max_patches=2048`, and `4096` for `-large` models.

## Load model and processor

In [None]:
processor = AutoProcessor.from_pretrained("google/matcha-base")
model = Pix2StructForConditionalGeneration.from_pretrained("google/matcha-base")
device = "cuda" if torch.cuda.is_available() else "cpu"

## Load ChartQA dataset

In [None]:
MAX_PATCHES = 2048

class ChartQADataset(Dataset):
    def __init__(self, processor, root_dir="ChartQA Dataset", split='train', split2="both"):
        """
        Args:
            root_dir (string): Directory with all the ChartQA data.
            split (string): Which split to load ("train" or "val" or "test").
            split2 (string): Which split to load ("both" or "augmented" or "human") within the first split.
        """
        self.processor = processor
        self.root_dir = root_dir
        self.split = split
        self.image_dir = os.path.join(self.root_dir, self.split, 'png')
        
        self.qa_augmented = []
        self.qa_human = []
        # Load questions and answers
        with open(os.path.join(self.root_dir, self.split, f'{split}_augmented.json'), 'r',  encoding='utf-8') as f:
            self.qa_augmented = json.load(f)
        with open(os.path.join(self.root_dir, self.split, f'{split}_human.json'), 'r', encoding='utf-8') as f:
            self.qa_human = json.load(f)

        if split2 == "both":
            self.data = self.qa_augmented + self.qa_human
        elif split2 == "augmented":
            self.data = self.qa_augmented
        elif split2 == "human":
            self.data = self.qa_human
        
    def __len__(self):
        return len(self.qa_augmented) + len(self.qa_human)
        
    def __getitem__(self, idx):
        qa = self.data[idx]
        # Load image
        qa["image"] = Image.open(f"{self.image_dir}/{qa['imgname']}").convert('RGB')
        return qa

In [None]:
def collator(batch):
  new_batch = {"flattened_patches":[], "attention_mask":[]}
  images = [item["image"] for item in batch]
  header_texts = [item["query"] for item in batch]
  label_texts = [item['label'] for item in batch]
  
  inputs = processor(images=images, text=header_texts, padding="max_length", return_tensors="pt", add_special_tokens=True, max_length=128)
  labels = processor.tokenizer(label_texts, return_tensors="pt", padding="max_length", max_length=128)
  new_batch["labels"] = labels.input_ids
  new_batch["flattened_patches"] = inputs["flattened_patches"]
  new_batch["attention_mask"] = inputs["attention_mask"]


  return new_batch

Now that we have loaded the processor, let's load the dataset and the dataloader:

In [None]:
batch_size = 4

In [None]:
train_dataset = ChartQADataset(processor, split='train')
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, collate_fn=collator)

In [None]:
val_dataset = ChartQADataset(processor, split='val', split2="human")
val_dataloader = DataLoader(val_dataset, shuffle=True, batch_size=batch_size, collate_fn=collator)

## Train the model

In [None]:
training_steps = 10000
checkpoint_steps = (256/batch_size)*200

In [None]:
optimizer = Adafactor(model.parameters(), scale_parameter=False, relative_step=False, lr=0.01, weight_decay=1e-05)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=(256/batch_size)*10000)
model.to(device);

In [None]:
checkpoint_dir = './checkpoints'

# Check if the directory exists
if not os.path.exists(checkpoint_dir):
    # Create the directory if it doesn't exist
    os.makedirs(checkpoint_dir)

In [None]:
model.train()

for idx, batch in cycle(enumerate(train_dataloader)):
  labels = batch.pop("labels").to(device)
  flattened_patches = batch.pop("flattened_patches").to(device)
  attention_mask = batch.pop("attention_mask").to(device)


  outputs = model(flattened_patches=flattened_patches,
                  attention_mask=attention_mask,
                  labels=labels)
  
  loss = outputs.loss

  loss.backward()

  print(f"Step {idx+1}/{training_steps} - Loss: {loss.item()}")

  optimizer.step()
  optimizer.zero_grad()

  if idx+1 % checkpoint_steps == 0:
      model.eval()

      val_loss = []
      val_batch_size = []
      for batch in val_dataloader:
        val_labels = batch.pop("labels").to(device)
        val_flattened_patches = batch.pop("flattened_patches").to(device)
        val_attention_mask = batch.pop("attention_mask").to(device)

        outputs = model(flattened_patches=flattened_patches,
                  attention_mask=attention_mask,
                  labels=labels)
  
        loss = outputs.loss
        curr_val_loss = loss.item()
        val_loss.append(curr_val_loss)
        val_batch_size.append(val_labels.size(0))
      val_loss_average = [v_loss*batch_size for v_loss, batch_size in zip(val_loss, val_batch_size)] / sum(val_batch_size)
      print(f"Validation Loss: {val_loss_average}")
      checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict()
      }
      torch.save(checkpoint, f'{checkpoint_dir}/checkpoint_training_step_{idx+1}_val_loss_{val_loss_average}.pth')

      model.train()
  
  if idx+1 == training_steps:
    break

Epoch: 0
Loss: 3.9150619506835938


KeyboardInterrupt: 

## Inference

Let's check the results on our train dataset

In [None]:
test_dataset = ChartQADataset(processor, split='test')
test_dataloader = DataLoader(test_dataset, shuffle=True, batch_size=batch_size, collate_fn=collator)

In [None]:
exact_match_metric = load("exact_match")

In [None]:
model.eval()

aug_accuracy = []
for idx, batch in enumerate(test_dataloader):
  labels = batch.pop("labels").to(device)
  flattened_patches = batch.pop("flattened_patches").to(device)
  attention_mask = batch.pop("attention_mask").to(device)

  generated_ids = model.generate(flattened_patches=flattened_patches, attention_mask=attention_mask, max_length=128)
  print(processor.tokenizer.batch_decode(generated_ids,skip_special_tokens=True))

  metric = exact_match_metric.compute(predictions=generated_ids, references=labels)
  print(metric)
  break

