Visualize results and trained models.

**Used for models trained with `train.py`, e.g., PartSDF.**

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import os, os.path
import sys
import json

import numpy as np
import matplotlib.pyplot as plt
import trimesh
import torch

sys.path.insert(0, "../")
from src import visualization as viz
from src import workspace as ws
from src.loss import get_loss_recon
from src.mesh import create_mesh, create_parts, SdfGridFiller
from src.metric import chamfer_distance
from src.primitives import slerp_quaternion
from src.reconstruct import reconstruct_parts
from src.utils import get_color_parts

# Load the experiment

In [None]:
from src.utils import set_seed, get_device

seed = 0
expdir = "../experiments/car_partsdf/"
specs = ws.load_specs(expdir)
device = specs.get("Device", get_device())

print(f"Experiment {expdir} (on {device})")
#set_seed(seed); print(f"Seeds initialized to {seed}.")

def load_gt_parts(instance):
    parts = []
    for i in range(n_parts):
        fn = os.path.join(specs["DataSource"], specs["Parts"]["SamplesDir"].split("/")[0], "meshes", instance, f"part{i}.obj")
        parts.append(trimesh.load(fn) if os.path.isfile(fn) else trimesh.Trimesh())
    return parts

clampD = specs["ClampingDistance"]
latent_reg = specs["LatentRegLambda"]
part_latent_reg = specs["Parts"].get("LatentRegLambda", None)
n_parts = specs["Parts"]["NumParts"]
part_dim = specs["Parts"]["LatentDim"]
use_poses = specs["Parts"].get("UsePoses", False)
use_occ = True if specs.get("ImplicitField", "SDF").lower() in ["occ", "occupancy"] else False

logs = ws.load_history(expdir)

fig, axs = plt.subplots(3, 4, figsize=(13,12))
for ax in axs.flat:
    ax.axis('off')

for i, name in enumerate(['loss', 'loss_part', 'loss_inter', 'loss_reg_part', 'lr', 'lat_norm']):
    if not name in logs:
        continue
    r, c = i//4, i%4
    axs[r,c].axis('on')
    axs[r,c].set_title(name)
    axs[r,c].plot(range(logs['epoch']), logs[name])
    if name+"-val" in logs:
        axs[r,c].plot(list(range(specs["ValidFrequency"], logs['epoch']+1, specs["ValidFrequency"])), logs[name+"-val"])
        axs[r,c].legend(['train', 'valid'])
    if name == 'lr':
        axs[r,c].plot(range(logs['epoch']), logs['lr_lat'])
        axs[r,c].legend(['lr', 'lr_lat'])

# Evaluation
for evaldir in [ws.get_eval_dir(expdir, logs['epoch']), 
                ws.get_eval_dir(expdir, f"{logs['epoch']}_parts")]:
    if os.path.isdir(evaldir):
        print(f"\nLoading evaluation data from {evaldir}")
        metrics = {}
        metric_names = ['chamfer', 'iou', 'ic']
        for k in metric_names:
            filename = os.path.join(evaldir, f"{k}.json")
            if os.path.isfile(filename):
                with open(filename) as f:
                    metrics[k] = json.load(f)
            else:
                metrics[k] = {}
        for metric in metric_names:
            all_values = list(metrics[metric].values())
            all_values = [v for v in all_values if not np.isnan(v)]
            print(f"Average {metric} = {np.mean(all_values) if len(all_values) else np.nan}  ({len(all_values)} shapes)")
            print(f"Median  {metric} = {np.median(all_values) if len(all_values) else np.nan}  ({len(all_values)} shapes)")
        
        # Fig
        for _i, k in enumerate(metric_names):
            i = 2 * 4 + _i
            r, c = i//4, i%4
            axs[r,c].axis('on')
            axs[r,c].set_title("Test " + k)
            axs[r,c].hist(list(metrics[k].values()), bins=20, alpha=0.5)

fig.tight_layout();

## Data

In [None]:
n_samples = specs["SamplesPerScene"]

with open(specs["TrainSplit"]) as f:
    instances = json.load(f)
