In [1]:
from transformers import AutoTokenizer, XLMRobertaTokenizer
import json
import pandas as pd
import pandas_profiling as pp
import re
from bs4 import BeautifulSoup

In [2]:
tag_regex = re.compile('<.*?>')

def extract_contents(text):
    """Remove html tags from a string"""
    return re.sub(tag_regex, '', text)

def extract_tags(text):
    return tag_regex.findall(text)


In [3]:
use_fast_tokenizer = False
tokenizer_dir = "D:/dataset/table_ocr/pubtabnet/tokenizer"
result_file = "D:/dataset/table_ocr/pubtabnet/results.jsonl"
eda_output_path = "./baseline_epoch5_eda_total.html"
df_output_path = "./baseline_epoch5_df.pkl"

In [4]:
if use_fast_tokenizer:
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
else:
    tokenizer = XLMRobertaTokenizer.from_pretrained(tokenizer_dir)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'MBartTokenizer'. 
The class this function is called from is 'XLMRobertaTokenizer'.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
data_file = open(result_file, encoding="utf-8")

In [6]:
def parse_html(html):
    bs = BeautifulSoup(html, 'html.parser')
    trs = bs.find_all("tr")
    nums_row = len(trs)
    max_col = 0
    nums_cell = 0
    for tr in trs:
        tds = tr.find_all("td")
        nums_cell += len(tds)
        if max_col < len(tds):
            max_col = len(tds)
    nums_col = max_col

    td_rowspans = bs.select('td[rowspan]')
    tmp_max_row_span = 0
    sum_rows_span = 0
    for td_rowspan in td_rowspans:
        sum_rows_span += int(td_rowspan['rowspan'])
        if tmp_max_row_span < int(td_rowspan['rowspan']):
            tmp_max_row_span = int(td_rowspan['rowspan'])

    td_colspans = bs.select('td[colspan]')
    tmp_max_col_span = 0
    sum_cols_span = 0
    for td_colspan in td_colspans:
        sum_cols_span += int(td_colspan['colspan'])
        if tmp_max_col_span < int(td_colspan['colspan']):
            tmp_max_col_span = int(td_colspan['colspan'])
    return nums_row, nums_col, tmp_max_row_span, tmp_max_col_span, nums_cell, len(td_rowspans), len(td_colspans), sum_rows_span, sum_cols_span

In [7]:
items = []
for i, line in enumerate(data_file):
#     if i > 20:
#         break
    item = json.loads(line)
    item['num_tokens_gt'] = len(tokenizer.encode(item['gt']))
#     item['num_tokens_pred'] = len(tokenizer.encode(item['pred']))
    item['len_str_gt'] = len(item['gt'])
#     item['len_str_pred'] = len(item['pred'])
    
    gt_structure_list = extract_tags(item['gt'])
#     pred_structure_list = extract_tags(item['pred'])
    gt_content = extract_contents(item['gt'])
#     pred_content = extract_contents(item['pred'])
    
    item['num_tag_tokens_gt'] = len(gt_structure_list)
#     item['num_tag_tokens_pred'] = len(pred_structure_list)
    item['num_content_tokens_gt'] = len(tokenizer.encode(gt_content))
#     item['num_content_tokens_pred'] = len(tokenizer.encode(pred_content))
    
    gt_nums_row, gt_nums_col, gt_max_row_span, gt_max_col_span, gt_nums_cell, gt_nums_rowspantag, gt_nums_colspantag, gt_sum_rows_span, gt_sum_cols_span = parse_html(item['gt'])
#     pr_nums_row, pr_nums_col, pr_max_row_span, pr_max_col_span, pr_nums_cell, pr_nums_rowspantag, pr_nums_colspantag, pr_sum_rows_span, pr_sum_cols_span = parse_html(item['pred'])
    
    item['num_cols_gt'] = gt_nums_col
#     item['num_cols_pred'] = pr_nums_col
    item['num_rows_gt'] = gt_nums_row
#     item['num_rows_pred'] = pr_nums_row
    
    item['num_cells_gt'] = gt_nums_col * gt_nums_row
