The goal of this notebook is to reproduce results from work: https://github.com/vsitzmann/siren on SDF task.

Citation from paper:
- Data: In paper the show results on Thai statue from the The Stanford 3D Scanning Repository.
- Architecture: We use the same 5-layer SIREN MLP for all experiments on SDF, using 256 units in each layer for the statue and 1024 units in each layer for the room.
- Hyperparameters: We train for 50,000 iterations, and **at each iteration fit on every voxel** of the volume. We use the Adam optimizer with a learning rate of 1 × 10−4 for all experiments. 
- We train for 50,000 iterations requiring approximately 6h hours to fit and evaluate a SIREN.
- SIREN converge already very well after around 5,000-7,000 iterations.

![image.png](screenshots/image_2023-04-26_14-14-03.png)

In [1]:
!nvidia-smi

Wed Apr 26 12:25:05 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.47.03    Driver Version: 510.47.03    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   69C    P8    20W /  70W |      2MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla T4            Off  | 00000000:00:05.0 Off |                    0 |
| N/A   77C    P0    72W /  70W |  14047MiB / 15360MiB |     61%      Default |
|       

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch
from pathlib import Path
Path.ls = lambda x: list(x.iterdir())

try:
    import lovely_tensors as lt
except:
    ! pip install --upgrade lovely-tensors
    import lovely_tensors as lt
    
lt.monkey_patch()

In [6]:
# !rm -rf /app/notebooks/siren_sdf/checkpoints/*

In [7]:
config = {
    "device": torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu"),
    "wandb_project": "siren_sdf",
    "experiment_name": "thai_statue_baseline",
    "logging": True,
    "point_cloud_path": "data/thai_statue.xyz",
    "batch_size": 25_000,
    "clip_grad": True,
    "checkpoint_dir": Path('checkpoints/'),
    "save_ckpt_freq": 10_000,
    "vis_freq": 2_500,
    "epochs": 10000,
    
    "lr": 1e-4,
    "iteration_on_stop": 50_000,
    "hidden_features": 256,
    "num_hidden_layers": 5,
    "net_type": 'sine',
}

config["checkpoint_dir"].mkdir(exist_ok=True, parents=True)

# Train SDF

In [8]:
# !pip install scikit-image==0.20.0 scikit-video==1.1.11 opencv-python==4.7.0.72 cmapy==0.6.6 ConfigArgParse==1.5.3 plyfile==0.9 -q
# !pip uninstall scipy -y; pip install scipy

In [9]:
import src.dataio as dataio
from torch.utils.data import DataLoader
from src.utils import get_sdf_summary
from src.sdf_meshing import create_mesh

In [10]:
sdf_dataset = dataio.PointCloud(config['point_cloud_path'], on_surface_points=config['batch_size'])
dataloader = DataLoader(sdf_dataset, shuffle=True, batch_size=1, pin_memory=True, num_workers=0)

Loading point cloud
Finished loading point cloud


In [11]:
from src.nn_modules import SingleBVPNet
    
model = SingleBVPNet(type=config['net_type'], in_features=3, hidden_features=config['hidden_features'], num_hidden_layers=config['num_hidden_layers']).to(config['device'])
# if config.get('load_from_checkpoint_path') is not None and Path(config['load_from_checkpoint_path']).exists():
#     model.load_state_dict(torch.load(config['load_from_checkpoint_path']))

In [12]:
from src.loss_functions import sdf

optimizer = torch.optim.Adam(lr=config['lr'], params=model.parameters())

In [13]:
import wandb
if config['logging']:
    run = wandb.init(project=config["wandb_project"], name=config["experiment_name"], config=config)

