In [None]:
import os
import sys
import json
import torch
import numpy as np
from tqdm.notebook import tqdm
import imageio
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from skimage import img_as_ubyte, io
import cv2
import glob
from typing import Optional, Tuple, Union

# pytorch ops
from pytorch3d.ops import knn_gather, knn_points

In [None]:
def compute_sampling_metrics_2d(pred_points, gt_points, thresholds, eps=1e-8):
    """
    Compute metrics that are based on sampling points from 2D image:
    - L2 Chamfer distance
    - Precision at various thresholds
    - Recall at various thresholds
    - F1 score at various thresholds
    Inputs:
        - pred_points: Tensor of shape (N, S, 3) giving coordinates of sampled points
          for each predicted mesh
        - gt_points: Tensor of shape (N, S, 3) giving coordinates of sampled points
          for each ground-truth mesh
        - thresholds: Distance thresholds to use for precision / recall / F1
        - eps: epsilon value to handle numerically unstable F1 computation
    Returns:
        - metrics: A dictionary where keys are metric names and values are Tensors of
          shape (N,) giving the value of the metric for the batch
    """
    metrics = {}
    lengths_pred = torch.full(
        (pred_points.shape[0],), pred_points.shape[1], dtype=torch.int64, device=pred_points.device
    )
    lengths_gt = torch.full(
        (gt_points.shape[0],), gt_points.shape[1], dtype=torch.int64, device=gt_points.device
    )

    # For each predicted point, find its neareast-neighbor GT point
    knn_pred = knn_points(pred_points, gt_points, lengths1=lengths_pred, lengths2=lengths_gt, K=1)
    # Compute L1 and L2 distances between each pred point and its nearest GT
    pred_to_gt_dists2 = knn_pred.dists[..., 0]  # (N, S)
    pred_to_gt_dists = pred_to_gt_dists2.sqrt()  # (N, S)

    # For each GT point, find its nearest-neighbor predicted point
    knn_gt = knn_points(gt_points, pred_points, lengths1=lengths_gt, lengths2=lengths_pred, K=1)
    # Compute L1 and L2 dists between each GT point and its nearest pred point
    gt_to_pred_dists2 = knn_gt.dists[..., 0]  # (N, S)
    gt_to_pred_dists = gt_to_pred_dists2.sqrt()  # (N, S)


    # Compute L2 chamfer distances
    chamfer_l2 = pred_to_gt_dists2.mean(dim=1) + gt_to_pred_dists2.mean(dim=1)
    metrics["Chamfer-L2"] = chamfer_l2

    # Compute precision, recall, and F1 based on L2 distances
    for t in thresholds:
        precision = 100.0 * (pred_to_gt_dists < t).float().mean(dim=1)
        recall = 100.0 * (gt_to_pred_dists < t).float().mean(dim=1)
        f1 = (2.0 * precision * recall) / (precision + recall + eps)
        metrics["Precision@%f" % t] = precision
        metrics["Recall@%f" % t] = recall
        metrics["F1@%f" % t] = f1

    # Move all metrics to CPU
    metrics = {k: v.cpu() for k, v in metrics.items()}
    return metrics

def plt_imshow(title, image):
    # convert the image frame BGR to RGB color space and display it
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    plt.imshow(image)
    plt.title(title)
    plt.grid(False)
    plt.show()
    
def get_grayscale_img(img_path):
    ''' Get image and convert to grayscale. Also convert black background to white if detected '''
    image = io.imread(img_path)
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    if gray[0][0] != 255:
        gray = 255 - gray
    gray = cv2.GaussianBlur(gray, (5, 5), 0)
    return gray

def binarize_img(gray_img, threshold):
    _, thresh = cv2.threshold(gray_img, threshold, 255, cv2.THRESH_BINARY_INV)
    thresh = cv2.erode(thresh, None, iterations=2)
    thresh = cv2.dilate(thresh, None, iterations=2)
    return thresh

