**Import**

In [1]:
import random
import pandas as pd
import numpy as np
import os
import re
import glob
import math
from pathlib import Path
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

from transformers import (
    DonutProcessor,
    VisionEncoderDecoderConfig,
    VisionEncoderDecoderModel,
    get_scheduler
)

import wandb

from tqdm.auto import tqdm

import warnings
warnings.filterwarnings(action='ignore') 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device.type

'cuda'

**Hyperparameter Settings**

In [3]:
CFG = {
    'SEED':42,
    'NUM_WORKERS':4,
    'NUM_PROC':1,
    'IMG_HEIGHT':0,
    'IMG_WIDTH':0,
    'MAX_LEN':10000
}

**Fix Seeds**

In [4]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(CFG['SEED'])

**Data Pre-processing**

In [5]:
type_dict = {0:"unk", 1:"nm", 2:"ing", 3:"exp", 4:"how", 5:"des", 9:"etc"}

In [6]:
def get_gt_strings(ct):
    gt_string = ""
    flag = 1
    tp = -1
    for i, item in enumerate(ct):
        if flag:
            gt_string = gt_string + f'<{type_dict[item[0]]}>'
            tp = item[0]
            flag = 0
            gt_string = gt_string + f'{item[1]}'
        
        elif not flag:
            gt_string = gt_string + f' {item[1]}'
        
        if i == len(ct)-1 or ct[i+1][0] != tp:
            gt_string = gt_string + f'</{type_dict[item[0]]}>'
            flag = 1
    
    return gt_string

----------Save Above------------

In [7]:
def gen_data(df, split='train'):
    
    img_path = Path("../data/train") if split == 'train' else Path("../data/test")

    for index, row in df.iterrows():
         yield {
              "id": row['image_path'].split('.')[0],
              "groud_truth": get_gt_strings(eval(row['texts'])),
              "image_path": img_path / row['image_path'], 
              "image_height": row['image_size'][0],
              "image_width": row['image_size'][1],
              "image_class": row['image_class'],
        }


In [8]:
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")

Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


In [9]:
train_df = pd.read_csv("../dataframes/train_annot_df.csv")

In [11]:
temp = get_gt_strings(eval(train_df.loc[0, 'texts']))

In [18]:
ids = processor.tokenizer(temp).input_ids
tokens = [processor.tokenizer.tokenize(x, add_special_tokens=True) for x in temp]

In [20]:
unk_tokens = []
for example_ids, example_tokens in zip(ids, tokens):
    example_unk_tokens = []
    for i in range(len(example_ids)):
        if example_ids[i] == processor.tokenizer.unk_token_id:
            example_unk_tokens.append(example_tokens[i])

    unk_tokens.append(example_unk_tokens)

unk_tokens

TypeError: object of type 'int' has no len()

In [16]:
decoded = []

for id in ids:
    decoded.append(processor.decode([id]))

In [17]:
decoded

