In [None]:
import sys
sys.path.append('..')
from train.train_utils.latent_sampler import LatentSampler
from util.visualization.utils_mesh import get_watertight_mesh_for_latent
import torch
import numpy as np
import k3d
from tqdm import tqdm
from util.checkpointing import load_yaml_and_drop_keys
from util.misc import get_model
from models.net_w_partials import NetWithPartials

#### Load model and create latents

In [2]:
mc_resolution = 248
device = 'cpu'
n_shapes = 9
torch.set_default_device(device)

# Parameters
config = load_yaml_and_drop_keys('../checkpoints/GINN-config.yml', keys_to_drop=[])
bounds = torch.from_numpy(np.load('../GINN/simJEB/data/bounds.npy')).float()

## MODEL
model = get_model(**config['model'], use_legacy_gabor=True)
model.load_state_dict(state_dict=torch.load('../checkpoints/GINN-model.pt'))

# to handle derivatives of the model we created an abstraction called `NetWithPartials`
netp = NetWithPartials.create_from_model(model, config['nz'], config['nx'])
lat_sampler = LatentSampler(**config['latent_sampling'])
z_latents = lat_sampler.val_z()
print(f'z_latents: {z_latents}')

  model.load_state_dict(state_dict=torch.load('../checkpoints/GINN-model.pt'))


z_latents: tensor([[0.0000, 0.0000],
        [0.0000, 0.0500],
        [0.0000, 0.1000],
        [0.0500, 0.0000],
        [0.0500, 0.0500],
        [0.0500, 0.1000],
        [0.1000, 0.0000],
        [0.1000, 0.0500],
        [0.1000, 0.1000]])


#### Extract meshes

In [11]:
## visualize shapes for a range of z
meshes = []
for z in tqdm(z_latents): ## do marching cubes for every z
    meshes.append(get_watertight_mesh_for_latent(netp.f_, netp.params, z, bounds, mc_resolution, device, 
                                                 chunks=1, level=0,
                                                 surpress_watertight=True))

100%|██████████| 9/9 [00:45<00:00,  5.04s/it]


#### Visualize meshes

In [None]:
spacing = 1.5*(bounds[:,1] - bounds[:,0]) ## distances between the shape center to plot in a grid (1.5 times the shape bounding box)
n_rows, n_cols = config['latent_sampling']['val_plot_grid']

fig = k3d.plot(height=800)
for i_shape in range(len(meshes)):
    i_col = i_shape  % n_cols
    i_row = i_shape // n_cols
    fig += k3d.mesh(*meshes[i_shape], color=0xdddcdc, side='double', translation=[0, spacing[1]*i_col, spacing[2]*i_row])

fig.display()

Output()