In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('../..')

In [3]:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [4]:
from src.infra import config

path = '../../config/vec.yaml'
opt = config.load_config(path)
opt.path = path

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


## Dataset & dataloader

In [5]:
from pathlib import Path

data_root = Path('/home/knpob/Documents/Hinton/data/shape-corr/FAUST_r')

In [6]:
from src.dataset.shape_cor_fast import PairFaustDatasetFast

dataset = PairFaustDatasetFast(
    data_root=data_root,
    phase='train',
    return_faces=True,
    return_L=True,
    return_mass=True,
    num_evecs=200,
    return_evecs=True,
    return_grad=True,
    return_corr=True,
    return_dist=True,
)
len(dataset)

6400

In [7]:
from src.dataloader.shape_cor_batch import BatchShapePairDataLoader

dataloader = BatchShapePairDataLoader(
    dataset,
    batch_size=8,
    shuffle=False,
    num_workers=0,
)
len(dataloader)

800

In [8]:
from src.utils.tensor import to_device

for idx, data in enumerate(dataloader):
    data = to_device(data, device)
    if idx == 1:
        break

data['first']['name'], data['second']['name'], 

(['tr_reg_000',
  'tr_reg_000',
  'tr_reg_000',
  'tr_reg_000',
  'tr_reg_000',
  'tr_reg_000',
  'tr_reg_000',
  'tr_reg_000'],
 ['tr_reg_008',
  'tr_reg_009',
  'tr_reg_010',
  'tr_reg_011',
  'tr_reg_012',
  'tr_reg_013',
  'tr_reg_014',
  'tr_reg_015'])

## Create point-to-point maps

In [10]:
from src.utils.fmap import corr2pointmap_vectorized

p2p = corr2pointmap_vectorized(
    data['first']['corr'],
    data['second']['corr'],
    max(data['second']['num_verts']),
)
p2p.shape, p2p

(torch.Size([8, 5001]),
 tensor([[   0,   28,    2,  ...,   -1,   -1,   -1],
         [   0, 2383,   96,  ...,   -1,   -1,   -1],
         [   0,   28,    2,  ..., 1949,   -1,   -1],
         ...,
         [  71,   16,    2,  ...,   -1,   -1,   -1],
         [   0,    1,  118,  ..., 1949,   -1,   -1],
         [  70,   28,  118,  ...,   -1,   -1,   -1]], device='cuda:0'))

### Missing points

In [None]:
batch_no_corr, vtx_no_corr = torch.where(p2p == -1)
batch_no_corr, vtx_no_corr

(tensor([0, 0, 0,  ..., 7, 7, 7], device='cuda:0'),
 tensor([   4,   18,   35,  ..., 4998, 4999, 5000], device='cuda:0'))

In [None]:
B, V, _ = data['second']['verts'].shape

corr_miss = torch.ones((V, V))
corr_miss[batch_no_corr, vtx_no_corr] = 0
corr_miss

tensor([[1., 1., 1.,  ..., 0., 0., 0.],
        [1., 1., 1.,  ..., 0., 0., 0.],
        [1., 1., 1.,  ..., 1., 0., 0.],
        ...,
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.]])

In [None]:
import pyvista as pv
pv.set_jupyter_backend('html')

idx = 0

name_x, name_y = data['first']['name'], data['second']['name']
verts_num_x = data['first']['num_verts']
verts_num_y = data['second']['num_verts']
faces_num_x = data['first']['num_faces']
faces_num_y = data['second']['num_faces']

output_path = Path('output/gt-texture-transfer/missing')
output_path.mkdir(parents=True, exist_ok=True)

for idx in range(B):
    mesh_y = pv.read(data_root / 'off' / f'{name_y[idx]}.off')
    mesh_y['missing'] = corr_miss[idx, :verts_num_y[idx]].cpu().numpy()
    
    pl = pv.Plotter(off_screen=True)
    pl.add_mesh(
        mesh=mesh_y,
        scalars='missing',
    )
    pl.camera_position = 'xy'
    pl.screenshot(output_path / f'{idx}', window_size=[1024, 1024], return_img=False)
    pl.close()

