Check the TODO comment in cell 17 (adapt the GPU ID used in the experiments)

## Create PyTorch Dataset

In [1]:
import transformers
print(transformers.__version__)

4.29.1


In [2]:
import pandas as pd
from glob import  glob
import json
from tqdm.notebook import trange, tqdm
from sklearn.model_selection import StratifiedKFold
import torch
import numpy as np
import random
import os
from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
from torch.cuda.amp import GradScaler, autocast

In [3]:
# image_path = glob('./datasets/default/train/images/*')

In [4]:
# label_path = glob('./datasets/default/train/annotations/*')

In [5]:
# assert len(image_path) == 60578

In [6]:
# assert len(label_path)== 60578

### Understanding `max_patches` argument

The paper introduces a new paradigm for processing the input image. It takes the image and create `n_patches` aspect-ratio preserving patches, and concatenates the remaining sequence with padding tokens to finally get `max_patches` patches. It appears that this argument is quite crucial for training and evaluation, as the model becomes very sensitive to this parameter.

For the sake of our example, we will fine-tune a model with `max_patches=1024`.

Note that most of the `-base` models have been fine-tuned with `max_patches=2048`, and `4096` for `-large` models.

In [3]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
from tqdm.notebook import tqdm
MAX_PATCHES = 2048

class ImageCaptioningDataset(Dataset):
    def __init__(self, df, processor):
        self.dataset = df
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        row = self.dataset.iloc[idx, :]
        image = Image.open(row.image_path)
        encoding = self.processor(images=image,
                                  text="Generate underlying data table of the figure below:",
                                  font_path="arial.ttf",
                                  return_tensors="pt",
                                  add_special_tokens=True, max_patches=MAX_PATCHES)
        
        encoding = {k:v.squeeze() for k,v in encoding.items()}
        encoding["text"] = row.label
        return encoding
    

## Load model and processor

In [4]:
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration, AutoProcessor
model = Pix2StructForConditionalGeneration.from_pretrained("./deplot")
processor = AutoProcessor.from_pretrained("./deplot")
model.load_state_dict(torch.load("./weights/deplot_v4/deplot_v4v4v4v5_0.bin"))

<All keys matched successfully>

Now that we have loaded the processor, let's load the dataset and the dataloader:

In [5]:
def collator(batch):
    new_batch = {"flattened_patches": [], "attention_mask": []}
    texts = [item["text"] for item in batch]
    text_inputs = processor.tokenizer(text=texts,
                                      padding="max_length",
                                      return_tensors="pt",
                                      add_special_tokens=True,
                                      max_length=512,
                                      truncation=True
                                      )

    new_batch["labels"] = text_inputs.input_ids

    for item in batch:
        new_batch["flattened_patches"].append(item["flattened_patches"])
        new_batch["attention_mask"].append(item["attention_mask"])

    new_batch["flattened_patches"] = torch.stack(new_batch["flattened_patches"])
    new_batch["attention_mask"] = torch.stack(new_batch["attention_mask"])

    return new_batch

In [6]:
df = pd.read_csv('./datasets/default/train_with_fold.csv')
print(len(df))
df.reset_index(drop=True, inplace=True)
train_df = df
train_df

60578


Unnamed: 0,image_path,types,label
0,./datasets/default/train/images/2616369a5acb.jpg,vertical_bar,<x_start>Guyana;Haiti;High income...;Honduras;...
1,./datasets/default/train/images/05a4e1dd189f.jpg,vertical_bar,<x_start>Kent;Sweetwater;Lackawanna;Merrimack;...
2,./datasets/default/train/images/5a6cca560ac5.jpg,line,<x_start>1961;1965;1970;1975;1980;1985;1990;19...
3,./datasets/default/train/images/1d6b573f6615.jpg,vertical_bar,<x_start>ISingapore;Soloman Islands;South Afri...
4,./datasets/default/train/images/8f8be0bb2e11.jpg,line,<x_start>2002;2003;2004;2005;2006;2007;2008;20...
...,...,...,...
60573,./datasets/default/train/images/a98bfe7505fe.jpg,vertical_bar,<x_start>Namibia;Nepal;Netherlands;New Zealand...
60574,./datasets/default/train/images/0929103983d3.jpg,line,<x_start>Jan;Feb;Mar;Apr;May;Jun;Jul;Aug;Sep;O...
60575,./datasets/default/train/images/af0675ba4b7c.jpg,line,<x_start>0;2;4;6;8;10<x_end> <y_start>29.09381...
60576,./datasets/default/train/images/91245c2c262b.jpg,vertical_bar,<x_start>Taiwan;Tajikistan;Tanzania;Thailand;T...


