训练集标注可视化

In [None]:
import cv2
import os
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
from glob import glob
import json
import numpy as np
import shutil
from pathlib import Path
from tqdm import tqdm


def img_label_visualize(src_img_path, pred_json_path, dst_img_path):
    # 按行生成彩色标注框
    col2color = []
    for i in range(3):
        for j in range(3):
            for k in range(3):
                col2color.append((i*100, j*100, k*100))
    col2color.append((255, 0, 0))
    col2color = np.array(col2color)

    # 加载图像
    src_img = cv2.imread(src_img_path)
    json_data = json.load(open(pred_json_path, 'r'))
    # 遍历坐标列表
    for i, cell in enumerate(json_data["cells"]):
        x1, y1, x2, y2 = np.array(cell['bbox']).astype(dtype=int).tolist()
        row_start, row_end = cell["row_start_idx"], cell["row_end_idx"]
        col_start, col_end = cell["col_start_idx"], cell["col_end_idx"]
        text = f"{row_start}" if row_start == row_end else f"{row_start}-{row_end}"
        text += f",{col_start}" if col_start == col_end else f",{col_start}-{col_end}"

        
        color = col2color[col_start] if col_start < col2color.shape[0] else col2color[-1]
        cv2.rectangle(src_img, (x1, y1, x2-x1, y2-y1), color.tolist(), 1)
        width, height = src_img.shape[1], src_img.shape[0]
        font_scale = min(width, height)//1024
        font_scale = max(font_scale, 0.5)
        font_thickness = int(font_scale * 2)
        font_height = int(font_scale * 20)
        cv2.putText(src_img, text, (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, font_scale, color.tolist(), font_thickness)
        
    cv2.imwrite(dst_img_path, src_img)

# 数据集名称
DATASET = "train_jpg480max_wireless"
# 输入数据目录
img_data_dir    = f"/media/ubuntu/Date12/TableStruct/new_data/{DATASET}"
json_data_dir    = f"/media/ubuntu/Date12/TableStruct/new_data/{DATASET}_gt_json"
# 输出目录
pred_visual_dir = f"./output/structure_result/{DATASET}-visualize/"
if os.path.exists(pred_visual_dir):
    shutil.rmtree(pred_visual_dir)
Path(pred_visual_dir).mkdir(parents=True, exist_ok=True)

imgs = sorted(glob(os.path.join(img_data_dir, "*.jpg")))
for src_img_path in tqdm(imgs):
    src_img_file = os.path.basename(src_img_path)
    pred_json_path = os.path.join(json_data_dir, src_img_file.replace(".jpg", "-gt.json"))
    dst_img_path = os.path.join(pred_visual_dir, src_img_file.replace(".jpg", ".jpg"))
    if os.path.exists(pred_json_path):
        img_label_visualize(src_img_path, pred_json_path, dst_img_path)
