# Run floating slide analysis and compare to brute force method

In [None]:
NSIDE = 1 << 8

DISTANCE_TYPE = "chord_squared"

## Imports

In [None]:
import jax
import jax.numpy as jnp
import optax
from numpyro.infer import MCMC, NUTS

from healpix_geometry_analysis.coordinates import HealpixCoordinates
from healpix_geometry_analysis.geometry.equatorial import EquatorialGeometry
from healpix_geometry_analysis.geometry.intermediate import IntermediateGeometry
from healpix_geometry_analysis.geometry.polar import PolarGeometry
from healpix_geometry_analysis.geometry.tile import TileGeometry
from healpix_geometry_analysis.problems.numpyro_sampler import NumpyroSamplerProblem
from healpix_geometry_analysis.problems.optax_optimizer import OptaxOptimizerProblem
from healpix_geometry_analysis.enable_x64 import enable_x64

# Use float64 with Jax
enable_x64()

## Initializa a coordinate object, which knows few coordinate system transformations

In [None]:
coord = HealpixCoordinates.from_nside(NSIDE)

## Define solvers

In [None]:
def solve_with_nuts(geometry, *, random_seed, num_samples):
    problem = NumpyroSamplerProblem(geometry, track_arc_length=False)
    kernel = NUTS(problem.model)
    mcmc = MCMC(
        kernel,
        num_samples=num_samples,
        num_warmup=0,
        num_chains=1,
        jit_model_args=True,
        progress_bar=False,
    )
    rng_key = jax.random.PRNGKey(random_seed)
    mcmc.run(rng_key)

    samples = mcmc.get_samples()

    argmin = jnp.argmin(samples["distance"])
    min_samples = jax.tree.map(lambda x: x[argmin], samples) | geometry.frozen_parameters

    k1, k2, kp1, kp2 = geometry.diagonal_indices(min_samples)
    k_mean = 0.5 * (k1 + k2)
    kp_mean = 0.5 * (kp1 + kp2)
    arc_distance_radian = geometry.arc_length_radians(geometry.calc_distance(min_samples))

    return arc_distance_radian, k_mean, kp_mean


def solve_with_ada(geometry, *, random_seed, n_iter):
    problem = OptaxOptimizerProblem(geometry)

    optimizer = problem.freeze_optimizer(optax.adabelief(1e-1))
    rng_key = jax.random.PRNGKey(random_seed)
    params = problem.initial_params(rng_key)
    opt_state = optimizer.init(params)
    value_and_grad = jax.jit(jax.value_and_grad(problem.loss))

    for _ in range(n_iter):
        loss, grads = value_and_grad(params)
        grads = jax.tree.map(lambda x: jnp.where(jnp.isfinite(x), x, 0.0), grads)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        params = optax.projections.projection_box(params, geometry.lower_bounds, geometry.upper_bounds)

    arc_distance_radian = geometry.arc_length_radians(loss)

    k1, k2, kp1, kp2 = geometry.diagonal_indices(params)
    k1_mean = 0.5 * (k1 + k2)
    kp1_mean = 0.5 * (kp1 + kp2)

    return arc_distance_radian, k1_mean, kp1_mean


def solve_with_ada_one_tile(coord, direction, k_c, kp_c, random_seed=0, n_iter=100):
    geometry = TileGeometry(
        coord=coord,
        k_center=k_c,
        kp_center=kp_c,
        direction=direction,
        distance="chord_squared",
    )
    arc_distance_radian, *_diag = solve_with_ada(geometry, random_seed=random_seed, n_iter=n_iter)
    return arc_distance_radian


solve_with_ada_one_tile_vmap = jax.vmap(
    solve_with_ada_one_tile,
    in_axes=(None, None, 0, 0, None, None),
)

In [None]:
def solve_float(geometry, *, random_seed, num_samples, n_iter, solver="nuts"):
    if solver == "nuts":
        return solve_with_nuts(geometry, random_seed=random_seed, num_samples=num_samples)
    if solver == "ada":
        return solve_with_ada(geometry, random_seed=random_seed, n_iter=n_iter)

In [None]:
def solve_1by1(coord, direction, k, kp, *, random_seed, n_iter_per_tile):
    arc_distance_radian = solve_with_ada_one_tile_vmap(
        coord,
        direction,
        k,
        kp,
        random_seed,
        n_iter_per_tile,
    )

    argmin = jnp.argmin(arc_distance_radian)

    return arc_distance_radian[argmin], k[argmin], kp[argmin]

## Run analysis separately for each of the three regions

### Equatorial region

In [None]:
%%time

intermediate_geometry = EquatorialGeometry(coord=coord, distance=DISTANCE_TYPE)

float_dist, float_k, float_kp = solve_float(
    intermediate_geometry,
    random_seed=0,
    num_samples=100_000,
    n_iter=10_000,
)

one_dist, one_k, one_kp = solve_1by1(
    coord,
    intermediate_geometry.direction,
    *coord.unique_equatorial_tiles(),
    random_seed=0,
    n_iter_per_tile=1000,
)

print(
    f"Floating: {float_dist} k={float_k} kp={float_kp} z={float(coord.phi_z(jnp.floor(float_k) + 0.5, jnp.floor(float_kp) + 0.5)[1])}"
)
print(f"One  by one: {one_dist} k={one_k} kp={one_kp} z={float(coord.phi_z(one_k, one_kp)[1])}")

### Intermediate region

In [None]:
%%time

for direction in ["p", "m"]:
    intermediate_geometry = IntermediateGeometry(coord=coord, distance=DISTANCE_TYPE, direction=direction)

    float_dist, float_k, float_kp = solve_float(
        intermediate_geometry,
        random_seed=0,
        num_samples=1_000_000,
        n_iter=10_000,
    )

    one_dist, one_k, one_kp = solve_1by1(
        coord,
        intermediate_geometry.direction,
        *coord.unique_intermediate_tiles(),
        random_seed=0,
        n_iter_per_tile=1000,
    )

    print(f"Direction: {direction}")
    print(f"Floating: {float_dist} k={float_k} kp={float_kp}")
    print(f"One  by one: {one_dist} k={one_k} kp={one_kp}")

### Polar region

In [None]:
%%time

for direction in ["p", "m"]:
    polar_geometry = PolarGeometry(coord=coord, distance=DISTANCE_TYPE, direction=direction)

    float_dist, float_k, float_kp = solve_float(
        polar_geometry,
        random_seed=0,
        num_samples=1_000_000,
        n_iter=10_000,
    )

    one_dist, one_k, one_kp = solve_1by1(
        coord,
        polar_geometry.direction,
        *coord.unique_polar_tiles(),
        random_seed=0,
        n_iter_per_tile=1000,
    )

    print(f"Direction: {direction}")
    print(f"Floating: {float_dist} k={float_k} kp={float_kp}")
    print(f"One  by one: {one_dist} k={one_k} kp={one_kp}")