## Create PyTorch Dataset

In [1]:
import transformers
print(transformers.__version__)

4.30.0.dev0


In [2]:
import pandas as pd
from glob import  glob
import json
from tqdm.notebook import trange, tqdm
from sklearn.model_selection import StratifiedKFold
import torch
import numpy as np
import random
import os
from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
from torch.cuda.amp import GradScaler, autocast

In [3]:
image_path = glob('./train/images/*')

In [4]:
label_path = glob('./train/annotations/*')

In [5]:
assert len(image_path) == 60578

In [6]:
assert len(label_path)== 60578

### Understanding `max_patches` argument

The paper introduces a new paradigm for processing the input image. It takes the image and create `n_patches` aspect-ratio preserving patches, and concatenates the remaining sequence with padding tokens to finally get `max_patches` patches. It appears that this argument is quite crucial for training and evaluation, as the model becomes very sensitive to this parameter.

For the sake of our example, we will fine-tune a model with `max_patches=1024`.

Note that most of the `-base` models have been fine-tuned with `max_patches=2048`, and `4096` for `-large` models.

In [7]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
from tqdm.notebook import tqdm
MAX_PATCHES = 2048

class ImageCaptioningDataset(Dataset):
    def __init__(self, df, processor):
        self.dataset = df
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        row = self.dataset.iloc[idx, :]
        image = Image.open(row.image_path)
        #display(image)
        encoding = self.processor(images=image,
                                  text="Generate underlying data table of the figure below:",
                                  font_path="arial.ttf",
                                  return_tensors="pt",
                                  add_special_tokens=True, max_patches=MAX_PATCHES)
        
        encoding = {k:v.squeeze() for k,v in encoding.items()}
        encoding["text"] = row.label 
        return encoding

## Load model and processor

In [10]:
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration

processor = Pix2StructProcessor.from_pretrained("google/matcha-base")
model = Pix2StructForConditionalGeneration.from_pretrained("google/matcha-plotqa-v2")

Now that we have loaded the processor, let's load the dataset and the dataloader:

In [11]:
def collator(batch):
    new_batch = {"flattened_patches": [], "attention_mask": []}
    texts = [item["text"] for item in batch]
    # print(texts)
    text_inputs = processor.tokenizer(text=texts,
                                      padding="max_length",
                                      return_tensors="pt",
                                      add_special_tokens=True,
                                      max_length=512,
                                      truncation=True
                                      )

    new_batch["labels"] = text_inputs.input_ids

    for item in batch:
        new_batch["flattened_patches"].append(item["flattened_patches"])
        new_batch["attention_mask"].append(item["attention_mask"])

    new_batch["flattened_patches"] = torch.stack(new_batch["flattened_patches"])
    new_batch["attention_mask"] = torch.stack(new_batch["attention_mask"])

    return new_batch

In [12]:
df = pd.read_csv('train_with_fold.csv')
print(len(df))
train_df = df[df['fold'] != 0]
train_df.reset_index(drop=True, inplace=True)

60578


In [14]:
class CFG:
    scheduler = 'cosine'  # ['linear', 'cosine']
    batch_scheduler = True
    num_cycles = 0.5  # 1.5
    num_warmup_steps = 0.2
    max_input_length = 130
    epochs = 10  # 5
    encoder_lr = 10e-6
    decoder_lr = 10e-6
    min_lr = 0.5e-6
    eps = 1e-6
    betas = (0.9, 0.999)
    weight_decay = 0
    num_fold = 5
    batch_size = 4
    seed = 1006
    num_workers = 2
    device='cuda'
    print_freq = 100

In [13]:
train_dataset = ImageCaptioningDataset(train_df, processor)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=CFG.batch_size, collate_fn=collator, pin_memory=True,
                                  prefetch_factor=40, num_workers=2)

In [15]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(CFG.seed)

