# WCS Reprojection Usage Examples (XRADIO)

This notebook **generates synthetic XRADIO images** using `xradio.image.make_empty_sky_image`,
then demonstrates WCS-aware reprojection with `wcs_reproject.py`.

Each example explains:
- the goal of the reprojection,
- what you should see in the output image,
- which quantity is expected to be conserved **based on the brightness unit**, and why.


## Setup
This section imports dependencies and defines helpers for generating synthetic XRADIO images.
The generated images are small so the examples run quickly.


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr

from xradio.image import make_empty_sky_image
from wcs_reproject import reproject_to_match, reproject_to_frame

def _gaussian_2d(l_vals, m_vals, amp, l0, m0, sigma_l, sigma_m):
    ll, mm = np.meshgrid(l_vals, m_vals, indexing='ij')
    rr = ((ll - l0) / sigma_l) ** 2 + ((mm - m0) / sigma_m) ** 2
    return amp * np.exp(-0.5 * rr)

def _make_template(n_l, n_m, cell_arcsec, frame='fk5', projection='SIN'):
    phase_center = [0.0, 0.0]  # [ra, dec] in rad
    image_size = [n_l, n_m]
    cell = np.deg2rad(cell_arcsec / 3600.0)
    cell_size = [cell, cell]
    frequency_coords = np.array([1.4e9])
    pol_coords = ['I']
    time_coords = np.array([59000.0])

    return make_empty_sky_image(
        phase_center=phase_center,
        image_size=image_size,
        cell_size=cell_size,
        frequency_coords=frequency_coords,
        pol_coords=pol_coords,
        time_coords=time_coords,
        direction_reference=frame,
        projection=projection,
        spectral_reference='lsrk',
        do_sky_coords=True,
    )

def _attach_sky(xds, sky_vals, units):
    dims = ('time', 'frequency', 'polarization', 'l', 'm')
    coords = {d: xds.coords[d] for d in dims}
    xds['SKY'] = xr.DataArray(sky_vals, dims=dims, coords=coords)
    xds['SKY'].attrs.update({'image_type': 'Intensity', 'type': 'sky', 'units': units})
    base_group = {'sky': 'SKY'}
    if units == 'Jy/beam':
        # Define a synthetic restoring beam so Jy/beam has valid beam metadata.
        beam = np.zeros((xds.sizes['time'], xds.sizes['frequency'], xds.sizes['polarization'], 3), dtype=np.float64)
        beam[..., 0] = np.deg2rad(15.0 / 3600.0)  # major axis [rad]
        beam[..., 1] = np.deg2rad(10.0 / 3600.0)  # minor axis [rad]
        beam[..., 2] = 0.0  # position angle [rad]
        xds['BEAM_FIT_PARAMS_SKY'] = xr.DataArray(
            beam,
            dims=('time', 'frequency', 'polarization', 'beam_params_label'),
            coords={
                'time': xds.coords['time'],
                'frequency': xds.coords['frequency'],
                'polarization': xds.coords['polarization'],
                'beam_params_label': xds.coords['beam_params_label'],
            },
        )
        xds['BEAM_FIT_PARAMS_SKY'].attrs['units'] = 'rad'
        base_group['beam_fit_params_sky'] = 'BEAM_FIT_PARAMS_SKY'
    xds.attrs['data_groups'] = {'base': base_group}
    xds.attrs['type'] = 'image_dataset'
    return xds

