# WCS Reprojection Usage Examples (XRADIO)

This notebook **generates synthetic XRADIO images** using `xradio.image.make_empty_sky_image`,
then demonstrates WCS-aware reprojection with `wcs_reproject.py`.

Each example explains:
- the goal of the reprojection,
- what you should see in the output image,
- which quantity is expected to be conserved **based on the brightness unit**, and why.


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from astropy.wcs import WCS

def generate_gaussian_xy(shape, x_c, y_c, sigma_a, sigma_b, theta_math, amplitude=100):
    """
    Generates a 2D Gaussian array with data[x, y] indexing.
    
    theta_math: Angle in radians from pixel +x (Right) toward pixel +y (Up).
    sigma_a: Major axis standard deviation.
    sigma_b: Minor axis standard deviation.
    """
    if sigma_b > sigma_a:
        raise ValueError("sigma_a must be the major axis (>= sigma_b)")
        
    # In data[x, y] convention, first axis must be X
    x, y = np.indices(shape)
    
    cos_t = np.cos(theta_math)
    sin_t = np.sin(theta_math)
    
    # Rotation transformation
    x_rot = (x - x_c) * cos_t + (y - y_c) * sin_t
    y_rot = -(x - x_c) * sin_t + (y - y_c) * cos_t
    
    g = amplitude * np.exp(-( (x_rot**2 / (2 * sigma_a**2)) + (y_rot**2 / (2 * sigma_b**2)) ))
    return g

def calculate_theta_math(data):
    """
    Recovers the pixel-space angle theta_math from data[x, y].
    Returns angle in radians in range [-pi/2, pi/2].
    """
    x_grid, y_grid = np.indices(data.shape)
    
    m00 = np.sum(data)
    if m00 <= 0: return np.nan
    
    # Centroids
    x_c = np.sum(x_grid * data) / m00
    y_c = np.sum(y_grid * data) / m00
    
    # Second-order central moments
    mu20 = np.sum((x_grid - x_c)**2 * data) / m00
    mu02 = np.sum((y_grid - y_c)**2 * data) / m00
    mu11 = np.sum((x_grid - x_c) * (y_grid - y_c) * data) / m00
    
    # Principal axis angle
    return 0.5 * np.arctan2(2 * mu11, mu20 - mu02)

def calculate_pa_from_wcs(data, wcs, beam=True):
    """
    Converts pixel theta_math to Astronomical Position Angle (PA).
    
    Assumptions:
    1. data is indexed as [x, y].
    2. PA is 0 at North, 90 at East.
    3. theta_math is from pixel +x axis toward pixel +y axis.
    if beam, returned angle satisfies -90 < pa <= 90
    otherwise 0 <= pa < 180
    """
    theta_math = calculate_theta_math(data)
    if np.isnan(theta_math):
        return np.nan

    # Get the transformation matrix
    m = wcs.pixel_scale_matrix
    
    # The direction of Celestial North (+Dec) in your pixel grid.
    # In the [x, y] convention, the second column of the matrix (m[0,1], m[1,1])
    # tells us how the pixel x and y components contribute to Declination.
    # We want the angle of the vector that points towards increasing Dec.
    north_pixel_angle = np.arctan2(m[1, 1], m[0, 1])
    
    # Parity: determinant is negative for East-Left (Standard FITS/CASA).
    # parity = -1
    det = np.linalg.det(m)
    parity = np.sign(det)
    
    # THE FIX: 
    # To get the sign you expect (North=0, moving toward West is negative):
    # We take (north_pixel_angle - theta_math).
    # Then multiply by parity to ensure Eastward movement is a positive increase.
    pa_rad = (north_pixel_angle - theta_math) * parity
    
    # Normalize to [0, 180) for elliptical symmetry
    pa_deg = np.degrees(pa_rad)
    pa_deg = ((pa_deg + 90) % 180) - 90 if beam else pa_deg % 180
    return pa_deg
    
def create_demo_wcs(shape):
    """Creates a standard East-Left, North-Up WCS."""
    w = WCS(naxis=2)
    w.wcs.crpix = [shape[0] / 2, shape[1] / 2]
    w.wcs.crval = [180.0, 0.0]
    w.wcs.ctype = ["RA---TAN", "DEC--TAN"]
    
    scale = 1.0 / 3600.0 # 1 arcsec/pixel
    w.wcs.cdelt = [-scale, scale] # East is -x, North is +y
    w.wcs.pc = [[1, 0], [0, 1]]
    return w

