# Diffumatch: demo

We show a simple pair example of diffu_match, and observe the effects of the mask, and the results after zero-shot optimization.

First loading necessary libraries

In [None]:
import numpy as np
import os
from omegaconf import OmegaConf
import matplotlib.pyplot as plt
import potpourri3d as pp3d
import torch
from pathlib import Path
import scipy
import random
import importlib

In [None]:
from utils.geometry import compute_operators, load_operators
from utils.mesh import load_mesh 
from utils.surfaces import opt_rot_surf, Surface, centroid
from utils.utils_func import convert_dict
from utils.torch_fmap import torch_zoomout, knnsearch, extract_p2p_torch_fmap
import meshplot as mp
from utils.eval import accuracy

In [None]:
import notebook_helpers as helper
import utils.meshplot as plot_helper
device = "cuda:0"
helper.device = device
cache_dir = "cache/fmaps_demo"
os.makedirs(cache_dir, exist_ok=True)

In [None]:
importlib.reload(helper)
importlib.reload(plot_helper)

Loading data. 
The first cell is voluntarily commented out, as we want to show the results on a new category in this demo. You can try example from the datasets of the paper in the cell or add new examples by modifying notebook_helpers.py

In [None]:
# demo_dataset = "SCAPE_r_ori"
# demo_dataset = "DT4D_r_ori"
# if demo_dataset == "SCAPE_r_ori":
#     cache_data = os.path.join(cache_dir, "SCAPE_ori")
#     id_1, id_2 = 52, 53
#     shape_surf, target_surf, shape_dict, target_dict, map_info = helper.load_pair(cache_data, id_1, id_2, "", "", demo_dataset)
# else:
#     cache_data = os.path.join("cache/fmaps/DT4D_ori")
#     id_1, id_2 = 0, 0
#     name_1, name_2 = "mannequin/Running047", "crypto/Standing2HMagicAttack01034"
#     shape_surf, target_surf, shape_dict, target_dict, map_info = helper.load_pair(cache_data, id_1, id_2, name_1, name_2, demo_dataset)

In [None]:
file_source = "example/cactus.off"
shape_dict, _ = helper.load_data(file_source, os.path.join(cache_dir, "cactus.npz"), "source")
shape_surf = Surface(filename=file_source)

file_target = "example/cactus_deformed.off"
target_dict, _ = helper.load_data(file_target, os.path.join(cache_dir, "cactus_deformed.npz"), "target")
target_surf = Surface(filename=file_target)

The target mesh is on the left, uncolored. The source mesh, with coloration, on the right. We also show the two meshes overlayed.

In [None]:
cmap1 = plot_helper.visu_pts(shape_surf)/255.
plot_helper.double_plot_surf(target_surf, shape_surf, None, cmap1)

In [None]:
scene, _ = plot_helper.overlay_surf(shape_dict["vertices"], shape_surf.faces, target_dict["vertices"], target_surf.faces)
scene

Creating the zero_shot helper class to match shapes with SNK

In [None]:
import zero_shot
cfg = OmegaConf.load("config/matching/snk.yaml")
matcher = zero_shot.Matcher(cfg)

Launching the SNK optimization loop.

In [None]:
target_normals = torch.from_numpy(target_surf.surfel/np.linalg.norm(target_surf.surfel, axis=-1, keepdims=True)).float().to(device)
C12_new, p2p, p2p_init, snk_rec, loss_save = matcher.optimize(shape_dict, target_dict, target_normals)

In [None]:
p2p_new, _ = extract_p2p_torch_fmap(C12_new, shape_dict["evecs"], target_dict["evecs"])
cmap2 = cmap1[p2p_new]
plot_helper.double_plot_surf(target_surf, shape_surf,cmap2, cmap1)

We can apply Zoomout to the computed map. 

In [None]:
evecs1, evecs2 = shape_dict["evecs"], target_dict["evecs"]
evecs_2trans = evecs2.t() @ torch.diag(target_dict["mass"])
evecs_1trans = evecs1.t() @ torch.diag(shape_dict["mass"])
C12_end_zo = torch_zoomout(evecs1, evecs2, evecs_2trans, C12_new.squeeze()[:15, :15], 150)# matcher.cfg.sds_conf.zoomout)
p2p_zo, _ = extract_p2p_torch_fmap(C12_end_zo, shape_dict["evecs"], target_dict["evecs"])
cmap2 = cmap1[p2p_zo]
plot_helper.double_plot_surf(target_surf, shape_surf,cmap2, cmap1)

SNK also has a reconstruction module (decoder loss). Let's look at the reconstruction from SNK!

In [None]:
rec_surf = Surface(FV=[target_surf.faces, snk_rec.squeeze().detach().cpu().numpy()])
scene, _ = plot_helper.overlay_surf(snk_rec.squeeze().detach().cpu().numpy(), target_surf.faces, shape_dict["vertices"], shape_surf.faces)
scene