Visualize results and trained models.

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import os, os.path
import sys
import json
import time

import numpy as np
import matplotlib.pyplot as plt
import trimesh
import torch

sys.path.insert(0, "../")
from src import visualization as viz
from src import workspace as ws
from src.loss import get_loss_recon
from src.mesh import create_mesh
from src.metric import chamfer_distance
from src.reconstruct import reconstruct

# Single experiment

## Load exp

In [None]:
from src.utils import set_seed

seed = 0
expdir = "../experiments/src_test/"
specs = ws.load_specs(expdir)

print(f"Experiment {expdir}")
#set_seed(seed); print(f"Seeds initialized to {seed}.")

clampD = specs["ClampingDistance"]
latent_reg = specs["LatentRegLambda"]

logs = ws.load_history(expdir)

fig, axs = plt.subplots(1, 4, figsize=(13,4))

for i, name in enumerate(['loss', 'loss_reg', 'lr', 'lat_norm']):
    axs[i].set_title(name)
    axs[i].plot(range(logs['epoch']), logs[name])
    if name+"_val" in logs:
        axs[i].plot(range(logs['epoch']), logs[name+"_val"])
        axs[i].legend(['train', 'valid'])
axs[2].plot(range(logs['epoch']), logs['lr_lat'])
axs[2].legend(['lr', 'lr_lat'])

fig.tight_layout();

## Data

In [None]:
n_samples = specs["SamplesPerScene"]

with open(specs["TrainSplit"]) as f:
    instances = json.load(f)
if specs.get("ValidSplit", None) is not None:
    with open(specs["ValidSplit"]) as f:
        instances_v = json.load(f)
else:
    instances_v = []
if specs.get("TestSplit", None) is not None:
    with open(specs["TestSplit"]) as f:
        instances_t = json.load(f)
else:
    instances_t = []

print(f"{len(instances)} shapes in train dataset.")
print(f"{len(instances_v)} shapes in valid dataset.")
print(f"{len(instances_t)} shapes in test dataset.")

## Model and latents

In [None]:
from src.model import get_model, get_latents

cp_epoch = logs['epoch']
latent_dim = specs['LatentDim']
model = get_model(specs["Network"], **specs.get("NetworkSpecs", {}), latent_dim=latent_dim).cuda()
latents = get_latents(len(instances), latent_dim, specs.get("LatentBound", None))

try:
    ws.load_model(expdir, model, cp_epoch)
    ws.load_latents(expdir, latents, cp_epoch)
    print(f"Loaded checkpoint of epoch={cp_epoch}")
except FileNotFoundError as err:
    checkpoint = ws.load_checkpoint(expdir)
    model.load_state_dict(checkpoint['model_state_dict'])
    latents.load_state_dict(checkpoint['latents_state_dict'])
    print(f"File not found: {err.filename}.\nLoading checkpoint instead (epoch={checkpoint['epoch']}).")
    del checkpoint

# Freeze to avoid possible gradient computations
model.eval()
for p in model.parameters():
    p.requires_grad_(False)

if False:
    print("Model:", model)
print(f"Model has {sum([x.nelement() for x in model.parameters()]):,} parameters.")
print(f"{latents.num_embeddings} latent vectors of size {latents.embedding_dim}.")

## Reconstruction

In [None]:
from sklearn.decomposition import PCA

from src.utils import sample_latents as _sample_latents

_pca = PCA(whiten=True).fit(latents.weight.detach().cpu().numpy())
def sample_latents(n=1, expvar=None):
    """PCA sampling of latent(s) from training distribution."""
    return _sample_latents(latents, n_samples=n, expvar=expvar, pca=_pca)

### Train

In [None]:
idx = np.random.randint(len(instances))
cp_epoch = None
print(f"Shape {idx}: {instances[idx]}")
if cp_epoch is not None:
    ws.load_model(expdir, model, cp_epoch)
    ws.load_latents(expdir, latents, cp_epoch)
    print(f"Loaded checkpoint of epoch={cp_epoch}")
latent = latents(torch.tensor([idx]).cuda())

train_mesh = create_mesh(model, latent, 256, 32**3, grid_filler=True, verbose=True)
gt_mesh = trimesh.load(os.path.join(specs["DataSource"], "meshes", instances[idx]+".obj"))
viz.plot_sdf_slices(model, latent, clampD=clampD, contour=False).show()
train_mesh.show()

### Interpolation

