In [1]:

import os
import numpy as np
import json
from utils.format_translate import segmentation_to_bbox, html_to_table
from utils.utils import format_table, format_table_1, format_layout, format_tokens, remove_empty_bboxes, get_html
import shutil
from tqdm import tqdm


def merge_token(token_list):
    """
    This function used to merge the common tokens of raw tokens, and reduce the max length.
    eg. merge '<td>' and '</td>' to '<td></td>' which are always appear together.
    :param token_list: [list]. the raw tokens from the json line file.
    :return: merged tokens.
    """
    pointer = 0
    merge_token_list = []
    # </tbody> is the last token str.
    while token_list[pointer] != '</tbody>':
        if token_list[pointer] == '<td>':
            tmp = token_list[pointer] + token_list[pointer+1]
            merge_token_list.append(tmp)
            pointer += 2
        else:
            merge_token_list.append(token_list[pointer])
            pointer += 1
    merge_token_list.append('</tbody>')
    return merge_token_list


In [2]:

EXT_DATA_ROOT = "/media/ubuntu/Date12/TableStruct/ext_data"

ext_data_save_dir = os.path.join(EXT_DATA_ROOT, "train")
if os.path.exists(ext_data_save_dir):
    shutil.rmtree(ext_data_save_dir)
shutil.copytree(os.path.join(EXT_DATA_ROOT, "img2"), ext_data_save_dir)
shutil.copytree(os.path.join(EXT_DATA_ROOT, "img1"), ext_data_save_dir, dirs_exist_ok=True)


for gt_file in ["gt1.txt", "gt2.txt"]:
    ext_gt_txt_file = os.path.join(EXT_DATA_ROOT, gt_file)
    gt_data = open(ext_gt_txt_file, 'r').readlines()
    for data in tqdm(gt_data):
        data = eval(data)

        imgname = os.path.basename(data['filename'])
        img_id = imgname.split(".")[0]

        cells = []
        for idx, cell in enumerate(data['html']['cells']):
            bbox = segmentation_to_bbox(cell['bbox'])
            tokens = '' if cell['tokens'] == [] else f'{idx}'
            cells.append(dict(bbox=bbox, tokens=tokens))

        tokens = data['html']['structure']['tokens']
        if '<thead>' not in tokens and '<tbody>' not in tokens:
            tokens.insert(0, '<tbody>')
            tokens.append('</tbody>')
        # tokens = format_tokens(','.join(tokens))
        # tokens = merge_token(tokens)

        html = dict(html=dict(structure=dict(tokens=tokens), cells=cells))

        table = html_to_table(html)
        table_new = format_table(table)
        
        save_path = os.path.join(ext_data_save_dir, f'{img_id}.json')
        json.dump(table_new, open(save_path, 'w'), indent=4, ensure_ascii=False)

100%|██████████| 5000/5000 [00:18<00:00, 275.14it/s]
100%|██████████| 3000/3000 [00:22<00:00, 131.14it/s]
