1. 将"*-pred.json"文件进行后处理
2. 将submit文件压缩打包

In [5]:
from glob import glob
from tqdm import tqdm
import os
import json
import cv2
import numpy as np
from pathlib import Path
import shutil
# from prettyprinter import cpprint, set_default_style



def post_process(table):
    row_record_y0 = [[] for i in range(len(table['layout']))]
    row_record_y1 = [[] for i in range(len(table['layout']))]
    col_record_x0 = [[] for i in range(len(table['layout'][0]))]
    col_record_x1 = [[] for i in range(len(table['layout'][0]))]

    for i, cell in enumerate(table['cells']):
        x0, y0, x1, y1 = cell['bbox']
        if x0 == 0 and y0 == 0 and x1 == 0 and y1 == 0:
            continue
        row_start = cell['row_start_idx']
        row_end   = cell['row_end_idx']
        col_start = cell['col_start_idx']
        col_end   = cell['col_end_idx']
        if row_start == row_end:
            row_record_y0[row_start].append([y0])
            row_record_y1[row_start].append([y1])
        if col_start == col_end:
            col_record_x0[col_start].append([x0])
            col_record_x1[col_start].append([x1])
    
    # 填充空bbox坐标
    for i, cell in enumerate(table['cells']):
        x0, y0, x1, y1 = cell['bbox']
        if x0 == 0 and y0 == 0 and x1 == 0 and y1 == 0:
            row_start = cell['row_start_idx']
            row_end   = cell['row_end_idx']
            col_start = cell['col_start_idx']
            col_end   = cell['col_end_idx']
            if len(row_record_y0[row_start]) == 0:
                continue
            if len(row_record_y1[row_end]) == 0:
                continue
            if len(col_record_x0[col_start]) == 0:
                continue
            if len(col_record_x1[col_end]) == 0:
                continue

            x0 = int(np.mean(col_record_x0[col_start]))
            x1 = int(np.mean(col_record_x1[col_end]))
            y0 = int(np.mean(row_record_y0[row_start]))
            y1 = int(np.mean(row_record_y1[row_end]))
            table['cells'][i]['bbox'] = [x0, y0, x1, y1]
            table['cells'][i]['segmentation'] = [[[x0, y0], [x1, y0], [x1, y1], [x0, y1]]]
    
    # 拓展多行多列cell的bbox
    # for i, cell in enumerate(table['cells']):
    #     x0, y0, x1, y1 = cell['bbox']
    #     row_start = cell['row_start_idx']
    #     row_end   = cell['row_end_idx']
    #     col_start = cell['col_start_idx']
    #     col_end   = cell['col_end_idx']
    #     if row_start != row_end:
    #         if len(row_record_y0[row_start]) > 0 and len(row_record_y1[row_end]) > 0:
    #             y0 = np.mean(row_record_y0[row_start])
    #             y1 = np.mean(row_record_y1[row_end])
    #     if col_start != col_end:
    #         if len(col_record_x0[col_start]) > 0 and len(col_record_x2[col_end]) > 0:
    #             x0 = np.mean(col_record_x1[col_start])
    #             x1 = np.mean(col_record_x2[col_end])
    #     table['cells'][i]['bbox'] = [x0, y0, x1, y1]
    #     table['cells'][i]['segmentation'] = [[[x0, y0], [x1, y0], [x1, y1], [x0, y1]]]

    return table


def check_table_format(table):
    if 'layout' not in table or 'cells' not in table:
        return False
    
    for cell in table['cells']:
        if 'col_start_idx' not in cell or 'col_end_idx' not in cell or 'row_start_idx' not in cell or 'row_end_idx' not in cell:
            return False
        if 'bbox' not in cell or 'segmentation' not in cell:
            return False
        if 'transcript' not in cell:
            return False

    return True


In [6]:

DATASET = "test_A_jpg480max"


pred_json_dir = f"./output/structure_result/{DATASET}-pred.json"
jpg_json_path = f"/media/ubuntu/Date12/TableStruct/new_data/{DATASET}.json"
jpg_info = json.load(open(jpg_json_path, 'r'))

submit_dir = f"./output/structure_result/{DATASET}_submit"
if os.path.exists(submit_dir):
    shutil.rmtree(submit_dir)
Path(submit_dir).mkdir(parents=True, exist_ok=True)

for pred_json_path in tqdm(glob(os.path.join(pred_json_dir, "*-pred.json"))):
    pred_json_name = os.path.basename(pred_json_path)
    save_json_path = os.path.join(submit_dir, pred_json_name)

    table = json.load(open(pred_json_path, "r"))
    table = post_process(table)

    img_id = pred_json_name.split(".")[0].replace("-pred", "")
    for idx, cell in enumerate(table['cells']):
        cell["bbox"] = np.array(cell["bbox"]) / jpg_info[img_id]["scale"]
        cell["bbox"] = cell["bbox"].astype(np.int32).tolist()
        segmentation = []
        for seg in cell["segmentation"]:
            seg = np.array(seg) / jpg_info[img_id]["scale"]
            seg = seg.astype(np.int32).tolist()
            segmentation.append(seg)
        cell["segmentation"] =segmentation 
        table['cells'][idx] = cell

    if check_table_format(table) == False:
        print(f"check_table_format error: {pred_json_path}")
        break
    json.dump(table, open(save_json_path, "w"), indent=4, ensure_ascii=False)

# Q.用这个好像有问题？ A.少了一层submit.zip/submit
# shutil.make_archive(submit_dir, 'zip', submit_dir)

100%|██████████| 5187/5187 [00:31<00:00, 165.44it/s]
