## Set-up environment

In [3]:
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from transformers import AutoProcessor, Pix2StructForConditionalGeneration
import json
from PIL import Image
import os

Run the cells below to setup the environment

In [13]:
!pip3 install -q git+https://github.com/huggingface/transformers.git

In [14]:
!pip3 install -q datasets

## Load ChartQA dataset

## Create PyTorch Dataset

In [19]:
with open(f"ChartQA Dataset/train/train_augmented.json") as f:
    aug = json.load(f)
with open(f"ChartQA Dataset/train/train_human.json") as f:
    human = json.load(f)

In [25]:
set([x["imgname"] for x in aug]) & set([x["imgname"] for x in human])

{'multi_col_1000.png',
 'multi_col_100021.png',
 'multi_col_100053.png',
 'multi_col_100065.png',
 'multi_col_100066.png',
 'multi_col_100100.png',
 'multi_col_100101.png',
 'multi_col_100177.png',
 'multi_col_100179.png',
 'multi_col_100187.png',
 'multi_col_100229.png',
 'multi_col_100236.png',
 'multi_col_100277.png',
 'multi_col_100294.png',
 'multi_col_100312.png',
 'multi_col_100322.png',
 'multi_col_100344.png',
 'multi_col_100347.png',
 'multi_col_100349.png',
 'multi_col_100350.png',
 'multi_col_100354.png',
 'multi_col_100371.png',
 'multi_col_100397.png',
 'multi_col_100405.png',
 'multi_col_100409.png',
 'multi_col_100423.png',
 'multi_col_100424.png',
 'multi_col_100497.png',
 'multi_col_100516.png',
 'multi_col_100578.png',
 'multi_col_100579.png',
 'multi_col_100592.png',
 'multi_col_100727.png',
 'multi_col_100750.png',
 'multi_col_100761.png',
 'multi_col_100835.png',
 'multi_col_100853.png',
 'multi_col_100883.png',
 'multi_col_100904.png',
 'multi_col_100909.png',
 '

In [4]:
class ChartQADataset(Dataset):
    def __init__(self, root_dir, split='train'):
        """
        Args:
            root_dir (string): Directory with all the ChartQA data.
            split (string): Which split to load ("train" or "val" or "test").
        """
        self.root_dir = root_dir
        self.split = split
        
        # Load questions and answers
        with open(os.path.join(self.root_dir, self.split, f'train_augmented.json'), 'r') as f:
            self.qa_augmented = json.load(f)
        with open(os.path.join(self.root_dir, self.split, f'train_human.json'), 'r') as f:
            self.qa_human = json.load(f)
            
        # Load image annotations
        self.annotations = {}
        annotations_dir = os.path.join(self.root_dir, self.split, 'annotations')
        for filename in os.listdir(annotations_dir):
            with open(os.path.join(annotations_dir, filename), 'r') as f:
                self.annotations[filename[:-5]] = json.load(f)
        
        # Load image paths
        self.image_paths = {}
        png_dir = os.path.join(self.root_dir, self.split, 'png')
        for filename in os.listdir(png_dir):
            self.image_paths[filename[:-4]] = os.path.join(png_dir, filename)
        
        # Load table paths
        self.table_paths = {}
        table_dir = os.path.join(self.root_dir, self.split, 'tables')
        for filename in os.listdir(table_dir):
            self.table_paths[filename[:-4]] = os.path.join(table_dir, filename)
        
    def __len__(self):
        return len(self.qa_augmented)
        
    def __getitem__(self, idx):
        # Get question and answer
        qa = self.qa_augmented[idx]
        if 'human' in qa:
            qa_human_idx = int(qa['human']) - 1
            qa_human = self.qa_human[qa_human_idx]
            question = qa_human['question']
            answer = qa_human['answer']
        else:
            question = qa['question']
            answer = qa['answer']
        
        # Get chart info
        chart_id = qa['chart_id']
        chart_annotations = self.annotations[chart_id]
        chart_image_path = self.image_paths[chart_id]
        chart_table_path = self.table_paths[chart_id]
        
        # Load image and table
        chart_image = Image.open(chart_image_path).convert('RGB')
        chart_table = pd.read_csv(chart_table_path)
        
        sample = {'question': question, 'answer': answer, 'chart_image': chart_image, 'chart_table': chart_table}
        
        if self.transform:
            sample = self.transform(sample)
        
        return sample

### Understanding `max_patches` argument

The paper introduces a new paradigm for processing the input image. It takes the image and create `n_patches` aspect-ratio preserving patches, and concatenates the remaining sequence with padding tokens to finally get `max_patches` patches. It appears that this argument is quite crucial for training and evaluation, as the model becomes very sensitive to this parameter.

For the sake of our example, we will fine-tune a model with `max_patches=1024`.

Note that most of the `-base` models have been fine-tuned with `max_patches=2048`, and `4096` for `-large` models.

In [37]:

MAX_PATCHES = 2048

class ChartQADataset(Dataset):
    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        encoding = self.processor(images=item["image"], return_tensors="pt", add_special_tokens=True, max_patches=MAX_PATCHES)
        
        encoding = {k:v.squeeze() for k,v in encoding.items()}
        encoding["text"] = item["text"]
        return encoding

## Load model and processor

In [112]:

processor = AutoProcessor.from_pretrained("google/matcha-base")
model = Pix2StructForConditionalGeneration.from_pretrained("google/matcha-base")

Downloading (…)rocessor_config.json: 100%|██████████| 249/249 [00:00<00:00, 1.82MB/s]
Downloading (…)okenizer_config.json: 100%|██████████| 2.62k/2.62k [00:00<00:00, 30.9MB/s]
Downloading spiece.model: 100%|██████████| 851k/851k [00:00<00:00, 4.02MB/s]
Downloading (…)/main/tokenizer.json: 100%|██████████| 3.27M/3.27M [00:00<00:00, 10.5MB/s]
Downloading (…)cial_tokens_map.json: 100%|██████████| 2.20k/2.20k [00:00<00:00, 10.3MB/s]
Downloading (…)lve/main/config.json: 100%|██████████| 4.89k/4.89k [00:00<00:00, 9.09MB/s]
Downloading pytorch_model.bin: 100%|██████████| 1.13G/1.13G [00:27<00:00, 40.7MB/s]


Now that we have loaded the processor, let's load the dataset and the dataloader:

In [29]:
def collator(batch):
  new_batch = {"flattened_patches":[], "attention_mask":[]}
  texts = [item["text"] for item in batch]
  
  text_inputs = processor(text=texts, padding="max_length", return_tensors="pt", add_special_tokens=True, max_length=20)
  
  new_batch["labels"] = text_inputs.input_ids
  
  for item in batch:
    new_batch["flattened_patches"].append(item["flattened_patches"])
    new_batch["attention_mask"].append(item["attention_mask"])
  
  new_batch["flattened_patches"] = torch.stack(new_batch["flattened_patches"])
  new_batch["attention_mask"] = torch.stack(new_batch["attention_mask"])

  return new_batch

In [30]:
train_dataset = ImageCaptioningDataset(dataset, processor)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=1, collate_fn=collator)

