In [None]:
import numpy as np
from math import radians, pi, cos
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.lines import Line2D
from matplotlib.ticker import MaxNLocator, LinearLocator, LogLocator
from astropy.io import fits
from astropy.wcs import WCS
from astropy.visualization import simple_norm
from matplotlib.backends.backend_pdf import PdfPages
import os
import sys

# Import CIAO tools
try:
    from ciao_contrib.runtool import arestore, aconvolve
except ImportError:
    print("FATAL ERROR: Could not import 'arestore' or 'aconvolve' from 'ciao_contrib.runtool'.")
    print("Please make sure you are running this script from within a")
    print("conda environment where CIAO Python is installed.")
    sys.exit(1) # Exit the script if CIAO is not available

# Configuration
plt.rcParams['figure.dpi'] = 400

# Physical and Conversion Constants
D_SS433_PC = 5500.0
C_PC_PER_DAY = (299792.458 * 86400) / (3.08567758 * 10**13)
ARCSEC_PER_RADIAN = (180.0 / pi) * 3600.0


def load_full_russian_params():
    """Returns the full SS433 ephemeris parameters from Cherepaschuk 2025."""
    return {
        'jd0_precession': 2400000.5 + 59898.78, 'precession_period': 160.14,
        'beta': 0.2591, 'theta': radians(19.64), 'inclination': radians(78.92),
        'prec_pa': radians(10.0), 'phi0': radians(-49.0),
        'jd0_nut': 2400000.5 + 59797.68, 'nut_period': 6.28802,
        'nut_ampl': radians(0.0063 * (180/pi)),
        'jd0_orb': 2460503.14, 'orbital_period': 13.082989,
        'beta_orb_ampl': 0.004, 'beta_orb_phase0': pi
    }


def ss433_phases(jd_obs, params):
    """Computes SS433 jet observables for an array of Julian Dates."""
    inc, chi = params['inclination'], params['prec_pa']
    effective_theta, effective_beta = params['theta'], params['beta']
    prec_phase = ((jd_obs - params['jd0_precession']) / params['precession_period']) % 1.0
    phi = params['phi0'] - 2 * np.pi * prec_phase
    orb_phase = ((jd_obs - params['jd0_orb']) / params['orbital_period']) % 1.0
    nut_phase = ((jd_obs - params['jd0_nut']) / params['nut_period']) % 1.0
    effective_theta += params['nut_ampl'] * np.cos(2 * np.pi * nut_phase)
    effective_beta += params['beta_orb_ampl'] * np.sin(2 * np.pi * orb_phase + params['beta_orb_phase0'])
    sin_theta, cos_theta = np.sin(effective_theta), np.cos(effective_theta)
    sin_inc, cos_inc = np.sin(inc), np.cos(inc)
    sin_phi, cos_phi = np.sin(phi), np.cos(phi)
    sin_chi, cos_chi = np.sin(chi), np.cos(chi)
    mu = (sin_theta * sin_inc * cos_phi + cos_theta * cos_inc)
    v_ra = (sin_chi * sin_theta * sin_phi + cos_chi * sin_inc * cos_theta -
            cos_chi * cos_inc * sin_theta * cos_phi)
    v_dec = (cos_chi * sin_theta * sin_phi - sin_chi * sin_inc * cos_theta +
             sin_chi * cos_inc * sin_theta * cos_phi)
    mu_red_ra = -effective_beta * v_ra / (1 + effective_beta * mu)
    mu_red_dec = -effective_beta * v_dec / (1 + effective_beta * mu)
    mu_blue_ra = effective_beta * v_ra / (1 - effective_beta * mu)
    mu_blue_dec = effective_beta * v_dec / (1 - effective_beta * mu)
    return mu_blue_ra, mu_blue_dec, mu_red_ra, mu_red_dec


