In [58]:
import torch
import gecco_torch
from gecco_torch.scatter import plot, combine_plots
from gecco_torch.data.neurons import TorchNeuronNet
from gecco_torch.utils import apply_transform

In [59]:
config_root = '/groups/turaga/home/troidlj/gecco/gecco-torch/example_configs'
config = gecco_torch.load_config(f'{config_root}/neuron_conditional.py') # load the model definition
model = config.model
state_dict = torch.load(f'{config_root}/lightning_logs/version_38/arxiv/epoch=9-step=50000.ckpt', map_location='cuda')
model.load_state_dict(state_dict['ema_state_dict'])
model = model.eval()

In [60]:
root = "/nrs/turaga/jakob/autoproof_data/flywire_cave_post/cache_t5"
split = "train"
n_points = 2048
data = TorchNeuronNet(root, split, n_points)

[4, 4, 40]


In [61]:
idx = 111
partial = data[idx].partial.unsqueeze(0)  # get a partial point cloud for conditioning
points = data[idx].points.unsqueeze(0)  # get the full point cloud for comparison
T = data[idx].T
T_i = data[idx].T_i

In [62]:
print (torch.sum(T_i @ T))

tensor(4.0005)


In [63]:
samples = model.sample_stochastic(
    (2, 2048, 3), # one example with 2048 3-dimensional points
    context=partial, # assuming an conditional model
    with_pbar=True, # shows a tqdm progress bar for sampling
)

  0%|          | 0/64 [00:00<?, ?step/s]

In [64]:
out = samples[0, ...].squeeze()  # get the first sample
partial = partial.squeeze()  # get the partial point cloud
gt = points.squeeze()  # get the ground truth point cloud

In [65]:
combined = torch.cat([gt, partial], dim=0)  # combine the partial and generated point clouds
combined_labels = torch.cat([
    torch.ones(gt.shape[0]) * 0,  # label for generated points
    torch.ones(partial.shape[0]) * 1  # label for ground truth points
], dim=0)

plot(combined.cpu().numpy(), combined_labels.cpu().numpy(), label_map={0: 'Ground Truth', 1: 'Partial Input'})

In [66]:
combined = torch.cat([gt, out], dim=0)  # combine the partial and generated point clouds
combined_labels = torch.cat([
    torch.ones(gt.shape[0]) * 0,  # label for generated points
    torch.ones(out.shape[0]) * 1  # label for ground truth points
], dim=0)

plot(combined.cpu().numpy(), combined_labels.cpu().numpy(), label_map={0: 'Ground Truth', 1: 'Generated'})