## Train the model

Let's train the model! Run the simply the cell below for training the model. We have observed that finding the best hyper-parameters was quite challenging and required a lot of trials and errors, as the model can easily enter in "collapse-model" (always predicting the same output, no matter the input) if the HP are not chosen correctly. In this example, we found out that using `AdamW` optimizer with `lr=1e-5` seemed to be the best approach.

Let's also print the generation output of the model each 20 epochs!

Bear in mind that the model took some time to converge, for instance to get decent results we had to let the script run for ~1hour. 

In [7]:
import torch

EPOCHS = 5000

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

model.train()

for epoch in range(EPOCHS):
  print("Epoch:", epoch)
  for idx, batch in enumerate(train_dataloader):
    labels = batch.pop("labels").to(device)
    flattened_patches = batch.pop("flattened_patches").to(device)
    attention_mask = batch.pop("attention_mask").to(device)

    outputs = model(flattened_patches=flattened_patches,
                    attention_mask=attention_mask,
                    labels=labels)
    
    loss = outputs.loss

    print("Loss:", loss.item())

    loss.backward()

    optimizer.step()
    optimizer.zero_grad()

    if (epoch + 1) % 20 == 0:
        model.eval()

        predictions = model.generate(flattened_patches=flattened_patches, attention_mask=attention_mask)        
        print("Predictions:", processor.batch_decode(predictions, skip_special_tokens=True))

        model.train()

Epoch: 0
Loss: 9.008894920349121


KeyboardInterrupt: 

## Inference

Let's check the results on our train dataset

In [None]:
# load image
example = dataset[0]
image = example["image"]
image

In [None]:
# prepare image for the model
model.eval()

inputs = processor(images=image, return_tensors="pt", max_patches=512).to(device)

flattened_patches = inputs.flattened_patches
attention_mask = inputs.attention_mask

generated_ids = model.generate(flattened_patches=flattened_patches, attention_mask=attention_mask, max_length=50)
generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(generated_caption)

## Load from the Hub

Once trained you can push the model and processor on the Hub to use them later. 
Meanwhile you can play with the model that we have fine-tuned!

In [None]:
import torch
from transformers import Pix2StructForConditionalGeneration, AutoProcessor

device = "cuda" if torch.cuda.is_available() else "cpu"

model = Pix2StructForConditionalGeneration.from_pretrained("ybelkada/pix2struct-base-football").to(device)
processor = AutoProcessor.from_pretrained("ybelkada/pix2struct-base-football")

Let's check the results on our train dataset!

In [None]:
from matplotlib import pyplot as plt

fig = plt.figure(figsize=(18, 14))

# prepare image for the model
for i, example in enumerate(dataset):
  image = example["image"]
  inputs = processor(images=image, return_tensors="pt", max_patches=1024).to(device)
  flattened_patches = inputs.flattened_patches
  attention_mask = inputs.attention_mask

  generated_ids = model.generate(flattened_patches=flattened_patches, attention_mask=attention_mask, max_length=50)
  generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
  fig.add_subplot(2, 3, i+1)
  plt.imshow(image)
  plt.axis("off")
  plt.title(f"Generated caption: {generated_caption}")