In [7]:
class CFG:
    scheduler = 'cosine'  # ['linear', 'cosine']
    batch_scheduler = True
    num_cycles = 0.5  # 1.5
    num_warmup_steps = 0.2
    max_input_length = 130
    epochs = 1  # 5
    encoder_lr = 1e-6
    decoder_lr = 1e-6
    min_lr = 0.1e-6
    eps = 1e-6
    betas = (0.9, 0.999)
    weight_decay = 1e-6
    num_fold = 5
    batch_size = 4
    seed = 1006
    num_workers = 2
    device='cuda'
    print_freq = 100

In [8]:
train_dataset = ImageCaptioningDataset(train_df, processor)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=CFG.batch_size, collate_fn=collator, pin_memory=True,
                                  prefetch_factor=40, num_workers=2)

In [9]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(CFG.seed)

In [10]:
def get_scheduler(cfg, optimizer, num_train_steps):
    cfg.num_warmup_steps = cfg.num_warmup_steps * num_train_steps
    if cfg.scheduler == 'linear':
        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=cfg.num_warmup_steps, num_training_steps=num_train_steps
        )
    elif cfg.scheduler == 'cosine':
        scheduler = get_cosine_schedule_with_warmup(
            optimizer, num_warmup_steps=cfg.num_warmup_steps, num_training_steps=num_train_steps,
            num_cycles=cfg.num_cycles
        )
    return scheduler

num_train_steps = int(len(train_dataset) / CFG.batch_size * CFG.epochs)


## Train the model

Let's train the model! Run the simply the cell below for training the model. We have observed that finding the best hyper-parameters was quite challenging and required a lot of trials and errors, as the model can easily enter in "collapse-model" (always predicting the same output, no matter the input) if the HP are not chosen correctly. In this example, we found out that using `AdamW` optimizer with `lr=1e-5` seemed to be the best approach.

Let's also print the generation output of the model each 10 epochs!

Bear in mind that the model took some time to converge, for instance to get decent results we had to let the script run for ~1hour.

In [11]:
model = torch.nn.DataParallel(model, device_ids=[0,1])  # TODO: adapt to your hardware
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5, weight_decay=1e-6)
scheduler = get_scheduler(CFG, optimizer, num_train_steps)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(device)
model.to(device)
scaler = torch.cuda.amp.GradScaler()
model.train()

cuda:0


