In [None]:
import sys
sys.path.append('../utils')
import fvdb
import fvdb_utils
import mesh_tools as mt
import igl
import numpy as np
from ssfid import calculate_activation_statistics, calculate_frechet_distance
from patch_utils import pairwise_IoU_dist
from classifier3D import classifier
import torch
import os
import glob
from tqdm import tqdm
import trimesh 

device='cuda'

In [None]:
voxel_size = 256
model = classifier(voxel_size=voxel_size)

weights_path = 'Clsshapenet_'+str(voxel_size)+'.pth'
if not os.path.exists(weights_path):
    raise RuntimeError(
        f"'{weights_path}' not exists. Please download it from https://drive.google.com/file/d/1HjnDudrXsNY4CYhIGhH4Q0r3-NBnBaiC/view?usp=sharing.")
model.load_state_dict(torch.load(weights_path, weights_only=False))
model.to(device)
model.eval()

names = os.listdir("../../data/GT_WT")
names.sort()

In [None]:
def sample_grid_points_aabb(aabb, resolution):
    aabb_min, aabb_max = aabb[:3], aabb[3:]
    aabb_size = aabb_max - aabb_min
    resolutions = (resolution * aabb_size / aabb_size.max()).astype(np.int32)

    xs = np.linspace(0.5, resolutions[0] - 0.5, resolutions[0]) / resolutions[0] * aabb_size[0] + aabb_min[0]
    ys = np.linspace(0.5, resolutions[1] - 0.5, resolutions[1]) / resolutions[1] * aabb_size[1] + aabb_min[1]
    zs = np.linspace(0.5, resolutions[2] - 0.5, resolutions[2]) / resolutions[2] * aabb_size[2] + aabb_min[2]
    grid_points = np.stack(np.meshgrid(xs, ys, zs, indexing='ij'), axis=-1)
    return grid_points

def normalize_aabb(v, reso, enlarge_scale=1.03, mult=8):
    aabb_min = np.min(v, axis=0)
    aabb_max = np.max(v, axis=0)
    center = (aabb_max + aabb_min) / 2
    bbox_size = (aabb_max - aabb_min).max() * enlarge_scale
    translation = -center
    scale = 1.0 / bbox_size * 2
    # v = (v + translation) * scale
    # v = (v - center) / bbox_size * 2
    aabb_min = (aabb_min * enlarge_scale - center) / bbox_size * 2
    aabb_max = (aabb_max * enlarge_scale - center) / bbox_size * 2
    aabb = np.concatenate([aabb_min, aabb_max], axis=0)
    aabb_size = aabb_max - aabb_min
    fm_size = (reso * aabb_size / aabb_size.max()).astype(np.int32)
    # round to multiple of 8
    fm_size = (fm_size + mult - 1) // mult * mult
    aabb_max = fm_size / fm_size.max()
    aabb = np.concatenate([-aabb_max, aabb_max], axis=0)
    return aabb, translation, scale

In [None]:
def get_dense_array_tight(shape_path, pts=None, translation=None, scale=None):
    ms = trimesh.load(shape_path)
    v = ms.vertices
    f = ms.faces
    # v, f = igl.read_triangle_mesh(shape_path)
    if pts is None:
        out_grid = True
        v = 2*mt.NDCnormalize(v)
        aabb, translation, scale = normalize_aabb(v, 256)
        v = (v+translation)*scale
        pts = sample_grid_points_aabb(aabb, 256)

    else:
        out_grid = False
        v = (v+translation)*scale
        
    sdf_compute = igl.fast_winding_number_for_meshes(v, f, pts.reshape(-1, 3))>.5
    or_sdfgrid = sdf_compute.reshape(pts.shape[:-1])
    new_shape = [int(x * voxel_size / max(or_sdfgrid.shape))
                        for x in or_sdfgrid.shape]

    or_sdfgrid = torch.nn.functional.adaptive_max_pool3d(1.*torch.tensor(or_sdfgrid[None, None], device=device), new_shape)[0, 0]>0
    or_sdfgrid = 1.*or_sdfgrid
    if out_grid:
        return or_sdfgrid, pts, translation, scale
    return or_sdfgrid


In [None]:

def get_metrics(globe, mesh_name):
    shapes = glob.glob(globe.format(mesh_name))
    shapes.sort()
    # print(shapes)
    or_sdfgrid, pts, translation, scale = get_dense_array_tight("../../data/GT_WT/{}.obj".format(mesh_name))
    mu_r, sigma_r = calculate_activation_statistics(or_sdfgrid, model)
    ssfid_values = []
    grids = []
    for shape in shapes:
        dense_grid = get_dense_array_tight(shape, pts, translation, scale)
        mu_f, sigma_f = calculate_activation_statistics(dense_grid, model)
        ssfid = calculate_frechet_distance(mu_r, sigma_r, mu_f, sigma_f)
        print(ssfid)
        ssfid_values.append(ssfid)
        grids.append(dense_grid)
    return np.mean(ssfid_values).round(4), pairwise_IoU_dist(torch.stack(grids, dim=0)).round(4)



### Ours

In [None]:
names = [e[:-4] for e in os.listdir('../../data/GT_WT')]
names.sort()
str_path = '../../meshed_output/{}/**/large_mesh.ply'

ours_ssfid = []
ours_iou = []
for name in tqdm(names):
    ssfid, io = get_metrics(str_path, name)
    ours_ssfid.append(ssfid)
    ours_iou.append(io)
  

In [None]:
ssfid_data = ours_ssfid + [np.array(ours_ssfid).mean()]
iou_data = ours_iou + [np.array(ours_iou).mean()]

In [None]:
ssfid_data = ours_ssfid + [np.array(ours_ssfid).mean()]
iou_data = ours_iou + [np.array(ours_iou).mean()]
print(r"\begin{tabular}{ l |c| " +"".join([' c ' for e in names]) + r" | c }")

print(r'Metric & Method               & {} & mean \\'.format(' & '.join(names)))
print(r'\hline')
print(r'G-Qual. $\downarrow$ & ours & {}\\'.format('   &    '.join(['{:.2f}' for e in ssfid_data]).format(*ssfid_data)))
print(r'\hline') 
print(r'G-Div. $\uparrow$ & ours    & {}\\'.format('   &    '.join(['{:.2f}' for e in iou_data]).format(*iou_data)))
print(r"\end{tabular}")