#     item['num_cells_pred'] = pr_nums_col * pr_nums_row
    item['num_cells_span_gt'] = gt_nums_cell
#     item['num_cells_span_pred'] = pr_nums_cell
    
    
    item['num_spans_gt'] = gt_nums_rowspantag + gt_nums_colspantag
    item['num_row_spans_gt'] = gt_nums_rowspantag
    item['num_col_spans_gt'] = gt_nums_colspantag
#     item['num_spans_pred'] = pr_nums_rowspantag + pr_nums_colspantag
#     item['num_row_spans_pred'] = pr_nums_rowspantag
#     item['num_col_spans_pred'] = pr_nums_colspantag
    
    item['sum_spans_gt'] = gt_sum_rows_span + gt_sum_cols_span
    item['sum_row_spans_gt'] = gt_sum_rows_span
    item['sum_col_spans_gt'] = gt_sum_cols_span
    item['max_spans_gt'] = max(gt_max_row_span, gt_max_col_span)
    item['max_row_spans_gt'] = gt_max_row_span
    item['max_col_spans_gt'] = gt_max_col_span
    
#     item['sum_spans_pred'] = pr_sum_rows_span + pr_sum_cols_span
#     item['sum_row_spans_pred'] = pr_sum_rows_span
#     item['sum_col_spans_pred'] = pr_sum_cols_span
#     item['max_spans_pred'] = max(pr_max_row_span, pr_max_col_span)
#     item['max_row_spans_pred'] = pr_max_row_span
#     item['max_col_spans_pred'] = pr_max_col_span
    
    item['has_span_gt'] = gt_max_row_span > 1 or gt_max_col_span > 1
    item['has_row_span_gt'] = gt_max_row_span > 1
    item['has_col_span_gt'] = gt_max_col_span > 1
#     item['has_span_pred'] = pr_max_row_span > 1 or pr_max_col_span > 1
#     item['has_row_span_pred'] = pr_max_row_span > 1
#     item['has_col_span_pred'] = pr_max_col_span > 1
            
    
    item['cell_width_gt'] = item['image_width'] / item['num_cols_gt']
    item['cell_height_gt'] = item['image_height'] / item['num_rows_gt']
    
    # 총 span 갯수
    # rowspan 갯수
    # colspan 갯수
    # span 칸 최대
    # rowspan 칸 최대
    # colspan 칸 최대
    # span 칸 합친것
    # rowspan 칸 합친것
    # colsapn 칸 합친것
    # has span
    
    # tag token count
    # content token count
    # cell width & height
    #
    
    # 셀갯수
        # row 갯수
        # col 갯수
        ## 태그
        # 컨텐츠
    items.append(item)
print(len(items))

9112


In [8]:
print(len(items))
df = pd.DataFrame.from_records(items)
df

9112


