In [5]:
import os
import sys
import cv2
import json
import shutil
import numpy as np
from glob import glob
from tqdm import tqdm
from pathlib import Path
from sklearn.model_selection import KFold

from utils.table2label import table2layout, fuse_gt_info, judge_error
from utils.table_helper import correct_table
from utils.format_translate import segmentation_to_bbox



def table2label(table_dir, label_dir, error_file_path):
    table_error = {}
    json_files = sorted(glob(os.path.join(table_dir, '*.json')))
    for idx, json_path in tqdm(enumerate(json_files), total=len(json_files)):
        json_dir = os.path.dirname(json_path)
        json_name = os.path.basename(json_path)
        # json_name = "06615.json"
        # json_path = os.path.join(json_dir, json_name)

        json_id = json_name.split('.')[0]
        table = json.load(open(json_path, 'r'))
        
        if not table['is_wireless']:
            continue

        # table['is_wireless'] = True

        # table = correct_table(table)
        try:
            gt_label = table2layout(table)
        except:
            table_error[json_id] = 'table2layout error'
            continue

        ## 有线表格得到的bbox还是cell框，不是text框
        try:
            gt_label = fuse_gt_info(gt_label, table)
        except:
            table_error[json_id] = "fuse_gt_info error" # 仅有1cell的有线表格 滤过
            continue

        valid, msg = judge_error(table, gt_label)
        if not valid:
            print(json_name, msg)
            table_error[json_id] = msg
            continue

        gt_json_path = os.path.join(label_dir, f'{json_id}-gt.json')
        json.dump(gt_label, open(gt_json_path, 'w'), indent=4)
    json.dump(table_error, open(error_file_path, 'w'), indent=4)

    print('table error: {}'.format(len(table_error)))

## STEP.1 gen_gt_labels

In [6]:
'''
输入
    训练集目录: {dataset_root}/train
输出
    训练集标注目录: {dataset_root}/train_gt_json/
    {dataset_root}/train_error.json
'''
DATASET = "train_jpg480max"
DATASET_ROOT = '/media/ubuntu/Date12/TableStruct/new_data'
IFTABLE_LINE_ROOT = '/media/ubuntu/Date12/TableStruct/iftable_line'

if os.path.exists(IFTABLE_LINE_ROOT):
    shutil.rmtree(IFTABLE_LINE_ROOT)
os.makedirs(os.path.join(IFTABLE_LINE_ROOT, "imgs", "train"), exist_ok=True)
os.makedirs(os.path.join(IFTABLE_LINE_ROOT, "imgs", "val"), exist_ok=True)

error_json = os.path.join(DATASET_ROOT, f"{DATASET}_error.json")
error_json = json.load(open(error_json, 'r'))

rc_label_dir = os.path.join(DATASET_ROOT, DATASET)


json_paths = sorted(glob(os.path.join(rc_label_dir, '*.json')))
json_paths = [json_path for json_path in json_paths if os.path.basename(json_path).split(".")[0] not in error_json]
json_paths = np.array(json_paths)

kf = KFold(n_splits=10, shuffle=True, random_state=42)

for i, (train_idx, val_idx) in enumerate(kf.split(json_paths)):
    train_set = json_paths[train_idx]
    train_set = np.random.choice(train_set, 4800, replace=False)
    val_set   = json_paths[val_idx]
    val_set   = np.random.choice(val_set, 200, replace=False)

    train_label = dict(images=[], annotations=[], categories=[dict(id=1, name="text")])

    line_anno_idx = 0
    for idx, json_path in tqdm(enumerate(train_set), total=len(train_set)):
        img_id = os.path.basename(json_path).split(".")[0]
        img_path = os.path.join(DATASET_ROOT, f"{DATASET}", f"{img_id}.jpg")
        height, width, _ = cv2.imread(img_path).shape
        train_label["images"].append(dict(file_name=f"train/{img_id}.jpg",
                                          height=height, width=width,
                                          segm_file=f'train/{img_id}.txt', id=idx))

        rc_label = json.load(open(json_path, 'r'))
        for line in rc_label['line']:
            x0, y0, x1, y1 = segmentation_to_bbox([line])
            x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
            seg = [[x0, y0, x0, y1, x1, y1, x1, y0]]
            train_label["annotations"].append(dict(id=line_anno_idx,
                                                   image_id=idx,
                                                   category_id=1,
                                                   segmentation=seg,
                                                   iscrowd=0,
                                                   area=(x1-x0)*(y1-y0),
                                                   bbox=[x0, y0, x1-x0, y1-y0]))
            line_anno_idx += 1
        dst_img_path = os.path.join(IFTABLE_LINE_ROOT, "imgs", "train", f"{img_id}.jpg") 
        shutil.copy2(img_path, dst_img_path)
    train_json_path = os.path.join(IFTABLE_LINE_ROOT, "instances_train.json")
    json.dump(train_label, open(train_json_path, 'w'), indent=4)


    line_anno_idx = 0
    val_label = dict(images=[], annotations=[], categories=[dict(id=1, name="text")])
    for idx, json_path in tqdm(enumerate(val_set), total=len(val_set)):
        img_id = os.path.basename(json_path).split(".")[0]
        img_path = os.path.join(DATASET_ROOT, f"{DATASET}", f"{img_id}.jpg")
        height, width, _ = cv2.imread(img_path).shape
        val_label["images"].append(dict(file_name=f"val/{img_id}.jpg",
                                        height=height, width=width,
                                        segm_file=f'val/{img_id}.txt', id=idx))
        
        rc_label = json.load(open(json_path, 'r'))
        for line in rc_label['line']:
            x0, y0, x1, y1 = segmentation_to_bbox([line])
            x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
            seg = [[x0, y0, x0, y1, x1, y1, x1, y0]]
            val_label["annotations"].append(dict(id=line_anno_idx,
                                                 image_id=idx,
                                                 category_id=1,
                                                 segmentation=seg,
                                                 iscrowd=0,
                                                 area=(x1-x0)*(y1-y0),
                                                 bbox=[x0, y0, x1-x0, y1-y0]))
            line_anno_idx += 1
        dst_img_path = os.path.join(IFTABLE_LINE_ROOT, "imgs", "val", f"{img_id}.jpg")
        shutil.copy2(img_path, dst_img_path)
    val_json_path = os.path.join(IFTABLE_LINE_ROOT, "instances_val.json")
    json.dump(val_label, open(val_json_path, 'w'), indent=4)
    break


100%|██████████| 4800/4800 [03:21<00:00, 23.81it/s] 
100%|██████████| 200/200 [00:07<00:00, 26.20it/s]
