In [1]:
import torch
from pathlib import Path

try:
    import lovely_tensors as lt
except:
    ! pip install --upgrade lovely-tensors
    import lovely_tensors as lt
    
lt.monkey_patch()

In [13]:
!rm -rf /app/notebooks/siren_sdf/checkpoints/*

wandb: Network error (ReadTimeout), entering retry loop.


In [4]:
config = {
    "device": torch.device("cuda:1") if torch.cuda.is_available() else torch.device("cpu"),
    "wandb_project": "siren_sdf",
    "experiment_name": "siren_sdf_baseline",
    "logging": True,
    "point_cloud_path": "data/interior_room.xyz",
    "batch_size": 1400,
    "lr": 1e-4,
    "clip_grad": True,
    "checkpoint_dir": Path('checkpoints/'),
    "save_ckpt_freq": 10,
    "epochs": 10000,
}

config["checkpoint_dir"].mkdir(exist_ok=True, parents=True)

# Train SDF

In [5]:
import src.dataio as dataio
from torch.utils.data import DataLoader

sdf_dataset = dataio.PointCloud(config['point_cloud_path'], on_surface_points=config['batch_size'])
dataloader = DataLoader(sdf_dataset, shuffle=True, batch_size=1, pin_memory=True, num_workers=0)

Loading point cloud
Finished loading point cloud


In [6]:
from src.nn_modules import SingleBVPNet

model = SingleBVPNet(type='sine', in_features=3).to(config['device'])

In [7]:
from src.loss_functions import sdf

optimizer = torch.optim.Adam(lr=config['lr'], params=model.parameters())

In [8]:
import wandb
if config['logging']:
    run = wandb.init(project=config["wandb_project"], name=config["experiment_name"], config=config)

[34m[1mwandb[0m: Currently logged in as: [33mnerlfield[0m. Use [1m`wandb login --relogin`[0m to force relogin
2023-04-18 14:31:08.947484: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-04-18 14:31:08.947538: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


In [9]:
from tqdm.auto import tqdm

iteration = 0
for epoch in (pbar := tqdm(range(int(config['epochs'])))):
    for step, (model_input, gt) in enumerate(dataloader):
        model_input = {key: value.to(config['device']) for key, value in model_input.items()}
        gt = {key: value.to(config['device']) for key, value in gt.items()}
        
        model_output = model(model_input)
        losses = sdf(model_output, gt)
        
        train_loss = 0.
        for loss_name, loss in losses.items():
            single_loss = loss.mean()
            train_loss += single_loss
            
            if config['logging']:
                wandb.log({loss_name: single_loss.item()})
                
        if config['logging']:
            wandb.log({"train_loss": train_loss.item()})
            
        optimizer.zero_grad()
        train_loss.backward()

        if config['clip_grad']:
            if isinstance(config['clip_grad'], bool):
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config['clip_grad'])

        optimizer.step()
        
        if epoch % config['save_ckpt_freq'] == 0 and epoch > 0:
            torch.save(model.state_dict(), config['checkpoint_dir'] / f"{config['wandb_project']}_{config['experiment_name']}_{epoch}_{iteration}.pth")
        
        pbar.set_description(f' => Loss: {train_loss.item():.3f}')
        
        iteration += 1

  0%|          | 0/10000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [11]:
config['save_ckpt_freq']

10

In [10]:
epoch % config['save_ckpt_freq']

3

# Save to mesh

In [3]:
from src.nn_modules import SingleBVPNet

class SDFDecoder(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        # Define the model.
        self.model = SingleBVPNet(type='siren', final_layer_factor=1, in_features=3)
        self.model.load_state_dict(torch.load(config['checkpoint_dir'] / f"{config['wandb_project']}_{config['experiment_name']}_{epoch}_{iteration}.pth"))
        self.model = self.model.to(config['device'])

    def forward(self, coords):
        model_in = {'coords': coords}
        return self.model(model_in)['model_out']

In [4]:
decoder = SDFDecoder(config)

In [5]:
from src.sdf_meshing import create_mesh

In [6]:
create_mesh(decoder, config['experiment_name'], device=config['device'])

0
262144
524288
786432
1048576
1310720
1572864
1835008
2097152
2359296
2621440
2883584
3145728
3407872
3670016
3932160
4194304
4456448
4718592
4980736
5242880
5505024
5767168
6029312
6291456
6553600
6815744
7077888
7340032
7602176
7864320
8126464
8388608
8650752
8912896
9175040
9437184
9699328
9961472
10223616
10485760
10747904
11010048
11272192
11534336
11796480
12058624
12320768
12582912
12845056
13107200
13369344
13631488
13893632
14155776
14417920
14680064
14942208
15204352
15466496
15728640
15990784
16252928
16515072
sampling takes: 5.221504
torch.Size([256, 256, 256])
