Visualize results and trained baseline models.

**Used for models trained with `train_baseline.py`, e.g., LatentModulated (:= PartSDF-1Part).**

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import os, os.path
import sys
import json

import numpy as np
import matplotlib.pyplot as plt
import trimesh
import torch

sys.path.insert(0, "../")
from src import visualization as viz
from src import workspace as ws
from src.loss import get_loss_recon
from src.mesh import create_mesh, SdfGridFiller
from src.metric import chamfer_distance
from src.reconstruct import reconstruct

# Load the experiment

In [None]:
from src.utils import set_seed, get_device

seed = 0
expdir = "../experiments/car_baseline/"
specs = ws.load_specs(expdir)
device = specs.get("Device", get_device())

print(f"Experiment {expdir} (on {device})")
#set_seed(seed); print(f"Seeds initialized to {seed}.")

clampD = specs["ClampingDistance"]
latent_reg = specs["LatentRegLambda"]

logs = ws.load_history(expdir)

fig, axs = plt.subplots(3, 4, figsize=(13,12))
for ax in axs.flat:
    ax.axis('off')

for i, name in enumerate(['loss', 'part_loss', 'loss_inter', 'loss_reg', 'lr', 'lat_norm']):
    if not name in logs:
        continue
    r, c = i//4, i%4
    axs[r,c].axis('on')
    axs[r,c].set_title(name)
    axs[r,c].plot(range(logs['epoch']), logs[name])
    if name+"-val" in logs:
        axs[r,c].plot(list(range(specs["ValidFrequency"], logs['epoch']+1, specs["ValidFrequency"])), logs[name+"-val"])
        axs[r,c].legend(['train', 'valid'])
    if name == 'lr':
        axs[r,c].plot(range(logs['epoch']), logs['lr_lat'])
        axs[r,c].legend(['lr', 'lr_lat'])

# Evaluation
for evaldir in [ws.get_eval_dir(expdir, logs['epoch']), 
                ws.get_eval_dir(expdir, f"{logs['epoch']}_parts")]:
    if os.path.isdir(evaldir):
        print(f"\nLoading evaluation data from {evaldir}")
        metrics = {}
        metric_names = ['chamfer', 'iou', 'ic']
        for k in metric_names:
            filename = os.path.join(evaldir, f"{k}.json")
            if os.path.isfile(filename):
                with open(filename) as f:
                    metrics[k] = json.load(f)
            else:
                metrics[k] = {}
        for metric in metric_names:
            all_values = list(metrics[metric].values())
            all_values = [v for v in all_values if not np.isnan(v)]
            print(f"Average {metric} = {np.mean(all_values) if len(all_values) else np.nan}  ({len(all_values)} shapes)")
            print(f"Median  {metric} = {np.median(all_values) if len(all_values) else np.nan}  ({len(all_values)} shapes)")
        
        # Fig
        for _i, k in enumerate(metric_names):
            i = 2 * 4 + _i
            r, c = i//4, i%4
            axs[r,c].axis('on')
            axs[r,c].set_title("Test " + k)
            axs[r,c].hist(list(metrics[k].values()), bins=20, alpha=0.5)

fig.tight_layout();

## Data

In [None]:
n_samples = specs["SamplesPerScene"]

with open(specs["TrainSplit"]) as f:
    instances = json.load(f)
if specs.get("ValidSplit", None) is not None:
    with open(specs["ValidSplit"]) as f:
        instances_v = json.load(f)
else:
    instances_v = []
if specs.get("TestSplit", None) is not None:
    with open(specs["TestSplit"]) as f:
        instances_t = json.load(f)
else:
    instances_t = []

print(f"{len(instances)} shapes in train dataset.")
print(f"{len(instances_v)} shapes in valid dataset.")
print(f"{len(instances_t)} shapes in test dataset.")

## Model and latents

In [None]:
from src.model import get_model, get_latents

cp_epoch = logs['epoch']
latent_dim = specs['LatentDim']
model = get_model(specs["Network"], **specs.get("NetworkSpecs", {}), latent_dim=latent_dim).to(device)
latents = get_latents(len(instances), latent_dim, specs.get("LatentBound", None), device=device)

try:
    ws.load_model(expdir, model, cp_epoch)
    ws.load_latents(expdir, latents, cp_epoch)
    print(f"Loaded checkpoint of epoch={cp_epoch}")
except FileNotFoundError as err:
    checkpoint = ws.load_checkpoint(expdir)
    model.load_state_dict(checkpoint['model_state_dict'])
    latents.load_state_dict(checkpoint['latents_state_dict'])
    print(f"File not found: {err.filename}.\nLoading checkpoint instead (epoch={checkpoint['epoch']}).")
    del checkpoint

# Freeze to avoid possible gradient computations
model.eval()
for p in model.parameters():
    p.requires_grad_(False)
