In [1]:
import json
import os

import sys
sys.path.insert(0, '../src')
from transformers import DonutProcessor
from modeling_tabeleiro import TabeleiroModel
from processing_tabeleiro import TabeleiroProcessor

ANN_PATH = '../../aux/data/anns/train/'
IMAGE_PATH = '../../aux/data/imgs/train/'

PROCESSORS_PATH = "../../aux/processors/"
MODELS_PATH = "../../aux/models/"

IMG_FORMAT = '.png'

json_list = os.listdir(ANN_PATH)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
aux_list = []

for json_item in json_list:
    if json_item[-6] != "L":
        aux_list.append(json_item[:-5])

json_list = aux_list

In [3]:
json_list = json_list
json_list[:10]

['PMC3549936_002_00',
 'PMC3112512_009_00',
 'PMC3772909_007_00',
 'PMC5294720_004_00',
 'PMC3446317_008_00',
 'PMC5071235_005_00',
 'PMC1421390_001_00',
 'PMC5251219_004_01',
 'PMC3973966_003_00',
 'PMC4147206_002_00']

In [4]:
from transformers import DonutProcessor, VisionEncoderDecoderModel
from transformers import VisionEncoderDecoderConfig
import torch

config = VisionEncoderDecoderConfig.from_pretrained(MODELS_PATH+"donut-base")

image_size = [640, 640]
max_length = 960


config.encoder.image_size = image_size
processor = TabeleiroProcessor.from_pretrained(PROCESSORS_PATH+"Donut_PubTables_TML_Processor8k")

In [5]:
new_tokens = ["<table_extraction>", "<table>", "<row>", "<cell>", "<row_and_col_header>", "<row_header>", "<col_header>"]
new_tokens += ["<content_row_and_col_header>", "<content_row_header>", "<content_col_header>", "<content>"]

for i in range(2):
    for j in range(2):
        for k in range(2):
            new_tokens.append("<span_type=0" + str(i) + str(j) + str(k) + ">")
            new_tokens.append("<span_type=1" + str(i) + str(j) + str(k) + ">")


processor.tokenizer.add_tokens(new_tokens, special_tokens = False)

0

In [6]:
cell_types = ["<cell>", "<row_and_col_header>", "<row_header>", "<col_header>"]
for i in range(2):
    for j in range(2):
        for k in range(2):
            cell_types.append("<span_type=0" + str(i) + str(j) + str(k) + ">")
            cell_types.append("<span_type=1" + str(i) + str(j) + str(k) + ">")


cell_tokens = [processor.tokenizer.convert_tokens_to_ids([cell_type])[0] for cell_type in cell_types]
row_tokens = [processor.tokenizer.convert_tokens_to_ids([row_type])[0] for row_type in ['<row>']]

In [7]:
processor.tokenizer.decode(row_tokens)

'<row>'

In [11]:
model = TabeleiroModel.from_pretrained("naver-clova-ix/donut-base",
                                       from_donut=True,
                                       decoder_extra_config={"pos_counters":[cell_tokens, row_tokens]},
                                       donut_config = config,
                                       ignore_mismatched_sizes=True)
model.decoder.resize_token_embeddings(len(processor.tokenizer))