In [16]:
def get_scheduler(cfg, optimizer, num_train_steps):
    cfg.num_warmup_steps = cfg.num_warmup_steps * num_train_steps
    if cfg.scheduler == 'linear':
        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=cfg.num_warmup_steps, num_training_steps=num_train_steps
        )
    elif cfg.scheduler == 'cosine':
        scheduler = get_cosine_schedule_with_warmup(
            optimizer, num_warmup_steps=cfg.num_warmup_steps, num_training_steps=num_train_steps,
            num_cycles=cfg.num_cycles
        )
    return scheduler

num_train_steps = int(len(train_dataset) / CFG.batch_size * CFG.epochs)


## Train the model

Let's train the model! Run the simply the cell below for training the model. We have observed that finding the best hyper-parameters was quite challenging and required a lot of trials and errors, as the model can easily enter in "collapse-model" (always predicting the same output, no matter the input) if the HP are not chosen correctly. In this example, we found out that using `AdamW` optimizer with `lr=1e-5` seemed to be the best approach.

Let's also print the generation output of the model each 20 epochs!

Bear in mind that the model took some time to converge, for instance to get decent results we had to let the script run for ~1hour.

In [17]:
EPOCHS = 10

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
scheduler = get_scheduler(CFG, optimizer, num_train_steps)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
model.to(device)
scaler = torch.cuda.amp.GradScaler()
model.train()

cuda


