## Load deformer and visualize results

In [1]:
# load libraries
import trimesh
import torch
import json
import os
from types import SimpleNamespace
from shapenet_dataloader import ShapeNetMesh
from deepdeform.layers.deformation_layer import NeuralFlowDeformer

### Options

In [2]:
# choice of checkpoint to load
run_dir = "runs/run_4746_nnnr_nosign"
checkpoint = "checkpoint_latest.pth.tar_deepdeform_best.pth.tar"
device = torch.device("cuda")

### Setup

In [3]:
# load training args
args = SimpleNamespace(**json.load(open(os.path.join(run_dir, 'params.json'), 'r')))

# setup model
deformer = NeuralFlowDeformer(latent_size=args.lat_dims, f_width=args.deformer_nf, s_nlayers=2, 
                              s_width=5, method=args.solver, nonlinearity=args.nonlin, arch='imnet',
                              adjoint=args.adjoint, rtol=args.rtol, atol=args.atol, via_hub=True,
                              no_sign_net=(not args.sign_net))
lat_params = torch.nn.Parameter(torch.randn(4746, args.lat_dims)*1e-1, requires_grad=True)
deformer.add_lat_params(lat_params)
deformer.to(device)

# load checkpoint
resume_dict = torch.load(os.path.join(run_dir, checkpoint))
start_ep = resume_dict["epoch"]
global_step = resume_dict["global_step"]
tracked_stats = resume_dict["tracked_stats"]
deformer.load_state_dict(resume_dict["deformer_state_dict"])

# dataloader
data_root = args.data_root.replace('shapenet_watertight', 'shapenet_simplified')
dset = ShapeNetMesh(data_root=data_root, split="train", category="chair", 
                    normals=False)


### Test deformation between a pair

In [None]:
source_idx = 0  # choose between [0, 4745]
target_idx = 1  # choose between [0, 4745]

_, _, v_src, f_src, v_tar, f_tar = dset.get_pairs(source_idx, target_idx)
v_src = v_src.to(device)
v_tar = v_tar.to(device)

# get the latent codes corresponding to these shapes
l_src = deformer.get_lat_params(source_idx)  
l_tar = deformer.get_lat_params(target_idx)
l_zero = torch.zeros_like(l_src)

lat_path = lambda l_src_, l_tar_: torch.stack([l_src_, l_zero, l_tar_], dim=1)

# interpolation between source and target
steps = 11
# source to target
