In [38]:
import numpy as np
import jax.numpy as jnp
import jax
import pandas as pd
import plotly.express as px
import matplotlib.pyplot as plt
import seaborn as sns
import rho_plus as rp

is_dark = False
theme, cs = rp.mpl_setup(is_dark)
rp.plotly_setup(is_dark)

In [39]:
%cd /home/nicholas/programs/avid

/home/nicholas/programs/avid


In [40]:
from avid.config import MainConfig
from pyrallis import load

config = load(MainConfig, 'configs/diled.toml')

In [41]:
from avid.dataset import load_file

dens = load_file(config, 0).density[0]

In [49]:
import functools as ft
from eins import EinsOp

dens = load_file(config, 0).density[0]

N_GRID = config.voxelizer.n_grid

grid_vals = jnp.linspace(0, 1, N_GRID + 1)[:-1]

xx, yy, zz = jnp.meshgrid(grid_vals, grid_vals, grid_vals)
xyz = EinsOp('d n1 n2 n3 -> (n1 n2 n3) d')(jnp.array([xx, yy, zz]))

xyz = (xyz + 0.2) % 1


def plot(dens, skip=2, thresh=0.02):
    cmap = rp.list_aquaria if is_dark else rp.list_cabana

    thresh = 0.01
    skip_mask = jnp.mgrid[0:N_GRID, 0:N_GRID, 0:N_GRID] % skip == 0
    skip_mask = skip_mask.reshape(3, -1).all(axis=0)
    dens = dens.sum(axis=-1).reshape(-1)
    bg_mask = (abs(dens) > thresh).reshape(-1)
    mask = skip_mask & bg_mask
    return px.scatter_3d(
        x=xyz[mask, 0],
        y=xyz[mask, 1],
        z=xyz[mask, 2],
        color=dens[mask].astype(xyz.dtype),
        # color=np.array(species[mask]).astype('str'),
        color_continuous_scale=cmap,
        range_x=[0, 1], range_y=[0, 1], range_z=[0, 1],
        opacity=1,
    )
plot(dens)

In [None]:
from jaxtyping import Array, Float
from jax.scipy import special

@eqx.filter_jit
def choose(x, y):
    return jnp.exp(special.gammaln(x + 1) - special.gammaln(y + 1) - special.gammaln(x - y + 1))

@eqx.filter_jit
def legendre_poly_coef(n, k):
    return choose(n, k) * choose(n + k, k)

@eqx.filter_vmap
@eqx.filter_jit
def legendre_poly(x: Float[Array, ''], n: int):
    kk = jnp.arange(n + 1)
    return jnp.dot(((x - 1) / 2) ** kk, legendre_poly_coef(n, kk))


xx = jnp.linspace(-1, 1, 100)
for n in range(10):
    plt.plot(xx, legendre_poly(xx, n), label=f'n = {n}')

plt.legend()

In [None]:
xx = jnp.linspace(0, 2, 200)
plt.subplots(figsize=(15, 6))
for n in range(0, 7, 1):
    plt.plot(xx, jnp.cos(n * (xx * 2 - 1) * jnp.pi) ** 2, label=f'n = {n}', lw=1)

plt.legend()

In [None]:
ncheby = 9
nnn = jnp.mgrid[:ncheby, :ncheby, :ncheby][:, ::1, ::1, ::1]
nnn = jnp.array(rearrange(nnn, 'd n1 n2 n3 -> n1 n2 n3 d'))
nnn = nnn[:, :, :, None, :]

basis = jnp.cos(nnn * (xyz * 2 - 1) * jnp.pi)
basis.shape

In [None]:
cheby = basis.prod(axis=-1)
cheby = cheby

In [None]:
coefs = jnp.dot(cheby, dens) / jnp.sum(cheby ** 2, axis=-1)
coefs.shape

In [None]:
from einops import einsum
yhat = einsum(coefs, cheby, 'n1 n2 n3, n1 n2 n3 npt -> npt')
plot(yhat)

In [None]:
plot(2 * dens - 1)

In [None]:
jnp.mean(jnp.abs(yhat - dens))

In [None]:
sns.displot(yhat - dens)