In [1]:
import json
import os

SPLIT = "val/"

PUBTABLES = '../../pubtables-1m/Donut_Annotations_PT/minimal/'
PUBTABNET = '../../pubtabnet/anns/'

ANN_PATH = PUBTABNET

json_list = os.listdir(ANN_PATH + SPLIT)

In [2]:
aux_list = []
html_list = []

for f in json_list:
    if(f[-6:-5] == "L"):
        html_list.append(f)
    else:
        aux_list.append(f)
json_list = aux_list

In [3]:
json_list[:10]

['PMC3752962_006_00.json',
 'PMC5285298_009_01.json',
 'PMC4721713_013_01.json',
 'PMC3189161_006_00.json',
 'PMC5921262_006_00.json',
 'PMC3940998_002_00.json',
 'PMC4211402_009_00.json',
 'PMC2796464_003_00.json',
 'PMC4714473_005_00.json',
 'PMC4900244_009_01.json']

In [4]:
from transformers import DonutProcessor

processor = DonutProcessor.from_pretrained("../../pubtabnet/Donut_PubTables_TML_Processor8k")

new_tokens = ["<table_extraction>", "<table>", "<row>", "<cell>", "<row_and_col_header>", "<row_header>", "<col_header>"]
new_tokens += ["<content_row_and_col_header>", "<content_row_header>", "<content_col_header>", "<content>"]

for i in range(2):
    for j in range(2):
        for k in range(2):
            new_tokens.append("<span_type=0" + str(i) + str(j) + str(k) + ">")
            new_tokens.append("<span_type=1" + str(i) + str(j) + str(k) + ">")


processor.tokenizer.add_tokens(new_tokens, special_tokens = False)

2023-12-06 19:36:32.294616: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-06 19:36:32.294642: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-06 19:36:32.294659: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


27

In [5]:
processor = DonutProcessor.from_pretrained("../../pubtabnet/Donut_PubTables_TML_Processor8k")

In [6]:
processor.tokenizer("<content_col_header>")

{'input_ids': [0, 44, 2], 'attention_mask': [1, 1, 1]}

In [7]:
def write_msg(msg):
    with open("msg.json", 'w') as out:
        json.dump({'msg': msg}, out, ensure_ascii=False, indent=4)

In [8]:
def cel2token(cell):
    if cell['span_type'][10:] != '0000':
        sequence = "<" + cell['span_type'] + ">"
        if cell['content_holder']:
            if cell['row_header'] and cell['col_header']:
                sequence += "<content_row_and_col_header>"
            elif cell['col_header']:
                sequence += "<content_col_header>"
            elif cell['row_header']:
                sequence += "<content_row_header>"
            else:
                sequence += "<content>"
            sequence += cell['content']
    else:
        sequence = ""
        if cell['content_holder']:
            if cell['row_header'] and cell['col_header']:
                sequence += "<row_and_col_header>"
            elif cell['col_header']:
                sequence += "<col_header>"
            elif cell['row_header']:
                sequence += "<row_header>"
            else:
                sequence += "<cell>"
            sequence += cell['content']
        
    return sequence

def row2token(row):
    sequence = "<row>"
    for cell in row:
        sequence += cel2token(cell)
    
    return sequence


def table2token(table):
    sequence = "<table>"
    for row in table:
        sequence += row2token(row)
    
    return sequence


def json2token(json):
    sequence = ""
    if('tables' in json):
        for table in json['tables']:
            sequence += table2token(table)

    return sequence

In [9]:
import os
import torch
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms

class DonutTableDataset(Dataset):
    def __init__(
        self,
        annotations,
        shuffle = True,
        split = "train",
        ignore_id = -100,
        prompt_end_token = None,
    ):            
        self.annotations = annotations
        
        self.split = split
        self.ignore_id = ignore_id
        
        
    def __len__(self):
        return len(self.annotations)
    
    
    def __getitem__(self, idx):
        
        file_name = self.annotations[idx]
        
        with open(ANN_PATH + SPLIT + file_name, encoding="utf-8") as f:
            annotation = json.load(f)
            
        try:
            target_sequence = json2token(annotation)
        except:
            print(file_name)
            raise cu
        
        input_ids = processor.tokenizer(
            target_sequence,
            add_special_tokens=True,
            max_length= 6000,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )["input_ids"].squeeze(0)

        labels = input_ids.clone()
        labels[labels == processor.tokenizer.pad_token_id] = self.ignore_id  
        
        encoding = dict(labels=labels,
                        file_name = file_name)
        
        return encoding

In [10]:
train_dataset = DonutTableDataset(json_list)

In [11]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)

In [12]:
len_dics = []

def add_lens(file_name, tokens):
    for i in range(len(tokens)):
        j = 0
        for token in tokens[i]:
            j += 1
            if(token == -100):
                break

        len_dics.append({'file': file_name[i], 'len': j})

In [13]:
import torch
from tqdm.auto import tqdm


for i, batch in enumerate(tqdm(train_dataloader)):
    batch = {k: v for k, v in batch.items()}
    labels = batch["labels"]
    filename = batch["file_name"]
    add_lens(filename, labels)

  0%|          | 0/9115 [00:00<?, ?it/s]

In [14]:
import json
with open(ANN_PATH + SPLIT[:-1] + "_lens.json", 'w') as out:
    json.dump(len_dics, out, ensure_ascii=False, indent=4)