In [1]:
from transformers import DonutProcessor
import sys
sys.path.insert(1, 'PubTabNet/src/')

from metric import TEDS
import json
import os


SPLIT = "train/"

PUBTABNET = 'data/anns/'

ANN_PATH = PUBTABNET

json_list = os.listdir(ANN_PATH + SPLIT)

aux_list = []
html_list = []

for f in json_list:
    if(f[-6:-5] == "L"):
        html_list.append(f)
    else:
        aux_list.append(f)
json_list = aux_list
old_processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from transformers import AddedToken

new_tokens = ["<table_extraction>", "<table>", "<row>", "<cell>", "<row_and_col_header>", "<row_header>", "<col_header>"]
new_tokens += ["<content_row_and_col_header>", "<content_row_header>", "<content_col_header>", "<content>"]

for i in range(2):
    for j in range(2):
        for k in range(2):
            new_tokens.append("<span_type=0" + str(i) + str(j) + str(k) + ">")
            new_tokens.append("<span_type=1" + str(i) + str(j) + str(k) + ">")

old_processor.tokenizer.add_tokens(new_tokens, special_tokens = False)

27

In [3]:
json_list[:10]

['PMC3549936_002_00.json',
 'PMC3112512_009_00.json',
 'PMC3772909_007_00.json',
 'PMC5294720_004_00.json',
 'PMC3446317_008_00.json',
 'PMC5071235_005_00.json',
 'PMC1421390_001_00.json',
 'PMC5251219_004_01.json',
 'PMC3973966_003_00.json',
 'PMC4147206_002_00.json']

In [17]:
def cel2token(cell):
    if cell['span_type'][10:] != '0000':
        sequence = "<" + cell['span_type'] + ">"
        if cell['content_holder']:
            if cell['row_header'] and cell['col_header']:
                sequence += "<content_row_and_col_header>"
            elif cell['col_header']:
                sequence += "<content_col_header>"
            elif cell['row_header']:
                sequence += "<content_row_header>"
            else:
                sequence += "<content>"
            sequence += cell['content']
    else:
        sequence = ""
        if cell['content_holder']:
            if cell['row_header'] and cell['col_header']:
                sequence += "<row_and_col_header>"
            elif cell['col_header']:
                sequence += "<col_header>"
            elif cell['row_header']:
                sequence += "<row_header>"
            else:
                sequence += "<cell>"
            sequence += cell['content']
        
    return sequence

def row2token(row):
    sequence = "<row>"
    for cell in row:
        sequence += cel2token(cell)
    
    return sequence


def table2token(table):
    sequence = "<table>"
    for row in table:
        sequence += row2token(row)
    
    return sequence


def json2token(json):
    sequence = ""
    if('tables' in json):
        for table in json['tables']:
            sequence += table2token(table)

    return sequence

In [38]:
cell_types = ["<cell>", "<col_header>", "<row_header>", "<row_and_col_header>"]
content_types = ["<content_row_and_col_header>", "<content_row_header>", "<content_col_header>", "<content>"]


added_content = ["<b>", "<i>", "<sup>", "<sub>","</b>", "</i>", "</sup>", "</sub>"]

for i in range(2):
    for j in range(2):
        for k in range(2):
            cell_types.append("<span_type=0" + str(i) + str(j) + str(k) + ">")
            cell_types.append("<span_type=1" + str(i) + str(j) + str(k) + ">")


def token2cell(processor, seq, i, cell_coord, cell_type):
    if cell_type[:5] == "<span": 
        span_type = cell_type[1:-1]
        cell_type = processor.decode(seq[i])
        if cell_type[:5] == "<cont":
            i += 1
            if cell_type != "<content>":
                cell_type = "<"+cell_type[9:]
            else:
                cell_type = "<cell>"
    else:
        span_type = "span_type=0000"
    
    x, y = cell_coord
    aux_cell = {
        "row": x,
        "col": y,
        "col_header": cell_type in ["<col_header>", "<row_and_col_header>"],
        "row_header": cell_type in ["<row_header>", "<row_and_col_header>"],
        "colspan": -1 if span_type != "span_type=0000"  else 1,
        "rowspan": -1 if span_type != "span_type=0000"  else 1,
        "span_type": span_type,
        "content": ""
    }
    start_cont = i
    while processor.decode(seq[i]) not in ["<row>", "</s>", "<table>"] + cell_types:
        i += 1
        
    aux_cell['content'] += processor.decode(seq[start_cont: i])
    
    return aux_cell, i

def token2row(processor, seq, i, row_id):
    cells = []
    while processor.decode(seq[i]) not in ["<row>", "</s>", "<table>"]:
        if processor.decode(seq[i]) in cell_types:
            aux_cell, i = token2cell(processor, seq, i+1, (row_id, len(cells)), processor.decode(seq[i]))
            cells.append(aux_cell)
        else:
            i += 1
    return cells, i


