In [1]:
import torch
import torch.nn as nn
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import renderers
import sdf
import csg

KeyboardInterrupt: 

In [2]:
device = torch.device("cuda:1")

In [3]:
def compute_rotation_matrix(axes, angles):
    nx, ny, nz = torch.unbind(axes, dim=-1)
    c, s = torch.cos(angles), torch.sin(angles)
    rotation_matrices = torch.stack([
        torch.stack([nx * nx * (1.0 - c) + 1. * c, ny * nx * (1.0 - c) - nz * s, nz * nx * (1.0 - c) + ny * s], dim=-1),
        torch.stack([nx * ny * (1.0 - c) + nz * s, ny * ny * (1.0 - c) + 1. * c, nz * ny * (1.0 - c) - nx * s], dim=-1),
        torch.stack([nx * nz * (1.0 - c) - ny * s, ny * nz * (1.0 - c) + nx * s, nz * nz * (1.0 - c) + 1. * c], dim=-1),
    ], dim=-2)
    return rotation_matrices

def render(signed_distance_function):
    
    # ---------------- camera matrix ---------------- #

    fx = fy = 1024
    cx = cy = 512
    camera_matrix = torch.tensor([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]], device=device).float()

    # ---------------- camera position ---------------- #

    distance = 5.0
    azimuth = np.pi / 4.0
    elevation = np.pi / 4.0

    camera_position = torch.tensor([
        +np.cos(elevation) * np.sin(azimuth), 
        -np.sin(elevation), 
        -np.cos(elevation) * np.cos(azimuth)
    ], device=device).float() * distance

    # ---------------- camera rotation ---------------- #

    target_position = torch.tensor([0.0, -1.0, 0.0], device=device).float()
    up_direction = torch.tensor([0.0, 1.0, 0.0], device=device).float()

    camera_z_axis = target_position - camera_position
    camera_x_axis = torch.cross(up_direction, camera_z_axis, dim=-1)
    camera_y_axis = torch.cross(camera_z_axis, camera_x_axis, dim=-1)
    camera_rotation = torch.stack((camera_x_axis, camera_y_axis, camera_z_axis), dim=-1)
    camera_rotation = nn.functional.normalize(camera_rotation, dim=-2)

    # ---------------- directional light ---------------- #

    light_directions = torch.tensor([1.0, -0.5, 0.0], device=device)

    # ---------------- ray marching ---------------- #

    y_positions = torch.arange(cy * 2, dtype=camera_matrix.dtype, device=device).float()
    x_positions = torch.arange(cx * 2, dtype=camera_matrix.dtype, device=device).float()
    y_positions, x_positions = torch.meshgrid(y_positions, x_positions, indexing='ij')
    z_positions = torch.ones_like(y_positions).float()
    ray_positions = torch.stack((x_positions, y_positions, z_positions), dim=-1)
    ray_positions = torch.einsum("mn,...n->...m", torch.inverse(camera_matrix),  ray_positions)
    ray_positions = torch.einsum("mn,...n->...m", camera_rotation, ray_positions) + camera_position
    ray_directions = nn.functional.normalize(ray_positions - camera_position, dim=-1)
   
    # ---------------- rendering ---------------- #

    ground = sdf.plane(torch.tensor([0.0, -1.0, 0.0], device=device), 0.0)
        
    num_iterations = 100
    convergence_threshold = 1e-1
    #signed_distance_function = csg.union(signed_distance_function, ground)
    signed_distance_function = signed_distance_function

    surface_positions, converged = renderers.sphere_tracing(
                    signed_distance_function=signed_distance_function, 
                    ray_positions=ray_positions, 
                    ray_directions=ray_directions, 
                    num_iterations=num_iterations, 
                    convergence_threshold=convergence_threshold,
                )
    
    surface_positions = torch.where(converged, surface_positions, torch.zeros_like(surface_positions))
    
    surface_normals = renderers.compute_normal(
        signed_distance_function=signed_distance_function, 
        surface_positions=surface_positions,
    )
    surface_normals = torch.where(converged, surface_normals, torch.zeros_like(surface_normals))

    image = renderers.phong_shading(
        surface_normals=surface_normals, 
        view_directions=camera_position - surface_positions, 
        light_directions=light_directions, 
        light_ambient_color=torch.ones(1, 1, 3, device=device),
        light_diffuse_color=torch.ones(1, 1, 3, device=device), 
        light_specular_color=torch.ones(1, 1, 3, device=device), 
        material_ambient_color=torch.full((1, 1, 3), 0.2, device=device) + (torch.rand(1, 1, 3, device=device) * 2 - 1) * 0.1,
        material_diffuse_color=torch.full((1, 1, 3), 0.7, device=device) + (torch.rand(1, 1, 3, device=device) * 2 - 1) * 0.1,
        material_specular_color=torch.full((1, 1, 3), 0.1, device=device),
        material_emission_color=torch.zeros(1, 1, 3, device=device),
        material_shininess=64.0,
    )

    grounded = torch.abs(ground(surface_positions)) < convergence_threshold
    image = torch.where(grounded, torch.full_like(image, 0.9), image)

    shadowed = renderers.compute_shadows(
        signed_distance_function=signed_distance_function, 
        surface_positions=surface_positions, 
        surface_normals=surface_normals,
        light_directions=light_directions, 
        num_iterations=num_iterations, 
        convergence_threshold=convergence_threshold,
        foreground_masks=converged,
    )
    image = torch.where(shadowed, image * 0.5, image)

    image = torch.where(converged, image, torch.ones_like(image))
    
    return image