['<s>',
 '<',
 'n',
 'm',
 '>',
 'Do',
 've',
 'she',
 'a',
 'but',
 'ter',
 '시',
 '어',
 '버',
 '터',
 '',
 '<unk>',
 '티',
 '바',
 'beauty',
 'cream',
 'bar',
 '</',
 'n',
 'm',
 '><',
 'des',
 '>',
 'B',
 'eau',
 'ty',
 'bar',
 'with',
 'This',
 'beauty',
 'bar',
 'la',
 'ther',
 'that',
 'will',
 'with',
 'leave',
 'your',
 'skin',
 'feeling',
 'she',
 'a',
 'but',
 'ter',
 'and',
 'warm',
 '1',
 '<unk>',
 '4',
 'mois',
 'tur',
 'izing',
 'van',
 'illa',
 'scen',
 't',
 '.',
 'cream',
 'create',
 's',
 'a',
 'cream',
 'y',
 'resto',
 'red',
 '.',
 '시',
 '어',
 '버',
 '터',
 '와',
 '1',
 '<unk>',
 '4',
 '부',
 '드',
 '<unk>',
 '게',
 '가',
 '꿔',
 '줍니다',
 '.',
 '모',
 '이',
 '스',
 '처',
 '라이',
 '징',
 '달',
 '<unk>',
 '한',
 '바',
 '<unk>',
 '라',
 '향',
 '의',
 '',
 '<unk>',
 '티',
 '바',
 '크',
 '림',
 '을',
 '함',
 '유',
 '한',
 '',
 '<unk>',
 '티',
 '바',
 '의',
 '크',
 '리',
 '미',
 '한',
 '거',
 '품',
 '이',
 '피',
 '부를',
 '',
 '촉',
 '촉',
 '하고',
 '</',
 'des',
 '><',
 'exp',
 '>',
 '20',
 '24.',
 '04.',
 '01',
 '',
 '까

In [None]:
added_tokens = []

class DonutDataset(Dataset):

    def __init__(
        self,
        dataframe_path: str,
        max_length: int,
        processor: DonutProcessor,
        split: str = "train",
        ignore_id: int = -100,
    ):
        super().__init__()

        self.max_length = max_length
        self.split = split
        self.ignore_id = ignore_id
        self.dataframe = pd.read_csv("../dataframes/train_annot_df.csv") if self.split == "train" else  pd.read_csv("../dataframes/test_annot_df.csv")
        self.dataframe_length = len(self.dataframe)
        self.gt_token_sequences = []
        
        for idx, sample in self.dataframe.iterrows():
            ground_truth = get_gt_strings(eval(sample['texts']))
            if "gt_parses" in ground_truth:  # when multiple ground truths are available, e.g., docvqa
                assert isinstance(ground_truth["gt_parses"], list)
                gt_jsons = ground_truth["gt_parses"]
            else:
                assert "gt_parse" in ground_truth and isinstance(ground_truth["gt_parse"], dict)
                gt_jsons = [ground_truth["gt_parse"]]

            self.gt_token_sequences.append(
                [
                    self.json2token(
                        gt_json,
                        update_special_tokens_for_json_key=self.split == "train",
                        sort_json_key=self.sort_json_key,
                    )
                    + processor.tokenizer.eos_token
                    for gt_json in gt_jsons  # load json from list of json
                ]
            )

    def json2token(self, obj: Any, update_special_tokens_for_json_key: bool = True, sort_json_key: bool = True):
        """
        Convert an ordered JSON object into a token sequence
        """
        if type(obj) == dict:
            if len(obj) == 1 and "text_sequence" in obj:
                return obj["text_sequence"]
            else:
                output = ""
                if sort_json_key:
                    keys = sorted(obj.keys(), reverse=True)
                else:
                    keys = obj.keys()
                for k in keys:
                    if update_special_tokens_for_json_key:
                        self.add_tokens([fr"", fr""])
                    output += (
                        fr""
                        + self.json2token(obj[k], update_special_tokens_for_json_key, sort_json_key)
                        + fr""
                    )
                return output
        elif type(obj) == list:
            return r"".join(
                [self.json2token(item, update_special_tokens_for_json_key, sort_json_key) for item in obj]
            )
        else:
            obj = str(obj)
            if f"<{obj}/>" in added_tokens:
                obj = f"<{obj}/>"  # for categorical special tokens
            return obj
    
    def add_tokens(self, list_of_tokens: List[str]):
        """
        Add special tokens to tokenizer and resize the token embeddings of the decoder
        """
        newly_added_num = processor.tokenizer.add_tokens(list_of_tokens)
        if newly_added_num > 0:
            model.decoder.resize_token_embeddings(len(processor.tokenizer))
            added_tokens.extend(list_of_tokens)
    
    def __len__(self) -> int:
        return self.dataset_length

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Load image from image_path of given dataset_path and convert into input_tensor and labels
        Convert gt data into input_ids (tokenized string)
        Returns:
            input_tensor : preprocessed image
            input_ids : tokenized gt_data
            labels : masked labels (model doesn't need to predict prompt and pad token)
        """
        sample = self.dataset[idx]

        # inputs
        pixel_values = processor(sample["image"], random_padding=self.split == "train", return_tensors="pt").pixel_values
        pixel_values = pixel_values.squeeze()

        # targets
        target_sequence = random.choice(self.gt_token_sequences[idx])  # can be more than one, e.g., DocVQA Task 1
        input_ids = processor.tokenizer(
            target_sequence,
            add_special_tokens=False,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )["input_ids"].squeeze(0)

        labels = input_ids.clone()
        labels[labels == processor.tokenizer.pad_token_id] = self.ignore_id  # model doesn't need to predict pad token
        # labels[: torch.nonzero(labels == self.prompt_end_token_id).sum() + 1] = self.ignore_id  # model doesn't need to predict prompt (for VQA)
        return pixel_values, labels, target_sequence