# How to learn feature for functional maps

In this notebook, we show how to use deep functional maps to learn feature for 3d shape matching.

In [1]:
# Set the backend for geomstats to PyTorch, commented for github test
import os

os.environ["GEOMSTATS_BACKEND"] = "pytorch"

import torch
from torch.utils.data import random_split

from geomfum.convert import P2pFromFmConverter
from geomfum.dataset.torch import PairsDataset, ShapeDataset
from geomfum.descriptor.learned import FeatureExtractor
from geomfum.forward_functional_map import ForwardFunctionalMap
from geomfum.learning.losses import (
    BijectivityLoss,
    GeodesicError,
    LaplacianCommutativityLoss,
    LossManager,
    OrthonormalityLoss,
)
from geomfum.learning.models import FMNet
from geomfum.learning.trainer import DeepFunctionalMapTrainer


First, we define our model. We can instantiate it combining feature extractors and forward logic, however, we provide some classic frameworks, like FMNet.

In [2]:
# Build the model
fmap_module = ForwardFunctionalMap(1e3, 1, True)

feature_extractor = FeatureExtractor.from_registry(
    which="diffusionnet",
    device="cuda",
    k_eig=200,
)

functional_map_model = FMNet(
    feature_extractor=feature_extractor,
    fmap_module=fmap_module,
    converter=P2pFromFmConverter(),
)


Then, we instantiate the training dataset and we split it for training purposes.

In [3]:
TRAIN_SET_PATH = "../../../datasets/smal/test_set/"
dataset = ShapeDataset(
    TRAIN_SET_PATH, spectral=True, distances=False, device="cuda", k=30
)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_shapes, validation_shapes = random_split(dataset, [train_size, val_size])


train_dataset = PairsDataset(
    train_shapes,
    pair_mode="all",
)

validation_dataset = PairsDataset(
    validation_shapes,
    pair_mode="all",
)


  from .autonotebook import tqdm as notebook_tqdm
  return _torch.sparse_csc_tensor(ccol_indices, row_indices, values, size=array.shape)


Sometimes the distance computation is usefull only at validation time, so we need to perform the following trick

In [None]:
from torch.utils.data import Subset


TRAIN_SET_PATH = "../../../datasets/smal/test_set/"
dataset1 = ShapeDataset(
    TRAIN_SET_PATH,
    spectral=True,
    distances=False,
    correspondences=False,
    device="cuda",
    k=30,
)
dataset2 = ShapeDataset(
    TRAIN_SET_PATH,
    spectral=True,
    distances=True,
    correspondences=True,
    device="cuda",
    k=30,
)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_shapes, validation_shapes = random_split(dataset1, [train_size, val_size])
# Create the full list of indices and shuffle
train_indices = train_shapes.indices
val_indices = validation_shapes.indices

train_shapes = Subset(dataset1, train_indices)
validation_shapes = Subset(dataset2, val_indices)

train_dataset = PairsDataset(
    train_shapes,
    pair_mode="all",
)

validation_dataset = PairsDataset(
    validation_shapes,
    pair_mode="all",
)


we build optimizer

In [5]:
optimizer = torch.optim.Adam(functional_map_model.parameters(), lr=1e-3)

Now we define the losses that we will cnsider. Again we can define our own losses, howver we provide some classic functional map energies, like the orthonormality loss.

In [6]:
# define the loss
losses = [
    OrthonormalityLoss(weight=1.0),
    BijectivityLoss(weight=1.0),
    LaplacianCommutativityLoss(weight=1e-3),
]
loss_manager = LossManager(losses)

losses = [
    GeodesicError(),
]

val_loss_manager = LossManager(losses)

We have defined a trainer for simplicity that thakes as input model, losses, train and val datasets and optimizer and manages the training loops.

In [7]:
trainer = DeepFunctionalMapTrainer(
    model=functional_map_model,
    train_loss_manager=loss_manager,
    val_loss_manager=val_loss_manager,
    train_set=train_dataset,
    val_set=validation_dataset,
    optimizer=optimizer,
    device="cuda",
    epochs=10,
)

In [8]:
trainer.validate()

Validation: 100%|██████████| 12/12 [00:17<00:00,  1.49s/batch, Loss=0.4931]


0.528020355568196