In [1]:
import torch
import torch.nn as nn
import numpy as np

from einops import rearrange

from mosaic_sdf import MosaicSDF
from shape_sampler import ShapeSampler
from optimizer import MosaicSDFOptimizer
from mosaic_sdf_visualizer import MosaicSDFVisualizer

from pytorch3d.vis.plotly_vis import AxisArgs, plot_batch_individually, plot_scene

In [2]:
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

device = 'cpu'
k = 3
n_grids = 1
volume_centers = torch.tensor([
    [0,0,0]
    ], dtype=torch.float32)
volume_scales = torch.tensor([1], dtype=torch.float32)

model = MosaicSDF(
    grid_resolution=k,
    n_grids=n_grids,
    volume_centers=volume_centers,
    volume_scales=volume_scales
).to(device)

In [9]:
def get_sdf(p):
    # print(p.shape)
    s = torch.linalg.norm(p, ord=2, dim=-1)
    # print(s.shape)
    return s

model.update_sdf_values(get_sdf)

In [10]:
model.mosaic_sdf_values

tensor([[[[-1.7321, -1.4142, -1.7321],
          [-1.4142, -1.0000, -1.4142],
          [-1.7321, -1.4142, -1.7321]],

         [[-1.4142, -1.0000, -1.4142],
          [-1.0000, -0.0000, -1.0000],
          [-1.4142, -1.0000, -1.4142]],

         [[-1.7321, -1.4142, -1.7321],
          [-1.4142, -1.0000, -1.4142],
          [-1.7321, -1.4142, -1.7321]]]], grad_fn=<ViewBackward0>)

In [11]:
p = torch.tensor([
    [-.99,-.99,-.99],
    [0,0,0],
    [.5,.5,.5]
], dtype=torch.float32)
sdfs = model(p)
sdfs


tensor([-1.6860,  0.0000, -1.1218], grad_fn=<NanToNumBackward0>)

In [8]:
# k = 3
# device = 'cpu'

rel_positions = torch.linspace(-1, 1, steps=k, device=device)
grid_coords = torch.stack(
    torch.meshgrid(rel_positions, rel_positions, rel_positions, indexing='ij'), 
    dim=-1).reshape((-1, 3))
# grid_coords

In [9]:
sdfs = model(grid_coords * .3).reshape((-1, 3))
sdfs


