# Deform a hemisphere into a cow's head

---


This notebook attempts to solve the problem of deforming a hemisphere into a cow's head.


Import the required modules.


In [None]:
import numpy as np
import pyvista as pv
import pycpd as cpd
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
import numerical_geometry as ng

---

## 1 - Create source and target meshes


Create a mesh of a hemisphere using the `pv.Sphere()` function.


In [None]:
hemisphere_mesh = pv.Sphere(
    radius=1, theta_resolution=100, phi_resolution=100, end_phi=90
)
hemisphere_mesh.plot(show_edges=True)

Import a mesh of a cow's head from PyVista. A cow's head is a complicated geometry which is topologically equivalent to the hemisphere mesh. It should be a good mesh to benchmark different deformation algorithms.


In [None]:
cow_head_filename = pv.examples.download_cow_head(load=False)
cow_head_mesh = pv.get_reader(cow_head_filename).read()
cow_head_mesh.plot(show_edges=True)

Rotate, translate and scale the cow's head mesh so that it lines up nicely with the hemisphere mesh.


In [None]:
cow_head_mesh = cow_head_mesh.rotate_x(angle=90)
cow_head_mesh = cow_head_mesh.rotate_y(angle=-62)
cow_head_mesh = cow_head_mesh.translate((-0.8, 0, -3.7))
cow_head_mesh.points = cow_head_mesh.points / np.mean(
    np.linalg.norm(cow_head_mesh.points, axis=1)
)

print(cow_head_mesh.points.shape)

Plot the source and target mesh.


In [None]:
plotter = pv.Plotter()
plotter.add_mesh(hemisphere_mesh, color="lightblue", opacity=0.5)
plotter.add_mesh(cow_head_mesh, color="brown", opacity=0.5, show_edges=True)
plotter.show()

---

## 2 - Use `pycpd` to deform the hemisphere to the cow's head


First, we will try using the `pypcd` [1] package to deform the hemisphere mesh to the cow's head mesh. `pypcd` is a NumPy implementation of the Coherent Point Drift (CPD) algorithm by Myronenko and Song [2].

After running 5 iterations, the general shape of the cow's head is captured fairly well. However, the ears and horns of the cow are not distinct, and the hole for the neck is in completely the wrong place!


In [None]:
NUM_ITERATIONS = 5

hemisphere_mesh_deformed = hemisphere_mesh.copy()

for i in range(NUM_ITERATIONS):
    # Create a deformable registration object.
    deformable_registration = cpd.DeformableRegistration(
        alpha=0.05,
        X=np.array(cow_head_mesh.points),
        Y=np.array(hemisphere_mesh_deformed.points),
    )

    # Run the deformable registration, and update the hemisphere mesh.
    transformed_hemisphere_points, _ = deformable_registration.register()
    hemisphere_mesh_deformed.points = transformed_hemisphere_points

# Plot the meshes in separate subplots.
pl = pv.Plotter(shape=(1, 2))

pl.subplot(0, 0)
pl.add_mesh(hemisphere_mesh_deformed, color="lightblue")
pl.add_text("Deformed hemisphere", font_size=10)

pl.subplot(0, 1)
pl.add_mesh(cow_head_mesh, color="orange")
pl.add_text("Cow's head", font_size=10)

pl.show()

In [None]:
pl = pv.Plotter()
pl.add_mesh(hemisphere_mesh_deformed, color="lightblue")
pl.show()

---

## 3 - Use a neural network to deform the hemisphere to the cow's head


We will now try to solve the problem using a neural network. We will aim to learn a function $f: \mathbb{R}^3 \rightarrow \mathbb{R}^3$ that maps the points of the source mesh to the deformation field. One benefit of this approach is that we can learn the function on a low dimensional representation (i.e. sparse source and target meshes), and then apply the function to a dense source mesh.


Set the device to MPS if available (you will need to modify this block of code if you aren't using an Apple silicon device).


In [None]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"Using {device} device.")

### 3.1 - Loss function


In order to optimize the parameters, we need a loss function. To start with, the loss function will have two terms: a Chamfer distance term, and a term which penalizes large deformations. The Chamfer distance provides a measure of how close two point clouds are.