def token2table(processor, seq, i):
    rows = []
    while processor.decode(seq[i]) not in ["</s>", "<table>"]:
        if processor.decode(seq[i]) == "<row>":
            aux_row, i = token2row(processor, seq, i+1, len(rows))
            rows.append(aux_row)
        else:
            i += 1
            
    return rows, i


def crop_empty_left(table):
    for row in table:
        empty_left = []
        for cell in row:
            if(cell['content'] == "" and cell['colspan'] == 1):
                empty_left.append(cell)
            else:
                empty_left.clear()
        
        for cell in empty_left:
            cell['colspan'] = 0
            cell['rowspan'] = 0

def token2ann(processor, seq, i):
    tables = []
    while processor.decode(seq[i]) != "</s>":
        if processor.decode(seq[i] == "<table>"):
            aux_table, i = token2table(processor, seq, i+1)
            crop_empty_left(aux_table)
            tables.append(aux_table)
        else:
            i += 1
        
    return {'tables': tables}

In [39]:
def update_vals(cell, content, col_header, row_header):
    cell['colspan'] = 0
    cell['rowspan'] = 0
    
    if cell['content'] != "":
        content = cell['content']
    
    col_header = col_header or cell['col_header']
    row_header = row_header or cell['row_header']
    
    return content, col_header, row_header


def define_by_path(cell, table):
    col_header, row_header = False, False
    i, j = cell['row'], cell['col']
    first_j = j

    content = ''
    
    while True:
        while True:
            content, col_header, row_header = update_vals(cell, content, col_header, row_header)
            
            if(cell['span_type'][-4:-3] != '1'):
                break
            j += 1
            cell = table[i][j]
            
        if(cell['span_type'][-3:-2] != '1'):
            break
        i += 1
        j = first_j
        cell = table[i][j]
    
    return (i, j), content, col_header, row_header

    
def define_spannings(table):
    for row in table:
        for cell in row:
            if(cell['colspan'] == -1):
                end_coord = (cell['row'], cell['col'])
                content = cell['content']
             
                end_coord, content, col_header, row_header = define_by_path(cell, table)
                
                
                cell['rowspan'] = end_coord[0] - cell['row'] + 1
                cell['colspan'] = end_coord[1] - cell['col'] + 1
                cell['col_header'] = col_header
                cell['row_header'] = row_header
                cell['content'] = content
                

In [40]:
def cell2html(cell):
    seq = "<td"
    if(cell['rowspan'] > 1):
        seq += ' rowspan="' + str(cell['rowspan']) +'"'
    if(cell['colspan'] > 1):
        seq += ' colspan="' + str(cell['colspan']) +'"'
    seq += ">" + cell["content"] + "</td>"
    
    return seq

def table2html(table):
    seq = ""
    head = False
    body = False
    for row in table:
        row_seq = ""
        count_header = 0
        row_len = len(row)
        for cell in row:
            if(cell['colspan'] > 0):
                count_header += cell['col_header']
                row_seq += cell2html(cell)
            else:
                row_len -= 1
            
        if(count_header  > row_len/2):
            if body:
                seq += "</tbody>"
                body = False
                
            if not head:
                seq += "<thead>"
                head = True
            seq +=  "<tr>" + row_seq + "</tr>" 
        else:
            if head:
                seq += "</thead>"
                head = False
                
            if not body:
                seq += "<tbody>"
                body = True
                
            seq +=  "<tr>" + row_seq + "</tr>"
    
    if head:
        seq += "</thead>"
        
    if body:
        seq += "</tbody>"
    
    return seq

In [41]:
gt_list = []
pred_list = []

teds = TEDS(n_jobs=4, structure_only = False)

In [42]:
processor.tokenizer("</s>")

{'input_ids': [0, 2, 2], 'attention_mask': [1, 1, 1]}

In [43]:
import random
import torch
from difflib import SequenceMatcher
from tqdm.auto import tqdm

removed_tags = []
soma_ratios = 0

