In [1]:
import json
import os

PUBTABLES_ANN = '../../pubtables-1m/Donut_Annotations_PT/minimal/Val/'
PUBTABLES_IMG = '../../pubtables-1m/PubTables-1M-Structure/Val/'

PUBTAB_ANN = '../../pubtabnet/anns/train/'
PUBTAB_IMG = '../../pubtabnet/imgs/train/'

ANN_PATH = PUBTAB_ANN
IMAGE_PATH = PUBTAB_IMG
IMG_FORMAT = '.png'

with open(ANN_PATH[:-1] + '_trunc_filelist.json') as file:
    json_list = file.read().splitlines()
    
json_list = [item for item in json_list]

In [2]:
aux_list = []

for json_item in json_list:
    aux_list.append(json_item[:-5])

json_list = aux_list

In [3]:
json_list[:10]

['PMC6051241_006_00',
 'PMC3398398_014_00',
 'PMC2682793_006_00',
 'PMC6082926_006_00',
 'PMC2809502_009_00',
 'PMC3890642_004_00',
 'PMC3926592_006_01',
 'PMC548940_003_00',
 'PMC4394591_006_00',
 'PMC1906827_006_00']

In [4]:
from transformers import DonutProcessor, VisionEncoderDecoderModel
from transformers import VisionEncoderDecoderConfig

config = VisionEncoderDecoderConfig.from_pretrained("naver-clova-ix/donut-base")

image_size = [750, 750]
max_length = 1300#config.decoder.max_position_embeddings


config.encoder.image_size = image_size


processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base", config=config, ignore_mismatched_sizes=True)

Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


In [5]:
model.decoder.model.decoder.embed_positions

MBartLearnedPositionalEmbedding(1538, 1024)

In [6]:
import torch
from torch import nn, Tensor
import math

class PositionalEncoding(nn.Module):

    def __init__(self,  max_len: int, d_model: int, dropout: float = 0.05):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.offset = 2
        
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.pos_enc = pe.squeeze(1)

    def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
        """`input_ids' shape is expected to be [bsz x seqlen]."""

        bsz, seq_len = input_ids.shape[:2]
        
        pos_start = past_key_values_length
        pos_end = past_key_values_length+seq_len
        
        positions = self.pos_enc[pos_start:pos_end].expand(bsz, -1, -1)
        
        return self.dropout(positions)

In [7]:
pos_enc = PositionalEncoding(10, 4)

In [8]:
model.decoder.model.decoder.embed_positions = PositionalEncoding(4098, 1024)

In [9]:
from transformers import AddedToken

new_tokens = ["<table_extraction>", "<table>", "<row>", "<cell>", "<row_and_col_header>", "<row_header>", "<col_header>"]
new_tokens += ["1", "≥","≤"]
new_tokens += ["<b>", "</b>", "<i>", "</i>", "<sup>", "</sup>","<sub>", "</sub>"]

for i in range(2):
    for j in range(2):
        for k in range(2):
            new_tokens.append("<span_type=0" + str(i) + str(j) + str(k) + ">")
            new_tokens.append("<span_type=1" + str(i) + str(j) + str(k) + ">")

new_tokens = [AddedToken(tag, rstrip = True, lstrip=True, normalized=False) for tag in new_tokens]

processor.tokenizer.add_tokens(new_tokens, special_tokens = False)
model.decoder.resize_token_embeddings(len(processor.tokenizer))

Embedding(57559, 1024)

In [10]:
processor.image_processor.size = image_size[::-1] # should be (width, height)
processor.image_processor.do_align_long_axis = False


In [11]:
with open("msg.json", 'w') as out:
        json.dump({'outputs': []}, out, ensure_ascii=False, indent=4)

def write_msg(msg):
    with open("msg.json", encoding="utf-8") as f:
        json_data = json.load(f)
    
    with open("msg.json", 'w') as out:
        json_data['outputs'].append(msg)
        json.dump(json_data, out, ensure_ascii=False, indent=4)

In [12]:
def cel2token(cell):
    if cell['span_type'][10:] != '0000':
        sequence = "<" + cell['span_type'] + ">"
    else:
        sequence = ""
    
    if cell['content_holder']:
        if cell['row_header'] and cell['col_header']:
            sequence += "<row_and_col_header>"
        elif cell['col_header']:
            sequence += "<col_header>"
        elif cell['row_header']:
            sequence += "<row_header>"
        else:
            sequence += "<cell>"
        sequence += cell['content']
        
    return sequence