Config of the decoder: <class 'modeling_dimbart.DiMBartForCausalLM'> is overwritten by shared decoder config: DiMBartConfig {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_cross_attention": true,
  "add_final_layer_norm": true,
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classifier_dropout": 0.0,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 4,
  "dim_max_position_embeddings": [
    512,
    120,
    120
  ],
  "dropout": 0.1,
  "encoder_attention_heads": 16,
  "encoder_ffn_dim": 4096,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 12,
  "eos_token_id": 2,
  "forced_eos_token_id": 2,
  "init_std": 0.02,
  "is_decoder": true,
  "is_encoder_decoder": false,
  "max_position_embeddings": 1536,
  "model_type": "dimbart",
  "num_hidden_layers": 12,
  "pad_token_id": 1,
  "pos_counters": [
    [
      15,
      18,
      17,
      16,
      19,
      20,
      21,
      22,
      2

Embedding(8000, 1024)

In [9]:
model.config

VisionEncoderDecoderConfig {
  "_name_or_path": "naver-clova-ix/donut-base",
  "architectures": [
    "VisionEncoderDecoderModel"
  ],
  "decoder": {
    "_name_or_path": "",
    "activation_dropout": 0.0,
    "activation_function": "gelu",
    "add_cross_attention": true,
    "add_final_layer_norm": true,
    "architectures": null,
    "attention_dropout": 0.0,
    "bad_words_ids": null,
    "begin_suppress_tokens": null,
    "bos_token_id": 0,
    "chunk_size_feed_forward": 0,
    "classifier_dropout": 0.0,
    "cross_attention_hidden_size": null,
    "d_model": 1024,
    "decoder_attention_heads": 16,
    "decoder_ffn_dim": 4096,
    "decoder_layerdrop": 0.0,
    "decoder_layers": 4,
    "decoder_start_token_id": null,
    "dim_max_position_embeddings": [
      512,
      120,
      120
    ],
    "diversity_penalty": 0.0,
    "do_sample": false,
    "dropout": 0.1,
    "early_stopping": false,
    "encoder_attention_heads": 16,
    "encoder_ffn_dim": 4096,
    "encoder_layerdrop": 

In [10]:
processor.image_processor.size = image_size[::-1] # should be (width, height)
#processor.image_processor.size = {'width': image_size[::-1][0], 'height': image_size[::-1][1]}
processor.image_processor.do_align_long_axis = False


In [11]:
with open("msg.json", 'w') as out:
        json.dump({'outputs': []}, out, ensure_ascii=False, indent=4)

def write_msg(msg):
    with open("msg.json", encoding="utf-8") as f:
        json_data = json.load(f)
    
    with open("msg.json", 'w') as out:
        json_data['outputs'].append(msg)
        json.dump(json_data, out, ensure_ascii=False, indent=4)

In [12]:
import os
import torch
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms

class DonutTableDataset(Dataset):
    def __init__(
        self,
        annotations,
        image_size,
        max_length,
        shuffle = True,
        split = "train",
        ignore_id = -100,
        prompt_end_token = None,
    ):            
        self.annotations = annotations
        
        
        self.image_size = image_size
        self.max_length = max_length
        self.split = split
        self.ignore_id = ignore_id
        
        
    def __len__(self):
        return len(self.annotations)
    
    
    def __getitem__(self, idx):
        
        file_name = self.annotations[idx]
        
        with open(ANN_PATH + file_name + ".json", encoding="utf-8") as f:
            annotation = json.load(f)
        
        image = Image.open(IMAGE_PATH + file_name + IMG_FORMAT)
        
        
        # inputs
        pixel_values = processor(image.convert("RGB"), random_padding=self.split == "train", return_tensors="pt").pixel_values.squeeze()
        pixel_values = pixel_values.squeeze()
        
        target_sequence = processor.json2token(annotation)+"</s>"
        
        input_ids = processor.tokenizer(
            target_sequence,
            add_special_tokens=False,
            max_length= max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )["input_ids"].squeeze(0)

        labels = input_ids.clone()
        
        labels[labels == processor.tokenizer.pad_token_id] = self.ignore_id
        
        
        encoding = dict(pixel_values=pixel_values,
                        labels=labels,
                        target = target_sequence,
                       filename = file_name)
        
        return encoding

In [13]:
train_dataset = DonutTableDataset(json_list,
                             max_length = max_length,
                             image_size = image_size)

In [14]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=1, num_workers=1, shuffle=True)

In [15]:
checkpointed = False
if checkpointed:
    model = TabeleiroModel.from_pretrained("../../pubtabnet/checkpoints/TML-5lre-5-checkpoint")

start_epoch = 0
avg_size = 1000

In [16]:
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(["<table_extraction>"])[0]

In [17]:
import torch
from tqdm.auto import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu' 
model.to(device) 
optimizer = torch.optim.AdamW(params=model.parameters(), lr=8e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=len(train_dataloader)//(27*4), gamma=(0.125)**(1/27))

if checkpointed:
    optimizer.load_state_dict(torch.load("../../pubtabnet/checkpoints/optim-checkpoint"))

num_steps = 0   

for epoch in range(start_epoch, 2):
    
    print("Epoch:", epoch+1)
    mean_loss = 0
    mean_smpl_loss = 0 
    model.train()
    for i, batch in enumerate(tqdm(train_dataloader)):
            
        batch = {k: v.to(device) if k not in ["target", "filename"] else v for k, v in batch.items()}
        
        pixel_values = batch["pixel_values"]
        labels = batch["labels"]
        
        outputs = model(pixel_values=pixel_values, labels=labels)
        
        
        loss = outputs.loss
        mean_loss += loss.item()   
        mean_smpl_loss += loss.item()
        
        loss /= 4
        loss.backward()
        
        if (i+1)%4 == 0 or i+1 == len(train_dataloader):
            optimizer.step()
            optimizer.zero_grad()
            num_steps += 1
            if num_steps%10000 == 0 :
                model.save_pretrained("../../aux/modelos/by_step/model-3D-STEP_"+str(num_steps))
            
            if  scheduler.get_last_lr()[0] > 7.5e-6:
                scheduler.step() 
                
        if i % avg_size == 0:
            print(str(i) + " Loss: ", mean_smpl_loss/avg_size)
            write_msg("batch " + str(i) +" loss: "+ str(mean_smpl_loss/avg_size))
            mean_smpl_loss = 0 
        
        
    
        
    model.save_pretrained("../../aux/modelos/checkpoints/TML-5lre-5-8kproc-checkpoint-epoch"+str(epoch))
    print("Epoch's mean loss: ", mean_loss/len(train_dataloader))
    
    write_msg("Epoch checkpointed: " + str(epoch+1) +" \n"+
              "Epoch's mean Loss: " + str(mean_loss/len(train_dataloader)))

Epoch: 1


  0%|                                     | 1/500777 [00:00<96:16:32,  1.44it/s]

0 Loss:  0.016291946411132813


  0%|                                  | 1001/500777 [04:42<39:09:12,  3.55it/s]

1000 Loss:  6.619077335357666


  0%|                                  | 1323/500777 [06:12<39:05:11,  3.55it/s]


KeyboardInterrupt: 

In [None]:
processor.tokenizer.decode([57525, 57526, 57527])

In [None]:
model.save_pretrained("../../aux/modelos/model-3D")

#### processor.save_pretrained("../../pubtabnet/processors/processor-GTML-8kproc")