Check the TODO comment in cell 17 (adapt the GPU ID used in the experiments)

## Create PyTorch Dataset

In [10]:
import transformers
print(transformers.__version__)

4.29.1


In [11]:
import pandas as pd
from glob import  glob
import json
from tqdm.notebook import trange, tqdm
from sklearn.model_selection import StratifiedKFold
import torch
import numpy as np
import random
import os
from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
from torch.cuda.amp import GradScaler, autocast

In [12]:
image_path = glob('./datasets/default/train/images/*')

In [13]:
label_path = glob('./datasets/default/train/annotations/*')

In [14]:
assert len(image_path) == 60578

In [15]:
assert len(label_path)== 60578

### Understanding `max_patches` argument

The paper introduces a new paradigm for processing the input image. It takes the image and create `n_patches` aspect-ratio preserving patches, and concatenates the remaining sequence with padding tokens to finally get `max_patches` patches. It appears that this argument is quite crucial for training and evaluation, as the model becomes very sensitive to this parameter.

For the sake of our example, we will fine-tune a model with `max_patches=1024`.

Note that most of the `-base` models have been fine-tuned with `max_patches=2048`, and `4096` for `-large` models.

In [7]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
from tqdm.notebook import tqdm
MAX_PATCHES = 2048

class ImageCaptioningDataset(Dataset):
    def __init__(self, df, processor):
        self.dataset = df
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        row = self.dataset.iloc[idx, :]
        image = Image.open(row.image_path)
        encoding = self.processor(images=image,
                                  text="Generate underlying data table of the figure below:",
                                  font_path="arial.ttf",
                                  return_tensors="pt",
                                  add_special_tokens=True, max_patches=MAX_PATCHES)
        
        encoding = {k:v.squeeze() for k,v in encoding.items()}
        encoding["text"] = row.label 
        return encoding

## Load model and processor

In [9]:
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration, AutoProcessor
model = Pix2StructForConditionalGeneration.from_pretrained("./deplot")
processor = AutoProcessor.from_pretrained("./deplot")
model.load_state_dict(torch.load("./weights/deplot_v4/deplot_v4_v4_0.bin"))

<All keys matched successfully>

Now that we have loaded the processor, let's load the dataset and the dataloader:

In [16]:
def collator(batch):
    new_batch = {"flattened_patches": [], "attention_mask": []}
    texts = [item["text"] for item in batch]
    text_inputs = processor.tokenizer(text=texts,
                                      padding="max_length",
                                      return_tensors="pt",
                                      add_special_tokens=True,
                                      max_length=512,
                                      truncation=True
                                      )

    new_batch["labels"] = text_inputs.input_ids

    for item in batch:
        new_batch["flattened_patches"].append(item["flattened_patches"])
        new_batch["attention_mask"].append(item["attention_mask"])

    new_batch["flattened_patches"] = torch.stack(new_batch["flattened_patches"])
    new_batch["attention_mask"] = torch.stack(new_batch["attention_mask"])

    return new_batch

In [17]:
df = pd.read_csv('./datasets/default/train_with_fold.csv')
print(len(df))
df.reset_index(drop=True, inplace=True)
train_df = df
train_df

60578


Unnamed: 0,image_path,types,label
0,./datasets/default/train/images/2616369a5acb.jpg,vertical_bar,<x_start>Guyana;Haiti;High income...;Honduras;...
1,./datasets/default/train/images/05a4e1dd189f.jpg,vertical_bar,<x_start>Kent;Sweetwater;Lackawanna;Merrimack;...
2,./datasets/default/train/images/5a6cca560ac5.jpg,line,<x_start>1961;1965;1970;1975;1980;1985;1990;19...
3,./datasets/default/train/images/1d6b573f6615.jpg,vertical_bar,<x_start>ISingapore;Soloman Islands;South Afri...
4,./datasets/default/train/images/8f8be0bb2e11.jpg,line,<x_start>2002;2003;2004;2005;2006;2007;2008;20...
...,...,...,...
60573,./datasets/default/train/images/a98bfe7505fe.jpg,vertical_bar,<x_start>Namibia;Nepal;Netherlands;New Zealand...
60574,./datasets/default/train/images/0929103983d3.jpg,line,<x_start>Jan;Feb;Mar;Apr;May;Jun;Jul;Aug;Sep;O...
60575,./datasets/default/train/images/af0675ba4b7c.jpg,line,<x_start>0;2;4;6;8;10<x_end> <y_start>29.09381...
60576,./datasets/default/train/images/91245c2c262b.jpg,vertical_bar,<x_start>Taiwan;Tajikistan;Tanzania;Thailand;T...


In [18]:
class CFG:
    scheduler = 'cosine'  # ['linear', 'cosine']
    batch_scheduler = True
    num_cycles = 0.5  # 1.5
    num_warmup_steps = 0.2
    max_input_length = 130
    epochs = 1  # 5
    encoder_lr = 1e-5
    decoder_lr = 1e-5
    min_lr = 1e-7
    eps = 1e-6
    betas = (0.9, 0.999)
    weight_decay = 1e-6
    num_fold = 5
    batch_size = 2
    seed = 1006
    num_workers = 2
    device='cuda'
    print_freq = 100