def plot_jet_overlay(base_path, obs_id_str, mjd_obs, pdf_object):
    """
    Runs arestore, aconvolve, and generates a jet overlay plot,
    saving it to the provided PDF object.
    """
    
    # Define file paths for processing
    obs_dir = os.path.join(base_path, obs_id_str)
    src_image = os.path.join(obs_dir, "src_image_square_160pixel.fits")
    psf_image = os.path.join(obs_dir, "psf_image_square_empirical_160pixel.fits")
    deconvolved_image = os.path.join(obs_dir, "deconvolved_image_square_empirical_160pixel.fits")
    smoothed_deconvolved_image = os.path.join(obs_dir, "smoothed_deconvolved_image_empirical_160pixel.fits")

    # Run arestore and aconvolve as Python functions
    try:
        print(f"  Running arestore for Obs ID {obs_id_str}...")
        arestore(
            infile=src_image,
            psffile=psf_image,
            outfile=deconvolved_image,
            numiter=100,
            clobber=True
        )
        print(f"  ...arestore finished successfully.")

        # Run aconvolve on the output of arestore
        print(f"  Running aconvolve for Obs ID {obs_id_str}...")
        
        # Corrected kernelspec syntax
        aconvolve(
            infile=deconvolved_image,
            outfile=smoothed_deconvolved_image,
            # Syntax: lib:gaus(dim, n_sigma_cutoff, amplitude, sigma_x, sigma_y)
            # n_sigma = radius / sigma = 3 / 1.5 = 2
            kernelspec="lib:gaus(2, 2, 1, 1.5, 1.5)", 
            clobber=True
        )
        # END FIX
        
        print(f"  ...aconvolve finished successfully.")
        # END NEW
    
    except FileNotFoundError as e:
        # This can be raised by the tools if input files are missing
        print(f"--> ERROR: File not found during processing: {e.filename}")
        print("    Please check if source or PSF images exist.")
        return  # Stop processing this observation
        
    except Exception as e:
        # Catch any other error from the CIAO tools
        print(f"--> ERROR: CIAO tool failed for Obs ID {obs_id_str}.")
        print(f"    Error: {e}")
        return  # Stop processing this observation
    
    fig = plt.figure(figsize=(10, 10))
    try:
        # Open the FINAL smoothed file
        with fits.open(smoothed_deconvolved_image) as hdul:
            hdu = hdul[0]
            wcs = WCS(hdu.header)
            image_data = hdu.data.astype(np.float32)
            
    except FileNotFoundError:
        # This error will now catch if aconvolve ran but failed to create the output
        print(f"--> SKIPPING: Smoothed FITS file not found at '{smoothed_deconvolved_image}'")
        plt.close(fig)
        return

    # Find the average center of the 4 brightest pixels
    # Find the indices of the top 4 brightest pixels in the flattened array
    top4_flat_indices = np.argsort(image_data.flatten())[-4:]
    # Convert those flat indices back to 2D (row, col) coordinates
    top4_coords = np.unravel_index(top4_flat_indices, image_data.shape)
    # Average the x (column) and y (row) coordinates
    avg_x_pix = np.mean(top4_coords[1])
    avg_y_pix = np.mean(top4_coords[0])
    # Convert the average pixel coordinate to a world coordinate (RA, Dec)
    center_ra, center_dec = wcs.pixel_to_world_values(avg_x_pix, avg_y_pix)

    params = load_full_russian_params()
    travel_time = np.linspace(0, 2 * params['precession_period'], 1000)
    ejection_jd = mjd_obs - travel_time

    mu_b_ra, mu_b_dec, mu_r_ra, mu_r_dec = ss433_phases(ejection_jd, params)

    def get_offset(mu):
        return (mu * C_PC_PER_DAY / D_SS433_PC * ARCSEC_PER_RADIAN) * travel_time

    off_b_ra, off_b_dec = get_offset(mu_b_ra), get_offset(mu_b_dec)
    off_r_ra, off_r_dec = get_offset(mu_r_ra), get_offset(mu_r_dec)

    cos_d = cos(radians(center_dec))
    path_b_ra, path_b_dec = center_ra + off_b_ra / 3600.0 / cos_d, center_dec + off_b_dec / 3600.0
    path_r_ra, path_r_dec = center_ra + off_r_ra / 3600.0 / cos_d, center_dec + off_r_dec / 3600.0

    ax = fig.add_subplot(1, 1, 1, projection=wcs)

    norm_image = simple_norm(image_data, stretch='log', percent=99.5)
    im = ax.imshow(image_data, origin='lower', cmap='gnuplot2', norm=norm_image)

    jet_cmap = 'rainbow'
    norm_jet = mcolors.Normalize(vmin=0, vmax=2 * params['precession_period'])
    mappable_jet = plt.cm.ScalarMappable(cmap=jet_cmap, norm=norm_jet)
    kwargs = {'c': travel_time, 'cmap': jet_cmap, 'norm': norm_jet,
              's': 15, 'transform': ax.get_transform('world')}

    ax.scatter(path_b_ra, path_b_dec, **kwargs)
    ax.scatter(path_r_ra, path_r_dec, **kwargs)

    # Add a black star at the calculated center
    ax.scatter(center_ra, center_dec, marker='*', s=150, c='black',
               transform=ax.get_transform('world'), zorder=10)

    # Use the obs_id_str passed to the function for the title
    ax.set_title(f"SS433 Model - Obs ID: {obs_id_str}\nMJD: {mjd_obs:.2f}")
    ax.set_xlabel("Right Ascension"); ax.set_ylabel("Declination")

    ax.coords[0].set_major_formatter('hh:mm:ss.s')
    ax.coords[1].set_major_formatter('dd:mm:ss')

    cax_jet = ax.inset_axes([1.04, 0.0, 0.0525, 1])
    cbar_jet = fig.colorbar(mappable_jet, cax=cax_jet)
    cbar_jet.set_label('Age of Jet Material (days)')
    
    cbar_img = fig.colorbar(im, ax=ax, orientation='horizontal', shrink=0.8, pad=0.08)
    cbar_img.set_label('Counts')
    
    cbar_img.locator = LogLocator(numticks=5)
    cbar_img.update_ticks()

    ax.grid(color='w', linestyle=':', alpha=0.5)

    pdf_object.savefig(fig, bbox_inches='tight')
    plt.close(fig)


if __name__ == "__main__":
    base_path = "/Users/leodrake/Documents/MIT/ss433/HRC_2024/"
    obs_data = {
        26568: 60454.37, 26569: 60461.87, 26570: 60467.48, 26571: 60476.27,
        26572: 60482.89, 26573: 60487.02, 26574: 60492.28, 26575: 60502.04,
        26576: 60507.10, 26577: 60517.51, 26578: 60522.28, 26579: 60530.20
    }
    output_pdf_path = "ss433_jet_overlays.pdf"

    print(f"Starting plot generation. Output will be saved to '{output_pdf_path}'...")
    with PdfPages(output_pdf_path) as pdf:
        for obs_id, mjd in obs_data.items():
            print(f"Processing Obs ID: {obs_id}...")
            # Call the function with base_path and obs_id
            # The function will now handle deconvolution, smoothing, AND plotting
            plot_jet_overlay(base_path, str(obs_id), mjd, pdf) 

    print("...Processing complete. PDF has been saved.")