In [None]:
!pip install -q git+https://github.com/huggingface/transformers.git

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [None]:
!pip install -q lightning wandb
!pip install polyleven
!pip install -q datasets

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


**Library**

In [None]:
import cv2
import os
import re
import json
from collections import Counter
from itertools import chain
from pathlib import Path
from typing import List, Dict, Union, Tuple, Any

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from PIL import Image
from sklearn.model_selection import StratifiedKFold
from torch.utils.data import DataLoader
from transformers import (
    DonutProcessor,
    VisionEncoderDecoderConfig,
    VisionEncoderDecoderModel,
    get_scheduler
)
from datasets import Dataset
from datasets import Image as ds_img
from polyleven import levenshtein # a faster version of levenshtein

**Mounted Drive**

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!unzip /content/drive/MyDrive/miniProject/benetech-making-graphs-accessible.zip -d dataset

Archive:  /content/drive/MyDrive/miniProject/benetech-making-graphs-accessible.zip
replace dataset/sample_submission.csv? [y]es, [n]o, [A]ll, [N]one, [r]ename: n
replace dataset/test/images/000b92c3b098.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: n
replace dataset/test/images/007a18eb4e09.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

**Config**

In [None]:
data_dir = Path("/content/dataset/train")
images_path = data_dir / "images"
train_json_files = list((data_dir / "annotations").glob("*.json"))

class CFG:

    # General
    debug = False
    num_proc = 2
    num_workers = 2
    gpus = 1

    # Data
    max_length = 512
    image_height = 256
    image_width = 256
    max_patch = 1024

    # Training
    epochs = 2
    val_check_interval = 1.0  # how many times we want to validate during an epoch
    check_val_every_n_epoch = 1
    gradient_clip_val = 1.0
    lr = 3e-5
    lr_scheduler_type = "cosine"
    num_warmup_steps = 100
    seed = 42
    warmup_steps = 300  
    output_path = "/content/output"
    log_steps = 200
    batch_size = 2
    use_wandb = False

In [None]:
PROMPT_TOKEN = "<|PROMPT|>"
X_START = "<x_start>"
X_END = "<x_end>"
Y_START = "<y_start>"
Y_END = "<y_end>"
PROMPT_END_TOKEN = "</|PROMPT|>"

SEPARATOR_TOKENS = [
    PROMPT_TOKEN,
    X_START,
    X_END,
    Y_START,
    Y_END,
    PROMPT_END_TOKEN
]

LINE_TOKEN =  "<line>" 
VERTICAL_BAR_TOKEN = "<vertical_bar>"
# HORIZONTAL_BAR_TOKEN = "<horizontal_bar>"
# SCATTER_TOKEN = "<scatter>"
DOT_TOKEN = "<dot>"

CHART_TYPE_TOKENS = [
    LINE_TOKEN,
    VERTICAL_BAR_TOKEN,
    # HORIZONTAL_BAR_TOKEN,
    # SCATTER_TOKEN,
    DOT_TOKEN
]

new_tokens = SEPARATOR_TOKENS + CHART_TYPE_TOKENS

In [None]:
def is_nan(value: Union[int, float, str]) -> bool:
     return isinstance(value, float) and str(value) == "nan"

def round_float(value: Union[int, float, str]) -> Union[str, float]:
    if isinstance(value, float):
        value = str(value)

        if "." in value:
            integer, decimal = value.split(".")
            if abs(float(integer)) > 1:
                decimal = decimal[:1]
            else:
                decimal = decimal[:4]

            value = integer + "." + decimal
    return value