if specs.get("ValidSplit", None) is not None:
    with open(specs["ValidSplit"]) as f:
        instances_v = json.load(f)
else:
    instances_v = []
if specs.get("TestSplit", None) is not None:
    with open(specs["TestSplit"]) as f:
        instances_t = json.load(f)
else:
    instances_t = []

print(f"{len(instances)} shapes in train dataset.")
print(f"{len(instances_v)} shapes in valid dataset.")
print(f"{len(instances_t)} shapes in test dataset.")

## Model and latents

In [None]:
from src.model import get_model, get_part_latents, get_part_poses

cp_epoch = logs['epoch']
latent_dim = specs['LatentDim']
model = get_model(specs["Network"], **specs.get("NetworkSpecs", {}), n_parts=n_parts, part_dim=part_dim, use_occ=use_occ).to(device)
latents = get_part_latents(len(instances), n_parts, part_dim, specs.get("LatentBound", None), device=device)
if use_poses:
    poses = get_part_poses(len(instances), n_parts, freeze=True, device=device, fill_nans=True)
    ws.load_poses(expdir, poses)

try:
    ws.load_model(expdir, model, cp_epoch)
    ws.load_latents(expdir, latents, cp_epoch)
    print(f"Loaded checkpoint of epoch={cp_epoch}")
except FileNotFoundError as err:
    checkpoint = ws.load_checkpoint(expdir)
    model.load_state_dict(checkpoint['model_state_dict'])
    latents.load_state_dict(checkpoint['latents_state_dict'])
    print(f"File not found: {err.filename}.\nLoading checkpoint instead (epoch={checkpoint['epoch']}).")
    del checkpoint

# Freeze to avoid possible gradient computations
model.eval()
for p in model.parameters():
    p.requires_grad_(False)
latents.requires_grad_(False)

if False:
    print("Model:", model)
print(f"Model has {sum([x.nelement() for x in model.parameters()]):,} parameters.")
print(f"{latents.num_embeddings}x{latents.n_parts} latent vectors of size {latents.embedding_dim}.")
if use_poses:
    print(f"Using part poses (pre-computed).")

# Reconstruction

In [None]:
from src.primitives import standardize_quaternion


def get_mean_latents_poses(use_poses=use_poses):
    """Get the average latent and poses from training data."""
    latent = latents.weight.detach().clone().mean(0, keepdim=True)
    # latent = torch.ones(1, n_parts, latent_dim).normal_(0, 0.01).to(device)
    pose = poses.weight.detach().clone().mean(0, keepdim=True)
    if use_poses:
        R = standardize_quaternion(pose[..., :4])
        t = pose[..., 4:7]
        s = pose[..., 7:10]
    else:
        R, t, s = None, None
    return latent, R, t, s

grid_filler = SdfGridFiller(256, device)

In [None]:
latent, R, t, s = get_mean_latents_poses()
# print("l", latent)
# print("R", R)
# print("t", t)
# print("s", s)

train_mesh = create_mesh(model, latent, 128, 32**3, grid_filler=grid_filler, verbose=True, R=R, t=t, s=s)
train_parts = create_parts(model, latent, 128, 32**3, grid_filler=grid_filler, verbose=True, R=R, t=t, s=s)
train_parts = trimesh.util.concatenate(get_color_parts(train_parts))
viz.plot_render([train_mesh, train_parts], use_texture=[False, True],
                titles=["Reconstruction", "Parts"], full_backfaces=True).show()
viz.plot_sdf_slices(model, latent, clampD=clampD, contour=False, R=R, t=t, s=s).show()
train_parts.show()

## Train shape

In [None]:
idx = np.random.randint(len(instances))
print(f"Shape {idx}: {instances[idx]}")
latent = latents(torch.tensor([idx]).to(device))
R, t, s = poses(torch.tensor([idx]).to(device)) if use_poses else (None, None, None)

