In [1]:
import torch
import torch.nn as nn
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import renderers
import sdf
import csg

In [2]:
device = torch.device("cuda:0")

In [3]:
# ---------------- load target ---------------- #
from pytorch3d.io import load_ply
from pytorch3d.structures import Meshes
import meshplot as mp

def load_target(path):
    # Load the target mesh
    verts, faces = load_ply(path)

    # Convert vertices to a tensor if not already
    if not isinstance(verts, torch.Tensor):
        verts = torch.tensor(verts, dtype=torch.float32)

    # Step 1: Compute the bounding box (min/max coordinates)
    min_coords = verts.min(dim=0)[0]
    max_coords = verts.max(dim=0)[0]

    # Step 2: Compute the center of the bounding box
    center = (min_coords + max_coords) / 2.0

    # Step 3: Center the object to the origin
    verts_centered = verts - center

    # Step 4: Compute the scale factor to fit within the unit cube
    scale_factor = 1.0 / (max_coords - min_coords).max()

    # Step 5: Scale the object to fit the unit cube
    verts_scaled = verts_centered * scale_factor

    # Step 6: Translate the object to the first octant (shift so all coordinates are positive)
    min_coords_after_scaling = verts_scaled.min(dim=0)[0]
    verts_final = verts_scaled - min_coords_after_scaling

    # Now `mesh` is centered and scaled to fit in the first octant cube
    return verts_final, faces

from pytorch3d.ops import knn_points

def multi_indexing(index: torch.Tensor, shape: torch.Size, dim=-2):
    shape = list(shape)
    back_pad = len(shape) - index.ndim
    for _ in range(back_pad):
        index = index.unsqueeze(-1)
    expand_shape = shape
    expand_shape[dim] = -1
    return index.expand(*expand_shape)


def multi_gather(values: torch.Tensor, index: torch.Tensor, dim=-2):
    # take care of batch dimension of, and acts like a linear indexing in the target dimention
    # we assume that the index's last dimension is the dimension to be indexed on
    return values.gather(dim, multi_indexing(index, values.shape, dim))


def winding_number(pts: torch.Tensor, verts: torch.Tensor, faces: torch.Tensor) -> torch.Tensor:
    """
    Parallel implementation of the Generalized Winding Number of points on the mesh
    O(n_points * n_faces) memory usage, parallelized execution

    1. Project tris onto the unit sphere around every points
    2. Compute the signed solid angle of the each triangle for each point
    3. Sum the solid angle of each triangle

    Parameters
    ----------
    pts    : torch.Tensor, (n_points, 3)
    verts  : torch.Tensor, (n_verts, 3)
    faces  : torch.Tensor, (n_faces, 3)

    This implementation is also able to take a/multiple batch dimension
    """
    # projection onto unit sphere: verts implementation gives a little bit more performance
    uv = verts[..., None, :, :] - pts[..., :, None, :]  # n_points, n_verts, 3
    uv = uv / uv.norm(dim=-1, keepdim=True)  # n_points, n_verts, 3

    # gather from the computed vertices (will result in a copy for sure)
    expanded_faces = faces[..., None, :, :].expand(*faces.shape[:-2], pts.shape[-2], *faces.shape[-2:])  # n_points, n_faces, 3

    u0 = multi_gather(uv, expanded_faces[..., 0])  # n, f, 3
    u1 = multi_gather(uv, expanded_faces[..., 1])  # n, f, 3
    u2 = multi_gather(uv, expanded_faces[..., 2])  # n, f, 3

    e0 = u1 - u0  # n, f, 3
    e1 = u2 - u1  # n, f, 3
    del u1

    # compute solid angle signs
    sign = (torch.cross(e0, e1, dim=-1) * u2).sum(dim=-1).sign()

    e2 = u0 - u2
    del u0, u2

    l0 = e0.norm(dim=-1)
    del e0

    l1 = e1.norm(dim=-1)
    del e1

    l2 = e2.norm(dim=-1)
    del e2

    # compute edge lengths: pure triangle
    l = torch.stack([l0, l1, l2], dim=-1)  # n_points, n_faces, 3

    # compute spherical edge lengths
    l = 2 * (l/2).arcsin()  # n_points, n_faces, 3

    # compute solid angle: preparing: n_points, n_faces
    s = l.sum(dim=-1) / 2
    s0 = s - l[..., 0]
    s1 = s - l[..., 1]
    s2 = s - l[..., 2]

    # compute solid angle: and generalized winding number: n_points, n_faces
    eps = 1e-10  # NOTE: will cause nan if not bigger than 1e-10
    solid = 4 * (((s/2).tan() * (s0/2).tan() * (s1/2).tan() * (s2/2).tan()).abs() + eps).sqrt().arctan()    
    signed_solid = solid * sign  # n_points, n_faces

    winding = signed_solid.sum(dim=-1) / (4 * torch.pi)  # n_points    

    return winding

