In [1]:
import torch
import torch.nn as nn
import numpy as np

In [2]:


class ShapeSampler(nn.Module):
    def __init__(self, shape_type):       
        super(ShapeSampler, self).__init__()

        self.shape_type = shape_type
        
    def forward(self, points):
        """
        This is a stub function that simulates sampling the SDF for different shapes.
        :param shape_type: A string indicating the type of shape ('sphere', 'cube', 'pyramid').
        :param points: A tensor of points in space (N, 3) we want to sample the SDF for.
        :return: A tensor representing the SDF values at the provided points.
        """
        if self.shape_type == "sphere":
            return self.sample_sphere(points)
        elif self.shape_type == "cube":
            return self.sample_cube(points)
        # elif shape_type == "pyramid":
        #     return self.sample_pyramid(points)
        else:
            raise ValueError("Unknown shape type.")
        
    def sample_sphere(self, points, radius=1.0):
        """
        Simulate SDF sampling for a sphere centered at the origin.
        :param points: Points at which to sample the SDF (N, 3).
        :param radius: Radius of the sphere.
        :return: SDF values at the provided points.
        """
        return torch.norm(points, dim=1) - radius

    def sample_cube(self, points, side_length=2.0):
        """
        Simulate SDF sampling for a cube centered at the origin.
        :param points: Points at which to sample the SDF (N, 3).
        :param side_length: Side length of the cube.
        :return: SDF values at the provided points.
        """
        half_side = side_length / 2
        max_dist = torch.max(torch.abs(points) - half_side, dim=1)[0]
        return max_dist
    
    # def sample_pyramid(self, points, height=2.0, base=2.0):
    #     """
    #     Simulate SDF sampling for a pyramid centered at the origin.
    #     :param points: Points at which to sample the SDF (N, 3).
    #     :param height: Height of the pyramid.
    #     :param base: Base length of the pyramid.
    #     :return: SDF values at the provided points.
    #     """
    #     # This is a simplified stub for the pyramid SDF.
    #     # A true SDF for a pyramid would involve more complex calculations.
    #     x_dist = torch.abs(points[:, 0]) - base/2
    #     y_dist = torch.abs(points[:, 1]) - base/2
    #     z_dist = points[:, 2] - height/2
    #     return torch.max(torch.max(x_dist, y_dist), z_dist)


In [3]:

# Example usage

points = torch.tensor([
    [0.5, 0.5, 0.5], 
    [1, 0, 0], 
    [-1.0, -1.0, -1.0], 
    [0.0, 0.0, 1.5]
])

# Sample SDF for different shapes at the specified points
sphere_sdf = ShapeSampler(shape_type="sphere").forward(points)
cube_sdf = ShapeSampler("cube")(points)
# pyramid_sdf = shape_sampler("pyramid", points)

sphere_sdf, cube_sdf #, pyramid_sdf


(tensor([-0.1340,  0.0000,  0.7321,  0.5000]),
 tensor([-0.5000,  0.0000,  0.0000,  0.5000]))

In [4]:

def initialize_mosaic_sdf(shape_sampler, n_grids, grid_resolution, grid_scale):
    """
    Initialize the Mosaic-SDF representation for a given shape.
    :param shape_sampler: 3D shape sampler.
    :param n_grids: Number of local grids.
    :param grid_resolution: Resolution of each grid (k x k x k).
    :param grid_scale: Scale (size) of each grid.
    :return: A representation of the shape as a set of local grids.
    """
    # Step 1: Normalize the shape to fit within a designated volume.
    
    # Step 2: Sample boundary points and initialize volume centers using farthest point sampling.
    volume_centers = sample_volume_centers(shape_sampler, n_grids)
    
    # Step 3: Initialize all grids with a uniform scale and empty SDF values.
    # TODO select more meaningful initial scale ?
    scales = torch.full((n_grids,), fill_value=grid_scale)
    
    sdf_values = torch.zeros((n_grids, grid_resolution, grid_resolution, grid_resolution))
    
    # Step 4: For each grid, compute the local SDF values (simplified here as zeros).
    for i in range(n_grids):
        # In practice, compute and store the SDF values for points within the grid.
        sdf_values[i] = compute_local_sdf(shape_sampler, volume_centers[i], grid_scale)
    
    # Each grid's representation could be expanded to include its location, scale, and SDF values.
    grids_representation = (volume_centers, scales, sdf_values)
    
    # TODO concat grid_representation into single row?

    return grids_representation