train_mesh = create_mesh(model, latent, 128, 32**3, grid_filler=grid_filler, verbose=True, R=R, t=t, s=s)
gt_mesh = trimesh.load(os.path.join(specs["DataSource"], "meshes", instances[idx]+".obj"))
train_parts = create_parts(model, latent, 128, 32**3, grid_filler=grid_filler, verbose=True, R=R, t=t, s=s)
train_parts = trimesh.util.concatenate(get_color_parts(train_parts))
viz.plot_render([gt_mesh, train_mesh, train_parts], use_texture=[False, False, True],
                titles=["GT", "Reconstruction", "Parts"], full_backfaces=True).show()
viz.plot_sdf_slices(model, latent, clampD=clampD, contour=False, R=R, t=t, s=s).show()
train_parts.show()

## Interpolation between train shapes

In [None]:
idx = np.random.randint(len(instances), size=2).tolist()
t = 0.5  # interpolation factor
print(f"Shapes {idx}: {instances[idx[0]]}, {instances[idx[1]]} (t={t:.2f})")
latent = latents(torch.tensor(idx).to(device))
latent = (1. - t) * latent[0] + t * latent[1]
if use_poses:
    R, tr, s = poses(torch.tensor(idx).to(device))
    R = slerp_quaternion(R[0], R[1], t)
    tr = (1. - t) * tr[0] + t * tr[1]
    s = (1. - t) * s[0] + t * s[1]
else:
    R, tr, s = None, None, None

interp_mesh = create_mesh(model, latent, 128, 32**3, grid_filler=grid_filler, verbose=True, R=R, t=tr, s=s)
gt_mesh0 = trimesh.load(os.path.join(specs["DataSource"], "meshes", instances[idx[0]]+".obj"))
gt_mesh1 = trimesh.load(os.path.join(specs["DataSource"], "meshes", instances[idx[1]]+".obj"))
gt_parts0 = trimesh.util.concatenate(get_color_parts(load_gt_parts(instances[idx[0]])))
gt_parts1 = trimesh.util.concatenate(get_color_parts(load_gt_parts(instances[idx[1]])))
interp_parts = create_parts(model, latent, 128, 32**3, grid_filler=grid_filler, verbose=True, R=R, t=tr, s=s)
interp_parts = trimesh.util.concatenate(get_color_parts(interp_parts))
viz.plot_render([gt_mesh0, interp_mesh, gt_mesh1, gt_parts0, interp_parts, gt_parts1], 
                use_texture=[False, False, False, True, True, True], full_backfaces=True,
                titles=["GT 0", f"Reconstruction (t={t:.2f})", "GT 1", "Parts 1", f"Parts (t={t:.2f})", "Parts 2"]).show()
viz.plot_sdf_slices(model, latent, clampD=clampD, contour=False, R=R, t=tr, s=s).show()
interp_mesh.show()

## Test shape
First, try to load an already reconstructed shape. If not, will optimize a latent and save the results (without overwriting).

In [None]:
# Reconstruction
always_reconstruct = False  # True to force reconstruction (do not overwrite existing files)
idx = np.random.choice(len(instances_t))

instance = instances_t[idx]
print(f"Reconstructing test shape {idx} ({instance})")
inter_lambda = specs["Parts"].get("IntersectionLambda", None)
inter_temp = specs["Parts"].get("IntersectionTemp", 1.)
if inter_lambda is not None:
    print(f"Using intersection loss with lambda={inter_lambda:.2f} and temperature={inter_temp}")
rotations, translations = None, None  # default part poses (None)

cp_epoch_str = str(cp_epoch) + "_parts"
latent_subdir = ws.get_recon_latent_subdir(expdir, cp_epoch_str)
mesh_subdir = ws.get_recon_mesh_subdir(expdir, cp_epoch_str)
os.makedirs(latent_subdir, exist_ok=True)
os.makedirs(mesh_subdir, exist_ok=True)
latent_fn = os.path.join(latent_subdir, instance + ".pth")
mesh_fn = os.path.join(mesh_subdir, instance + ".obj")
parts_subdir = ws.get_recon_parts_subdir(expdir, cp_epoch_str)
parts_fn = os.path.join(parts_subdir, instance + ".obj")

loss_recon = get_loss_recon("L1-Hard", reduction='none')

