In [None]:
import os
import glob
import logging
import time
import datetime

# Core scientific and plotting libraries
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.lines import Line2D
from astropy.io import fits
from scipy.stats import chi2
from scipy.optimize import curve_fit
from scipy.ndimage import rotate

# Sherpa and CIAO imports
from sherpa.astro.ui import *
from sherpa.astro.utils import *
from ciao_contrib.runtool import *
from coords.format import ra2deg, dec2deg

# MCMC and PDF compilation imports
import corner
from PIL import Image

# high-DPI plots, larger fonts, origin at lower-left
plt.rcParams.update({
    'figure.dpi': 400,
    'font.size': 18,
    'image.origin': 'lower',
})

# suppress Sherpa info messages
logger = logging.getLogger("sherpa")
logger.setLevel(logging.WARNING)

In [None]:
def quickpos(x, y, x0, y0, iterations=1, size_list=None, binsize_list=None, doplot=False):
    """
    Iteratively refines the centroid position using 1D histogram fitting.
    """

    # Helper function to create step-plot coordinates from histograms
    def step_plot(x, y, binwidth):
        xsteps = np.ravel(np.column_stack((x - binwidth/2, x + binwidth/2)))
        ysteps = np.repeat(y, 2)
        return xsteps, ysteps

    # Set default window/bin sizes if none are provided
    if size_list is None:
        size_list = [np.max(x) - np.min(x)] * iterations
    if binsize_list is None:
        binsize_list = [1.0] * iterations

    # Initialize variables
    fig_list = []
    current_x0, current_y0 = x0, y0
    cnt = None
    best_x0, best_y0 = current_x0, current_y0 

    # Loop for each refinement iteration
    for i in range(iterations):
        size = size_list[i]
        binsize = binsize_list[i]

        # Create debug plots if requested
        if doplot:
            fig, ax = plt.subplots(3, 1, figsize=(6, 10))
            ax[0].scatter(x, y, s=0.5, c='k')
            ax[0].set_xlim(current_x0 - size, current_x0 + size)
            ax[0].set_ylim(current_y0 - size, current_y0 + size)
            ax[0].set_title(f'Centroid Plot (Iteration {i+1})')
        else:
            fig = None
            ax = [None, None, None]

        # Get all points within the current window
        ob = np.where((np.abs(x - current_x0) < size) & (np.abs(y - current_y0) < size))
        
        # Skip if no points are found
        if len(ob[0]) == 0:
            print(f"Warning: No points found in window for iteration {i+1}. Using previous values.")
            if doplot:
                 fig_list.append(fig)
                 plt.show()
            continue 

        # Create 1D histograms for X and Y
        xbins = np.arange(current_x0 - size, current_x0 + size + binsize, binsize)
        ybins = np.arange(current_y0 - size, current_y0 + size + binsize, binsize)

        xhist, xedges = np.histogram(x[ob], bins=xbins)
        yhist, yedges = np.histogram(y[ob], bins=ybins)

        # Get bin centers
        xval = 0.5 * (xedges[:-1] + xedges[1:])
        yval = 0.5 * (yedges[:-1] + yedges[1:])

        # Define 1D Gaussian model for fitting
        def gaussian(x, a, mu, sigma, offset):
            return a * np.exp(-((x - mu)**2) / (2 * sigma**2)) + offset

        # Fit Gaussian to x histogram
        try:
            xmax = np.max(xhist)
            x0_new = xval[np.argmax(xhist)]
            xestpar = [xmax, x0_new, 2 * binsize, 0]
            xpar, _ = curve_fit(gaussian, xval, xhist, p0=xestpar)
            xf = gaussian(xval, *xpar)
            best_x0 = xpar[1] 
            xcnt = xpar[0] * xpar[2] * np.sqrt(2 * np.pi)
        except Exception as e:
            print(f"Warning: X-fit failed for iteration {i+1}: {e}. Using previous X value.")
            xcnt = 0
            xf = np.zeros_like(xval)

        if doplot:
            xval_s, xhist_s = step_plot(xval, xhist, binsize)
            ax[1].plot(xval_s, xhist_s, c='b', alpha=0.3, label='x histogram')
            ax[1].plot(xval, xf, '--', c='red', linewidth=0.75, label='Gaussian fit')
            ax[1].legend()

        # Fit Gaussian to y histogram
        try:
            ymax = np.max(yhist)
            y0_new = yval[np.argmax(yhist)]
            yestpar = [ymax, y0_new, 2 * binsize, 0]
            ypar, _ = curve_fit(gaussian, yval, yhist, p0=yestpar)
            yf = gaussian(yval, *ypar)
            best_y0 = ypar[1] 
            ycnt = ypar[0] * ypar[2] * np.sqrt(2 * np.pi)
        except Exception as e:
            print(f"Warning: Y-fit failed for iteration {i+1}: {e}. Using previous Y value.")
            ycnt = 0
            yf = np.zeros_like(yval)


        if doplot:
            yval_s, yhist_s = step_plot(yval, yhist, binsize)
            ax[2].plot(yval_s, yhist_s, c='b', alpha=0.3, label='y histogram')
            ax[2].plot(yval, yf, '--', c='red', linewidth=0.75, label='Gaussian fit')
            ax[2].legend()

        # Estimate counts and update centroid for next iteration
        cnt = 0.5 * (xcnt + ycnt)
        
        current_x0 = best_x0
        current_y0 = best_y0

        if doplot:
            fig_list.append(fig)
            plt.show()

    return best_x0, best_y0, cnt, fig_list

####################################################################################################################################################################################

def data_extract_quickpos_iter(infile, iters=3, sizes=[10, 5, 1.5], binsizes=[0.1, 0.1, 0.05]):
    """
    Extracts data from a FITS file and runs quickpos to get an initial centroid.
    """
    # Open the FITS file
    with fits.open(infile) as obs:
        hdr = obs[1].header
        data = obs[1].data
        
        # Extract essential header info
        scale = hdr['tcdlt20']
        xc = hdr['tcrpx20']
        exptime = hdr['exposure']
            
        # Form modified julian date for this obs
        mjd_start = hdr['mjd-obs']
        half_expos = 0.5 * (hdr['tstop']-hdr['tstart'])
        date = mjd_start + half_expos / 86400
        
        # Convert event positions to arcsec
        x = (data['x'] - xc) * scale * 3600
        y = (data['y'] - xc) * scale * 3600
        
        # Filter a 20 arcsec radius region
        rr = np.sqrt(x**2 + y**2)
        ok = np.where(rr < 20)
        
        # Get a rough starting estimate of the centroid    
        x0_est = np.average(x[ok])
        y0_est = np.average(y[ok])

    # Set iteration parameters
    iterations = iters
    size_list = sizes
    binsize_list = binsizes
    
    # Run the iterative centroid refinement
    x0_best, y0_best, cnt, qp_figs = quickpos(x[ok], y[ok], x0_est, y0_est, iterations, size_list, binsize_list)
    
    # Convert best-fit arcsec position back to pixels
    pixel_x0_best = x0_best / (scale * 3600) + xc
    pixel_y0_best = y0_best / (scale * 3600) + xc

    return date, exptime, pixel_x0_best, pixel_y0_best, cnt, qp_figs

####################################################################################################################################################################################