### Single incomplete map filling

Treat the incomplete ground-truth map as a set of incomplete correspondences labeling. Then construct corresponding point-indicator functions for functional maps solver

In [31]:
idx = 1
vtx_y_ls = torch.where(p2p[idx] != -1)
vtx_x_ls = p2p[idx][vtx_y_ls]
vtx_y_ls, vtx_x_ls

((tensor([   0,    1,    2,  ..., 4994, 4995, 4997], device='cuda:0'),),
 tensor([   0, 2383,   96,  ..., 4997, 2760, 1949], device='cuda:0'))

In [32]:
B, V_x, _ = data['first']['verts'].shape
_, V_y, _ = data['second']['verts'].shape

delta_x = torch.eye(V_x, V_x).to(device=p2p.device)
delta_y = torch.eye(V_y, V_y).to(device=p2p.device)

feat_x = delta_x[vtx_x_ls].T # [V, C]
feat_y = delta_y[vtx_y_ls].T # [V, C]
feat_x.shape, feat_y.shape

(torch.Size([4999, 3503]), torch.Size([5001, 3503]))

In [66]:
from src.module.fmap import RegularizedFMNet_vectorized

fmap_solver = RegularizedFMNet_vectorized()
Cxy = fmap_solver(
    feat_x=feat_x.unsqueeze(0),
    feat_y=feat_y.unsqueeze(0),
    evals_x=data['first']['evals'][idx].unsqueeze(0),
    evals_y=data['second']['evals'][idx].unsqueeze(0),
    evecs_trans_x=data['first']['evecs_trans'][idx].unsqueeze(0),
    evecs_trans_y=data['second']['evecs_trans'][idx].unsqueeze(0),
)[0]
Cxy

tensor([[[-8.7579e-01,  3.0815e-05,  4.0791e-06,  ...,  5.5985e-08,
          -4.2601e-08,  1.0078e-07],
         [-6.3886e-06,  4.3833e-01, -1.6123e-04,  ...,  8.7883e-08,
           6.0843e-08, -3.3812e-08],
         [-4.0919e-06,  4.3851e-04,  8.6647e-01,  ...,  1.0924e-08,
          -8.6053e-09,  9.2009e-10],
         ...,
         [ 2.8654e-08, -1.5024e-08,  9.5501e-08,  ...,  1.3947e-01,
          -1.3023e-01,  6.2209e-03],
         [-6.8533e-08, -8.2695e-08,  2.4661e-08,  ..., -3.5313e-02,
           1.8958e-01,  3.3373e-02],
         [ 5.1464e-08,  7.2539e-08, -4.6112e-09,  ...,  1.2346e-01,
           6.5107e-02,  5.3503e-02]]], device='cuda:0')

In [None]:
from src.utils.fmap import fmap2pointmap_vectorized

evecs_x = data['first']['evecs'][idx].unsqueeze(0)  # [1, V_y, k]
evecs_y = data['second']['evecs'][idx].unsqueeze(0)  # [1, V_y, k]

p2p_fill = fmap2pointmap_vectorized(
    Cxy=Cxy,
    evecs_x=evecs_x,
    evecs_y=evecs_y,
    verts_mask_x=torch.ones((1, V_x), dtype=torch.bool, device=device),
    verts_mask_y=torch.ones((1, V_y), dtype=torch.bool, device=device),
)
p2p_fill, (p2p_fill == -1).sum()

(tensor([[   0,   16,  145,  ..., 1908, 1908, 3540]], device='cuda:0'),
 tensor(0, device='cuda:0'))

In [None]:
p2p_fill, (p2p_fill == -1).sum()

