In [None]:
import os
import glob
import logging
import time
import datetime
from functools import partial

# Core scientific and plotting libraries
import numpy as np
import matplotlib.pyplot as plt
from astropy.io import fits
from scipy.stats import chi2
from scipy.optimize import curve_fit
from scipy.ndimage import rotate

# MCMC, PDF compilation, and Parallelization
import corner
from PIL import Image
from multiprocess import Pool, Manager
from tqdm.notebook import tqdm

# Set default Matplotlib styles
plt.rcParams.update({
    'figure.dpi': 400,
    'font.size': 18,
    'image.origin': 'lower',
})

# Suppress Sherpa info messages
logger = logging.getLogger("sherpa")
logger.setLevel(logging.WARNING)

In [None]:
def quickpos(x, y, x0, y0, iterations=1, size_list=None, binsize_list=None, doplot=False):
    """
    Iteratively refines the centroid position using 1D histogram fitting.
    """

    # Helper function to create step-plot coordinates from histograms
    def step_plot(x, y, binwidth):
        xsteps = np.ravel(np.column_stack((x - binwidth/2, x + binwidth/2)))
        ysteps = np.repeat(y, 2)
        return xsteps, ysteps

    # Set default window/bin sizes if none are provided
    if size_list is None:
        size_list = [np.max(x) - np.min(x)] * iterations
    if binsize_list is None:
        binsize_list = [1.0] * iterations

    # Initialize variables
    fig_list = []
    current_x0, current_y0 = x0, y0
    cnt = None
    best_x0, best_y0 = current_x0, current_y0 

    # Loop for each refinement iteration
    for i in range(iterations):
        size = size_list[i]
        binsize = binsize_list[i]

        # Get all points within the current window
        ob = np.where((np.abs(x - current_x0) < size) & (np.abs(y - current_y0) < size))
        
        # Skip if no points are found
        if len(ob[0]) == 0:
            # print(f"Warning: No points found in window for iteration {i+1}. Using previous values.")
            continue 

        # Create 1D histograms for X and Y
        xbins = np.arange(current_x0 - size, current_x0 + size + binsize, binsize)
        ybins = np.arange(current_y0 - size, current_y0 + size + binsize, binsize)

        xhist, xedges = np.histogram(x[ob], bins=xbins)
        yhist, yedges = np.histogram(y[ob], bins=ybins)

        # Get bin centers
        xval = 0.5 * (xedges[:-1] + xedges[1:])
        yval = 0.5 * (yedges[:-1] + yedges[1:])

        # Define 1D Gaussian model for fitting
        def gaussian(x, a, mu, sigma, offset):
            return a * np.exp(-((x - mu)**2) / (2 * sigma**2)) + offset

        # Fit Gaussian to x histogram
        try:
            xmax = np.max(xhist)
            x0_new = xval[np.argmax(xhist)]
            xestpar = [xmax, x0_new, 2 * binsize, 0]
            xpar, _ = curve_fit(gaussian, xval, xhist, p0=xestpar)
            best_x0 = xpar[1] 
            xcnt = xpar[0] * xpar[2] * np.sqrt(2 * np.pi)
        except Exception as e:
            # print(f"Warning: X-fit failed for iteration {i+1}: {e}. Using previous X value.")
            xcnt = 0

        # Fit Gaussian to y histogram
        try:
            ymax = np.max(yhist)
            y0_new = yval[np.argmax(yhist)]
            yestpar = [ymax, y0_new, 2 * binsize, 0]
            ypar, _ = curve_fit(gaussian, yval, yhist, p0=yestpar)
            best_y0 = ypar[1] 
            ycnt = ypar[0] * ypar[2] * np.sqrt(2 * np.pi)
        except Exception as e:
            # print(f"Warning: Y-fit failed for iteration {i+1}: {e}. Using previous Y value.")
            ycnt = 0

        # Estimate counts and update centroid for next iteration
        cnt = 0.5 * (xcnt + ycnt)
        current_x0 = best_x0
        current_y0 = best_y0

    return best_x0, best_y0, cnt, fig_list

####################################################################################################################################################################################

def data_extract_quickpos_iter(infile, iters=3, sizes=[10, 5, 1.5], binsizes=[0.1, 0.1, 0.05]):
    """
    Extracts data from a FITS file and runs quickpos to get an initial centroid.
    """
    # Open the FITS file
    with fits.open(infile) as obs:
        hdr = obs[1].header
        data = obs[1].data
        
        # Extract essential header info
        scale = hdr['tcdlt20']
        xc = hdr['tcrpx20']
        exptime = hdr['exposure']
            
        # Form modified julian date for this obs
        mjd_start = hdr['mjd-obs']
        half_expos = 0.5 * (hdr['tstop']-hdr['tstart'])
        date = mjd_start + half_expos / 86400
        
        # Convert event positions to arcsec
        x = (data['x'] - xc) * scale * 3600
        y = (data['y'] - xc) * scale * 3600
        
        # Filter a 20 arcsec radius region
        rr = np.sqrt(x**2 + y**2)
        ok = np.where(rr < 20)
        
        # Get a rough starting estimate of the centroid    
        x0_est = np.average(x[ok])
        y0_est = np.average(y[ok])

    # Set iteration parameters
    iterations = iters
    size_list = sizes
    binsize_list = binsizes
    
    # Run the iterative centroid refinement
    x0_best, y0_best, cnt, qp_figs = quickpos(x[ok], y[ok], x0_est, y0_est, iterations, size_list, binsize_list)
    
    # Convert best-fit arcsec position back to pixels
    pixel_x0_best = x0_best / (scale * 3600) + xc
    pixel_y0_best = y0_best / (scale * 3600) + xc

    return date, exptime, pixel_x0_best, pixel_y0_best, cnt, qp_figs

####################################################################################################################################################################################

def rotate_psf_array(psf_file, match_file, outfile):
    """
    Rotates a PSF image array based on a match file's ROLL_NOM
    and saves it with the *original* PSF's header.
    """

    # Get the roll angle from the observation's event file
    try:
        with fits.open(match_file) as hdu_match:
            # Check the primary header (HDU 0)
            if 'ROLL_NOM' in hdu_match[0].header:
                roll_nom = hdu_match[0].header['ROLL_NOM']
            # If not, check the EVENTS header (HDU 1)
            elif hdu_match[1].header and 'ROLL_NOM' in hdu_match[1].header:
                roll_nom = hdu_match[1].header['ROLL_NOM']
            # Fallback to ROLL_PNT in HDU 1
            elif hdu_match[1].header and 'ROLL_PNT' in hdu_match[1].header:
                roll_nom = hdu_match[1].header['ROLL_PNT']
                # print("  Note: Using 'ROLL_PNT' as 'ROLL_NOM' was not found.")
            else:
                print(f"  ERROR: Could not find 'ROLL_NOM' or 'ROLL_PNT' in {match_file}.")
                return
    except FileNotFoundError:
        print(f"  ERROR: Match file not found: {match_file}")
        return
    except Exception as e:
        print(f"  ERROR: Could not read match file header: {e}")
        return

    # Calculate the rotation angle (scipy rotates counter-clockwise)
    angle_to_rotate = roll_nom - 45.0

    # Open the empirical PSF file
    try:
        with fits.open(psf_file) as hdu_psf:
            # Get the main image data and header
            if hdu_psf[0].data is None:
                # Handle cases where data is in HDU 1
                psf_data = hdu_psf[1].data
                psf_header = hdu_psf[1].header
            else:
                psf_data = hdu_psf[0].data
                psf_header = hdu_psf[0].header
                
            if psf_data is None:
                print(f"ERROR: No image data found in HDU 0 or 1 of {psf_file}.")
                return
    except FileNotFoundError:
        print(f"  ERROR: PSF file not found: {psf_file}")
        return
    except Exception as e:
        print(f"  ERROR: Could not read PSF file data/header: {e}")
        return

    # Rotate the PSF data array
    rotated_psf_data = rotate(
        psf_data,
        angle_to_rotate,
        reshape=False,       # Keep the same array shape
        cval=0.0,            # Fill new pixels with 0
        order=3              # Cubic interpolation
    )

    # Save the new, rotated data
    
    # Add a HISTORY card to document the rotation
    timestamp = datetime.datetime.now().isoformat()
    psf_header.add_history(f"Rotated by {angle_to_rotate:.4f} deg (ROLL_NOM={roll_nom:.4f} - 45.0)")
    psf_header.add_history(f"Rotation applied by script on {timestamp}")

    # Write the new FITS file
    hdu_out = fits.PrimaryHDU(data=rotated_psf_data, header=psf_header)
    try:
        hdu_out.writeto(outfile, overwrite=True)
    except Exception as e:
        print(f"  ERROR: Could not write output file: {e}")