def rotate_psf_array(psf_file, match_file, outfile):
    """
    Rotates a PSF image array based on a match file's ROLL_NOM
    and saves it with the *original* PSF's header.
    """

    # Get the roll angle from the observation's event file
    try:
        with fits.open(match_file) as hdu_match:
            # Check the primary header (HDU 0)
            if 'ROLL_NOM' in hdu_match[0].header:
                roll_nom = hdu_match[0].header['ROLL_NOM']
            # If not, check the EVENTS header (HDU 1)
            elif hdu_match[1].header and 'ROLL_NOM' in hdu_match[1].header:
                roll_nom = hdu_match[1].header['ROLL_NOM']
            # Fallback to ROLL_PNT in HDU 1
            elif hdu_match[1].header and 'ROLL_PNT' in hdu_match[1].header:
                roll_nom = hdu_match[1].header['ROLL_PNT']
                print("  Note: Using 'ROLL_PNT' as 'ROLL_NOM' was not found.")
            else:
                print("  ERROR: Could not find 'ROLL_NOM' or 'ROLL_PNT' in HDU 0 or 1 of match file.")
                return
    except FileNotFoundError:
        print(f"  ERROR: Match file not found: {match_file}")
        return
    except Exception as e:
        print(f"  ERROR: Could not read match file header: {e}")
        return

    # Calculate the rotation angle (scipy rotates counter-clockwise)
    angle_to_rotate = roll_nom - 45.0

    # Open the empirical PSF file
    try:
        with fits.open(psf_file) as hdu_psf:
            # Get the main image data and header
            if hdu_psf[0].data is None:
                # Handle cases where data is in HDU 1
                psf_data = hdu_psf[1].data
                psf_header = hdu_psf[1].header
            else:
                psf_data = hdu_psf[0].data
                psf_header = hdu_psf[0].header
                
            if psf_data is None:
                print(f"ERROR: No image data found in HDU 0 or 1 of {psf_file}.")
                return
    except FileNotFoundError:
        print(f"  ERROR: PSF file not found: {psf_file}")
        return
    except Exception as e:
        print(f"  ERROR: Could not read PSF file data/header: {e}")
        return

    # Rotate the PSF data array
    rotated_psf_data = rotate(
        psf_data,
        angle_to_rotate,
        reshape=False,       # Keep the same array shape
        cval=0.0,            # Fill new pixels with 0
        order=3              # Cubic interpolation
    )

    # Save the new, rotated data
    
    # Add a HISTORY card to document the rotation
    timestamp = datetime.datetime.now().isoformat()
    psf_header.add_history(f"Rotated by {angle_to_rotate:.4f} deg (ROLL_NOM={roll_nom:.4f} - 45.0)")
    psf_header.add_history(f"Rotation applied by script on {timestamp}")

    # Write the new FITS file
    hdu_out = fits.PrimaryHDU(data=rotated_psf_data, header=psf_header)
    try:
        hdu_out.writeto(outfile, overwrite=True)
    except Exception as e:
        print(f"  ERROR: Could not write output file: {e}")

####################################################################################################################################################################################

def write_pixelscale(file: str, nx: int, ny: int, ra: str, dec: str, hrc_pscale_arcsec: float = 0.13175):
    """
    Adds a WCS header to a FITS file using dmhedit.
    """

    # Calculate center pixel (CRPIX)
    x_pix_ctr = (nx / 2.0) + 0.5
    y_pix_ctr = (ny / 2.0) + 0.5

    # Calculate plate scale in degrees (CDELT)
    hrc_pscale_deg = hrc_pscale_arcsec / 3600.
    
    # This is the 1/4 pixel scale of the empirical PSF
    x_platescale = -abs(hrc_pscale_deg / 4.)
    y_platescale = abs(hrc_pscale_deg / 4.)
    
    # Convert RA/Dec to degrees (CRVAL)
    ra_deg = ra2deg(ra)
    dec_deg = dec2deg(dec)

    # Define all WCS keywords to be added
    wcs_params = [
        # (key, value, datatype, unit)
        ("WCSAXES", 2, "short", None),
        ("CRPIX1", x_pix_ctr, "float", None),
        ("CRPIX2", y_pix_ctr, "float", None),
        ("CDELT1", x_platescale, "float", "deg"),
        ("CDELT2", y_platescale, "float", "deg"),
        ("CUNIT1", "deg", "string", None),
        ("CUNIT2", "deg", "string", None),
        ("CTYPE1", "RA---TAN", "string", None),
        ("CTYPE2", "DEC--TAN", "string", None),
        ("CRVAL1", ra_deg, "float", "deg"),
        ("CRVAL2", dec_deg, "float", "deg"),
        ("LONPOLE", 180.0, "float", "deg"),
        ("LATPOLE", 0, "float", "deg"),
        ("RADESYS", "ICRS", "string", None),
    ]
    
    # Loop and apply all keywords using dmhedit
    try:
        for key, value, dtype, unit in wcs_params:
            dmhedit(infile=file, op="add", key=key, value=value, datatype=dtype, unit=unit)
    except Exception as e:
        print(f"  ERROR: dmhedit failed: {e}")

####################################################################################################################################################################################

def process_mcmc_results(covar_results, chains, burn_in_frac=0.2, sigma=1):
    """
    Processes raw MCMC chains to calculate median and error bounds.
    """
    
    # Transpose the chains array to (niter, nparams)
    chains = chains.T
    
    # Set percentile for the requested sigma level
    if sigma == 1:
        # 1-sigma (68.27%)
        q_low = 15.865
        q_high = 84.135
    elif sigma == 2:
        # 2-sigma (95.45%)
        q_low = 2.275
        q_high = 97.725
    elif sigma == 3:
        # 3-sigma (99.73%)
        q_low = 0.135
        q_high = 99.865
    else:
        raise ValueError("Only sigma=1, 2, or 3 is supported.")

    # Calculate number of burn-in steps
    n_iter, n_params = chains.shape
    burn_in = int(n_iter * burn_in_frac)
    
    # Handle case where burn_in is too high
    if burn_in >= n_iter:
        print(f"Warning: burn_in_frac ({burn_in_frac}) is too high, resulting in 0 valid chains.")
        print(f"         Resetting burn_in to 0 for this run.")
        burn_in = 0
        
    # Get the valid chains after discarding burn-in
    valid_chains = chains[burn_in:, :]
    
    print(f"  Processing {valid_chains.shape[0]} valid chain iterations (after {burn_in} burn-in).")

    # Initialize results dictionary
    mcmc_results = {
        'parnames': [],
        'parvals': [],  # Median (50th percentile)
        'parmins': [],  # Lower error bound (e.g., 16th percentile)
        'parmaxes': []  # Upper error bound (e.g., 84th percentile)
    }

    # Get thawed parameter names from the covariance results
    thawed_parnames = covar_results.parnames

    # Check for parameter/chain dimension mismatch
    if len(thawed_parnames) != n_params:
        print(f"Warning: Number of thawed params ({len(thawed_parnames)}) does not match chain dims ({n_params}).")
        thawed_parnames = covar_results.parnames[:n_params]

    # Iterate over each parameter's chain
    for i, parname in enumerate(thawed_parnames):
        chain_col = valid_chains[:, i]
        
        # Calculate quantiles (median, 1-sigma lower, 1-sigma upper)
        p_low, p_mid, p_high = np.percentile(chain_col, [q_low, 50, q_high])
        
        # Store results
        mcmc_results['parnames'].append(parname)
        mcmc_results['parvals'].append(p_mid)
        mcmc_results['parmins'].append(p_low)
        mcmc_results['parmaxes'].append(p_high)

    # Return the summary dict, valid chains, and param names
    return mcmc_results, valid_chains, thawed_parnames

