In [None]:
import torch
import torch.nn as nn
import numpy as np

from einops import rearrange

from mosaic_sdf import MosaicSDF
from shape_sampler import ShapeSampler
from optimizer import MosaicSDFOptimizer
from mosaic_sdf_visualizer import MosaicSDFVisualizer

from pytorch3d.vis.plotly_vis import AxisArgs, plot_batch_individually, plot_scene

In [None]:

def get_sdf(p):
    
    # print(p.shape)
    s = 1 - np.linalg.norm(p, ord=2, axis=-1)
    # print(s.shape)
    return -s

# model.update_sdf_values(get_sdf)

In [None]:
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

device = 'cpu'
k = 7
n_grids = 1
volume_centers = torch.tensor([
    [0,0,0]
    ], dtype=torch.float32)
volume_scales = torch.tensor([1], dtype=torch.float32)

cube_wireframe_path = 'data/cube_wireframe.obj'
cow_mesh_path = 'data/cow_mesh/cow.obj'

shape_sampler = ShapeSampler.from_file(cow_mesh_path, sdf_func=get_sdf)

model = MosaicSDF(
    grid_resolution=k,
    n_grids=n_grids,
    volume_centers=volume_centers,
    volume_scales=volume_scales
).to(device)

model.update_sdf_values(shape_sampler)

In [None]:
model.mosaic_sdf_values

In [None]:

visualizer = MosaicSDFVisualizer(model, shape_sampler, 
    device, template_mesh_path=cube_wireframe_path)#, requires_grad=False)

# visualizer.plot_meshes()

In [None]:
input_points = torch.tensor([
    [-.5,-.5,0]
    ], dtype=torch.float32)

model(input_points)

In [None]:
resolution=5
grid_points = torch.stack(torch.meshgrid(
    torch.linspace(-1, 1, resolution),
    torch.linspace(-1, 1, resolution),
    torch.linspace(-1, 1, resolution), indexing='ij'
), dim=-1).reshape(-1, 3)#.to(device)

# print(grid_points.shape)
model(grid_points).reshape(-1, resolution, resolution)

In [None]:
if True:
    with torch.no_grad():
        meshes = visualizer.create_state_meshes(
            show_mosaic_grids=True,
            show_target_mesh=False,
            resolution=16
            )
        
        # Render the plotly figure
        fig = plot_scene({
            "subplot1": {
                "mesh": meshes
            }
        })
        fig.show()    
    
# Hard boundary on diagonal

In [None]:
p = torch.tensor([
    [-.99,-.99,-.99],
    [0,0,0],
    [.5,.5,.5]
], dtype=torch.float32)
sdfs = model(p)
sdfs


In [None]:
# k = 3
# device = 'cpu'

rel_positions = torch.linspace(-1, 1, steps=k, device=device)
grid_coords = torch.stack(
    torch.meshgrid(rel_positions, rel_positions, rel_positions, indexing='ij'), 
    dim=-1).reshape((-1, 3))
# grid_coords

In [None]:
sdfs = model(grid_coords * .3).reshape((-1, 3))
sdfs


In [None]:
grid_coords.shape

In [None]:
n_grids = 6
volume_centers = nn.Parameter(torch.rand((n_grids, 3)) * 2 - 1)  # Initialize randomly within [-1, 1]
print(volume_centers.shape)

scales = nn.Parameter(torch.rand((n_grids,)))
print(scales.shape)