In [1]:
import os
import glob
import logging

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.lines import Line2D
from matplotlib.backends.backend_pdf import PdfPages

from astropy.io import fits
from scipy.stats import chi2
from scipy.optimize import curve_fit

from sherpa.astro.ui import *
from sherpa.astro.utils import *
from ciao_contrib.runtool import *

import datetime
from scipy.ndimage import rotate
from coords.format import ra2deg, dec2deg

# high-DPI plots, larger fonts, origin at lower-left
plt.rcParams.update({
    'figure.dpi': 400,
    'font.size': 18,
    'image.origin': 'lower',
})

# suppress Sherpa info messages
logger = logging.getLogger("sherpa")
logger.setLevel(logging.WARNING)

In [2]:
def quickpos(x, y, x0, y0, iterations=1, size_list=None, binsize_list=None, doplot=False):
    """
    Iteratively refines the centroid position using histogram fitting.
    (Streamlined: `step_plot` is now vectorized)
    """

    def step_plot(x, y, binwidth):
        xsteps = np.ravel(np.column_stack((x - binwidth/2, x + binwidth/2)))
        ysteps = np.repeat(y, 2)
        return xsteps, ysteps

    # Set default lists if none are provided
    if size_list is None:
        size_list = [np.max(x) - np.min(x)] * iterations
    if binsize_list is None:
        binsize_list = [1.0] * iterations

    fig_list = []
    current_x0, current_y0 = x0, y0
    cnt = None
    best_x0, best_y0 = current_x0, current_y0 

    for i in range(iterations):
        size = size_list[i]
        binsize = binsize_list[i]

        if doplot:
            fig, ax = plt.subplots(3, 1, figsize=(6, 10))
            ax[0].scatter(x, y, s=0.5, c='k')
            ax[0].set_xlim(current_x0 - size, current_x0 + size)
            ax[0].set_ylim(current_y0 - size, current_y0 + size)
            ax[0].set_title(f'Centroid Plot (Iteration {i+1})')
        else:
            fig = None
            ax = [None, None, None]

        ob = np.where((np.abs(x - current_x0) < size) & (np.abs(y - current_y0) < size))
        
        if len(ob[0]) == 0:
            print(f"Warning: No points found in window for iteration {i+1}. Using previous values.")
            if doplot:
                 fig_list.append(fig)
                 plt.show()
            continue 

        xbins = np.arange(current_x0 - size, current_x0 + size + binsize, binsize)
        ybins = np.arange(current_y0 - size, current_y0 + size + binsize, binsize)

        xhist, xedges = np.histogram(x[ob], bins=xbins)
        yhist, yedges = np.histogram(y[ob], bins=ybins)

        xval = 0.5 * (xedges[:-1] + xedges[1:])
        yval = 0.5 * (yedges[:-1] + yedges[1:])

        def gaussian(x, a, mu, sigma, offset):
            return a * np.exp(-((x - mu)**2) / (2 * sigma**2)) + offset

        # Fit Gaussian to x histogram
        try:
            xmax = np.max(xhist)
            x0_new = xval[np.argmax(xhist)]
            xestpar = [xmax, x0_new, 2 * binsize, 0]
            xpar, _ = curve_fit(gaussian, xval, xhist, p0=xestpar)
            xf = gaussian(xval, *xpar)
            best_x0 = xpar[1] 
            xcnt = xpar[0] * xpar[2] * np.sqrt(2 * np.pi)
        except Exception as e:
            print(f"Warning: X-fit failed for iteration {i+1}: {e}. Using previous X value.")
            xcnt = 0
            xf = np.zeros_like(xval)

        if doplot:
            xval_s, xhist_s = step_plot(xval, xhist, binsize)
            ax[1].plot(xval_s, xhist_s, c='b', alpha=0.3, label='x histogram')
            ax[1].plot(xval, xf, '--', c='red', linewidth=0.75, label='Gaussian fit')
            ax[1].legend()

        # Fit Gaussian to y histogram
        try:
            ymax = np.max(yhist)
            y0_new = yval[np.argmax(yhist)]
            yestpar = [ymax, y0_new, 2 * binsize, 0]
            ypar, _ = curve_fit(gaussian, yval, yhist, p0=yestpar)
            yf = gaussian(yval, *ypar)
            best_y0 = ypar[1] 
            ycnt = ypar[0] * ypar[2] * np.sqrt(2 * np.pi)
        except Exception as e:
            print(f"Warning: Y-fit failed for iteration {i+1}: {e}. Using previous Y value.")
            ycnt = 0
            yf = np.zeros_like(yval)


        if doplot:
            yval_s, yhist_s = step_plot(yval, yhist, binsize)
            ax[2].plot(yval_s, yhist_s, c='b', alpha=0.3, label='y histogram')
            ax[2].plot(yval, yf, '--', c='red', linewidth=0.75, label='Gaussian fit')
            ax[2].legend()

        # Estimate counts and update centroid
        cnt = 0.5 * (xcnt + ycnt)
        
        current_x0 = best_x0
        current_y0 = best_y0

        if doplot:
            fig_list.append(fig)
            plt.show()

    return best_x0, best_y0, cnt, fig_list

