In [1]:
import torch
import os
import sys
sys.path.insert(0, '../src')

from transformers import DonutProcessor
from modeling_tabeleiro import TabeleiroModel
from PIL import Image


IMG_PATH = "../../pubtabnet/imgs/final_eval/"

filelist = os.listdir(IMG_PATH)

In [2]:
processor = DonutProcessor.from_pretrained("../../pubtabnet/processors/processor-GTML-8kproc")

2024-04-01 11:41:19.512532: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-01 11:41:19.512555: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-01 11:41:19.512573: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
model = TabeleiroModel.from_pretrained("../../pubtabnet/modelos/model-minimal-3D-epoch3-lr5e-5-8kproc")
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(["<table_extraction>"])[0]
model

You are using a model of type dimbart to instantiate a model of type mbart. This is not supported for all configurations of models and can yield errors.


TabeleiroModel(
  (encoder): DonutSwinModel(
    (embeddings): DonutSwinEmbeddings(
      (patch_embeddings): DonutSwinPatchEmbeddings(
        (projection): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
      )
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): DonutSwinEncoder(
      (layers): ModuleList(
        (0): DonutSwinStage(
          (blocks): ModuleList(
            (0-1): 2 x DonutSwinLayer(
              (layernorm_before): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
              (attention): DonutSwinAttention(
                (self): DonutSwinSelfAttention(
                  (query): Linear(in_features=128, out_features=128, bias=True)
                  (key): Linear(in_features=128, out_features=128, bias=True)
                  (value): Linear(in_features=128, out_features=128, bias=True)
                  (dropout): Dropout(p=0.0, inplace=False)
                )
    

In [4]:
model.config

VisionEncoderDecoderConfig {
  "decoder": {
    "_name_or_path": "../../pubtabnet/modelos/model-minimal-3D-epoch3-lr5e-5-8kproc/dimbart_decoder",
    "activation_dropout": 0.0,
    "activation_function": "gelu",
    "add_cross_attention": true,
    "add_final_layer_norm": true,
    "architectures": [
      "DiMBartForCausalLM"
    ],
    "attention_dropout": 0.0,
    "bad_words_ids": null,
    "begin_suppress_tokens": null,
    "bos_token_id": 0,
    "chunk_size_feed_forward": 0,
    "classifier_dropout": 0.0,
    "cross_attention_hidden_size": null,
    "d_model": 1024,
    "decoder_attention_heads": 16,
    "decoder_ffn_dim": 4096,
    "decoder_layerdrop": 0.0,
    "decoder_layers": 4,
    "decoder_start_token_id": null,
    "dim_max_position_embeddings": [
      514,
      120,
      120
    ],
    "diversity_penalty": 0.0,
    "do_sample": false,
    "dropout": 0.1,
    "early_stopping": false,
    "encoder_attention_heads": 16,
    "encoder_ffn_dim": 4096,
    "encoder_layerdrop":

In [5]:
model.decoder.model.decoder.embed_positions(torch.IntTensor([[0, 1, 3, 4],
                                             [5, 7, 8, 6]]))[0].shape

torch.Size([2, 4, 1024])

In [6]:
torch.as_tensor([37, 37, 120, 1320, 8], device = "cuda:0", dtype = torch.int)

tensor([  37,   37,  120, 1320,    8], device='cuda:0', dtype=torch.int32)

In [7]:
import os
import torch
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms

class DonutTableDataset(Dataset):
    def __init__(
        self,
        annotations,
        max_length,
        ignore_id = -100,
        prompt_end_token = None,
    ):            
        self.annotations_files = list(annotations.keys())
        self.annotations = annotations
        
        self.max_length = max_length
        self.ignore_id = ignore_id        
        
        
    def __len__(self):
        return len(self.annotations)
    
    
    def __getitem__(self, idx):
        
        file_name = self.annotations_files[idx]
        
        gt = self.annotations[file_name]['html']
        
        image = Image.open(IMG_PATH + file_name)
        
        
        # inputs
        pixel_values = processor(image.convert("RGB"), random_padding=False, return_tensors="pt").pixel_values.squeeze()
        pixel_values = pixel_values.squeeze()
        
        encoding = dict(file_name = file_name,
                        pixel_values=pixel_values,
                        gt = gt)
        
        return encoding

In [8]:
import json
with open('../../pubtabnet/final_eval.json') as fp:
    annotations = json.load(fp)

test_set = DonutTableDataset(annotations, 4096)

In [9]:
cell_types = ["<cell>", "<col_header>", "<row_header>", "<row_and_col_header>"]
content_types = ["<content_row_and_col_header>", "<content_row_header>", "<content_col_header>", "<content>"]


added_content = ["<b>", "<i>", "<sup>", "<sub>","</b>", "</i>", "</sup>", "</sub>"]

for i in range(2):
    for j in range(2):
        for k in range(2):
            cell_types.append("<span_type=0" + str(i) + str(j) + str(k) + ">")
            cell_types.append("<span_type=1" + str(i) + str(j) + str(k) + ">")


def token2cell(processor, seq, i, cell_coord, cell_type):
    if cell_type[:5] == "<span": 
        span_type = cell_type[1:-1]
        cell_type = processor.decode(seq[i])
        if cell_type[:5] == "<cont":
            i += 1
            if cell_type != "<content>":
                cell_type = "<"+cell_type[9:]
            else:
                cell_type = "<cell>"
    else:
        span_type = "span_type=0000"
    
    x, y = cell_coord
    aux_cell = {
        "row": x,
        "col": y,
        "col_header": cell_type in ["<col_header>", "<row_and_col_header>"],
        "row_header": cell_type in ["<row_header>", "<row_and_col_header>"],
        "colspan": -1 if span_type != "span_type=0000"  else 1,
        "rowspan": -1 if span_type != "span_type=0000"  else 1,
        "span_type": span_type,
        "content": ""
    }
    start_cont = i
    while processor.decode(seq[i]) not in ["<row>", "</s>", "<table>"] + cell_types:
        i += 1
        
    aux_cell['content'] += processor.decode(seq[start_cont: i])
    
    return aux_cell, i

def token2row(processor, seq, i, row_id):
    cells = []
    while processor.decode(seq[i]) not in ["<row>", "</s>", "<table>"]:
        if processor.decode(seq[i]) in cell_types:
            aux_cell, i = token2cell(processor, seq, i+1, (row_id, len(cells)), processor.decode(seq[i]))
            cells.append(aux_cell)
        else:
            i += 1
    return cells, i


def token2table(processor, seq, i):
    rows = []
    while processor.decode(seq[i]) not in ["</s>", "<table>"]:
        if processor.decode(seq[i]) == "<row>":
            aux_row, i = token2row(processor, seq, i+1, len(rows))
            rows.append(aux_row)
        else:
            i += 1
            
    return rows, i


def crop_empty_left(table):
    for row in table:
        empty_left = []
        for cell in row:
            if(cell['content'] == "" and cell['colspan'] == 1):
                empty_left.append(cell)
            else:
                empty_left.clear()
        
        for cell in empty_left:
            cell['colspan'] = 0
            cell['rowspan'] = 0

def token2ann(processor, seq, i):
    tables = []
    while processor.decode(seq[i]) != "</s>":
        if processor.decode(seq[i] == "<table>"):
            aux_table, i = token2table(processor, seq, i+1)
            #crop_empty_left(aux_table)
            tables.append(aux_table)
        else:
            i += 1
        
    return {'tables': tables}

In [10]:
def update_vals(cell, content, col_header, row_header):
    cell['colspan'] = 0
    cell['rowspan'] = 0
    
    if cell['content'] != "":
        content = cell['content']
    
    col_header = col_header or cell['col_header']
    row_header = row_header or cell['row_header']
    
    return content, col_header, row_header


def define_by_path(cell, table):
    col_header, row_header = False, False
    i, j = cell['row'], cell['col']
    first_j = j

    content = ''
    
    while True:
        while True:
            content, col_header, row_header = update_vals(cell, content, col_header, row_header)
            
            if(cell['span_type'][-4:-3] != '1'):
                break
            j += 1
            cell = table[i][j]
            
        if(cell['span_type'][-3:-2] != '1'):
            break
        i += 1
        j = first_j
        cell = table[i][j]
    
    return (i, j), content, col_header, row_header

    
def define_spannings(table):
    for row in table:
        for cell in row:
            if(cell['colspan'] == -1):
                end_coord = (cell['row'], cell['col'])
                content = cell['content']
             
                end_coord, content, col_header, row_header = define_by_path(cell, table)
                
                
                cell['rowspan'] = end_coord[0] - cell['row'] + 1
                cell['colspan'] = end_coord[1] - cell['col'] + 1
                cell['col_header'] = col_header
                cell['row_header'] = row_header
                cell['content'] = content
                

In [11]:
def cell2html(cell):
    seq = "<td"
    if(cell['rowspan'] > 1):
        seq += ' rowspan="' + str(cell['rowspan']) +'"'
    if(cell['colspan'] > 1):
        seq += ' colspan="' + str(cell['colspan']) +'"'
    seq += ">" + cell["content"] + "</td>"
    
    return seq

def table2html(table):
    seq = ""
    head = False
    body = False
    for row in table:
        row_seq = ""
        count_header = 0
        row_len = len(row)
        for cell in row:
            if(cell['colspan'] > 0):
                count_header += cell['col_header']
                row_seq += cell2html(cell)
            else:
                row_len -= 1
            
        if(count_header  > row_len/2):
            if body:
                seq += "</tbody>"
                body = False
                
            if not head:
                seq += "<thead>"
                head = True
            seq +=  "<tr>" + row_seq + "</tr>" 
        else:
            if head:
                seq += "</thead>"
                head = False
                
            if not body:
                seq += "<tbody>"
                body = True
                
            seq +=  "<tr>" + row_seq + "</tr>"
    
    if head:
        seq += "</thead>"
        
    if body:
        seq += "</tbody>"
    
    return seq

In [12]:
processor.decode(processor.tokenizer.eos_token_id)

'</s>'

In [13]:
sys.path.insert(1, '../../pubtabnet/PubTabNet/src/')
from metric import TEDS

teds = TEDS(n_jobs=4, structure_only = True, ignore_nodes = ["b"])

In [14]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(test_set, batch_size=1, shuffle=True)

In [15]:
import torch
from tqdm.auto import tqdm

out_dics = {}
sum_score = 0

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
model.eval()

for i, batch in enumerate(tqdm(train_dataloader)):
    
    pixel_values = batch["pixel_values"].to(device)
    filename = batch["file_name"][0]
    gt = batch['gt'][0].replace("> ", ">").replace(" <", "<")
    
    # autoregressively generate sequence
    outputs = model.generate(
        pixel_values,
        max_length= 1600,
        early_stopping=True,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        num_beams= 3,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True,
        )
    
    seq = torch.cat((outputs.sequences[0], torch.Tensor([2, 2]).int().to(device)), 0)
    
    table = token2ann(processor, seq, 1)
    try:
        define_spannings(table['tables'][0])
    except:
        print(table['tables'][0])
    table_html = "<html><body><table>" + table2html(table['tables'][0]).replace("> ", ">").replace(" <", "<") + "</table></body></html>"
    
    
    score = teds.evaluate(table_html, gt)
    sum_score += score
    
    print(score,"--",sum_score/(i+1))
    
    out_dics[filename] = table_html

  0%|          | 0/9064 [00:00<?, ?it/s]



1.0 -- 1.0
0.9695121951219512 -- 0.9847560975609756
1.0 -- 0.9898373983739838
1.0 -- 0.9923780487804879
1.0 -- 0.9939024390243902
1.0 -- 0.9949186991869919
1.0 -- 0.9956445993031359
1.0 -- 0.9961890243902439
1.0 -- 0.9966124661246611
0.9901960784313726 -- 0.9959708273553323
1.0 -- 0.9963371157775748
0.8709677419354839 -- 0.9858896679574006
1.0 -- 0.9869750781145237
1.0 -- 0.987905429677772
1.0 -- 0.9887117343659205
0.9594594594594594 -- 0.9868834671842667
1.0 -- 0.9876550279381332
1.0 -- 0.988340859719348
1.0 -- 0.9889544986814877
1.0 -- 0.9895067737474132
0.8907103825136612 -- 0.9848021836886631
1.0 -- 0.9854929935209966
1.0 -- 0.9861237329331272
0.4135338345864662 -- 0.9622658205020164
1.0 -- 0.9637751876819357
1.0 -- 0.9651684496941689
0.956989247311828 -- 0.9648655162726008
0.8666666666666667 -- 0.9613584145009603
0.9259259259259259 -- 0.9601366045500971
1.0 -- 0.9614653843984271
1.0 -- 0.9627084365146069
1.0 -- 0.9638737978735255
0.853448275862069 -- 0.9605275699337843
0.910256410

KeyboardInterrupt: 

In [None]:
processor.tokenizer.decode(outputs.sequences[0])

In [None]:
cp_dic = out_dics.copy()

In [None]:
out_dics = cp_dic.copy()
for filename in out_dics:
    out_dics[filename] = "<html><body><table>" + out_dics[filename] + "</table></body></html>"

In [None]:
import json
with open("../../pubtabnet/TML_pred_dic-EPOCH1.json", 'w') as out:
    json.dump(out_dics, out, ensure_ascii=False)

In [None]:
print(torch.cat((outputs.sequences[0], torch.tensor([processor.tokenizer.eos_token_id]).to(device)), 0))

In [None]:
import json
with open("../../pubtabnet/final_eval.json", encoding="utf-8") as f:
    gts = json.load(f)

for gt in gts:
    if(gt == "663f4502ef940b47563185fb6dd16307b43b895fdb4fe1bbe8e514e6ad2bf6f2.png"):
        print(gts[gt]['html'].replace("<b>", "").replace("</b>", ""))

In [None]:
del model

In [None]:
processor