In [None]:
def src_psf_images(obsid, infile, x0, y0, diameter, wcs_ra, wcs_dec, binsize=0.25, shape='square', psfimg=True, showimg=False, empirical_psf=None):
    """
    Creates and loads source and (optionally) PSF images into Sherpa.
    """

    # Define the region string based on shape
    if shape.lower() == 'circle':
        region_str = f"circle({x0},{y0},{diameter/2})"
    elif shape.lower() == 'square':
        region_str = f"box({x0},{y0},{diameter},{diameter},0)"
        # Define the cutout region for the 512x512 logical PSF
        img_region_str = f"box(256.5,256.5,{diameter/binsize},{diameter/binsize},0)"
    else:
        print('Shape is not circle or square, using user-defined region...')
        region_str = shape.lower()

    # Get the obsid directory name
    obsid = os.path.dirname(os.path.dirname(infile))
    
    # Define all output filenames
    logical_width = diameter/binsize
    
    imagefile=f'{obsid}/src_image_{shape}_{int(logical_width)}pixel.fits'
    psf_imagefile = f'{obsid}/psf_image_{shape}_raytrace_{int(logical_width)}pixel.fits' # Raytraced PSF
    psf_rotated = f'{obsid}/psf_rotated.fits'        # Rotated empirical PSF
    psf_rotated_cut = f'{obsid}/psf_rotated_cut.fits'  # Cut, rotated PSF
    emp_psf_imagefile = f'{obsid}/psf_image_{shape}_empirical_{int(logical_width)}pixel.fits' # Final reprojected PSF
    
    # Unlearn CIAO tools for a clean run
    dmcopy.punlearn()
    dmcopy.clobber = 'yes'
    reproject_image.punlearn()
    reproject_image.clobber = 'yes'

    # Process and load the source image
    print(f"Creating source image: {imagefile}")
    dmcopy.infile = f'{infile}[sky={region_str}][bin x=::{binsize},y=::{binsize}]'
    dmcopy.outfile = imagefile
    dmcopy()
    
    # Load the data image into Sherpa
    load_data(imagefile)
    if showimg:
        image_close()
        image_data()
    
    # Process the PSF image
    
    # This is the main path: using an empirical PSF
    if empirical_psf is not None:
        # Rotate the empirical PSF to match the observation's roll angle
        rotate_psf_array(psf_file=empirical_psf, match_file=infile, outfile=psf_rotated)
        
        # Give the newly rotated PSF a WCS header
        try:
            with fits.open(psf_rotated) as hdu_rot:
                nx = hdu_rot[0].header['NAXIS1']
                ny = hdu_rot[0].header['NAXIS2']

            write_pixelscale(file=psf_rotated, nx=nx, ny=ny, ra=str(wcs_ra), dec=str(wcs_dec))

        except Exception as e:
            print(f"!!! ERROR during WCS stamping: {e}")

        # Cut out the rotated, WCS-stamped PSF
        # Note: it's binned by binsize*4 because the empirical PSF is 4x oversampled
        dmcopy.infile = f'{psf_rotated}[{img_region_str}][bin x=::{binsize*4},y=::{binsize*4}]'
        dmcopy.outfile = psf_rotated_cut
        dmcopy()

        # Reproject the cut PSF to match the data image's grid
        reproject_image.infile = psf_rotated_cut
        reproject_image.matchfile = imagefile      # Match the data image
        reproject_image.outfile = emp_psf_imagefile
        reproject_image.method = 'sum'           # Conserve flux
        reproject_image()

        # Load the final reprojected PSF into Sherpa
        load_psf(f'centr_psf{obsid}', emp_psf_imagefile)
        set_psf(f'centr_psf{obsid}')
        print(f"Loaded empirical PSF: {emp_psf_imagefile}\n")
        
        if showimg:
            image_close()
            image_psf()

    # This path is for a raytraced PSF
    elif psfimg:
        psf_infile = f'{obsid}/raytrace_projrays.fits'
        
        dmcopy.infile = f'{psf_infile}[{region_str}][bin x=::{binsize},y=::{binsize}]'
        dmcopy.outfile = psf_imagefile
        dmcopy()

        load_psf(f'centr_psf{obsid}', psf_imagefile)
        set_psf(f'centr_psf{obsid}')

        if showimg:
            image_close()
            image_psf()
    
    # This path is for no PSF
    else:
        print("No PSF loaded.\n")

    return binsize

####################################################################################################################################################################################