In [19]:
train_dataset = ImageCaptioningDataset(train_df, processor)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=CFG.batch_size, collate_fn=collator, pin_memory=True,
                                  prefetch_factor=40, num_workers=2)

In [20]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(CFG.seed)

In [21]:
def get_scheduler(cfg, optimizer, num_train_steps):
    cfg.num_warmup_steps = cfg.num_warmup_steps * num_train_steps
    if cfg.scheduler == 'linear':
        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=cfg.num_warmup_steps, num_training_steps=num_train_steps
        )
    elif cfg.scheduler == 'cosine':
        scheduler = get_cosine_schedule_with_warmup(
            optimizer, num_warmup_steps=cfg.num_warmup_steps, num_training_steps=num_train_steps,
            num_cycles=cfg.num_cycles
        )
    return scheduler

num_train_steps = int(len(train_dataset) / CFG.batch_size * CFG.epochs)


## Train the model

Let's train the model! Run the simply the cell below for training the model. We have observed that finding the best hyper-parameters was quite challenging and required a lot of trials and errors, as the model can easily enter in "collapse-model" (always predicting the same output, no matter the input) if the HP are not chosen correctly. In this example, we found out that using `AdamW` optimizer with `lr=1e-5` seemed to be the best approach.

Let's also print the generation output of the model each 10 epochs!

Bear in mind that the model took some time to converge, for instance to get decent results we had to let the script run for ~1hour.

In [22]:
model = torch.nn.DataParallel(model, device_ids=[2,3])  # TODO: adapt to your hardware
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5, weight_decay=1e-6)
scheduler = get_scheduler(CFG, optimizer, num_train_steps)
device = "cuda:2" if torch.cuda.is_available() else "cpu"
print(device)
model.to(device)
scaler = torch.cuda.amp.GradScaler()
model.train()

cuda:2