####################################################################################################################################################################################

def data_extract_quickpos_iter(infile, iters=3, sizes=[10, 5, 1.5], binsizes=[0.1, 0.1, 0.05]):
    with fits.open(infile) as obs:
        hdr = obs[1].header
        data = obs[1].data
        data.shape
        
        #extracting scale and reference coordinate   
        scale = hdr['tcdlt20']
        xc = hdr['tcrpx20']
        exptime = hdr['exposure']
            
        #form modified julian date for this obs
        mjd_start = hdr['mjd-obs']
        half_expos = 0.5 * (hdr['tstop']-hdr['tstart'])
        date = mjd_start + half_expos / 86400
        
        #converting event positions to arcsec
        x = (data['x'] - xc) * scale * 3600
        y = (data['y'] - xc) * scale * 3600
        
        rr = np.sqrt(x**2 + y**2)
        ok = np.where(rr < 20)
        
        #starting estimate of centroid    
        x0_est = np.average(x[ok])
        y0_est = np.average(y[ok])

    iterations = iters
    size_list = sizes
    binsize_list = binsizes
    
    x0_best, y0_best, cnt, qp_figs = quickpos(x[ok], y[ok], x0_est, y0_est, iterations, size_list, binsize_list)
    
    pixel_x0_best = x0_best / (scale * 3600) + xc
    pixel_y0_best = y0_best / (scale * 3600) + xc

    return date, exptime, pixel_x0_best, pixel_y0_best, cnt, qp_figs

####################################################################################################################################################################################

def rotate_psf_array(psf_file, match_file, outfile):
    """
    Rotates a PSF image array based on a match file's ROLL_NOM
    and saves it with the *original* PSF's header.
    (Streamlined: Uses 'with' statements)
    """

    # --- 1. Open the match file to get ROLL_NOM ---
    try:
        with fits.open(match_file) as hdu_match:
            # Check the primary header (HDU 0)
            if 'ROLL_NOM' in hdu_match[0].header:
                roll_nom = hdu_match[0].header['ROLL_NOM']
            # If not, check the EVENTS header (HDU 1)
            elif hdu_match[1].header and 'ROLL_NOM' in hdu_match[1].header:
                roll_nom = hdu_match[1].header['ROLL_NOM']
            # Fallback to ROLL_PNT in HDU 1
            elif hdu_match[1].header and 'ROLL_PNT' in hdu_match[1].header:
                roll_nom = hdu_match[1].header['ROLL_PNT']
                print("  Note: Using 'ROLL_PNT' as 'ROLL_NOM' was not found.")
            else:
                print("  ERROR: Could not find 'ROLL_NOM' or 'ROLL_PNT' in HDU 0 or 1 of match file.")
                return
    except FileNotFoundError:
        print(f"  ERROR: Match file not found: {match_file}")
        return
    except Exception as e:
        print(f"  ERROR: Could not read match file header: {e}")
        return

    # --- 2. Calculate the rotation angle ---
    # scipy.ndimage.rotate rotates counter-clockwise
    angle_to_rotate = roll_nom - 45.0

    # --- 3. Open the PSF file to get its data and header ---
    try:
        with fits.open(psf_file) as hdu_psf:
            if hdu_psf[0].data is None:
                # Handle cases where data is in HDU 1
                psf_data = hdu_psf[1].data
                psf_header = hdu_psf[1].header
            else:
                psf_data = hdu_psf[0].data
                psf_header = hdu_psf[0].header
                
            if psf_data is None:
                print(f"ERROR: No image data found in HDU 0 or 1 of {psf_file}.")
                return
    except FileNotFoundError:
        print(f"  ERROR: PSF file not found: {psf_file}")
        return
    except Exception as e:
        print(f"  ERROR: Could not read PSF file data/header: {e}")
        return

    # --- 4. Rotate the PSF data array ---
    rotated_psf_data = rotate(
        psf_data,
        angle_to_rotate,
        reshape=False,       # Keep the same array shape
        cval=0.0,            # Fill new pixels with 0
        order=3              # Cubic interpolation
    )

    # --- 5. Save the new, rotated data ---
    
    # Add a HISTORY card to document what we did.
    timestamp = datetime.datetime.now().isoformat()
    psf_header.add_history(f"Rotated by {angle_to_rotate:.4f} deg (ROLL_NOM={roll_nom:.4f} - 45.0)")
    psf_header.add_history(f"Rotation applied by script on {timestamp}")

    hdu_out = fits.PrimaryHDU(data=rotated_psf_data, header=psf_header)

    try:
        hdu_out.writeto(outfile, overwrite=True)
    except Exception as e:
        print(f"  ERROR: Could not write output file: {e}")

####################################################################################################################################################################################