latents.requires_grad_(False)

if False:
    print("Model:", model)
print(f"Model has {sum([x.nelement() for x in model.parameters()]):,} parameters.")
print(f"{latents.num_embeddings} latent vectors of size {latents.embedding_dim}.")

# Reconstruction

In [None]:
grid_filler = SdfGridFiller(256, device)

## Train shape

In [None]:
idx = np.random.randint(len(instances))
print(f"Shape {idx}: {instances[idx]}")
latent = latents(torch.tensor([idx]).to(device))

train_mesh = create_mesh(model, latent, 256, 32**3, grid_filler=grid_filler, verbose=True)
gt_mesh = trimesh.load(os.path.join(specs["DataSource"], "meshes", instances[idx]+".obj"))
viz.plot_render([gt_mesh, train_mesh], titles=["GT", "Reconstruction"], full_backfaces=True).show()
viz.plot_sdf_slices(model, latent, clampD=clampD, contour=False).show()
train_mesh.show()

## Interpolation between train shapes

In [None]:
idx = np.random.randint(len(instances), size=2).tolist()
t = 0.5  # interpolation factor
print(f"Shapes {idx}: {instances[idx[0]]}, {instances[idx[1]]} (t={t:.2f})")
latent = latents(torch.tensor(idx).to(device))
latent = (1. - t) * latent[0] + t * latent[1]

interp_mesh = create_mesh(model, latent, 256, 32**3, grid_filler=grid_filler, verbose=True)
gt_mesh0 = trimesh.load(os.path.join(specs["DataSource"], "meshes", instances[idx[0]]+".obj"))
gt_mesh1 = trimesh.load(os.path.join(specs["DataSource"], "meshes", instances[idx[1]]+".obj"))
viz.plot_render([gt_mesh0, interp_mesh, gt_mesh1],
                titles=["GT 0", f"Reconstruction (t={t:.2f})", "GT 1"], full_backfaces=True).show()
viz.plot_sdf_slices(model, latent, clampD=clampD, contour=False).show()
interp_mesh.show()

## Test shape
First, try to load an already reconstructed shape. If not, will optimize a latent and save the results (without overwriting).

In [None]:
# Reconstruction
always_reconstruct = False  #Â True to force reconstruction (do not overwrite existing files)
idx = np.random.choice(len(instances_t))

instance = instances_t[idx]
print(f"Reconstructing test shape {idx} ({instance})")

cp_epoch_str = str(cp_epoch)
latent_subdir = ws.get_recon_latent_subdir(expdir, cp_epoch_str)
mesh_subdir = ws.get_recon_mesh_subdir(expdir, cp_epoch_str)
os.makedirs(latent_subdir, exist_ok=True)
os.makedirs(mesh_subdir, exist_ok=True)
latent_fn = os.path.join(latent_subdir, instance + ".pth")
mesh_fn = os.path.join(mesh_subdir, instance + ".obj")

loss_recon = get_loss_recon("L1-Hard", reduction='none')

# Latent: load existing or reconstruct
if not always_reconstruct and os.path.isfile(latent_fn):
    latent = torch.load(latent_fn)
    print(f"Latent norm = {latent.norm():.4f} (existing)")
else:
    npz = np.load(os.path.join(specs["DataSource"], specs["SamplesDir"], instance, specs["SamplesFile"]))
    err, latent = reconstruct(model, npz, 400, 8000, 5e-3, loss_recon, latent_reg, clampD, None, latent_dim, 
                              verbose=True)
    print(f"Final loss: {err:.6f}, latent norm = {latent.norm():.4f}")
    if not os.path.isfile(latent_fn):  # save reconstruction
        torch.save(latent, latent_fn)
# Mesh: load existing or reconstruct
if not always_reconstruct and os.path.isfile(mesh_fn):
    test_mesh = trimesh.load(mesh_fn)
else:
    test_mesh = create_mesh(model, latent, 256, 32**3, grid_filler=grid_filler, verbose=True)
    if not os.path.isfile(mesh_fn):  # save reconstruction
        test_mesh.export(mesh_fn)
gt_mesh = trimesh.load(os.path.join(specs["DataSource"], "meshes", instance+".obj"))

# Chamfer
chamfer_samples = 30_000
if test_mesh.is_empty:
    chamfer_val = float('inf')
else:
    chamfer_val = chamfer_distance(gt_mesh.sample(chamfer_samples), test_mesh.sample(chamfer_samples))
print(f"Chamfer-distance (x10^4) = {chamfer_val * 1e4:.6f}")

viz.plot_render([gt_mesh, test_mesh], titles=["GT", "Reconstruction"], full_backfaces=True).show()
viz.plot_sdf_slices(model, latent, clampD=clampD, contour=False).show()

test_mesh.show()