In [1]:
import sys
sys.path.insert(1, '../../pubtabnet/PubTabNet/src/')
from metric import TEDS

import json
import torch
import os
from transformers import DonutProcessor, VisionEncoderDecoderModel
from transformers import VisionEncoderDecoderConfig
from PIL import Image


IMG_PATH = "../../pubtabnet/imgs/final_eval/"
IMG_FORMAT = ".png"

with open('../../pubtabnet/final_eval.json') as fp:
    annotations = json.load(fp)

In [2]:
model = VisionEncoderDecoderModel.from_pretrained("../../pubtabnet/HTML-lre-5-epoch1")

Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at ../../pubtabnet/HTML-lre-5-epoch1 and are newly initialized: ['decoder.model.decoder.embed_positions.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
model.config

VisionEncoderDecoderConfig {
  "_commit_hash": null,
  "_name_or_path": "../../pubtabnet/HTML-lre-5-epoch1",
  "architectures": [
    "VisionEncoderDecoderModel"
  ],
  "decoder": {
    "_name_or_path": "",
    "activation_dropout": 0.0,
    "activation_function": "gelu",
    "add_cross_attention": true,
    "add_final_layer_norm": true,
    "architectures": null,
    "attention_dropout": 0.0,
    "bad_words_ids": null,
    "begin_suppress_tokens": null,
    "bos_token_id": 0,
    "chunk_size_feed_forward": 0,
    "classifier_dropout": 0.0,
    "cross_attention_hidden_size": null,
    "d_model": 1024,
    "decoder_attention_heads": 16,
    "decoder_ffn_dim": 4096,
    "decoder_layerdrop": 0.0,
    "decoder_layers": 4,
    "decoder_start_token_id": null,
    "diversity_penalty": 0.0,
    "do_sample": false,
    "dropout": 0.1,
    "early_stopping": false,
    "encoder_attention_heads": 16,
    "encoder_ffn_dim": 4096,
    "encoder_layerdrop": 0.0,
    "encoder_layers": 12,
    "encoder_

In [4]:
import torch
from torch import nn, Tensor
import math

class PositionalEncoding(nn.Module):

    def __init__(self,  max_len: int, d_model: int, dropout: float = 0.05):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.offset = 2
        
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.pos_enc = pe.squeeze(1)

    def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
        """`input_ids' shape is expected to be [bsz x seqlen]."""

        bsz, seq_len = input_ids.shape[:2]
        
        pos_start = past_key_values_length
        pos_end = past_key_values_length+seq_len
        
        positions = self.pos_enc[pos_start:pos_end].expand(bsz, -1, -1)
        
        return self.dropout(positions)

In [5]:
model.decoder.model.decoder.embed_positions = PositionalEncoding(4098, 1024)

In [6]:
processor = DonutProcessor.from_pretrained("../../pubtabnet/Donut_PubTables_HTML_Processor")

In [7]:
model

VisionEncoderDecoderModel(
  (encoder): DonutSwinModel(
    (embeddings): DonutSwinEmbeddings(
      (patch_embeddings): DonutSwinPatchEmbeddings(
        (projection): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
      )
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): DonutSwinEncoder(
      (layers): ModuleList(
        (0): DonutSwinStage(
          (blocks): ModuleList(
            (0): DonutSwinLayer(
              (layernorm_before): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
              (attention): DonutSwinAttention(
                (self): DonutSwinSelfAttention(
                  (query): Linear(in_features=128, out_features=128, bias=True)
                  (key): Linear(in_features=128, out_features=128, bias=True)
                  (value): Linear(in_features=128, out_features=128, bias=True)
                  (dropout): Dropout(p=0.0, inplace=False)
                )

In [8]:
import os
import torch
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms

class DonutTableDataset(Dataset):
    def __init__(
        self,
        annotations,
        max_length,
        ignore_id = -100,
        prompt_end_token = None,
    ):            
        self.annotations_files = list(annotations.keys())
        self.annotations = annotations
        
        self.max_length = max_length
        self.ignore_id = ignore_id        
        
        
    def __len__(self):
        return len(self.annotations)
    
    
    def __getitem__(self, idx):
        
        file_name = self.annotations_files[idx]
        
        gt = self.annotations[file_name]['html']
        
        image = Image.open(IMG_PATH + file_name)
        
        
        # inputs
        pixel_values = processor(image.convert("RGB"), random_padding=False, return_tensors="pt").pixel_values.squeeze()
        pixel_values = pixel_values.squeeze()
        
        encoding = dict(file_name = file_name,
                        pixel_values=pixel_values,
                        gt = gt)
        
        return encoding

In [9]:
test_set = DonutTableDataset(annotations, 4096)

In [10]:
from torch.utils.data import DataLoader

teds = TEDS(n_jobs=4, structure_only = False)

train_dataloader = DataLoader(test_set, batch_size=1, shuffle=False)

In [11]:
processor.tokenizer("<i></i><sub></sub>")

{'input_ids': [0, 57555, 57556, 57561, 57562, 2], 'attention_mask': [1, 1, 1, 1, 1, 1]}

In [None]:
import torch
from tqdm.auto import tqdm

out_dics = {}
sum_score = 0 

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.eval()
model.to(device)

# prepare decoder inputs
task_prompt = "<table_extraction>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
decoder_input_ids = decoder_input_ids.to(device)

for i, batch in enumerate(tqdm(train_dataloader)):
    
    pixel_values = batch["pixel_values"].to(device)
    filename = batch["file_name"][0]
    gt = batch["gt"][0].replace("> ", ">")
    
    # autoregressively generate sequence
    outputs = model.generate(
        pixel_values,
        decoder_input_ids=decoder_input_ids,
        early_stopping = True,
        repetition_penalty = 0.5,
        max_length= 1600,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=False,
        num_beams= 3,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True
        )
    
    table_html = "<html><body><table>"+processor.tokenizer.decode(outputs.sequences[0][2:-1]).replace("> ", ">").replace("1 ", "1") + "</table></body></html>"
    
    score = teds.evaluate(table_html, gt)
    sum_score += score
    if(len(outputs.sequences[0]) > 1300):
        print(table_html, "\n\n", gt)
    print(sum_score/(i+1), score, len(outputs.sequences[0]))

    
    out_dics[filename] = table_html

  0%|          | 0/9064 [00:00<?, ?it/s]



0.9620280299410734 0.9620280299410734 379
0.9069103691832174 0.8517927084253615 580
0.8783400777719766 0.821199494949495 159
0.8695559294663922 0.8432034845496383 848
<html><body><table><thead><tr><td></td><td><b>Variable</b></td><td><b>OR<sup>a</sup>(IC95%)</b></td><td><b><i>p</i><sup>a</sup>value</b></td><td><b>OR<sup>a</sup>(IC95%)</b></td><td><b><i>p</i><sup>b</sup>value</b></td></tr></thead><tbody><tr><td>PSVI</td><td>Part1</td><td></td><td></td><td></td><td></td></tr><tr><td></td><td>Low PSVI</td><td>1</td><td><0.01</td><td></td><td></td></tr><tr><td></td><td>Middle PSVI</td><td>1.21(1.21;1.21)</td><td><0.01</td><td></td><td></td></tr><tr><td></td><td>High PSVI</td><td>2.50 (2.49; 2.51)</td><td><0.01</td><td></td><td></td></tr><tr><td></td><td>Part 2</td><td></td><td></td><td></td><td></td></tr><tr><td>Gender</td><td>Female</td><td>2.05 (2.04; 2.05)</td><td><0.01</td><td>1.78 (1.59; 2.00)</td><td><0.01</td></tr><tr><td></td><td>Male</td><td>1</td><td><0.01</td><td>1</td><td></td>

0.8478117313221282 0.9029958252796776 197
0.8603620569649854 0.9356640108221291 229
0.8622063047437949 0.8751160391954615 695
0.8605518878439291 0.8473165526450026 214
0.8436121294473102 0.6911543038777398 413


In [None]:
import json
with open("../../pubtabnet/HTML_pred_dic.json", 'w') as out:
    json.dump(out_dics, out, ensure_ascii=False)

In [None]:
print(torch.cat((outputs.sequences[0], torch.tensor([processor.tokenizer.eos_token_id]).to(device)), 0))

In [None]:
import json
with open("../../pubtabnet/final_eval.json", encoding="utf-8") as f:
    gts = json.load(f)

for gt in gts:
    if(gt == "663f4502ef940b47563185fb6dd16307b43b895fdb4fe1bbe8e514e6ad2bf6f2.png"):
        print(gts[gt]['html'].replace("<b>", "").replace("</b>", ""))