(tensor([[   0,   16,  145,  ..., 1908, 1908, 3540]], device='cuda:0'),
 tensor(0, device='cuda:0'))

In [None]:
from src.utils.fmap import pointmap2Pyx_smooth_vectorized, pointmap2Pyx_vectorized

evecs_trans_x = data['first']['evecs_trans'][idx].unsqueeze(0)  # [1, V_x, k]
evecs_trans_y = data['second']['evecs_trans'][idx].unsqueeze(0)  # [1, V_y, k]

Pyx_fill = pointmap2Pyx_smooth_vectorized(
    p2p=p2p_fill,
    evecs_x=evecs_x,
    evecs_y=evecs_y,
    evecs_trans_x=evecs_trans_x,
    evecs_trans_y=evecs_trans_y,
)

In [None]:
Pyx_fill.sum(axis=-1), Pyx_fill.sum(axis=-2)

(tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 0.0000]],
        device='cuda:0'),
 tensor([[1.2589, 1.2118, 1.4185,  ..., 0.7747, 1.4900, 1.8969]],
        device='cuda:0'))

Validate that padding zero-channels to `feat_x` `feat_y` does not change the functional map solver output:

In [None]:
V = 5000
feat_x_ = torch.zeros((V_x, V), device=feat_x.device)
feat_y_ = torch.zeros((V_y, V), device=feat_y.device)
feat_x_[:feat_x.shape[0], :feat_x.shape[1]] = feat_x
feat_y_[:feat_y.shape[0], :feat_y.shape[1]] = feat_y

In [69]:
from src.module.fmap import RegularizedFMNet_vectorized

fmap_solver = RegularizedFMNet_vectorized()
Cxy_ = fmap_solver(
    feat_x=feat_x_.unsqueeze(0),
    feat_y=feat_y_.unsqueeze(0),
    evals_x=data['first']['evals'][idx].unsqueeze(0),
    evals_y=data['second']['evals'][idx].unsqueeze(0),
    evecs_trans_x=data['first']['evecs_trans'][idx].unsqueeze(0),
    evecs_trans_y=data['second']['evecs_trans'][idx].unsqueeze(0),
)[0]
Cxy_

tensor([[[-8.7579e-01,  3.0815e-05,  4.0791e-06,  ...,  5.5985e-08,
          -4.2601e-08,  1.0078e-07],
         [-6.3886e-06,  4.3833e-01, -1.6122e-04,  ...,  8.7883e-08,
           6.0843e-08, -3.3812e-08],
         [-4.0920e-06,  4.3851e-04,  8.6647e-01,  ...,  1.0924e-08,
          -8.6053e-09,  9.2007e-10],
         ...,
         [ 2.8654e-08, -1.5024e-08,  9.5501e-08,  ...,  1.3947e-01,
          -1.3023e-01,  6.2209e-03],
         [-6.8533e-08, -8.2695e-08,  2.4661e-08,  ..., -3.5313e-02,
           1.8958e-01,  3.3373e-02],
         [ 5.1464e-08,  7.2539e-08, -4.6112e-09,  ...,  1.2346e-01,
           6.5107e-02,  5.3503e-02]]], device='cuda:0')

In [70]:
(Cxy - Cxy_).norm()

tensor(5.9954e-07, device='cuda:0')

### Batched incomplete maps filling

In [88]:
B, V_x, _ = data['first']['verts'].shape
_, V_y, _ = data['second']['verts'].shape
_, V_t = data['first']['corr'].shape
B, V_x, V_y, V_t

(8, 4999, 5001, 5000)

In [76]:
delta_x = torch.eye(V_x, V_x).repeat(B, 1, 1).to(device=p2p.device) # [B, V_x, V_x]
delta_y = torch.eye(V_y, V_y).repeat(B, 1, 1).to(device=p2p.device) # [B, V_x, V_x]
delta_x.shape, delta_y.shape

(torch.Size([8, 4999, 4999]), torch.Size([8, 5001, 5001]))

