# Explore the boundaries of a HEALPix tile and its neighbours

In [None]:
import healpy as hp
import matplotlib.pyplot as plt
import numpy as np

steps = 100
order = 4
nside = hp.order2nside(order)
# Tile from polar region
lonlat = 0, np.degrees(np.arcsin(0.6))
tile_id = hp.ang2pix(nside, *lonlat, nest=True, lonlat=True)

fig = plt.figure(figsize=(16, 6))
ax_cyl = fig.add_subplot(1, 2, 1)
ax_pol = fig.add_subplot(1, 2, 2, projection="polar")


def plot(phi, z, **kwargs):
    theta = np.degrees(np.arccos(z))
    ax_cyl.plot(phi, z, **kwargs)
    ax_pol.plot(np.radians(phi), theta, **kwargs)


def hp_boundaries(tile_id, *, n=steps):
    hp_boundaries_xyz = hp.boundaries(nside, tile_id, nest=True, step=n)
    hp_boundaries_lonlat = hp.vec2ang(hp_boundaries_xyz.T, lonlat=True)
    phi = np.where(hp_boundaries_lonlat[0] < 180, hp_boundaries_lonlat[0], hp_boundaries_lonlat[0] - 360)
    return np.stack([phi, hp_boundaries_xyz[2]])


plot(*hp_boundaries(tile_id), color="black", label="healpy")
for neighbour_tile_id in hp.get_all_neighbours(nside, *lonlat, nest=True, lonlat=True):
    if neighbour_tile_id > 0:
        plot(*hp_boundaries(neighbour_tile_id), color="black", lw=0.5, alpha=0.5, label=None)

lonlat_center = hp.pix2ang(nside, tile_id, nest=True, lonlat=True)
phi_c, z_c = lonlat_center[0], np.sin(np.radians(lonlat_center[1]))
if z_c > 2 / 3:
    i_c = np.sqrt(3) * nside * np.sqrt(1 - z_c)
    j_c = 2 * i_c / np.pi * np.radians(phi_c) - 0.5
    k_c = j_c + 0.5
    kp_c = i_c - j_c - 0.5
else:
    k_c = 3 * nside / 4 * (2 / 3 - z_c + 8 * np.radians(phi_c) / (3 * np.pi))
    kp_c = nside + 3 * nside / 4 * (2 / 3 - z_c - 8 * np.radians(phi_c) / (3 * np.pi))

# k_c += 0.7
# kp_c += 0.7

# plot(phi_c, z_c, ls='none', marker='x', color='black', label=f"i={i_c:.1f}, j={j_c:.1f}, k={k_c:.1f}, k'={kp_c:.1f}")


def phi_z_from_k(delta, *, n=steps):
    kp = kp_c + delta
    k = np.linspace(k_c - delta, k_c + delta, n)  # + np.where(kp >= 1, 0, 2 - kp)
    return phi_z(k, kp)


def phi_z_from_kp(delta, *, n=steps):
    kp = np.linspace(kp_c - delta, kp_c + delta, n)
    k = k_c + delta
    return phi_z(k, kp)


# def phi_z(k, kp):
#     j = np.abs(k) - 0.5
#     i = np.abs(kp) + np.abs(k)

#     z = np.where(i <= nside, 1 - (i / nside)**2 / 3, 4/3 - 2 * i / (3 * nside))

#     phi_polar = 0.5 * np.pi * (j + 0.5) / i
#     # phi_eq = 0.5 * np.pi * (nside - kp) / nside - 3 * np.pi / 8 * (z - 2/3)
#     phi_eq = np.pi/4/nside * (nside - np.abs(kp) + np.abs(k))
#     phi = np.where(i <= nside, phi_polar, phi_eq)
#     phi = np.where(kp >= 0, phi, np.pi - phi)
#     phi = np.where(k >= 0, phi, -phi)

#     return np.degrees(phi), z


def phi_z(k, kp):
    # return np.degrees(polar(k, kp)[0]), polar(k, kp)[1]
    eq_phi, eq_z = eq(k, kp)
    polar_phi, polar_z = polar(k, kp)
    z = np.where(eq_z <= 2 / 3, eq_z, polar_z)
    phi = np.where(eq_z <= 2 / 3, eq_phi, polar_phi)
    return np.degrees(phi), z


def eq(k, kp):
    z = 2 / 3 * (2 - (kp + k) / nside)
    phi = np.pi / 4 / nside * (nside - kp + k)
    return phi, z


def polar(k, kp):
    j = np.abs(k) - 0.5
    i = np.abs(kp) + np.abs(k)

    z = 1 - (i / nside) ** 2 / 3
    phi = 0.5 * np.pi * (j + 0.5) / i

    eq_phi, eq_z = eq(np.abs(k), np.abs(kp))

    z = np.where(np.abs(kp) + np.abs(k) <= nside, z, eq_z)
    phi = np.where(np.abs(kp) + np.abs(k) <= nside, phi, eq_phi)

    phi = np.where(kp >= 0, phi, np.pi - phi)
    phi = np.where(k >= 0, phi, -phi)

    return phi, z


# NE
plot(*phi_z_from_k(delta=-0.5), color="red", label="NE")
plot(*phi_z_from_k(delta=-1.5), color="red", ls="--", label=None)

# NW
plot(*phi_z_from_kp(delta=-0.5), color="orange", label="NW")
plot(*phi_z_from_kp(delta=-1.5), color="orange", ls="--", label=None)

# SW
plot(*phi_z_from_k(delta=0.5), color="green", label="SW")
plot(*phi_z_from_k(delta=1.5), color="green", ls="--", label=None)

# SE
plot(*phi_z_from_kp(delta=0.5), color="purple", label="SE")
plot(*phi_z_from_kp(delta=1.5), color="purple", ls="--", label=None)

ax_cyl.set_xlabel(r"$\phi$")
ax_cyl.set_ylabel(r"$z$")
ax_cyl.legend()

ax_pol.set_yticks(ax_pol.get_yticks())
ax_pol.set_yticklabels([90.0 - tick for tick in ax_pol.get_yticks()])