# Predict deformations for the DrivAerML dataset


---

## Introduction


This notebook attempts to deform meshes from the DrivAerML dataset into one another.


Import required modules.


In [None]:
import os
import glob
from pathlib import Path
import numpy as np
import pyvista as pv
import pycpd as cpd
import torch
import numerical_geometry as ng

Set the device to MPS if available (you will need to modify this block of code if you aren't using an Apple silicon device).


In [None]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"Using {device} device.")

---

## 1 - Create source and target meshes


Find all of the STL files in the dataset, and read them into PyVista meshes.


In [None]:
data_dir = Path("../drivaer_data/")
files = sorted(list(data_dir.glob("*.stl")))
meshes = []
for file in files:
    meshes.append(pv.read(file))


print(f"{len(meshes)} meshes successfully imported from {len(files)} files.")

Pick 9 random meshes and plot them.


In [None]:
pl = pv.Plotter(shape=(3, 3), window_size=[1200, 1200])

indices = np.random.choice(range(len(meshes)), size=9, replace=False)
for i, index in enumerate(indices):
    pl.subplot(i // 3, i % 3)
    pl.add_mesh(meshes[index])
    pl.add_text(
        files[index].name,
        font_size=12,
        position="upper_right",
    )

pl.show()

Pick two meshes at random to be source and target meshes, and plot them.


In [None]:
indices = np.random.choice(len(meshes), size=2, replace=False)
source_mesh, target_mesh = meshes[indices[0]], meshes[indices[1]]
ng.plot_source_and_target(source_mesh, target_mesh)
print(f"Source mesh: {files[indices[0].name]}. Target mesh: {files[indices[1].name]}.")

---

## 2 - Predict the deformation using a neural network


We will aim to learn a function $f: \mathbb{R}^6 \rightarrow \mathbb{R}^3$ that maps the source points (as well as their vertex normals) to the deformation field.


We can initialize a neural network using `ng.NeuralNetwork()`.


In [None]:
model = ng.NeuralNetwork(
    input_dim=6, layers=2, hidden_dim=1000, output_dim=3, dropout_prob=0.1, verbose=True
).to(device)

In order to optimize the parameters, we need a loss function. We will use only the chamfer distance in the loss function (no regularization terms). We can prevent wild deformations by not making our network too non-linear.


In [None]:
def loss_function(source, target, deformation):
    """
    Loss function
    =============

    Computes the loss.
    """
    return ng.chamfer_distance(source + deformation, target)

We can estimate the optimal learning rate using the `find_optimal_lr()` method.


In [None]:
_ = model.find_optimal_lr(
    source=source_mesh,
    target=target_mesh,
    device=device,
    loss_function=loss_function,
    optimizer_type="Adam",
    initial_lr=1e-4,
    final_lr=1e-1,
    early_stopping=True,
    loss_threshold=1,
)

We can train the model using the `train_model()` function, and then save the parameters to the `Parameters/` directory.


In [None]:
model.train_model(
    source=source_mesh,
    target=target_mesh,
    device=device,
    loss_function=loss_function,
    epochs=50,
    optimizer_type="Adam",
    lr=1e-3,
    batch_size=5000,
    early_stopping=True,
    validation_fraction=0.1,
    patience=5,
    min_epochs=5,
)
# torch.save(model.state_dict(), "Parameters/model_parameters.pth")

Now that we have trained the model, we can use it to calculate a deformation field, and then apply the deformation field to the source.


In [None]:
deformation = model.evaluate_model(source_mesh, device)
ng.plot_deformation(source_mesh, target_mesh, deformation)
ng.animate_deformation(source_mesh, target_mesh, deformation)

In [None]:
import pyvista as pv
import imageio

# Animate the deformation from source_mesh to deformed source and save as gif/mp4


# Compute deformed source points
deformed_points = source_mesh.points + deformation

# Create deformed mesh
deformed_mesh = source_mesh.copy()
deformed_mesh.points = deformed_points

# Setup plotter
plotter = pv.Plotter(off_screen=True)
plotter.add_mesh(source_mesh, color="blue", opacity=0.5, label="Source")
plotter.add_mesh(deformed_mesh, color="red", opacity=0.5, label="Deformed Source")
plotter.add_legend()

n_frames = 30
frames = []

for t in np.linspace(0, 1, n_frames):
    intermediate_points = source_mesh.points + deformation * t
    intermediate_mesh = source_mesh.copy()
    intermediate_mesh.points = intermediate_points
    plotter.clear()
    plotter.add_mesh(intermediate_mesh, color="purple")
    plotter.camera_position = "xy"
    img = plotter.screenshot(return_img=True)
    frames.append(img)


# Save as gif
imageio.mimsave("deformation_animation.gif", frames, duration=0.05)

# Save as mp4
imageio.mimsave("deformation_animation.mp4", frames, fps=20)