In [95]:
corr_x = data['first']['corr']  # [B, V_t]
corr_y = data['second']['corr']  # [B, V_t]
batch_idx = torch.arange(B, device=corr_y.device).unsqueeze(1).expand(B, V_t)  # [B, V_t]

feat_x = delta_x[batch_idx, corr_x].transpose(1, 2)  # [B, V_t, V_x] -> [B, V_x, V_t]
feat_y = delta_y[batch_idx, corr_y].transpose(1, 2)  # [B, V_t, V_y] -> [B, V_y, V_t]

feat_x.shape, feat_y.shape # p.s. now we get V_t channels of corresponding point indicator functions

(torch.Size([8, 4999, 5000]), torch.Size([8, 5001, 5000]))

In [98]:
from src.module.fmap import RegularizedFMNet_vectorized

fmap_solver = RegularizedFMNet_vectorized()
Cxy, _ = fmap_solver(
    feat_x=feat_x,
    feat_y=feat_y,
    evals_x=data['first']['evals'],
    evals_y=data['second']['evals'],
    evecs_trans_x=data['first']['evecs_trans'],
    evecs_trans_y=data['second']['evecs_trans'],
)
Cxy.shape

torch.Size([8, 200, 200])

In [100]:
from src.utils.fmap import fmap2pointmap_vectorized

p2p_fill = fmap2pointmap_vectorized(
    Cxy=Cxy,
    evecs_x=data['first']['evecs'],
    evecs_y=data['second']['evecs'],
    verts_mask_x=data['first']['verts_mask'],
    verts_mask_y=data['second']['verts_mask'],
)
p2p_fill, (p2p_fill == -1).sum()

(tensor([[ 109,    1,    2,  ..., 1945, 1945,   -1],
         [   0,   16,  145,  ..., 1908, 1908,   -1],
         [  16, 2383,    6,  ..., 2001, 2001,   -1],
         ...,
         [2383, 2383,  107,  ..., 4322, 4322,   -1],
         [  20, 2383,    6,  ..., 2101, 2101, 2101],
         [  16, 2383,  107,  ..., 4629, 4629,   -1]], device='cuda:0'),
 tensor(9, device='cuda:0'))

In [101]:
Pyx_fill = pointmap2Pyx_smooth_vectorized(
    p2p=p2p_fill,
    evecs_x=data['first']['evecs'],
    evecs_y=data['second']['evecs'],
    evecs_trans_x=data['first']['evecs_trans'],
    evecs_trans_y=data['second']['evecs_trans'],
)

Pyx_fill.sum(axis=-1), Pyx_fill.sum(axis=-2)

(tensor([[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 0.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 0.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 0.0000],
         ...,
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 0.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 0.0000]],
        device='cuda:0'),
 tensor([[0.8896, 0.8124, 1.2605,  ..., 0.2357, 0.7723, 1.6971],
         [1.2698, 1.2414, 1.4175,  ..., 0.7900, 1.4760, 1.7978],
         [1.2106, 1.1001, 1.7506,  ..., 1.4366, 1.2905, 1.2460],
         ...,
         [1.7351, 1.8086, 1.3986,  ..., 0.6672, 1.2421, 0.1283],
         [0.9146, 0.7839, 1.5355,  ..., 1.2225, 1.5520, 2.0346],
         [1.4698, 1.4486, 1.7543,  ..., 2.7880, 1.7022, 0.8091]],
        device='cuda:0'))

## Texture transfer

In [105]:
import pyvista as pv
pv.set_jupyter_backend('html')

from src.utils.fmap import pointmap2Pyx_vectorized, pointmap2Pyx_smooth_vectorized
from src.utils.texture import write_obj_pair
from src.utils.tensor import to_numpy

output_path = Path('output/gt-texture-transfer')
output_path.mkdir(parents=True, exist_ok=True)