[34m[1mwandb[0m: Currently logged in as: [33mnerlfield[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
from tqdm.auto import tqdm

iteration = 0
for epoch in (pbar := tqdm(range(int(config['epochs'])))):
    for step, (model_input, gt) in enumerate(dataloader):
        model_input = {key: value.to(config['device']) for key, value in model_input.items()}
        gt = {key: value.to(config['device']) for key, value in gt.items()}
        
        model_output = model(model_input)
        losses = sdf(model_output, gt)
        
        train_loss = 0.
        for loss_name, loss in losses.items():
            single_loss = loss.mean()
            train_loss += single_loss
                
        if config['logging']:
            wandb.log({
                "sdf": losses['sdf'].item(),
                "inter": losses['inter'].item(),
                "normal_constraint": losses['normal_constraint'].item(),
                "grad_constraint": losses['grad_constraint'].item(),
                "train_loss": train_loss.item()
            })
            
        optimizer.zero_grad()
        train_loss.backward()

        if config['clip_grad']:
            if isinstance(config['clip_grad'], bool):
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config['clip_grad'])

        optimizer.step()
        
        if iteration % config['vis_freq'] == 0 and config['logging']:
            sdf_summary = get_sdf_summary(model, model_input, gt, model_output)
            wandb.log({
                "xy_sdf_slice": wandb.Image(sdf_summary['xy_sdf_slice']),
                "xz_sdf_slice": wandb.Image(sdf_summary['xz_sdf_slice']),
                "yz_sdf_slice": wandb.Image(sdf_summary['yz_sdf_slice'])
            })
            
        if iteration % config['save_ckpt_freq'] == 0:
            torch.save(model.state_dict(), config['checkpoint_dir'] / f"{config['wandb_project']}_{config['experiment_name']}_{epoch}_{iteration}.pth")
        
        pbar.set_description(f' => Loss: {train_loss.item():.3f}')
        
        iteration += 1
        
        if iteration > config['iteration_on_stop']:
            break
    if iteration > config['iteration_on_stop']:
            break

  0%|          | 0/10000 [00:00<?, ?it/s]

  ax.contour(sample, levels=[0], colors='k', linewidths=0.3)


In [15]:
torch.save(model.state_dict(), config['checkpoint_dir'] / f"{config['wandb_project']}_{config['experiment_name']}_{epoch}_{iteration}.pth")

In [31]:
[i for i in config['checkpoint_dir'].ls()]

[Path('checkpoints/siren_sdf_thai_statue_baseline_251_50001.pth'),
 Path('checkpoints/siren_sdf_thai_statue_splitact_baseline_with_10_mult_0_0.pth'),
 Path('checkpoints/siren_sdf_thai_statue_baseline_100_20000.pth'),
 Path('checkpoints/siren_sdf_thai_statue_baseline_150_30000.pth'),
 Path('checkpoints/siren_sdf_thai_statue_splitact_baseline_with_10mult_normal_init_0_0.pth'),
 Path('checkpoints/siren_sdf_thai_statue_baseline_0_0.pth'),
 Path('checkpoints/siren_sdf_thai_statue_baseline_201_40000.pth'),
 Path('checkpoints/siren_sdf_thai_statue_baseline_50_10000.pth'),
 Path('checkpoints/siren_sdf_thai_statue_splitact_baseline_with_5mult_0_0.pth'),
 Path('checkpoints/siren_sdf_thai_statue_baseline_251_50000.pth'),
 Path('checkpoints/siren_sdf_thai_statue_splitact_baseline_with_5mult_50_10000.pth'),
 Path('checkpoints/siren_sdf_thai_statue_splitact_baseline_with_normal_init_0_0.pth'),
 Path('checkpoints/siren_sdf_thai_statue_splitact_baseline_0_0.pth'),
 Path('checkpoints/siren_sdf_thai_sta

# Save to mesh

In [None]:
break

In [34]:
from src.nn_modules import SingleBVPNet

class SDFDecoder(torch.nn.Module):
    def __init__(self, config, ckpt_path=None):
        super().__init__()
        # Define the model.
        self.model = SingleBVPNet(type=config['net_type'], in_features=3, hidden_features=config['hidden_features'], num_hidden_layers=config['num_hidden_layers'])
        if ckpt_path is not None:
            self.model.load_state_dict(torch.load(ckpt_path))
        self.model = self.model.to(config['device'])

    def forward(self, coords):
        model_in = {'coords': coords}
        return self.model(model_in)['model_out']

In [35]:
decoder = SDFDecoder(config, Path('checkpoints/siren_sdf_thai_statue_baseline_50_10000.pth'))

In [36]:
create_mesh(decoder, config['experiment_name'], device=config['device'], N=200)

0
262144
524288
786432
1048576
1310720
1572864
1835008
2097152
2359296
2621440
2883584
3145728
3407872
3670016
3932160
4194304
4456448
4718592
4980736
5242880
5505024
5767168
6029312
6291456
6553600
6815744
7077888
7340032
7602176
7864320
sampling takes: 5.443153
torch.Size([200, 200, 200])
