训练集标注可视化

In [7]:
%config InlineBackend.figure_format = 'retina'

In [8]:
import cv2
import os
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
from glob import glob
import json
import numpy as np
from pathlib import Path
from tqdm import tqdm



# 按行生成彩色标注框
idx2color = []
for i in range(3):
    for j in range(3):
        for k in range(3):
            idx2color.append((i*100, j*100, k*100))
idx2color.append((255, 0, 0))
idx2color = np.array(idx2color)


def table_visualize(src_img, table, key='row', only_last=False):
    last_img = src_img.copy()
    for idx, row in enumerate(table[key]):
        color = idx2color[idx] if idx < len(idx2color) else idx2color[-1]
        pts = np.array(row, dtype=np.int32)
        if not only_last:
            temp_img = src_img.copy()
            cv2.polylines(temp_img, [pts], True, color.tolist(), 1)
        cv2.polylines(last_img, [pts], True, color.tolist(), 1)
        

        cx, cy = pts.mean(axis=0)
        if not only_last:
            cv2.putText(temp_img, str(idx), (int(cx), int(cy)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color.tolist(), 1)
            temp_img = cv2.cvtColor(temp_img, cv2.COLOR_BGR2RGB)
            plt.imshow(temp_img)
            plt.show()
        cv2.putText(last_img, str(idx), (int(cx), int(cy)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color.tolist(), 1)
        last_img = cv2.cvtColor(last_img, cv2.COLOR_BGR2RGB)
    plt.imshow(last_img)
    plt.show()

def label_visualize(src_img, label):
    for cell in label['cells']:
        start_row_idx = cell['row_start_idx']
        end_row_idx   = cell['row_end_idx']
        start_col_idx = cell['col_start_idx']
        end_col_idx   = cell['col_end_idx']
        color = idx2color[start_row_idx] if start_row_idx < len(idx2color) else idx2color[-1]
        x0, y0, x1, y1 = cell['bbox']

        pts = np.array(cell['segmentation'], dtype=np.int32)
        cv2.rectangle(src_img, (int(x0), int(y0)), (int(x1), int(y1)), color.tolist(), 1)

        theight = cv2.getTextSize(str(start_row_idx), cv2.FONT_HERSHEY_SIMPLEX, 1, 2)[0][1]
        show_text = f"{start_row_idx}" if start_row_idx == end_row_idx else f"{start_row_idx}-{end_row_idx}"
        show_text += f",{start_col_idx}" if start_col_idx == end_col_idx else f" {start_col_idx}-{end_col_idx}"

        cv2.putText(src_img, show_text, (int(x0), int(y0)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color.tolist(), 1)
        # transcript = cell['transcript']
        # cv2.putText(src_img, transcript, (int(x0), int(y0)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color.tolist(), 1)
        plt.imshow(src_img)
    plt.show()

def rotate_label_visualize(src_img, label):
    h, w = src_img.shape[:2]
    center = (w//2, h//2)
    M = cv2.getRotationMatrix2D(center, -1, 1.0)
    src_img = cv2.warpAffine(src_img, M, (w, h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE)
    # src_img = cv2.warpAffine(src_img, M, (w, h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_CONSTANT, borderValue=(0,0,0))

    for cell in label['cells']:
        start_row_idx = cell['row_start_idx']
        end_row_idx   = cell['row_end_idx']
        start_col_idx = cell['col_start_idx']
        end_col_idx   = cell['col_end_idx']
        color = idx2color[start_row_idx] if start_row_idx < len(idx2color) else idx2color[-1]
        x0, y0, x1, y1 = cell['bbox']
        print(1, type(cv2.transform(np.array([[[x0, y0], [x1, y1]]]), M)))
        print(2, type(cv2.transform(np.array([[[x0, y0], [x1, y1]]]), M).squeeze()))
        [x0, y0], [x1, y1] = cv2.transform(np.array([[[x0, y0], [x1, y1]]]), M).squeeze().astype(np.int32)

        pts = np.array(cell['segmentation'], dtype=np.int32)
        cv2.rectangle(src_img, (int(x0), int(y0)), (int(x1), int(y1)), color.tolist(), 1)

        theight = cv2.getTextSize(str(start_row_idx), cv2.FONT_HERSHEY_SIMPLEX, 1, 2)[0][1]
        show_text = f"{start_row_idx}" if start_row_idx == end_row_idx else f"{start_row_idx}-{end_row_idx}"
        show_text += f",{start_col_idx}" if start_col_idx == end_col_idx else f" {start_col_idx}-{end_col_idx}"

        cv2.putText(src_img, show_text, (int(x0), int(y0)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color.tolist(), 1)
        plt.imshow(src_img)
    plt.show()





In [9]:
import utils.table2label as table2label
'''
可视化单一图像
'''
img_id = '00849'
img_id = '00838'
img_id = '00741'
img_id = '02078'
img_id = '00075'
img_id = '00169'
img_id = '00425'
img_id = '00761'
img_id = '00856'
img_id = '00001'
img_id = '00009'
img_id = '00023'
img_id = '00027'
img_id = '00029'
img_id = '00033'
img_id = '00034'
jpg_path  = f"/media/ubuntu/Date12/TableStruct/new_data/train_jpg480max/{img_id}.jpg"
json_path = f"/media/ubuntu/Date12/TableStruct/new_data/train_jpg480max/{img_id}.json"

src_img = cv2.imread(jpg_path)
table = json.load(open(json_path, 'r'))

# if not table2label.table_valid(table):
#     print("table_valid error!", img_id)
# # table = table2label.fix_table_error(table)

# table_visualize(src_img, table, 'row', only_last=False)
# table_visualize(src_img, table, 'col', only_last=False)
# table_visualize(src_img, table, 'line', only_last=True)
label = table2label.table2label(table)
# label_visualize(src_img, label)
print(label['layout'])
rotate_label_visualize(src_img, label)

for i in range(100):
    break
    img_id = f'{i:05d}'
    jpg_path  = f"/media/ubuntu/Date12/TableStruct/data/train_jpg/{img_id}.jpg"
    json_path = f"/media/ubuntu/Date12/TableStruct/data/train_jpg/{img_id}.json"

    src_img = cv2.imread(jpg_path)
    table = json.load(open(json_path, 'r'))

    table = table2label.fix_table_error(table)

    if  not table2label.table_valid(table):
        print("table_valid error!", img_id)
        continue

    # table_visualize(src_img, table, 'row')
    label = table2label.table2label(table)
    if not table2label.table2label_valid(table, label):
        num_row = len(table['row'])
        num_col = len(table['col'])
        num_row_layout = len(label['layout'])
        num_col_layout = len(label['layout'][-1])
        if num_row != num_row_layout:
            print(f"{img_id}:", "table row: ", num_row, "label row:", num_row_layout)
        if num_col != num_col_layout:
            print(f"{img_id}", "table col: ", num_col, "label col:", num_col_layout)
        # print("not valid!!!", img_id)
    # label_visualize(src_img, label)



[[0, 1, 2, 3, 4, 5, 5, 5, 5, 6, 6, 6, 6, 7, 8, 9], [0, 1, 2, 3, 4, 10, 11, 12, 13, 14, 15, 16, 17, 18, 8, 9], [19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], [19, 20, 21, 22, 23, 35, 36, 37, 38, 39, 40, 41, 42, 43, 33, 34], [19, 20, 21, 22, 23, 44, 45, 46, 47, 48, 49, 50, 51, 52, 33, 34], [19, 20, 21, 22, 23, 53, 54, 55, 56, 57, 58, 59, 60, 61, 33, 34], [19, 20, 21, 22, 23, 62, 63, 64, 65, 66, 67, 68, 69, 70, 33, 34], [19, 20, 21, 22, 23, 71, 72, 73, 74, 75, 76, 77, 78, 79, 33, 34], [19, 20, 21, 22, 23, 80, 81, 82, 83, 84, 85, 86, 87, 88, 33, 34], [89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104], [105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120], [121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136], [121, 122, 123, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149], [121, 122, 123, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 148, 161], [121, 122, 123, 162, 163, 164