In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import os, os.path
import sys
import json
import time

import numpy as np
import matplotlib.pyplot as plt
import torch
import trimesh

if "../" not in sys.path:
    sys.path.insert(0, "../")
from src import visualization as viz
from src.data import samples_from_tensor
from src.loss import get_loss_recon
from src.model import get_model, activation, features
from src.utils import set_seed, clamp_sdf

# Initialization

In [None]:
seed = 0
set_seed(seed)
print(f"Seeds initialized to {seed}.")

# Data

In [None]:
idx = 0

datasource = "/scratch/cvlab/datasets/datasets_talabot/shapenet_disn/1_normalized/cars/"
trainsplit = "/scratch/cvlab/datasets/datasets_talabot/shapenet_disn/1_normalized/cars/splits/cars_train100.json"

with open(trainsplit) as f:
    split = json.load(f)
instance = split[idx]
print(f"Shape {idx}: {instance}")

# Load shape and samples
mesh_gt = trimesh.load(os.path.join(datasource, "meshes", instance+".obj"))
samples_gt = np.load(os.path.join(datasource, "samples", instance, "deepsdf.npz"))

samples_gt = {k: torch.from_numpy(samples_gt[k]).float().cuda() for k in ['pos', 'neg']}

# add surf points
if False:
    surf_gt = np.load(os.path.join(datasource, "samples", instance, "surface.npz"))['all'][:,:3]
    surf_gt = torch.from_numpy(surf_gt).float().cuda()
    samples_gt = {
        "pos" : torch.cat([samples_gt["pos"], torch.cat([surf_gt[:len(surf_gt)//2], torch.zeros_like(surf_gt[:len(surf_gt)//2, 0:1])], dim=-1)], dim=0),
        "neg" : torch.cat([samples_gt["neg"], torch.cat([surf_gt[len(surf_gt)//2:], torch.zeros_like(surf_gt[len(surf_gt)//2:, 0:1])], dim=-1)], dim=0)
    }
    print("Surf samples:", surf_gt.shape)

print("Pos and neg samples:", samples_gt['pos'].shape, samples_gt['neg'].shape)

mesh_gt.show()

# Model

In [None]:
from src.model.deepsdf import DeepSDF

model = DeepSDF(
    latent_dim=0, 
    hidden_dim=256, 
    n_layers=6, 
    in_insert=[3],
    dropout=0.0, 
    weight_norm=True, 
    activation="relu", 
    features=None
).cuda()

if model.features:
    print(model.features)
    print(f"{sum([x.nelement() for x in model.features.parameters()]):,} parameters in features.")
print(f"{sum([x.nelement() for x in model.parameters()]):,} parameters in generator.")

# Dummy empty latent
latent = torch.zeros(1, model.latent_dim).cuda()

# Re-initialize training history
history = {'epoch': 0, 'loss': []}

# Training

In [None]:
n_iters = 1000
n_samples = 16384
clampD = 0.1

loss_recon = get_loss_recon('L1-Hard', 'mean')

def _train(n_iters):
    """Train the model on the shape."""
    # Inside a function to avoid leftovers variable clogging the memory
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [int(n_iters*0.8), int(n_iters*0.9)], 0.35)

    model.train()
    start, start_it = time.time(), -1
    for it in range(n_iters):
        xyz, sdf_gt = samples_from_tensor(samples_gt['pos'], samples_gt['neg'], n_samples)
        
        sdf_pred = model(xyz)
        if clampD is not None and clampD > 0.:
            sdf_pred = clamp_sdf(sdf_pred, clampD, ref=sdf_gt)
            sdf_gt = clamp_sdf(sdf_gt, clampD)

        loss = loss_recon(sdf_pred, sdf_gt).mean()
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if scheduler is not None:
            scheduler.step()
        
        history['epoch'] += 1
        history['loss'].append(loss.item())
        if (it+1) % max(1, n_iters // 20) == 0:
            print(f"Iter {it+1: 5d}/{n_iters}: loss={history['loss'][-1]:.6f}" + \
                  f"  ({(time.time() - start) / (it - start_it) * 1000: 3.0f}ms/iter)")
            start, start_it = time.time(), it

_train(n_iters)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4,4))
ax.plot(history['loss'])
ax.set_title("Loss")
ax.set_xlabel('Iter')
#ax.set_ylim(0., 0.1)
fig.show()

# Results

In [None]:
# SDF visualization
viz.plot_sdf_slices(model, latent, clampD).show()

In [None]:
# Mesh
from src.mesh import create_mesh

mesh_pred = create_mesh(model, latent, N=256, grid_filler=True)

viz.plot_render([mesh_gt, mesh_pred], titles=["GT", "Reconstruction"]).show()

In [None]:
mesh_pred.show()

In [None]:
mesh_gt.show()