def gaussian_image_fit(observation, n_components, position, ampl, fwhm,
                       background=0, pos_min=(0, 0), pos_max=None, exptime=None, lock_fwhm=True,
                       freeze_components=None, use_mcmc=True, mcmc_iter=5000, mcmc_burn_in_frac=0.2,
                       mcmc_scale=1.0, 
                       prefix="g", confirm=True, imgfit=False):
    """
    Fits multi-component 2D Gaussian models to image data, runs MCMC, 
    and generates summary plots and text.
    """
    
    # Helper function to expand single-value inputs
    def process_numeric_param(param, name):
        """Expand a single number to a list for all components."""
        if isinstance(param, (int, float)):
            return [param] * n_components
        elif isinstance(param, list):
            if len(param) != n_components:
                raise ValueError(f"The list of {name} values must have length equal to n_components.")
            return param
        else:
            raise ValueError(f"{name} must be either a number or a list of numbers.")

    # Helper function to expand single-tuple inputs
    def process_tuple_param(param, name):
        """Expand a single (x, y) tuple to a list for all components."""
        if isinstance(param, (tuple, list)) and len(param) == 2 and all(isinstance(x, (int, float)) for x in param):
            return [param] * n_components
        elif isinstance(param, list):
            if len(param) != n_components:
                raise ValueError(f"The list of {name} values must have length equal to n_components.")
            return param
        else:
            raise ValueError(f"{name} must be either a tuple (x, y) or a list of such tuples.")

    # Process all input parameters
    positions = process_tuple_param(position, "position")
    ampls = process_numeric_param(ampl, "ampl")
    fwhms = process_numeric_param(fwhm, "fwhm")
    pos_mins = process_tuple_param(pos_min, "pos_min")
    if pos_max is None:
        pos_maxs = [None] * n_components
    else:
        pos_maxs = process_tuple_param(pos_max, "pos_max")

    # Build the model expression
    comp_names = []
    gaussian_components = []
    model_components = []
    for i in range(1, n_components + 1):
        comp_name = f"{prefix}{i}"
        comp_names.append(comp_name)
        comp = gauss2d(comp_name)
        gaussian_components.append(comp)
        model_components.append(comp)
    
    # Add background component if requested
    bkg_comp = None
    if background > 0:
        bkg_comp = const2d("c1")
        model_components.append(bkg_comp)
    
    # Set the final model in Sherpa
    if model_components:
        set_source(sum(model_components))
    else:
        raise ValueError("Model expression is empty. Cannot set source.")

    # Assign parameters and constraints for each component
    freeze_list = (freeze_components if isinstance(freeze_components, list)
                   else ([freeze_components] if freeze_components is not None else []))
    for i, comp in enumerate(gaussian_components):
        comp_number = i + 1
        comp.xpos = positions[i][0]
        comp.ypos = positions[i][1]
        comp.ampl = ampls[i]
        comp.fwhm = fwhms[i]
        
        # Set parameter limits
        if hasattr(comp.xpos, 'min'): comp.xpos.min = pos_mins[i][0]
        if hasattr(comp.ypos, 'min'): comp.ypos.min = pos_mins[i][1]
        if pos_maxs[i] is not None:
            if hasattr(comp.xpos, 'max'): comp.xpos.max = pos_maxs[i][0]
            if hasattr(comp.ypos, 'max'): comp.ypos.max = pos_maxs[i][1]
        if hasattr(comp.ampl, 'min'): comp.ampl.min = 0
        
        # Freeze components if requested
        if comp_number in freeze_list:
            freeze(comp)
            print(f"Froze entire component {comp_number} ({comp.name}) as requested.")

    # Link FWHMs of all components to the first one
    central_component = 1
    if lock_fwhm:
        master = gaussian_components[central_component-1].fwhm
        for idx, comp in enumerate(gaussian_components):
            if idx != (central_component-1):
                link(comp.fwhm, master)

    # Set up background component
    if bkg_comp is not None:
        bkg_comp.c0 = background
        if hasattr(bkg_comp.c0, 'min'):
            bkg_comp.c0.min = 0

    # Confirm model with user
    if confirm:
        show_model()
        proceed = input("Proceed with fit? (y/n): ")
        if proceed.lower() != "y":
            print("Fit canceled.")
            return None, None, None # Return None for summary, fig, and corner_fig

    # Run global 'moncar' fit
    print("  Running global fit (moncar)...")
    set_stat('cstat')
    set_method('moncar')
    set_method_opt('numcores', 12)
    set_method_opt('population_size', 10 * 16 * (n_components * 3 + 1))
    set_method_opt('xprob', 0.5)
    set_method_opt('weighting_factor', 0.5)
    fit()
    
    # Run local 'simplex' fit to polish the result
    print("  Polishing fit (simplex)...")
    set_method('simplex')
    fit()
    
    fit_results = get_fit_results()

    # MCMC Error Estimation
    mcmc_results = None
    corner_fig = None 
    mcmc_duration_str = ""
    if use_mcmc:
        # Start timer
        mcmc_start_time = time.time()
        print(f"-->Running MCMC for error estimation with {mcmc_iter} iterations...")
        
        # Thaw parameters for MCMC run
        for i, comp in enumerate(gaussian_components):
            comp_number = i + 1
            if comp_number not in freeze_list:
                thaw(comp.ampl)
                if not (lock_fwhm and comp_number != central_component):
                     thaw(comp.fwhm)
                thaw(comp.xpos, comp.ypos)
        
        # Set the sampler to Metropolis-Hastings
        set_sampler('metropolismh')
        set_sampler_opt('scale', mcmc_scale)
        
        # Calculate covariance matrix (required by sampler)
        print("  Calculating covariance matrix for MCMC proposal...")
        covar()
        
        # Get covariance results (for parameter names)
        covar_results = get_covar_results()
        
        # Run the MCMC chain
        stats, accept, chains = get_draws(niter=mcmc_iter)
        print(f"  MCMC complete. Overall acceptance rate: {np.mean(accept):.3f}")

        # Process the raw chains into a results dictionary
        mcmc_results, valid_chains, thawed_parnames = process_mcmc_results(
            covar_results,
            chains,
            burn_in_frac=mcmc_burn_in_frac,
            sigma=1
        )
        
        # Generate corner plot
        print(f"  Generating corner plot for prefix '{prefix}'...")
        corner_fig = corner.corner(
            valid_chains,
            labels=thawed_parnames,
            quantiles=[0.16, 0.5, 0.84],
            show_titles=True,
            title_fmt=".3f"
        )
        corner_fig.suptitle(f"MCMC Corner Plot - ObsID {observation} ({prefix})", y=1.02)

        # Stop timer and format duration string
        mcmc_end_time = time.time()
        mcmc_duration = mcmc_end_time - mcmc_start_time
        mcmc_duration_min = mcmc_duration / 60.0
        print(f"  MCMC block execution time: {mcmc_duration_min:.2f} minutes")
        mcmc_duration_str = f"MCMC block execution time = {mcmc_duration_min:.2f} minutes\n\n"

    # Optional: Display imgfit in ds9
    if imgfit:
        print("\nDisplaying image_fit() in ds9...")
        image_fit()
        input("  `imgfit` is active. Press Enter in this terminal to continue...")

    # Build Fit Result Summary
    fit_summary = (
        f"Method = {fit_results.methodname}\n"\
        f"Statistic = {fit_results.statname}\n"\
        f"Initial fit statistic = {fit_results.istatval:.2f}\n"\
        f"Final fit statistic = {fit_results.statval:.2f} at function evaluation {fit_results.nfev}\n"\
        f"Data points = {fit_results.numpoints}\n"\
        f"Degrees of freedom = {fit_results.dof}\n"\
        f"Probability [Q-value] = {fit_results.qval}\n"\
        f"Reduced statistic = {fit_results.rstat:.5f}\n"\
        f"Change in statistic = {fit_results.dstatval:.2f}\n\n"\
    )
    
    # Helper for formatting parameter values
    def fmt_val(val, width=10, prec=3):
        if val is None:
            return "------".rjust(width)
        return f"{val:>{width}.{prec}f}"
        
    # Build the parameter table (MCMC results)
    if mcmc_results is not None:
        param_table_rows = [
            f"MCMC ({mcmc_burn_in_frac*100:.0f}% burn-in) 1-sigma bounds:",
            f"{'Param':<12} {'Median':>10} {'Lower':>10} {'Upper':>10}",
            f"{'-'*5:<12} {'-'*8:>10} {'-'*5:>10} {'-'*5:>10}"
        ]
        for name, best, low, high in zip(mcmc_results['parnames'], 
                                         mcmc_results['parvals'], 
                                         mcmc_results['parmins'], 
                                         mcmc_results['parmaxes']):
            param_table_rows.append(
                f"{name:<12} {fmt_val(best)} {fmt_val(low)} {fmt_val(high)}"
            )
        param_table = "\n".join(param_table_rows)
    # Build the parameter table (no MCMC)
    else:
        param_table_rows = [
            "Best-Fit Parameter Values (No MCMC):",
            f"{'Param':<12} {'Best-Fit':>10}",
            f"{'-'*5:<12} {'-'*8:>10}"
        ]
        for name, best in zip(fit_results.parnames, fit_results.parvals):
            param_table_rows.append(f"{name:<12} {fmt_val(best)}")
        param_table = "\n".join(param_table_rows)
            
    # Combine all text components
    summary_output = fit_summary + mcmc_duration_str + param_table + '\n'

    # Build component count rate block
    if exptime and use_mcmc and mcmc_results is not None:
        rate_block_rows = ["Component count rates (counts/s):"]
        parnames = mcmc_results['parnames']
        parvals  = mcmc_results['parvals']
        parmins  = mcmc_results['parmins']
        parmaxes = mcmc_results['parmaxes']
        
        # Loop over each Gaussian component
        for comp in gaussian_components:
            comp_img   = get_model_component_image(comp.name)
            total_cts  = comp_img.y.sum()
            rate       = total_cts / exptime
            short      = comp.name.split('.')[-1]
            amp_name   = f"{short}.ampl"
            fwhm_name  = f"{short}.fwhm"
            
            # Get amplitude param values
            if amp_name in parnames:
                a_idx      = parnames.index(amp_name)
                A_best     = parvals[a_idx]
                dA_minus_val = parmins[a_idx]
                dA_plus_val  = parmaxes[a_idx]
                dA_minus = abs(A_best - dA_minus_val) if dA_minus_val is not None else 0
                dA_plus  = abs(dA_plus_val - A_best)  if dA_plus_val  is not None else 0
            else:
                A_best = 1; dA_minus = 0; dA_plus = 0
                
            # Get FWHM param values
            if fwhm_name in parnames:
                f_idx      = parnames.index(fwhm_name)
                F_best     = parvals[f_idx]
                dF_minus_val = parmins[f_idx]
                dF_plus_val  = parmaxes[f_idx]
                dF_minus = abs(F_best - dF_minus_val) if dF_minus_val is not None else 0
                dF_plus  = abs(dF_plus_val - F_best)  if dF_plus_val  is not None else 0
            else:
                F_best = 1; dF_minus = 0; dF_plus = 0
                
            # Propagate errors for count rate (CR ~ A * FWHM^2)
            frac_minus = np.sqrt((dA_minus/A_best)**2 + (2*dF_minus/F_best)**2) if A_best > 0 and F_best > 0 else 0
            frac_plus  = np.sqrt((dA_plus /A_best)**2 + (2*dF_plus /F_best)**2) if A_best > 0 and F_best > 0 else 0
            dR_minus = (total_cts * frac_minus) / exptime
            dR_plus  = (total_cts * frac_plus)  / exptime
            
            # Append formatted string
            rate_block_rows.append(
                f"  {short:<6}: {rate:7.4f}  "\
                f"(-{dR_minus:6.4f}/+{dR_plus:6.4f})"\
            )
        summary_output += "\n" + "\n".join(rate_block_rows) + "\n"
    else:
        summary_output = fit_summary + param_table + '\n\n\n\n'

    # Get images for plotting
    plot_options = ["data_fit", "model", "deviance"]
    n_plots = len(plot_options)
    fig, axs = plt.subplots(1, n_plots, figsize=(10 * n_plots, 5 * n_plots))
    if n_plots == 1:
        axs = [axs]
    plot_idx = 0
    data_img = get_data_image()
    data_vals = data_img.y
    
    # Set a display floor to avoid log(0)
    min_positive_val = np.min(data_vals[data_vals > 0]) if np.any(data_vals > 0) else 1e-9
    display_floor = min_positive_val / 10.0
    data_masked = np.maximum(data_vals, display_floor) 
    
    model_img = get_model_image()
    model_vals = model_img.y
    model_masked = np.maximum(model_vals, display_floor)
    
    # Calculate C-stat deviance residuals
    d_vals = data_vals
    m_vals = model_vals
    D = 2.0 * (data_masked * np.log(data_masked / model_masked) - (data_masked - model_masked))
    D = np.where(m_vals <= 0, 2.0 * d_vals, D)
    D = np.where((m_vals > 0) & (d_vals <= 0), 2.0 * m_vals, D)
    resid_dev = np.sign(data_vals - model_vals) * np.sqrt(np.abs(D))
    
    # Set log normalization for plots
    vmax_display = np.max(data_vals)
    log_norm = mcolors.LogNorm(
        vmin=display_floor, 
        vmax=vmax_display if vmax_display > display_floor else display_floor + 1
    )
    
    # Plot data and fit contours
    if "data_fit" in plot_options:
        ax = axs[plot_idx]
        im = ax.imshow(data_masked, origin='lower', cmap='gnuplot2', norm=log_norm,
                       interpolation='nearest')

        legend_elements = []
        # Use visible colors for contours
        base_colors = ['yellow', 'cyan', 'lime', 'xkcd:light lavender']
        linestyles = ['--', ':', '-.']
        
        for i, comp_name in enumerate(comp_names):
            comp_vals = get_model_component_image(comp_name).y
            
            # Skip any all-zero (frozen) components
            if not np.any(comp_vals > 0): continue
            
            color = base_colors[i % len(base_colors)]
            linestyle = '--' if i < len(base_colors) else linestyles[(i // len(base_colors)) % len(linestyles)]

            # Plot multi-component contours
            if n_components > 1:
                level = 0.2 * np.max(comp_vals)
                ax.contour(comp_vals, levels=[level], colors=[color],
                           linestyles=linestyle, linewidths=2)
            # Plot single-component contours
            else:
                levels = np.linspace(np.min(comp_vals), np.max(comp_vals), 6)
                if len(np.unique(levels)) > 1:
                    ax.contour(comp_vals, levels=levels[1:], colors=[color],
                               linestyles=linestyle, linewidths=2)

            legend_elements.append(Line2D([0], [0], lw=2, linestyle=linestyle,
                                          color=color, label=f"{comp_name}"))
        if legend_elements:
            ax.legend(handles=legend_elements, loc='upper right')
        ax.set_title(f"{observation} Data + Fit Overlay")
        ax.set_xlabel("X Pixel"); ax.set_ylabel("Y Pixel")
        fig.colorbar(im, ax=ax, label="Counts", shrink=0.53)
        plot_idx += 1

    # Plot model
    if "model" in plot_options:
        ax = axs[plot_idx]
        im = ax.imshow(model_masked, origin='lower', cmap='gnuplot2', norm=log_norm,
                       interpolation='nearest')
        ax.set_title("Model")
        ax.set_xlabel("X Pixel"); ax.set_ylabel("Y Pixel")
        fig.colorbar(im, ax=ax, label="Model Counts", shrink=0.53)
        plot_idx += 1

    # Plot deviance
    if "deviance" in plot_options:
        ax = axs[plot_idx]
        im = ax.imshow(np.abs(resid_dev), origin='lower', cmap='gnuplot2',
                       norm=mcolors.Normalize(vmin=0, vmax=5),
                       interpolation='nearest')
        ax.set_title("Poisson Deviance Residuals")
        ax.set_xlabel("X Pixel"); ax.set_ylabel("Y Pixel")
        fig.colorbar(im, ax=ax, label="|Residuals|", shrink=0.53)
        plot_idx += 1
        
    # Clean up and close plot
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.close(fig)

    # Return all results
    return summary_output, fig, corner_fig

In [None]:
# Change to your working directory.
os.chdir('/Users/leodrake/Documents/MIT/ss433/HRC_2024/')

# We define the specific RA/Dec for each ObsID here.
obsid_coords = {
    "26568": ("287.9565362", "4.9826061"),
    "26569": ("287.9563218", "4.9827745"),
    "26570": ("287.9563754", "4.9825322"),
    "26571": ("287.9561693", "4.9827006"),
    "26572": ("287.9565032", "4.9826636"),# Set the main working directory
os.chdir('/Users/leodrake/Documents/MIT/ss433/HRC_2024/')

# Define WCS coordinates for each observation
obsid_coords = {
    "26568": ("287.9565362", "4.9826061"),
    "26569": ("287.9563218", "4.9827745"),
    "26570": ("287.9563754", "4.9825322"),
    "26571": ("287.9561693", "4.9827006"),
    "26572": ("287.9565032", "4.9826636"),
    "26573": ("287.9565444", "4.9826390"),
    "26574": ("287.9562518", "4.9825651"),
    "26575": ("287.9566969", "4.9828114"),
    "26576": ("287.9566351", "4.9826718"),
    "26577": ("287.9565238", "4.9826020"),
    "26578": ("287.9566021", "4.9826800"),
    "26579": ("287.9565733", "4.9825774")
}

# Define tuned MCMC scale factors for each observation
mcmc_scale_factors = {
    "26568": 0.4,   # Rate 21.2%
    "26569": 0.03,  # Rate was 12.6% @ 0.05 trying 0.03
    "26570": 0.25,  # Rate 21.8%
    "26571": 0.03,  # Rate was 12.6% @ 0.05 trying 0.03
    "26572": 0.1,   # Rate 21.1%
    "26573": 0.25,  # Rate 23.0%
    "26574": 0.5,   # Rate 22.3%
    "26575": 0.2,   # Rate 20.2%
    "26576": 0.3,   # Rate 20.6%
    "26577": 0.5,   # Rate 25.2%
    "26578": 0.6,   # Rate 26.4%
    "26579": 0.4,   # Rate 20.1%
}

# Define the empirical PSF file to be used
emp_psf_file = "/Users/leodrake/Documents/MIT/ss433/HRC_2024/empPSF_iARLac_v2025_2017-2025.fits" 

# Find all event files to be processed
event_files = sorted(glob.glob('*/repro/*splinecorr.fits'))

# Create lists to hold temporary plot filenames
pdf_out_files = []
multi_pdf_out_files = []

# Define output PDF and text file names
pdf_out_filename = '2Dfits/0fit-plots.pdf'
multi_pdf_out_filename = '2Dfits/0multi-comp-plots.pdf'
results_filename = '2Dfits/0fit-results.txt'
multi_results_filename = '2Dfits/0multi-comp-fit-results.txt'

# Open the main results file and the multi-component-only results file
with open(results_filename, 'w') as results_file, open(multi_results_filename, 'w') as multi_results_file:

    # Start main loop over each event file
    for infile in event_files[:]:
        
        # Get ObsID from file path
        obsid = os.path.dirname(os.path.dirname(infile))
        print(f'\nProcessing {obsid}\n')

        # Get ObsID-specific coordinates
        if obsid not in obsid_coords:
            print(f"!!! WARNING: ObsID {obsid} not in coordinate lookup table. Skipping.")
            continue
        current_ra, current_dec = obsid_coords[obsid]
        
        # Get ObsID-specific MCMC scale factor
        current_mcmc_scale = mcmc_scale_factors.get(obsid, 1.0)
        print(f"  Using MCMC scale factor: {current_mcmc_scale}")
        
        # Run initial data extraction and centroiding
        date, exptime, pixel_x0_best, pixel_y0_best, cnt, qp_figs = data_extract_quickpos_iter(infile)
        
        # Write the header block to the main results file
        header_text = (
            f"Observation: {obsid}\n"\
            f"Infile: {infile}\n"\
            f"Date: {date}, Exptime: {exptime}\n"\
        )
        results_file.write(header_text)

        # Stage 1: Centroid Fit
        img_width = 40  # physical pixels
        cent_binsize = 1.0 # 1.0 pixel bins
        
        # Load image for centroiding (no PSF)
        src_psf_images(
            obsid, infile, pixel_x0_best, pixel_y0_best, img_width,
            wcs_ra=current_ra, wcs_dec=current_dec,
            binsize=cent_binsize, 
            psfimg=False, 
            empirical_psf=None
        )
        
        # Set logical image dimensions
        logical_width = img_width / cent_binsize
        img_center = logical_width / 2.0 + 0.5

        print('Centroiding...\n')
        
        # Call the fit function (1 component, no MCMC)
        centroid_fit_summary, centroid_fit_fig, cent_corner_fig = gaussian_image_fit(
            obsid, 1, (img_center, img_center), cnt, (1.0 / cent_binsize),
            prefix="centrg",
            background=0.1, 
            pos_max=(logical_width, logical_width),
            use_mcmc=False,
            confirm=False
        )
        
        # Check if user canceled the fit
        if centroid_fit_summary is None:
            print(f"Centroid fit for {obsid} canceled. Skipping observation.")
            clean()
            continue
            
        # Save plot to a temporary PNG
        temp_cent_fit_png = f"2Dfits/temp_{obsid}_cent_fit.png"
        centroid_fit_fig.savefig(temp_cent_fit_png)
        plt.close(centroid_fit_fig)
        pdf_out_files.append(temp_cent_fit_png)

        # Write summary to text file
        results_file.write("\nCENTROID FIT SUMMARY:\n\n")
        results_file.write(centroid_fit_summary)

        # Get the best-fit physical coordinates from the centroid fit
        d = get_data()
        crval_x, crval_y = d.sky.crval
        crpix_x, crpix_y = d.sky.crpix
        cdelt_x, cdelt_y = d.sky.cdelt
        xphys_best = crval_x + (centrg1.xpos.val - crpix_x) * cdelt_x
        yphys_best = crval_y + (centrg1.ypos.val - crpix_y) * cdelt_y

        # Stage 2: Single-Component Source Fit
        img_width = 10  # physical pixels
        src_binsize = 0.25 # 1/4 pixel bins
        
        # Load image and empirical PSF
        src_psf_images(
            obsid, infile, xphys_best, yphys_best, img_width,
            wcs_ra=current_ra, wcs_dec=current_dec,
            binsize=src_binsize, 
            psfimg=True, 
            empirical_psf=emp_psf_file
        )

        # Set logical image dimensions
        logical_width = img_width / src_binsize 
        img_center = logical_width / 2.0 + 0.5 
        
        # Scale initial guesses to the new binning
        pixel_scale_guess = 1.0 / src_binsize 
        scaled_cnt_guess = cnt / (pixel_scale_guess**2)
        scaled_fwhm_guess = 1.0 * pixel_scale_guess 

        print('Fitting Source...\n')
        
        # Call the fit function (1 component, no MCMC)
        src_fit_summary, src_fit_fig, src_corner_fig = gaussian_image_fit(
            obsid, 1, (img_center, img_center), scaled_cnt_guess, scaled_fwhm_guess,
            prefix="srcg",
            pos_max=(logical_width, logical_width),
            use_mcmc=False,
            confirm=False
        )
        
        # Check if user canceled the fit
        if src_fit_summary is None:
            print(f"Source fit for {obsid} canceled. Skipping observation.")
            clean()
            continue

        # Save plot to temporary PNG
        temp_src_fit_png = f"2Dfits/temp_{obsid}_src_fit.png"
        src_fit_fig.savefig(temp_src_fit_png)
        plt.close(src_fit_fig)
        pdf_out_files.append(temp_src_fit_png)

        # Write summary to text file
        results_file.write("SOURCE FIT SUMMARY:\n\n")
        results_file.write(src_fit_summary)

        # Stage 3: Multi-Component Fit
        
        # Get scaling info from the single-component fit
        srcfit_off_x = srcg1.xpos.val - img_center 
        srcfit_off_y = srcg1.ypos.val - img_center 
        src_ampl = srcg1.ampl.val
        src_fwhm = srcg1.fwhm.val
        
        # Set up image properties
        img_width = 40 
        multi_binsize = 0.25
        
        # Load image and empirical PSF
        src_psf_images(
            obsid, infile, xphys_best, yphys_best, img_width,
            wcs_ra=current_ra, wcs_dec=current_dec,
            binsize=multi_binsize, 
            empirical_psf=emp_psf_file
        )
        
        # Set logical image dimensions
        logical_width = img_width / multi_binsize 
        img_center = logical_width / 2.0 + 0.5   
        
        # Calculate pixel scale factor
        pixel_scale = src_binsize / multi_binsize 
        
        # Scale parameters for g1 (from src_fit)
        new_xpos = img_center + (srcfit_off_x * pixel_scale)
        new_ypos = img_center + (srcfit_off_y * pixel_scale)
        scaled_src_fwhm = src_fwhm * pixel_scale
        scaled_src_ampl = src_ampl / (pixel_scale**2)
        
        # Scale parameters for g2, g3... (from quickpos cnt)
        pixel_scale_guess = 1.0 / multi_binsize 
        scaled_cnt_ampl = cnt / (pixel_scale_guess**2)
        scaled_default_fwhm = 1.0 * pixel_scale_guess 

        # Set component lists for the fit
        n_components = 3  # total number of components (change as needed)
        positions = [(new_xpos, new_ypos)] + [(img_center, img_center)] * (n_components - 1)
        amplitudes = [scaled_src_ampl] + [scaled_cnt_ampl] * (n_components - 1)
        fwhms = [scaled_src_fwhm] + [scaled_default_fwhm] * (n_components - 1)

        print('Fitting multi-component Gaussian...')

        # Call the fit function (N components, with MCMC)
        multi_fit_summary, multi_fit_fig, multi_corner_fig = gaussian_image_fit(
            obsid, n_components, positions, amplitudes, fwhms,
            prefix="g",
            background=0.1,
            pos_max=(logical_width, logical_width),
            pos_min=(0, 0),
            exptime=exptime,
            confirm=False,
            use_mcmc=True,
            mcmc_iter=100000,
            mcmc_scale=current_mcmc_scale
        )

        # Check if user canceled the fit
        if multi_fit_summary is None:
            print(f"Multi-component fit for {obsid} canceled. Skipping observation.")
            clean()
            continue

        print('\nPlots saving to PDF...\n')
        
        # Save fit plot to temp PNG
        temp_multi_fit_png = f"2Dfits/temp_{obsid}_multi_fit.png"
        multi_fit_fig.savefig(temp_multi_fit_png)
        plt.close(multi_fit_fig)
        pdf_out_files.append(temp_multi_fit_png)
        multi_pdf_out_files.append(temp_multi_fit_png)

        # Save corner plot to temp PNG
        if multi_corner_fig is not None:
            temp_multi_corner_png = f"2Dfits/temp_{obsid}_multi_corner.png"
            multi_corner_fig.savefig(temp_multi_corner_png)
            plt.close(multi_corner_fig)
            pdf_out_files.append(temp_multi_corner_png)
            multi_pdf_out_files.append(temp_multi_corner_png)

        # Save multi-component fit parameters
        results_file.write("MULTI-COMPONENT FIT SUMMARY:\n\n")
        results_file.write(multi_fit_summary)

        # Aggregate text block for writing to separate file
        multi_results_text = (
            f"Observation: {obsid}\n"\
            f"Infile: {infile}\n"\
            f"Date: {date}, Exptime: {exptime}\n"\
            f"{multi_fit_summary}\n\n"\
        )
        multi_results_file.write(multi_results_text)

        # Clear the current Sherpa session to prepare for the next observation
        clean()
        print('Sherpa Session Cleaned\n\n')

# Final compilation and cleanup block
print('Tidying Up and Compiling PDFs...\n')

# Helper function to compile PNGs into a single PDF
def compile_pngs_to_pdf(png_files, pdf_filename):
    if not png_files:
        print(f"No images to compile for {pdf_filename}.")
        return
    
    # Check if the first file exists
    if not os.path.exists(png_files[0]):
        print(f"ERROR: Cannot find file {png_files[0]} to start PDF.")
        return

    images = []
    # Open the first image
    img1 = Image.open(png_files[0]).convert('RGB')
    
    # Open and append all subsequent images
    for png_file in png_files[1:]:
        if os.path.exists(png_file):
            images.append(Image.open(png_file).convert('RGB'))
        else:
            print(f"Warning: Missing file {png_file}, skipping.")
    
    # Save the final PDF
    img1.save(pdf_filename, "PDF", resolution=100.0, save_all=True, append_images=images)

# Compile the main PDF (all plots)
try:
    compile_pngs_to_pdf(pdf_out_files, pdf_out_filename)
    print(f"Successfully compiled {pdf_out_filename}")
except Exception as e:
    print(f"ERROR: Could not compile {pdf_out_filename}: {e}")

# Compile the multi-component-only PDF
try:
    compile_pngs_to_pdf(multi_pdf_out_files, multi_pdf_out_filename)
    print(f"Successfully compiled {multi_pdf_out_filename}")
except Exception as e:
    print(f"ERROR: Could not compile {multi_pdf_out_filename}: {e}")

# Final, separate cleanup step
print("Cleaning up temporary PNG files...")
temp_files_to_clean = glob.glob("2Dfits/temp_*.png")
for f in temp_files_to_clean:
    try:
        os.remove(f)
    except Exception as e:
        print(f"Warning: Could not remove {f}: {e}")

print('Process Complete')
    "26573": ("287.9565444", "4.9826390"),
    "26574": ("287.9562518", "4.9825651"),
    "26575": ("287.9566969", "4.9828114"),
    "26576": ("287.9566351", "4.9826718"),
    "26577": ("287.9565238", "4.9826020"),
    "26578": ("287.9566021", "4.9826800"),
    "26579": ("287.9565733", "4.9825774")
}

# [MODIFIED] Final tuning pass for scale factors
mcmc_scale_factors = {
    # Perfect (20-27%)
    "26568": 0.4,   # Rate 21.2%
    "26569": 0.03,  # Was 12.6% @ 0.05. Try 0.03
    "26570": 0.25,  # Rate 21.8%
    "26571": 0.03,  # Was 12.6% @ 0.05. Try 0.03
    "26572": 0.1,   # Rate 21.1%
    "26573": 0.25,  # Rate 23.0%
    "26574": 0.5,   # Rate 22.3%
    "26575": 0.2,   # Rate 20.2%
    "26576": 0.3,   # Rate 20.6%
    "26577": 0.5,   # Rate 25.2%
    "26578": 0.6,   # Rate 26.4%
    "26579": 0.4,   # Rate 20.1%
}


# Define the empirical PSF file to be used.
emp_psf_file = "/Users/leodrake/Documents/MIT/ss433/HRC_2024/empPSF_iARLac_v2025_2017-2025.fits" 

# Find all evt2.fits files in subdirectories.
event_files = sorted(glob.glob('*/repro/*splinecorr.fits'))

# Create lists to hold temporary plot FILENAMES
pdf_out_files = []
multi_pdf_out_files = []

# Define output PDF names
pdf_out_filename = '2Dfits/0fit-plots.pdf'
multi_pdf_out_filename = '2Dfits/0multi-comp-plots.pdf'

# Open the main results file and the multi-component-only results file.
results_filename = '2Dfits/0fit-results.txt'
multi_results_filename = '2Dfits/0multi-comp-fit-results.txt'
with open(results_filename, 'w') as results_file, open(multi_results_filename, 'w') as multi_results_file:

    # Loop over each event file.
    for infile in event_files[:]:
        # Extract observation directory/name.
        obsid = os.path.dirname(os.path.dirname(infile))
        print(f'\nProcessing {obsid}\n')

        # Get coordinates for this ObsID
        if obsid not in obsid_coords:
            print(f"!!! WARNING: ObsID {obsid} not in coordinate lookup table. Skipping.")
            continue
        current_ra, current_dec = obsid_coords[obsid]
        
        # Get MCMC scale factor for this ObsID (default to 1.0)
        current_mcmc_scale = mcmc_scale_factors.get(obsid, 1.0)
        print(f"  Using MCMC scale factor: {current_mcmc_scale}")
        
        # Data extraction and initial quickpos
        date, exptime, pixel_x0_best, pixel_y0_best, cnt, qp_figs = data_extract_quickpos_iter(infile)
        
        # Aggregate text block for writing
        header_text = (
            f"Observation: {obsid}\n"\
            f"Infile: {infile}\n"\
            f"Date: {date}, Exptime: {exptime}\n"\
        )
        results_file.write(header_text)

        # PSF image and centroid fit
        img_width = 40  # physical pixels
        
        cent_binsize = 1.0 # Define binsize for this step
        src_psf_images(
            obsid, infile, pixel_x0_best, pixel_y0_best, img_width,
            wcs_ra=current_ra, wcs_dec=current_dec, # Pass WCS info
            binsize=cent_binsize, 
            psfimg=False, 
            empirical_psf=None
        )
        
        logical_width = img_width / cent_binsize
        img_center = logical_width / 2.0 + 0.5

        print('Centroiding...\n')
        
        centroid_fit_summary, centroid_fit_fig, cent_corner_fig = gaussian_image_fit(
            obsid, 1, (img_center, img_center), cnt, (1.0 / cent_binsize),
            prefix="centrg",
            background=0.1, 
            pos_max=(logical_width, logical_width),
            use_mcmc=False,
            confirm=False
        )
        
        if centroid_fit_summary is None:
            print(f"Centroid fit for {obsid} canceled. Skipping observation.")
            clean()
            continue
            
        temp_cent_fit_png = f"2Dfits/temp_{obsid}_cent_fit.png"
        centroid_fit_fig.savefig(temp_cent_fit_png)
        plt.close(centroid_fit_fig)
        pdf_out_files.append(temp_cent_fit_png)

        # Save centroid fit parameters
        results_file.write("\nCENTROID FIT SUMMARY:\n\n")
        results_file.write(centroid_fit_summary)

        # Retrieve centroid fit physical coordinates
        d = get_data()
        crval_x, crval_y = d.sky.crval
        crpix_x, crpix_y = d.sky.crpix
        cdelt_x, cdelt_y = d.sky.cdelt
        xphys_best = crval_x + (centrg1.xpos.val - crpix_x) * cdelt_x
        yphys_best = crval_y + (centrg1.ypos.val - crpix_y) * cdelt_y

        # Source fit in physical coordinates
        img_width = 10  # physical pixels
        
        src_binsize = 0.25 
        src_psf_images(
            obsid, infile, xphys_best, yphys_best, img_width,
            wcs_ra=current_ra, wcs_dec=current_dec,
            binsize=src_binsize, 
            psfimg=True, 
            empirical_psf=emp_psf_file
        )

        logical_width = img_width / src_binsize 
        img_center = logical_width / 2.0 + 0.5 
        
        pixel_scale_guess = 1.0 / src_binsize 
        scaled_cnt_guess = cnt / (pixel_scale_guess**2)
        scaled_fwhm_guess = 1.0 * pixel_scale_guess 

        print('Fitting Source...\n')
        
        src_fit_summary, src_fit_fig, src_corner_fig = gaussian_image_fit(
            obsid, 1, (img_center, img_center), scaled_cnt_guess, scaled_fwhm_guess,
            prefix="srcg",
            pos_max=(logical_width, logical_width),
            use_mcmc=False,
            confirm=False
        )
        
        if src_fit_summary is None:
            print(f"Source fit for {obsid} canceled. Skipping observation.")
            clean()
            continue

        temp_src_fit_png = f"2Dfits/temp_{obsid}_src_fit.png"
        src_fit_fig.savefig(temp_src_fit_png)
        plt.close(src_fit_fig)
        pdf_out_files.append(temp_src_fit_png)

        # Save source fit parameters
        results_file.write("SOURCE FIT SUMMARY:\n\n")
        results_file.write(src_fit_summary)

        # Multi-component fit
        
        srcfit_off_x = srcg1.xpos.val - img_center 
        srcfit_off_y = srcg1.ypos.val - img_center 
        src_ampl = srcg1.ampl.val
        src_fwhm = srcg1.fwhm.val
        
        img_width = 40 
        multi_binsize = 0.25
        
        src_psf_images(
            obsid, infile, xphys_best, yphys_best, img_width,
            wcs_ra=current_ra, wcs_dec=current_dec,
            binsize=multi_binsize, 
            empirical_psf=emp_psf_file
        )
        
        logical_width = img_width / multi_binsize 
        img_center = logical_width / 2.0 + 0.5   
        
        pixel_scale = src_binsize / multi_binsize 
        
        new_xpos = img_center + (srcfit_off_x * pixel_scale)
        new_ypos = img_center + (srcfit_off_y * pixel_scale)
        scaled_src_fwhm = src_fwhm * pixel_scale
        scaled_src_ampl = src_ampl / (pixel_scale**2)
        
        pixel_scale_guess = 1.0 / multi_binsize 
        scaled_cnt_ampl = cnt / (pixel_scale_guess**2)
        scaled_default_fwhm = 1.0 * pixel_scale_guess 

        n_components = 3  # total number of components (change as needed)
        positions = [(new_xpos, new_ypos)] + [(img_center, img_center)] * (n_components - 1)
        amplitudes = [scaled_src_ampl] + [scaled_cnt_ampl] * (n_components - 1)
        fwhms = [scaled_src_fwhm] + [scaled_default_fwhm] * (n_components - 1)

        print('Fitting multi-component Gaussian...')

        multi_fit_summary, multi_fit_fig, multi_corner_fig = gaussian_image_fit(
            obsid, n_components, positions, amplitudes, fwhms,
            prefix="g",
            background=0.1,
            pos_max=(logical_width, logical_width),
            pos_min=(0, 0),
            exptime=exptime,
            confirm=False,
            use_mcmc=True,
            mcmc_iter=100000,
            mcmc_scale=current_mcmc_scale
        )

        if multi_fit_summary is None:
            print(f"Multi-component fit for {obsid} canceled. Skipping observation.")
            clean()
            continue

        print('\nPlots saving to PDF...\n')
        
        temp_multi_fit_png = f"2Dfits/temp_{obsid}_multi_fit.png"
        multi_fit_fig.savefig(temp_multi_fit_png)
        plt.close(multi_fit_fig)
        pdf_out_files.append(temp_multi_fit_png)
        multi_pdf_out_files.append(temp_multi_fit_png)


        if multi_corner_fig is not None:
            temp_multi_corner_png = f"2Dfits/temp_{obsid}_multi_corner.png"
            multi_corner_fig.savefig(temp_multi_corner_png)
            plt.close(multi_corner_fig)
            pdf_out_files.append(temp_multi_corner_png)
            multi_pdf_out_files.append(temp_multi_corner_png)

        # Save multi-component fit parameters
        results_file.write("MULTI-COMPONENT FIT SUMMARY:\n\n")
        results_file.write(multi_fit_summary)

        # Aggregate text block for writing
        multi_results_text = (
            f"Observation: {obsid}\n"\
            f"Infile: {infile}\n"\
            f"Date: {date}, Exptime: {exptime}\n"\
            f"{multi_fit_summary}\n\n"\
        )
        multi_results_file.write(multi_results_text)

        # Clear the current Sherpa session to prepare for the next observation
        clean()
        print('Sherpa Session Cleaned\n\n')

# [MODIFIED] New PDF compilation and cleanup block
print('Tidying Up and Compiling PDFs...\n')

# Helper function to compile PNGs into a PDF
def compile_pngs_to_pdf(png_files, pdf_filename):
    if not png_files:
        print(f"No images to compile for {pdf_filename}.")
        return
    
    # Check if the first file exists before trying to open it
    if not os.path.exists(png_files[0]):
        print(f"ERROR: Cannot find file {png_files[0]} to start PDF.")
        return

    images = []
    # Open the first image
    img1 = Image.open(png_files[0]).convert('RGB')
    
    # Open subsequent images
    for png_file in png_files[1:]:
        if os.path.exists(png_file):
            images.append(Image.open(png_file).convert('RGB'))
        else:
            print(f"Warning: Missing file {png_file}, skipping.")
    
    # Save as PDF
    img1.save(pdf_filename, "PDF", resolution=100.0, save_all=True, append_images=images)

# Compile the main PDF
try:
    compile_pngs_to_pdf(pdf_out_files, pdf_out_filename)
    print(f"Successfully compiled {pdf_out_filename}")
except Exception as e:
    print(f"ERROR: Could not compile {pdf_out_filename}: {e}")

# Compile the multi-component-only PDF
try:
    compile_pngs_to_pdf(multi_pdf_out_files, multi_pdf_out_filename)
    print(f"Successfully compiled {multi_pdf_out_filename}")
except Exception as e:
    print(f"ERROR: Could not compile {multi_pdf_out_filename}: {e}")

# [NEW] Final, separate cleanup step
print("Cleaning up temporary PNG files...")
temp_files_to_clean = glob.glob("2Dfits/temp_*.png")
for f in temp_files_to_clean:
    try:
        os.remove(f)
    except Exception as e:
        print(f"Warning: Could not remove {f}: {e}")

print('Process Complete')