Pix2StructForConditionalGeneration(
  (encoder): Pix2StructVisionModel(
    (embeddings): Pix2StructVisionEmbeddings(
      (patch_projection): Linear(in_features=768, out_features=768, bias=True)
      (row_embedder): Embedding(4096, 768)
      (column_embedder): Embedding(4096, 768)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): Pix2StructVisionEncoder(
      (layer): ModuleList(
        (0): Pix2StructVisionLayer(
          (attention): Pix2StructVisionAttention(
            (query): Linear(in_features=768, out_features=768, bias=False)
            (key): Linear(in_features=768, out_features=768, bias=False)
            (value): Linear(in_features=768, out_features=768, bias=False)
            (output): Linear(in_features=768, out_features=768, bias=False)
          )
          (mlp): Pix2StructVisionMlp(
            (wi_0): Linear(in_features=768, out_features=2048, bias=False)
            (wi_1): Linear(in_features=768, out_features=2048, bias=False)
         

In [None]:
for epoch in range(CFG.epochs):
    print("Epoch:", epoch)
    for idx, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
        labels = batch.pop("labels").to(device)
        flattened_patches = batch.pop("flattened_patches").to(device)
        attention_mask = batch.pop("attention_mask").to(device)
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            outputs = model(flattened_patches=flattened_patches,
                            attention_mask=attention_mask,
                            labels=labels)

        loss = outputs.loss

        scaler.scale(loss).backward()
        #grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1000)
        # Unscales gradients and calls
        # or skips optimizer.step()
        scaler.step(optimizer)
        # Updates the scale for next iteration
        scaler.update()
        scheduler.step()
        if idx % 100 == 0:
            print("Loss:", loss.item(), f'lr : {scheduler.get_lr()[0]:.6f} ', sep=' ')
        

    if (epoch + 1) % 1 == 0:
        torch.save(model.state_dict(), f'./matcha_v1/matcha_{epoch}.bin')

Epoch: 0


  0%|          | 0/15145 [00:00<?, ?it/s]



Loss: 18.981365203857422 lr : 0.0000 
Loss: 16.097797393798828 lr : 0.0000 
Loss: 18.269908905029297 lr : 0.0000 
Loss: 16.054994583129883 lr : 0.0000 
Loss: 11.140259742736816 lr : 0.0000 
Loss: 8.733059883117676 lr : 0.0000 
Loss: 5.940934181213379 lr : 0.0000 
Loss: 5.100738048553467 lr : 0.0000 
Loss: 5.102210998535156 lr : 0.0000 
Loss: 4.010885238647461 lr : 0.0000 
Loss: 4.282005310058594 lr : 0.0000 
Loss: 4.042594909667969 lr : 0.0000 
Loss: 3.5747976303100586 lr : 0.0000 
Loss: 5.140348434448242 lr : 0.0000 
Loss: 4.2400617599487305 lr : 0.0000 
Loss: 4.306550979614258 lr : 0.0000 
Loss: 4.28499698638916 lr : 0.0000 
Loss: 4.107604026794434 lr : 0.0000 
Loss: 4.192779064178467 lr : 0.0000 
Loss: 4.6402363777160645 lr : 0.0000 
Loss: 4.045029640197754 lr : 0.0000 
Loss: 4.375026226043701 lr : 0.0000 
Loss: 3.2620224952697754 lr : 0.0000 
Loss: 4.083436489105225 lr : 0.0000 
Loss: 3.6515517234802246 lr : 0.0000 
Loss: 3.3203001022338867 lr : 0.0000 
Loss: 3.744830846786499 lr :

  0%|          | 0/15145 [00:00<?, ?it/s]

Loss: 0.20481933653354645 lr : 0.0000 
Loss: 0.43141621351242065 lr : 0.0000 
Loss: 0.0845700055360794 lr : 0.0000 
Loss: 0.12753713130950928 lr : 0.0000 
Loss: 0.1718612015247345 lr : 0.0000 
Loss: 0.1573198437690735 lr : 0.0000 
Loss: 0.4603939354419708 lr : 0.0000 
Loss: 0.07166243344545364 lr : 0.0000 
Loss: 0.08136807382106781 lr : 0.0000 
Loss: 0.1638917773962021 lr : 0.0000 
Loss: 0.25288960337638855 lr : 0.0000 
Loss: 0.11330229043960571 lr : 0.0000 
Loss: 0.11167331784963608 lr : 0.0000 
Loss: 0.14511267840862274 lr : 0.0000 
Loss: 0.2542201578617096 lr : 0.0000 
Loss: 0.2906930148601532 lr : 0.0000 
Loss: 0.18470777571201324 lr : 0.0000 
Loss: 0.24274948239326477 lr : 0.0000 
Loss: 0.17966996133327484 lr : 0.0000 
Loss: 0.23360802233219147 lr : 0.0000 
Loss: 0.22135958075523376 lr : 0.0000 
Loss: 0.17021240293979645 lr : 0.0000 
Loss: 0.20791390538215637 lr : 0.0000 
Loss: 0.2486371397972107 lr : 0.0000 
Loss: 0.3508821129798889 lr : 0.0000 
Loss: 0.18464551866054535 lr : 0.0

  0%|          | 0/15145 [00:00<?, ?it/s]

Loss: 0.3277851939201355 lr : 0.0000 
Loss: 0.2065010517835617 lr : 0.0000 
Loss: 0.20300701260566711 lr : 0.0000 
Loss: 0.14089538156986237 lr : 0.0000 
Loss: 0.1640571802854538 lr : 0.0000 
Loss: 0.1218208596110344 lr : 0.0000 
Loss: 0.24343055486679077 lr : 0.0000 
Loss: 0.06384728103876114 lr : 0.0000 
Loss: 0.18722495436668396 lr : 0.0000 
Loss: 0.30027249455451965 lr : 0.0000 
Loss: 0.13525807857513428 lr : 0.0000 
Loss: 0.22028736770153046 lr : 0.0000 
Loss: 0.1588953733444214 lr : 0.0000 
Loss: 0.2761460542678833 lr : 0.0000 
Loss: 0.17602002620697021 lr : 0.0000 
Loss: 0.3080819845199585 lr : 0.0000 
Loss: 0.14960549771785736 lr : 0.0000 
Loss: 0.08443856239318848 lr : 0.0000 
Loss: 0.1802663803100586 lr : 0.0000 
Loss: 0.20693688094615936 lr : 0.0000 
Loss: 0.14269421994686127 lr : 0.0000 
Loss: 0.31147393584251404 lr : 0.0000 
Loss: 0.17172573506832123 lr : 0.0000 
Loss: 0.05020740628242493 lr : 0.0000 
Loss: 0.1946120411157608 lr : 0.0000 
Loss: 0.0982380360364914 lr : 0.00