In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('../..')

In [3]:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [4]:
from src.infra import config

path = '../../config/vec.yaml'
opt = config.load_config(path)
opt.path = path

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


## Dataset & dataloader

In [5]:
from pathlib import Path

data_root = Path('/home/knpob/Documents/Hinton/data/shape-corr/FAUST_r')

In [38]:
from src.dataset.shape_cor_fast import PairFaustDatasetFast

dataset = PairFaustDatasetFast(
    data_root=data_root,
    phase='train',
    return_faces=True,
    return_L=True,
    return_mass=True,
    num_evecs=200,
    return_evecs=True,
    return_grad=True,
    return_corr=True,
    return_dist=True,
)
len(dataset)

6400

In [39]:
from src.dataloader.shape_cor_batch import BatchShapePairDataLoader

dataloader = BatchShapePairDataLoader(
    dataset,
    batch_size=8,
    shuffle=False,
    num_workers=0,
)
len(dataloader)

800

In [40]:
from src.utils.tensor import to_device

for idx, data in enumerate(dataloader):
    data = to_device(data, device)
    if idx == 1:
        break

data['first']['name'], data['second']['name'], 

(['tr_reg_000',
  'tr_reg_000',
  'tr_reg_000',
  'tr_reg_000',
  'tr_reg_000',
  'tr_reg_000',
  'tr_reg_000',
  'tr_reg_000'],
 ['tr_reg_008',
  'tr_reg_009',
  'tr_reg_010',
  'tr_reg_011',
  'tr_reg_012',
  'tr_reg_013',
  'tr_reg_014',
  'tr_reg_015'])

## `corr2pointmap`

In [44]:
from src.utils.fmap import corr2pointmap_vectorized

p2p = corr2pointmap_vectorized(
    data['first']['corr'],
    data['second']['corr'],
    max(data['second']['num_verts']),
)
p2p.shape, p2p

(torch.Size([8, 5001]),
 tensor([[   0,   28,    2,  ...,   -1,   -1,   -1],
         [   0, 2383,   96,  ...,   -1,   -1,   -1],
         [   0,   28,    2,  ..., 1949,   -1,   -1],
         ...,
         [  71,   16,    2,  ...,   -1,   -1,   -1],
         [   0,    1,  118,  ..., 1949,   -1,   -1],
         [  70,   28,  118,  ...,   -1,   -1,   -1]], device='cuda:0'))

Surprisingly, the ground-truth `p2p` has a ~4.5 geodesic error on the remeshed FAUST dataset!

In [45]:
from src.metric.geodist import GeodesicDist_vectorized

geodist = GeodesicDist_vectorized()
err = geodist.geodesic_error(
    dist_x=data['first']['dist'],
    corr_x=data['first']['corr'],
    corr_y=data['second']['corr'],
    p2p=p2p,
)
err.shape, err, err.mean(axis=1)

(torch.Size([8, 5000]),
 tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0169, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0114],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0103]],
        device='cuda:0'),
 tensor([0.0046, 0.0045, 0.0045, 0.0045, 0.0046, 0.0046, 0.0045, 0.0045],
        device='cuda:0'))

In [52]:
from src.metric.geodist import GeodesicDist

geodist = GeodesicDist()
err = geodist.calculate_geodesic_error(
    dist_x=data['first']['dist'][0].cpu().numpy(),
    corr_x=data['first']['corr'][0].cpu().numpy(),
    corr_y=data['second']['corr'][0].cpu().numpy(),
    p2p=p2p[0].cpu().numpy(),
)
err.shape, err, err.mean()

((), np.float32(0.004646266), np.float32(0.004646266))

## LBO

In [None]:
import pyvista as pv
pv.set_jupyter_backend('html')

In [None]:
lbo_x, lbo_y = data['first']['evecs'], data['second']['evecs']
lbo_x.shape, lbo_y.shape

(torch.Size([8, 4999, 200]), torch.Size([8, 5001, 200]))

In [None]:
import numpy as np

batch_interval = 10
k = 10
shape_disp = np.array([1, 0, 0])
pair_disp = np.array([0, 2, 0])
cammera_position = 'xy'
window_size = [1024, 512]
output_folder = Path('output/lbo-samples')
output_folder.mkdir(parents=True, exist_ok=True)

In [None]:
for idx in range(len(lbo_x)):
    pl = pv.Plotter()

    name_x, name_y = data['first']['name'][idx], data['second']['name'][idx]
    num_verts_x = data['first']['num_verts']
    num_verts_y = data['second']['num_verts']
    mesh_x = pv.read(data_root / 'off' / f"{name_x}.off")
    mesh_y = pv.read(data_root / 'off' / f"{name_y}.off")
    lbo_min = min(lbo_x[idx, :, :k].min(), lbo_y[idx, :, :k].min()).cpu().numpy()
    lbo_max = max(lbo_x[idx, :, :k].max(), lbo_y[idx, :, :k].max()).cpu().numpy()

    for dim in range(k):
        mesh_x[f'lbo-{dim}'] = (lbo_x[idx, :num_verts_x[idx], dim]).cpu().numpy()
        mesh_y[f'lbo-{dim}'] = (lbo_y[idx, :num_verts_y[idx], dim]).cpu().numpy()

        pl.add_mesh(
            mesh=mesh_x.translate(dim * shape_disp),
            scalars=f'lbo-{dim}',
            cmap='coolwarm',
            clim=[lbo_min, lbo_max],
            show_scalar_bar=False,
        )
        pl.add_mesh(
            mesh=mesh_y.translate(dim * shape_disp).translate(pair_disp),
            scalars=f'lbo-{dim}',
            cmap='coolwarm',
            clim=[lbo_min, lbo_max],
            show_scalar_bar=False,
        )

    pl.camera_position = 'xy'
    pl.zoom_camera(2)
    pl.screenshot(output_folder / f'lbo-{idx}.png', window_size=window_size, return_img=False)

[0m[33m2025-07-22 14:04:56.965 (   2.019s) [    7DC8B828E440]vtkXOpenGLRenderWindow.:1416  WARN| bad X server connection. DISPLAY=[0m