Pyx = pointmap2Pyx_smooth_vectorized(
    p2p=p2p_fill,
    evecs_x = data['first']['evecs'],
    evecs_y = data['second']['evecs'],
    evecs_trans_x = data['first']['evecs_trans'],
    evecs_trans_y = data['second']['evecs_trans'],
)

# Pyx = pointmap2Pyx_vectorized(
#     p2p=p2p_fill,
#     num_verts_x = data['first']['num_verts'],
#     num_verts_y = data['second']['num_verts'],
# )

name_x, name_y = data['first']['name'], data['second']['name']
verts_num_x = data['first']['num_verts']
verts_num_y = data['second']['num_verts']
faces_num_x = data['first']['num_faces']
faces_num_y = data['second']['num_faces']

for idx in range(len(data['first']['name'])):
    # export mesh w/ texture
    write_obj_pair(
        file_name1=str(output_path / f'{name_x[idx]}.obj'),
        file_name2=str(output_path / f'{name_x[idx]}--{name_y[idx]}.obj'),
        faces1=to_numpy(data['first']['faces'][idx, :faces_num_x[idx]]),
        verts1=to_numpy(data['first']['verts'][idx, :verts_num_x[idx]]),
        verts2=to_numpy(data['second']['verts'][idx, :verts_num_y[idx]]),
        faces2=to_numpy(data['second']['faces'][idx, :faces_num_y[idx]]),
        Pyx=to_numpy(Pyx[idx, :verts_num_y[idx], :verts_num_x[idx]]),
        texture_file=str(output_path / 'texture.png'),
    )

    # render texture transfer
    pl = pv.Plotter(off_screen=True)
    pl.add_mesh(
        mesh=pv.read(output_path / f'{name_x[idx]}.obj'),
        texture=pv.read_texture(output_path / 'texture.png'),
    )
    pl.add_mesh(
        mesh=pv.read(output_path / f'{name_x[idx]}--{name_y[idx]}.obj').translate([1, 0, 0]),
        texture=pv.read_texture(output_path / 'texture.png'),
    )
    pl.camera_position = 'xy'
    pl.screenshot(output_path / f'{idx}', window_size=[1024, 1024], return_img=False)
    pl.close()

    # remove the mesh files if not needed
    (output_path / f'{name_x[idx]}.obj').unlink(missing_ok=True)
    (output_path / f'{name_x[idx]}.mtl').unlink(missing_ok=True)
    (output_path / f'{name_x[idx]}--{name_y[idx]}.obj').unlink(missing_ok=True)
    (output_path / f'{name_x[idx]}--{name_y[idx]}.mtl').unlink(missing_ok=True)

## Benchmarking ground-truth correspondence

Surprisingly, the ground-truth `p2p` has a ~4.5 geodesic error on the remeshed FAUST dataset!

In [10]:
from src.metric.geodist import GeodesicDist_vectorized

geodist = GeodesicDist_vectorized()
err = geodist.geodesic_error(
    dist_x=data['first']['dist'],
    corr_x=data['first']['corr'],
    corr_y=data['second']['corr'],
    p2p=p2p,
)
err.shape, err, err.mean(axis=1)

(torch.Size([8, 5000]),
 tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0169, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0114],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0070, 0.0114],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0103]],
        device='cuda:0'),
 tensor([0.0046, 0.0045, 0.0045, 0.0045, 0.0046, 0.0047, 0.0046, 0.0044],
        device='cuda:0'))

In [11]:
from src.metric.geodist import GeodesicDist

geodist = GeodesicDist()
err = geodist.calculate_geodesic_error(
    dist_x=data['first']['dist'][0].cpu().numpy(),
    corr_x=data['first']['corr'][0].cpu().numpy(),
    corr_y=data['second']['corr'][0].cpu().numpy(),
    p2p=p2p[0].cpu().numpy(),
)
err.shape, err, err.mean()

((), np.float32(0.0046353377), np.float32(0.0046353377))