# Deform a sphere into a cow's head

---


This notebook attempts to solve the problem of deforming a sphere into a cow's head.


Import the required modules.


In [None]:
import numpy as np
import pyvista as pv
import pycpd as cpd
import torch
import numerical_geometry as ng

Set the device to MPS if available (you will need to modify this block of code if you aren't using an Apple silicon device).


In [None]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"Using {device} device.")

---

## 1 - Create source and target meshes


We will create a mesh of a sphere using the `pv.Sphere()` function.


In [None]:
sphere_mesh = pv.Sphere(radius=1, theta_resolution=100, phi_resolution=100)

Import a mesh of a cow's head from PyVista; it should be a good mesh to benchmark different deformation algorithms. We will compute the vertex normals using the `compute_normals()` function. We will rotate, translate and scale the cow's head mesh so that it lines up nicely with the sphere mesh.


In [None]:
cow_head_filename = pv.examples.download_cow_head(load=False)
cow_head_mesh = pv.get_reader(cow_head_filename).read()

cow_head_mesh = cow_head_mesh.rotate_x(angle=90)
cow_head_mesh = cow_head_mesh.rotate_y(angle=-62)
cow_head_mesh = cow_head_mesh.translate(-np.array(cow_head_mesh.center))
cow_head_mesh.points = cow_head_mesh.points / np.mean(
    np.linalg.norm(cow_head_mesh.points, axis=1)
)

Plot the source and target mesh.


In [None]:
pl = pv.Plotter(shape=(1, 3))

pl.subplot(0, 0)
pl.add_mesh(sphere_mesh, color="lightblue")
pl.add_text(f"Points: {sphere_mesh.n_points}", font_size=12, position="upper_right")

pl.subplot(0, 1)
pl.add_mesh(cow_head_mesh, color="orange")
pl.add_text(f"Points: {cow_head_mesh.n_points}", font_size=12, position="upper_right")

pl.subplot(0, 2)
pl.add_mesh(sphere_mesh, color="lightblue", opacity=0.5)
pl.add_mesh(cow_head_mesh, color="orange", opacity=0.5)


pl.show()

---

## 2 - Use a neural network to learn the deformation


We will now try to solve the problem using a neural network. We will aim to learn a function $f: \mathbb{R}^6 \rightarrow \mathbb{R}^3$ that maps the points of the source mesh (as well as their vertex normals) to the deformation field. One benefit of this approach is that we can learn the function on a low dimensional representation (i.e. sparse source and target meshes), and then apply the function to a dense source mesh.


### 2.1 - Loss function


In order to optimize the parameters, we need a loss function. Our loss function will include 3 terms: a Chamfer distance term, a deformation magnitude regularization term, and a Laplacian regularization term.


In [None]:
def loss_function(source, target, deformation):
    """
    Loss function
    =============

    Computes the loss.
    """
    return ng.chamfer_distance(source + deformation, target)

### 2.2 - Create and train the model


We can initialize a multilayer perceptron (MLP) using `ng.NeuralNetwork()`.


In [None]:
model = ng.NeuralNetwork(layers=5, input_dim=6, hidden_dim=64, output_dim=3).to(device)
print(f"Total number of parameters: {model.num_parameters}")

We can estimate the optimal learning rate using the `find_optimal_lr()` method.


In [None]:
model.find_optimal_lr(
    source_mesh=sphere_mesh,
    target_mesh=cow_head_mesh,
    device=device,
    loss_function=loss_function,
    optimizer_type="Adam",
    early_stopping=True,
    loss_threshold=0.1,
)

We can train the model using the `train_model()` function, and then save the parameters to the `Parameters/` directory.


In [None]:
model.train_model(
    source_mesh=sphere_mesh,
    target_mesh=cow_head_mesh,
    device=device,
    loss_function=loss_function,
    epochs=1000,
    optimizer_type="Adam",
    lr=1e-3,
    source_batch_size=1024,
    target_batch_size=1024,
    early_stopping=True,
    patience=100,
    min_delta=1e-6,
)
torch.save(model.state_dict(), "Parameters/model_parameters.pth")

### 2.3 - Testing


Now that we have trained the model, we can use it to calculate a deformation field, and then apply the deformation field to the source.


In [None]:
sphere_mesh_2 = pv.Sphere(radius=0.75, theta_resolution=200, phi_resolution=200)
deformation = model.evaluate_model(sphere_mesh_2, device)

ng.plot_deformation(
    source_mesh=sphere_mesh_2,
    target_mesh=cow_head_mesh,
    deformation=deformation,
    show_edges=False,
)

---

## Appendix A - Use a neural network with multiple stages (curriculum training)


We will try training over multiple stages (curriculum training), and see if that improves the performance at all.


Set the device to MPS if available (you will need to modify this block of code if you aren't using an Apple silicon device).


In [None]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"Using {device} device.")

### A.1 - Loss function


In order to optimize the parameters, we need a loss function. To start with, the loss function will have two terms: a Chamfer distance term, and a term which penalizes large deformations. The Chamfer distance provides a measure of how close two point clouds are.


In [None]:
lambda_deformation = 1e-3
lambda_laplacian = 1e-3


def loss_function(source, target, deformation):
    """
    Loss function
    =============

    Computes the loss.
    """
    chamfer_distance = ng.chamfer_distance(source + deformation, target)
    deformation_loss = ng.deformation_loss(deformation)
    laplacian_loss = ng.laplacian_loss(source, deformation)

    loss = (
        chamfer_distance
        + (lambda_deformation * deformation_loss)
        + (lambda_laplacian * laplacian_loss)
    )
    return loss

### A.2 - Create and train the model over multiple stages


We now have everything we need to train the model. The training is handled by the `train_model()` function. We will train the model over 5 stages. You may want to try modifying this code to vary the hyperparameters/loss function at each training stage.


In [None]:
source = torch.from_numpy(sphere_mesh.points).float().to(device)

stages = 5
deformation_history = []

for stage in range(stages):
    print(f"\nStage {stage+1}:\n--------")

    # Find the current source.
    if stage > 0:
        current_source = current_source + current_deformation
    else:
        current_source = source

    model = ng.NeuralNetwork(parameters=128, layers=4).to(device)

    model.train_model(
        source_mesh=current_source,
        target_mesh=cow_head_mesh,
        device=device,
        loss_function=loss_function,
        optimizer_type="SGD",
        epochs=40,
        batch_size=512,
        target_batch_size=1000,
        learning_rate=1e-1,
    )

    current_deformation = model(current_source).detach()
    deformation_history.append(current_deformation)

### A.3 - Testing


We can calculate the full deformation field using `deformation_history`.


In [None]:
total_deformation = torch.zeros_like(deformation_history[0])

for deformation in deformation_history:
    total_deformation += deformation

total_deformation = total_deformation.to("cpu").detach().numpy()

We can now apply the deformation field to the source.


In [None]:
ng.plot_deformation(
    source_mesh=sphere_mesh,
    deformation=total_deformation,
    show_edges=False,
)