In [6]:
import numpy as np
import cv2
import torch


# Read image and segmented area

In [11]:
# image = cv2.imread("path/to/image.jpg")

image = np.array([
    [[255, 0, 0], [0, 255, 0], [0, 0, 255]],        # 红色、绿色、蓝色
    [[255, 255, 0], [255, 0, 255], [0, 255, 255]],  # 黄色、紫色、青色
    [[255, 255, 255], [128, 128, 128], [0, 0, 0]]   # 白色、灰色、黑色
], dtype=np.uint8)


segmentation_output = {
    "instances": {
        "pred_boxes": (torch.tensor([[10, 10, 100, 100], [20, 20, 200, 200]])),
        "scores": torch.tensor([0.9, 0.95]),
        "pred_classes": torch.tensor([1, 2]),
        "pred_masks": torch.tensor([[[1, 1], [0, 0]], [[1, 1], [1, 1]]]),
    },
}
# input class to category name mapping
category_name_map = {1: "other", 2: "main_beam"}


# Select main area

In [12]:
def select_main_area(segmentation_output, category_name_map):
    instances = segmentation_output["instances"]
    masks = instances["pred_masks"]
    classes = instances["pred_classes"]

    # max area index
    areas = masks.sum(dim=(1, 2))
    max_area_index = areas.argmax().item()
    
    # main_beam class index
    main_beam_indices = [i for i, cls in enumerate(classes) if category_name_map[cls.item()] == "main_beam"]
    if main_beam_indices:
        main_beam_areas = areas[main_beam_indices]
        main_beam_index = main_beam_indices[main_beam_areas.argmax().item()]
    else:
        main_beam_index = max_area_index

    # compare the size of the max area and the main beam area
    if areas[max_area_index] > areas[main_beam_index]:
        selected_index = max_area_index
    else:
        selected_index = main_beam_index

    return masks[selected_index]

selected_mask = select_main_area(segmentation_output, category_name_map)


# Abstract complex polygon from the main area to simple polygon

In [13]:
def simplify_contours(mask, epsilon_factor=0.01):
    # get contours
    mask_np = mask.cpu().numpy().astype(np.uint8)
    contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    # simplify contours
    simplified_contours = []
    for contour in contours:
        epsilon = epsilon_factor * cv2.arcLength(contour, True)
        approx = cv2.approxPolyDP(contour, epsilon, True)
        simplified_contours.append(approx)
    
    return simplified_contours

simplified_contours = simplify_contours(selected_mask)


# Reshape the simple polygon to parallel pairs/correct perspective

In [14]:
def perspective_transform(contours, image_shape):
    # find the longest edge of the contour
    def find_longest_edge(contour):
        max_dist = 0
        p1, p2 = None, None
        for i in range(len(contour)):
            for j in range(i+1, len(contour)):
                dist = np.linalg.norm(contour[i] - contour[j])
                if dist > max_dist:
                    max_dist = dist
                    p1, p2 = contour[i][0], contour[j][0]
        return p1, p2

    # get all the points of the contour
    contour = contours[0].squeeze()
    p1, p2 = find_longest_edge(contour)
    remaining_points = [pt for pt in contour if not np.array_equal(pt, p1) and not np.array_equal(pt, p2)]
    p3, p4 = remaining_points[:2]
    
    src_pts = np.float32([p1, p2, p3, p4])
    width = np.linalg.norm(p1 - p2)
    height = np.linalg.norm(p3 - p4)
    dst_pts = np.float32([[0, 0], [width, 0], [0, height], [width, height]])
    
    # calculate the perspective transform matrix
    M = cv2.getPerspectiveTransform(src_pts, dst_pts)
    transformed_image = cv2.warpPerspective(image, M, (image_shape[1], image_shape[0]), borderValue=(255, 255, 255))
    
    # apply the transform to the contour
    transformed_contour = cv2.perspectiveTransform(np.array([contour], dtype=np.float32), M)[0]

    return transformed_image, transformed_contour

transformed_image, transformed_contour = perspective_transform(simplified_contours, image.shape)


ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (4,) + inhomogeneous part.

In [None]:
# 逼近轮廓
def myApprox(con, epsilon_factor=0.01):
    epsilon = epsilon_factor * cv2.arcLength(con, True)
    approx = cv2.approxPolyDP(con, epsilon, True)
    return approx

# 多边形矫正
def Polygon_correction(img):
    if len(img.shape) > 2:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

    contours, _ = cv2.findContours(img, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
    ori_img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
    ori_h, ori_w = ori_img.shape[:2]
    print('ori_w, ori_h:', ori_w, ori_h)

    cv2.imshow('binary', img)
    cv2.waitKey(0)
    cv2.destroyAllWindows()

    if len(contours) > 0:
        cnts = sorted(contours, key=cv2.contourArea, reverse=True)[:2]
        docCnt = None

        for cnt in cnts:
            approx = myApprox(cnt)
            if len(approx) >= 4:
                docCnt = approx
                break

        if docCnt is not None:
            rect = cv2.minAreaRect(docCnt)
            box = cv2.boxPoints(rect)
            box = np.int0(box)

            src_points = np.float32(box)
            dst_points = np.float32([[0, 0], [0, ori_h - 1], [ori_w - 1, ori_h - 1], [ori_w - 1, 0]])

            M = cv2.getPerspectiveTransform(src_points, dst_points)
            result = cv2.warpPerspective(ori_img, M, (ori_w, ori_h))

            for point in box:
                cv2.circle(result, tuple(point), 5, (0, 255, 0), 2)

            print("原始点:", src_points)
            print("变换后的点：", dst_points)

            cv2.polylines(result, [box], True, (255, 255, 0), 2)
            cv2.imshow('result', result)
            cv2.waitKey(0)
            cv2.destroyAllWindows()
            return result
        else:
            return img
    else:
        return img

# 测试代码
# img = cv2.imread('your_image_path_here')
# Polygon_correction(img)


# Export results

In [None]:
import matplotlib.pyplot as plt

def export_results(original_image, transformed_image, transformed_contour):
    plt.figure(figsize=(10, 5))
    
    # # show the original image
    # plt.subplot(1, 2, 1)
    # plt.title("Original Image")
    # plt.imshow(cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB))
    
    # # show the transformed image
    # plt.subplot(1, 2, 2)
    # plt.title("Transformed Image")
    # plt.imshow(cv2.cvtColor(transformed_image, cv2.COLOR_BGR2RGB))
    
    # plot the contour
    for pt in transformed_contour:
        plt.scatter(pt[0], pt[1], c='red')
    
    plt.show()

export_results(image, transformed_image, transformed_contour)


# 