####################################################################################################################################################################################

def process_mcmc_results(covar_results, chains, burn_in_frac=0.2, sigma=1):
    """
    Processes raw MCMC chains to calculate median and error bounds.
    """
    
    # Transpose the chains array to (niter, nparams)
    chains = chains.T
    
    # Set percentile for the requested sigma level
    if sigma == 1:
        # 1-sigma (68.27%)
        q_low = 15.865
        q_high = 84.135
    elif sigma == 2:
        # 2-sigma (95.45%)
        q_low = 2.275
        q_high = 97.725
    elif sigma == 3:
        # 3-sigma (99.73%)
        q_low = 0.135
        q_high = 99.865
    else:
        raise ValueError("Only sigma=1, 2, or 3 is supported.")

    # Calculate number of burn-in steps
    n_iter, n_params = chains.shape
    burn_in = int(n_iter * burn_in_frac)
    
    # Handle case where burn_in is too high
    if burn_in >= n_iter:
        # print(f"Warning: burn_in_frac ({burn_in_frac}) is too high, resulting in 0 valid chains.")
        # print(f"         Resetting burn_in to 0 for this run.")
        burn_in = 0
        
    # Get the valid chains after discarding burn-in
    valid_chains = chains[burn_in:, :]
    
    # Initialize results dictionary
    mcmc_results = {
        'parnames': [],
        'parvals': [],  # Median (50th percentile)
        'parmins': [],  # Lower error bound (e.g., 16th percentile)
        'parmaxes': []  # Upper error bound (e.g., 84th percentile)
    }

    # Get thawed parameter names from the covariance results
    thawed_parnames = covar_results.parnames

    # Check for parameter/chain dimension mismatch
    if len(thawed_parnames) != n_params:
        # print(f"Warning: Number of thawed params ({len(thawed_parnames)}) does not match chain dims ({n_params}).")
        thawed_parnames = covar_results.parnames[:n_params]

    # Iterate over each parameter's chain
    for i, parname in enumerate(thawed_parnames):
        chain_col = valid_chains[:, i]
        
        # Calculate quantiles (median, 1-sigma lower, 1-sigma upper)
        p_low, p_mid, p_high = np.percentile(chain_col, [q_low, 50, q_high])
        
        # Store results
        mcmc_results['parnames'].append(parname)
        mcmc_results['parvals'].append(p_mid)
        mcmc_results['parmins'].append(p_low)
        mcmc_results['parmaxes'].append(p_high)

    # Return the summary dict, valid chains, and param names
    return mcmc_results, valid_chains, thawed_parnames

