# Plot astro images correctly using matplotlib

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from astropy.wcs import WCS

def generate_gaussian_xy(shape, x_c, y_c, sigma_a, sigma_b, theta_math, amplitude=100):
    """
    Generates a 2D Gaussian array with data[x, y] indexing.
    
    theta_math: Angle in radians from pixel +x (Right) toward pixel +y (Up).
    sigma_a: Major axis standard deviation.
    sigma_b: Minor axis standard deviation.
    """
    if sigma_b > sigma_a:
        raise ValueError("sigma_a must be the major axis (>= sigma_b)")
        
    # In data[x, y] convention, first axis must be X
    x, y = np.indices(shape)
    
    cos_t = np.cos(theta_math)
    sin_t = np.sin(theta_math)
    
    # Rotation transformation
    x_rot = (x - x_c) * cos_t + (y - y_c) * sin_t
    y_rot = -(x - x_c) * sin_t + (y - y_c) * cos_t
    
    g = amplitude * np.exp(-( (x_rot**2 / (2 * sigma_a**2)) + (y_rot**2 / (2 * sigma_b**2)) ))
    return g

def calculate_theta_math(data):
    """
    Recovers the pixel-space angle theta_math from data[x, y].
    Returns angle in radians in range [-pi/2, pi/2].
    """
    x_grid, y_grid = np.indices(data.shape)
    
    m00 = np.sum(data)
    if m00 <= 0: return np.nan
    
    # Centroids
    x_c = np.sum(x_grid * data) / m00
    y_c = np.sum(y_grid * data) / m00
    
    # Second-order central moments
    mu20 = np.sum((x_grid - x_c)**2 * data) / m00
    mu02 = np.sum((y_grid - y_c)**2 * data) / m00
    mu11 = np.sum((x_grid - x_c) * (y_grid - y_c) * data) / m00
    
    # Principal axis angle
    return 0.5 * np.arctan2(2 * mu11, mu20 - mu02)

def calculate_pa_from_wcs(data, wcs, beam=True):
    """
    Converts pixel theta_math to Astronomical Position Angle (PA).
    
    Assumptions:
    1. data is indexed as [x, y].
    2. PA is 0 at North, 90 at East.
    3. theta_math is from pixel +x axis toward pixel +y axis.
    if beam, returned angle satisfies -90 < pa <= 90
    otherwise 0 <= pa < 180
    """
    theta_math = calculate_theta_math(data)
    if np.isnan(theta_math):
        return np.nan

    # Get the transformation matrix
    m = wcs.pixel_scale_matrix
    
    # The direction of Celestial North (+Dec) in your pixel grid.
    # In the [x, y] convention, the second column of the matrix (m[0,1], m[1,1])
    # tells us how the pixel x and y components contribute to Declination.
    # We want the angle of the vector that points towards increasing Dec.
    north_pixel_angle = np.arctan2(m[1, 1], m[0, 1])
    
    # Parity: determinant is negative for East-Left (Standard FITS/CASA).
    # parity = -1
    det = np.linalg.det(m)
    parity = np.sign(det)
    
    # THE FIX: 
    # To get the sign you expect (North=0, moving toward West is negative):
    # We take (north_pixel_angle - theta_math).
    # Then multiply by parity to ensure Eastward movement is a positive increase.
    pa_rad = (north_pixel_angle - theta_math) * parity
    
    # Normalize to [0, 180) for elliptical symmetry
    pa_deg = np.degrees(pa_rad)
    pa_deg = ((pa_deg + 90) % 180) - 90 if beam else pa_deg % 180
    return pa_deg
    
def create_demo_wcs(shape):
    """Creates a standard East-Left, North-Up WCS."""
    w = WCS(naxis=2)
    w.wcs.crpix = [shape[0] / 2, shape[1] / 2]
    w.wcs.crval = [180.0, 0.0]
    w.wcs.ctype = ["RA---TAN", "DEC--TAN"]
    
    scale = 1.0 / 3600.0 # 1 arcsec/pixel
    w.wcs.cdelt = [-scale, scale] # East is -x, North is +y
    w.wcs.pc = [[1, 0], [0, 1]]
    return w

def generate_astro_plot(data, wcs):
    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(1, 1, 1, projection=wcs)

    # IMPORTANT: data.T is required because imshow expects [row, col] (y, x)
    # but our data is [x, y].
    ax.imshow(data.T, origin='lower', cmap='magma')

    # Draw Arrow showing the major axis in pixel space
    x0, y0 = 100, 100
    length = 40
    ax.arrow(
        x0, y0, length*np.cos(np.radians(recovered_theta)), 
        length*np.sin(np.radians(recovered_theta)), 
        color='cyan', width=1, head_width=5, label='Major Axis'
    )

    ax.coords[0].set_axislabel('Right Ascension')
    ax.coords[1].set_axislabel('Declination')
    plt.title(f"Sky PA: {pa:.2f}째 | Pixel Theta: {recovered_theta:.2f}째")
    plt.legend()
    return plt

# --- DEMO EXECUTION ---

# 1. Setup
shape = (200, 200) # (width, height)
wcs = create_demo_wcs(shape)
plots = []

for target_theta_math in range(0, 360, 30):
    # 2. Define Target (e.g., North-East)
    # North is +y (90 deg), East is -x (180 deg).
    # North-East is halfway between them = 135 degrees.
    target_theta_math_rad = np.radians(target_theta_math)
    data = generate_gaussian_xy(shape, 100, 100, 20, 8, target_theta_math_rad)

    # 3. Analyze
    pa = calculate_pa_from_wcs(data, wcs, False)
    recovered_theta = np.degrees(calculate_theta_math(data))

    print(f"Input theta_math:     {target_theta_math:.2f}째")
    print(f"Recovered theta_math: {recovered_theta:.2f}째")
    print(f"Calculated Sky PA:    {pa:.2f} (Should be {target_theta_math - 90})")
    print()

    # 4. Plotting
    plots.append(generate_astro_plot(data, wcs))
    
for p in plots:
    p.show()