def sample_volume_centers(shape_sampler, n_grids):
    # Placeholder for actual implementation of farthest point sampling.
    return torch.rand((n_grids, 3))  # Randomly initialized for illustration.


def compute_local_sdf(shape_sampler, center, scale, grid_resolution=7):
    x = torch.linspace(center[0] - scale / 2, center[0] + scale / 2, grid_resolution)
    y = torch.linspace(center[1] - scale / 2, center[1] + scale / 2, grid_resolution)
    z = torch.linspace(center[2] - scale / 2, center[2] + scale / 2, grid_resolution)
    x, y, z = torch.meshgrid(x, y, z, indexing='ij')
    grid_points = torch.stack([x, y, z], dim=-1).reshape(-1, 3)
    sdf_values = shape_sampler(grid_points).reshape(grid_resolution, grid_resolution, grid_resolution)
    return sdf_values


In [5]:
# Using the updated ShapeSampler class
shape_sampler = ShapeSampler("sphere")
center = torch.tensor([0.0, 0.0, 0.0])
scale = 2.0

# Compute local SDF values
local_sdf_values = compute_local_sdf(shape_sampler, center, scale, grid_resolution=3)
print(local_sdf_values.shape) 

torch.Size([3, 3, 3])


In [6]:
local_sdf_values

tensor([[[ 0.7321,  0.4142,  0.7321],
         [ 0.4142,  0.0000,  0.4142],
         [ 0.7321,  0.4142,  0.7321]],

        [[ 0.4142,  0.0000,  0.4142],
         [ 0.0000, -1.0000,  0.0000],
         [ 0.4142,  0.0000,  0.4142]],

        [[ 0.7321,  0.4142,  0.7321],
         [ 0.4142,  0.0000,  0.4142],
         [ 0.7321,  0.4142,  0.7321]]])

In [7]:

# def trilinear_interpolation(sdf_values, relative_pos):
#     # Placeholder function for trilinear interpolation
#     # sdf_values: tensor of SDF values in a grid
#     # relative_pos: relative position within the grid, scaled to [0, 1]
#     # This function should interpolate sdf_values based on relative_pos
#     return torch.tensor(0.0)  # Simplified return value for illustration

def trilinear_interpolation_pytorch(grid, relative_pos):
    return None


def scalar_weight_function(distance, scale):
    # Placeholder function for scalar weight based on distance and scale
    return 1.0  # Simplified return value for illustration


In [8]:
class MosaicSDF(nn.Module):
    def __init__(self, shape_sampler: ShapeSampler, grid_resolution=7, n_grids=1024):
        """
        Initialize the MosaicSDF representation.
        
        :param shape_sampler: Shape Sampler facade.
        :param n_grids: Number of local grids.
        :param grid_resolution: Resolution of each grid (assumed cubic for simplicity).
        :param grid_scale: Initial scale (size) for each grid.
        """
        super(MosaicSDF, self).__init__()
        
        self.shape_sampler = shape_sampler

        # Assuming volume_centers, scales, and sdf_values are learnable parameters
        self.volume_centers = nn.Parameter(torch.rand((n_grids, 3)) * 2 - 1)  # Initialize randomly within [-1, 1]
        
        min_rand_scale = .01
        max_rand_scale = 1.
        self.scales = nn.Parameter(torch.rand((n_grids,)) * (max_rand_scale - min_rand_scale) + min_rand_scale)

        self.sdf_values = nn.Parameter(torch.randn(n_grids, grid_resolution, grid_resolution, grid_resolution))
        
    def forward(self, points):
        """
        Compute the SDF values at given points using the Mosaic-SDF representation.
        
        :param points: Tensor of points where SDF values are to be computed (N, 3).
        :return: SDF values at the provided points.
        """
        
        points_expanded_to_grids = points[:, None, :]
        grids_expanded_to_points = self.volume_centers[None, ...]
        scales_expanded_to_points = self.scales[None, ..., None]
        
        points_grid_centers_diff = points_expanded_to_grids - grids_expanded_to_points
        scaled_points_grid_centers_diff = points_grid_centers_diff / scales_expanded_to_points
        
        sdf_values_expanded_to_points = self.sdf_values[None, ...]
        points_interpolated_sdf = trilinear_interpolation(sdf_values_expanded_to_points, scaled_points_grid_centers_diff)
        
        distances_from_point_to_grid_centers = torch.linalg.vector_norm(points_grid_centers_diff, dim=-1)
        
        points_sdf_weight = scalar_weight_function(distances_from_point_to_grid_centers, scales_expanded_to_points)

        points_sdf = torch.sum(points_interpolated_sdf * points_sdf_weight, dim=-1).reshape(points.shape[0], 1)
        return points_sdf


