# Inference and visualization

This notebook shows how to run inference and visualize data. We use the awesome `k3d` package for interactive plots in Jupyter.

In [None]:
import os
import torch
import k3d
import matplotlib.pyplot as plt

import gecco_torch
from gecco_torch.diffusion import Diffusion
from gecco_torch.structs import Example

Load the config file and get the model definition

In [None]:
root = '../release-checkpoints/taskonomy/'
config = gecco_torch.load_config(f'{root}/config.py')
model: Diffusion = config.model

Prepare data. This can be skipped entirely if the model is unconditional and we're only interested in inference (generation).

In [None]:
data = config.data
data.setup() # PyTorch lightning data modules need to be setup before use

In [None]:
checkpoints_path = f'{root}/lightning_logs/version_0/checkpoints'
checkpoints = os.listdir(checkpoints_path)
checkpoint_name = next(c for c in checkpoints if c.endswith('.ckpt'))
checkpoint_path = os.path.join(checkpoints_path, checkpoint_name)
print(f'Using checkpoint {checkpoint_name}')
checkpoint_state_dict = torch.load(checkpoint_path, map_location='cpu')
model_state_dict = checkpoint_state_dict['ema_state_dict']
model.load_state_dict(model_state_dict)

Grab a batch of data to have access to the conditioning images and intrinsics matices

In [None]:
batch: Example = next(iter(data.val_dataloader()))
print(batch) # print the batch to see what it contains

In [None]:
# find out the best backend
if torch.cuda.is_available():
    map_device = lambda x: x.to(device='cuda')
else:
    map_device = lambda x: x

model: Diffusion = map_device(model).eval()
context = batch.ctx.apply_to_tensors(map_device)

Sample the examples. Since the batch has 48 items, we will sample 48 point clouds

In [None]:
with torch.autocast('cuda', dtype=torch.float16):
    samples = model.sample_stochastic(
        (48, 2048, 3),
        context=context,
        with_pbar=True,
    )

Visualize the input image

In [None]:
example_id = 5 # index within the batch

plt.imshow(batch.ctx.image[example_id].permute(1, 2, 0).cpu().numpy())

Visualize the point cloud in 3d. Green - ground truth, red - GECCO sample.

In [None]:
plot = k3d.plot()
plot += k3d.points(samples[example_id].cpu().numpy().astype('float32'), point_size=0.01, color=0xff0000)
plot += k3d.points(batch.data[example_id].cpu().numpy().astype('float32'), point_size=0.01, color=0x00ff00)
plot.display()

## Bonus: upsampling
GECCO is trained with a specific number of points in each point cloud and at inference time should be used in a similar regime. There is however a trick which allows sampling mulitple new points conditionally on an already existing point cloud. Repeated multiple times and concatenated, we achieve upsampling by creating multiple new points, independent of each other **conditionally on the input data**.

In [None]:
# pick only the current `example_id` to avoid running out of memory
pick_id = lambda t: t[example_id:example_id+1]
sample_to_upsample = map_device(pick_id(samples))

with torch.autocast('cuda', dtype=torch.float16):
    upsampled = model.upsample(
        n_new=100_000,
        data=sample_to_upsample,
        context=context.apply_to_tensors(pick_id),
        with_pbar=True,
        num_steps=32,
    )

Visualize the upsampled point cloud: green - original, red - upsampled.

In [None]:
plot = k3d.plot()
plot += k3d.points(pick_id(samples).squeeze(0).cpu().numpy(), point_size=0.01, color=0x00ff00)
plot += k3d.points(upsampled.squeeze(0).cpu().numpy(), point_size=0.01, color=0xff0000)
plot.display()