in_grid_weights: tensor([[[[[0.7000, 0.7000, 0.0000],
           [0.7000, 0.7000, 0.0000],
           [0.0000, 0.0000, 0.0000]],

          [[0.7000, 0.7000, 0.0000],
           [0.7000, 0.3000, 0.0000],
           [0.0000, 0.0000, 0.0000]],

          [[0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000]]]],



        [[[[1.0000, 0.7000, 1.0000],
           [1.0000, 0.7000, 1.0000],
           [0.0000, 0.0000, 0.0000]],

          [[1.0000, 0.7000, 1.0000],
           [1.0000, 0.3000, 1.0000],
           [0.0000, 0.0000, 0.0000]],

          [[0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000]]]],



        [[[[0.0000, 0.7000, 0.7000],
           [0.0000, 0.7000, 0.7000],
           [0.0000, 0.0000, 0.0000]],

          [[0.0000, 0.7000, 0.7000],
           [0.0000, 0.3000, 0.7000],
           [0.0000, 0.0000, 0.0000]],

          [[0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000

tensor([[1.2081, 1.2991, 1.2081],
        [1.2991, 1.3495, 1.2991],
        [1.2081, 1.2991, 1.2081],
        [1.2991, 1.3495, 1.2991],
        [1.3495, 1.4164, 1.3495],
        [1.2991, 1.3495, 1.2991],
        [1.2081, 1.2991, 1.2081],
        [1.2991, 1.3495, 1.2991],
        [1.2081, 1.2991, 1.2081]], grad_fn=<ViewBackward0>)

In [36]:
grid_coords.shape

torch.Size([3, 3, 3, 3])

In [11]:
n_grids = 6
volume_centers = nn.Parameter(torch.rand((n_grids, 3)) * 2 - 1)  # Initialize randomly within [-1, 1]
print(volume_centers.shape)

scales = nn.Parameter(torch.rand((n_grids,)))
print(scales.shape)

torch.Size([6, 3])
torch.Size([6])


In [12]:
volume_centers * scales[:, None]

tensor([[ 0.1015,  0.0511, -0.0690],
        [ 0.0241, -0.4002,  0.1895],
        [-0.1182,  0.0408, -0.4072],
        [-0.4465, -0.1329,  0.0760],
        [ 0.2473,  0.0935, -0.0185],
        [-0.0339, -0.0463, -0.0189]], grad_fn=<MulBackward0>)

In [13]:
volume_centers

Parameter containing:
tensor([[ 0.9275,  0.4668, -0.6302],
        [ 0.0537, -0.8938,  0.4233],
        [-0.1877,  0.0648, -0.6466],
        [-0.5596, -0.1665,  0.0953],
        [ 0.3060,  0.1157, -0.0229],
        [-0.5544, -0.7560, -0.3082]], requires_grad=True)

In [14]:
scales

Parameter containing:
tensor([0.1095, 0.4478, 0.6298, 0.7979, 0.8083, 0.0612], requires_grad=True)

In [39]:
grid_resolution = 3
n_grids = 4  # Reduced for simplicity
mosaic_sdf = MosaicSDF(grid_resolution=grid_resolution, n_grids=n_grids)
# Customize initialization for testing
# mosaic_sdf.volume_centers = nn.Parameter(torch.tensor([[0., 0., 0.], [1., 1., 1.], [-1., -1., -1.], [2., 2., 2.]]))
volume_centers = torch.tensor([
    [0.5, 0.5, 0.5], 
    # [-0.5, -0.5, -0.5], 
    # [0.5, -0.5, 0.5], 
    # [-0.5, 0.5, -0.5]
    ])
mosaic_sdf.volume_centers = nn.Parameter(volume_centers)

scales = torch.ones((volume_centers.shape[0])) * 1
mosaic_sdf.scales = nn.Parameter(scales)



In [57]:

rel_positions = torch.linspace(-1, 1, steps=k, device=device)
grid_coords = torch.stack(
    torch.meshgrid(rel_positions, rel_positions, rel_positions, indexing='ij'), 
    dim=-1)

# torch.norm(grid_coords, dim=-1, p=2)


In [58]:
grid_coords[0]

tensor([[[-1., -1., -1.],
         [-1., -1.,  0.],
         [-1., -1.,  1.]],

        [[-1.,  0., -1.],
         [-1.,  0.,  0.],
         [-1.,  0.,  1.]],

        [[-1.,  1., -1.],
         [-1.,  1.,  0.],
         [-1.,  1.,  1.]]])

In [41]:
sdf_values = torch.zeros((k,k,k))
sdf_values[0, ...] = 0
sdf_values[1, ...] = 1
sdf_values[2, ...] = 2

mosaic_sdf.register_buffer('mosaic_sdf_values', sdf_values)

sdf_values

tensor([[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[2., 2., 2.],
         [2., 2., 2.],
         [2., 2., 2.]]])

In [52]:
# Test the forward function with a known input and check the output
input_points = torch.tensor([
    # [-1, -1, -1], 

    # [1., 1., 1.]
    # [0, 0, 0], 
    # [2, 2, 2], 
    [0.5, .5, 0.5], 
    # [0.5, 0.5, 0.5], 
    # [-0.5, -0.5, -0.5], 
    ])
expected_sdf_values = torch.tensor(
    [
        # 1,
          1

          ], 
    dtype=torch.float32)  # Fill in based on expected logic
actual_sdf_values = mosaic_sdf(input_points)
assert torch.allclose(expected_sdf_values, actual_sdf_values, atol=1e-6),  actual_sdf_values


In [None]:

init_mosaic_sdf_values = torch.zeros((volume_centers.shape[0], grid_resolution, grid_resolution, grid_resolution))
init_mosaic_sdf_values[0] = torch.tensor([
    []
])
# init_mosaic_sdf_values[1] += 2  # Second grid, SDF values of 2, and so on
# init_mosaic_sdf_values[2] += 3
# init_mosaic_sdf_values[3] += 4
mosaic_sdf.register_buffer('mosaic_sdf_values', init_mosaic_sdf_values)



In [None]:
# Test the forward function with a known input and check the output
input_points = torch.tensor([
    # [0., 0., 0.], 
    # [1., 1., 1.]
    [0.5, 0.5, 0.5], 
    [-0.5, -0.5, -0.5], 
    ])
expected_sdf_values = torch.tensor(
    [1, 0], 
    dtype=torch.float32)  # Fill in based on expected logic
actual_sdf_values = mosaic_sdf(input_points)
assertTrue(
    torch.allclose(expected_sdf_values, actual_sdf_values, atol=1e-6), 
    actual_sdf_values)

In [61]:
in_grid_offsets = torch.linspace(-1, 1, k)

x, y, z = torch.meshgrid(in_grid_offsets, in_grid_offsets, in_grid_offsets, indexing='ij')

grid_offsets = (
    torch.stack([x, y, z], dim=-1)
        .reshape((-1, 3))
        .to(device)
)
# grid_offsets

In [62]:

x, y, z = torch.meshgrid(in_grid_offsets, in_grid_offsets, in_grid_offsets)

grid_offsets2 = (
    torch.stack([x, y, z], dim=-1)
        .reshape((-1, 3))
        .to(device)
)

grid_offsets2 == grid_offsets

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


tensor([[True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True]])