In [43]:

volume_centers = torch.tensor([
    [0.1, 0, 0.5], 
    [0.5, 0.2, 0], 
    [1, 1, .5], 
])

scales = torch.tensor([1,1,1])

points = torch.tensor([
    [0.5, 0.5, 0], 
    [1, 0, .5], 
])



In [44]:
points_expanded_to_grids = points[:, None, :]
grids_expanded_to_points = volume_centers[None, ...]
scales_expanded_to_points = scales[None, ..., None]

points_grid_centers_diff = points_expanded_to_grids - grids_expanded_to_points
scaled_points_grid_centers_diff = points_grid_centers_diff / scales_expanded_to_points

In [45]:
k = 7

sdf_values = torch.rand((scales.shape[0], k, k ,k))

In [46]:

k = 7

def trilinear_interpolation_pytorch(grid, relative_pos):
    """
    Performs trilinear interpolation within a 7x7x7 SDF grid using PyTorch, for a batch of points.

    Args:
        grid (torch.Tensor): Tensor containing the SDF values for multiple grids.
                             Shape: (n_grids, 7, 7, 7).
        relative_pos (torch.Tensor): Batch of 3D coordinates (x, y, z) of points within the grids,
                                     values should be in the range [-1, 1]. Shape: (N, 3).

    Returns:
        torch.Tensor: The interpolated SDF values at the given points. Shape: (N,).
    """
    # TODO
    # if relative_pos is outside -1,1 - return 0 SDF value

    
    # Convert relative positions from [-1, 1] to [0, 6]
    relative_pos_scaled = (relative_pos + 1) * 3
    base_indices = torch.floor(relative_pos_scaled).long()
    upper_indices = torch.ceil(relative_pos_scaled).long()
    
    # Calculate the fractional part for interpolation
    fractional_part = relative_pos_scaled - base_indices.float()

    # Ensure indices are within bounds
    base_indices = torch.clamp(base_indices, 0, 6)
    upper_indices = torch.clamp(upper_indices, 0, 6)

    # Preparing to gather corner values for each grid and each point
    B, N, _ = relative_pos.shape
    idx = torch.arange(B)[:, None]

    # Calculate weights for trilinear interpolation
    fx, fy, fz = fractional_part.unbind(-1)
    weights = torch.stack([
        (1 - fx) * (1 - fy) * (1 - fz),
        (1 - fx) * (1 - fy) * fz,
        (1 - fx) * fy * (1 - fz),
        (1 - fx) * fy * fz,
        fx * (1 - fy) * (1 - fz),
        fx * (1 - fy) * fz,
        fx * fy * (1 - fz),
        fx * fy * fz
    ], dim=-1)  # Shape: (B, N, 8)
    
    # Gather the values at the 8 corners for each point
    corner_values = torch.zeros(N, 8)
    for i in range(8):
        bx, by, bz = [b[:, i] for b in torch.meshgrid(base_indices[:, 0], base_indices[:, 1], base_indices[:, 2], indexing='ij')]
        ux, uy, uz = [u[:, i] for u in torch.meshgrid(upper_indices[:, 0], upper_indices[:, 1], upper_indices[:, 2], indexing='ij')]
        corner_values[:, i] = grid[idx, bx, by, bz]
    
    # Compute the interpolated SDF values
    interpolated_sdf = torch.sum(weights * corner_values, dim=-1)

    return interpolated_sdf


In [47]:
trilinear_interpolation_pytorch(sdf_values, scaled_points_grid_centers_diff)