def get_contour(img):
    # want CHAIN_APPROX_NONE because it returns a more full list of points
    contours, hierarchy  = cv2.findContours(img.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
    # grab largest contour (idx 0 i think)
    cnt_idx = 0
    contour = np.squeeze(contours[cnt_idx]) # size: [N x 2]
    return contour

def draw_sampled_points(gray, points):
    radius = 1
    thickness = 1
    contour_img = gray.copy()
    for i in range(points.shape[0]):
        image = cv2.circle(contour_img, tuple(points[i]), radius, [0, 255, 0], thickness)
    return image

def get_bbox(mask): # 'mask' should be the binary image where 0 is background and 1 is foreground. if this is switched on your binary image then instead pass (1 - mask)
    sum_x = np.sum(mask,axis=0)
    sum_y = np.sum(mask,axis=1)

    x_min = np.nonzero(sum_y)[0][0] 
    x_max = np.nonzero(sum_y)[0][-1]

    y_min = np.nonzero(sum_x)[0][0]
    y_max = np.nonzero(sum_x)[0][-1]
    return [x_min,x_max,y_min,y_max]

def bbox_crop_img(img,bbox):
    return img[bbox[0]:bbox[1],bbox[2]:bbox[3]]


def make_square_2d(img,pad_val=0):
    h = img.shape[0]
    w = img.shape[1]
    if h == w:
        return img
    elif w < h:
        padamt = h - w
        if padamt % 2 == 0:
            padim = np.pad(img, ((0,0), (padamt//2, padamt//2)), mode='constant', constant_values=pad_val)
        else:
            padim = np.pad(img, ((0,0), (padamt//2, padamt//2+1)), mode='constant', constant_values=pad_val)
    else:
        padamt = w - h
        if padamt % 2 == 0:
            padim = np.pad(img, ((padamt//2, padamt//2), (0,0)), mode='constant', constant_values=pad_val)
        else:
            padim = np.pad(img, ((padamt//2, padamt//2+1), (0,0)), mode='constant', constant_values=pad_val)
    if padim.shape[0] != padim.shape[1]:
        assert False, 'failure in make_square'
    return padim
from skimage import io

### set relevant parameters (replace with path to your files)

In [None]:
import glob

f = open('results.csv', 'w')
# write csv header
f.write('object1,object2,CD,F@0.1,F@0.2,F@0.5\n')

lim_img=glob.glob("*_lim.png")
for img_path1 in lim_img:
    prefix=img_path1.split("_")[0]
    other_imgs = glob.glob(prefix + "*")
    other_imgs = [imgpath for imgpath in other_imgs if 'lim' not in imgpath]
    for img_path2 in other_imgs:
        obj1 = img_path1.split('/')[-1][:-4]
        obj2 = img_path2.split('/')[-1][:-4]
        print(obj1,obj2)
        # how many points to sample (if this is great than number of points in im1 or im2's contour it will
        # get overrided with the # points of the smaller contour)
        n_points_to_sample = 100
        gray1 = get_grayscale_img(img_path1)
        gray2 = get_grayscale_img(img_path2)
        img_dim = gray1.shape[0] # used later to normalize sampled points, assumes square images

        plt.imshow(gray1, cmap='gray', vmin=0, vmax=255)
        plt.title("grayscale image 1")
        plt.show()

        plt.imshow(gray2, cmap='gray', vmin=0, vmax=255)
        plt.title("grayscale image 2")
        plt.show()
        thresh_val = 250
        img = gray1
        mask1 = binarize_img(gray1, thresh_val) # your binary/threshold image (should be dimensions WxH)
        bbox = get_bbox(mask1)
        mask1 = make_square_2d(bbox_crop_img(mask1,bbox), 0)
        gray1 = make_square_2d(bbox_crop_img(gray1,bbox), 255)

        img = gray2
        mask2 = binarize_img(gray2, thresh_val) # your binary/threshold image (should be dimensions WxH)
        bbox = get_bbox(mask2)
        mask2 = make_square_2d(bbox_crop_img(mask2,bbox), 0)
        gray2 = make_square_2d(bbox_crop_img(gray2,bbox), 255)

        mask1=cv2.resize(mask1, (100,100)) 
        mask2=cv2.resize(mask2, (100,100))
        
        gray1=cv2.resize(gray1, (100,100)) 
        gray2=cv2.resize(gray2, (100,100))

        plt.imshow(gray1, cmap='gray')
        plt.title("grayscale image 1")
        plt.show()

        plt.imshow(gray2, cmap='gray')
        plt.title("grayscale image 2")
        plt.show()

        # show extra details
        print(np.unique(gray1))
        print(bbox)
        plt.imshow(mask1)
        plt.show()
        plt.imshow(mask2)
        plt.show()

        assert mask1.shape == mask2.shape, 'IMAGES MUST BE SAME DIMESIONS'
        assert mask1.shape[0] == mask1.shape[1], 'IMAGES MUST BE SQUARE'

        contour1 = get_contour(mask1)
        contour2 = get_contour(mask2)

        n_points1 = contour1.shape[0]
        n_points2 = contour2.shape[0]
        print('found contour on img 1 with', n_points1, 'points')
        print('found contour on img 2 with', n_points2, 'points')
        print('sampling', n_points_to_sample, 'of them')
        if n_points_to_sample > min(n_points1,n_points2):
            print('not enough points to sample! overriding n_points_to_sample..')
            n_points_to_sample = min(n_points1,n_points2)
            print('sampling', n_points_to_sample, 'now')

        sampled_idxes = np.random.choice(np.arange(n_points1), n_points_to_sample, replace=False)
        sampled_points1 = contour1[sampled_idxes] # size: [n_points_to_sample x 2]
        sampled_idxes = np.random.choice(np.arange(n_points2), n_points_to_sample, replace=False)
        sampled_points2 = contour2[sampled_idxes] # size: [n_points_to_sample x 2]

        vis_points1 = draw_sampled_points(gray1, sampled_points1)
        vis_points2 = draw_sampled_points(gray2, sampled_points2)

        plt.imshow(vis_points1, cmap='gray', vmin=0, vmax=255)
        plt.title('sampled points 1')
        plt.show()
        plt.imshow(vis_points2, cmap='gray', vmin=0, vmax=255)
        plt.title('sampled points 2')
        plt.show()

        thresholds = [0.1, 0.2, 0.5]
        pts1 = torch.tensor(sampled_points1).unsqueeze(0) / img_dim # normalize pts to 0-1 range
        pts2 = torch.tensor(sampled_points2).unsqueeze(0) / img_dim # normalize pts to 0-1 range
        # scale by 10 to be more comparable with meshes (i think?)
        pts1 = pts1 * 10.0
        pts2 = pts2 * 10.0

        result = compute_sampling_metrics_2d(pts1, pts2, thresholds)
        print('Chamfer distance:', result['Chamfer-L2'].item()) # lower is better
        print('F-score @ 0.1:', result['F1@0.100000'].item()) # higher is better
        print('F-score @ 0.2:', result['F1@0.200000'].item()) # higher is better
        print('F-score @ 0.5:', result['F1@0.500000'].item()) # higher is better
        
        # write to file
        f.write('%s,%s,%f,%f,%f,%f\n'%(obj1,obj2,result['Chamfer-L2'].item(), result['F1@0.100000'].item(), result['F1@0.200000'].item(), result['F1@0.500000'].item()))
        
# close file writer
f.close()

In [None]:
img_path="#3Superellipse_Dsh.png"
img=io.imread(img_path)
print(img.shape)
print(np.unique(img[:,:,-1]))

### load image (must be square and same size for the rest to make sense!)

### binarize images

## find contour points

### calculate distance metrics