# Analysis for the minimum size for every tile of a grid

In [None]:
NSIDE = 3

## Imports

In [None]:
import os
from functools import partial

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
from numpyro.infer import MCMC, NUTS

from healpix_geometry_analysis.coordinates import HealpixCoordinates
from healpix_geometry_analysis.geometry.tile import TileGeometry
from healpix_geometry_analysis.problems.numpyro_sampler import NumpyroSamplerProblem
from healpix_geometry_analysis.problems.optax_optimizer import OptaxOptimizerProblem
from healpix_geometry_analysis.enable_x64 import enable_x64

enable_x64()

## Initializa a coordinate object, which knows few coordinate system transformations

In [None]:
coord = HealpixCoordinates.from_nside(NSIDE)

## Making a list of tiles

### Equatorial region requires a tile per each Northern Hemisphere ring

In [None]:
# Step between equatorial rings
delta_z = 2 / 3 / coord.grid.nside
# Step between meridian rings
delta_phi = 0.5 * jnp.pi / coord.grid.nside

# First, with longitude = 0
z_meridian = jnp.arange(1, coord.grid.nside - 2, 2) * delta_z
phi_meridian = jnp.zeros_like(z_meridian)
# Next, with a half-step over phi
z_offset = jnp.arange(0, coord.grid.nside - 1, 2) * delta_z
phi_offset = jnp.full_like(z_offset, 0.5 * delta_phi)

z_eq = jnp.concatenate([z_meridian, z_offset])
phi_eq = jnp.concatenate([phi_meridian, phi_offset])

k_eq, kp_eq = coord.diag_from_phi_z(phi_eq, z_eq)

### Intermidiate region requires all tiles from ring `z = 2/3 - delta_z`

In [None]:
phi_inter = jnp.arange(0, coord.grid.nside // 2 + 1) * delta_phi
z_inter = jnp.full_like(phi_inter, 2 / 3 - delta_z)

k_inter, kp_inter = coord.diag_from_phi_z(phi_inter, z_inter)

### Polar region requires all tiles in 0 < lon <= pi/4, 2/3 <= z < 1

In [None]:
# Use rectangular indices to define the tiles

# First, create a matrix of all possible pairs: we will filter it later
i_pol_ = jnp.arange(1, coord.grid.nside + 1)
j_pol_ = jnp.arange(0, coord.grid.nside)
i_pol_all, j_pol_all = jnp.meshgrid(i_pol_, j_pol_)

# Filter to have only j indices within a required "triangle"
j_pol_idx = j_pol_all <= (i_pol_all - 1) // 2
i_pol, j_pol = i_pol_all[j_pol_idx], j_pol_all[j_pol_idx]

# Get k & k'
k_pol = j_pol + 0.5
kp_pol = i_pol - j_pol - 0.5

### Combine all diagonal indices and create geometry objects

In [None]:
k = jnp.concatenate([k_eq, k_inter, k_pol])
kp = jnp.concatenate([kp_eq, kp_inter, kp_pol])

plt.scatter(*coord.phi_z(k, kp), s=10)
plt.xlabel(r"$\phi$")
plt.ylabel("$z$")

print(k.shape)

## Use NUTS sampler

In [None]:
%%time


@partial(jax.vmap, in_axes=[None, 0, 0, None])
def solve_with_nuts(direction, k_c, kp_c, random_seed=0):
    geometry = TileGeometry(
        coord=coord,
        k_center=k_c,
        kp_center=kp_c,
        direction=direction,
        distance="chord_squared",
    )
    problem = NumpyroSamplerProblem(geometry, track_arc_length=True)

    kernel = NUTS(problem.model)
    mcmc = MCMC(kernel, num_warmup=0, num_samples=10_000, jit_model_args=True, progress_bar=False)
    rng_key = jax.random.PRNGKey(random_seed)
    mcmc.run(rng_key)

    samples = mcmc.get_samples()

    argmin = jnp.argmin(samples["distance"])
    return jax.tree.map(lambda x: x[argmin], samples)


random_seeds = {"p": 1, "m": -1}
samples = {direction: solve_with_nuts(direction, k, kp, seed) for direction, seed in random_seeds.items()}

min_arc_length = min(float(jnp.min(samples["arc_length_degree"])) for samples in samples.values())
average_size = coord.grid.average_pixel_size_degree
ratio = min_arc_length / average_size
print(f"{min_arc_length = :.4f}, {average_size = : .4f} {ratio = : .4f}")

## Use AdaBelief optimizer

In [None]:
%%time


@partial(jax.vmap, in_axes=[None, 0, 0, None])
def solve_with_ada(direction, k_c, kp_c, random_seed=0):
    geometry = TileGeometry(
        coord=coord,
        k_center=k_c,
        kp_center=kp_c,
        direction=direction,
        distance="chord_squared",
    )
    problem = OptaxOptimizerProblem(geometry)

    optimizer = problem.freeze_optimizer(optax.adabelief(1e-1))
    rng_key = jax.random.PRNGKey(random_seed)
    params = problem.initial_params(rng_key)
    opt_state = optimizer.init(params)

    for _ in range(100):
        loss, grads = jax.value_and_grad(problem.loss)(params)
        grads = jax.tree.map(lambda x: jnp.where(jnp.isfinite(x), x, 0.0), grads)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        params = optax.projections.projection_box(
            params, problem.geometry.lower_bounds, problem.geometry.upper_bounds
        )

    arc_distance_deg = problem.geometry.arc_length_degrees(loss)
    return arc_distance_deg


random_seeds = {"p": 1, "m": -1}
arc_distance_deg = jnp.concatenate(
    [solve_with_ada(direction, k, kp, seed) for direction, seed in random_seeds.items()]
)
min_arc_length = jnp.min(arc_distance_deg)
average_size = coord.grid.average_pixel_size_degree
ratio = min_arc_length / average_size
print(f"{min_arc_length = :.4f}, {average_size = : .4f} {ratio = : .4f}")