In [None]:
import os

import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from scipy.spatial.distance import cdist

from gecco_jax.models.reparam import UVLReparam, GaussianReparam
from gecco_jax import load_config

In [None]:
path = '../../release-checkpoints/shapenet-vol'
config_path = os.path.join(path, 'config.py')
config = load_config(config_path)

In [None]:
loader = config.make_train_loader()#[0]

In [None]:
reference_reparam = GaussianReparam(
    mean=jnp.zeros(3),
    std=jnp.ones(3),
)

# reference_reparam = UVLReparam(
#     # uvl_mean=jnp.zeros(3),
#     # uvl_std=jnp.ones(3),
# )

In [None]:
def apply_reparam(reparam, examples):
    reparametrized = []
    for example in examples:
        uvl = jax.vmap(reparam.data_to_diffusion, in_axes=(1, None), out_axes=1)(example.points, example.ctx)
        reparametrized.append(np.asarray(uvl))

    return np.concatenate(reparametrized, axis=0)

In [None]:
examples = []
for i, example in enumerate(loader):
    if i == 10:
        break
    
    example = jax.tree_map(
        lambda tensor: jax.device_put(tensor.numpy() if hasattr(tensor, 'numpy') else tensor),
        example,
    )
    examples.append(example)

In [None]:
def plot_stats(reparametrized):
    fig, ax = plt.subplots()
    kw = dict(histtype='step', bins=np.linspace(reparametrized.min(), reparametrized.max(), 100))

    reparametrized_flat = reparametrized.reshape(-1, reparametrized.shape[-1]).T
    for data, label in zip(reparametrized_flat, ('x', 'y', 'z')):
        mean = data.mean()
        std = data.std()
        
        label = f'{label}: $\mu=${mean:0.2f}, $\sigma=${std:0.2f}'
        ax.hist(data, label=label, **kw)

    std = reparametrized_flat.std()
    ax.set_title(f'$\sigma={std:.2f}$')
    fig.legend()

    mean = reparametrized_flat.mean(axis=1)
    std = reparametrized_flat.std(axis=1)

    return mean, std

In [None]:
reparametrized_reference = apply_reparam(reference_reparam, examples)
mean, std = plot_stats(reparametrized_reference)

In [None]:
adjusted_reparam = GaussianReparam(mean=mean, std=std)

print(f'mean={adjusted_reparam.mean}, std={adjusted_reparam.std}')

reparametrized_adjusted = apply_reparam(adjusted_reparam, examples)
_ = plot_stats(reparametrized_adjusted)

In [None]:
reparametrized_adjusted_flat = reparametrized_adjusted.reshape(reparametrized_adjusted.shape[0], -1)
p_distances = cdist(reparametrized_adjusted_flat, reparametrized_adjusted_flat)
ixs = np.arange(p_distances.shape[0])
p_distances[ixs, ixs] = -float('inf') # diagonal is not interesting

In [None]:
_ = plt.hist(p_distances.flatten(), bins=np.linspace(0, np.nanmax(p_distances), 100), log=True)

In [None]:
xs, ys = np.where(p_distances > np.quantile(p_distances.flatten(), 0.99))
permutation = np.random.permutation(xs.shape[0])
xs = xs[permutation]
ys = ys[permutation]

In [None]:
import k3d

ps = 0.1
plot = k3d.plot()
plot += k3d.points(reparametrized_adjusted[xs[0]], point_size=ps, color=0xff0000)
plot += k3d.points(reparametrized_adjusted[ys[0]], point_size=ps, color=0x00ff00)
plot.display()