Check the TODO comment in cell 17 (adapt the GPU ID used in the experiments)

## Create PyTorch Dataset

In [1]:
import transformers
print(transformers.__version__)

4.29.1


In [2]:
import pandas as pd
from glob import  glob
import json
from tqdm.notebook import trange, tqdm
from sklearn.model_selection import StratifiedKFold
import torch
import numpy as np
import random
import os
from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
from torch.cuda.amp import GradScaler, autocast

In [3]:
image_path = glob('./train/images/*')

In [4]:
label_path = glob('./train/annotations/*')

In [5]:
assert len(image_path) == 60578

In [6]:
assert len(label_path)== 60578

### Understanding `max_patches` argument

The paper introduces a new paradigm for processing the input image. It takes the image and create `n_patches` aspect-ratio preserving patches, and concatenates the remaining sequence with padding tokens to finally get `max_patches` patches. It appears that this argument is quite crucial for training and evaluation, as the model becomes very sensitive to this parameter.

For the sake of our example, we will fine-tune a model with `max_patches=1024`.

Note that most of the `-base` models have been fine-tuned with `max_patches=2048`, and `4096` for `-large` models.

In [7]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
from tqdm.notebook import tqdm
MAX_PATCHES = 2048

class ImageCaptioningDataset(Dataset):
    def __init__(self, df, processor):
        self.dataset = df
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        row = self.dataset.iloc[idx, :]
        image = Image.open(row.image_path)
        #display(image)
        encoding = self.processor(images=image,
                                  text="Generate underlying data table of the figure below:",
                                  font_path="arial.ttf",
                                  return_tensors="pt",
                                  add_special_tokens=True, max_patches=MAX_PATCHES)
        
        encoding = {k:v.squeeze() for k,v in encoding.items()}
        encoding["text"] = row.label 
        return encoding

## Load model and processor

In [8]:
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration

processor = Pix2StructProcessor.from_pretrained("google/matcha-base")
model = Pix2StructForConditionalGeneration.from_pretrained("google/matcha-plotqa-v2")

Now that we have loaded the processor, let's load the dataset and the dataloader:

In [9]:
def collator(batch):
    new_batch = {"flattened_patches": [], "attention_mask": []}
    texts = [item["text"] for item in batch]
    # print(texts)
    text_inputs = processor.tokenizer(text=texts,
                                      padding="max_length",
                                      return_tensors="pt",
                                      add_special_tokens=True,
                                      max_length=512,
                                      truncation=True
                                      )

    new_batch["labels"] = text_inputs.input_ids

    for item in batch:
        new_batch["flattened_patches"].append(item["flattened_patches"])
        new_batch["attention_mask"].append(item["attention_mask"])

    new_batch["flattened_patches"] = torch.stack(new_batch["flattened_patches"])
    new_batch["attention_mask"] = torch.stack(new_batch["attention_mask"])

    return new_batch

In [10]:
df = pd.read_csv('train_with_fold.csv')
print(len(df))
train_df = df[df['fold'] != 0]
train_df.reset_index(drop=True, inplace=True)

60578


In [13]:
class CFG:
    scheduler = 'cosine'  # ['linear', 'cosine']
    batch_scheduler = True
    num_cycles = 0.5  # 1.5
    num_warmup_steps = 0.2
    max_input_length = 130
    epochs = 10  # 5
    encoder_lr = 10e-6
    decoder_lr = 10e-6
    min_lr = 0.5e-6
    eps = 1e-6
    betas = (0.9, 0.999)
    weight_decay = 0
    num_fold = 5
    batch_size = 2
    seed = 1006
    num_workers = 2
    device='cuda:1'
    print_freq = 100

In [14]:
train_dataset = ImageCaptioningDataset(train_df, processor)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=CFG.batch_size, collate_fn=collator, pin_memory=True,
                                  prefetch_factor=40, num_workers=2)

In [15]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(CFG.seed)

