# Deform a sphere into a cube

---


This notebook solves the problem of deforming a sphere into a cube using different methods.


Import the required modules.


In [None]:
import numpy as np
import pyvista as pv
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
import utils

Define helper functions.


In [None]:
def animate_sphere_to_cube(
    sphere_mesh,
    cube_mesh,
    deformation,
    t_min=0,
    t_max=1,
    show_cube=False,
    show_edges=True,
):
    """
    Animate sphere to cube
    ======================

    Adds an inline plot of the deformed sphere in the notebook, with a slider controlling the deformation parameter.
    """
    sphere_vertices = sphere_mesh.points.copy()

    # Create plotter with custom window size
    pl = pv.Plotter(window_size=[1000, 700])

    # Add reference cube
    if show_cube:
        cube_actor = pl.add_mesh(cube_mesh, color="orange", opacity=0.3)

    # Add initial sphere
    sphere_actor = pl.add_mesh(sphere_mesh, color="lightblue", show_edges=show_edges)

    def update_deformation(t):
        # Calculate deformed vertices
        deformed_vertices = sphere_vertices + deformation * t
        deformed_mesh = pv.PolyData(deformed_vertices, sphere_mesh.faces)

        # Update the mesh
        sphere_actor.GetMapper().SetInputData(deformed_mesh)

        # Render
        pl.render()

    # Add slider
    pl.add_slider_widget(
        update_deformation,
        rng=[t_min, t_max],
        value=0,
        title="t",
        pointa=(0.05, 0.8),
        pointb=(0.25, 0.8),
        style="modern",
    )

    pl.show()

---

## 1 - Create source and target meshes


We can create a sphere (source) mesh and a cube (target) mesh using the `utils.sphere()` and `utils.cube()` functions respectively. Generating meshes of spheres and cubes is discussed in the `sphere.ipynb` and `cube.ipynb` notebooks respectively.


In [None]:
sphere_mesh = utils.sphere(num_polar_angles=100, num_azimuthal_angles=200)
cube_mesh = utils.cube(num_points_per_side=50)

Plot the sphere (source) mesh and the cube (target) mesh.


In [None]:
pl = pv.Plotter(shape=(1, 2))

pl.subplot(0, 0)
pl.add_mesh(sphere_mesh, show_edges=True)
pl.add_text("Sphere (source)", font_size=12)

pl.subplot(0, 1)
pl.add_mesh(cube_mesh, show_edges=True)
pl.add_text("Cube (target)", font_size=12)

pl.show()

---

## 2 - Deform a sphere into a cube using ray tracing


The first step to solve this problem is to find where the vertex normals of the sphere intersect the faces of the cube. This is a ray tracing problem, and can be solved using PyVista's built in ray tracing functionality. We will define a function, `get_intersection_points()`, which uses PyVista's `ray_trace()` method to find where the sphere's vertex normals intersect the faces of the cube.


In [None]:
def get_intersection_points(sphere_mesh, cube_mesh, ray_length=10):
    """
    Get intersection points
    =======================

    Calculates where the vertex normals to the sphere mesh intersect the faces of the cube mesh.
    The intersection points are calculated using the PyVista `ray_trace()` method.
    """
    intersection_points = []
    intersection_rays = []
    intersection_cells = []

    # Extract vertices and normals from the sphere mesh.
    sphere_vertices = sphere_mesh.points
    sphere_normals = sphere_mesh.point_data["normals"]

    # Process each ray individually.
    for i, (origin, normal) in enumerate(zip(sphere_vertices, sphere_normals)):
        # Perform ray tracing.
        try:
            end_point = origin + normal * ray_length
            points, cells = cube_mesh.ray_trace(origin, end_point, first_point=True)

            # If intersection found, store the results
            if len(points) > 0:
                intersection_points.append(points)
                intersection_rays.append(i)
                intersection_cells.append(cells)

        except Exception as e:
            continue

    # Convert results to numpy arrays.
    intersection_points = (
        np.array(intersection_points) if intersection_points else np.empty((0, 3))
    )
    intersection_rays = (
        np.array(intersection_rays) if intersection_rays else np.empty((0,), dtype=int)
    )
    intersection_cells = (
        np.array(intersection_cells)
        if intersection_cells
        else np.empty((0,), dtype=int)
    )

    return intersection_points, intersection_rays, intersection_cells

We can use `get_intersection_points()` to solve for the deformation field.


In [None]:
intersection_points, intersection_rays, intersection_cells = get_intersection_points(
    sphere_mesh,
    cube_mesh,
)

deformation = intersection_points - sphere_mesh.points

print(sphere_mesh.points.shape)
print(intersection_points.shape)

To deform the sphere, we simply need to multiply the deformation field, $\vec{d}$, by a scalar, $t \in [0, 1]$, and evolve the position of each vertex according to this new deformation field. This is handled by the `animate_sphere_to_cube()` function.


In [None]:
animate_sphere_to_cube(
    sphere_mesh,
    cube_mesh,
    deformation,
)

---

## 3 - Deform a sphere into a cube using a neural network


We will now try to solve the problem using a neural network. We will aim to learn a function $f: \mathbb{R}^3 \rightarrow \mathbb{R}^3$ that maps the points of the source mesh to the deformation field. One benefit of this approach is that we can learn the function on a low dimensional representation (i.e. sparse source and target meshes), and then apply the function to a dense source mesh.


Set the device to MPS if available (you will need to modify this block of code if you aren't using an Apple silicon device).


In [None]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"Using {device} device.")

