预测结果可视化

In [1]:
TRAINSET_PNG2JPG  = False
TESTSET_PNG2JPG   = True
TRAINSET_JPG_JSON = False
TESTSET_JPG_JSON  = False

### Step1. 训练集PNG转JPG，长边缩放至512

In [2]:
import cv2
import os
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
from glob import glob
import json
import numpy as np
from pathlib import Path
from tqdm import tqdm



# 数据集名称
dataset = "train"
dataset_root = "/media/ubuntu/Date12/TableStruct/data/"
# 输入数据目录
train_png_dir  = os.path.join(dataset_root, dataset)
# 输出
train_jpg_dir  = os.path.join(dataset_root, f"{dataset}_jpg")
jpg_json_path  = os.path.join(dataset_root, f"{dataset}_jpg.json")
Path(train_jpg_dir).mkdir(parents=True, exist_ok=True)


if TRAINSET_PNG2JPG:
    jpg_infos = dict()
    for img_path in tqdm(glob(os.path.join(train_png_dir, "*.png"))):
        img_file = os.path.basename(img_path)
        img_id = img_file.split(".")[0]
        jpg_img_path = os.path.join(train_jpg_dir, img_file.replace(".png", ".jpg"))
        img = cv2.imread(img_path)

        scale = 1.0
        oriwidth, oriheight = width, height = img.shape[1], img.shape[0]

        if height > 512 and width > 512:
            # 短边缩小至512，需要注意极端宽高比情况
            if width >= height:
                scale = 512.0 / height
                height = 512 
                width = round(width * scale)
                img = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
            else:
                scale = 512.0 / width
                width = 512 
                height = round(height * scale)
                img = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
        elif height < 512 and width < 512:
            # 短边放大至512，需要注意极端宽高比情况
            if height >= width:
                scale = 512.0 / width
                width = 512 
                height = round(height * scale)
                img = cv2.resize(img, (width, height), interpolation=cv2.INTER_LINEAR)
            else:
                scale = 512.0 / height
                height = 512 
                width = round(width * scale)
                img = cv2.resize(img, (width, height), interpolation=cv2.INTER_LINEAR)
        jpg_infos[img_id] = dict(oriwidth=oriwidth, oriheight=oriheight, width=width, height=height, scale=scale)
        cv2.imwrite(jpg_img_path, img)

    json.dump(jpg_infos, open(jpg_json_path, "w"), indent=4)


### Step2. 测试集PNG转JPG，长边缩放至512

In [3]:
# 数据集名称
dataset = "test_A"
dataset_root = "/media/ubuntu/Date12/TableStruct/data/"
# 输入数据目录
test_png_dir  = os.path.join(dataset_root, dataset)
# 输出
test_jpg_dir  = os.path.join(dataset_root, f"{dataset}_jpg")
jpg_json_path  = os.path.join(dataset_root, f"{dataset}_jpg.json")
Path(test_jpg_dir).mkdir(parents=True, exist_ok=True)


if TESTSET_PNG2JPG:
    jpg_infos = dict()
    for img_path in tqdm(glob(os.path.join(test_png_dir, "*.png"))):
        img_file = os.path.basename(img_path)
        img_id = img_file.split(".")[0]
        jpg_img_path = os.path.join(test_jpg_dir, img_file.replace(".png", ".jpg"))
        img = cv2.imread(img_path)

        scale = 1.0
        oriwidth, oriheight = width, height = img.shape[1], img.shape[0]

        if height > 512 and width > 512:
            # 短边缩小至512，需要注意极端宽高比情况
            if width >= height:
                scale = 512.0 / height
                height = 512 
                width = round(width * scale)
                img = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
            else:
                scale = 512.0 / width
                width = 512 
                height = round(height * scale)
                img = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
        elif height < 512 and width < 512:
            # 短边放大至512，需要注意极端宽高比情况
            if height >= width:
                scale = 512.0 / width
                width = 512 
                height = round(height * scale)
                img = cv2.resize(img, (width, height), interpolation=cv2.INTER_LINEAR)
            else:
                scale = 512.0 / height
                height = 512 
                width = round(width * scale)
                img = cv2.resize(img, (width, height), interpolation=cv2.INTER_LINEAR)
        jpg_infos[img_id] = dict(oriwidth=oriwidth, oriheight=oriheight, width=width, height=height, scale=scale)
        cv2.imwrite(jpg_img_path, img)

    json.dump(jpg_infos, open(jpg_json_path, "w"), indent=4)

100%|██████████| 5187/5187 [08:43<00:00,  9.91it/s]


### Step3. 训练集JSON坐标标注按比例缩放

In [4]:

# 数据集名称
dataset = "train"
dataset_root = "/media/ubuntu/Date12/TableStruct/data/"
# 输入数据目录
train_json_dir  = os.path.join(dataset_root, dataset)
# 输出
train_jpg_dir  = os.path.join(dataset_root, f"{dataset}_jpg")
jpg_json_path  = os.path.join(dataset_root, f"{dataset}_jpg.json")
Path(train_jpg_dir).mkdir(parents=True, exist_ok=True)

def segmentation_scale(segmentations, scale):
    ret = []
    for seg in segmentations:
        seg = np.array(seg) * scale
        seg = seg.astype(np.int32).tolist()
        ret.append(seg)
    return ret

if TRAINSET_JPG_JSON:
    img_infos = json.load(open(jpg_json_path, "r"))
    for json_path in tqdm(glob(os.path.join(train_json_dir, "*.json"))):
        json_file = os.path.basename(json_path)
        img_id = json_file.split(".")[0]
        jpg_json_path = os.path.join(train_jpg_dir, json_file)

        info = json.load(open(json_path, "r"))
        scale = img_infos[img_id]["scale"]

        info["row"]  = segmentation_scale(info["row"], scale)
        info["col"]  = segmentation_scale(info["col"], scale)
        info["line"] = segmentation_scale(info["line"], scale)
        info["cell"] = segmentation_scale(info["cell"], scale)

        json.dump(info, open(jpg_json_path, "w"))