1. 将"*-pred.json"文件进行后处理
2. 将submit文件压缩打包

In [3]:
from glob import glob
from tqdm import tqdm
import os
import json
import cv2
import numpy as np
from pathlib import Path
import shutil
# from prettyprinter import cpprint, set_default_style
from utils.format_translate import table_to_html, html_to_table, format_html
from utils.cal_f1 import table_to_relations, evaluate_f1
from utils.metric import TEDSMetric
 


def cal_metric(pred, label):
    # trans layout to np.narray
    pred['layout'] = np.array(pred['layout'])
    label['layout'] = np.array(label['layout'])

    # calculate F1-Measure
    pred_relations = table_to_relations(pred)
    label_relations = table_to_relations(label)
    f1 = evaluate_f1([label_relations], [pred_relations], num_workers=1)


    # calculate TEDS-Struct
    pred_htmls = table_to_html(pred)
    pred_htmls = format_html(pred_htmls)


    label_htmls = table_to_html(label)
    label_htmls = format_html(label_htmls)

    teds_metric = TEDSMetric(num_workers=1, structure_only=False)
    teds_info = teds_metric([pred_htmls], [label_htmls])

    # calculate final metric base on macro
    metric = 0
    for idx in range(len(teds_info)):
        # metric += 0.5 * f1[idx][-1] + 0.5 * teds_info[idx]
        metric += teds_info[idx]
    metric = metric / len(teds_info)
    # print('final metric is %.2f' % metric)

    pred['layout'] = pred['layout'].tolist()
    label['layout'] = label['layout'].tolist()

    return metric


In [4]:
# 待融合的3个数据集
pred_data_root = "./output/structure_result/"
predA = "10fold2_epoch_29_val92.87_test92.21"
predB = "10fold1_epoch_30_val92.90_test92.20"
predC = "10fold0_epoch_18_val93.98_test92.19"

predAdir = os.path.join(pred_data_root, predA)
predBdir = os.path.join(pred_data_root, predB)
predCdir = os.path.join(pred_data_root, predC)

fusion_dir = os.path.join(pred_data_root, "FINAL_FUSION")
Path(fusion_dir).mkdir(parents=True, exist_ok=True)

for pred_path in sorted(glob(os.path.join(predAdir, '*.json'))):
    pred_file = os.path.basename(pred_path)

    predA_path = os.path.join(predAdir, pred_file)
    predB_path = os.path.join(predBdir, pred_file)
    predC_path = os.path.join(predCdir, pred_file)

    preda = json.load(open(predA_path))
    predb = json.load(open(predB_path))
    predc = json.load(open(predC_path))

    metric_ab = cal_metric(preda, predb)
    metric_bc = cal_metric(predb, predc)
    metric_ac = cal_metric(preda, predc)
    if metric_ab >= metric_bc and metric_ab >= metric_ac:
        fin_pred = preda
    elif metric_bc >= metric_ab and metric_bc >= metric_ac:
        fin_pred = predb
    else:
        fin_pred = preda
    fusion_path = os.path.join(fusion_dir, pred_file)
    json.dump(fin_pred, open(fusion_path, 'w'), indent=4, ensure_ascii=False)



# Q.用这个好像有问题？ A.少了一层submit.zip/submit
# shutil.make_archive(submit_dir, 'zip', submit_dir)

100%|██████████| 1/1 [00:00<00:00, 12.29it/s]
100%|██████████| 1/1 [00:00<00:00, 12.12it/s]
100%|██████████| 1/1 [00:00<00:00, 11.22it/s]
100%|██████████| 1/1 [00:00<00:00, 72.29it/s]
100%|██████████| 1/1 [00:00<00:00, 85.17it/s]
100%|██████████| 1/1 [00:00<00:00, 84.50it/s]
100%|██████████| 1/1 [00:00<00:00,  7.97it/s]
100%|██████████| 1/1 [00:00<00:00,  7.54it/s]
100%|██████████| 1/1 [00:00<00:00,  7.24it/s]
100%|██████████| 1/1 [00:00<00:00,  4.69it/s]
100%|██████████| 1/1 [00:00<00:00,  4.99it/s]
100%|██████████| 1/1 [00:00<00:00,  4.45it/s]
100%|██████████| 1/1 [00:00<00:00,  9.64it/s]
100%|██████████| 1/1 [00:00<00:00,  9.54it/s]
100%|██████████| 1/1 [00:00<00:00,  9.48it/s]
100%|██████████| 1/1 [00:00<00:00,  7.12it/s]
100%|██████████| 1/1 [00:00<00:00,  8.21it/s]
100%|██████████| 1/1 [00:00<00:00,  8.92it/s]
100%|██████████| 1/1 [00:00<00:00, 27.03it/s]
100%|██████████| 1/1 [00:00<00:00, 27.15it/s]
100%|██████████| 1/1 [00:00<00:00, 13.96it/s]
100%|██████████| 1/1 [00:00<00:00,