DataParallel(
  (module): Pix2StructForConditionalGeneration(
    (encoder): Pix2StructVisionModel(
      (embeddings): Pix2StructVisionEmbeddings(
        (patch_projection): Linear(in_features=768, out_features=768, bias=True)
        (row_embedder): Embedding(4096, 768)
        (column_embedder): Embedding(4096, 768)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (encoder): Pix2StructVisionEncoder(
        (layer): ModuleList(
          (0): Pix2StructVisionLayer(
            (attention): Pix2StructVisionAttention(
              (query): Linear(in_features=768, out_features=768, bias=False)
              (key): Linear(in_features=768, out_features=768, bias=False)
              (value): Linear(in_features=768, out_features=768, bias=False)
              (output): Linear(in_features=768, out_features=768, bias=False)
            )
            (mlp): Pix2StructVisionMlp(
              (wi_0): Linear(in_features=768, out_features=2048, bias=False)
              (wi_1): 

In [12]:
import matplotlib.pyplot as plt

loss_file = open("./loss/loss_deplot_v10.txt","w")

for epoch in range(CFG.epochs):
    print("Epoch:", epoch)
    torch.cuda.empty_cache()
    for idx, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
        labels = batch.pop("labels").to(device)
        flattened_patches = batch.pop("flattened_patches").to(device)
        attention_mask = batch.pop("attention_mask").to(device)
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            outputs = model(flattened_patches=flattened_patches,
                            attention_mask=attention_mask,
                            labels=labels)

        loss = outputs.loss.mean()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        if idx % 100 == 0:
            print("Epoch:", epoch, "Loss:", loss.item(), f'lr : {scheduler.get_lr()[0]:.9f} ', sep=' ')
            loss_file.write(f"Epoch: {epoch}, Iteration: {idx}, Loss: {loss.item()}\n")
            loss_file.flush()
    if (epoch + 1) % 1 == 0:
        torch.save(model.module.state_dict(), f'./weights/deplot_v4/deplot_v10_{epoch}.bin')

loss_file.close()

Epoch: 0


  0%|          | 0/15145 [00:00<?, ?it/s]



Epoch: 0 Loss: 0.608335018157959 lr : 0.000000003 
Epoch: 0 Loss: 0.45563560724258423 lr : 0.000000333 
Epoch: 0 Loss: 0.7266737222671509 lr : 0.000000664 
Epoch: 0 Loss: 0.4343388080596924 lr : 0.000000994 
Epoch: 0 Loss: 0.9325616359710693 lr : 0.000001324 
Epoch: 0 Loss: 0.39354532957077026 lr : 0.000001654 
Epoch: 0 Loss: 0.9676792621612549 lr : 0.000001984 
Epoch: 0 Loss: 0.6883355379104614 lr : 0.000002645 
Epoch: 0 Loss: 0.7004773616790771 lr : 0.000002975 
Epoch: 0 Loss: 0.5748456120491028 lr : 0.000003305 
Epoch: 0 Loss: 0.896929144859314 lr : 0.000003635 
Epoch: 0 Loss: 0.7628474235534668 lr : 0.000003965 
Epoch: 0 Loss: 1.118822455406189 lr : 0.000004295 
Epoch: 0 Loss: 0.9545638561248779 lr : 0.000004626 
Epoch: 0 Loss: 0.6373425722122192 lr : 0.000004956 
Epoch: 0 Loss: 0.697874903678894 lr : 0.000005286 
Epoch: 0 Loss: 0.49559277296066284 lr : 0.000005616 
Epoch: 0 Loss: 1.3608222007751465 lr : 0.000005946 
Epoch: 0 Loss: 0.8845911026000977 lr : 0.000006276 
Epoch: 0 Loss

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Epoch: 0 Loss: 0.8615395426750183 lr : 0.000006513 
Epoch: 0 Loss: 1.308750867843628 lr : 0.000006389 
Epoch: 0 Loss: 0.5766546726226807 lr : 0.000006264 
Epoch: 0 Loss: 0.41859665513038635 lr : 0.000006138 
Epoch: 0 Loss: 0.7585686445236206 lr : 0.000006011 
Epoch: 0 Loss: 0.6086609363555908 lr : 0.000005884 
Epoch: 0 Loss: 0.7975767850875854 lr : 0.000005756 
Epoch: 0 Loss: 0.5260027647018433 lr : 0.000005628 
Epoch: 0 Loss: 0.5464826822280884 lr : 0.000005499 
Epoch: 0 Loss: 0.6549850702285767 lr : 0.000005370 
Epoch: 0 Loss: 0.534896731376648 lr : 0.000005240 
Epoch: 0 Loss: 1.1025620698928833 lr : 0.000005111 
Epoch: 0 Loss: 0.9155890345573425 lr : 0.000004981 
Epoch: 0 Loss: 0.5710170269012451 lr : 0.000004851 
Epoch: 0 Loss: 0.647935152053833 lr : 0.000004722 
Epoch: 0 Loss: 0.8183948993682861 lr : 0.000004593 
Epoch: 0 Loss: 0.6705788373947144 lr : 0.000004463 
Epoch: 0 Loss: 0.9018584489822388 lr : 0.000004335 
Epoch: 0 Loss: 1.0156874656677246 lr : 0.000004207 
Epoch: 0 Loss: