In [None]:
import numpy as np
import jax.numpy as jnp
import jax
import pandas as pd
import plotly.express as px
import matplotlib.pyplot as plt
import seaborn as sns
import rho_plus as rp

is_dark = True
theme, cs = rp.mpl_setup(is_dark)
rp.plotly_setup(is_dark)

In [None]:
%cd /home/nicholas/programs/avid

In [None]:
m = 56
n = 128

params = np.random.randn(len(ii)).clip(-3, 3) * 0.2
rot = np.random.randn(n).clip(-3, 3) * 1

def make_mat(params, rot):
    mat = jnp.eye(m, n)
    ii, jj = jnp.tril_indices_from(mat, k=-1)
    mat = mat.at[ii, jj].add(params)

    orth = jax.lax.linalg.householder_product(mat, rot)
    orth = orth / jnp.linalg.norm(orth, axis=0)
    return orth

orth=make_mat(params, rot)
sns.heatmap(orth, center=0, cmap='rho_diverging')

In [None]:
from eins import EinsOp
vs = np.random.randn(m, n)

vvt = EinsOp('m n1, m n2 -> m n1 n2')(vs, vs)
vtv = EinsOp('m n, m n -> m 1 1')(vs, vs)
mats = jnp.eye(n, n) - 2 / vtv * vvt
orth = jnp.linalg.multi_dot(mats)
sns.heatmap(orth, center=0, cmap='rho_diverging')

In [None]:
N_GRID = 24

data = eqx.tree_deserialise_leaves('precomputed/densities/batch1.eqx',
                            {'density': jnp.zeros((52, N_GRID ** 3), dtype=jnp.float32),
                             'species': jnp.zeros((52, N_GRID ** 3), dtype=jnp.int16)})

In [None]:
import functools as ft
from einops import rearrange, reduce

grid_vals = jnp.linspace(0, 1, N_GRID + 1)[:-1]

xx, yy, zz = jnp.meshgrid(grid_vals, grid_vals, grid_vals)
xyz = rearrange(jnp.array([xx, yy, zz]), 'd n1 n2 n3 -> (n1 n2 n3) d')

dens = data['density'][0]

cmap = rp.list_aquaria if is_dark else rp.list_cabana

def plot(dens, n_skip=2, thresh=0.1, cmap=cmap):
    bg_mask = abs(dens) > thresh
    skip_mask = jnp.all(jnp.round(xyz * N_GRID) % n_skip == 0, axis=1)
    mask = skip_mask & bg_mask
    return px.scatter_3d(
        x=xyz[mask, 0],
        y=xyz[mask, 1],
        z=xyz[mask, 2],
        color=dens[mask],
        # color=np.array(species[mask]).astype('str'),
        color_continuous_scale=cmap,
        range_x=[0, 1], range_y=[0, 1], range_z=[0, 1],
        opacity=1,
    )

plot(dens, n_skip=2)

In [None]:
from jaxtyping import Array, Float
from jax.scipy import special

@eqx.filter_jit
def choose(x, y):
    return jnp.exp(special.gammaln(x + 1) - special.gammaln(y + 1) - special.gammaln(x - y + 1))

@eqx.filter_jit
def legendre_poly_coef(n, k):
    return choose(n, k) * choose(n + k, k)

@eqx.filter_vmap
@eqx.filter_jit
def legendre_poly(x: Float[Array, ''], n: int):
    kk = jnp.arange(n + 1)
    return jnp.dot(((x - 1) / 2) ** kk, legendre_poly_coef(n, kk))


xx = jnp.linspace(-1, 1, 100)
for n in range(10):
    plt.plot(xx, legendre_poly(xx, n), label=f'n = {n}')

plt.legend()

In [None]:
xx = jnp.linspace(0, 2, 200)
plt.subplots(figsize=(15, 6))
for n in range(0, 7, 1):
    plt.plot(xx, jnp.cos(n * (xx * 2 - 1) * jnp.pi) ** 2, label=f'n = {n}', lw=1)

plt.legend()

In [None]:
ncheby = 9
nnn = jnp.mgrid[:ncheby, :ncheby, :ncheby][:, ::1, ::1, ::1]
nnn = jnp.array(rearrange(nnn, 'd n1 n2 n3 -> n1 n2 n3 d'))
nnn = nnn[:, :, :, None, :]

basis = jnp.cos(nnn * (xyz * 2 - 1) * jnp.pi)
basis.shape

In [None]:
cheby = basis.prod(axis=-1)
cheby = cheby

In [None]:
coefs = jnp.dot(cheby, dens) / jnp.sum(cheby ** 2, axis=-1)
coefs.shape

In [None]:
from einops import einsum
yhat = einsum(coefs, cheby, 'n1 n2 n3, n1 n2 n3 npt -> npt')
plot(yhat)

In [None]:
plot(2 * dens - 1)

In [None]:
jnp.mean(jnp.abs(yhat - dens))

In [None]:
sns.displot(yhat - dens)