Unnamed: 0,file_name,gt,pred,teds_all,teds_struct,teds_content,image_width,image_height,num_tokens_gt,len_str_gt,...,sum_row_spans_gt,sum_col_spans_gt,max_spans_gt,max_row_spans_gt,max_col_spans_gt,has_span_gt,has_row_span_gt,has_col_span_gt,cell_width_gt,cell_height_gt
0,PMC2915972_003_00.png,"<table><tr><td rowspan=""2"">Reactive astroglioi...","<table><tr><td rowspan=""2"">Reactive astroglioi...",0.722161,0.740741,0.828959,238,287,650,1668,...,25,2,6,6,2,True,True,True,59.500000,22.076923
1,PMC2915972_003_00.png,<table><tr><td>Name of algorithm</td><td>Notab...,<table><tr><td>Name of algorithm</td><td>Notab...,1.000000,1.000000,1.000000,238,287,203,567,...,0,0,0,0,0,False,False,False,119.000000,47.833333
2,PMC2915972_003_00.png,<table><tr><td></td><td>HC (N = 20)</td><td>FA...,<table><tr><td></td><td>HC (N = 20)</td><td>FA...,1.000000,1.000000,1.000000,238,287,170,340,...,0,0,0,0,0,False,False,False,79.333333,57.400000
3,PMC2915972_003_00.png,<table><tr><td></td><td>No of patients</td></t...,"<table><tr><td></td><td colspan=""2"">No of pati...",0.709845,0.724138,0.993056,238,287,427,888,...,0,2,2,0,2,True,False,True,119.000000,12.478261
4,PMC5451934_004_00.png,<table><tr><td>miRNA</td><td>Change relative t...,<table><tr><td>miRNA</td><td>Change relative t...,1.000000,1.000000,1.000000,389,56,1130,2091,...,0,0,0,0,0,False,False,False,48.625000,3.500000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9107,PMC3195381_002_00.png,"<table><tr><td colspan=""2"">ASE subscales item ...",<table><tr><td></td><td>ASE subscales item num...,0.938990,0.966942,0.967817,500,629,1211,3628,...,0,4,2,0,2,True,False,True,250.000000,15.341463
9108,PMC4226882_007_00.png,"<table><tr><td></td><td colspan=""6"">After chro...","<table><tr><td></td><td colspan=""6"">After chro...",0.942212,0.952381,0.920000,238,105,335,698,...,0,12,6,0,6,True,False,True,34.000000,17.500000
9109,PMC4226882_007_00.png,<table><tr><td>Factor Name</td><td>Individual ...,<table><tr><td>Factor Name</td><td>Individual ...,0.925729,0.961538,0.871272,238,105,1364,3767,...,0,0,0,0,0,False,False,False,59.500000,4.038462
9110,PMC4226882_007_00.png,<table><tr><td>Tissue</td><td>RMH EC</td><td>R...,<table><tr><td>Tissue</td><td>RMH EC</td><td>R...,1.000000,1.000000,1.000000,238,105,167,318,...,0,0,0,0,0,False,False,False,47.600000,26.250000


In [9]:
# profile = pp.ProfileReport(df)
# profile.to_file(eda_output_path)

In [10]:
# df.to_pickle(df_output_path)

In [11]:
key_dict = {
    'num_tokens_gt' : 10,
    'num_tag_tokens_gt' : 10,
    'num_content_tokens_gt' : 10,
    'num_spans_gt' : 10,
    'num_row_spans_gt' : 10,
    'num_col_spans_gt' : 10,
    'sum_spans_gt' : 10,
    'sum_row_spans_gt' : 10,
    'sum_col_spans_gt' : 10,
    'has_span_gt' : 2,
    'has_row_span_gt' : 2,
    'has_col_span_gt' : 2,
    'num_cells_gt' : 10,
    'num_cells_span_gt' : 10,
    'num_cols_gt' : 10,
    'num_rows_gt' : 10,
    'cell_width_gt' : 10,
    'cell_height_gt' : 10,
    'len_str_gt': 10
}

In [12]:
for key in key_dict:
    num_bins = key_dict[key]
    labels = pd.cut(df[key], num_bins, retbins=True)[1]
    df[key + '_class'] = pd.cut(df[key], num_bins, labels=labels[1:])
    gb = df.groupby(key + '_class')['teds_all']
    gbm = gb.mean()
    gbc = gb.count()
    print("")
    print(key)
    for v in list(zip(gbm.index.values, gbm, gbc)):
        print(v)



num_tokens_gt
(573.4, 0.9450900751773313, 4559)
(1082.8, 0.9216335341999353, 2697)
(1592.1999999999998, 0.8855480480416218, 1102)
(2101.6, 0.7870289315273937, 436)
(2611.0, 0.6222968155295836, 183)
(3120.3999999999996, 0.5211435273137541, 77)
(3629.7999999999997, 0.42382451994861425, 29)
(4139.2, 0.3439093099240552, 17)
(4648.599999999999, 0.3834080574607154, 10)
(5158.0, 0.25515168057413534, 2)

num_tag_tokens_gt
(152.2, 0.9346641664681362, 5213)
(294.4, 0.9184099161756044, 2421)
(436.59999999999997, 0.8751164686607688, 898)
(578.8, 0.7700014192138301, 337)
(721.0, 0.6738280964920469, 136)
(863.1999999999999, 0.5740515168144583, 62)
(1005.3999999999999, 0.5191878863648344, 27)
(1147.6, 0.3678145909547292, 11)
(1289.8, 0.45468830066143995, 4)
(1432.0, 0.4952012980534581, 3)