def write_pixelscale(file: str, nx: int, ny: int, ra: str, dec: str, hrc_pscale_arcsec: float = 0.13175):
    """
    Adds a WCS header to a FITS file.
    (Streamlined: Repetitive 'dmhedit' calls are now looped)
    """

    # --- 1. CRPIX Calculation ---
    x_pix_ctr = (nx / 2.0) + 0.5
    y_pix_ctr = (ny / 2.0) + 0.5

    # --- 2. CDELT Calculation ---
    hrc_pscale_deg = hrc_pscale_arcsec / 3600.
    
    # This is the 1/4 pixel scale
    x_platescale = -abs(hrc_pscale_deg / 4.)
    y_platescale = abs(hrc_pscale_deg / 4.)
    
    # --- 3. CRVAL Calculation ---
    ra_deg = ra2deg(ra)
    dec_deg = dec2deg(dec)

    # --- 4. Apply Header Keywords ---
    wcs_params = [
        # (key, value, datatype, unit)
        ("WCSAXES", 2, "short", None),
        ("CRPIX1", x_pix_ctr, "float", None),
        ("CRPIX2", y_pix_ctr, "float", None),
        ("CDELT1", x_platescale, "float", "deg"),
        ("CDELT2", y_platescale, "float", "deg"),
        ("CUNIT1", "deg", "string", None),
        ("CUNIT2", "deg", "string", None),
        ("CTYPE1", "RA---TAN", "string", None),
        ("CTYPE2", "DEC--TAN", "string", None),
        ("CRVAL1", ra_deg, "float", "deg"),
        ("CRVAL2", dec_deg, "float", "deg"),
        ("LONPOLE", 180.0, "float", "deg"),
        ("LATPOLE", 0, "float", "deg"),
        ("RADESYS", "ICRS", "string", None),
    ]
    try:
        for key, value, dtype, unit in wcs_params:
            dmhedit(infile=file, op="add", key=key, value=value, datatype=dtype, unit=unit)
    except Exception as e:
        print(f"  ERROR: dmhedit failed: {e}")

In [3]:
def src_psf_images(obsid, infile, x0, y0, diameter, wcs_ra, wcs_dec, binsize=0.25, shape='square', psfimg=True, showimg=False, empirical_psf=None):
    """
    Creates and loads source and (optionally) PSF images into Sherpa.
    (Streamlined: Consolidated tool setup)
    """

    if shape.lower() == 'circle':
        region_str = f"circle({x0},{y0},{diameter/2})"
    elif shape.lower() == 'square':
        region_str = f"box({x0},{y0},{diameter},{diameter},0)"
        # This is the cutout region for the 512x512 PSF in LOGICAL pixels
        img_region_str = f"box(256.5,256.5,{diameter/binsize},{diameter/binsize},0)"
    else:
        print('Shape is not circle or square, using user-defined region...')
        region_str = shape.lower()

    obsid = os.path.dirname(os.path.dirname(infile))
    
    # Define all output filenames
    logical_width = diameter/binsize
    
    imagefile=f'{obsid}/src_image_{shape}_{int(logical_width)}pixel.fits'
    psf_imagefile = f'{obsid}/psf_image_{shape}_raytrace_{int(logical_width)}pixel.fits' # Raytraced PSF
    psf_rotated = f'{obsid}/psf_rotated.fits'        # Rotated empirical PSF
    psf_rotated_cut = f'{obsid}/psf_rotated_cut.fits'  # Cut, rotated PSF
    emp_psf_imagefile = f'{obsid}/psf_image_{shape}_empirical_{int(logical_width)}pixel.fits' # Final reprojected PSF
    
    dmcopy.punlearn()
    dmcopy.clobber = 'yes'
    reproject_image.punlearn()
    reproject_image.clobber = 'yes'

    # --- 1. Process and load the source image ---
    print(f"Creating source image: {imagefile}")
    dmcopy.infile = f'{infile}[sky={region_str}][bin x=::{binsize},y=::{binsize}]'
    dmcopy.outfile = imagefile
    dmcopy()
    
    load_data(imagefile)
    if showimg:
        image_close()
        image_data()
    
    # --- 2. Process the PSF image ---
    
    if empirical_psf is not None:
        # Rotate the PSF array
        rotate_psf_array(psf_file=empirical_psf, match_file=infile, outfile=psf_rotated)
        
        # Give the rotated PSF a WCS
        try:
            with fits.open(psf_rotated) as hdu_rot:
                nx = hdu_rot[0].header['NAXIS1']
                ny = hdu_rot[0].header['NAXIS2']

            write_pixelscale(file=psf_rotated, nx=nx, ny=ny, ra=str(wcs_ra), dec=str(wcs_dec))

        except Exception as e:
            print(f"!!! ERROR during WCS stamping: {e}")

        # Cut out the rotated, WCS-stamped PSF
        dmcopy.infile = f'{psf_rotated}[{img_region_str}][bin x=::{binsize*4},y=::{binsize*4}]'
        dmcopy.outfile = psf_rotated_cut
        dmcopy()

        # Reproject the cut PSF to match the data
        reproject_image.infile = psf_rotated_cut
        reproject_image.matchfile = imagefile      # Use matchfile parameter
        reproject_image.outfile = emp_psf_imagefile # Use your final filename
        reproject_image.method = 'sum'           # Conserve flux
        reproject_image()

        # Load the final PSF
        load_psf(f'centr_psf{obsid}', emp_psf_imagefile) # Use your final filename
        set_psf(f'centr_psf{obsid}')
        print(f"Loaded empirical PSF: {emp_psf_imagefile}\n") # Use your final filename
        
        if showimg:
            image_close()
            image_psf()

    elif psfimg:
        psf_infile = f'{obsid}/raytrace_projrays.fits'
        
        dmcopy.infile = f'{psf_infile}[{region_str}][bin x=::{binsize},y=::{binsize}]'
        dmcopy.outfile = psf_imagefile
        dmcopy()

        load_psf(f'centr_psf{obsid}', psf_imagefile)
        set_psf(f'centr_psf{obsid}')

        if showimg:
            image_close()
            image_psf()
    else:
        print("No PSF loaded.\n")

    return binsize