for i, file_name in enumerate(tqdm(json_list)):

    with open(ANN_PATH + SPLIT + file_name, encoding="utf-8") as f:
        annotation = json.load(f)
    with open(ANN_PATH + SPLIT + file_name[:-5]+"-HTML.json", encoding="utf-8") as f:
        gt_html = "<html><body><table>"+str(json.load(f))+"</table></body></html>"
    
    for tag in removed_tags:
        gt_html = "<html><body><table>"+gt_html.replace(tag, "")+"</table></body></html>"
    
    target_sequence = "<s>"+json2token(annotation)+"</s>"
    input_ids = processor.tokenizer(
        target_sequence,
        add_special_tokens=False,
        max_length= 4096,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )["input_ids"].squeeze(0)
    
    
    input_ids = torch.cat((input_ids, torch.Tensor([2, 2]).int()), 0)
    
    table = token2ann(processor, input_ids, 1)
    table_aux  = table.copy()
    
    
    
    define_spannings(table['tables'][0])
    pred_html = "<html><body><table>"+table2html(table['tables'][0])+"</table></body></html>"
    
    pred_html = pred_html
    gt_html = gt_html
    
    pred_list.append(pred_html)
    gt_list.append(gt_html)
    
    ratio = teds.evaluate(str(pred_html), str(gt_html))
    if(ratio < .9):
        print(processor.decode(input_ids[1]))
        #print(target_sequence)
        #print(table['tables'][0])
        print(pred_html)
        print()
        print(gt_html)
        print(ratio)
    if i% 10 == 0:   
        print(str(i)+'->', soma_ratios/max(1, i))
    
    soma_ratios += ratio
print(soma_ratios/len(json_list))

  0%|                                    | 1/500777 [00:03<457:13:19,  3.29s/it]

<table>
<html><body><table><thead><tr><td colspan="2"><b> </b></td><td><b>Your Score</b></td><td colspan="3"><b> </b></td><td><b>Your Score</b></td></tr></thead><tbody><tr><td colspan="2"><b><unk>. Your age group?</b></td><td></td><td colspan="3"><b>6. Are you currently taking medication for high blood pressure?</b></td></tr><tr><td>Under 35 years</td><td>0 points</td><td></td><td colspan="2">No</td><td>0 points</td></tr><tr><td>35 – 44 years</td><td>2 points</td><td></td><td colspan="2">Yes</td><td>2 points</td><td>___</td></tr><tr><td>45 – 54 years</td><td>4 points</td><td></td><td colspan="2"></td></tr><tr><td>55 – 64 years</td><td>6 points</td><td></td><td colspan="2"></td></tr><tr><td>65 years or over</td><td>8 points</td><td>___</td><td colspan="2"></td></tr><tr><td colspan="3"><b>2. Your gender?</b></td><td colspan="3"><b>7. Do you currently smoke cigarettes or any other tobacco products on a daily basis?</b></td></tr><tr><td>Female</td><td>0 points</td><td></td><td colspan="2">

  0%|                                    | 2/500777 [00:05<367:46:38,  2.64s/it]


KeyboardInterrupt: 

In [None]:
input_ids

In [None]:
json.dump(pred_list, open("teste-pred.json", 'w'))
json.dump(gt_list, open("teste-gt.json", 'w'))

In [None]:
a = '<thead><tr><td rowspan="2"><b>Cutoff (IRS)</b></td><td rowspan="2"><b>Number</b></td><td colspan="2"><b>Overall survival prognosis</b></td></tr><tr><td><b>Log-rank <i>P</i></b></td><td><b>Corrected <i>P</i><sup><i>∗</i></sup></b></td></tr></thead><tbody><tr><td>≥1</td><td>99/94</td><td>5.56</td><td>0.0184</td></tr><tr><td>≥2</td><td>129/64</td><td>4.37</td><td>0.0377</td></tr><tr><td>≥3</td><td>150/43</td><td>8.37</td><td>0.0038</td></tr><tr><td>≥4</td><td>161/32</td><td>6.35</td><td>0.0117</td></tr><tr><td>≥5 </td><td>177/16</td><td>4.50</td><td>0.0338</td></tr></tbody>'

for tag in ["<i>", "</i>","<b>","</b>"]:
    a = a.replace(tag, "")
    
a

In [None]:
target_sequence = "Capítulo Primeiro Do título Uma noite destas, vindo da cidade para o Engenho Novo, encontrei no[ 1 ] trem da Central um rapaz aqui do bairro, que eu conheço de vista e de cha- péu. Comprimentou-me, sentou-se ao pé de mim, falou da lua e dos mi- nistros, e acabou recitando-me versos. A viagem era curta, e os versos pode ser que não fossem inteiramente maus. Succedeu, porém, que, como eu es- tava cansado, fechei os olhos três ou quatro vezes; tanto bastou para que ele interrompesse a leitura e metesse os versos no bolso. — Continue, disse eu acordando.[ 2 ] — Já acabei, murmurou ele.[ 3 ] — São muito bonitos.[ 4 ] Vi-lhe fazer um gesto para tirá-los outra vez do bolso, mas não passou[ 5 ] do gesto; estava amuado. No dia seguinte entrou a dizer de mim nomes feios, e acabou alcunhando-me Dom Casmurro."

input_ids = processor.tokenizer(
        target_sequence,
        add_special_tokens=True,
        max_length= 1536,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )["input_ids"].squeeze(0)

In [None]:
input_ids

In [None]:
for token in input_ids:
    print(processor.decode(token))

In [None]:
processor.decode(input_ids)