In [None]:
def process_observation(infile, progress_queue, obsid_coords, mcmc_scale_factors, emp_psf_file,
                        n_components_multi, run_mcmc_multi, mcmc_iter_multi,
                        mcmc_n_walkers, mcmc_ball_size, progress_chunks=50, recalc=False):
    """
    Worker function to process a single observation.
    """
    
    # Process Local Imports
    from sherpa.astro.ui import (
        load_data, image_close, image_data, load_psf, set_psf, image_psf,
        gauss2d, const2d, set_source, freeze, link, show_model, set_stat,
        set_method, set_method_opt, fit, get_fit_results, thaw, set_sampler,
        set_sampler_opt, covar, get_covar_results, get_draws,
        get_model_component_image, get_data_image, get_model_image, get_data,
        clean, calc_stat
    )
    
    from ciao_contrib.runtool import dmcopy, reproject_image, dmhedit
    from coords.format import ra2deg, dec2deg
    
    # Import plotting and math libraries
    import matplotlib.pyplot as plt
    import matplotlib.colors as mcolors
    from matplotlib.lines import Line2D
    import corner
    import emcee
    from emcee.backends import HDFBackend
    import numpy as np
    from astropy.io import fits
    from scipy.stats import chi2
    from scipy.optimize import curve_fit
    from scipy.ndimage import rotate
    import datetime
    import time
    import logging
    import os

    logging.getLogger("sherpa").setLevel(logging.WARNING)

    
    # Process Local Helper Functions
    
    def write_pixelscale(file: str, nx: int, ny: int, ra: str, dec: str, hrc_pscale_arcsec: float = 0.13175):
        """Adds a WCS header to a FITS file using dmhedit."""
        x_pix_ctr = (nx / 2.0) + 0.5
        y_pix_ctr = (ny / 2.0) + 0.5
        hrc_pscale_deg = hrc_pscale_arcsec / 3600.
        x_platescale = -abs(hrc_pscale_deg / 4.)
        y_platescale = abs(hrc_pscale_deg / 4.)
        ra_deg = ra2deg(ra)
        dec_deg = dec2deg(dec)
        wcs_params = [
            ("WCSAXES", 2, "short", None), ("CRPIX1", x_pix_ctr, "float", None),
            ("CRPIX2", y_pix_ctr, "float", None), ("CDELT1", x_platescale, "float", "deg"),
            ("CDELT2", y_platescale, "float", "deg"), ("CUNIT1", "deg", "string", None),
            ("CUNIT2", "deg", "string", None), ("CTYPE1", "RA---TAN", "string", None),
            ("CTYPE2", "DEC--TAN", "string", None), ("CRVAL1", ra_deg, "float", "deg"),
            ("CRVAL2", dec_deg, "float", "deg"), ("LONPOLE", 180.0, "float", "deg"),
            ("LATPOLE", 0, "float", "deg"), ("RADESYS", "ICRS", "string", None),
        ]
        try:
            for key, value, dtype, unit in wcs_params:
                dmhedit(infile=file, op="add", key=key, value=value, datatype=dtype, unit=unit)
        except Exception as e:
            obsid = os.path.basename(os.path.dirname(file))
            print(f"  ERROR (ObsID {obsid}): dmhedit failed: {e}")

    def compute_split_rhat(chain):
        """
        Calculates the Split-Rhat statistic for convergence.
        Chain shape: (n_steps, n_walkers, n_params)
        """
        n_steps, n_walkers, n_params = chain.shape
        
        # Split chains in half to check stationarity within chains
        half = n_steps // 2
        split_chain = np.concatenate((chain[:half], chain[half:]), axis=1)
        
        N = half
        M = n_walkers * 2
        
        # Calculate Within-Chain Variance (W)
        # Variance of each chain, then mean across chains
        var_within = np.var(split_chain, axis=0, ddof=1)
        W = np.mean(var_within, axis=0)
        
        # Calculate Between-Chain Variance (B)
        # Mean of each chain
        mean_chains = np.mean(split_chain, axis=0)
        # Variance of those means, multiplied by N
        B = N * np.var(mean_chains, axis=0, ddof=1)
        
        # Calculate Var+ (Marginal Posterior Variance)
        var_plus = ((N - 1) / N) * W + (1 / N) * B
        
        # R-hat
        rhat = np.sqrt(var_plus / W)
        return rhat

    def src_psf_images(obsid, infile, x0, y0, diameter, wcs_ra, wcs_dec, binsize=0.25, shape='square', psfimg=True, showimg=False, empirical_psf=None):
        """Creates and loads source and (optionally) PSF images into Sherpa."""
        if shape.lower() == 'circle':
            region_str = f"circle({x0},{y0},{diameter/2})"
        elif shape.lower() == 'square':
            region_str = f"box({x0},{y0},{diameter},{diameter},0)"
            img_region_str = f"box(256.5,256.5,{diameter*4},{diameter*4},0)"
        else:
            region_str = shape.lower()
            
        logical_width = diameter/binsize
        imagefile=f'{obsid}/src_image_{shape}_{int(logical_width)}pixel.fits'
        psf_rotated = f'{obsid}/psf_rotated.fits'
        psf_rotated_cut = f'{obsid}/psf_rotated_cut.fits'
        emp_psf_imagefile = f'{obsid}/psf_image_{shape}_empirical_{int(logical_width)}pixel.fits'
        
        dmcopy.punlearn()
        dmcopy.clobber = 'yes'
        reproject_image.punlearn()
        reproject_image.clobber = 'yes'

        dmcopy.infile = f'{infile}[sky={region_str}][bin x=::{binsize},y=::{binsize}]'
        dmcopy.outfile = imagefile
        dmcopy()
        load_data(imagefile)

        if empirical_psf is not None:
            try:
                with fits.open(infile) as hdu_match:
                    if 'ROLL_NOM' in hdu_match[0].header:
                        roll_nom = hdu_match[0].header['ROLL_NOM']
                    elif hdu_match[1].header and 'ROLL_NOM' in hdu_match[1].header:
                        roll_nom = hdu_match[1].header['ROLL_NOM']
                    elif hdu_match[1].header and 'ROLL_PNT' in hdu_match[1].header:
                        roll_nom = hdu_match[1].header['ROLL_PNT']
                    else:
                        print(f"  ERROR: Could not find 'ROLL_NOM' or 'ROLL_PNT' in {infile}")
                        return
            except FileNotFoundError:
                print(f"  ERROR: Match file not found: {infile}")
                return
            except Exception as e:
                print(f"  ERROR: Could not read match file header: {e}")
                return
            angle_to_rotate = roll_nom - 45.0
            try:
                with fits.open(empirical_psf) as hdu_psf:
                    if hdu_psf[0].data is None:
                        psf_data = hdu_psf[1].data
                        psf_header = hdu_psf[1].header
                    else:
                        psf_data = hdu_psf[0].data
                        psf_header = hdu_psf[0].header
                    if psf_data is None:
                        print(f"ERROR: No image data found in {empirical_psf}.")
                        return
            except FileNotFoundError:
                print(f"  ERROR: PSF file not found: {empirical_psf}")
                return
            except Exception as e:
                print(f"  ERROR: Could not read PSF file data/header: {e}")
                return
            rotated_psf_data = rotate(
                psf_data, angle_to_rotate, reshape=False, cval=0.0, order=3
            )
            timestamp = datetime.datetime.now().isoformat()
            psf_header.add_history(f"Rotated by {angle_to_rotate:.4f} deg (ROLL_NOM={roll_nom:.4f} - 45.0)")
            psf_header.add_history(f"Rotation applied by script on {timestamp}")
            hdu_out = fits.PrimaryHDU(data=rotated_psf_data, header=psf_header)
            try:
                hdu_out.writeto(psf_rotated, overwrite=True)
            except Exception as e:
                print(f"  ERROR: Could not write output file: {psf_rotated}")
            
            try:
                with fits.open(psf_rotated) as hdu_rot:
                    nx = hdu_rot[0].header['NAXIS1']
                    ny = hdu_rot[0].header['NAXIS2']
                write_pixelscale(file=psf_rotated, nx=nx, ny=ny, ra=str(wcs_ra), dec=str(wcs_dec))
            except Exception as e:
                print(f"!!! ERROR (ObsID {obsid}): WCS stamping failed: {e}")

            dmcopy.infile = f'{psf_rotated}[{img_region_str}][bin x=::{binsize*4},y=::{binsize*4}]'
            dmcopy.outfile = psf_rotated_cut
            dmcopy()
            reproject_image.infile = psf_rotated_cut
            reproject_image.matchfile = imagefile
            reproject_image.outfile = emp_psf_imagefile
            reproject_image.method = 'sum'
            reproject_image()
            load_psf(f'centr_psf{obsid}', emp_psf_imagefile)
            set_psf(f'centr_psf{obsid}')
        elif psfimg:
            psf_infile = f'{obsid}/raytrace_projrays.fits'
            psf_imagefile = f'{obsid}/psf_image_{shape}_raytrace_{int(logical_width)}pixel.fits'
            dmcopy.infile = f'{psf_infile}[sky={region_str}][bin x=::{binsize},y=::{binsize}]'
            dmcopy.outfile = psf_imagefile
            dmcopy()
            load_psf(f'centr_psf{obsid}', psf_imagefile)
            set_psf(f'centr_psf{obsid}')

        return binsize

    def gaussian_image_fit(observation, n_components, position, ampl, fwhm,
                           background=0, pos_min=(0, 0), pos_max=None, exptime=None, lock_fwhm=True,
                           freeze_components=None, use_mcmc=True, mcmc_iter=5000, mcmc_burn_in_frac=0.2,
                           n_walkers=32, ball_size=1e-4, 
                           prefix="g", confirm=True, imgfit=False, progress_chunks=50):
        """
        Fits multi-component 2D Gaussian models using Sherpa for optimization
        and emcee for error estimation.
        """
        
        # Helper to expand single value inputs
        def process_numeric_param(param, name):
            if isinstance(param, (int, float)): return [param] * n_components
            elif isinstance(param, list):
                if len(param) != n_components: raise ValueError(f"List of {name} must have length {n_components}.")
                return param
            else: raise ValueError(f"{name} must be a number or a list.")

        # Helper to expand single tuple inputs
        def process_tuple_param(param, name):
            if isinstance(param, (tuple, list)) and len(param) == 2 and all(isinstance(x, (int, float)) for x in param):
                return [param] * n_components
            elif isinstance(param, list):
                if len(param) != n_components: raise ValueError(f"List of {name} must have length {n_components}.")
                return param
            else: raise ValueError(f"{name} must be a tuple (x, y) or a list.")

        # Process parameters and Build Model
        positions = process_tuple_param(position, "position")
        ampls = process_numeric_param(ampl, "ampl")
        fwhms = process_numeric_param(fwhm, "fwhm")
        pos_mins = process_tuple_param(pos_min, "pos_min")
        pos_maxs = [None] * n_components if pos_max is None else process_tuple_param(pos_max, "pos_max")

        comp_names = []
        gaussian_components = []
        model_components = []
        for i in range(1, n_components + 1):
            comp_name = f"{prefix}{i}"
            comp_names.append(comp_name)
            comp = gauss2d(comp_name)
            gaussian_components.append(comp)
            model_components.append(comp)
        
        bkg_comp = None
        if background > 0:
            bkg_comp = const2d("c1")
            model_components.append(bkg_comp)
        
        if model_components: set_source(sum(model_components))
        else: raise ValueError("Model expression is empty.")

        freeze_list = (freeze_components if isinstance(freeze_components, list) else ([freeze_components] if freeze_components is not None else []))
        for i, comp in enumerate(gaussian_components):
            comp_number = i + 1
            comp.xpos = positions[i][0]; comp.ypos = positions[i][1]
            comp.ampl = ampls[i]; comp.fwhm = fwhms[i]
            if hasattr(comp.xpos, 'min'): comp.xpos.min = pos_mins[i][0]
            if hasattr(comp.ypos, 'min'): comp.ypos.min = pos_mins[i][1]
            if pos_maxs[i] is not None:
                if hasattr(comp.xpos, 'max'): comp.xpos.max = pos_maxs[i][0]
                if hasattr(comp.ypos, 'max'): comp.ypos.max = pos_maxs[i][1]
            if hasattr(comp.ampl, 'min'): comp.ampl.min = 0
            if comp_number in freeze_list: freeze(comp)

        central_component = 1
        if lock_fwhm:
            master = gaussian_components[central_component-1].fwhm
            for idx, comp in enumerate(gaussian_components):
                if idx != (central_component-1): link(comp.fwhm, master)

        if bkg_comp is not None:
            bkg_comp.c0 = background
            if hasattr(bkg_comp.c0, 'min'): bkg_comp.c0.min = 0

        if confirm:
            show_model()
            if input(f"  (ObsID {observation}) Proceed with fit? (y/n): ").lower() != "y": return None, None, None, None, None

        # Optimization
        set_stat('cstat')
        set_method('moncar'); set_method_opt('numcores', 1)
        set_method_opt('population_size', 10 * 16 * (n_components * 3 + 1)); set_method_opt('xprob', 0.5); set_method_opt('weighting_factor', 0.5)
        fit()
        set_method('simplex'); fit()
        fit_results = get_fit_results()

        # Identify Free Parameters
        thawed_pars = []
        thawed_par_names = []
        for i, comp in enumerate(gaussian_components):
            comp_number = i + 1
            if comp_number not in freeze_list:
                thaw(comp.ampl); thawed_pars.append(comp.ampl); thawed_par_names.append(comp.ampl.fullname)
                if not (lock_fwhm and comp_number != central_component):
                     thaw(comp.fwhm); thawed_pars.append(comp.fwhm); thawed_par_names.append(comp.fwhm.fullname)
                thaw(comp.xpos, comp.ypos)
                thawed_pars.append(comp.xpos); thawed_par_names.append(comp.xpos.fullname)
                thawed_pars.append(comp.ypos); thawed_par_names.append(comp.ypos.fullname)
        if bkg_comp is not None and not bkg_comp.c0.frozen:
            thawed_pars.append(bkg_comp.c0); thawed_par_names.append(bkg_comp.c0.fullname)

        # Capture Best Fit Values (Simplex)
        best_fit_values = [p.val for p in thawed_pars]
        best_fit_stat = fit_results.statval

        mcmc_results = None
        walker_map_fig = None
        corner_fig = None
        mcmc_duration_str = ""
        flux_results = None
        
        if use_mcmc:
            mcmc_start_time = time.time()
            ndim = len(thawed_pars)
            
            def log_probability(theta):
                for param, value in zip(thawed_pars, theta):
                    if value < param.min or value > param.max: return -np.inf
                for param, value in zip(thawed_pars, theta): param.val = value
                return -0.5 * calc_stat()

            current_n_walkers = n_walkers if n_walkers >= 2 * ndim else 2 * ndim + 2
            
            # Chain Storage and Resumption Logic
            # Format: mcmc-chain-4comp-nwalkers-nsteps-0p0001ball
            ball_str = str(ball_size).replace('.', 'p')
            
            param_folder_name = (f"mcmc-chain-{n_components}comp-"
                                 f"{current_n_walkers}walkers-"
                                 f"{mcmc_iter}steps-"
                                 f"{ball_str}ball")
            
            chain_dir = os.path.join(os.getcwd(), "2Dfits", "emcee_chains", param_folder_name)
            os.makedirs(chain_dir, exist_ok=True)
            chain_filename = os.path.join(chain_dir, f"{obsid}_chain.h5")
            
            # Use compression to reduce file size
            backend = HDFBackend(chain_filename, compression="gzip", compression_opts=4)
            
            # Force Recalculation Logic
            # If recalc=True (passed from main), we wipe the backend to start fresh
            if recalc:
                backend.reset(current_n_walkers, ndim)

            # Determine if we are resuming or starting fresh
            run_sampler = True
            p0 = None
            
            # If file exists and has data, check status
            if os.path.exists(chain_filename) and backend.iteration > 0:
                if backend.iteration >= mcmc_iter:
                    print(f"  (ObsID {observation}) Found completed chain ({backend.iteration} steps). Loading results...")
                    run_sampler = False
                else:
                    print(f"  (ObsID {observation}) Resuming chain from step {backend.iteration}...")
                    p0 = None # Setting p0 to None tells emcee to resume from backend
            else:
                # Start fresh: Generate initial walker positions
                backend.reset(current_n_walkers, ndim)
                best_fit_pos = np.array(best_fit_values)
                p0 = best_fit_pos + ball_size * np.random.randn(current_n_walkers, ndim)
                for i, param in enumerate(thawed_pars):
                    p0[:, i] = np.clip(p0[:, i], param.min + 1e-6, param.max - 1e-6)

            # Outer try block to catch MCMC and Plotting errors
            try:
                # Initialize Sampler
                sampler = emcee.EnsembleSampler(current_n_walkers, ndim, log_probability, backend=backend)
                
                if run_sampler:
                    # Inner try to catch sampler crashes specifically
                    try:
                        # Calculate how many steps are left
                        current_step = backend.iteration
                        remaining_steps = mcmc_iter - current_step
                        
                        update_interval = max(1, int(mcmc_iter / progress_chunks))
                        
                        # Run loop (p0 is None if resuming)
                        for i, sample in enumerate(sampler.sample(p0, iterations=remaining_steps, progress=False)):
                            if (i + 1) % update_interval == 0: progress_queue.put(1)
                            
                    except Exception as e:
                        print(f"  ERROR (ObsID {observation}) Sampler crashed: {e}")
                        # Note: We proceed to try loading whatever data we have
                
                # Load chains for analysis (discarding burn-in)
                discard = int(mcmc_iter * mcmc_burn_in_frac)
                
                # Check if we have enough samples to discard
                if sampler.iteration < discard:
                     discard = 0
                     
                flat_samples = sampler.get_chain(discard=discard, flat=True)
                raw_chain = sampler.get_chain(discard=discard) 
                
                # Robust Convergence Statistics
                # 1. Autocorrelation Time (tau)
                try:
                    # tol=0 prevents error if chain is short
                    tau = sampler.get_autocorr_time(tol=0) 
                    tau_max = np.max(tau)
                    # Effective Sample Size (ESS)
                    ess = (raw_chain.shape[0] * raw_chain.shape[1]) / tau_max
                except Exception:
                    tau = [np.nan] * ndim
                    tau_max = np.nan
                    ess = 0

                # 2. Gelman-Rubin (Split-Rhat)
                try:
                    rhat_vals = compute_split_rhat(raw_chain)
                    rhat_max = np.max(rhat_vals)
                except Exception as e:
                    print(f"Warning: R-hat calc failed: {e}")
                    rhat_vals = [np.nan] * ndim
                    rhat_max = np.nan

                # Convergence String for Text Output
                conv_str = (
                    f"Convergence Stats:\n"
                    f"  Max Autocorr Time (tau): {tau_max:.1f} steps\n"
                    f"  Max Split-Rhat:          {rhat_max:.4f} (Goal < 1.1)\n"
                    f"  Effective Samples (ESS): {int(ess)}\n"
                    f"  Chain Length / tau:      {raw_chain.shape[0] / tau_max:.1f} (Goal > 50)\n\n"
                )

                # Compute MCMC Stats
                q_low, q_mid, q_high = 15.865, 50.0, 84.135
                mcmc_results_data = {'parnames': [], 'parvals': [], 'parmins': [], 'parmaxes': []}
                for i, name in enumerate(thawed_par_names):
                    mcmc_vals = flat_samples[:, i]
                    p_low, p_mid, p_high = np.percentile(mcmc_vals, [q_low, q_mid, q_high])
                    mcmc_results_data['parnames'].append(name)
                    mcmc_results_data['parvals'].append(p_mid)
                    mcmc_results_data['parmins'].append(p_low)
                    mcmc_results_data['parmaxes'].append(p_high)
                mcmc_results = mcmc_results_data

                # Get log_prob for all samples in flat chain
                log_probs = sampler.get_log_prob(discard=discard, flat=True)
                max_idx = np.argmax(log_probs)
                
                # Evaluate stat for MCMC best vs Simplex best
                best_mcmc_stat = -2.0 * log_probs[max_idx]
                
                if best_mcmc_stat < best_fit_stat:
                    # MCMC found a better minimum, updating King of the Hill
                    best_fit_values = list(flat_samples[max_idx])
                    best_fit_stat = best_mcmc_stat

                # Flux and count-rate uncertainties
                flux_results = {}
                if exptime is not None:
                    fwhm_master_name = gaussian_components[central_component - 1].fwhm.fullname
                    if fwhm_master_name in thawed_par_names:
                        f_idx_master = thawed_par_names.index(fwhm_master_name)
                        F_chain = flat_samples[:, f_idx_master]
                        for comp in gaussian_components:
                            amp_name = comp.ampl.fullname
                            if amp_name in thawed_par_names:
                                a_idx = thawed_par_names.index(amp_name)
                                A_chain = flat_samples[:, a_idx]
                                flux_chain = A_chain * (F_chain**2)
                                F_low, F_mid, F_high = np.percentile(flux_chain, [q_low, q_mid, q_high])
                                flux_results[comp.name] = (F_low, F_mid, F_high)

                # RESTORE BEST FIT for Plots
                for param, val in zip(thawed_pars, best_fit_values): param.val = val

                # Walker Spatial Map (Shaded Density Contour Version)
                walker_map_fig, ax = plt.subplots(1, 1, figsize=(19, 19))
                
                data_img = get_data_image(); data_vals = data_img.y
                
                # Get dimensions for proper extent
                ny, nx = data_vals.shape
                # Astronomical convention: Center of 1st pixel is (1,1)
                plot_extent = [0.5, nx + 0.5, 0.5, ny + 0.5]
                
                min_pos = np.min(data_vals[data_vals > 0]) if np.any(data_vals > 0) else 1e-9
                display_floor = min_pos / 10.0
                data_masked = np.maximum(data_vals, display_floor)
                log_norm = mcolors.LogNorm(vmin=display_floor, vmax=np.max(data_vals))
                
                # Pass extent to imshow so it aligns with 1-based MCMC coordinates
                im_data = ax.imshow(data_masked, origin='lower', cmap='gray_r', norm=log_norm, 
                                    interpolation='nearest', extent=plot_extent)
                
                colors =      ['cyan', 'lime',      'magenta', 'orange',        'yellow']
                # Darker colors for Best Fit and Median markers, including xkcd colors
                dark_colors = ['navy', 'darkgreen', 'indigo',  'xkcd:burgundy', 'xkcd:shit']
                
                for i, comp_name in enumerate(comp_names):
                    x_name = f"{comp_name}.xpos"; y_name = f"{comp_name}.ypos"
                    
                    if x_name in thawed_par_names and y_name in thawed_par_names:
                        x_idx = thawed_par_names.index(x_name)
                        y_idx = thawed_par_names.index(y_name)
                        
                        # Extract ALL points for this component
                        x_pts = raw_chain[:, :, x_idx].flatten()
                        y_pts = raw_chain[:, :, y_idx].flatten()
                        
                        # Generate 2D Density (Histogram)
                        # Range matches the plot extent exactly
                        H, xedges, yedges = np.histogram2d(
                            y_pts, x_pts, 
                            bins=[ny, nx], 
                            range=[[0.5, ny + 0.5], [0.5, nx + 0.5]]
                        )
                        
                        # Plot Contours if data exists
                        if np.sum(H) > 0:
                            comp_color = colors[i % len(colors)]
                            dark_c = dark_colors[i % len(dark_colors)]
                            base_rgb = mcolors.to_rgb(comp_color)
                            
                            peak = H.max()
                            levels = [peak * 0.1, peak * 0.3, peak * 0.5, peak * 0.7, peak * 0.9]
                            
                            fill_colors = [
                                (*base_rgb, 0.1), 
                                (*base_rgb, 0.3), 
                                (*base_rgb, 0.5),
                                (*base_rgb, 0.7),
                                (*base_rgb, 0.9)
                            ]
                            
                            # Filled Contours with Gradient Shading
                            ax.contourf(H, levels=levels, colors=fill_colors, extend='max', extent=plot_extent)
                            # Line Contours (Thinner)
                            ax.contour(H, levels=levels, colors=[comp_color], linewidths=1.0, alpha=0.9, extent=plot_extent)

                            # Plot Best Fit Position
                            bf_x = best_fit_values[x_idx]; bf_y = best_fit_values[y_idx]
                            ax.scatter(bf_x, bf_y, marker='o', color=dark_c, s=100, zorder=20, 
                                       edgecolors='white', label=f"{comp_name} Best Fit")

                            # Plot Median Position
                            if mcmc_results is not None:
                                med_x = mcmc_results['parvals'][x_idx]
                                med_y = mcmc_results['parvals'][y_idx]
                                ax.scatter(med_x, med_y, marker='x', color=dark_c, s=200, linewidth=3, zorder=19,
                                           label=f"{comp_name} Median")

                ax.set_title(f"Walker Density Map - ObsID {observation}"); ax.set_xlabel("X Pixel"); ax.set_ylabel("Y Pixel")
                
                handles, labels = ax.get_legend_handles_labels()
                by_label = dict(zip(labels, handles))
                
                # Check if Median is present, if not add proxy
                if not any("Median" in l for l in labels) and mcmc_results is not None:
                     p_med = Line2D([0], [0], color='black', marker='x', linestyle='None', 
                                    markersize=10, markeredgewidth=3, label='Median')
                     by_label['Median'] = p_med

                ax.legend(by_label.values(), by_label.keys(), loc='upper right')
                walker_map_fig.colorbar(im_data, ax=ax, label="Counts", shrink=0.8); walker_map_fig.tight_layout()

                # Corner Figure (Conditional Downsampling)
                total_samples = flat_samples.shape[0]
                # Threshold: 1,000,000
                threshold = 1000000 
                
                if total_samples > threshold:
                    stride = int(total_samples / threshold)
                    plot_samples = flat_samples[::stride]
                    title_suffix = f"(Downsampled {stride}x)"
                else:
                    plot_samples = flat_samples
                    title_suffix = "(Full Chain)"

                corner_fig = corner.corner(
                    plot_samples,
                    labels=thawed_par_names, quantiles=[0.16, 0.5, 0.84],
                    show_titles=True, title_fmt=".3f",
                    truths=best_fit_values, truth_color='red',
                    quiet=True
                )
                corner_fig.suptitle(f"Corner Plot {title_suffix} - ObsID {observation}", y=1.02)

                mcmc_end_time = time.time()
                mcmc_duration_min = (mcmc_end_time - mcmc_start_time) / 60.0
                mcmc_duration_str = (f"emcee execution time = {mcmc_duration_min:.2f} minutes\n"
                                     f"Mean acceptance fraction = {np.mean(sampler.acceptance_fraction):.3f}\n"
                                     f"{conv_str}")

            except Exception as e:
                mcmc_results = None
                mcmc_duration_str = f"emcee FAILED: {e}\n\n"

        # Fit Summary Text
        fit_summary = (
            f"Method = {fit_results.methodname}\nStatistic = {fit_results.statname}\n"
            f"Final C-stat = {best_fit_stat:.2f} (Simplex+MCMC)\n" 
            f"Reduced statistic = {fit_results.rstat:.5f}\n\n"
        )
        
        def fmt_val(val): return "------" if val is None else f"{val:>10.3f}"
            
        if mcmc_results is not None:
            param_table = [
                f"emcee Results (Red Line = Best Fit, Black = Median):",
                f"{'Param':<12} {'Best Fit':>10} {'Median':>10} {'-Error':>10} {'+Error':>10}",
                f"{'-'*5:<12} {'-'*8:>10} {'-'*8:>10} {'-'*8:>10} {'-'*8:>10}"
            ]
            for name, median, low, high, best in zip(mcmc_results['parnames'], 
                                                     mcmc_results['parvals'], 
                                                     mcmc_results['parmins'], 
                                                     mcmc_results['parmaxes'],
                                                     best_fit_values):
                # Calculate explicit +/- errors relative to median
                err_minus = median - low
                err_plus = high - median
                display_name = name.split('.')[-2] + '.' + name.split('.')[-1] if '.' in name else name
                param_table.append(f"{display_name:<12} {fmt_val(best)} {fmt_val(median)} {fmt_val(err_minus)} {fmt_val(err_plus)}")
            param_table = "\n".join(param_table)
        else:
            param_table = "Best-Fit Values (No MCMC):\n" + "\n".join([f"{n:<12} {fmt_val(v)}" for n, v in zip(fit_results.parnames, fit_results.parvals)])
                
        summary_output = fit_summary + mcmc_duration_str + param_table + '\n'

        # Component count-rate block uses full MCMC flux (A * FWHM^2)
        if exptime and use_mcmc and mcmc_results is not None:
            rate_block_rows = ["Component count rates (counts/s):"]
            
            for comp in gaussian_components:
                short = comp.name.split('.')[-1]
                if flux_results is not None and comp.name in flux_results:
                    F_low, F_mid, F_high = flux_results[comp.name]
                    rate_mid = F_mid / exptime
                    rate_minus = (F_mid - F_low) / exptime
                    rate_plus = (F_high - F_mid) / exptime
                    rate_block_rows.append(
                        f"  {short:<6}: {rate_mid:7.4f}  (-{rate_minus:6.4f}/+{rate_plus:6.4f})"
                    )
                else:
                    # Fallback: use best-fit model component counts with no propagated error
                    comp_img = get_model_component_image(comp.name)
                    total_cts = comp_img.y.sum()
                    rate = total_cts / exptime
                    rate_block_rows.append(
                        f"  {short:<6}: {rate:7.4f}  (no MCMC rate errors)"
                    )
            summary_output += "\n" + "\n".join(rate_block_rows) + "\n"
        else:
            summary_output = fit_summary + param_table + '\n\n\n\n'

        fig, axs = plt.subplots(1, 3, figsize=(30, 15))
        plot_idx = 0
        data_img = get_data_image(); data_vals = data_img.y
        min_pos = np.min(data_vals[data_vals > 0]) if np.any(data_vals > 0) else 1e-9
        display_floor = min_pos / 10.0
        data_masked = np.maximum(data_vals, display_floor)
        model_img = get_model_image(); model_vals = model_img.y
        model_masked = np.maximum(model_vals, display_floor)
        D = 2.0 * (data_masked * np.log(data_masked / model_masked) - (data_masked - model_masked))
        resid_dev = np.sign(data_vals - model_vals) * np.sqrt(np.abs(D))
        log_norm = mcolors.LogNorm(vmin=display_floor, vmax=np.max(data_vals))
        
        ax = axs[0]
        im = ax.imshow(data_masked, origin='lower', cmap='gnuplot2', norm=log_norm, interpolation='nearest')
        
        base_colors = ['white', 'cyan', 'lime', 'xkcd:periwinkle']
        legend_elements = []
        
        for i, comp_name in enumerate(comp_names):
            comp_vals = get_model_component_image(comp_name).y
            if not np.any(comp_vals > 0): continue

            color = base_colors[i % len(base_colors)]
            
            ax.contour(comp_vals, levels=[0.2 * np.max(comp_vals)], colors=[color], linestyles=['--'], linewidths=2)
            
            # Add to legend list
            legend_elements.append(Line2D([0], [0], lw=2, linestyle='--', color=color, label=f"{comp_name}"))
        
        if legend_elements:
            ax.legend(handles=legend_elements, loc='upper right')
                
        ax.set_title(f"{observation} Data + Best Fit Overlay"); fig.colorbar(im, ax=ax, label="Counts", shrink=0.53)

        ax = axs[1]
        im = ax.imshow(model_masked, origin='lower', cmap='gnuplot2', norm=log_norm, interpolation='nearest')
        ax.set_title("Best Fit Model"); fig.colorbar(im, ax=ax, label="Model Counts", shrink=0.53)

        ax = axs[2]
        im = ax.imshow(np.abs(resid_dev), origin='lower', cmap='gnuplot2', norm=mcolors.Normalize(vmin=0, vmax=5), interpolation='nearest')
        ax.set_title("Poisson Deviance (Best Fit)"); fig.colorbar(im, ax=ax, label="|Residuals|", shrink=0.53)
            
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        
        return summary_output, fig, corner_fig, walker_map_fig

    # Main Worker Logic
    
    # Initialize lists to store temp file paths
    pdf_out_files = []
    multi_pdf_out_files = []
    
    # Extract observation directory name
    obsid = os.path.dirname(os.path.dirname(infile))

    # set random seed based on obsid for reproducible chains
    try:
        np.random.seed(int(obsid))
    except ValueError:
        # fallback for non integer obsids
        np.random.seed(hash(obsid) % (2**32 - 1))

    # Get ObsID specific coordinates
    if obsid not in obsid_coords:
        print(f"!!! WARNING: ObsID {obsid} not in coordinate lookup table. Skipping.")
        return (obsid, "", "", "", "", "", [], [])
    current_ra, current_dec = obsid_coords[obsid]
    
    # Run initial data extraction and centroiding
    date, exptime, pixel_x0_best, pixel_y0_best, cnt, qp_figs = data_extract_quickpos_iter(infile)
    
    # Store the header text
    header_text = (
        f"Observation: {obsid}\n"
        f"Infile: {infile}\n"
        f"Date: {date}, Exptime: {exptime}\n"
    )

    # Stage 1: Centroid Fit
    img_width = 40
    cent_binsize = 1.0
    src_psf_images(
        obsid, infile, pixel_x0_best, pixel_y0_best, img_width,
        wcs_ra=current_ra, wcs_dec=current_dec,
        binsize=cent_binsize, 
        psfimg=False, 
        empirical_psf=None
    )
    logical_width = img_width / cent_binsize
    img_center = logical_width / 2.0 + 0.5
    
    centroid_fit_summary, centroid_fit_fig, _, _ = gaussian_image_fit(
        obsid, 1, (img_center, img_center), cnt, (1.0 / cent_binsize),
        prefix="centrg", background=0.1, pos_max=(logical_width, logical_width),
        use_mcmc=False, confirm=False
    )
    
    if centroid_fit_summary is None:
        print(f"Centroid fit for {obsid} canceled. Skipping.")
        clean()
        return (obsid, "", "", "", "", "", [], [])

    # Save plot to a temporary PNG
    temp_cent_fit_png = f"2Dfits/temp_{obsid}_cent_fit.png"
    centroid_fit_fig.savefig(temp_cent_fit_png)
    plt.close(centroid_fit_fig)
    pdf_out_files.append(temp_cent_fit_png)
    progress_queue.put(1)

    # Get best fit physical coordinates
    d = get_data()
    crval_x, crval_y = d.sky.crval
    crpix_x, crpix_y = d.sky.crpix
    cdelt_x, cdelt_y = d.sky.cdelt
    
    # Access the dynamically created model object
    xphys_best = crval_x + (globals()['centrg1'].xpos.val - crpix_x) * cdelt_x
    yphys_best = crval_y + (globals()['centrg1'].ypos.val - crpix_y) * cdelt_y

    # Stage 2: Single Component Source Fit
    img_width = 10
    src_binsize = 0.25
    src_psf_images(
        obsid, infile, xphys_best, yphys_best, img_width,
        wcs_ra=current_ra, wcs_dec=current_dec,
        binsize=src_binsize, 
        psfimg=True, 
        empirical_psf=emp_psf_file
    )
    logical_width = img_width / src_binsize 
    img_center = logical_width / 2.0 + 0.5 
    pixel_scale_guess = 1.0 / src_binsize 
    scaled_cnt_guess = cnt / (pixel_scale_guess**2)
    scaled_fwhm_guess = 1.0 * pixel_scale_guess 
    
    src_fit_summary, src_fit_fig, _, _ = gaussian_image_fit(
        obsid, 1, (img_center, img_center), scaled_cnt_guess, scaled_fwhm_guess,
        prefix="srcg", pos_max=(logical_width, logical_width),
        use_mcmc=False, confirm=False
    )
    
    if src_fit_summary is None:
        print(f"Source fit for {obsid} canceled. Skipping.")
        clean()
        return (obsid, header_text, centroid_fit_summary, "", "", "", pdf_out_files, [])

    # Save plot to temporary PNG
    temp_src_fit_png = f"2Dfits/temp_{obsid}_src_fit.png"
    src_fit_fig.savefig(temp_src_fit_png)
    plt.close(src_fit_fig)
    pdf_out_files.append(temp_src_fit_png)
    progress_queue.put(1) 

    # Stage 3: Multi Component Fit
    # Access the dynamically created model object
    srcfit_off_x = globals()['srcg1'].xpos.val - img_center 
    srcfit_off_y = globals()['srcg1'].ypos.val - img_center 
    src_ampl = globals()['srcg1'].ampl.val
    src_fwhm = globals()['srcg1'].fwhm.val
    
    # This is your desired multi fit setup
    img_width = 40 
    multi_binsize = 0.25
    
    src_psf_images(
        obsid, infile, xphys_best, yphys_best, img_width,
        wcs_ra=current_ra, wcs_dec=current_dec,
        binsize=multi_binsize, 
        empirical_psf=emp_psf_file,
    )
    
    logical_width = img_width / multi_binsize 
    img_center = logical_width / 2.0 + 0.5   
    pixel_scale = src_binsize / multi_binsize 
    
    new_xpos = img_center + (srcfit_off_x * pixel_scale)
    new_ypos = img_center + (srcfit_off_y * pixel_scale)
    scaled_src_fwhm = src_fwhm * pixel_scale
    scaled_src_ampl = src_ampl / (pixel_scale**2)
    
    pixel_scale_guess = 1.0 / multi_binsize 
    scaled_cnt_ampl = cnt / (pixel_scale_guess**2)
    scaled_default_fwhm = 1.0 * pixel_scale_guess 

    # This is your desired component setup
    n_components = n_components_multi 
    positions = [(new_xpos, new_ypos)] + [(img_center, img_center)] * (n_components - 1)
    amplitudes = [scaled_src_ampl] + [scaled_cnt_ampl] * (n_components - 1)
    fwhms = [scaled_src_fwhm] + [scaled_default_fwhm] * (n_components - 1)

    # This is your desired fit call
    multi_fit_summary, multi_fit_fig, multi_corner_fig, multi_walker_fig = gaussian_image_fit(
        obsid, n_components, positions, amplitudes, fwhms,
        prefix="g", background=0.1, pos_max=(logical_width, logical_width),
        pos_min=(0, 0), exptime=exptime, confirm=False, 
        use_mcmc=run_mcmc_multi, 
        mcmc_iter=mcmc_iter_multi,
        n_walkers=mcmc_n_walkers,  
        ball_size=mcmc_ball_size,
        progress_chunks=progress_chunks 
    )

    if multi_fit_summary is None:
        print(f"Multi-component fit for {obsid} canceled. Skipping.")
        clean()
        return (obsid, header_text, centroid_fit_summary, src_fit_summary, "", "", pdf_out_files, [])

    # Save fit plot to temp PNG
    temp_multi_fit_png = f"2Dfits/temp_{obsid}_multi_fit.png"
    multi_fit_fig.savefig(temp_multi_fit_png)
    plt.close(multi_fit_fig)
    pdf_out_files.append(temp_multi_fit_png)
    multi_pdf_out_files.append(temp_multi_fit_png)

    # Save Walker Map (New Visualization)
    if multi_walker_fig is not None:
        temp_walker_png = f"2Dfits/temp_{obsid}_walker_map.png"
        multi_walker_fig.savefig(temp_walker_png)
        plt.close(multi_walker_fig)
        pdf_out_files.append(temp_walker_png)

    # Save Corner Plot
    if multi_corner_fig is not None:
        temp_corner_png = f"2Dfits/temp_{obsid}_corner.png"
        multi_corner_fig.savefig(temp_corner_png)
        plt.close(multi_corner_fig)
        pdf_out_files.append(temp_corner_png)

    # Create the text for the multi results file
    multi_results_text = (
        f"Observation: {obsid}\n"
        f"Infile: {infile}\n"
        f"Date: {date}, Exptime: {exptime}\n"
        f"{multi_fit_summary}\n\n"
    )

    # Clean up the local Sherpa session
    clean()
    progress_queue.put(1) 

    # Return all results for aggregation
    return (
        obsid,
        header_text,
        centroid_fit_summary,
        src_fit_summary,
        multi_fit_summary,
        multi_results_text,
        pdf_out_files,
        multi_pdf_out_files
    )