In [None]:
idx = np.random.randint(len(instances), size=2).tolist()
t = 0.5
cp_epoch = None
print(f"Shapes {idx}: {instances[idx[0]]}, {instances[idx[1]]} (t={t:.2f})")
if cp_epoch is not None:
    ws.load_model(expdir, model, cp_epoch)
    ws.load_latents(expdir, latents, cp_epoch)
    print(f"Loaded checkpoint of epoch={cp_epoch}")
latent = latents(torch.tensor(idx).cuda())
latent = (1. - t) * latent[0] + t * latent[1]

interp_mesh = create_mesh(model, latent, 256, 32**3, grid_filler=True, verbose=True)
gt_mesh0 = trimesh.load(os.path.join(specs["DataSource"], "meshes", instances[idx[0]]+".obj"))
gt_mesh1 = trimesh.load(os.path.join(specs["DataSource"], "meshes", instances[idx[1]]+".obj"))
viz.plot_render(
    [gt_mesh0, interp_mesh, gt_mesh1],
    titles=["GT 0", f"Reconstruction (t={t:.2f})", "GT 1"]
).show()
viz.plot_sdf_slices(model, latent, clampD=clampD, contour=False).show()
interp_mesh.show()

### Random

In [None]:
cp_epoch = None
if cp_epoch is not None:
    ws.load_model(expdir, model, cp_epoch)
    print(f"Loaded checkpoint of epoch={cp_epoch}")
latent = sample_latents()
print(f"Latent norm = {latent.norm().item():.6f}")

rand_mesh = create_mesh(model, latent, 256, 32**3, grid_filler=True, verbose=True)
viz.plot_sdf_slices(model, latent, clampD=clampD, contour=False).show()
rand_mesh.show()

### Valid/Test
First, try to load an already reconstructed shape. If not, will optimize a latent and save the results (without overwriting).

In [None]:
# Reconstruction
_instances = instances_t  # valid or test
always_reconstruct = False  # True to force reconstruction (do not overwrite existing files)
idx = np.random.choice(len(_instances))
instance = _instances[idx]
print(f"Reconstructing test shape {idx} ({instance})")

latent_subdir = ws.get_recon_latent_subdir(expdir, cp_epoch)
mesh_subdir = ws.get_recon_mesh_subdir(expdir, cp_epoch)
os.makedirs(latent_subdir, exist_ok=True)
os.makedirs(mesh_subdir, exist_ok=True)
latent_fn = os.path.join(latent_subdir, instance + ".pth")
mesh_fn = os.path.join(mesh_subdir, instance + ".obj")

loss_recon = get_loss_recon("L1-Hard", reduction='none')

# Latent: load existing or reconstruct
if not always_reconstruct and os.path.isfile(latent_fn):
    latent = torch.load(latent_fn)
    print(f"Latent norm = {latent.norm():.4f} (existing)")
else:
    npz = np.load(os.path.join(specs["DataSource"], specs["SamplesDir"], instance, specs["SamplesFile"]))
    err, latent = reconstruct(model, npz, 400, 8000, 5e-3, loss_recon, latent_reg, clampD, None, latent_dim, verbose=True)
    print(f"Final loss: {err:.6f}, Latent norm = {latent.norm():.4f}")
    if not os.path.isfile(latent_fn):  # save reconstruction
        torch.save(latent, latent_fn)
# Mesh: load existing or reconstruct
if not always_reconstruct and os.path.isfile(mesh_fn):
    test_mesh = trimesh.load(mesh_fn)
else:
    test_mesh = create_mesh(model, latent, 256, 32**3, grid_filler=True, verbose=True)
    if not os.path.isfile(mesh_fn):  # save reconstruction
        test_mesh.export(mesh_fn)
gt_mesh = trimesh.load(os.path.join(specs["DataSource"], "meshes", instance+".obj"))

# Chamfer
chamfer_samples = 30_000
chamfer_val = chamfer_distance(gt_mesh.sample(chamfer_samples), test_mesh.sample(chamfer_samples))
print(f"Chamfer-distance (x10^4) = {chamfer_val * 1e4:.6f}")

viz.plot_render([gt_mesh, test_mesh], titles=["GT", "Reconstruction"]).show()
viz.plot_sdf_slices(model, latent, clampD=clampD, contour=False).show()

test_mesh.show()

In [None]:
# Stop here in case "Run All" has been used.
raise RuntimeError("Stop here.")

# Misc.
Misc. code to test various things.