def row2token(row):
    sequence = "<row>"
    for cell in row:
        sequence += cel2token(cell)
    
    return sequence


def table2token(table):
    sequence = "<table>"
    for row in table:
        sequence += row2token(row)
    
    return sequence


def json2token(json):
    sequence = ""
    if('tables' in json):
        for table in json['tables']:
            sequence += table2token(table)

    return sequence

In [13]:
import os
import torch
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms

class DonutTableDataset(Dataset):
    def __init__(
        self,
        annotations,
        image_size,
        max_length,
        shuffle = True,
        split = "train",
        ignore_id = -100,
        prompt_end_token = None,
    ):            
        self.annotations = annotations
        
        
        self.image_size = image_size
        self.max_length = max_length
        self.split = split
        self.ignore_id = ignore_id
        
        
    def __len__(self):
        return len(self.annotations)
    
    
    def __getitem__(self, idx):
        
        file_name = self.annotations[idx]
        
        with open(ANN_PATH + file_name + ".json", encoding="utf-8") as f:
            annotation = json.load(f)
        
        image = Image.open(IMAGE_PATH + file_name + IMG_FORMAT)
        
        
        # inputs
        pixel_values = processor(image.convert("RGB"), random_padding=self.split == "train", return_tensors="pt").pixel_values.squeeze()
        pixel_values = pixel_values.squeeze()
        
        target_sequence = "<s>"+json2token(annotation)+"</s>"
        
        input_ids = processor.tokenizer(
            target_sequence,
            add_special_tokens=True,
            max_length= max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )["input_ids"].squeeze(0)

        labels = input_ids.clone()
        
        labels[labels == processor.tokenizer.pad_token_id] = self.ignore_id
        
        
        encoding = dict(pixel_values=pixel_values,
                        labels=labels)
        
        return encoding

In [14]:
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(["<table_extraction>"])[0] 


In [15]:
train_dataset = DonutTableDataset(json_list,
                             max_length = max_length,
                             image_size = image_size)

In [16]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)

In [17]:
#model = torch.load("model_epoch_checkpoint_.bin")
start_epoch = 0
avg_size = 1000

In [18]:
import torch
from tqdm.auto import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu' 
model.to(device)
optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-4)


for epoch in range(start_epoch, 1):
    print("Epoch:", epoch+1)
    mean_loss = 0
    mean_smpl_loss = 0 
    model.train()
    for i, batch in enumerate(tqdm(train_dataloader)):
        
        batch = {k: v.to(device) for k, v in batch.items()}
        
        pixel_values = batch["pixel_values"]
        labels = batch["labels"]
        
        
        outputs = model(pixel_values=pixel_values, labels=labels)
        
        
        loss = outputs.loss
        loss.backward()
        
        optimizer.step()
        optimizer.zero_grad()
        
        mean_loss += loss.item()   
        mean_smpl_loss += loss.item() 
        if i % avg_size == 0:
            print(str(i) + " Loss: ", mean_smpl_loss/avg_size)
            write_msg("batch " + str(i) +" loss: "+ str(mean_smpl_loss/avg_size))
            mean_smpl_loss = 0 
        
        if i % 10000 == 0:
             model.save_pretrained("../../pubtabnet/TML-lre-5-checkpoint")
    
    model.save_pretrained("../../pubtabnet/TML-lre-5-checkpoint")
    print("Epoch's mean loss: ", mean_loss/len(train_dataloader))
    
    write_msg("Epoch checkpointed: " + str(epoch+1) +" \n"+
              "Epoch's mean Loss: " + str(mean_loss/len(train_dataloader)))

Epoch: 1


  0%|          | 0/125195 [00:00<?, ?it/s]

0 Loss:  0.016866207122802734


KeyboardInterrupt: 

In [None]:
len(train_dataset)

In [None]:
model.save_pretrained("../../pubtabnet/model-minimal-epoch1")

In [None]:
processor.save_pretrained("../../pubtabnet/Donut_PubTables_TML_Processor")