In [4]:
# ---------------- load target ---------------- #
from pytorch3d.io import load_ply
from pytorch3d.structures import Meshes
import meshplot as mp

def load_target(path):
    # Load the target mesh
    verts, faces = load_ply(path)

    # Convert vertices to a tensor if not already
    if not isinstance(verts, torch.Tensor):
        verts = torch.tensor(verts, dtype=torch.float32)

    # Step 1: Compute the bounding box (min/max coordinates)
    min_coords = verts.min(dim=0)[0]
    max_coords = verts.max(dim=0)[0]

    # Step 2: Compute the center of the bounding box
    center = (min_coords + max_coords) / 2.0

    # Step 3: Center the object to the origin
    verts_centered = verts - center

    # Step 4: Compute the scale factor to fit within the unit cube
    scale_factor = 1.0 / (max_coords - min_coords).max()

    # Step 5: Scale the object to fit the unit cube
    verts_scaled = verts_centered * scale_factor

    # Step 6: Translate the object to the first octant (shift so all coordinates are positive)
    #min_coords_after_scaling = verts_scaled.min(dim=0)[0]
    #verts_final = verts_scaled - min_coords_after_scaling

    # Now `mesh` is centered and scaled to fit in the first octant cube
    return verts_scaled, faces

In [5]:
from pytorch3d.ops import knn_points

def multi_indexing(index: torch.Tensor, shape: torch.Size, dim=-2):
    shape = list(shape)
    back_pad = len(shape) - index.ndim
    for _ in range(back_pad):
        index = index.unsqueeze(-1)
    expand_shape = shape
    expand_shape[dim] = -1
    return index.expand(*expand_shape)


def multi_gather(values: torch.Tensor, index: torch.Tensor, dim=-2):
    # take care of batch dimension of, and acts like a linear indexing in the target dimention
    # we assume that the index's last dimension is the dimension to be indexed on
    return values.gather(dim, multi_indexing(index, values.shape, dim))


def winding_number(pts: torch.Tensor, verts: torch.Tensor, faces: torch.Tensor) -> torch.Tensor:
    """
    Parallel implementation of the Generalized Winding Number of points on the mesh
    O(n_points * n_faces) memory usage, parallelized execution

    1. Project tris onto the unit sphere around every points
    2. Compute the signed solid angle of the each triangle for each point
    3. Sum the solid angle of each triangle

    Parameters
    ----------
    pts    : torch.Tensor, (n_points, 3)
    verts  : torch.Tensor, (n_verts, 3)
    faces  : torch.Tensor, (n_faces, 3)

    This implementation is also able to take a/multiple batch dimension
    """
    # projection onto unit sphere: verts implementation gives a little bit more performance
    uv = verts[..., None, :, :] - pts[..., :, None, :]  # n_points, n_verts, 3
    uv = uv / uv.norm(dim=-1, keepdim=True)  # n_points, n_verts, 3

    # gather from the computed vertices (will result in a copy for sure)
    expanded_faces = faces[..., None, :, :].expand(*faces.shape[:-2], pts.shape[-2], *faces.shape[-2:])  # n_points, n_faces, 3

    u0 = multi_gather(uv, expanded_faces[..., 0])  # n, f, 3
    u1 = multi_gather(uv, expanded_faces[..., 1])  # n, f, 3
    u2 = multi_gather(uv, expanded_faces[..., 2])  # n, f, 3

    e0 = u1 - u0  # n, f, 3
    e1 = u2 - u1  # n, f, 3
    del u1

    # compute solid angle signs
    sign = (torch.cross(e0, e1, dim=-1) * u2).sum(dim=-1).sign()

    e2 = u0 - u2
    del u0, u2

    l0 = e0.norm(dim=-1)
    del e0

    l1 = e1.norm(dim=-1)
    del e1

    l2 = e2.norm(dim=-1)
    del e2

    # compute edge lengths: pure triangle
    l = torch.stack([l0, l1, l2], dim=-1)  # n_points, n_faces, 3

    # compute spherical edge lengths
    l = 2 * (l/2).arcsin()  # n_points, n_faces, 3

    # compute solid angle: preparing: n_points, n_faces
    s = l.sum(dim=-1) / 2
    s0 = s - l[..., 0]
    s1 = s - l[..., 1]
    s2 = s - l[..., 2]

    # compute solid angle: and generalized winding number: n_points, n_faces
    eps = 1e-10  # NOTE: will cause nan if not bigger than 1e-10
    solid = 4 * (((s/2).tan() * (s0/2).tan() * (s1/2).tan() * (s2/2).tan()).abs() + eps).sqrt().arctan()    
    signed_solid = solid * sign  # n_points, n_faces

    winding = signed_solid.sum(dim=-1) / (4 * torch.pi)  # n_points
    

    return winding