In [16]:
def get_scheduler(cfg, optimizer, num_train_steps):
    cfg.num_warmup_steps = cfg.num_warmup_steps * num_train_steps
    if cfg.scheduler == 'linear':
        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=cfg.num_warmup_steps, num_training_steps=num_train_steps
        )
    elif cfg.scheduler == 'cosine':
        scheduler = get_cosine_schedule_with_warmup(
            optimizer, num_warmup_steps=cfg.num_warmup_steps, num_training_steps=num_train_steps,
            num_cycles=cfg.num_cycles
        )
    return scheduler

num_train_steps = int(len(train_dataset) / CFG.batch_size * CFG.epochs)


## Train the model

Let's train the model! Run the simply the cell below for training the model. We have observed that finding the best hyper-parameters was quite challenging and required a lot of trials and errors, as the model can easily enter in "collapse-model" (always predicting the same output, no matter the input) if the HP are not chosen correctly. In this example, we found out that using `AdamW` optimizer with `lr=1e-5` seemed to be the best approach.

Let's also print the generation output of the model each 10 epochs!

Bear in mind that the model took some time to converge, for instance to get decent results we had to let the script run for ~1hour.

In [17]:
EPOCHS = 10

model = torch.nn.DataParallel(model, device_ids=[1,2])  # TODO: adapt to your hardware
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
scheduler = get_scheduler(CFG, optimizer, num_train_steps)
device = "cuda:1" if torch.cuda.is_available() else "cpu"
print(device)
model.to(device)
scaler = torch.cuda.amp.GradScaler()
model.train()

cuda:1


Pix2StructForConditionalGeneration(
  (encoder): Pix2StructVisionModel(
    (embeddings): Pix2StructVisionEmbeddings(
      (patch_projection): Linear(in_features=768, out_features=768, bias=True)
      (row_embedder): Embedding(4096, 768)
      (column_embedder): Embedding(4096, 768)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): Pix2StructVisionEncoder(
      (layer): ModuleList(
        (0): Pix2StructVisionLayer(
          (attention): Pix2StructVisionAttention(
            (query): Linear(in_features=768, out_features=768, bias=False)
            (key): Linear(in_features=768, out_features=768, bias=False)
            (value): Linear(in_features=768, out_features=768, bias=False)
            (output): Linear(in_features=768, out_features=768, bias=False)
          )
          (mlp): Pix2StructVisionMlp(
            (wi_0): Linear(in_features=768, out_features=2048, bias=False)
            (wi_1): Linear(in_features=768, out_features=2048, bias=False)
         

In [None]:
loss_file = open("loss.txt","w")

for epoch in range(CFG.epochs):
    print("Epoch:", epoch)
    for idx, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
        labels = batch.pop("labels").to(device)
        flattened_patches = batch.pop("flattened_patches").to(device)
        attention_mask = batch.pop("attention_mask").to(device)
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            outputs = model(flattened_patches=flattened_patches,
                            attention_mask=attention_mask,
                            labels=labels)

        loss = outputs.loss

        scaler.scale(loss).backward()
        #grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1000)
        # Unscales gradients and calls
        # or skips optimizer.step()
        scaler.step(optimizer)
        # Updates the scale for next iteration
        scaler.update()
        scheduler.step()
        if idx % 100 == 0:
            print("Loss:", loss.item(), f'lr : {scheduler.get_lr()[0]:.6f} ', sep=' ')
            loss_file.write(f"Epoch: {epoch}, Iteration: {idx}, Loss: {loss.item()}\n")
            loss_file.flush()
        

    if (epoch + 1) % 1 == 0:
        # torch.save(model.state_dict(), f'./matcha_v1/matcha_{epoch}.bin')
        torch.save(model.module.state_dict(), f'./matcha_v1/matcha_{epoch}.bin')

loss_file.close()
        

Epoch: 0


  0%|          | 0/24231 [00:00<?, ?it/s]



Loss: 15.795980453491211 lr : 0.000000 
Loss: 19.013874053955078 lr : 0.000000 
Loss: 11.124963760375977 lr : 0.000000 
Loss: 9.472804069519043 lr : 0.000000 
Loss: 9.150559425354004 lr : 0.000000 
Loss: 14.26933479309082 lr : 0.000000 
Loss: 11.317273139953613 lr : 0.000000 
Loss: 9.331180572509766 lr : 0.000000 