DataParallel(
  (module): Pix2StructForConditionalGeneration(
    (encoder): Pix2StructVisionModel(
      (embeddings): Pix2StructVisionEmbeddings(
        (patch_projection): Linear(in_features=768, out_features=768, bias=True)
        (row_embedder): Embedding(4096, 768)
        (column_embedder): Embedding(4096, 768)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (encoder): Pix2StructVisionEncoder(
        (layer): ModuleList(
          (0): Pix2StructVisionLayer(
            (attention): Pix2StructVisionAttention(
              (query): Linear(in_features=768, out_features=768, bias=False)
              (key): Linear(in_features=768, out_features=768, bias=False)
              (value): Linear(in_features=768, out_features=768, bias=False)
              (output): Linear(in_features=768, out_features=768, bias=False)
            )
            (mlp): Pix2StructVisionMlp(
              (wi_0): Linear(in_features=768, out_features=2048, bias=False)
              (wi_1): 

In [23]:
import matplotlib.pyplot as plt

loss_file = open("./loss/loss_deplota_final.txt","w")

for epoch in range(CFG.epochs):
    print("Epoch:", epoch)
    for idx, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
        labels = batch.pop("labels").to(device)
        flattened_patches = batch.pop("flattened_patches").to(device)
        attention_mask = batch.pop("attention_mask").to(device)
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            outputs = model(flattened_patches=flattened_patches,
                            attention_mask=attention_mask,
                            labels=labels)

        loss = outputs.loss.mean()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        if idx % 100 == 0:
            print("Epoch:", epoch, "Loss:", loss.item(), f'lr : {scheduler.get_lr()[0]:.9f} ', sep=' ')
            loss_file.write(f"Epoch: {epoch}, Iteration: {idx}, Loss: {loss.item()}, lr : {scheduler.get_lr()[0]:.9f}\n")
            loss_file.flush()
        
    if (epoch + 1) % 1 == 0:
        torch.save(model.module.state_dict(), f'./weights/deplot_v6/deplot_v6{epoch}.bin')

loss_file.close()

Epoch: 0


  0%|          | 0/30289 [00:00<?, ?it/s]



Epoch: 0 Loss: 0.7125048637390137 lr : 0.000000002 
Epoch: 0 Loss: 1.0355751514434814 lr : 0.000000167 
Epoch: 0 Loss: 0.29985955357551575 lr : 0.000000497 
Epoch: 0 Loss: 0.8635852932929993 lr : 0.000000662 
Epoch: 0 Loss: 0.43086129426956177 lr : 0.000000827 
Epoch: 0 Loss: 0.6352189779281616 lr : 0.000000992 
Epoch: 0 Loss: 0.5029826760292053 lr : 0.000001157 
Epoch: 0 Loss: 0.8607119917869568 lr : 0.000001322 
Epoch: 0 Loss: 1.1083998680114746 lr : 0.000001487 
Epoch: 0 Loss: 0.4324033856391907 lr : 0.000001652 
Epoch: 0 Loss: 0.6643177270889282 lr : 0.000001817 
Epoch: 0 Loss: 1.061600685119629 lr : 0.000001983 
Epoch: 0 Loss: 0.16949674487113953 lr : 0.000002148 
Epoch: 0 Loss: 0.34243592619895935 lr : 0.000002313 
Epoch: 0 Loss: 0.5405922532081604 lr : 0.000002478 
Epoch: 0 Loss: 1.056748628616333 lr : 0.000002643 
Epoch: 0 Loss: 0.8286745548248291 lr : 0.000002808 


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Epoch: 0 Loss: 0.7291737794876099 lr : 0.000009887 
Epoch: 0 Loss: 0.7909120917320251 lr : 0.000009873 
Epoch: 0 Loss: 0.8310853838920593 lr : 0.000009858 
Epoch: 0 Loss: 0.8715543150901794 lr : 0.000009842 
Epoch: 0 Loss: 0.7366686463356018 lr : 0.000009826 
Epoch: 0 Loss: 1.5487995147705078 lr : 0.000009808 
Epoch: 0 Loss: 0.40721139311790466 lr : 0.000009790 
Epoch: 0 Loss: 0.6000350713729858 lr : 0.000009771 
Epoch: 0 Loss: 0.8124750256538391 lr : 0.000009751 
Epoch: 0 Loss: 0.7051783800125122 lr : 0.000009731 
Epoch: 0 Loss: 1.1867467164993286 lr : 0.000009709 
Epoch: 0 Loss: 0.7914816737174988 lr : 0.000009687 
Epoch: 0 Loss: 1.2311625480651855 lr : 0.000009664 
Epoch: 0 Loss: 0.7817955017089844 lr : 0.000009640 
Epoch: 0 Loss: 1.4209864139556885 lr : 0.000009616 


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Epoch: 0 Loss: 0.352177232503891 lr : 0.000008969 
Epoch: 0 Loss: 0.8648884892463684 lr : 0.000008929 
Epoch: 0 Loss: 0.8045511245727539 lr : 0.000008889 
Epoch: 0 Loss: 0.33150506019592285 lr : 0.000008847 
Epoch: 0 Loss: 0.3041873276233673 lr : 0.000008806 
Epoch: 0 Loss: 0.8247061967849731 lr : 0.000008763 
Epoch: 0 Loss: 0.40094509720802307 lr : 0.000008720 
Epoch: 0 Loss: 0.8243778944015503 lr : 0.000008677 
Epoch: 0 Loss: 1.5518616437911987 lr : 0.000008633 
Epoch: 0 Loss: 0.31360748410224915 lr : 0.000008588 
Epoch: 0 Loss: 1.2452291250228882 lr : 0.000008542 
Epoch: 0 Loss: 0.6742597818374634 lr : 0.000008496 
Epoch: 0 Loss: 0.34211230278015137 lr : 0.000008450 
Epoch: 0 Loss: 0.6707308292388916 lr : 0.000008402 
Epoch: 0 Loss: 1.0115817785263062 lr : 0.000008355 
Epoch: 0 Loss: 0.9568164348602295 lr : 0.000008306 
Epoch: 0 Loss: 0.34076717495918274 lr : 0.000008257 
Epoch: 0 Loss: 0.6478431224822998 lr : 0.000008208 
Epoch: 0 Loss: 0.74178147315979 lr : 0.000008158 
Epoch: 0 L

Epoch: 0 Loss: 0.79490065574646 lr : 0.000000475 
Epoch: 0 Loss: 0.9485747814178467 lr : 0.000000447 
Epoch: 0 Loss: 1.5893464088439941 lr : 0.000000421 
Epoch: 0 Loss: 0.4751659035682678 lr : 0.000000395 
Epoch: 0 Loss: 1.1286671161651611 lr : 0.000000371 
Epoch: 0 Loss: 1.0332812070846558 lr : 0.000000346 
Epoch: 0 Loss: 0.5004827976226807 lr : 0.000000323 
Epoch: 0 Loss: 0.1714707612991333 lr : 0.000000301 
Epoch: 0 Loss: 1.2647364139556885 lr : 0.000000279 
Epoch: 0 Loss: 0.3197959065437317 lr : 0.000000258 
Epoch: 0 Loss: 1.2978699207305908 lr : 0.000000238 
Epoch: 0 Loss: 0.9293451309204102 lr : 0.000000218 
Epoch: 0 Loss: 0.8262723684310913 lr : 0.000000200 
Epoch: 0 Loss: 1.3899335861206055 lr : 0.000000182 
Epoch: 0 Loss: 1.0441644191741943 lr : 0.000000165 
Epoch: 0 Loss: 0.7729865312576294 lr : 0.000000149 
Epoch: 0 Loss: 0.6349879503250122 lr : 0.000000134 
Epoch: 0 Loss: 0.642643928527832 lr : 0.000000119 
Epoch: 0 Loss: 1.2339662313461304 lr : 0.000000106 
Epoch: 0 Loss: 