# ---------------- mesh to SDF ---------------- #
def compute_sdf(verts, faces, points):

    # Step 1: Flatten the points to shape (N, 3), where N = H * W
    original_shape = points.shape[:-1]  # Save the original shape (H, W)
    points_flat    = points.view(-1, 3)    # Flatten to (N, 3)    
    verts_packed   = verts.unsqueeze(0)  # Add a batch dimension (1, V, 3)
    
    # Step 2: Find the closest vertex on the mesh to each point in `points`
    dists, idx, _ = knn_points(points_flat.unsqueeze(0), verts_packed, K=1)  # (1, N, 1)
    
    # Step 3: Compute the distance from the query point to the closest point on the mesh
    signed_distances = torch.sqrt(dists).squeeze(0).squeeze(-1)  # (N,)
    
    # Step 5: Reshape the sdf_values back to the original grid shape (H, W)
    sdf_values = signed_distances.view(*original_shape).unsqueeze(-1)  # Add the last dimension to get (H, W, 1)
    
    winding = winding_number(points_flat.to(device),verts.to(device),faces.to(device))
    
    signs = torch.where(winding > 0.5, -1.0, 1.0)
    
    return sdf_values * signs.view(*original_shape).unsqueeze(-1)


def mesh_sdf_wrapper(verts, faces):
    def sdf(p):
        # Use the compute_sdf function from before, which takes `verts` and `faces`
        return compute_sdf(verts, faces, p)
    return sdf

In [6]:
# ---------------- test ---------------- #


In [7]:
# Define SDF
signed_distance_functions = sdf.translation(sdf.sphere(0.5), torch.tensor([0.0, 0.0, 0.0], device=device))

In [8]:
# Evaluation Grid
x_positions = torch.tensor([0., 0.2, 0. , 0.], device=device).float()
y_positions = torch.tensor([0.5, 0., 0. , 0.], device=device).float()
z_positions = torch.tensor([0., 0., 0.6 , 0.], device=device).float()

ray_positions = torch.stack((x_positions, y_positions, z_positions), dim=-1)

In [None]:
# Evaluate
signed_distance_functions(ray_positions)

In [10]:
# Render (debug)
image = render(signed_distance_functions)

In [None]:
# Scale the tensor to the range [0, 255]
tensor = (image * 255).byte()  # Convert to byte (uint8)

# Convert to NumPy array
np_array = tensor.cpu().numpy()

# Convert to a PIL image
image = Image.fromarray(np_array)

# Display the image using matplotlib
plt.imshow(image)
plt.axis('off')  # Hide the axis
plt.show()

In [12]:
verts, faces = load_target('../data/sphere.ply')

In [13]:
mesh_sdf = mesh_sdf_wrapper(verts.to(device), faces.to(device))  # Create the mesh SDF function

In [None]:
# Evaluate
mesh_sdf(ray_positions)

In [None]:
# Render (debug)
image_mesh = render(mesh_sdf)

In [None]:
# Scale the tensor to the range [0, 255]
tensor = (image_mesh * 255).byte()  # Convert to byte (uint8)

# Convert to NumPy array
np_array = tensor.cpu().numpy()

# Convert to a PIL image
image_mesh = Image.fromarray(np_array)

# Display the image using matplotlib
plt.imshow(image_mesh)
plt.axis('off')  # Hide the axis
plt.show()

In [None]:
mp.plot(verts.numpy(), faces.numpy())
