In [1]:
import sys


import json
import torch
import os
from transformers import DonutProcessor
from transformers import VisionEncoderDecoderConfig

sys.path.insert(0, '../src')
from modeling_pos_donut import PosDonutModel

from PIL import Image


IMG_PATH = "../../aux/data/imgs/final_eval/"

filelist = os.listdir(IMG_PATH)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config = VisionEncoderDecoderConfig.from_pretrained("../../aux/models/model-Pos-1_EPOCHS")
model = PosDonutModel.from_pretrained("../../aux/models/model-Pos-1_EPOCHS", config = config)

Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at ../../aux/models/model-Pos-1_EPOCHS and are newly initialized: ['decoder.lm_head.weight', 'decoder.model.decoder.embed_positions.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
model.config

VisionEncoderDecoderConfig {
  "_name_or_path": "../../aux/models/model-Pos-1_EPOCHS",
  "architectures": [
    "VisionEncoderDecoderModel"
  ],
  "decoder": {
    "_name_or_path": "",
    "activation_dropout": 0.0,
    "activation_function": "gelu",
    "add_cross_attention": true,
    "add_final_layer_norm": true,
    "architectures": null,
    "attention_dropout": 0.0,
    "bad_words_ids": null,
    "begin_suppress_tokens": null,
    "bos_token_id": 0,
    "chunk_size_feed_forward": 0,
    "classifier_dropout": 0.0,
    "cross_attention_hidden_size": null,
    "d_model": 1024,
    "decoder_attention_heads": 16,
    "decoder_ffn_dim": 4096,
    "decoder_layerdrop": 0.0,
    "decoder_layers": 4,
    "decoder_start_token_id": null,
    "diversity_penalty": 0.0,
    "do_sample": false,
    "dropout": 0.1,
    "early_stopping": false,
    "encoder_attention_heads": 16,
    "encoder_ffn_dim": 4096,
    "encoder_layerdrop": 0.0,
    "encoder_layers": 12,
    "encoder_no_repeat_ngram_size":

In [14]:
processor = DonutProcessor.from_pretrained("../../aux/processors/donut-base")
processor.image_processor.size = model.encoder.config.image_size[::-1] # should be (width, height)
processor.image_processor.do_align_long_axis = False

new_tokens  = ["<table_extraction>"]
new_tokens += ["<thead>", "</thead>", "<tbody>", "</tbody>"]
new_tokens += ["<tr>", "</tr>", "<td>", "</td>"]

new_tokens += ["<td ", ">"]
for i in range(1, 11):
    new_tokens +=['colspan="'+str(i)+'"']
    new_tokens +=['rowspan="'+str(i)+'"']

new_tokens += ["<i>", "</i>","<b>", "</b>", "<sup>", "</sup>", "<sub>", "</sub>"]

processor.tokenizer.add_tokens(new_tokens, special_tokens = False)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


39

In [18]:
processor.tokenizer.decode(57525)

'<table_extraction>'

In [5]:
model

VisionEncoderDecoderModel(
  (encoder): DonutSwinModel(
    (embeddings): DonutSwinEmbeddings(
      (patch_embeddings): DonutSwinPatchEmbeddings(
        (projection): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
      )
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): DonutSwinEncoder(
      (layers): ModuleList(
        (0): DonutSwinStage(
          (blocks): ModuleList(
            (0-1): 2 x DonutSwinLayer(
              (layernorm_before): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
              (attention): DonutSwinAttention(
                (self): DonutSwinSelfAttention(
                  (query): Linear(in_features=128, out_features=128, bias=True)
                  (key): Linear(in_features=128, out_features=128, bias=True)
                  (value): Linear(in_features=128, out_features=128, bias=True)
                  (dropout): Dropout(p=0.0, inplace=False)
           

In [6]:
import os
import torch
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms

class DonutTableDataset(Dataset):
    def __init__(
        self,
        annotations,
        max_length,
        ignore_id = -100,
        prompt_end_token = None,
    ):            
        self.annotations_files = list(annotations.keys())
        self.annotations = annotations
        
        self.max_length = max_length
        self.ignore_id = ignore_id        
        
        
    def __len__(self):
        return len(self.annotations)
    
    
    def __getitem__(self, idx):
        
        file_name = self.annotations_files[idx]
        
        gt = self.annotations[file_name]['html']
        
        image = Image.open(IMG_PATH + file_name)
        
        
        # inputs
        pixel_values = processor(image.convert("RGB"), random_padding=False, return_tensors="pt").pixel_values.squeeze()
        pixel_values = pixel_values.squeeze()
        
        encoding = dict(file_name = file_name,
                        pixel_values=pixel_values,
                        gt = gt)
        
        return encoding

In [7]:
import json
with open('../../aux/data/anns/test/final_eval.json') as fp:
    annotations = json.load(fp)

test_set = DonutTableDataset(annotations, 4096)

In [21]:
from torch.utils.data import DataLoader
sys.path.insert(1, '../../aux/PubTabNet/src/')
from metric import TEDS

teds = TEDS(n_jobs=4, structure_only = False)

test_dataloader = DataLoader(test_set, batch_size=1, shuffle=False)

In [22]:
processor.tokenizer("<i></i><sub></sub>")

{'input_ids': [0, 57555, 57556, 57561, 57562, 2], 'attention_mask': [1, 1, 1, 1, 1, 1]}

In [23]:
import torch
from tqdm.auto import tqdm

out_dics = {}
sum_score = 0 

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.eval()
model.to(device)

decoder_input_ids = processor.tokenizer("<table_extraction>", add_special_tokens=False, return_tensors="pt").input_ids


for i, batch in enumerate(tqdm(test_dataloader)):
    
    pixel_values = batch["pixel_values"].to(device)
    filename = batch["file_name"][0]
    gt = batch["gt"][0].replace("> ", ">")
    
    # autoregressively generate sequence
    outputs = model.generate(
        pixel_values,
        decoder_input_ids=decoder_input_ids.to(device),
        early_stopping = True,
        max_length= 1600,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        num_beams= 3,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True
        )


    print(processor.tokenizer.decode(outputs.sequences[0]))
    table_html = "<html><body><table>"+processor.tokenizer.decode(outputs.sequences[0][2:-1]).replace("> ", ">")+ "</table></body></html>"

    
    
    score = teds.evaluate(table_html, gt)
    sum_score += score
    if(len(outputs.sequences[0]) > 5300):
        print(table_html, "\n\n", gt)
    print(sum_score/(i+1), score, len(outputs.sequences[0]))

    
    out_dics[filename] = table_html

  0%|                                       | 1/9064 [00:07<19:42:05,  7.83s/it]

<s><table_extraction><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad

  0%|                                       | 1/9064 [00:15<37:50:45, 15.03s/it]


KeyboardInterrupt: 

In [None]:
import json
with open("../../pubtabnet/HTML_pred_dic.json", 'w') as out:
    json.dump(out_dics, out, ensure_ascii=False)

In [None]:
print(torch.cat((outputs.sequences[0], torch.tensor([processor.tokenizer.eos_token_id]).to(device)), 0))

In [None]:
import json
with open("../../pubtabnet/final_eval.json", encoding="utf-8") as f:
    gts = json.load(f)

for gt in gts:
    if(gt == "663f4502ef940b47563185fb6dd16307b43b895fdb4fe1bbe8e514e6ad2bf6f2.png"):
        print(gts[gt]['html'].replace("<b>", "").replace("</b>", ""))