num_content_tokens_gt
(278.4, 0.942801998984034, 5670)
(540.8, 0.9176136323657867, 2270)
(803.1999999999999, 0.8187175788817241, 800)
(1065.6, 0.6175139411355821, 248)
(1328.0, 0.47038292857663216, 76)
(1590.39999

In [13]:
for key in key_dict:
    num_bins = key_dict[key]
    labels = pd.cut(df[key], num_bins, retbins=True)[1]
    df[key + '_class'] = pd.cut(df[key], num_bins, labels=labels[1:])
    gb = df.groupby(key + '_class')['teds_struct']
    gbm = gb.mean()
    gbc = gb.count()
    print("")
    print(key)
    for v in list(zip(gbm.index.values, gbm, gbc)):
        print(v)



num_tokens_gt
(573.4, 0.9571847719747154, 4559)
(1082.8, 0.9376495605770054, 2697)
(1592.1999999999998, 0.9082775928852155, 1102)
(2101.6, 0.8221651373213332, 436)
(2611.0, 0.6632297497508852, 183)
(3120.3999999999996, 0.5653602996262391, 77)
(3629.7999999999997, 0.47003141307333557, 29)
(4139.2, 0.3888060423847612, 17)
(4648.599999999999, 0.4584013595859225, 10)
(5158.0, 0.3263309803404143, 2)

num_tag_tokens_gt
(152.2, 0.9488332227156228, 5213)
(294.4, 0.9341767920336396, 2421)
(436.59999999999997, 0.898532055446964, 898)
(578.8, 0.8003432327051682, 337)
(721.0, 0.7120767610919557, 136)
(863.1999999999999, 0.6122586200942582, 62)
(1005.3999999999999, 0.5563020728281617, 27)
(1147.6, 0.39030947515317005, 11)
(1289.8, 0.5348401810884654, 4)
(1432.0, 0.5502568192041876, 3)

num_content_tokens_gt
(278.4, 0.9547882607380892, 5670)
(540.8, 0.9347270245908246, 2270)
(803.1999999999999, 0.8526740484720037, 800)
(1065.6, 0.6668674035067036, 248)
(1328.0, 0.5147184685670512, 76)
(1590.3999999

In [14]:
keys = ['max_spans_gt','max_row_spans_gt','max_col_spans_gt',]
for k in keys:
    gb = df.groupby(k)['teds_all']
    gbm = gb.mean()
    gbc = gb.count()
    print(k)
    for v in list(zip(gbm.index.values, gbm, gbc)):
        print(v)


max_spans_gt
(0, 0.9426347040120482, 4650)
(2, 0.8974139891732827, 1489)
(3, 0.8907138811266726, 1029)
(4, 0.8704443385904445, 704)
(5, 0.8840186642554205, 418)
(6, 0.8411315866237623, 334)
(7, 0.8341000034291245, 199)
(8, 0.7934361826983839, 130)
(9, 0.7804842888927935, 91)
(10, 0.7406963498668229, 68)
max_row_spans_gt
(0, 0.9273710525060468, 6927)
(2, 0.8852056727976281, 1370)
(3, 0.8369431793643449, 433)
(4, 0.7828071895419646, 140)
(5, 0.8115344086807222, 86)
(6, 0.7271556244879961, 67)
(7, 0.674923602335092, 27)
(8, 0.7282429096706965, 30)
(9, 0.7503053486904719, 17)
(10, 0.6575148383746453, 15)
max_col_spans_gt
(0, 0.9275808183370873, 5286)
(2, 0.8953079034203503, 1378)
(3, 0.9037143890848199, 818)
(4, 0.8851437866130959, 600)
(5, 0.8955063757855445, 351)
(6, 0.8665899196421013, 273)
(7, 0.8502354025642015, 178)
(8, 0.8129941646066903, 100)
(9, 0.7874172886690025, 74)
(10, 0.7547566178257019, 54)


In [15]:
keys = ['max_spans_gt','max_row_spans_gt','max_col_spans_gt',]
for k in keys:
    gb = df.groupby(k)['teds_struct']
    gbm = gb.mean()
    gbc = gb.count()
    print(k)
    for v in list(zip(gbm.index.values, gbm, gbc)):
        print(v)


max_spans_gt
(0, 0.9573811441510244, 4650)
(2, 0.9127543771881699, 1489)
(3, 0.9073240266471583, 1029)
(4, 0.8889640371319861, 704)
(5, 0.9022275366686404, 418)
(6, 0.8636621280127728, 334)
(7, 0.8610719321922408, 199)
(8, 0.8278297450680835, 130)
(9, 0.8145748809499519, 91)
(10, 0.7844231464817262, 68)
max_row_spans_gt
(0, 0.9429197213505769, 6927)
(2, 0.9027537645754846, 1370)
(3, 0.8578532713544139, 433)
(4, 0.8121650568054533, 140)
(5, 0.8369170587852565, 86)
(6, 0.7587994382726944, 67)
(7, 0.7268089269861969, 27)
(8, 0.7582448123953022, 30)
(9, 0.779415244615401, 17)
(10, 0.6964272253682801, 15)
max_col_spans_gt
(0, 0.9433709609982338, 5286)
(2, 0.910248472152802, 1378)
(3, 0.9199314799590108, 818)
(4, 0.9013840299044629, 600)
(5, 0.9136195354033639, 351)
(6, 0.8869356843477022, 273)
(7, 0.876502917029627, 178)
(8, 0.8487052248699181, 100)
(9, 0.8226520947024839, 74)
(10, 0.8059141774117256, 54)


In [22]:
print("teds struct")
print(df['teds_struct'].mean())

print("teds all")
print(df['teds_all'].mean())

teds struct
0.9265159376725594
teds all
0.9097695324502346


In [16]:
df['len_str_gt'].idxmax()

2957

In [17]:
df['len_str_gt'].max()

10578

In [18]:
print(df.loc[2957]['gt'].count("<td>"))
print(df.loc[2957]['gt'].count("<tr>"))

650
65


In [19]:
import re
def remove_html_tags(text):
    """Remove html tags from a string"""
    import re
    clean = re.compile('<.*?>')
    return re.sub(clean, '', text)


multiple_space_re = re.compile(r'[ ]{2,}')


def convert(text):
    text = " ".join([remove_html_tags(token) for token in text.split("</td>")])
    text = multiple_space_re.sub(' ', text)
    return "<tr><td>{}</td></tr>".format(text.strip())

In [20]:
convert(df.loc[2957]['gt'])

'<tr><td>No. ETHNIC ISO639-3 FAMILY SUB-FAMILY BRANCH POPULATION COUNTRY PROVINCE COUNTY D1 Bolyu ply Austro-Asiatic Mon-Khmer Palyu 10,000 China Guangxi Longlin D2 Yerong yrn Daic Kadai Bu-Rong 400 China Guangxi Napo D3 Qau gio Daic Kadai Ge-Chi 3,000 China Guizhou Bijie D4 Blue-Gelao giq Daic Kadai Ge-Chi 1,700 China Guangxi Longlin D5 Lachi lbt Daic Kadai Ge-Chi 9,016 China Yunnan Maguan D6 Mollao Daic Kadai Ge-Chi 30,000 China Guizhou Majiang D7 Red-Gelao gir Daic Kadai Ge-Chi 1,500 China Guizhou Dafang D8 White-Gelao giw Daic Kadai Ge-Chi 1,200 China Yunnan Malipo D9 Hlai-Qi lic Daic Kadai Hlai 747,000 China Hainan Tongza D10 Jiamao jio Daic Kadai Hlai 52,300 China Hainan Baoting D11 Buyang byu Daic Kadai Yang-Biao 3,000 China Yunnan Guangnan D12 Cun cuq Daic Kadai Yang-Biao 70,000 China Hainan Dongfang D13 Laqua laq Daic Kadai Yang-Biao 307 China Yunnan Malipo D14 Man-Caolan mlc Daic Kam-Tai Be-Tai 114,000 China Guangxi Fangcheng D15 Zhuang-N ccx Daic Kam-Tai Be-Tai 10,000,000 Ch