## Evaluation code
Note that user should modify the code (deformed coordinates extraction part) to their own setting.

We provide example dataset to run evaluation.


In [1]:
import os
import torch
import numpy as np

os.environ["CUDA_VISIBLE_DEVICES"]= '7'

## 1. Save deformed coordinates (and labels)

In [None]:
eval_type = 'kp'
category = 'chair'
spath = 'your path to data split (train split for evaluation)'
radius = 1.03 # bbox diagonal line for 3D shape
# this depends on how you preprocess the dataset. we use 1.03 for our model and 1.0 for baselines
dpath = 'your path to dataset'


with open(spath,'r') as file:
    all_names = file.read().split('\n')

In [None]:
"""
load evaluation dataset (ex. part label, keypoint, ...)
evaluate with train_split (unsupervised, surrogate tasks for shape correspondence)
"""

from eval_dataset import PartDataset, KeyPointDataset

if eval_type == 'kp':
    dataset = KeyPointDataset(category=category, split_path=spath, data_root=dpath, normalize_pc_max=1.03)
elif eval_type == 'partlabel':
    dataset = PartDataset(category=category, split_path=spath, data_root=dpath, normalize_pc_max=1.03)

In [None]:
"""
extract deformed coords with dataset (example: kp dataset)
note that for part label task, we use 2 part labels for table category.
i.e., label[np.where(label == 49)[0]] = 48
"""

for i, data in enumerate(dataset):
    coords = torch.from_numpy(data[0]).unsqueeze(0).cuda()
    label = data[1]
    name = data[2]
    
    ### get latent & embedding ###
    instance_idx = split_models.index(name)
    sidx = torch.Tensor([instance_idx]).squeeze().long().cuda()[None,...]  
    embed = model.shape_latent_code(sidx)
    
    with torch.no_grad():
        latent_z = E(coords)
    deformed_pts = model.get_template_coords(coords, latent_z, embed, f)
    
    save_file = torch.cat([coords.squeeze(0).detach().cpu(), deformed_pts.squeeze(0).detach().cpu()], dim=1)
    folder_dir = os.path.join(save_folder, name)
    
    if not os.path.isdir(folder_dir):
        os.makedirs(folder_dir)

    save_path = os.path.join(folder_dir, 'coords.pt')
    torch.save(save_file,save_path) ## coords
    torch.save(torch.from_numpy(label), os.path.join(folder_dir, 'kp_label.pt'))
    torch.save(embed.detach().cpu(), os.path.join(folder_dir, 'latent.pt'))
    

## 2. Evaluation(1): keypoint transfer

In [3]:
# example - chair
src_id = ['26a5761e22185ab367d783b4714d4324','10d174a00639990492d9da2668ec34c','104256e5bb73b0b719fb4103277a6b93']
path_kp = './examples/chair_kp'
idlist = sorted(os.listdir(path_kp))
tgt_id = [x for x in idlist if x not in src_id]

In [4]:
from utils_eval import get_center_keypoints, get_deformed_points

kp_T = get_center_keypoints(path_kp,src_id)
    
tgt_o, tgt_d, tgt_gt_kp, tgt_name = get_deformed_points(path_kp,tgt_id)

tgt_o = np.stack(tgt_o)
tgt_d = np.stack(tgt_d)
tgt_gt_kp = np.stack(tgt_gt_kp)

distances = torch.sqrt(torch.sum((torch.from_numpy(tgt_d[:,None,:,:]).cuda() - torch.from_numpy(kp_T[None,:,None,:]).cuda()) ** 2, dim=-1)).detach().cpu().numpy()
pred_kp = np.argmin(distances,axis=-1)

pred_kp_pts = []
gt_kp_pts = []
distance_new = []
for b in range(pred_kp.shape[0]):
    pred_kp_pts.append(tgt_o[b][pred_kp[b]])
    gt_kp_pts.append(tgt_o[b][tgt_gt_kp[b]])
pred_kp_pts = np.stack(pred_kp_pts)
gt_kp_pts = np.stack(gt_kp_pts)

print(pred_kp_pts.shape)
print(gt_kp_pts.shape)

(97, 21, 3)
(97, 21, 3)


In [5]:
n_data = pred_kp_pts.shape[0]
threshold_list = [0.01 * i * 1.03 for i in range(51)]
pcks = []
for b in range(n_data):
    valid_idx = [i for i in range(tgt_gt_kp.shape[1]) if tgt_gt_kp[b, i] >= 0]
    pred = pred_kp_pts[b][valid_idx]
    gt = gt_kp_pts[b][valid_idx]
    dists = np.sqrt(np.sum((gt - pred)**2, axis=-1))
    pck = np.stack([np.sum(dists < th)/len(valid_idx) for th in threshold_list])
    pcks.append(pck)

pcks= np.stack(pcks)
res= np.mean(pcks,0)

for i in range(len(res)):
    print("threshold {} : {}".format(0.01*i, res[i]))

threshold 0.0 : 0.0
threshold 0.01 : 0.21537761813686615
threshold 0.02 : 0.2702900192590914
threshold 0.03 : 0.3478045302181141
threshold 0.04 : 0.4504962492530716
threshold 0.05 : 0.5266289703827606
threshold 0.06 : 0.5931201505185738
threshold 0.07 : 0.6493973481541706
threshold 0.08 : 0.6962829065012204
threshold 0.09 : 0.7372290232508544
threshold 0.1 : 0.767757237708723
threshold 0.11 : 0.787306215110945
threshold 0.12 : 0.8022156840531611
threshold 0.13 : 0.8203998649420116
threshold 0.14 : 0.8339665287209251
threshold 0.15 : 0.8480489896196381
threshold 0.16 : 0.854399158553191
threshold 0.17 : 0.8605775061809019
threshold 0.18 : 0.8669450016326907
threshold 0.19 : 0.8728360178329855
threshold 0.2 : 0.8847826521568181
threshold 0.21 : 0.8942082780772895
threshold 0.22 : 0.904959382642827
threshold 0.23 : 0.9107031234381142
threshold 0.24 : 0.9139431823482763
threshold 0.25 : 0.9201287493585856
threshold 0.26 : 0.9252833885338435
threshold 0.27 : 0.9301434768990864
threshold 0.2

## 3. Evaluation(2): part label transfer

In [7]:
# chair
src_id = ['11d9817e65d7ead6b87028a4b477349f','10a1783f635b3fc181dff5c2e57ad46e','10dc303144fe5d668d1b9a1d97e2846']
path_ptl = './examples/chair_ptl'
idlist = sorted(os.listdir(path_ptl))
tgt_id = [x for x in idlist if x not in src_id]

In [8]:
import statistics as stat
from utils_eval import part_label_transfer

res = part_label_transfer(path_ptl, src_id, tgt_id, 'chair', bsize=5)
print(stat.mean(res[-2]))

0.8297882430332223