RuntimeError: torch.meshgrid: Expected 0D or 1D tensor in the tensor list but got:  4  4  1
 5  3  3
[ torch.LongTensor{2,3} ]

In [48]:
print(scaled_points_grid_centers_diff.shape)
scaled_points_grid_centers_diff

torch.Size([2, 3, 3])


tensor([[[ 0.4000,  0.5000, -0.5000],
         [ 0.0000,  0.3000,  0.0000],
         [-0.5000, -0.5000, -0.5000]],

        [[ 0.9000,  0.0000,  0.0000],
         [ 0.5000, -0.2000,  0.5000],
         [ 0.0000, -1.0000,  0.0000]]])

In [49]:
grid = sdf_values
relative_pos = scaled_points_grid_centers_diff

In [50]:

# Convert relative positions from [-1, 1] to [0, 6]
relative_pos_scaled = (relative_pos + 1) * 3
base_indices = torch.floor(relative_pos_scaled).long()
upper_indices = torch.ceil(relative_pos_scaled).long()

# Calculate the fractional part for interpolation
fractional_part = relative_pos_scaled - base_indices.float()

# Ensure indices are within bounds
base_indices = torch.clamp(base_indices, 0, 6)
upper_indices = torch.clamp(upper_indices, 0, 6)

# Preparing to gather corner values for each grid and each point
B, N, _ = relative_pos.shape
idx = torch.arange(B)[:, None]

In [51]:

# Calculate weights for trilinear interpolation
fx, fy, fz = fractional_part.unbind(-1)
weights = torch.stack([
    (1 - fx) * (1 - fy) * (1 - fz),
    (1 - fx) * (1 - fy) * fz,
    (1 - fx) * fy * (1 - fz),
    (1 - fx) * fy * fz,
    fx * (1 - fy) * (1 - fz),
    fx * (1 - fy) * fz,
    fx * fy * (1 - fz),
    fx * fy * fz
], dim=-1)  # Shape: (B, N, 8)

In [56]:
weights.shape

torch.Size([2, 3, 8])

In [52]:
fractional_part

tensor([[[0.2000, 0.5000, 0.5000],
         [0.0000, 0.9000, 0.0000],
         [0.5000, 0.5000, 0.5000]],

        [[0.7000, 0.0000, 0.0000],
         [0.5000, 0.4000, 0.5000],
         [0.0000, 0.0000, 0.0000]]])

In [55]:
fractional_part[:,:,0]

tensor([[0.2000, 0.0000, 0.5000],
        [0.7000, 0.5000, 0.0000]])

In [53]:
fractional_part.unbind(-1)

(tensor([[0.2000, 0.0000, 0.5000],
         [0.7000, 0.5000, 0.0000]]),
 tensor([[0.5000, 0.9000, 0.5000],
         [0.0000, 0.4000, 0.0000]]),
 tensor([[0.5000, 0.0000, 0.5000],
         [0.0000, 0.5000, 0.0000]]))

In [37]:
idx.shape

torch.Size([2, 1])

In [32]:
relative_pos

tensor([[[ 0.5000,  0.5000, -0.5000],
         [ 0.0000,  0.0000,  0.0000],
         [-0.5000, -0.5000, -0.5000]],

        [[ 1.0000,  0.0000,  0.0000],
         [ 0.5000, -0.5000,  0.5000],
         [ 0.0000, -1.0000,  0.0000]]])

In [31]:
relative_pos_scaled

tensor([[[4.5000, 4.5000, 1.5000],
         [3.0000, 3.0000, 3.0000],
         [1.5000, 1.5000, 1.5000]],

        [[6.0000, 3.0000, 3.0000],
         [4.5000, 1.5000, 4.5000],
         [3.0000, 0.0000, 3.0000]]])

In [29]:
base_indices

tensor([[[4, 4, 1],
         [3, 3, 3],
         [1, 1, 1]],

        [[6, 3, 3],
         [4, 1, 4],
         [3, 0, 3]]])

In [30]:
upper_indices

tensor([[[5, 5, 2],
         [3, 3, 3],
         [2, 2, 2]],

        [[6, 3, 3],
         [5, 2, 5],
         [3, 0, 3]]])