####################################################################################################################################################################################

def gaussian_image_fit(observation, n_components, position, ampl, fwhm,
                       background=0, pos_min=(0, 0), pos_max=None, exptime=None, central_component=None, lock_fwhm=False,
                       freeze_components=None, confidence=0, stat='cstat', method='moncar', prefix="g", 
                       confirm=False, imgfit=False, plot=False, plot_options=None, results=False):

    # Helper Functions
    def process_numeric_param(param, name):
        """If param is a single number, expand it into a list for all components;
        otherwise verify that the list is of correct length."""
        if isinstance(param, (int, float)):
            return [param] * n_components
        elif isinstance(param, list):
            if len(param) != n_components:
                raise ValueError(f"The list of {name} values must have length equal to n_components.")
            return param
        else:
            raise ValueError(f"{name} must be either a number or a list of numbers.")

    def process_tuple_param(param, name):
        """If param is a single (x, y) tuple, expand it for all components;
        otherwise verify that the list of tuples is of correct length."""
        if isinstance(param, (tuple, list)) and len(param) == 2 and all(isinstance(x, (int, float)) for x in param):
            return [param] * n_components
        elif isinstance(param, list):
            if len(param) != n_components:
                raise ValueError(f"The list of {name} values must have length equal to n_components.")
            return param
        else:
            raise ValueError(f"{name} must be either a tuple (x, y) or a list of such tuples.")

    # Process Input Parameters
    positions = process_tuple_param(position, "position")
    ampls = process_numeric_param(ampl, "ampl")
    fwhms = process_numeric_param(fwhm, "fwhm")
    pos_mins = process_tuple_param(pos_min, "pos_min")
    if pos_max is None:
        pos_maxs = [None] * n_components
    else:
        pos_maxs = process_tuple_param(pos_max, "pos_max")

    # Build the Model Expression and Create Components
    comp_names = []
    gaussian_components = []
    model_components = []

    for i in range(1, n_components + 1):
        comp_name = f"{prefix}{i}"
        comp_names.append(comp_name)
        comp = gauss2d(comp_name)
        gaussian_components.append(comp)
        model_components.append(comp)

    # Add background component if needed.
    bkg_comp = None
    if background > 0:
        bkg_comp = const2d("c1")
        model_components.append(bkg_comp)

    if model_components:
        set_source(sum(model_components))
    else:
        raise ValueError("Model expression is empty. Cannot set source.")

    # Assign Parameters and Constraints for Each Gaussian Component
    freeze_list = (freeze_components if isinstance(freeze_components, list)
                   else ([freeze_components] if freeze_components is not None else []))
    for i, comp in enumerate(gaussian_components):
        comp_number = i + 1

        # positions & amplitude
        comp.xpos = positions[i][0]
        comp.ypos = positions[i][1]
        comp.ampl = ampls[i]
        comp.fwhm = fwhms[i]

        # constraints
        if hasattr(comp.xpos, 'min'): comp.xpos.min = pos_mins[i][0]
        if hasattr(comp.ypos, 'min'): comp.ypos.min = pos_mins[i][1]
        if pos_maxs[i] is not None:
            if hasattr(comp.xpos, 'max'): comp.xpos.max = pos_maxs[i][0]
            if hasattr(comp.ypos, 'max'): comp.ypos.max = pos_maxs[i][1]
        if hasattr(comp.ampl, 'min'): comp.ampl.min = 0

        # any full-component freezes
        if comp_number in freeze_list:
            freeze(comp)
            print(f"Froze entire component {comp_number} ({comp.name}) as requested.")

    # Link FWHMs
    if central_component is not None and lock_fwhm:
        master = gaussian_components[central_component-1].fwhm
        for idx, comp in enumerate(gaussian_components):
            if idx != (central_component-1):
                link(comp.fwhm, master)

    # Background Component Setup
    if bkg_comp is not None:
        bkg_comp.c0 = background
        if hasattr(bkg_comp.c0, 'min'):
            bkg_comp.c0.min = 0

    # Confirm Model (Optional)
    if confirm:
        show_model()
        proceed = input("Proceed with fit? (y/n): ")
        if proceed.lower() != "y":
            print("Fit canceled.")
            return None

    # Set Fitting Options and Run the Fit
    set_stat(stat)
    set_method(method)
    if method == 'moncar':
        set_method_opt('numcores', 12)
        set_method_opt('population_size', 10 * 16 * (n_components * 3 + 1))
        set_method_opt('xprob', 0.5)
        set_method_opt('weighting_factor', 0.5)

    fit()
    fit_results = get_fit_results()

    # Confidence Interval Calculation
    conf_results = None # Ensure conf_results exists
    if confidence > 0:
        print(f"-->Moncar fit statistic: {fit_results.statval:.2f}\n")
        for comp in gaussian_components:
            thaw(comp.xpos, comp.ypos, comp.ampl)
        print(f"Calculating {confidence}-sigma confidence intervals using simplex...")
        set_method('simplex')
        fit()
        fit_results = get_fit_results()
        print(f"-->Simplex fit statistic: {fit_results.statval:.2f}\n")
        set_conf_opt('numcores', 12)
        set_conf_opt('sigma', confidence)
        conf()
        conf_results = get_conf_results()

    if imgfit:
        image_close()
        image_fit()

    # Build Fit Result Summary (Optional)
    if results:
        fit_summary = (
            f"Method = {fit_results.methodname}\n"
            f"Statistic = {fit_results.statname}\n"
            f"Initial fit statistic = {fit_results.istatval:.2f}\n"
            f"Final fit statistic = {fit_results.statval:.2f} at function evaluation {fit_results.nfev}\n"
            f"Data points = {fit_results.numpoints}\n"
            f"Degrees of freedom = {fit_results.dof}\n"
            f"Probability [Q-value] = {fit_results.qval}\n"
            f"Reduced statistic = {fit_results.rstat:.5f}\n"
            f"Change in statistic = {fit_results.dstatval:.2f}\n\n"
        )
        
        def fmt_val(val, width=10, prec=3):
            if val is None:
                return "-------".rjust(width)
            return f"{val:>{width}.{prec}f}"

        if confidence and conf_results is not None:
            param_table_rows = [
                f"confidence {confidence}-sigma bounds:",
                f"{'Param':<10}\t{'Best-Fit':>10}\t{'Lower':>10}\t{'Upper':>10}",
                f"{'-'*5:<10}\t{'-'*8:>10}\t{'-'*5:>10}\t{'-'*5:>10}"
            ]
            for name, best, low, high in zip(conf_results.parnames, 
                                             conf_results.parvals, 
                                             conf_results.parmins, 
                                             conf_results.parmaxes):
                param_table_rows.append(
                    f"{name:<10}\t{fmt_val(best)}\t{fmt_val(low)}\t{fmt_val(high)}"
                )
            param_table = "\n".join(param_table_rows)
        else:
            param_table_rows = [
                "Best-Fit Parameter Values:",
                f"{'Param':<10}\t{'Best-Fit':>10}",
                f"{'-'*5:<10}\t{'-'*8:>10}"
            ]
            for name, best in zip(fit_results.parnames, fit_results.parvals):
                param_table_rows.append(f"{name:<10}\t{fmt_val(best)}")
            param_table = "\n".join(param_table_rows)
                
        summary_output = fit_summary + param_table + '\n'

        # Count rate block (with amplitude & FWHM error propagation)
        if exptime and confidence > 0 and conf_results is not None:
            rate_block_rows = ["Component count rates (counts/s):"]
            for comp in gaussian_components:
                comp_img   = get_model_component_image(comp.name)
                total_cts  = comp_img.y.sum()
                rate       = total_cts / exptime
                short      = comp.name.split('.')[-1]
                amp_name   = f"{short}.ampl"
                fwhm_name  = f"{short}.fwhm"
    
                if amp_name in conf_results.parnames:
                    a_idx      = conf_results.parnames.index(amp_name)
                    A_best     = conf_results.parvals[a_idx]
                    
                    dA_minus_val = conf_results.parmins[a_idx]
                    dA_plus_val  = conf_results.parmaxes[a_idx]
                    
                    dA_minus = abs(dA_minus_val) if dA_minus_val is not None else 0
                    dA_plus  = abs(dA_plus_val)  if dA_plus_val  is not None else 0
                else:
                    A_best = 1; dA_minus = 0; dA_plus = 0
    
                if fwhm_name in conf_results.parnames:
                    f_idx      = conf_results.parnames.index(fwhm_name)
                    F_best     = conf_results.parvals[f_idx]
                    
                    dF_minus_val = conf_results.parmins[f_idx]
                    dF_plus_val  = conf_results.parmaxes[f_idx]
                    
                    dF_minus = abs(dF_minus_val) if dF_minus_val is not None else 0
                    dF_plus  = abs(dF_plus_val)  if dF_plus_val  is not None else 0
                else:
                    F_best = 1; dF_minus = 0; dF_plus = 0
    
                frac_minus = np.sqrt((dA_minus/A_best)**2 + (2*dF_minus/F_best)**2) if A_best > 0 and F_best > 0 else 0
                frac_plus  = np.sqrt((dA_plus /A_best)**2 + (2*dF_plus /F_best)**2) if A_best > 0 and F_best > 0 else 0
    
                dR_minus = (total_cts * frac_minus) / exptime
                dR_plus  = (total_cts * frac_plus)  / exptime
    
                rate_block_rows.append(
                    f"  {short:<6}: {rate:7.4f}  "
                    f"(-{dR_minus:6.4f}/+{dR_plus:6.4f})"
                )
            summary_output += "\n" + "\n".join(rate_block_rows) + "\n"
        else:
            summary_output = fit_summary + param_table + '\n\n\n\n'

    # Retrieve Data/Model Images and Compute Residuals
    data_img = get_data_image()
    data_vals = data_img.y
    min_positive_val = np.min(data_vals[data_vals > 0]) if np.any(data_vals > 0) else 1e-9
    display_floor = min_positive_val / 10.0
    
    data_masked = np.maximum(data_vals, display_floor) 

    model_img = get_model_image()
    model_vals = model_img.y
    model_masked = np.maximum(model_vals, display_floor)
    
    # Use the original, unmasked arrays for logic
    d_vals = data_vals
    m_vals = model_vals

    # Start with the standard formula, using masked arrays to avoid log(0)
    # This is Case 3 (d > 0, m > 0). We will overwrite the other cases.
    D = 2.0 * (data_masked * np.log(data_masked / model_masked) - (data_masked - model_masked))
    
    # Apply Case 1 (m <= 0), dev = 2*d
    D = np.where(m_vals <= 0, 2.0 * d_vals, D)
    
    # Apply Case 2 (m > 0 AND d <= 0), dev = 2*m
    D = np.where((m_vals > 0) & (d_vals <= 0), 2.0 * m_vals, D)

    resid_dev = np.sign(data_vals - model_vals) * np.sqrt(np.abs(D))

    # Determine and Prepare Plot Options
    if plot_options is None:
        plot_options = ["data_fit", "model", "deviance"]
    elif isinstance(plot_options, str):
        plot_options = [plot_options]
    elif not isinstance(plot_options, list):
        print(f"Warning: plot_options was of type {type(plot_options)}. Using default plots.")
        plot_options = ["data_fit", "model", "deviance"]

    n_plots = len(plot_options)
    if n_plots == 0:
        print("No plots requested.")
        fig, axs = plt.subplots(1, 1); plt.close(fig) # Create and close empty fig
        return (summary_output, fig) if results else fig

    fig, axs = plt.subplots(1, n_plots, figsize=(10 * n_plots, 5 * n_plots))
    if n_plots == 1:
        axs = [axs]
    plot_idx = 0

    vmax_display = np.max(data_vals)
    log_norm = mcolors.LogNorm(
        vmin=display_floor, 
        vmax=vmax_display if vmax_display > display_floor else display_floor + 1
    )

    # Plot "data_fit"
    if "data_fit" in plot_options:
        ax = axs[plot_idx]
        im = ax.imshow(data_masked, origin='lower', cmap='gnuplot2', norm=log_norm,
                       interpolation='nearest')

        legend_elements = []
        base_colors = ['white', 'cyan', 'lime', 'xkcd:light lavender']
        linestyles = ['--', ':', '-.']

        for i, comp_name in enumerate(comp_names):
            comp_vals = get_model_component_image(comp_name).y
            if not np.any(comp_vals > 0): continue

            color = base_colors[i % len(base_colors)]
            linestyle = '--' if i < len(base_colors) else linestyles[(i // len(base_colors)) % len(linestyles)]

            if n_components > 1:
                level = 0.2 * np.max(comp_vals)
                if level <= 0: level = 1e-9 * np.max(comp_vals)
                if level > 0:
                    ax.contour(comp_vals, levels=[level], colors=[color],
                               linestyles=linestyle, linewidths=2)
            else:
                levels = np.linspace(np.min(comp_vals), np.max(comp_vals), 6)
                if len(np.unique(levels)) > 1:
                    ax.contour(comp_vals, levels=levels[1:], colors=[color],
                               linestyles=linestyle, linewidths=2)

            legend_elements.append(Line2D([0], [0], lw=2, linestyle=linestyle,
                                          color=color, label=f"{comp_name}"))
        if legend_elements:
            ax.legend(handles=legend_elements, loc='upper right')
        ax.set_title(f"{observation} Data + Fit Overlay")
        ax.set_xlabel("X Pixel"); ax.set_ylabel("Y Pixel")
        fig.colorbar(im, ax=ax, label="Counts", shrink=0.53)
        plot_idx += 1

    # Plot "model"
    if "model" in plot_options:
        if plot_idx >= n_plots: return (summary_output, fig) if results else fig
        ax = axs[plot_idx]
        im = ax.imshow(model_masked, origin='lower', cmap='gnuplot2', norm=log_norm,
                       interpolation='nearest')
        ax.set_title("Model")
        ax.set_xlabel("X Pixel"); ax.set_ylabel("Y Pixel")
        fig.colorbar(im, ax=ax, label="Model Counts", shrink=0.53)
        plot_idx += 1

    # Plot "deviance"
    if "deviance" in plot_options:
        if plot_idx >= n_plots: return (summary_output, fig) if results else fig
        ax = axs[plot_idx]
        im = ax.imshow(np.abs(resid_dev), origin='lower', cmap='gnuplot2',
                       norm=mcolors.Normalize(vmin=0, vmax=5),
                       interpolation='nearest')
        ax.set_title("Poisson Deviance Residuals")
        ax.set_xlabel("X Pixel"); ax.set_ylabel("Y Pixel")
        fig.colorbar(im, ax=ax, label="|Residuals|", shrink=0.53)
        plot_idx += 1

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    if plot:
        plt.show()
    else:
        plt.close(fig)

    # Return Results
    return (summary_output, fig) if results else fig

In [4]:
# Change to your working directory.
os.chdir('/Users/leodrake/Documents/MIT/ss433/HRC_2024/')

# We define the specific RA/Dec for each ObsID here.
# Keys are strings to match the 'obsid' variable.
# Values are (ra, dec) tuples of strings.
obsid_coords = {
    "26568": ("287.9565362", "4.9826061"),
    "26569": ("287.9563218", "4.9827745"),
    "26570": ("287.9563754", "4.9825322"),
    "26571": ("287.9561693", "4.9827006"),
    "26572": ("287.9565032", "4.9826636"),
    "26573": ("287.9565444", "4.9826390"),
    "26574": ("287.9562518", "4.9825651"),
    "26575": ("287.9566969", "4.9828114"),
    "26576": ("287.9566351", "4.9826718"),
    "26577": ("287.9565238", "4.9826020"),
    "26578": ("287.9566021", "4.9826800"),
    "26579": ("287.9565733", "4.9825774")
}

# Define the empirical PSF file to be used.
emp_psf_file = "/Users/leodrake/Documents/MIT/ss433/HRC_2024/empPSF_iARLac_v2025_2017-2025.fits" 

# Find all evt2.fits files in subdirectories.
event_files = sorted(glob.glob('*/repro/*splinecorr.fits'))

# Open PdfPages objects to save one page per observation.
pdf_out = PdfPages('2Dfits/0fit-plots.pdf')
multi_pdf_out = PdfPages('2Dfits/0multi-comp-plots.pdf')

# Open the main results file and the multi-component-only results file.
results_filename = '2Dfits/0fit-results.txt'
multi_results_filename = '2Dfits/0multi-comp-fit-results.txt'
with open(results_filename, 'w') as results_file, open(multi_results_filename, 'w') as multi_results_file:

    # Loop over each event file.
    for infile in event_files[:]:
        # Extract observation directory/name.
        obsid = os.path.dirname(os.path.dirname(infile))
        print(f'\nProcessing {obsid}\n')

        # Get coordinates for this ObsID
        if obsid not in obsid_coords:
            print(f"!!! WARNING: ObsID {obsid} not in coordinate lookup table. Skipping.")
            continue
        current_ra, current_dec = obsid_coords[obsid]
        
        # Data extraction and initial quickpos
        date, exptime, pixel_x0_best, pixel_y0_best, cnt, qp_figs = data_extract_quickpos_iter(infile)
        
        # Aggregate text block for writing
        header_text = (
            f"Observation: {obsid}\n"
            f"Infile: {infile}\n"
            f"Date: {date}, Exptime: {exptime}\n"
        )
        results_file.write(header_text)

        # PSF image and centroid fit
        img_width = 40  # physical pixels
        
        cent_binsize = 1.0 # Define binsize for this step
        src_psf_images(
            obsid, infile, pixel_x0_best, pixel_y0_best, img_width,
            wcs_ra=current_ra, wcs_dec=current_dec, # Pass WCS info
            binsize=cent_binsize, 
            psfimg=False, 
            empirical_psf=None
        )
        
        logical_width = img_width / cent_binsize
        img_center = logical_width / 2.0 + 0.5

        print('Centroiding...\n')
        
        centroid_fit_summary, centroid_fit_fig = gaussian_image_fit(
            obsid, 1, (img_center, img_center), cnt, (1.0 / cent_binsize), # Scale FWHM guess
            prefix="centrg",
            background=0.1, 
            pos_max=(logical_width, logical_width), # Use logical width
            results=True
        )
        pdf_out.savefig(centroid_fit_fig)
        plt.close(centroid_fit_fig)

        # Save centroid fit parameters
        results_file.write("\nCENTROID FIT SUMMARY:\n\n")
        results_file.write(centroid_fit_summary)

        # Retrieve centroid fit physical coordinates
        d = get_data()
        crval_x, crval_y = d.sky.crval
        crpix_x, crpix_y = d.sky.crpix
        cdelt_x, cdelt_y = d.sky.cdelt
        xphys_best = crval_x + (centrg1.xpos.val - crpix_x) * cdelt_x
        yphys_best = crval_y + (centrg1.ypos.val - crpix_y) * cdelt_y

        # Source fit in physical coordinates
        img_width = 10  # physical pixels
        
        # Using EMPIRICAL PSF at binsize=0.25
        src_binsize = 0.25 # This is forced by using empirical_psf
        src_psf_images(
            obsid, infile, xphys_best, yphys_best, img_width,
            wcs_ra=current_ra, wcs_dec=current_dec, # Pass WCS info
            binsize=src_binsize, 
            psfimg=True, 
            empirical_psf=emp_psf_file
        )

        logical_width = img_width / src_binsize # 10 / 0.25 = 40
        img_center = logical_width / 2.0 + 0.5 # 40 / 2 + 0.5 = 20.5
        
        # Scale 'cnt' and 'fwhm' guess to this 0.25-binned image
        pixel_scale_guess = 1.0 / src_binsize # 4.0
        scaled_cnt_guess = cnt / (pixel_scale_guess**2)
        scaled_fwhm_guess = 1.0 * pixel_scale_guess # fwhm=1.0 at binsize=1.0 -> fwhm=4.0 at binsize=0.25

        print('Fitting Source...\n')
        
        src_fit_summary, src_fit_fig = gaussian_image_fit(
            obsid, 1, (img_center, img_center), scaled_cnt_guess, scaled_fwhm_guess,
            prefix="srcg",
            pos_max=(logical_width, logical_width),
            results=True
        )
        pdf_out.savefig(src_fit_fig)
        plt.close(src_fit_fig)

        # Save source fit parameters
        results_file.write("SOURCE FIT SUMMARY:\n\n")
        results_file.write(src_fit_summary)

        # Multi-component fit
        
        # Dynamic scaling logic
        srcfit_off_x = srcg1.xpos.val - img_center # img_center is 20.5
        srcfit_off_y = srcg1.ypos.val - img_center # img_center is 20.5
        src_ampl = srcg1.ampl.val
        src_fwhm = srcg1.fwhm.val
        
        # Set up for multi-component fit
        img_width = 40 # physical pixels
        multi_binsize = 0.5
        
        src_psf_images(
            obsid, infile, xphys_best, yphys_best, img_width,
            wcs_ra=current_ra, wcs_dec=current_dec, # Pass WCS info
            binsize=multi_binsize, 
            empirical_psf=emp_psf_file
        )
        
        logical_width = img_width / multi_binsize # 160.0
        img_center = logical_width / 2.0 + 0.5   # 80.5
        
        # Calculate scale factor
        pixel_scale = src_binsize / multi_binsize # 0.25 / 0.25 = 1.0
        
        # Scale parameters
        new_xpos = img_center + (srcfit_off_x * pixel_scale)
        new_ypos = img_center + (srcfit_off_y * pixel_scale)
        scaled_src_fwhm = src_fwhm * pixel_scale
        scaled_src_ampl = src_ampl / (pixel_scale**2)
        
        # Re-scale the 'cnt' guess for this new image
        pixel_scale_guess = 1.0 / multi_binsize # 4.0
        scaled_cnt_ampl = cnt / (pixel_scale_guess**2)
        scaled_default_fwhm = 1.0 * pixel_scale_guess # 4.0

        n_components = 3  # total number of components
        positions = [(new_xpos, new_ypos)] + [(img_center, img_center)] * (n_components - 1)
        amplitudes = [scaled_src_ampl] + [scaled_cnt_ampl] * (n_components - 1)
        fwhms = [scaled_src_fwhm] + [scaled_default_fwhm] * (n_components - 1)

        print('Fitting multi-component Gaussian...')

        multi_fit_summary, multi_fit_fig = gaussian_image_fit(
            obsid, n_components, positions, amplitudes, fwhms,
            prefix="g",
            background=0.1,
            pos_max=(logical_width, logical_width),
            pos_min=(0, 0),
            exptime=exptime,
            central_component=1,
            lock_fwhm=True,
            confidence=1,
            results=True
        )

        print('\nPlots saving to PDF...\n')
        pdf_out.savefig(multi_fit_fig)
        multi_pdf_out.savefig(multi_fit_fig)
        plt.close(multi_fit_fig)

        # Save multi-component fit parameters
        results_file.write("MULTI-COMPONENT FIT SUMMARY:\n\n")
        results_file.write(multi_fit_summary)

        # Aggregate text block for writing
        multi_results_text = (
            f"Observation: {obsid}\n"
            f"Infile: {infile}\n"
            f"Date: {date}, Exptime: {exptime}\n"
            f"{multi_fit_summary}\n\n"
        )
        multi_results_file.write(multi_results_text)

        # Clear the current Sherpa session to prepare for the next observation
        clean()
        print('Sherpa Session Cleaned\n\n')

# Close the PDF files after processing all files.
print('Tidying Up...\n')
pdf_out.close()
multi_pdf_out.close()

print('Process Complete')


Processing 26568

Creating source image: 26568/src_image_square_40pixel.fits
No PSF loaded.

Centroiding...

Creating source image: 26568/src_image_square_40pixel.fits
Loaded empirical PSF: 26568/psf_image_square_empirical_40pixel.fits

Fitting Source...

Creating source image: 26568/src_image_square_80pixel.fits
Loaded empirical PSF: 26568/psf_image_square_empirical_80pixel.fits

Fitting multi-component Gaussian...
-->Moncar fit statistic: 4181.17

Calculating 1-sigma confidence intervals using simplex...
-->Simplex fit statistic: 4181.17


Plots saving to PDF...

Sherpa Session Cleaned



Processing 26569

Creating source image: 26569/src_image_square_40pixel.fits
No PSF loaded.

Centroiding...

Creating source image: 26569/src_image_square_40pixel.fits
Loaded empirical PSF: 26569/psf_image_square_empirical_40pixel.fits

Fitting Source...

Creating source image: 26569/src_image_square_80pixel.fits
Loaded empirical PSF: 26569/psf_image_square_empirical_80pixel.fits

Fitting multi-com