def generate_astro_plot(data, wcs):
    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(1, 1, 1, projection=wcs)

    # IMPORTANT: data.T is required because imshow expects [row, col] (y, x)
    # but our data is [x, y].
    ax.imshow(data.T, origin='lower', cmap='magma')

    # Draw Arrow showing the major axis in pixel space
    x0, y0 = 100, 100
    length = 40
    ax.arrow(
        x0, y0, length*np.cos(np.radians(recovered_theta)), 
        length*np.sin(np.radians(recovered_theta)), 
        color='cyan', width=1, head_width=5, label='Major Axis'
    )

    ax.coords[0].set_axislabel('Right Ascension')
    ax.coords[1].set_axislabel('Declination')
    plt.title(f"Sky PA: {pa:.2f}째 | Pixel Theta: {recovered_theta:.2f}째")
    plt.legend()
    return plt

# --- DEMO EXECUTION ---

# 1. Setup
shape = (200, 200) # (width, height)
wcs = create_demo_wcs(shape)
plots = []

for target_theta_math in range(0, 360, 30):
    # 2. Define Target (e.g., North-East)
    # North is +y (90 deg), East is -x (180 deg).
    # North-East is halfway between them = 135 degrees.
    target_theta_math_rad = np.radians(target_theta_math)
    data = generate_gaussian_xy(shape, 100, 100, 20, 8, target_theta_math_rad)

    # 3. Analyze
    pa = calculate_pa_from_wcs(data, wcs, False)
    recovered_theta = np.degrees(calculate_theta_math(data))

    print(f"Input theta_math:     {target_theta_math:.2f}째")
    print(f"Recovered theta_math: {recovered_theta:.2f}째")
    print(f"Calculated Sky PA:    {pa:.2f} (Should be {target_theta_math - 90})")
    print()

    # 4. Plotting
    plots.append(generate_astro_plot(data, wcs))
    
for p in plots:
    p.show()

## Setup
This section imports dependencies and defines helpers for generating synthetic XRADIO images.
The generated images are small so the examples run quickly.


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr

from xradio.image import make_empty_sky_image
from wcs_reproject import reproject_to_match, reproject_to_frame

def _gaussian_2d(l_vals, m_vals, amp, l0, m0, sigma_l, sigma_m):
    ll, mm = np.meshgrid(l_vals, m_vals, indexing='ij')
    rr = ((ll - l0) / sigma_l) ** 2 + ((mm - m0) / sigma_m) ** 2
    return amp * np.exp(-0.5 * rr)

def _make_template(n_l, n_m, cell_arcsec, frame='fk5', projection='SIN'):
    phase_center = [0.0, 0.0]  # [ra, dec] in rad
    image_size = [n_l, n_m]
    cell = np.deg2rad(cell_arcsec / 3600.0)
    cell_size = [cell, cell]
    frequency_coords = np.array([1.4e9])
    pol_coords = ['I']
    time_coords = np.array([59000.0])

    return make_empty_sky_image(
        phase_center=phase_center,
        image_size=image_size,
        cell_size=cell_size,
        frequency_coords=frequency_coords,
        pol_coords=pol_coords,
        time_coords=time_coords,
        direction_reference=frame,
        projection=projection,
        spectral_reference='lsrk',
        do_sky_coords=True,
    )

def _attach_sky(xds, sky_vals, units):
    dims = ('time', 'frequency', 'polarization', 'l', 'm')
    coords = {d: xds.coords[d] for d in dims}
    xds['SKY'] = xr.DataArray(sky_vals, dims=dims, coords=coords)
    xds['SKY'].attrs.update({'image_type': 'Intensity', 'type': 'sky', 'units': units})
    base_group = {'sky': 'SKY'}
    if units == 'Jy/beam':
        # Define a synthetic restoring beam so Jy/beam has valid beam metadata.
        beam = np.zeros((xds.sizes['time'], xds.sizes['frequency'], xds.sizes['polarization'], 3), dtype=np.float64)
        beam[..., 0] = np.deg2rad(15.0 / 3600.0)  # major axis [rad]
        beam[..., 1] = np.deg2rad(10.0 / 3600.0)  # minor axis [rad]
        beam[..., 2] = np.pi / 2  # position angle [rad], aligned with source major axis
        xds['BEAM_FIT_PARAMS_SKY'] = xr.DataArray(
            beam,
            dims=('time', 'frequency', 'polarization', 'beam_params_label'),
            coords={
                'time': xds.coords['time'],
                'frequency': xds.coords['frequency'],
                'polarization': xds.coords['polarization'],
                'beam_params_label': xds.coords['beam_params_label'],
            },
        )
        xds['BEAM_FIT_PARAMS_SKY'].attrs['units'] = 'rad'
        base_group['beam_fit_params_sky'] = 'BEAM_FIT_PARAMS_SKY'
    xds.attrs['data_groups'] = {'base': base_group}
    xds.attrs['type'] = 'image_dataset'
    return xds