def get_gt_string_and_xy(filepath: Union[str, os.PathLike]) -> Dict[str, str]:
    filepath = Path(filepath)

    with open(filepath) as fp:
        data = json.load(fp)

    data_series = data["data-series"]

    all_x, all_y = [], []

    for d in data_series:
        x = d["x"]
        y = d["y"]

        x = round_float(x)
        y = round_float(y)

        # Ignore nan values
        if is_nan(x) or is_nan(y):
            continue

        all_x.append(x)
        all_y.append(y)
        
    
    if data['chart-type'] in ['horizontal_bar','scatter']:
       return None
    
    chart_type = f"<{data['chart-type']}>"
    x_str = X_START + ";".join(list(map(str, all_x))) + X_END
    y_str = Y_START + ";".join(list(map(str, all_y))) + Y_END
    
    gt_string = PROMPT_TOKEN + chart_type + x_str + y_str + PROMPT_END_TOKEN

    return {
        "ground_truth": gt_string,
        "x": json.dumps(all_x),
        "y": json.dumps(all_y),
        "chart-type": data["chart-type"],
        "id": filepath.stem,
        "source": data["source"],
    }

In [None]:
get_gt_string_and_xy(data_dir / "annotations" / "000d269c8e26.json")

{'ground_truth': '<|PROMPT|><line><x_start>0;2;4;6;8;10;12<x_end><y_start>45.8;45.9;46.3;46.1;46.1;47.0;47.4<y_end></|PROMPT|>',
 'x': '["0", "2", "4", "6", "8", "10", "12"]',
 'y': '["45.8", "45.9", "46.3", "46.1", "46.1", "47.0", "47.4"]',
 'chart-type': 'line',
 'id': '000d269c8e26',
 'source': 'generated'}

In [None]:
from tqdm import tqdm
from glob import glob

stone = None
ANNOTATION = "/content/dataset/train/annotations/*.json"
for file_name in tqdm(glob(ANNOTATION)):
    row = get_gt_string_and_xy(file_name)
    if row is None:
      continue
    if stone is None:
       stone = pd.DataFrame([row])
    else:
      stone = pd.concat([stone, pd.DataFrame([row])], ignore_index=True,axis=1)    


In [None]:
def gen_data(files: List[Union[str, os.PathLike]]) -> Dict[str, str]:
    for f in files:
        res = get_gt_string_and_xy(f)
        if res is None:
           continue
        yield {
            **res,
            "image_path": str(images_path / f"{f.stem}.jpg"),
        }


ds = Dataset.from_generator(
    gen_data, gen_kwargs={"files": train_json_files}, num_proc=CFG.num_proc
)



In [None]:
def add_image_sizes(examples: Dict[str, Union[str, os.PathLike]]) -> Dict[str, List[int]]:
    sizes = [Image.open(x).size for x in examples["image_path"]]

    width, height = list(zip(*sizes))

    return {
        "width": list(width),
        "height": list(height),
    }


ds = ds.map(add_image_sizes, batched=True, num_proc=CFG.num_proc)

Map (num_proc=2):   0%|          | 0/49262 [00:00<?, ? examples/s]

**Load Model**

In [None]:
from transformers import Pix2StructForConditionalGeneration, AutoProcessor

repo_id = "hoangphu7122002ai/pix2struct_v0"
processor = AutoProcessor.from_pretrained(repo_id, is_vqa=False)
model = Pix2StructForConditionalGeneration.from_pretrained(repo_id)

Downloading (…)rocessor_config.json:   0%|          | 0.00/303 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/2.45k [00:00<?, ?B/s]

Downloading spiece.model:   0%|          | 0.00/851k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/3.27M [00:00<?, ?B/s]

