In [None]:
import os

import jax
import k3d
import numpy as np
import equinox as eqx
import mediapy

from gecco_jax import load_config

In [None]:
#path = '../../release-checkpoints/taskonomy'
path = '../../release-checkpoints/shapenet-vol'
#path = '../../release-checkpoints/shapenet-unconditional/chair'
#path = '../../release-checkpoints/shapenet-unconditional/car'
#path = '../../release-checkpoints/shapenet-unconditional/airplane'
config_path = os.path.join(path, 'config.py')
 
config = load_config(config_path)

In [None]:
dataloader,  = config.make_val_loader()
dataset = dataloader.dataset

In [None]:
model = config.make_model(key=jax.random.PRNGKey(42))
save_path = os.path.join(f'{path}/checkpoint', 'ema.eqx')
model = eqx.tree_deserialise_leaves(save_path, like=model)
model = eqx.tree_at( # set the default number of solver steps to 128
    where=lambda m: m.schedule.n_solver_steps,
    pytree=model,
    replace=128,
)
model = eqx.tree_inference(model, True) # set model to eval mode

In [None]:
def sample_and_display(index: int, key=44):
    '''
    Takes the example `index` from the dataset, conditionally generates a sample and compares both visually
    '''
    xyz, ctx_raw, _ = jax.tree_map(lambda array: np.asarray(array), dataset[index].discard_extras())
    
    if ctx_raw:
        mediapy.show_image(ctx_raw.image)
    
    sample = model.sample(xyz.shape, ctx_raw, n=1, key=jax.random.PRNGKey(key))
    sample = np.asarray(sample)

    kw = dict(point_size=0.05, shader='3d')

    plot = k3d.plot()
    plot += k3d.points(sample[0], color=0xff0000, **kw) # show the sample in red
    plot += k3d.points(xyz, color=0x00ff00, **kw) # show the ground truth in green
    plot.display() 

In [None]:
example_id = 100

sample_and_display(example_id)

In [None]:
def upsample(model, context, point_cloud, n_fold: int=50):
    '''
    Uses `model` to upsample the `point_cloud` by `n_fold` times, using `context` as conditioning.
    '''
    # split the input point cloud into two halves
    known = point_cloud.reshape(2, -1, 3)
    
    # define a function that inpaints one half
    def inpaint_one(known):
        return model.sample_inpaint(
            known,
            raw_ctx=context,
            n_completions=n_fold,
            s_churn=0.5,
            n_substeps=4,
            m_to_inpaint=known.shape[0],
            key=jax.random.PRNGKey(42),
        )

    # apply the function to both halves
    samples = jax.vmap(inpaint_one)(known)

    # fold the results into a single point cloud
    return samples.reshape(-1, 3)

def sample_and_upsample(model, key, context=None, n_fold=49):
    '''
    Samples a low res point cloud upsamples it and returns both.
    '''
    # sample a low-resolution point cloud
    sample_low = model.sample_stochastic(
        (2048, 3),
        context,
        key=jax.random.PRNGKey(key),
        s_churn=0.5,
    ).squeeze(0)
    
    # upsample it
    sample_high = upsample(model, context, sample_low, n_fold=n_fold)
    
    # return both
    return np.asarray(sample_low), np.asarray(sample_high)

In [None]:
def show_upsampling(index: int, key=44):
    '''
    Takes the example `index` from the dataset, conditionally generates a low res sample,
    upsamples it and compares both visually.
    '''
    xyz, ctx_raw, _ = jax.tree_map(lambda array: np.asarray(array), dataset[index].discard_extras())
 
    if ctx_raw:
        mediapy.show_image(ctx_raw.image)
    
    low_sample, high_sample = sample_and_upsample(model, key=key, context=ctx_raw, n_fold=10)

    kw = dict(point_size=0.02, shader='3d')

    plot = k3d.plot()

    plot += k3d.points(low_sample, color=0x0000ff, **kw) # low res in blue
    plot += k3d.points(high_sample, color=0xff0000, **kw) # high res in red
    # plot += k3d.points(xyz, color=0x00ff00, **kw) # optional ground truth in green
    plot.display()

In [None]:
show_upsampling(example_id)