In [1]:
import numpy as np
import matplotlib.pyplot as plt
import cv2
import glob
import PIL
import os
import pandas as pd
from sklearn.cluster import SpectralClustering
from tqdm import tqdm

In [2]:
root_path = "/mnt/pentagon/xul076/"
test_meta = pd.read_csv("./sketchy_test_meta.csv", index_col=0)

In [3]:
test_meta

Unnamed: 0,class,photo,sketch,photo_path,sketch_path,validity,split,sketch_idx,photo_idx
0,airplane,n02691156_10151,n02691156_10151-2,sketchy/rendered_256x256/256x256/photo/tx_0001...,sketchy/rendered_256x256/256x256/sketch/tx_000...,ambiguous,test,0,0
1,airplane,n02691156_10151,n02691156_10151-3,sketchy/rendered_256x256/256x256/photo/tx_0001...,sketchy/rendered_256x256/256x256/sketch/tx_000...,valid,test,1,0
2,airplane,n02691156_10151,n02691156_10151-4,sketchy/rendered_256x256/256x256/photo/tx_0001...,sketchy/rendered_256x256/256x256/sketch/tx_000...,ambiguous,test,2,0
3,airplane,n02691156_10151,n02691156_10151-5,sketchy/rendered_256x256/256x256/photo/tx_0001...,sketchy/rendered_256x256/256x256/sketch/tx_000...,valid,test,3,0
4,airplane,n02691156_10151,n02691156_10151-6,sketchy/rendered_256x256/256x256/photo/tx_0001...,sketchy/rendered_256x256/256x256/sketch/tx_000...,valid,test,4,0
...,...,...,...,...,...,...,...,...,...
6245,zebra,n02391049_9960,n02391049_9960-1,sketchy/rendered_256x256/256x256/photo/tx_0001...,sketchy/rendered_256x256/256x256/sketch/tx_000...,valid,test,0,9
6246,zebra,n02391049_9960,n02391049_9960-2,sketchy/rendered_256x256/256x256/photo/tx_0001...,sketchy/rendered_256x256/256x256/sketch/tx_000...,valid,test,1,9
6247,zebra,n02391049_9960,n02391049_9960-3,sketchy/rendered_256x256/256x256/photo/tx_0001...,sketchy/rendered_256x256/256x256/sketch/tx_000...,valid,test,2,9
6248,zebra,n02391049_9960,n02391049_9960-4,sketchy/rendered_256x256/256x256/photo/tx_0001...,sketchy/rendered_256x256/256x256/sketch/tx_000...,valid,test,3,9


In [4]:
sketch_files = list(test_meta["sketch_path"])
sketch_files = [os.path.join(root_path, x) for x in sketch_files]

In [5]:
def close(thresh):
    contours, hierarchy = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
    fill = cv2.fillPoly(thresh, contours, 255)
    thresh = cv2.dilate(thresh, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (9, 9)))
    return thresh

In [6]:
centerss = []
np.random.seed(0)
cluster = SpectralClustering(8, affinity="nearest_neighbors", random_state=0)
for sketch_file in tqdm(sketch_files):
    img = np.array(PIL.Image.open(sketch_file))
    # fig, axes = plt.subplots(1, 4, figsize=(15, 3))
    # axes[0].imshow(img)

    img_ = 255 - img[:, :, 0]
    ret, thresh = cv2.threshold(img_, 127, 255, 0)
    closure = False
    while not closure:
        thresh = close(thresh)
        contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
        if len(contours) == 1:
            closure = True
    thresh = thresh / 255
    # axes[1].imshow(thresh)

    x, y = np.where(thresh == 1)
    pins = np.stack([x, y], axis=0).T

    if len(pins) < 2048:
        continue

    select = np.random.choice(len(pins), 2048, replace=False)
    pins = pins[select]

    splits = cluster.fit_predict(pins)
    centers = []

    for i in range(8):
        # img_ = img.copy()
        pin_split = pins[splits==i]
        # axes[2].scatter(pin_split[:, 1], pin_split[:, 0], s=5)
        center = pin_split.mean(axis=0)