In [None]:
# This guard is essential for multiprocessing in notebooks
if __name__ == '__main__':

    # Fix DOS Warning for large plots
    from PIL import Image
    Image.MAX_IMAGE_PIXELS = None
    
    # Run Configuration
    
    # Set the number of components for the multi component fit
    multi_n_components = 3
    
    # Set to True to run MCMC, or False for a fast test run
    run_mcmc = True

    # Set to True to force the pipeline to delete old chains and run fresh.
    # Set to False to resume or load existing chains.
    recalculate_chains = True
    
    # MCMC CONTROL PARAMETERS
    mcmc_iterations = 20000
    mcmc_n_walkers = 4 * (multi_n_components * 3 + 2)       
    mcmc_ball_size = 1e-4
    # New: Number of progress bar updates during MCMC (e.g. 50 = every 2%)
    mcmc_updates_per_obs = 400

    # Static Parameters
    
    # Set the main working directory
    os.chdir('/Users/leodrake/Documents/MIT/ss433/HRC_2024/')

    # Define WCS coordinates for each observation
    obsid_coords = {
        "26568": ("287.9565362", "4.9826061"),
        "26569": ("287.9563218", "4.9827745"),
        "26570": ("287.9563754", "4.9825322"),
        "26571": ("287.9561693", "4.9827006"),
        "26572": ("287.9565032", "4.9826636"),
        "26573": ("287.9565444", "4.9826390"),
        "26574": ("287.9562518", "4.9825651"),
        "26575": ("287.9566969", "4.9828114"),
        "26576": ("287.9566351", "4.9826718"),
        "26577": ("287.9565238", "4.9826020"),
        "26578": ("287.9566021", "4.9826800"),
        "26579": ("287.9565733", "4.9825774")
    }

    # Define the empirical PSF file to be used
    emp_psf_file = "/Users/leodrake/Documents/MIT/ss433/HRC_2024/empPSF_iARLac_v2025_2017-2025.fits" 

    # Find all event files to be processed
    event_files = sorted(glob.glob('*/repro/*splinecorr.fits'))[:]
    
    # Define output file names
    pdf_out_filename = '2Dfits/0fit-plots.pdf'
    multi_pdf_out_filename = '2Dfits/0multi-comp-plots.pdf'
    results_filename = '2Dfits/0fit-results.txt'
    multi_results_filename = '2Dfits/0multi-comp-fit-results.txt'
    
    # Define total number of steps for the progress bar
    # 3 standard checkpoints + MCMC updates per file
    ticks_per_file = 3 + mcmc_updates_per_obs if run_mcmc else 3
    total_steps = len(event_files) * ticks_per_file

    # Run in Parallel
    
    # Create a Manager and a Queue for progress updates
    manager = Manager()
    progress_queue = manager.Queue()
    
    # Create the partial function, pre loading it with the args
    worker_func = partial(process_observation, 
                          progress_queue=progress_queue,
                          obsid_coords=obsid_coords, 
                          mcmc_scale_factors={}, 
                          emp_psf_file=emp_psf_file,
                          n_components_multi=multi_n_components,
                          run_mcmc_multi=run_mcmc,
                          mcmc_iter_multi=mcmc_iterations,
                          mcmc_n_walkers=mcmc_n_walkers,  
                          mcmc_ball_size=mcmc_ball_size,
                          progress_chunks=mcmc_updates_per_obs,
                          recalc=recalculate_chains
                         )

    # Set number of cores to use
    num_processes = os.cpu_count()
    print(f"--- Starting parallel processing on {num_processes} cores ---")
    start_total_time = time.time()
    
    # Start the pool and run the jobs
    with tqdm(total=total_steps, desc="Processing Observations", bar_format="{l_bar}{r_bar}") as pbar:
        with Pool(processes=num_processes) as pool:
            # Use map_async to run jobs without blocking the main thread
            async_result = pool.map_async(worker_func, event_files)
            
            # Monitor the queue and update the progress bar
            while not async_result.ready():
                while not progress_queue.empty():
                    pbar.update(progress_queue.get())
                time.sleep(0.1) 
            
            # Update the bar with any remaining items in the queue
            while not progress_queue.empty():
                pbar.update(progress_queue.get())
                
            # Get the final results from all workers
            results = async_result.get()
    
    end_total_time = time.time()
    print(f"\n--- Parallel processing complete in {(end_total_time - start_total_time) / 60.0:.2f} minutes ---")
    print("--- Aggregating all results... ---")

    # Aggregate Results
    
    # Sort results by obsid first item in tuple to ensure correct order
    results.sort(key=lambda x: x[0])
    
    # Initialize master lists for file paths
    all_pdf_out_files = []
    all_multi_pdf_out_files = []

    # Open text files and write all results in order
    with open(results_filename, 'w') as results_file, open(multi_results_filename, 'w') as multi_results_file:
        for res in results:
            # Unpack the tuple from the worker
            (obsid, header_text, centroid_fit_summary, src_fit_summary, 
             multi_fit_summary, multi_results_text, 
             pdf_out_files_worker, multi_pdf_out_files_worker) = res
            
            # Write to 0fit-results.txt
            results_file.write(header_text)
            results_file.write("\nCENTROID FIT SUMMARY:\n\n")
            results_file.write(centroid_fit_summary)
            results_file.write("SOURCE FIT SUMMARY:\n\n")
            results_file.write(src_fit_summary)
            results_file.write("MULTI-COMPONENT FIT SUMMARY:\n\n")
            results_file.write(multi_fit_summary)
            
            # Write to 0multi-comp-fit-results.txt
            multi_results_file.write(multi_results_text)
            
            # Add this workers PNG files to the master lists
            all_pdf_out_files.extend(pdf_out_files_worker)
            all_multi_pdf_out_files.extend(multi_pdf_out_files_worker)

    print("Text files written.")

    # PDF Compilation and Cleanup
    print('Compiling PDFs...\n')

    # Helper function to compile PNGs into a single PDF
    def compile_pngs_to_pdf(pbar, png_files, pdf_filename):
        if not png_files:
            return
        
        if not os.path.exists(png_files[0]):
            print(f"ERROR: Cannot find file {png_files[0]} to start PDF.")
            return

        images = []
        img1 = Image.open(png_files[0]).convert('RGB')
        pbar.update(1) 
        
        for png_file in png_files[1:]:
            if os.path.exists(png_file):
                images.append(Image.open(png_file).convert('RGB'))
            else:
                print(f"Warning: Missing file {png_file}, skipping.")
            pbar.update(1) 
        
        img1.save(pdf_filename, "PDF", resolution=100.0, save_all=True, append_images=images)

    # Create one single progress bar for all PDF compilation
    total_plots_to_compile = len(all_pdf_out_files) + len(all_multi_pdf_out_files)
    
    with tqdm(total=total_plots_to_compile, desc="Compiling PDF Plots", bar_format="{l_bar}{r_bar}") as pbar:
        try:
            # Compile the main PDF all plots
            compile_pngs_to_pdf(pbar, all_pdf_out_files, pdf_out_filename)
            print(f"Successfully compiled {pdf_out_filename}")
        except Exception as e:
            print(f"ERROR: Could not compile {pdf_out_filename}: {e}")

        try:
            # Compile the multi component only PDF
            compile_pngs_to_pdf(pbar, all_multi_pdf_out_files, multi_pdf_out_filename)
            print(f"Successfully compiled {multi_pdf_out_filename}")
        except Exception as e:
            print(f"ERROR: Could not compile {multi_pdf_out_filename}: {e}")

    # Final cleanup
    print("Cleaning up temporary PNG files...")
    temp_files_to_clean = glob.glob("2Dfits/temp_*.png")
    for f in temp_files_to_clean:
        try:
            os.remove(f)
        except Exception as e:
            print(f"Warning: Could not remove {f}: {e}")

    print('Process Complete')