In [18]:
import torch
import gecco_torch
from gecco_torch.scatter import plot, combine_plots
from gecco_torch.data.neurons import TorchNeuronNet
from gecco_torch.utils import apply_transform
from chamferdist import ChamferDistance

In [2]:
config_root = '/groups/turaga/home/troidlj/gecco/gecco-torch/example_configs'
config = gecco_torch.load_config(f'{config_root}/neuron_conditional.py') # load the model definition
model = config.model
state_dict = torch.load(f'{config_root}/lightning_logs/version_39/arxiv/epoch=49-step=250000.ckpt', map_location='cuda')
model.load_state_dict(state_dict['ema_state_dict'])
model = model.eval()



In [3]:
root = "/nrs/turaga/jakob/autoproof_data/flywire_cave_post/cache_t5"
split = "test"
n_points = 2048
data = TorchNeuronNet(root, split, n_points)

[4, 4, 40]


In [4]:
idx = 700
partial = data[idx].partial.unsqueeze(0)  # get a partial point cloud for conditioning
points = data[idx].points.unsqueeze(0)  # get the full point cloud for comparison
T = data[idx].T
T_i = data[idx].T_i

In [5]:
print (torch.sum(T_i @ T))

tensor(4.)


In [6]:
samples = model.sample_stochastic(
    (10, 2048, 3), # one example with 2048 3-dimensional points
    context=partial, # assuming an conditional model
    with_pbar=True, # shows a tqdm progress bar for sampling
)

  0%|          | 0/64 [00:00<?, ?step/s]

In [19]:
out = samples[8, ...].squeeze()  # get the first sample
partial = partial.squeeze()  # get the partial point cloud
gt = points.squeeze()  # get the ground truth point cloud

In [26]:
print(out.shape, out.dtype)
print(gt.shape, gt.dtype)

out_loss = out.unsqueeze(0).to(torch.float32)
gt_loss = gt.unsqueeze(0).to(torch.float32)
partial_loss = partial.unsqueeze(0).to(torch.float32)

chamfer_loss = ChamferDistance()
loss = chamfer_loss(partial_loss, out_loss)  # compute the Chamfer distance between the partial and generated point clouds
print(f"Chamfer Loss: {loss.item()}")

torch.Size([2048, 3]) torch.float64
torch.Size([2048, 3]) torch.float32
Chamfer Loss: 0.0035175506491214037


In [12]:
combined = torch.cat([gt, partial], dim=0)  # combine the partial and generated point clouds
combined_labels = torch.cat([
    torch.ones(gt.shape[0]) * 0,  # label for generated points
    torch.ones(partial.shape[0]) * 1  # label for ground truth points
], dim=0)

plot(combined.cpu().numpy(), combined_labels.cpu().numpy(), label_map={0: 'Ground Truth', 1: 'Partial Input'})

In [27]:
combined = torch.cat([partial, out], dim=0)  # combine the partial and generated point clouds
combined_labels = torch.cat([
    torch.ones(partial.shape[0]) * 0,  # label for partial points
    torch.ones(out.shape[0]) * 1  # label for generated points
], dim=0)

plot(combined.cpu().numpy(), combined_labels.cpu().numpy(), label_map={0: 'Partial', 1: 'Generated'})

In [14]:
import os
from tqdm import tqdm

pth = "/nrs/turaga/jakob/autoproof_data/flywire_cave_post/matrix/{}".format(idx)

if not os.path.exists(pth):
    os.makedirs(pth)

for i in tqdm(range(samples.shape[0])):
    out1 = samples[i, ...].squeeze()  # get the first sample
    c1 = torch.cat([out1, gt], dim=0)  # combine the first generated and ground truth point clouds

    labels = torch.cat([
        torch.ones(out1.shape[0]) * 0,  # label for generated points
        torch.ones(gt.shape[0]) * 1  # label for ground truth points
    ], dim=0)

    fig1 = plot(c1.cpu().numpy(), labels.cpu().numpy(), label_map={0: 'Generated', 1: 'Ground Truth'})
    pth1 = os.path.join(pth, '{}.png'.format(i))
    fig1.write_image(pth1)


100%|██████████| 10/10 [00:29<00:00,  2.91s/it]
