In [None]:

# Install required packages if running in Colab
import sys
import os

if 'google.colab' in sys.modules:
    # update pip and setuptools
    !pip install --upgrade pip setuptools wheel
    # Download requirements.txt from the repo if not present
    if not os.path.exists('requirements_colab.txt'):
        !wget https://raw.githubusercontent.com/gattia/ISB-2025-Shape-Modeling/main/requirements_colab.txt
    # Install requirements
    !pip install -r requirements_colab.txt


In [None]:
import pymskt as mskt
import glob
import os
from itkwidgets import view
import pyvista as pv
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import math

import json
import sys

import matplotlib.pyplot as plt

from google.colab import output
output.enable_custom_widget_manager()

In [None]:
# Determine if running in Colab (remote) or local
is_colab = 'google.colab' in sys.modules

# Path to the JSON file listing all mesh filenames
json_path = 'list_meshes.json'

# If running in Colab, download list_meshes.json if not present
if is_colab and not os.path.exists(json_path):
    # Download the JSON file from the GitHub repo
    !wget https://raw.githubusercontent.com/gattia/ISB-2025-Shape-Modeling/main/list_meshes.json -O list_meshes.json

# Load the list of mesh filenames from the JSON file
with open(json_path, 'r') as f:
    mesh_list = json.load(f)  # mesh_list is a list of filenames

list_tib_paths = []

for mesh_filename in mesh_list:
    if is_colab:
        # Remote: load from GitHub raw content
        base_url = "https://raw.githubusercontent.com/gattia/ISB-2025-Shape-Modeling/main/.data"
        path_tib_bone = f"{base_url}/{mesh_filename}"
    else:
        # Local: load from .data directory
        path_tib_bone = os.path.join('.data', mesh_filename)
        # Optionally, check if file exists locally
        if not os.path.exists(path_tib_bone):
            continue
    list_tib_paths.append(path_tib_bone)


In [None]:
def normalize_bone(bone, buffer=0.2):
    mean = np.mean(bone.points, axis=0)
    bone.points -= mean
    norm = np.linalg.norm(bone.points, axis=1)
    max_norm = np.max(norm)
    bone.points /= (max_norm / (1-buffer))
    

In [None]:
n = 5

list_tibs = []

for idx in range(n):
    print(f'Loading: {idx}/{n}')
    if is_colab:
        # Remote: load from GitHub raw content
        base_url = "https://raw.githubusercontent.com/gattia/ISB-2025-Shape-Modeling/main/data" # Corrected base_url
        mesh_filename = mesh_list[idx]
        path_tib_bone_url = f"{base_url}/{mesh_filename}"
        local_path = f"/content/{mesh_filename}" # Local path to save the downloaded file
        print(f"Attempting to download: {path_tib_bone_url}") # Print the URL
        # Download the file
        !wget {path_tib_bone_url} -O {local_path}
        # Use pyvista.read to load from the local path
        tibia_mesh = pv.read(local_path)
    else:
        # Local: load from .data directory
        path_tib_bone = os.path.join('.data', mesh_list[idx])
        # Optionally, check if file exists locally
        if not os.path.exists(path_tib_bone):
            continue
        tibia_mesh = mskt.mesh.Mesh(path_tib_bone)


    if idx == 0:
        ref_tibia = mskt.mesh.Mesh(tibia_mesh)
        # normalize the tibia bone
        normalize_bone(ref_tibia)

        list_tibs.append(ref_tibia)
        print('Using first as reference... not doing registration. ')
        continue

    # load the tibia and the tibia cart
    tibia = mskt.mesh.Mesh(tibia_mesh)
    # normalize it pre-emptively
    normalize_bone(tibia)
    # register the tibia to the reference tibia
    tibia.rigidly_register(ref_tibia, return_transformed_mesh=True)
    list_tibs.append(tibia)

In [None]:

def generate_sdf_points_on_slice(
    mesh, 
    N=20_000, 
    close_sd=0.01, 
    far_sd=0.075, 
    slice_axis='y', 
    verbose=True
):
    """
    Generate noisy points on a slice of a mesh and compute their SDF values.

    Parameters
    ----------
    mesh : mskt.mesh.Mesh or pyvista.PolyData
        The (already normalized) mesh to slice and compute SDFs against.
    N : int
        Number of points per noise level (total points will be 2*N).
    close_sd : float
        Standard deviation of noise for the 'close' set.
    far_sd : float
        Standard deviation of noise for the 'far' set.
    slice_axis : str
        Axis to slice along ('x', 'y', or 'z').
    verbose : bool
        If True, print progress.

    Returns
    -------
    slice_ : pyvista.PolyData
        The sliced mesh.
    pts_ : pyvista.PolyData
        The generated points with SDF values in the 'implicit_distance' array.
    """
    # Copy mesh to avoid modifying input
    slice_ = mesh.slice(slice_axis)

    pts = np.zeros((N*2, 3))
    for i, SD in enumerate([close_sd, far_sd]):
        if verbose:
            print(f"Generating points set {i} with SD={SD}")
        indices = (np.random.sample(N) * slice_.points.shape[0]).astype(int)
        x = np.random.normal(loc=0, scale=SD, size=N)
        y = np.random.normal(loc=0, scale=SD, size=N)
        pts[i*N:(i+1)*N, 0] = slice_.points[indices, 0] + x
        pts[i*N:(i+1)*N, 2] = slice_.points[indices, 2] + y

    pts_ = pv.PolyData(pts)
    pts_.compute_implicit_distance(mesh, inplace=True)
    return slice_, pts_

# Example usage:
slice_, pts_ = generate_sdf_points_on_slice(
    ref_tibia, 
    N=20_000, 
    close_sd=0.01, 
    far_sd=0.075,
    slice_axis='y'
)

view(geometries=[slice_, ref_tibia], point_sets=pts_, point_size=2)


In [None]:
class SimpleMLP(nn.Module):
    def __init__(self, input_dim=2, latent_dim=8, hidden_dim=64, output_dim=1):
        """
        input_dim: dimension of input features (e.g., 2 for xz)
        latent_dim: dimension of latent vector to concatenate
        hidden_dim: hidden layer size
        output_dim: output size (e.g., 1 for SDF)
        """
        super(SimpleMLP, self).__init__()
        self.latent_dim = latent_dim
        self.input_dim = input_dim
        self.net = nn.Sequential(
            nn.Linear(input_dim + latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x, latent):
        """
        x: (batch, input_dim)
        latent: (batch, latent_dim) or (latent_dim,) (will be broadcasted if needed)
        """
        x_cat = torch.cat([x, latent], dim=-1)
        return self.net(x_cat)

class PointsSDFDataset(Dataset):
    def __init__(self, mesh_list, n_sample=500, N=20_000, close_sd=0.01, far_sd=0.075, max_sdf=0.1, slice_axis='y', verbose=False):
        """
        mesh_list: list of meshes (already normalized) to sample from
        N: number of points per noise level (total points per mesh will be 2*N)
        close_sd: standard deviation for 'close' noise
        far_sd: standard deviation for 'far' noise
        max_sdf: clamp SDF values to [-max_sdf, max_sdf]
        slice_axis: axis to slice along ('x', 'y', or 'z')
        verbose: print progress
        """
        self.xz = []
        self.sdf = []
        self.n_sample = n_sample
        for mesh in mesh_list:
            # Use the provided function to generate points and SDFs
            _, pts_ = generate_sdf_points_on_slice(
                mesh, N=N, close_sd=close_sd, far_sd=far_sd, slice_axis=slice_axis, verbose=verbose
            )
            pts = pts_.points  # shape (2*N, 3)
            sdf = pts_['implicit_distance']  # shape (2*N,)
            xz_tensor = torch.tensor(pts[:, [0, 2]], dtype=torch.float32)
            sdf_tensor = torch.tensor(sdf, dtype=torch.float32).unsqueeze(1)
            sdf_tensor = torch.clamp(sdf_tensor, min=-max_sdf, max=max_sdf)
            self.xz.append(xz_tensor)
            self.sdf.append(sdf_tensor)
        self.point_batch_size = self.xz[0].shape[0]  # assumes all have same number of points

    def __len__(self):
        # Number of meshes
        return len(self.xz)

    def __getitem__(self, idx):
        # Returns (idx, xz, sdf) for the idx-th mesh
        # get n_sample random points from the xz and sdf
        N = self.xz[idx].shape[0]
        indices = np.random.choice(N, size=self.n_sample, replace=False)
        return idx, self.xz[idx][indices], self.sdf[idx][indices]

num_epochs = 40_000
n_samples = 500
batch_size = 2
latent_dim = 32
latent_init_std = 0.1
slice_axis = 'y'

lat_vecs = torch.nn.Embedding(len(list_tibs), latent_dim, max_norm=10.0)
torch.nn.init.normal_(
    lat_vecs.weight.data,
    0.0,
    latent_init_std / math.sqrt(latent_dim),
)

# Simple training loop for the SimpleMLP and PointsSDFDataset

# Assume dataset and dataloader are already created as shown above
# Example:
dataset = PointsSDFDataset(
    list_tibs, n_sample=n_samples, N=20_000,
    close_sd=0.01, far_sd=0.075,
    max_sdf=0.1, slice_axis=slice_axis,
    verbose=True
)

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

model = SimpleMLP(input_dim=2, latent_dim=latent_dim, hidden_dim=32, output_dim=1)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.L1Loss()  # Use L1 loss


for epoch in range(num_epochs):
    running_loss = 0.0
    for idx, xz, sdf in dataloader:
        optimizer.zero_grad()
        latents = lat_vecs(idx)
        # expand the second dimension of the latents to match the n_samples size
        latents = latents.unsqueeze(1).expand(-1, n_samples, -1)
        pred = model(xz, latents)
        # pred = torch.clamp(pred, min=-0.1, max=0.1)
        # compute an L2 norm on the latents
        lat_loss = torch.sum(torch.norm(latents, dim=-1))
        sdf_loss = criterion(pred, sdf)
        loss = lat_loss + sdf_loss
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * xz.size(0)
    epoch_loss = running_loss / len(dataloader.dataset)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.6f}")



