# Analysis for the minimum size for every tile of a grid

In [None]:
NSIDE = 1 << 4

## Imports

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax

from healpix_geometry_analysis.coordinates import HealpixCoordinates
from healpix_geometry_analysis.geometry.tile import TileGeometry
from healpix_geometry_analysis.problems.optax_optimizer import OptaxOptimizerProblem
from healpix_geometry_analysis.enable_x64 import enable_x64

# Use float64 with Jax
enable_x64()

## Initializa a coordinate object, which knows few coordinate system transformations

In [None]:
coord = HealpixCoordinates.from_nside(NSIDE)

## Making a list of tiles

In [None]:
k, kp = coord.unique_tiles()
print(f"tile count: {k.size}")

In [None]:
if k.size < 100_000:
    plt.scatter(*coord.phi_z(k, kp), s=10)
    plt.xlabel(r"$\phi$")
    plt.ylabel("$z$")
    plt.hlines(2 / 3, 0, jnp.pi / 4, ls="--", color="gray")
else:
    print("Too many tiles to plot")

## Use NUTS sampler

In [None]:
# %%time
#
#
# @partial(jax.vmap, in_axes=[None, 0, 0, None])
# def solve_with_nuts(direction, k_c, kp_c, random_seed=0):
#     geometry = TileGeometry(
#         coord=coord,
#         k_center=k_c,
#         kp_center=kp_c,
#         direction=direction,
#         distance="chord_squared",
#     )
#     problem = NumpyroSamplerProblem(geometry, track_arc_length=True)
#
#     kernel = NUTS(problem.model)
#     mcmc = MCMC(kernel, num_warmup=0, num_samples=10_000, jit_model_args=True, progress_bar=False)
#     rng_key = jax.random.PRNGKey(random_seed)
#     mcmc.run(rng_key)
#
#     samples = mcmc.get_samples()
#
#     argmin = jnp.argmin(samples["distance"])
#     return jax.tree.map(lambda x: x[argmin], samples)
#
#
# random_seeds = {"p": 1, "m": -1}
# samples = {direction: solve_with_nuts(direction, k, kp, seed) for direction, seed in random_seeds.items()}
#
# min_arc_length = min(float(jnp.min(samples["arc_length_degree"])) for samples in samples.values())
# average_size = coord.grid.average_pixel_size_degree
# ratio = min_arc_length / average_size
# print(f"{min_arc_length = :.6f}, {average_size = : .4f} {ratio = : .4f}")

## Use AdaBelief optimizer

In [None]:
%%time


def solve_with_ada(direction, k_c, kp_c, random_seed=0, n_iter=100):
    geometry = TileGeometry(
        coord=coord,
        k_center=k_c,
        kp_center=kp_c,
        direction=direction,
        distance="chord_squared",
    )
    problem = OptaxOptimizerProblem(geometry)

    optimizer = problem.freeze_optimizer(optax.adabelief(1e-1))
    rng_key = jax.random.PRNGKey(random_seed)
    params = problem.initial_params(rng_key)
    opt_state = optimizer.init(params)
    value_and_grad = jax.jit(jax.value_and_grad(problem.loss))

    for _ in range(n_iter):
        loss, grads = value_and_grad(params)
        grads = jax.tree.map(lambda x: jnp.where(jnp.isfinite(x), x, 0.0), grads)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        params = optax.projections.projection_box(
            params, problem.geometry.lower_bounds, problem.geometry.upper_bounds
        )

    arc_distance_radian = problem.geometry.arc_length_radians(loss)
    return arc_distance_radian


vmapped = jax.vmap(solve_with_ada, in_axes=(None, 0, 0, None, None))
random_seeds = {"p": 1, "m": -1}
arc_distance_radian = jnp.concatenate(
    [vmapped(direction, k, kp, seed, 100) for direction, seed in random_seeds.items()]
)
argmin = jnp.argmin(arc_distance_radian)

k_center = k[argmin % k.size]
kp_center = kp[argmin % k.size]
direction = list(random_seeds)[argmin // k.size]
min_arc_length_radian = solve_with_ada(direction, k_center, kp_center, random_seed=0, n_iter=10_000)
ra_min, dec_min = coord.lonlat_degrees(k_center, kp_center)
average_size = coord.grid.average_pixel_size_radian
ratio = min_arc_length_radian / average_size
print(
    f"{min_arc_length_radian = :.16e}, {average_size = : .4f} {ratio = : .4f} {ra_min = :.5f}, {dec_min = :.5f} {direction = }"
)