# Multi-pair visibilities with CROISSANT JAX

This notebook extends the single-beam example (`croissant_jax.ipynb`) to simulate
two antenna beams simultaneously using `croissant.jax.multipair`.

In [None]:
import croissant as cro
import croissant.jax as crojax
from functools import partial
from healpy import get_nside, projview
import jax
import jax.numpy as jnp
import lunarsky
import matplotlib.pyplot as plt
import s2fft

In [None]:
# simulation parameters
world = "moon"
freq = jnp.arange(1, 51)  # 1-50 MHz
time = lunarsky.Time("2025-12-01 09:00:00")  # time at the beginning of the simulation
loc = lunarsky.MoonLocation(lon=0, lat=-22.5)  # location of telescope
topo = lunarsky.LunarTopo(obstime=time, location=loc) # coordinate frame of telescope
# 240 bins in a sidereal day on the moon
ntimes = 240
dt = cro.constants.sidereal_day[world] / ntimes

## Two beams

We create two short-dipole beams: one aligned with the x-axis and one with the y-axis.

In [None]:
beam_L = 90
theta = s2fft.sampling.s2_samples.thetas(L=beam_L, sampling="dh")
phi = s2fft.sampling.s2_samples.phis_equiang(L=beam_L, sampling="dh")
phi, theta = jnp.meshgrid(phi, theta)
ct = jnp.cos(theta)
cp = jnp.cos(phi)
sp = jnp.sin(phi)

# beam 0: x-dipole
beam_x = (freq[:, None, None]/freq[-1])**(-2) * (ct**2 * cp**2 + sp**2)[None, :, :]
# beam 1: y-dipole
beam_y = (freq[:, None, None]/freq[-1])**(-2) * (ct**2 * sp**2 + cp**2)[None, :, :]

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
axes[0].imshow(beam_x[30], aspect="auto")
axes[0].set_title("x-dipole")
axes[1].imshow(beam_y[30], aspect="auto")
axes[1].set_title("y-dipole")
for ax in axes:
    ax.set_xlabel("$\\phi$")
    ax.set_ylabel("$\\theta$")
plt.tight_layout()
plt.show()

In [None]:
# transform beams to alm
beam2alm = partial(s2fft.forward_jax, L=beam_L, spin=0, nside=None, sampling="dh", reality=True)
beam_x_alm = jax.vmap(beam2alm)(beam_x)
beam_y_alm = jax.vmap(beam2alm)(beam_y)

# normalization: sqrt(total_power_p * total_power_q) for each pair
lmax = beam_L - 1
tp_x = crojax.alm.total_power(beam_x_alm, lmax)
tp_y = crojax.alm.total_power(beam_y_alm, lmax)

## Sky

In [None]:
sky_map = jnp.load("ulsa.npy")
ix = -6
projview(m=sky_map[ix], title=f"ULSA sky at {freq[ix]} MHz")
plt.show()

In [None]:
nside = get_nside(sky_map[0])
sky_L = 2*nside
sky2alm = partial(s2fft.forward, L=sky_L, spin=0, nside=nside, sampling="healpix", method="jax_healpy", reality=True)
sky_alm = jnp.array([sky2alm(m) for m in sky_map])

## Coordinate transforms and setup

In [None]:
sim_L = sky_L
sim_lmax = sim_L - 1

beam_x_alm = crojax.alm.reduce_lmax(beam_x_alm, sim_lmax)
beam_y_alm = crojax.alm.reduce_lmax(beam_y_alm, sim_lmax)

phases = crojax.simulator.rot_alm_z(sim_lmax, ntimes, dt, world=world)

eul_topo, dl_topo = crojax.rotations.generate_euler_dl(sim_lmax, topo, "mcmf")
eul_gal, dl_gal = crojax.rotations.generate_euler_dl(sim_lmax, "galactic", "mcmf")

topo2mcmf = partial(s2fft.utils.rotation.rotate_flms, L=sim_L, rotation=eul_topo, dl_array=dl_topo)
gal2mcmf = partial(s2fft.utils.rotation.rotate_flms, L=sim_L, rotation=eul_gal, dl_array=dl_gal)

beam_x_alm = jax.vmap(topo2mcmf)(beam_x_alm)
beam_y_alm = jax.vmap(topo2mcmf)(beam_y_alm)
sky_alm = jax.vmap(gal2mcmf)(sky_alm)

## Multi-pair simulation

Stack the two auto-correlation pair beams along axis 0 and call `compute_visibilities`.

In [None]:
# stack pair beams: (N_pairs, N_freq, lmax+1, 2*lmax+1)
beam_alm = jnp.stack([beam_x_alm, beam_y_alm], axis=0)

# normalization for each pair
norm = jnp.array([tp_x, tp_y])

# simulate all pairs at once
vis = crojax.multipair.compute_visibilities(beam_alm, sky_alm, phases, norm)
print(f"Output shape (N_times, N_pairs, N_freq): {vis.shape}")

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
for i, label in enumerate(["x-dipole", "y-dipole"]):
    im = axes[i].imshow(jnp.log(vis[:, i, :].real), aspect="auto")
    axes[i].set_title(label)
    axes[i].set_xlabel("Frequency bin")
    axes[i].set_ylabel("Time bin")
    plt.colorbar(im, ax=axes[i])
plt.tight_layout()
plt.show()

In [None]:
plt.figure()
plt.plot(freq, vis[150, 0, :].real, label="x-dipole")
plt.plot(freq, vis[150, 1, :].real, label="y-dipole")
plt.xlabel("Frequency [MHz]")
plt.ylabel("Antenna temperature [K]")
plt.yscale("log")
plt.legend()
plt.show()