# ---------------- mesh to SDF ---------------- #
def compute_sdf(verts, faces, points):

    # Step 1: Flatten the points to shape (N, 3), where N = H * W
    original_shape = points.shape[:-1]  # Save the original shape (H, W)
    points_flat    = points.view(-1, 3)    # Flatten to (N, 3)    
    verts_packed   = verts.unsqueeze(0)  # Add a batch dimension (1, V, 3)
    
    # Step 2: Find the closest vertex on the mesh to each point in `points`
    dists, idx, _ = knn_points(points_flat.unsqueeze(0), verts_packed, K=1)  # (1, N, 1)
    
    # Step 3: Compute the distance from the query point to the closest point on the mesh
    signed_distances = torch.sqrt(dists).squeeze(0).squeeze(-1)  # (N,)
    
    # Step 5: Reshape the sdf_values back to the original grid shape (H, W)
    sdf_values = signed_distances.view(*original_shape).unsqueeze(-1)  # Add the last dimension to get (H, W, 1)
    
    winding = winding_number(points_flat.to(device),verts.to(device),faces.to(device))
    
    signs = torch.where(winding > 0.5, -1.0, 1.0)
    
    return sdf_values * signs.view(*original_shape).unsqueeze(-1)


def mesh_sdf_wrapper(verts, faces):
    def sdf(p):
        # Use the compute_sdf function from before, which takes `verts` and `faces`
        return compute_sdf(verts, faces, p)
    return sdf

In [4]:
# Load target
verts, faces = load_target('../data/sphere.ply')
mesh_sdf     = mesh_sdf_wrapper(verts.to(device), faces.to(device)) 

In [5]:
import torch.optim as optim

# Initialize the tensors
tensor_1 = torch.randn(1, requires_grad=True, device=device)  # Tensor of size 1
tensor_3 = torch.randn(3, requires_grad=True, device=device)  # Tensor of size 3


# Define an optimizer to optimize both tensors
optimizer = optim.SGD([tensor_1,tensor_3], lr=0.01)

In [6]:
# Number of samples per axis (you can adjust this)
n_samples = 10

# Generate evenly spaced points for x, y, z in the range [0, 1]
x = torch.linspace(0, 1, n_samples)
y = torch.linspace(0, 1, n_samples)
z = torch.linspace(0, 1, n_samples)

# Create a meshgrid for x, y, z
x_grid, y_grid, z_grid = torch.meshgrid(x, y, z, indexing='ij')

# Stack the grid to form the final positions
grid_positions = torch.stack((x_grid, y_grid, z_grid), dim=-1).reshape(-1, 3).to(device)

In [None]:
target_sdf = mesh_sdf(grid_positions).to(device)

# Training loop
for epoch in range(100000):  # Run for 100 iterations
    optimizer.zero_grad()  # Zero the gradients

    # Evaluate
    current_sdf = sdf.translation(sdf.sphere(tensor_1), tensor_3)(grid_positions).to(device)
    
    # Loss: simple mean squared error between tensor and target
    loss_1 = torch.nn.functional.mse_loss(current_sdf, target_sdf)
    
    # Total loss (you can weight them differently if needed)
    total_loss = loss_1
    
    # Backward pass and optimize
    total_loss.backward()
    optimizer.step()
    
    # Print the loss and tensor values every 10 epochs
    if epoch % 100 == 0:
        print(f'Epoch {epoch}: Loss = {total_loss.item()}')
        print(f'Tensor 1: {tensor_1.data}')
        print(f'Tensor 3: {tensor_3.data}')

print("Optimization completed.")

In [None]:
# Sdf
sphere = sdf.translation(sdf.sphere(1.5), torch.tensor([10.0, -10.0, 10.0], device=device))
eval_sdf = sphere(grid_positions)
eval_sdf

In [None]:
(eval_mesh-eval_sdf).mean()