#         img_ = cv2.circle(img_, (int(center[1]), int(center[0])), 3, (0, 0, 255), thickness=-1, lineType=cv2.LINE_AA)

#         plt.imshow(img_)
#         plt.show()

        centers.append(center)
    # axes[2].set_ylim(0, 256)
    # axes[2].set_xlim(0, 256)
    # axes[2].invert_yaxis()
    # centers = np.array(centers)
    # axes[3].imshow(img)
    # axes[3].scatter(centers[:, 1], centers[:, 0], c="red")
    # axes[3].set_ylim(0, 256)
    # axes[3].set_xlim(0, 256)
    # axes[3].invert_yaxis()
    # plt.show()
    centerss.append(centers)

100%|███████████████████████| 6250/6250 [25:21<00:00,  4.11it/s]


In [7]:
center_list = np.array(centerss)
center_list.shape

(6250, 8, 2)

In [8]:
keypoint_meta = pd.DataFrame(columns = [*test_meta.columns, "keypoint_x", "keypoint_y", "keypoint_idx"])
for i in tqdm(range(len(center_list))):
    kps = center_list[i]
    row = list(test_meta.loc[i])
    for j in range(len(kps)):
        kp = kps[j]
        kp_row = row + [kp[0], kp[1], j]
        keypoint_meta.loc[i * 8 + j] = kp_row

100%|███████████████████████| 6250/6250 [01:37<00:00, 64.42it/s]


In [9]:
keypoint_meta.to_csv(os.path.join(root_path, "sketchy/sketchy_test_keypoint_meta.csv"))
keypoint_meta.to_csv(os.path.join("./sketchy_test_keypoint_meta.csv"))

In [10]:
keypoint_meta

Unnamed: 0,class,photo,sketch,photo_path,sketch_path,validity,split,sketch_idx,photo_idx,keypoint_x,keypoint_y,keypoint_idx
0,airplane,n02691156_10151,n02691156_10151-2,sketchy/rendered_256x256/256x256/photo/tx_0001...,sketchy/rendered_256x256/256x256/sketch/tx_000...,ambiguous,test,0,0,67.324818,87.003650,0
1,airplane,n02691156_10151,n02691156_10151-2,sketchy/rendered_256x256/256x256/photo/tx_0001...,sketchy/rendered_256x256/256x256/sketch/tx_000...,ambiguous,test,0,0,96.371681,39.119469,1
2,airplane,n02691156_10151,n02691156_10151-2,sketchy/rendered_256x256/256x256/photo/tx_0001...,sketchy/rendered_256x256/256x256/sketch/tx_000...,ambiguous,test,0,0,177.586667,172.555556,2
3,airplane,n02691156_10151,n02691156_10151-2,sketchy/rendered_256x256/256x256/photo/tx_0001...,sketchy/rendered_256x256/256x256/sketch/tx_000...,ambiguous,test,0,0,72.923423,163.130631,3
4,airplane,n02691156_10151,n02691156_10151-2,sketchy/rendered_256x256/256x256/photo/tx_0001...,sketchy/rendered_256x256/256x256/sketch/tx_000...,ambiguous,test,0,0,110.099631,212.926199,4
...,...,...,...,...,...,...,...,...,...,...,...,...
49995,zebra,n02391049_9960,n02391049_9960-5,sketchy/rendered_256x256/256x256/photo/tx_0001...,sketchy/rendered_256x256/256x256/sketch/tx_000...,valid,test,4,9,61.701389,67.541667,3
49996,zebra,n02391049_9960,n02391049_9960-5,sketchy/rendered_256x256/256x256/photo/tx_0001...,sketchy/rendered_256x256/256x256/sketch/tx_000...,valid,test,4,9,114.846154,201.070769,4
49997,zebra,n02391049_9960,n02391049_9960-5,sketchy/rendered_256x256/256x256/photo/tx_0001...,sketchy/rendered_256x256/256x256/sketch/tx_000...,valid,test,4,9,185.506757,206.581081,5
49998,zebra,n02391049_9960,n02391049_9960-5,sketchy/rendered_256x256/256x256/photo/tx_0001...,sketchy/rendered_256x256/256x256/sketch/tx_000...,valid,test,4,9,200.708571,53.508571,6
