# Deform a sphere into a cow's head


---

## Introduction


This notebook attempts to solve the problem of deforming a sphere into a cow's head.


Import the required modules.


In [None]:
import numpy as np
import pyvista as pv
import pycpd as cpd
import torch
import numerical_geometry as ng

Set the device to MPS if available (you will need to modify this block of code if you aren't using an Apple silicon device).


In [None]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"Using {device} device.")

---

## 1 - Create source and target meshes


We will create a mesh of a sphere using the `pv.Sphere()` function.


In [None]:
sphere_mesh = pv.Sphere(radius=1, theta_resolution=500, phi_resolution=500)

Import a mesh of a cow's head from PyVista; it should be a good mesh to benchmark different deformation algorithms. We will compute the vertex normals using the `compute_normals()` function. We will rotate, translate and scale the cow's head mesh so that it lines up nicely with the sphere mesh.


In [None]:
cow_head_filename = pv.examples.download_cow_head(load=False)
cow_head_mesh = pv.get_reader(cow_head_filename).read()

cow_head_mesh = cow_head_mesh.rotate_x(angle=90)
cow_head_mesh = cow_head_mesh.rotate_y(angle=-62)
cow_head_mesh = cow_head_mesh.translate(-np.array(cow_head_mesh.center))
cow_head_mesh.points = cow_head_mesh.points / np.mean(
    np.linalg.norm(cow_head_mesh.points, axis=1)
)

Plot the source and target mesh.


In [None]:
ng.plot_source_and_target(sphere_mesh, cow_head_mesh)

---

## 2 - Use a neural network to learn the deformation


We will aim to learn a function $f: \mathbb{R}^6 \rightarrow \mathbb{R}^3$ that maps the source points (as well as their vertex normals) to the deformation field. One benefit of this approach is that we can learn the function on a low dimensional representation (i.e. a sparse source mesh), and then apply the function to a dense source mesh.


We can initialize a multilayer perceptron (MLP) using `ng.NeuralNetwork()`.


In [None]:
model = ng.NeuralNetwork(layers=2, input_dim=6, hidden_dim=64, output_dim=3).to(device)
print(f"Total number of parameters: {model.num_parameters}")

In order to optimize the parameters, we need a loss function. We will use only the chamfer distance in the loss function (no regularization terms). We can prevent wild deformations by not making our network too non-linear.


In [None]:
def loss_function(source, target, deformation):
    """
    Loss function
    =============

    Computes the loss.
    """
    return ng.chamfer_distance(source + deformation, target)

We can estimate the optimal learning rate using the `find_optimal_lr()` method.


In [None]:
model.find_optimal_lr(
    source_mesh=sphere_mesh,
    target_mesh=cow_head_mesh,
    device=device,
    loss_function=loss_function,
    optimizer_type="Adam",
    early_stopping=True,
    loss_threshold=1,
)

We can train the model using the `train_model()` function, and then save the parameters to the `Parameters/` directory.


In [None]:
model.train_model(
    source_mesh=sphere_mesh,
    target_mesh=cow_head_mesh,
    device=device,
    loss_function=loss_function,
    epochs=100,
    optimizer_type="Adam",
    lr=1e-3,
    source_batch_size=1024,
    target_batch_size=1024,
    early_stopping=True,
    validation_fraction=0.01,
    patience=10,
)
# torch.save(model.state_dict(), "Parameters/model_parameters.pth")

Now that we have trained the model, we can use it to calculate a deformation field, and then apply the deformation field to the source.


In [None]:
sphere_mesh_2 = pv.Sphere(radius=0.75, theta_resolution=200, phi_resolution=200)
deformation = model.evaluate_model(sphere_mesh_2, device)

ng.plot_deformation(
    source_mesh=sphere_mesh_2,
    target_mesh=cow_head_mesh,
    deformation=deformation,
    show_edges=False,
)