In [1]:
import sys


import json
import torch
import os
from transformers import DonutProcessor, VisionEncoderDecoderModel
from transformers import VisionEncoderDecoderConfig
from PIL import Image


IMG_PATH = "../../pubtabnet/imgs/final_eval/"
IMG_FORMAT = ".png"

with open('../../pubtabnet/final_eval.json') as fp:
    annotations = json.load(fp)

In [2]:
model = VisionEncoderDecoderModel.from_pretrained("../../pubtabnet/modelos/HTML-lre-5-epoch3")

Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at ../../pubtabnet/modelos/HTML-lre-5-epoch3 and are newly initialized: ['decoder.model.decoder.embed_positions.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
model.config

VisionEncoderDecoderConfig {
  "_name_or_path": "../../pubtabnet/modelos/HTML-lre-5-epoch3",
  "architectures": [
    "VisionEncoderDecoderModel"
  ],
  "decoder": {
    "_name_or_path": "",
    "activation_dropout": 0.0,
    "activation_function": "gelu",
    "add_cross_attention": true,
    "add_final_layer_norm": true,
    "architectures": null,
    "attention_dropout": 0.0,
    "bad_words_ids": null,
    "begin_suppress_tokens": null,
    "bos_token_id": 0,
    "chunk_size_feed_forward": 0,
    "classifier_dropout": 0.0,
    "cross_attention_hidden_size": null,
    "d_model": 1024,
    "decoder_attention_heads": 16,
    "decoder_ffn_dim": 4096,
    "decoder_layerdrop": 0.0,
    "decoder_layers": 4,
    "decoder_start_token_id": null,
    "diversity_penalty": 0.0,
    "do_sample": false,
    "dropout": 0.1,
    "early_stopping": false,
    "encoder_attention_heads": 16,
    "encoder_ffn_dim": 4096,
    "encoder_layerdrop": 0.0,
    "encoder_layers": 12,
    "encoder_no_repeat_ngram_

In [4]:
import torch
from torch import nn, Tensor
import math

class PositionalEncoding(nn.Module):

    def __init__(self,  max_len: int, d_model: int, dropout: float = 0.05):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.offset = 2
        
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.pos_enc = pe.squeeze(1)

    def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
        """`input_ids' shape is expected to be [bsz x seqlen]."""

        bsz, seq_len = input_ids.shape[:2]
        
        pos_start = past_key_values_length
        pos_end = past_key_values_length+seq_len
        
        positions = self.pos_enc[pos_start:pos_end].expand(bsz, -1, -1)
        
        return self.dropout(positions)

In [5]:
model.decoder.model.decoder.embed_positions = PositionalEncoding(4098, 1024)

In [6]:
processor = DonutProcessor.from_pretrained("../../pubtabnet/Donut_PubTables_HTML_Processor")

2023-12-07 19:42:25.369206: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-07 19:42:25.369234: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-07 19:42:25.369252: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [7]:
model

VisionEncoderDecoderModel(
  (encoder): DonutSwinModel(
    (embeddings): DonutSwinEmbeddings(
      (patch_embeddings): DonutSwinPatchEmbeddings(
        (projection): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
      )
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): DonutSwinEncoder(
      (layers): ModuleList(
        (0): DonutSwinStage(
          (blocks): ModuleList(
            (0-1): 2 x DonutSwinLayer(
              (layernorm_before): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
              (attention): DonutSwinAttention(
                (self): DonutSwinSelfAttention(
                  (query): Linear(in_features=128, out_features=128, bias=True)
                  (key): Linear(in_features=128, out_features=128, bias=True)
                  (value): Linear(in_features=128, out_features=128, bias=True)
                  (dropout): Dropout(p=0.0, inplace=False)
           

In [8]:
import os
import torch
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms

class DonutTableDataset(Dataset):
    def __init__(
        self,
        annotations,
        max_length,
        ignore_id = -100,
        prompt_end_token = None,
    ):            
        self.annotations_files = list(annotations.keys())
        self.annotations = annotations
        
        self.max_length = max_length
        self.ignore_id = ignore_id        
        
        
    def __len__(self):
        return len(self.annotations)
    
    
    def __getitem__(self, idx):
        
        file_name = self.annotations_files[idx]
        
        gt = self.annotations[file_name]['html']
        
        image = Image.open(IMG_PATH + file_name)
        
        
        # inputs
        pixel_values = processor(image.convert("RGB"), random_padding=False, return_tensors="pt").pixel_values.squeeze()
        pixel_values = pixel_values.squeeze()
        
        encoding = dict(file_name = file_name,
                        pixel_values=pixel_values,
                        gt = gt)
        
        return encoding

In [9]:
test_set = DonutTableDataset(annotations, 4096)

In [10]:
from torch.utils.data import DataLoader
sys.path.insert(1, '../../pubtabnet/PubTabNet/src/')
from metric import TEDS

teds = TEDS(n_jobs=4, structure_only = False)

train_dataloader = DataLoader(test_set, batch_size=1, shuffle=False)

In [11]:
processor.tokenizer("<i></i><sub></sub>")

{'input_ids': [0, 57555, 57556, 57561, 57562, 2], 'attention_mask': [1, 1, 1, 1, 1, 1]}

In [13]:
import torch
from tqdm.auto import tqdm

out_dics = {}
sum_score = 0 

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.eval()
model.to(device)

for i, batch in enumerate(tqdm(train_dataloader)):
    
    pixel_values = batch["pixel_values"].to(device)
    filename = batch["file_name"][0]
    gt = batch["gt"][0].replace("> ", ">")
    
    # autoregressively generate sequence
    outputs = model.generate(
        pixel_values,
        early_stopping = True,
        max_length= 4096,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        num_beams= 3,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True
        )
    
    table_html = "<html><body><table>"+processor.tokenizer.decode(outputs.sequences[0][2:-1]).replace("> ", ">")+ "</table></body></html>"
    
    score = teds.evaluate(table_html, gt)
    sum_score += score
    if(len(outputs.sequences[0]) > 5300):
        print(table_html, "\n\n", gt)
    print(sum_score/(i+1), score, len(outputs.sequences[0]))

    
    out_dics[filename] = table_html

  0%|          | 0/9064 [00:00<?, ?it/s]

0.9518190572487842 0.9518190572487842 336
0.899509964469726 0.8472008716906676 542
0.8992983296881706 0.8988750601250601 139
0.8847507032161608 0.8411078238001315 783
0.8457340090825592 0.6896672325481537 1107
0.8537679832796635 0.8939378542651849 191
0.8636118303188267 0.9226749125538053 211
0.8651601652932276 0.8759985101140336 661
0.861969812197048 0.8364469874276113 205
0.8519314802007889 0.7615864922344565 412
0.8521136507758108 0.8539353565260301 506
0.8544580893428472 0.8802469135802469 170
0.8541792011148363 0.8508325423787048 100
0.8575498609718862 0.9013684391135371 389
0.8574771099183119 0.8564585951682726 222
0.858627163455995 0.8758779665212414 77
0.8593171754821071 0.870357367899901 181
0.8634195246422869 0.9331594603653427 164
0.8624596702991599 0.8451822921228778 437
0.8586690568050681 0.7866474004173236 1590
0.8573673556873664 0.8313333333333334 572
0.8563270085112612 0.8344797178130512 810
0.8562129792976877 0.8537043365990734 719
0.8558473112728895 0.8474369467025284

KeyboardInterrupt: 

In [None]:
import json
with open("../../pubtabnet/HTML_pred_dic.json", 'w') as out:
    json.dump(out_dics, out, ensure_ascii=False)

In [None]:
print(torch.cat((outputs.sequences[0], torch.tensor([processor.tokenizer.eos_token_id]).to(device)), 0))

In [None]:
import json
with open("../../pubtabnet/final_eval.json", encoding="utf-8") as f:
    gts = json.load(f)

for gt in gts:
    if(gt == "663f4502ef940b47563185fb6dd16307b43b895fdb4fe1bbe8e514e6ad2bf6f2.png"):
        print(gts[gt]['html'].replace("<b>", "").replace("</b>", ""))