Downloading (…)in/added_tokens.json:   0%|          | 0.00/198 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/4.99k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.13G [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/164 [00:00<?, ?B/s]

In [None]:
processor.image_processor.size = {
    "height": CFG.image_height,
    "width": CFG.image_width,
}

processor.tokenizer.add_tokens(new_tokens)
model.resize_token_embeddings(len(processor.tokenizer))

Embedding(50353, 768)

In [None]:
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids([PROMPT_TOKEN])[0]
model.config.decoder_end_token_id = processor.tokenizer.convert_tokens_to_ids([PROMPT_END_TOKEN])[0]

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

def augments():
    return A.Compose([
        A.Resize(width=CFG.image_width, height=CFG.image_height),
        A.Normalize(
            mean=[0, 0, 0],
            std=[1, 1, 1],
            max_pixel_value=255,
        ),
        ToTensorV2(),
    ])

In [None]:
from functools import partial

def preprocess(examples,processor: DonutProcessor, CFG: CFG):
    item = examples
    # prepare inputs
    # encoding = processor(images=item["image_path"], max_patches=CFG.max_patch, return_tensors="pt")
    # if len(item["image_path"]) == 1:
    #    item["image_path"] = augments(image=item)['image']
    item["image_path"] = [augments()(image=np.array(image))['image'] for image in item["image_path"]]
    encoding = processor(images=item["image_path"], max_patches=CFG.max_patch, return_tensors="pt")
    encoding = {k:v.squeeze() for k,v in encoding.items()}
    
    # prepare targets
    target_sequence = item["ground_truth"]
    input_ids = processor.tokenizer(
        target_sequence,
        max_length=CFG.max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    ).input_ids

    labels = input_ids.squeeze().clone()
    labels[labels == processor.tokenizer.pad_token_id] = processor.tokenizer.pad_token_id  # model doesn't need to predict pad token
    encoding["labels"] = labels
    # labels[: torch.nonzero(labels == self.prompt_end_token_id).sum() + 1] = self.ignore_id  # model doesn't need to predict prompt (for VQA)
    return {
        'flattened_patches' : encoding['flattened_patches'],
        'attention_mask' : encoding['attention_mask'],
        'text' : target_sequence,
        'labels' : encoding['labels']
    }

In [None]:
image_ds = ds.cast_column("image_path", ds_img())
image_ds.set_transform(partial(preprocess, processor=processor, CFG=CFG))

In [None]:
image_ds[0]

In [None]:
train_len = int(len(ds) * 0.9)
test_len = int(len(ds) * 0.95)
val_len = int(len(ds))
train_ds = ds.select(*[range(0,train_len)])
train_ds = train_ds.cast_column("image_path", ds_img())
train_ds.set_transform(partial(preprocess, processor=processor, CFG=CFG))

test_ds = ds.select(*[range(train_len,test_len)])
test_ds = test_ds.cast_column("image_path", ds_img())
test_ds.set_transform(partial(preprocess, processor=processor, CFG=CFG))

val_ds = ds.select(*[range(test_len,val_len)])
val_ds = val_ds.cast_column("image_path", ds_img())
val_ds.set_transform(partial(preprocess, processor=processor, CFG=CFG))

In [None]:
train_ds[1]

In [None]:
def collate_fn(batch):
    new_batch = {"flattened_patches":[], "attention_mask":[], "text" : []}
    texts = [item["text"] for item in batch]
    text_inputs = processor(
        text=texts, 
        padding="max_length", 
        truncation=True, 
        return_tensors="pt", 
        add_special_tokens=True, 
        max_length=CFG.max_length
    )
    new_batch["labels"] = text_inputs.input_ids
    for item in batch:
        new_batch["flattened_patches"].append(item["flattened_patches"])
        new_batch["attention_mask"].append(item["attention_mask"])
        new_batch["text"].append(item["text"])
    new_batch["flattened_patches"] = torch.stack(new_batch["flattened_patches"])
    new_batch["attention_mask"] = torch.stack(new_batch["attention_mask"])

    return new_batch

In [None]:
if CFG.debug:
    train_ds = train_ds.select(range(100))
    val_ds = val_ds.select(range(100))

train_dataloader = DataLoader(
    train_ds,
    batch_size=CFG.batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=CFG.num_workers,
)
val_dataloader = DataLoader(
    val_ds,
    batch_size=CFG.batch_size,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=CFG.num_workers,
)

test_dataloader = DataLoader(
    test_ds,
    batch_size=CFG.batch_size,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=CFG.num_workers,
)



num_training_steps = len(train_dataloader) * CFG.epochs // CFG.gpus

batch = next(iter(train_dataloader))

batch.keys(), [(k, v) for k, v in batch.items()]

(dict_keys(['flattened_patches', 'attention_mask', 'text', 'labels']),
 [('flattened_patches',
   tensor([[[ 1.0000,  1.0000,  0.3368,  ...,  0.3368,  0.3368,  0.3368],
            [ 1.0000,  2.0000,  0.3368,  ...,  0.3368,  0.3368,  0.3368],
            [ 1.0000,  3.0000,  0.3368,  ...,  0.3368,  0.3368,  0.3368],
            ...,
            [32.0000, 30.0000,  0.3368,  ...,  0.3368,  0.3368,  0.3368],
            [32.0000, 31.0000,  0.3368,  ...,  0.3368,  0.3368,  0.3368],
            [32.0000, 32.0000,  0.3368,  ...,  0.3368,  0.3368,  0.3368]],
   
           [[ 1.0000,  1.0000, -0.4650,  ..., -0.2553, -0.2553, -0.2553],
            [ 1.0000,  2.0000, -0.4650,  ..., -0.2553, -0.2553, -0.2553],
            [ 1.0000,  3.0000, -0.4650,  ..., -0.2553, -0.2553, -0.2553],
            ...,
            [32.0000, 30.0000, -0.2553,  ..., -0.2553, -0.2553, -0.2553],
            [32.0000, 31.0000, -0.2553,  ..., -0.2553, -0.2553, -0.2553],
            [32.0000, 32.0000, -0.2553,  ..., -0.412

In [None]:
batch['labels']

tensor([[50344, 50351, 50345,  ...,     0,     0,     0],
        [50344, 50350, 50345,  ...,     0,     0,     0]])

In [None]:
print(processor.batch_decode([id for id in batch["labels"][0].squeeze().tolist() if id != 0]))

['<|PROMPT|>', '<vertical_bar>', '<x_start>', 'Bur', 'und', 'i', ';', 'C', 'ambo', 'dia', ';', 'Ca', 'mer', 'oon', ';', 'Canada', ';', 'Cap', 'e', 'Verde', ';', 'Central', 'African', 'Republic', ';', 'C', 'had', ';', 'Chi', 'le', ';', 'China', ';', 'Col', 'om', 'bia', '<x_end>', '<y_start>', '', '4', '7', '.', '4', ';', '7', '0', '.', '7', ';', '1', '0', '0', '.', '6', ';', '4', '0', '3', '.', '4', ';', '1', '0', '3', '.', '9', ';', '7', '4', '.', '0', ';', '4', '7', '.', '4', ';', '1', '4', '7', '.', '2', ';', '3', '0', '6', '.', '9', ';', '1', '2', '3', '.', '9', '<y_end>', '</|PROMPT|>', '</s>']


In [None]:
model.config.text_config.is_decoder=True

In [None]:

from pathlib import Path
import re
from nltk import edit_distance
import numpy as np
import wandb

import torch

from transformers.optimization import Adafactor, get_cosine_schedule_with_warmup

import pytorch_lightning as pl
from torch.cuda.amp import GradScaler, autocast

class Pix2Struct(pl.LightningModule):
    def __init__(self, config, processor, model):
        super().__init__()
        self.CFG = config
        self.processor = processor
        self.model = model

    def training_step(self, batch, batch_idx):
        
        outputs = self.model(flattened_patches=batch["flattened_patches"],
                    attention_mask=batch['attention_mask'],
                    labels=batch['labels'])
        loss = outputs.loss
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx, dataset_idx=0):
        answers = batch["text"]
        flattened_patches, attention_mask = batch["flattened_patches"], batch["attention_mask"]
        batch_size = flattened_patches.shape[0]
        # we feed the prompt to the model
        decoder_input_ids = torch.full((batch_size, 1),self.model.config.decoder_start_token_id, device=self.device)
        
        outputs = self.model.generate(flattened_patches=flattened_patches,
                                      attention_mask=attention_mask,
                                      # decoder_input_ids=decoder_input_ids,
                                      pad_token_id=self.model.config.pad_token_id,
                                      eos_token_id=self.model.config.decoder_end_token_id,
                                      max_new_tokens=512,
                                      return_dict_in_generate=True,
                                      )
#         print(outputs.sequences)
        predictions = []
        for seq in self.processor.tokenizer.batch_decode(outputs.sequences,skip_special_tokens=True):
            seq = seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
            # seq = re.sub(r"<.*?>", "", seq, count=1).strip()  # remove first task start token
            predictions.append(seq)

        scores = []
        for pred, answer in zip(predictions, answers):
            # pred = re.sub(r"(?:(?<=>) | (?=", "", answer, count=1)
            answer = answer.replace(self.processor.tokenizer.eos_token, "")
            scores.append(edit_distance(pred, answer) / max(len(pred), len(answer)))
            
            if len(scores) == 1:
                print(f"Prediction: {pred}")
                print(f"    Answer: {answer}")
                print(f" Normed ED: {scores[0]}")

        self.log("val_edit_distance", np.mean(scores)) 
        
        return scores

    def configure_optimizers(self):
        optimizer = Adafactor(self.parameters(), scale_parameter=False, relative_step=False, lr=self.CFG.lr, weight_decay=1e-05)
        scheduler = get_cosine_schedule_with_warmup(optimizer,
                                                    num_warmup_steps=self.CFG.num_warmup_steps,
                                                    num_training_steps=self.CFG.warmup_steps)
        
        return [optimizer], [scheduler]

    def train_dataloader(self):
        return train_dataloader

    def val_dataloader(self):
        return val_dataloader

In [None]:
pl_module = Pix2Struct(CFG, processor, model)

In [None]:
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import Callback, EarlyStopping

loggers = []

# if CFG.use_wandb:
import wandb

In [None]:
wandb.finish()
wandb_logger = WandbLogger(project="Pix2Struct", name="demo-run-pix2struct-adafactor-colab")

class PushToHubCallback(Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        print(f"Pushing model to the hub, epoch {trainer.current_epoch}")
        pl_module.model.push_to_hub("hoangphu7122002ai/donutAxis_v1",
                                    commit_message=f"Training in progress, epoch {trainer.current_epoch}")
        pl_module.processor.push_to_hub("hoangphu7122002ai/donutAxis_v1",
                                    commit_message=f"Training in progress, epoch {trainer.current_epoch}")

    def on_train_end(self, trainer, pl_module):
        print(f"Pushing model to the hub after training")
        pl_module.processor.push_to_hub("hoangphu7122002ai/donutAxis_v1",
                                    commit_message=f"Training done")
        pl_module.model.push_to_hub("hoangphu7122002ai/donutAxis_v1",
                                    commit_message=f"Training done")

early_stop_callback = EarlyStopping(monitor="val_edit_distance", patience=3, verbose=False, mode="min")

trainer = pl.Trainer(
        accelerator="gpu",
        devices=1,
        max_epochs=CFG.epochs,
        val_check_interval=CFG.val_check_interval,
        check_val_every_n_epoch=CFG.check_val_every_n_epoch,
        gradient_clip_val=CFG.gradient_clip_val,
        precision='16-mixed', # if you have tensor cores (t4, v100, a100, etc.) training will be 2x faster
        num_sanity_val_steps=2,
        # callbacks=[PushToHubCallback()], 
        logger=wandb_logger,
)


trainer.fit(pl_module, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.utilities.rank_zero:`Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type                               | Params
-------------------------------------------------------------
0 | model | Pix2StructForConditionalGeneration | 282 M 
-------------------------------------------------------------
282 M     Trainable params
0         Non-trainable params
282 M     Tota

Sanity Checking: 0it [00:00, ?it/s]

Prediction: <|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT|><|PROMPT