In [None]:
lambda_chamfer = 1
lambda_deformation = 1e-6


def loss_function(source, target, deformation):
    """
    Loss function
    =============

    Computes the loss.
    """

    chamfer_distance = ng.get_chamfer_distance(source + deformation, target)
    average_deformation = ng.get_average_deformation(deformation)

    loss = (lambda_chamfer * chamfer_distance) + (
        lambda_deformation * average_deformation
    )
    return loss

### 3.2 - Create and train the model


We can initialize a multilayer perceptron (MLP) and an optimizer using the `ng.new_network()` function. We can train the model using the `train_model()` function.


In [None]:
model = ng.core.NeuralNetwork(parameters=128, layers=4).to(device)

model.train_model(
    source_mesh=hemisphere_mesh,
    target_mesh=cow_head_mesh,
    device=device,
    loss_function=loss_function,
    optimizer_type="SGD",
    epochs=100,
    batch_size=512,
    target_batch_size=1000,
    learning_rate=1e-1,
)

### 3.3 - Testing


Now that we have trained the model, we can use it to calculate a deformation field.


In [None]:
deformation = model.get_deformation_field(hemisphere_mesh, device)

We can now apply the deformation field to the source.


In [None]:
ng.animate_deformation(
    source_mesh=hemisphere_mesh, target_mesh=cow_head_mesh, deformation=deformation
)

---

## 4 - Use a neural network with multiple stages


The approach from the previous section didn't work very well, so we will try training over multiple stages, and see if that improves the performance at all.


Set the device to MPS if available (you will need to modify this block of code if you aren't using an Apple silicon device).


In [None]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"Using {device} device.")

### 4.1 - Loss function


In order to optimize the parameters, we need a loss function. To start with, the loss function will have two terms: a Chamfer distance term, and a term which penalizes large deformations. The Chamfer distance provides a measure of how close two point clouds are.


In [None]:
lambda_chamfer = 1
lambda_deformation = 1e-6


def loss_function(source, target, deformation):
    """
    Loss function
    =============

    Computes the loss.
    """

    chamfer_distance = ng.get_chamfer_distance(source + deformation, target)
    average_deformation = ng.get_average_deformation(deformation)

    loss = (lambda_chamfer * chamfer_distance) + (
        lambda_deformation * average_deformation
    )
    return loss

### 4.2 - Create and train the model over multiple stages


We now have everything we need to train the model. The training is handled by the `train_model()` function.


In [None]:
source = torch.from_numpy(hemisphere_mesh.points).float().to(device)

stages = 5
deformation_history = []

for stage in range(stages):
    if stage > 0:
        current_source = current_source + current_deformation
    else:
        current_source = source

    print(f"\nStage {stage+1}:\n--------")

    model = ng.core.NeuralNetwork(parameters=128, layers=4).to(device)

    model.train_model(
        source_mesh=current_source,
        target_mesh=cow_head_mesh,
        device=device,
        loss_function=loss_function,
        optimizer_type="SGD",
        epochs=100,
        batch_size=512,
        target_batch_size=1000,
        learning_rate=1e-1,
    )

    current_deformation = model(current_source).detach()
    deformation_history.append(current_deformation)

### 3.4 - Testing


We can calculate the full deformation field using `deformation_history`.


In [None]:
total_deformation = torch.zeros_like(deformation_history[0])

for deformation in deformation_history:
    total_deformation += deformation

total_deformation = total_deformation.to("cpu").detach().numpy()

We can now apply the deformation field to the source.


In [None]:
ng.animate_deformation(
    source_mesh=hemisphere_mesh,
    target_mesh=cow_head_mesh,
    deformation=total_deformation,
)

---

## References


[1] - Khallaghi, S. siavashk/pycpd. (2025) (https://github.com/siavashk/pycpd).

[2] - Myronenko, A. & Song, X. Point-Set Registration: Coherent Point Drift. IEEE Trans. Pattern Anal. Mach. Intell. 32, 2262–2275 (2010) (https://doi.org/10.1109/TPAMI.2010.46).