def _make_point_source_jy_per_pixel(n_l=128, n_m=128, cell_arcsec=2.0, loc=None):
    xds = _make_template(n_l, n_m, cell_arcsec, frame='fk5')
    base = np.zeros((n_l, n_m), dtype=np.float64)
    if loc is None:
        base[n_l // 2, n_m // 2] = 1.0
    else:
        base[loc] = 1.0
    sky = base[None, None, None, :, :]
    return _attach_sky(xds, sky, units='Jy/pixel')

def _make_gaussian_jy_per_beam(n_l=128, n_m=128, cell_arcsec=2.0, loc=None):
    xds = _make_template(n_l, n_m, cell_arcsec, frame='fk5')
    l = xds.coords['l'].values
    m = xds.coords['m'].values
    if loc is None:
        i_l, i_m = n_l // 2, n_m // 2
    else:
        i_l, i_m = loc
    base = _gaussian_2d(
        l, m, amp=1.0,
        l0=l[i_l],
        m0=m[i_m],
        sigma_l=5e-5,
        sigma_m=3e-5,
    )
    sky = base[None, None, None, :, :]
    return _attach_sky(xds, sky, units='Jy/beam')

def _plot_plane(da, title):
    arr = da.isel(time=0, frequency=0, polarization=0).values
    l_arcsec = np.rad2deg(da.coords['l'].values) * 3600.0
    m_arcsec = np.rad2deg(da.coords['m'].values) * 3600.0
    extent = [l_arcsec.max(), l_arcsec.min(), m_arcsec.min(), m_arcsec.max()]
    plt.figure(figsize=(4, 4))
    plt.imshow(arr, origin='lower', extent=extent, aspect='equal')
    plt.title(title)
    plt.xlabel('l offset (arcsec)')
    plt.ylabel('m offset (arcsec)')
    plt.colorbar()
    plt.tight_layout()


## Example 1: Reproject to Match a Target Grid (Jy/pixel)
**Goal:** reproject a `Jy/pixel` source image onto a target image with a different pixel size.

**Expected output:**
- Output has the same shape and grid as the target.
- The point source should remain centered, but the pixel sampling changes.

**Quantity conserved:** **integrated flux**.
- `Jy/pixel` represents flux per pixel area.
- When pixel sizes change, total integrated flux should be preserved.
- Use a **flux-conserving** method (e.g. `exact` or `adaptive`).


In [None]:
src = _make_point_source_jy_per_pixel(n_l=128, n_m=128, cell_arcsec=2.0, loc=(90,40))
tgt = _make_template(n_l=160, n_m=160, cell_arcsec=1.5, frame='fk5')

out = reproject_to_match(src, tgt, data_group='base', method='exact')

_plot_plane(src['SKY'], 'Source (Jy/pixel)')
_plot_plane(out['SKY'], 'Reprojected to Target Grid (Flux Conserved)')

In [None]:
flux = {}
ra = {}
dec = {}
dl = {}
dm = {}
for xds, label in zip([src,out], ["input", "output"]):
    sky = xds["SKY"]
    print(f"{label} shape: {sky.shape}")
    l = np.rad2deg(xds.l.values) * 3600
    m = np.rad2deg(xds.m.values) * 3600
    dl[label] = abs(l[1] - l[0])
    dm[label] = abs(m[1] - m[0])
    pixel_area = dl[label] * dm[label]
    flux[label] = float(xds.SKY.sum().values) * pixel_area
    print(
        f"{label} cell size: {dl} x {dm} arcsec"
    )
    print(f"{label} area-weighted sum (flux): {flux[label]}")
    arr = sky.values
    idx = np.unravel_index(np.argmax(arr), arr.shape)
    idx_list = [int(i) for i in idx]
    print(f"{label} max value pixel coord: {idx_list}")
    print("ra shape", xds["right_ascension"].shape)
    ra[label] = xds["right_ascension"].values[idx_list[3], idx_list[4]]
    dec[label] = xds["declination"].values[idx_list[3], idx_list[4]]
    print(f"{label} max value world coord (ra, dec): {ra}, {dec}")
result = "flux is conserved as" if np.isclose(flux["input"], flux["output"], rtol=1e-4) else "flux is not conserved which is not"
diff = np.abs(1 - flux["output"]/flux["input"])
result += f" expected for this case. Relative difference: {diff}"
print()
print(result)
ra_diff = np.rad2deg(
    abs(ra["output"] - ra["input"])*3600*np.cos(dec["input"])
)
dec_diff = np.rad2deg(abs(dec["output"] - dec["input"])) * 3600
print(
    f"ra, dec difference of peak: {ra_diff}, {dec_diff} arcsec"
)
if (
    ra_diff <= max(dl["input"], dl["output"])
    and dec_diff <= max(dm["input"], dm["output"])
):
    res = "less"
else:
    res = "greater"
print(f"world coord difference {res} than larger pixel size")

    
    


## Example 2: Reproject to Match a Target Grid (Jy/beam)
**Goal:** reproject a `Jy/beam` image to a target grid using interpolation.

**Expected output:**
- Output has the same shape and grid as the target.
- The Gaussian source should preserve its **local shape and peak** approximately.
- The source is intentionally off-center to verify this is not a center-only special case.

**Quantity conserved:** **flux density** (for fixed beam), not raw pixel sum.
- In the `Jy/beam` case, the physically meaningful integrated quantity is flux density, not sum, and it scales like `sum(SKY) * pixel_area / beam_area`.
- Raw `sum(SKY)` changes when pixel area changes, so it is not conserved.
- Interpolation should preserve local morphology/peak approximately while this integrated flux-density quantity remains consistent.


In [None]:
src = _make_gaussian_jy_per_beam(n_l=128, n_m=128, cell_arcsec=2.0, loc=(90, 40))
tgt = _make_template(n_l=160, n_m=160, cell_arcsec=1.5, frame='fk5')

out = reproject_to_match(src, tgt, data_group='base', method='interp', order=1)

_plot_plane(src['SKY'], 'Source (Jy/beam)')
_plot_plane(out['SKY'], 'Reprojected to Target Grid (Interpolated)')


In [None]:
metrics = {}
for xds, label in zip([src, out], ["input", "output"]):
    sky = xds["SKY"]
    l_arcsec = np.rad2deg(xds.l.values) * 3600.0
    m_arcsec = np.rad2deg(xds.m.values) * 3600.0
    dl = float(abs(l_arcsec[1] - l_arcsec[0]))
    dm = float(abs(m_arcsec[1] - m_arcsec[0]))
    arr = sky.values
    idx = np.unravel_index(np.argmax(arr), arr.shape)
    idx_list = [int(i) for i in idx]
    i_l, i_m = idx_list[-2], idx_list[-1]
    peak_ra = float(xds["right_ascension"].values[i_l, i_m])
    peak_dec = float(xds["declination"].values[i_l, i_m])
    metrics[label] = {
        "shape": sky.shape,
        "dl": dl,
        "dm": dm,
        "peak": float(sky.max().values),
        "sum": float(sky.sum().values),
        "idx": idx_list,
        "ra": peak_ra,
        "dec": peak_dec,
    }

shape_match = (
    out["SKY"].sizes["l"] == tgt.sizes["l"]
    and out["SKY"].sizes["m"] == tgt.sizes["m"]
)
peak_rel_diff = abs(metrics["output"]["peak"] - metrics["input"]["peak"]) / metrics["input"]["peak"]
peak_ok = peak_rel_diff <= 0.05

cell_area_input = metrics["input"]["dl"] * metrics["input"]["dm"]
cell_area_output = metrics["output"]["dl"] * metrics["output"]["dm"]
sum_ratio = metrics["output"]["sum"] / metrics["input"]["sum"]
inv_cell_area_ratio = cell_area_input / cell_area_output

# For Jy/beam maps: integrated flux density scales as sum * pixel_area / beam_area.
# If beam_area is unchanged between input/output, comparing sum*pixel_area is sufficient
# because beam_area cancels in the ratio.
beam_area_arcsec2 = src["SKY"].attrs.get("beam_area_arcsec2", None)
if beam_area_arcsec2 is not None:
    fd_input = metrics["input"]["sum"] * cell_area_input / beam_area_arcsec2
    fd_output = metrics["output"]["sum"] * cell_area_output / beam_area_arcsec2
    fd_label = "Integrated flux density estimate (Jy)"
else:
    fd_input = metrics["input"]["sum"] * cell_area_input
    fd_output = metrics["output"]["sum"] * cell_area_output
    fd_label = "Integrated flux-density proxy (sum*pixel_area; beam area cancels)"
fd_rel_diff = abs(fd_output - fd_input) / abs(fd_input)
fd_ok = fd_rel_diff <= 0.02

ra_diff_arcsec = np.rad2deg(abs(metrics["output"]["ra"] - metrics["input"]["ra"]) * np.cos(metrics["input"]["dec"])) * 3600.0
dec_diff_arcsec = np.rad2deg(abs(metrics["output"]["dec"] - metrics["input"]["dec"])) * 3600.0
within_pixel = (
    ra_diff_arcsec <= max(metrics["input"]["dl"], metrics["output"]["dl"])
    and dec_diff_arcsec <= max(metrics["input"]["dm"], metrics["output"]["dm"])
)

print("Example 2 Numerical Verification (Jy/beam, interp)")
print("-" * 56)
print(f"Input shape:  {metrics['input']['shape']}")
print(f"Output shape: {metrics['output']['shape']}")
print(f"Target (l,m): ({tgt.sizes['l']}, {tgt.sizes['m']})")
print(f"Shape/grid match target: {shape_match}")
print()
print(f"Input cell size:  {metrics['input']['dl']:.6f} x {metrics['input']['dm']:.6f} arcsec")
print(f"Output cell size: {metrics['output']['dl']:.6f} x {metrics['output']['dm']:.6f} arcsec")
print()
print(f"Input peak Jy/beam:  {metrics['input']['peak']:.6e}")
print(f"Output peak Jy/beam: {metrics['output']['peak']:.6e}")
print(f"Peak relative difference: {peak_rel_diff:.6e} (<= 5% ? {peak_ok})")
print()
print(f"Input sum (not conserved for Jy/beam):  {metrics['input']['sum']:.6e}")
print(f"Output sum (not conserved for Jy/beam): {metrics['output']['sum']:.6e}")
print(f"Sum ratio output/input: {sum_ratio:.6e}")
print(f"Reciprocal cell-area ratio (A_in/A_out): {inv_cell_area_ratio:.6e}")
print()
print(fd_label)
print(f"Input:  {fd_input:.6e}")
print(f"Output: {fd_output:.6e}")
print(f"Relative difference: {fd_rel_diff:.6e} (<= 2% ? {fd_ok})")
print()
print(f"Input peak pixel index:  {metrics['input']['idx']}")
print(f"Output peak pixel index: {metrics['output']['idx']}")
print(f"Peak world offset (RA, Dec): ({ra_diff_arcsec:.6f}, {dec_diff_arcsec:.6f}) arcsec")
print(f"Peak world offset within <= 1 pixel scale: {within_pixel}")
print()
all_ok = shape_match and peak_ok and fd_ok and within_pixel
summary = "PASS" if all_ok else "CHECK"
print(f"Result summary: {summary}")


## Example 3: Reproject to a New Sky Frame (ICRS -> Galactic)
**Goal:** change the sky frame while keeping the same pixel grid.

**Expected output:**
- Output retains the original grid shape.
- The source appears **rotated** in the new frame.
- Output world-coordinate axes are in the target frame (`galactic_longitude`, `galactic_latitude`).
- Original world-coordinate axes can be retained via `keep_input_world_coords=True`.

**Quantity conserved:**
- If `Jy/pixel` and using `exact/adaptive`, total flux is conserved.
- If `Jy/beam` and using `interp`, local intensity patterns are preserved.


In [None]:
import importlib
import wcs_reproject

# Ensure notebook picks up latest local edits to wcs_reproject.py
importlib.reload(wcs_reproject)
reproject_to_frame = wcs_reproject.reproject_to_frame

src = _make_gaussian_jy_per_beam(
        n_l=128, n_m=128, cell_arcsec=2.0
    )

out = reproject_to_frame(
    src,
    'galactic',
    keep_grid=True,
    method='interp',
    order=1,
    keep_input_world_coords=True,
)
print('Output world coords:', [c for c in out.coords if c in ('galactic_longitude', 'galactic_latitude', 'input_right_ascension', 'input_declination')])
if 'BEAM_FIT_PARAMS_SKY' in src and 'BEAM_FIT_PARAMS_SKY' in out:
    src_pa = float(src['BEAM_FIT_PARAMS_SKY'].sel(beam_params_label='pa').isel(time=0, frequency=0, polarization=0).values)
    out_pa = float(out['BEAM_FIT_PARAMS_SKY'].sel(beam_params_label='pa').isel(time=0, frequency=0, polarization=0).values)
    print(f"Beam PA (rad): input={src_pa:.6e}, output={out_pa:.6e}, delta={out_pa-src_pa:.6e}")

_plot_plane(src['SKY'], 'Source (FK5)')
_plot_plane(out['SKY'], 'Reprojected to Galactic Frame (Same Grid)')


In [None]:
src

In [None]:
out

## Example 4: Reproject to a New Frame With Re-Centered Grid (Jy/pixel)
**Goal:** change to a new frame and rebuild a same-sized grid centered in the target frame.

**Expected output:**
- Same shape and pixel size as the source.
- Spatial coordinates re-centered for the target frame.

**Quantity conserved:** **integrated flux** for `Jy/pixel` when using a flux-conserving method.


In [None]:
src = _make_point_source_jy_per_pixel(n_l=128, n_m=128, cell_arcsec=2.0)

out = reproject_to_frame(src, 'galactic', keep_grid=False, method='exact')

_plot_plane(src['SKY'], 'Source (Jy/pixel, FK5)')
_plot_plane(out['SKY'], 'Reprojected to Galactic (Re-centered Grid)')
