In [2]:

import os
import numpy as np
import json
from utils.teds_utils import html_to_table, format_table, format_tokens, format_layout, \
    remove_empty_bboxes, get_html
from utils.format_translate import segmentation_to_bbox
import shutil
from tqdm import tqdm



def format_table_1(table): 
    layout = table['layout']
    num = layout.max() + 1
    idx = 0
    new_cells = []
    cell_cord = set()
    for i, row in enumerate(layout):
        for j, cell_id in enumerate(row):
            if cell_id == -1:
                layout[i, j] = num + idx
                idx += 1
                empty_cell = dict(
                    col_start_idx=j,
                    row_start_idx=i,
                    col_end_idx=j,
                    row_end_idx=i,
                    transcript = '',
                    bbox = [0, 0, 0, 0],
                    segmentation = [[[0, 0], [0, 0], [0, 0], [0, 0]]]
                )
                new_cells.append(empty_cell)
            else:
                if cell_id not in cell_cord:
                    cell_cord.add(cell_id)
                    new_cells.append(table['cells'][cell_id])

    new_layout = format_layout(layout)
    assert len(new_cells) == new_layout.max() + 1

    table = dict(
        layout=new_layout,
        cells=new_cells
    )

    return table


def merge_token(token_list):
    """
    This function used to merge the common tokens of raw tokens, and reduce the max length.
    eg. merge '<td>' and '</td>' to '<td></td>' which are always appear together.
    :param token_list: [list]. the raw tokens from the json line file.
    :return: merged tokens.
    """
    pointer = 0
    merge_token_list = []
    # </tbody> is the last token str.
    while token_list[pointer] != '</tbody>':
        if token_list[pointer] == '<td>':
            tmp = token_list[pointer] + token_list[pointer+1]
            merge_token_list.append(tmp)
            pointer += 2
        else:
            merge_token_list.append(token_list[pointer])
            pointer += 1
    merge_token_list.append('</tbody>')
    return merge_token_list

EXT_DATA_ROOT = "/media/ubuntu/Date12/TableStruct/ext_data"

ext_data_save_dir = os.path.join(EXT_DATA_ROOT, "train")
if os.path.exists(ext_data_save_dir):
    shutil.rmtree(ext_data_save_dir)
shutil.copytree(os.path.join(EXT_DATA_ROOT, "img2"), ext_data_save_dir)

ext_gt_txt_file = os.path.join(EXT_DATA_ROOT, "gt2.txt")
gt_data = open(ext_gt_txt_file, 'r').readlines()
for data in tqdm(gt_data):
    data = eval(data)

    imgname = os.path.basename(data['filename'])

    bboxes = []
    for cell in data['html']['cells']:
        bboxes.append(segmentation_to_bbox(cell['bbox']))
    bboxes = remove_empty_bboxes(bboxes) 

    tokens = data['html']['structure']['tokens']
    tokens = format_tokens(','.join(tokens))
    tokens = merge_token(tokens)

    html = get_html(tokens, bboxes, bbox_format='xyxy')
    table = html_to_table(html)
        
    try:
        table_new = format_table(table)
    except:
        print("format_table error")
        ### 去除layout中的-1
        table = format_table_1(table)
        table_new = format_table(table)
    
    img_id = imgname.split(".")[0]
    save_path = os.path.join(ext_data_save_dir, f'{img_id}.json')

    json.dump(table_new, open(save_path, 'w'), indent=4, ensure_ascii=False)


100%|██████████| 3000/3000 [00:24<00:00, 122.78it/s]
