In [None]:
import torch
import torch.nn as nn
import numpy as np

from einops import rearrange

from mosaic_sdf import MosaicSDF
from shape_sampler import ShapeSampler
from optimizer import MosaicSDFOptimizer
from mosaic_sdf_visualizer import MosaicSDFVisualizer


In [None]:
sphere_mesh_path = 'data/sphere.obj'
cube_wireframe_path = 'data/cube_wireframe.obj'


sdf_shape_path = 'data/chain.obj'

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [None]:
shape_sampler = ShapeSampler.from_file(sdf_shape_path, device='cuda', 
                                       normalize_shape=True,
                                       make_watertight=True)

config = {   
    'device': device,
    
    'shape_path': sdf_shape_path,  # Adjust accordingly
    
    # mosaicSDF params
    'grid_resolution': 7,
    # 'n_grids': 1024,
    'n_grids': 256, # the most important parameter

    'points_random_spread': .03,
    'val_points_random_spread': .03,
    'mosaic_scale_multiplier': 3,
    
    # optimizer params
    'lr': 1e-4,
    'weight_decay': 0.0,
    "b1": 0.9,
    "b2": .999,

    'lambda_val': .1,
    
    
    # optimization params
    'points_in_epoch': 4096,
    'points_sample_size': 32,
    'gradient_accumulation_steps': 4,

    'eval_every_nth_points': 1024,
    'val_size': 1024,
    'points_sample_size_eval_scaler': 4, # can sample faster during eval

    'project_name': 'mosaicSDF_select',
    'log_to_wandb': False, 

    # other debug stuff
    'output_graph': False,
    'points_random_sampling': False
}

optimizer = MosaicSDFOptimizer(config)

optimizer.model.update_sdf_values(shape_sampler)


In [None]:

visualizer = MosaicSDFVisualizer(optimizer.shape_sampler, 
    device, template_mesh_path=cube_wireframe_path)#, requires_grad=False)

In [None]:
from utils_vis import compare_shapes

compare_shapes(shape_sampler, visualizer, optimizer.model, 
               resolution=80, 
               show_mosaic_grids=False,
               show_gt_mesh=True, show_gt_sdf=True, show_mosaic_sdf=True,
               region_span=[1, .6, .2]
               )

In [None]:
for i in range(4):
    optimizer.step()

In [None]:

compare_shapes(shape_sampler, visualizer, optimizer.model, 
               resolution=80, 
               show_mosaic_grids=False,
               show_gt_mesh=True, show_gt_sdf=True, show_mosaic_sdf=True,
               region_span=[1, .6, .2]
               )