Define hyperparameters.


In [None]:
epochs = 100
batch_size = 512
target_batch_size = 1000
learning_rate = 1e-1
lambda_chd = 1.0
lambda_deform = 1e-6

### 3.1 - Prepare the point clouds


Convert `sphere_mesh.points` and `cube_mesh.points` to PyTorch Tensors, `source` and `target` respectively, and send them to the GPU.


In [None]:
source = torch.from_numpy(sphere_mesh.points).float().to(device)
target = torch.from_numpy(cube_mesh.points).float().to(device)

Validate `batch_size` and `target_batch_size`


In [None]:
if batch_size > source.shape[0]:
    batch_size = source.shape[0]

if target_batch_size > target.shape[0]:
    target_batch_size = target.shape[0]

Create a dataset and dataloader for `source`.


In [None]:
source_dataset = TensorDataset(source)

if source.shape[0] > batch_size:
    source_dataloader = DataLoader(source_dataset, batch_size=batch_size, shuffle=True)
else:
    batch_size = source.shape[0]
    source_dataloader = DataLoader(source_dataset, batch_size=batch_size, shuffle=True)

### 3.2 - Create the model


To start with, we will use a simple multilayer perceptron (MLP). We can increase the complexity of the model later if required.


In [None]:
class DeformationNetwork(nn.Module):
    def __init__(self, hidden=128, layers=4):
        super().__init__()
        mods = []
        in_dim = 3
        for _ in range(layers):
            mods += [nn.Linear(in_dim, hidden), nn.ReLU(inplace=True)]
            in_dim = hidden
        mods += [nn.Linear(in_dim, 3)]
        self.net = nn.Sequential(*mods)

    def forward(self, x):
        return self.net(x)


model = DeformationNetwork(hidden=128, layers=4).to(device)

### 3.3 - Loss function


In order to optimize the parameters, we need a loss function. To start with, the loss function will have two terms: a Chamfer distance term, and a term which penalizes large deformations. The Chamfer distance provides a measure of how close two point clouds are.


In [None]:
def get_chamfer_distance(x, y):
    """
    Get Chamfer distance
    ====================

    Computes the Chamfer distance between two point clouds, x and y.
    """
    # Compute pairwise squared distances between points in x and points in y.
    x = x.unsqueeze(1)  # (N, 1, D)
    y = y.unsqueeze(0)  # (1, M, D)
    dist = torch.sum((x - y) ** 2, dim=2)  # (N, M)

    # For each point in x, find nearest neighbor in y.
    min_dist_x, _ = torch.min(dist, dim=1)

    # For each point in y, find nearest neighbor in x.
    min_dist_y, _ = torch.min(dist, dim=0)

    # Take the average of the mean nearest neighbor distances.
    return torch.mean(min_dist_x) + torch.mean(min_dist_y)


def get_average_deformation(deformation):
    """
    Get average deformation
    =======================

    Computes the average size of a deformation field.
    """
    return torch.sqrt((deformation**2).mean())


def loss_function(source, target, deformation, lambda_chd, lambda_deform):
    """
    Loss function
    =============

    Computes the loss.
    """

    chamfer_distance = get_chamfer_distance(source + deformation, target)
    average_deformation = get_average_deformation(deformation)
    loss = (lambda_chd * chamfer_distance) + (lambda_deform * average_deformation)
    return loss

### 3.4 - Optimizer


The optimizer updates the parameters based on the gradient of the loss function. We will start by using the stochastic gradient descent (SGD) optimization algorithm; we can change the optimization algorithm later if required.


In [None]:
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

### 3.5 - Training


We now have everything we need to set up the training loop and train the model.


In [None]:
loss_history = []
epoch_history = []

# Training loop.
for epoch in torch.arange(1, epochs + 1):
    epoch_loss = 0.0
    num_batches = 0

    for batch_idx, (batch_source,) in enumerate(source_dataloader):
        # Sample a subset of the target.
        target_indices = torch.randperm(target.size(0))[:target_batch_size]
        batch_target = target[target_indices]

        # Compute the deformation for this batch.
        batch_deformation = model(batch_source)

        # Compute the loss for this batch.
        batch_loss = loss_function(
            batch_source, batch_target, batch_deformation, lambda_chd, lambda_deform
        )

        # Backpropagate the loss.
        batch_loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        epoch_loss += batch_loss.item()
        num_batches += 1

    # Average loss for this epoch.
    average_epoch_loss = epoch_loss / num_batches

    # Store loss values for plotting.
    loss_history.append(average_epoch_loss)
    epoch_history.append(epoch)

    if epoch % 10 == 0:
        print(f"[{epoch}/{epochs}]: loss = {average_epoch_loss:.6f}")


fig, ax = plt.subplots(figsize=(10, 6))
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.grid(True, alpha=0.3)
ax.scatter(epoch_history, loss_history, s=10)
plt.show()

### 3.6 - Test the model


Now that we have trained the model, we can use it to calculate a deformation field.


In [None]:
deformation = model(source).to("cpu").detach().numpy()

We can now apply the deformation field to the source.


In [None]:
animate_sphere_to_cube(
    sphere_mesh,
    cube_mesh,
    deformation,
)

---

## 4 - Questions


1. Why does the loss vs $t$ graph have these oscillations?

![alt text](Images/loss_vs_epoch.png)

2. What terms would it be worth experimenting with adding to the loss function? I would like to add a term that depends on the curvature of the deformed source, and a term that depends on the local density of the deformed source point cloud; what is the best way of computing these quantities?
