训练集标注可视化

In [2]:
import cv2
import os
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
from glob import glob
import json
import numpy as np
from pathlib import Path


def table_img_visualize(src_img_path, pred_json_path, dst_img_path):
    # 按行生成彩色标注框
    row2color = []
    for i in range(3):
        for j in range(3):
            for k in range(3):
                row2color.append((i*100, j*100, k*100))
    row2color = np.array(row2color)

    # 加载图像
    src_img = cv2.imread(src_img_path)
    json_data = json.load(open(pred_json_path, 'r'))
    # 遍历坐标列表
    for i, cell in enumerate(json_data["cells"]):
        x1, y1, x2, y2 = np.array(cell['bbox']).astype(dtype=int).tolist()
        row_start, row_end = cell["row_start_idx"], cell["row_end_idx"]
        col_start, col_end = cell["col_start_idx"], cell["col_end_idx"]
        text = f"{row_start}-{row_end}, {col_start}-{col_end}"

        if row_end >= row2color.shape[0]:
            color = np.array((0, 0, 255))
        else:
            color = (row2color[row_start] + row2color[row_end])//2
        cv2.rectangle(src_img, (x1, y1, x2-x1, y2-y1), color.tolist(), 2)
        cv2.putText(src_img, text, (x1, y1+2), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color.tolist(), 2)

    cv2.imwrite(dst_img_path, src_img)

# 数据集名称
dataset = "train"
# 输入数据目录
img_data_dir    = f"/media/ubuntu/Date12/TableStruct/data/{dataset}/"
json_data_dir    = f"/media/ubuntu/Date12/TableStruct/data/{dataset}_gt_json/"
# 输出目录
pred_visual_dir = f"./output/structure_result/{dataset}-visualize/"
Path(pred_visual_dir).mkdir(parents=True, exist_ok=True)

for src_img_path in glob(img_data_dir+"*.png"):
    src_img_file = os.path.basename(src_img_path)
    pred_json_path = os.path.join(json_data_dir, src_img_file.replace(".png", "-gt.json"))
    dst_img_path = os.path.join(pred_visual_dir, src_img_file.replace(".png", ".jpg"))
    table_img_visualize(src_img_path, pred_json_path, dst_img_path)