def _make_point_source_jy_per_pixel(n_l=128, n_m=128, cell_arcsec=2.0, loc=None):
    xds = _make_template(n_l, n_m, cell_arcsec, frame='fk5')
    base = np.zeros((n_l, n_m), dtype=np.float64)
    if loc is None:
        base[n_l // 2, n_m // 2] = 1.0
    else:
        base[loc] = 1.0
    sky = base[None, None, None, :, :]
    return _attach_sky(xds, sky, units='Jy/pixel')

def _make_gaussian_jy_per_beam(n_l=128, n_m=128, cell_arcsec=2.0, loc=None):
    xds = _make_template(n_l, n_m, cell_arcsec, frame='fk5')
    l = xds.coords['l'].values
    m = xds.coords['m'].values
    if loc is None:
        i_l, i_m = n_l // 2, n_m // 2
    else:
        i_l, i_m = loc
    base = _gaussian_2d(
        l, m, amp=1.0,
        l0=l[i_l],
        m0=m[i_m],
        sigma_l=5e-5,
        sigma_m=3e-5,
    )
    sky = base[None, None, None, :, :]
    return _attach_sky(xds, sky, units='Jy/beam')

def _plot_plane(da, title, mark_peak=True):
    plane = da.isel(time=0, frequency=0, polarization=0)
    arr = plane.values
    l_arcsec = np.rad2deg(da.coords['l'].values) * 3600.0
    m_arcsec = np.rad2deg(da.coords['m'].values) * 3600.0
    extent = [l_arcsec.max(), l_arcsec.min(), m_arcsec.min(), m_arcsec.max()]

    # Display with x=l and y=m regardless of plane dim ordering.
    if plane.dims == ('l', 'm'):
        plot_arr = arr.T
        idx = np.unravel_index(np.nanargmax(arr), arr.shape)
        i_l, i_m = int(idx[0]), int(idx[1])
    elif plane.dims == ('m', 'l'):
        plot_arr = arr
        idx = np.unravel_index(np.nanargmax(arr), arr.shape)
        i_m, i_l = int(idx[0]), int(idx[1])
    else:
        raise ValueError(f'Unexpected plane dims: {plane.dims}; expected (l,m) or (m,l)')

    plt.figure(figsize=(4, 4))
    plt.imshow(plot_arr, origin='lower', extent=extent, aspect='equal')
    if mark_peak:
        plt.scatter(
            [l_arcsec[i_l]],
            [m_arcsec[i_m]],
            marker='x',
            s=64,
            linewidths=1.5,
            c='white',
            label='peak',
        )
        plt.legend(loc='upper right', framealpha=0.8)
    plt.title(title)
    plt.xlabel('l offset (arcsec)')
    plt.ylabel('m offset (arcsec)')
    plt.colorbar()
    plt.tight_layout()


In [None]:
import numpy as np

def calculate_theta_symmetric(data: np.array):
    """
    For a 2-D elliptical gaussian compute theta (angle between x-axis and major axis) measured 
    in a normal right handed coordinate system with x-axis, y-axis increasing to the right and 
    up respectively. Works best on noiseless data where the gaussian isn't very near the edge.
    For astronomical images, we use pixel coords, so even though the world l-axis
    increases to the left, the x pixel coordinate increases to the right as in any sane
    cartesian system. The angle returned is the angle between the positive pixel coordinate/x
    axis and the major axis of the elliptical footprint. Any conversion to pa is not
    done here.
    The Symmetric Window Logic
    Initial Estimate: Perform a standard moment calculation on the full array to find the 
    centroid $(\bar{x}, \bar{y})$.
    Define Bounds: Determine the distance from the centroid to the nearest edge of the array. 
    Let this be $d_{min}$.
    Crop: Create a symmetric square window of size $2 \times d_{min}$ centered at $(\bar{x}, 
    \bar{y})$.Final Calculation:
    Run the moment and $\phi$ formulas only on the data within this symmetric window.
    """
    rows, cols = data.shape
    y_grid, x_grid = np.indices((rows, cols))
    print("x_grid", x_grid)
    print("y_grid", y_grid)
    
    # --- Pass 1: Global Centroid ---
    m00 = np.sum(data)
    if m00 == 0:
        return np.nan
    print("xgrid[1,2]", x_grid[1,2])
    x_c_initial = np.sum(x_grid * data) / m00
    y_c_initial = np.sum(y_grid * data) / m00

    print(f"pixel coords of gaussian peak: {x_c_initial}, {y_c_initial}")
    
    # --- Pass 2: Define Symmetric Window ---
    # Find distance to the closest edge to stay within array bounds
    dist_x = min(x_c_initial, cols - 1 - x_c_initial)
    dist_y = min(y_c_initial, rows - 1 - y_c_initial)
    
    # Use the smaller of the two to create a square symmetric buffer
    # Note: We use a floor to stay safely within pixel indices
    buffer = int(min(dist_x, dist_y))
    
    if buffer < 1: # Gaussian peak is effectively on the edge
        return np.nan 

    # Define slice indices
    # this is a square centered on the peak and extends to the edge closer
    # to the peak
    x_start, x_end = int(x_c_initial - buffer), int(x_c_initial + buffer)
    y_start, y_end = int(y_c_initial - buffer), int(y_c_initial + buffer)
    
    # Extract symmetric sub-region
    sub_data = data[y_start:y_end+1, x_start:x_end+1]
    sub_y, sub_x = np.indices(sub_data.shape)
    
    # --- Pass 3: Precise Moments on Symmetric Data ---
    m00_sub = np.sum(sub_data)
    x_c_sub = np.sum(sub_x * sub_data) / m00_sub
    y_c_sub = np.sum(sub_y * sub_data) / m00_sub
    
    mu20 = np.sum((sub_x - x_c_sub)**2 * sub_data) / m00_sub
    mu02 = np.sum((sub_y - y_c_sub)**2 * sub_data) / m00_sub
    mu11 = np.sum((sub_x - x_c_sub) * (sub_y - y_c_sub) * sub_data) / m00_sub
    
    theta = 0.5 * np.arctan2(2 * mu11, mu20 - mu02)
    print(f"theta: {theta}")
    return theta

src = _make_gaussian_jy_per_beam(n_l=140, n_m=128, cell_arcsec=2.0, loc=(90, 40))
arr = src["SKY"].isel(time=0, polarization=0, frequency=0).values
print("90, 40", arr[90, 40])
print("40, 90", arr[40, 90])
print("shape", arr.shape)
theta = calculate_theta_symmetric(arr)
print(f"theta: {theta}")


In [None]:
z = np.zeros([40, 30])
print(z.shape)

## Example 1: Reproject to Match a Target Grid (Jy/pixel)
**Goal:** reproject a `Jy/pixel` source image onto a target image with a different pixel size.

**Expected output:**
- Output has the same shape and grid as the target.
- The point source should remain centered, but the pixel sampling changes.

**Quantity conserved:** **integrated flux**.
- `Jy/pixel` represents flux per pixel area.
- When pixel sizes change, total integrated flux should be preserved.
- Use a **flux-conserving** method (e.g. `exact` or `adaptive`).


In [None]:
src = _make_point_source_jy_per_pixel(n_l=128, n_m=128, cell_arcsec=2.0, loc=(90,40))
tgt = _make_template(n_l=160, n_m=160, cell_arcsec=1.5, frame='fk5')

out = reproject_to_match(src, tgt, data_group='base', method='exact')

_plot_plane(src['SKY'], 'Source (Jy/pixel)')
_plot_plane(out['SKY'], 'Reprojected to Target Grid (Flux Conserved)')

In [None]:
flux = {}
ra = {}
dec = {}
dl = {}
dm = {}
for xds, label in zip([src,out], ["input", "output"]):
    sky = xds["SKY"]
    print(f"{label} shape: {sky.shape}")
    l = np.rad2deg(xds.l.values) * 3600
    m = np.rad2deg(xds.m.values) * 3600
    dl[label] = abs(l[1] - l[0])
    dm[label] = abs(m[1] - m[0])
    pixel_area = dl[label] * dm[label]
    flux[label] = float(xds.SKY.sum().values) * pixel_area
    print(
        f"{label} cell size: {dl} x {dm} arcsec"
    )
    print(f"{label} flux density: {flux[label]}")
    arr = sky.values
    idx = np.unravel_index(np.argmax(arr), arr.shape)
    idx_list = [int(i) for i in idx]
    print(f"{label} max value pixel coord: {idx_list}")
    print("ra shape", xds["right_ascension"].shape)
    ra[label] = xds["right_ascension"].values[idx_list[3], idx_list[4]]
    dec[label] = xds["declination"].values[idx_list[3], idx_list[4]]
    print(
        f"{label} max value world coord (ra, dec): "
        f"{ra[label]}, {dec[label]}"
    )
result = "flux is conserved as" if np.isclose(flux["input"], flux["output"], rtol=1e-4) else "flux is not conserved which is not"
diff = np.abs(1 - flux["output"]/flux["input"])
result += f" expected for this case. Relative difference: {diff}"
print()
print(result)
ra_diff = np.rad2deg(
    abs(ra["output"] - ra["input"])*3600*np.cos(dec["input"])
)
dec_diff = np.rad2deg(abs(dec["output"] - dec["input"])) * 3600
print(
    f"ra, dec difference of peak: {ra_diff}, {dec_diff} arcsec"
)
if (
    ra_diff <= max(dl["input"], dl["output"])
    and dec_diff <= max(dm["input"], dm["output"])
):
    res = "less"
else:
    res = "greater"
print(f"world coord difference {res} than larger pixel size")

    
    


## Example 2: Reproject to Match a Target Grid (Jy/beam)
**Goal:** reproject a `Jy/beam` image to a target grid using interpolation.

**Expected output:**
- Output has the same shape and grid as the target.
- The Gaussian source should preserve its **local shape and peak** approximately.
- The source is intentionally off-center to verify this is not a center-only special case.

**Quantity conserved:** **flux density** (for fixed beam), not raw pixel sum.
- In the `Jy/beam` case, the physically meaningful integrated quantity is flux density, not sum, and it scales like `sum(SKY) * pixel_area / beam_area`.
- Raw `sum(SKY)` changes when pixel area changes, so it is not conserved.
- Interpolation should preserve local morphology/peak approximately while this integrated flux-density quantity remains consistent.


In [None]:
src = _make_gaussian_jy_per_beam(n_l=128, n_m=128, cell_arcsec=2.0, loc=(90, 40))
tgt = _make_template(n_l=160, n_m=160, cell_arcsec=1.5, frame='fk5')

out = reproject_to_match(src, tgt, data_group='base', method='interp', order=1)

_plot_plane(src['SKY'], 'Source (Jy/beam)')
_plot_plane(out['SKY'], 'Reprojected to Target Grid (Interpolated)')


In [None]:
metrics = {}
for xds, label in zip([src, out], ["input", "output"]):
    sky = xds["SKY"]
    l_arcsec = np.rad2deg(xds.l.values) * 3600.0
    m_arcsec = np.rad2deg(xds.m.values) * 3600.0
    dl = float(abs(l_arcsec[1] - l_arcsec[0]))
    dm = float(abs(m_arcsec[1] - m_arcsec[0]))
    arr = sky.values
    idx = np.unravel_index(np.argmax(arr), arr.shape)
    idx_list = [int(i) for i in idx]
    i_l, i_m = idx_list[-2], idx_list[-1]
    peak_ra = float(xds["right_ascension"].values[i_l, i_m])
    peak_dec = float(xds["declination"].values[i_l, i_m])
    metrics[label] = {
        "shape": sky.shape,
        "dl": dl,
        "dm": dm,
        "peak": float(sky.max().values),
        "sum": float(sky.sum().values),
        "idx": idx_list,
        "ra": peak_ra,
        "dec": peak_dec,
    }

shape_match = (
    out["SKY"].sizes["l"] == tgt.sizes["l"]
    and out["SKY"].sizes["m"] == tgt.sizes["m"]
)
peak_rel_diff = abs(metrics["output"]["peak"] - metrics["input"]["peak"]) / metrics["input"]["peak"]
peak_ok = peak_rel_diff <= 0.05

cell_area_input = metrics["input"]["dl"] * metrics["input"]["dm"]
cell_area_output = metrics["output"]["dl"] * metrics["output"]["dm"]
sum_ratio = metrics["output"]["sum"] / metrics["input"]["sum"]
inv_cell_area_ratio = cell_area_input / cell_area_output

# For Jy/beam maps: integrated flux density scales as sum * pixel_area / beam_area.
# If beam_area is unchanged between input/output, comparing sum*pixel_area is sufficient
# because beam_area cancels in the ratio.
beam_area_arcsec2 = src["SKY"].attrs.get("beam_area_arcsec2", None)
if beam_area_arcsec2 is not None:
    fd_input = metrics["input"]["sum"] * cell_area_input / beam_area_arcsec2
    fd_output = metrics["output"]["sum"] * cell_area_output / beam_area_arcsec2
    fd_label = "Integrated flux density estimate (Jy)"
else:
    fd_input = metrics["input"]["sum"] * cell_area_input
    fd_output = metrics["output"]["sum"] * cell_area_output
    fd_label = "Integrated flux-density proxy (sum*pixel_area; beam area cancels)"
fd_rel_diff = abs(fd_output - fd_input) / abs(fd_input)
fd_ok = fd_rel_diff <= 0.02

ra_diff_arcsec = np.rad2deg(abs(metrics["output"]["ra"] - metrics["input"]["ra"]) * np.cos(metrics["input"]["dec"])) * 3600.0
dec_diff_arcsec = np.rad2deg(abs(metrics["output"]["dec"] - metrics["input"]["dec"])) * 3600.0
within_pixel = (
    ra_diff_arcsec <= max(metrics["input"]["dl"], metrics["output"]["dl"])
    and dec_diff_arcsec <= max(metrics["input"]["dm"], metrics["output"]["dm"])
)

print("Example 2 Numerical Verification (Jy/beam, interp)")
print("-" * 56)
print(f"Input shape:  {metrics['input']['shape']}")
print(f"Output shape: {metrics['output']['shape']}")
print(f"Target (l,m): ({tgt.sizes['l']}, {tgt.sizes['m']})")
print(f"Shape/grid match target: {shape_match}")
print()
print(f"Input cell size:  {metrics['input']['dl']:.6f} x {metrics['input']['dm']:.6f} arcsec")
print(f"Output cell size: {metrics['output']['dl']:.6f} x {metrics['output']['dm']:.6f} arcsec")
print()
print(f"Input peak Jy/beam:  {metrics['input']['peak']:.6e}")
print(f"Output peak Jy/beam: {metrics['output']['peak']:.6e}")
print(f"Peak relative difference: {peak_rel_diff:.6e} (<= 5% ? {peak_ok})")
print()
print(f"Input sum (not conserved for Jy/beam):  {metrics['input']['sum']:.6e}")
print(f"Output sum (not conserved for Jy/beam): {metrics['output']['sum']:.6e}")
print(f"Sum ratio output/input: {sum_ratio:.6e}")
print(f"Reciprocal cell-area ratio (A_in/A_out): {inv_cell_area_ratio:.6e}")
print()
print(fd_label)
print(f"Input:  {fd_input:.6e}")
print(f"Output: {fd_output:.6e}")
print(f"Relative difference: {fd_rel_diff:.6e} (<= 2% ? {fd_ok})")
print()
print(f"Input peak pixel index:  {metrics['input']['idx']}")
print(f"Output peak pixel index: {metrics['output']['idx']}")
print(f"Peak world offset (RA, Dec): ({ra_diff_arcsec:.6f}, {dec_diff_arcsec:.6f}) arcsec")
print(f"Peak world offset within <= 1 pixel scale: {within_pixel}")
print()
all_ok = shape_match and peak_ok and fd_ok and within_pixel
summary = "PASS" if all_ok else "CHECK"
print(f"Result summary: {summary}")


## Example 3: Reproject to a New Sky Frame (FK5 -> Galactic)
**Goal:** convert to Galactic frame on a Galactic-aligned pixel grid (`keep_grid=False`) and verify sky-position consistency.

**Expected output:**
- Output world-coordinate axes are Galactic (`galactic_longitude`, `galactic_latitude`).
- Original world-coordinate axes can be retained via `keep_input_world_coords=True`.
- The image content can rotate/re-sample in pixel space because the basis changes with frame conversion.

**Numerical checks below:**
- Peak world-position consistency (input FK5 peak transformed with Astropy vs output Galactic peak).
- Galactic grid alignment to image edges for `keep_grid=False`.


In [None]:
import importlib
import wcs_reproject

# Ensure notebook picks up latest local edits to wcs_reproject.py
importlib.reload(wcs_reproject)
reproject_to_frame = wcs_reproject.reproject_to_frame

src = _make_gaussian_jy_per_beam(n_l=128, n_m=128, cell_arcsec=2.0, loc=(90, 40))

out = reproject_to_frame(
    src,
    'galactic',
    keep_grid=False,
    method='interp',
    order=1,
    keep_input_world_coords=True,
)

print('Output world coords:', [
    c for c in out.coords
    if c in ('galactic_longitude', 'galactic_latitude', 'input_right_ascension', 'input_declination')
] )
if 'BEAM_FIT_PARAMS_SKY' in src and 'BEAM_FIT_PARAMS_SKY' in out:
    src_pa = float(src['BEAM_FIT_PARAMS_SKY'].sel(beam_params_label='pa').isel(time=0, frequency=0, polarization=0).values)
    out_pa = float(out['BEAM_FIT_PARAMS_SKY'].sel(beam_params_label='pa').isel(time=0, frequency=0, polarization=0).values)
    print(f"Beam PA (rad): input={src_pa:.6e}, output={out_pa:.6e}")

_plot_plane(src['SKY'], 'Source (FK5, Jy/beam)')
_plot_plane(out['SKY'], 'Reprojected to Galactic (Grid Aligned to Galactic Axes)')


In [None]:
from astropy.coordinates import SkyCoord
from astropy import units as u

print('Example 3 Numerical Verification (off-center point source)')
print('-' * 72)

src_pt = _make_point_source_jy_per_pixel(n_l=128, n_m=128, cell_arcsec=2.0, loc=(90, 40))
out_keep_false = reproject_to_frame(
    src_pt,
    'galactic',
    keep_grid=False,
    method='interp',
    order=1,
    keep_input_world_coords=True,
)
out_keep_true = reproject_to_frame(
    src_pt,
    'galactic',
    keep_grid=True,
    method='interp',
    order=1,
    keep_input_world_coords=True,
)

def _peak_idx(arr):
    valid = np.isfinite(arr)
    if np.count_nonzero(valid) == 0:
        raise RuntimeError('No finite pixels available for peak detection.')
    return np.unravel_index(np.nanargmax(np.where(valid, arr, np.nan)), arr.shape)

src_idx = _peak_idx(src_pt['SKY'].values)
src_i_l, src_i_m = int(src_idx[-2]), int(src_idx[-1])
src_ra = float(src_pt['right_ascension'].values[src_i_l, src_i_m])
src_dec = float(src_pt['declination'].values[src_i_l, src_i_m])
src_gal = SkyCoord(src_ra * u.rad, src_dec * u.rad, frame='fk5').transform_to('galactic')
src_glon = float(src_gal.spherical.lon.to_value(u.rad))
src_glat = float(src_gal.spherical.lat.to_value(u.rad))

def _gal_peak_offsets_arcsec(xds):
    idx = _peak_idx(xds['SKY'].values)
    i_l, i_m = int(idx[-2]), int(idx[-1])
    out_glon = float(xds['galactic_longitude'].values[i_l, i_m])
    out_glat = float(xds['galactic_latitude'].values[i_l, i_m])
    dlon = np.rad2deg(abs(np.arctan2(np.sin(out_glon - src_glon), np.cos(out_glon - src_glon))) * np.cos(src_glat)) * 3600.0
    dlat = np.rad2deg(abs(out_glat - src_glat)) * 3600.0
    return [int(i) for i in idx], dlon, dlat

idx_false, dlon_false, dlat_false = _gal_peak_offsets_arcsec(out_keep_false)
idx_true, dlon_true, dlat_true = _gal_peak_offsets_arcsec(out_keep_true)

pix_scale_arcsec = max(
    abs(np.rad2deg(src_pt['l'].values[1] - src_pt['l'].values[0]) * 3600.0),
    abs(np.rad2deg(src_pt['m'].values[1] - src_pt['m'].values[0]) * 3600.0),
)
peak_ok_false = dlon_false <= pix_scale_arcsec and dlat_false <= pix_scale_arcsec
peak_ok_true = dlon_true <= pix_scale_arcsec and dlat_true <= pix_scale_arcsec

# Grid-alignment test for keep_grid=False: Galactic lon should vary mostly along l,
# and Galactic lat should vary mostly along m.
glon = out_keep_false['galactic_longitude'].values
glat = out_keep_false['galactic_latitude'].values
dlon_l = np.nanmedian(np.abs(np.diff(glon, axis=0)))
dlat_l = np.nanmedian(np.abs(np.diff(glat, axis=0)))
dlon_m = np.nanmedian(np.abs(np.diff(glon, axis=1)))
dlat_m = np.nanmedian(np.abs(np.diff(glat, axis=1)))
cross_ratio_l = float(dlat_l / max(dlon_l, 1e-30))
cross_ratio_m = float(dlon_m / max(dlat_m, 1e-30))
grid_parallel_ok = cross_ratio_l < 0.01 and cross_ratio_m < 0.01

print(f"Input peak pixel coord:                  {[int(i) for i in src_idx]}")
print(f"Output peak pixel coord (keep_grid=False): {idx_false}")
print(f"Output peak pixel coord (keep_grid=True):  {idx_true}")
print()
print(f"Pixel scale for tolerance: {pix_scale_arcsec:.6f} arcsec")
print(f"Peak galactic offset keep_grid=False (lon, lat): ({dlon_false:.6f}, {dlat_false:.6f}) arcsec")
print(f"Peak galactic offset keep_grid=True  (lon, lat): ({dlon_true:.6f}, {dlat_true:.6f}) arcsec")
print(f"Peak world-position consistent (keep_grid=False): {peak_ok_false}")
print(f"Peak world-position consistent (keep_grid=True):  {peak_ok_true}")
print()
print(f"Galactic-grid cross-axis ratio along l: {cross_ratio_l:.6e}")
print(f"Galactic-grid cross-axis ratio along m: {cross_ratio_m:.6e}")
print(f"Galactic grid parallel to image edges (keep_grid=False): {grid_parallel_ok}")

all_ok = peak_ok_false and peak_ok_true and grid_parallel_ok
print()
print(f"Result summary: {'PASS' if all_ok else 'CHECK'}")


def _wrap_half_turn(angle):
    while angle > np.pi / 2:
        angle -= np.pi
    while angle < -np.pi / 2:
        angle += np.pi
    return angle

def _major_axis_pa_plot(ds, frac_threshold=0.2):
    """Major-axis PA in plot convention (north->east, east is left on plot)."""
    arr = ds['SKY'].isel(time=0, frequency=0, polarization=0).values
    l = ds['l'].values
    m = ds['m'].values
    ll, mm = np.meshgrid(l, m, indexing='ij')

    valid = np.isfinite(arr)
    if np.count_nonzero(valid) == 0:
        raise RuntimeError('No finite pixels available for PA estimate.')

    threshold = float(np.nanmax(arr[valid])) * frac_threshold
    mask = valid & (arr >= threshold)
    if np.count_nonzero(mask) < 5:
        raise RuntimeError('Too few pixels for PA estimate.')

    w = arr[mask]
    w = w / np.sum(w)

    # Plot convention PA is measured from north (+m) toward east (+l).
    east = -ll
    north = mm

    east_sel = east[mask]
    north_sel = north[mask]
    mean_east = float(np.sum(w * east_sel))
    mean_north = float(np.sum(w * north_sel))
    de = east_sel - mean_east
    dn = north_sel - mean_north

    c_ee = float(np.sum(w * de * de))
    c_nn = float(np.sum(w * dn * dn))
    c_en = float(np.sum(w * de * dn))
    cov = np.array([[c_ee, c_en], [c_en, c_nn]])

    evals, evecs = np.linalg.eigh(cov)
    major = evecs[:, int(np.argmax(evals))]
    pa = float(np.arctan2(float(major[0]), float(major[1])))
    return _wrap_half_turn(pa)


beam_pa_ok = True
if 'BEAM_FIT_PARAMS_SKY' in src and 'BEAM_FIT_PARAMS_SKY' in out:
    src_pa = float(
        src['BEAM_FIT_PARAMS_SKY']
        .sel(beam_params_label='pa')
        .isel(time=0, frequency=0, polarization=0)
        .values
    )
    out_pa = float(
        out['BEAM_FIT_PARAMS_SKY']
        .sel(beam_params_label='pa')
        .isel(time=0, frequency=0, polarization=0)
        .values
    )
    input_source_pa = np.rad2deg(_major_axis_pa_plot(src))
    output_source_pa = np.rad2deg(_major_axis_pa_plot(out))
    input_beam_pa = np.rad2deg(_wrap_half_turn(src_pa))
    output_beam_pa = np.rad2deg(_wrap_half_turn(out_pa))

    input_pa_sep = abs(
        np.rad2deg(
            _wrap_half_turn(np.deg2rad(input_beam_pa - input_source_pa))
        )
    )
    output_pa_sep = abs(
        np.rad2deg(
            _wrap_half_turn(np.deg2rad(output_beam_pa - output_source_pa))
        )
    )
    beam_pa_ok = input_pa_sep < 1.0 and output_pa_sep < 1.0

    print()
    print('Beam/source PA consistency (Gaussian Example 3 image):')
    print(f'Input source PA [deg]:  {input_source_pa:.6f}')
    print(f'Input beam PA [deg]:    {input_beam_pa:.6f}')
    print(f'Output source PA [deg]: {output_source_pa:.6f}')
    print(f'Output beam PA [deg]:   {output_beam_pa:.6f}')
    print(f'Input source-beam PA separation [deg]:  {input_pa_sep:.6f}')
    print(f'Output source-beam PA separation [deg]: {output_pa_sep:.6f}')
    print(f'Beam/source PA parallel within 1 deg: {beam_pa_ok}')

all_ok = all_ok and beam_pa_ok
print() 
print(f"Result summary (including beam/source PA check): {'PASS' if all_ok else 'CHECK'}")


In [None]:
# Quick handle for inspection if needed
out


## Example 4: Reproject to a New Frame With Re-Centered Grid (Jy/pixel)
**Goal:** change to a new frame and rebuild a same-sized grid centered in the target frame.

**Expected output:**
- Same shape and pixel size as the source.
- Spatial coordinates re-centered for the target frame.

**Quantity conserved:** **integrated flux** for `Jy/pixel` when using a flux-conserving method.


In [None]:
src = _make_point_source_jy_per_pixel(n_l=128, n_m=128, cell_arcsec=2.0)

out = reproject_to_frame(src, 'galactic', keep_grid=False, method='exact')

_plot_plane(src['SKY'], 'Source (Jy/pixel, FK5)')
_plot_plane(out['SKY'], 'Reprojected to Galactic (Re-centered Grid)')