@torch.no_grad()
def fill_nans_average(tensor, average):
    """Replace invalid values (any along dim=-1) by the average."""
    index = torch.isnan(tensor).any(-1)
    tensor[index] = average[index].detach().clone().to(tensor.dtype).to(tensor.device)

# Latent: load existing or reconstruct
if not always_reconstruct and os.path.isfile(latent_fn):
    latent = torch.load(latent_fn)
    print(f"Latent norm = {latent.norm():.4f} (existing)")
    if use_poses:
        _poses = torch.load(os.path.join(ws.get_recon_poses_subdir(expdir, cp_epoch_str), instance + ".pth"))
        rotations, translations, scales = _poses[..., :4], _poses[..., 4:7], _poses[..., 7:10]
else:
    npz = np.load(os.path.join(specs["DataSource"], specs["SamplesDir"], instance, specs["SamplesFile"]))
    npz_parts = np.load(os.path.join(specs["DataSource"], specs["Parts"]["SamplesDir"], instance, specs["SamplesFile"]))
    if use_poses:
        rotations = torch.tensor(np.load(os.path.join(specs["DataSource"], specs["Parts"]["ParametersDir"], instance, "quaternions.npy")))
        translations = torch.tensor(np.load(os.path.join(specs["DataSource"], specs["Parts"]["ParametersDir"], instance, "translations.npy")))
        scales = torch.tensor(np.load(os.path.join(specs["DataSource"], specs["Parts"]["ParametersDir"], instance, "scales.npy")))
        print("Using part poses (GT).")
        # Remove possible Nans (empty parts) 
        _, _R, _t, _s = get_mean_latents_poses()
        fill_nans_average(rotations, _R.squeeze(0))
        fill_nans_average(translations, _t.squeeze(0))
        fill_nans_average(scales, _s.squeeze(0))
    out = reconstruct_parts(model, npz, npz_parts, 400, 8000, 5e-3, loss_recon, specs["ReconLossLambda"], loss_recon,
                            specs["Parts"]["ReconLossLambda"], part_latent_reg, clampD, None, part_dim, n_parts=n_parts,
                            is_part_sdfnet=True, inter_lambda=inter_lambda, inter_temp=inter_temp, 
                            rotations=rotations, translations=translations, scales=scales, verbose=True, device=device)
    if use_poses:
        err, latent, rotations, translations, scales = out
    else:
        err, latent = out
    print(f"Final loss: {err:.6f}, latent norm = {latent.norm():.4f}")

# Mesh: load existing or reconstruct
if not always_reconstruct and os.path.isfile(mesh_fn):
    test_mesh = trimesh.load(mesh_fn)
else:
    test_mesh = create_mesh(model, latent, 128, 32**3, grid_filler=grid_filler, verbose=True, R=rotations, t=translations, s=scales)
gt_mesh = trimesh.load(os.path.join(specs["DataSource"], "meshes", instance+".obj"))
gt_parts = load_gt_parts(instance)
gt_parts = trimesh.util.concatenate(get_color_parts(gt_parts))

# Chamfer
chamfer_samples = 30_000
if test_mesh.is_empty:
    chamfer_val = float('inf')
else:
    chamfer_val = chamfer_distance(gt_mesh.sample(chamfer_samples), test_mesh.sample(chamfer_samples))
print(f"Chamfer-distance (x10^4) = {chamfer_val * 1e4:.6f}")

if not always_reconstruct and os.path.isfile(parts_fn):
    test_parts = trimesh.load(parts_fn)
else:
    test_parts = create_parts(model, latent, 128, 32**3, grid_filler=grid_filler, verbose=True, R=rotations, t=translations, s=scales)
    test_parts = trimesh.util.concatenate(get_color_parts(test_parts))
viz.plot_render([gt_mesh, test_mesh, gt_parts, test_parts], use_texture=[False, False, True, True], max_cols=2,
                titles=["GT", "Reconstruction", "GT Parts", "Parts"], full_backfaces=True).show()
viz.plot_sdf_slices(model, latent, clampD=clampD, contour=False, R=rotations, t=translations, s=scales).show()

test_parts.show()