In [None]:
LATENT_IDX = 0

# Create a grid of x and z from -1 to 1 in steps of 0.01
x = np.arange(-1, 1.01, 0.01)
z = np.arange(-1, 1.01, 0.01)
xx, zz = np.meshgrid(x, z)
xz_grid = np.stack([xx.ravel(), zz.ravel()], axis=1)
xz_tensor = torch.from_numpy(xz_grid).float()

# Get SDF predictions from the trained model
with torch.no_grad():
    latent = lat_vecs(torch.tensor(LATENT_IDX))
    latent = latent.expand(xz_tensor.size(0), -1)
    print(latent.size())
    sdf_pred = model(xz_tensor, latent).cpu().numpy().flatten()

# Store as xyz (x, 0, z)
xyz = np.stack([xz_grid[:, 0], np.zeros_like(xz_grid[:, 0]), xz_grid[:, 1]], axis=1)

xyz_ = pv.PolyData(xyz)
# assign sdf to xyz_
xyz_['sdf'] = sdf_pred

view(geometries=[slice_, list_tibs[LATENT_IDX]], point_sets=xyz_, point_size=2)

In [None]:
# Generate a 3D SDF image using the generative model
n = 100
x_min, y_min, z_min = -1, -0.01, -1
x_max, y_max, z_max = 1, 0.01, 1
spacing = (
    (x_max - x_min) / (n - 1),
    0.01,
    (z_max - z_min) / (n - 1),
)
grid = pv.ImageData(
    dimensions=(n, 2, n),
    spacing=spacing,
    origin=(x_min, y_min, z_min),
)
x, y, z = grid.points.T

# Prepare input for the generative model: (x, z) and latent
xz_grid_3d = np.stack([x, z], axis=1)
xz_tensor_3d = torch.from_numpy(xz_grid_3d).float()

with torch.no_grad():
    latent = lat_vecs(torch.tensor(LATENT_IDX))
    latent = latent.expand(xz_tensor_3d.size(0), -1)
    sdf_pred_3d = model(xz_tensor_3d, latent).cpu().numpy().flatten()

# Marching cubes: extract the zero level set
mesh = grid.contour([0], sdf_pred_3d, method='marching_cubes')
mesh.compute_normals()

# Visualize the mesh
view(geometries=[mesh, slice_])


In [None]:
# Generate a 3D SDF image using the generative model

n = 100
x_min, y_min, z_min = -1, -0.01, -1
x_max, y_max, z_max = 1, 0.01, 1
spacing = (
    (x_max - x_min) / (n - 1),
    0.01,
    (z_max - z_min) / (n - 1),
)
grid = pv.ImageData(
    dimensions=(n, 2, n),
    spacing=spacing,
    origin=(x_min, y_min, z_min),
)
x, y, z = grid.points.T

# Prepare input for the generative model: (x, z) and latent
xz_grid_3d = np.stack([x, z], axis=1)
xz_tensor_3d = torch.from_numpy(xz_grid_3d).float()

with torch.no_grad():
    latent = lat_vecs(torch.tensor(LATENT_IDX))
    latent = latent.expand(xz_tensor_3d.size(0), -1)
    sdf_pred_3d = model(xz_tensor_3d, latent).cpu().numpy().flatten()

# Marching cubes: extract the zero level set
mesh = grid.contour([0], sdf_pred_3d, method='marching_cubes')
mesh.compute_normals()

slice_ = list_tibs[LATENT_IDX].slice('y')
# slice_, pts_ = generate_sdf_points_on_slice(list_tibs[LATENT_IDX], N=20_000, close_sd=0.01, far_sd=0.075)

# Visualize the mesh
view(geometries=[mesh, slice_])
