In [None]:
import os, glob, re
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt # Plotting interface
from astropy.io import fits
from scipy.optimize import minimize
from matplotlib.cm import get_cmap
from scipy.signal import find_peaks
from matplotlib.backends.backend_pdf import PdfPages
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from collections import defaultdict
import cmasher as cmr
from scipy.optimize import curve_fit
from scipy.special import erf
import warnings
from pathlib import Path

plt.rcParams['figure.dpi'] = 400
mpl.rcParams['figure.max_open_warning'] = 0


In [None]:
# ---------- Globbing Functions ---------- #
# Directory containing DU subfolders (du1, du2, du3, ...)
ROOT_DIR = "/Users/leodrake/Library/CloudStorage/Box-Box/IXPE_rmfs/sim_data_mit"  # Adjusted to your path

def list_sim_fits(du_specifier='all', pattern='sim_*_pol_recon*.fits'):
    """
    Return a DU label and sorted FITS paths.

    Parameters
    ----------
    du_specifier : str or list, optional
        If 'all', searches all 'du*' subfolders.
        If a string like 'du1', searches only that subfolder.
        If a list like ['du1', 'du2'], searches those subfolders.
        Defaults to 'all'.
    pattern : str, optional
        The glob pattern for FITS files.

    Returns
    -------
    du_label : str
        A label representing the processed DUs (e.g., 'du1', 'du_all', 'du_du1_du2').
    files : list
        A sorted list of found FITS file paths.
    """
    if not os.path.isdir(ROOT_DIR):
        raise FileNotFoundError(f"ROOT_DIR '{ROOT_DIR}' not found.")

    all_dus = [d for d in os.listdir(ROOT_DIR) if d.lower().startswith('du') and os.path.isdir(os.path.join(ROOT_DIR, d))]

    if isinstance(du_specifier, str) and du_specifier.lower() == 'all':
        subdirs_to_search = all_dus
        du_label = 'du123'
    elif isinstance(du_specifier, str) and du_specifier in all_dus:
        subdirs_to_search = [du_specifier]
        du_label = du_specifier
    elif isinstance(du_specifier, (list, tuple)):
        subdirs_to_search = [d for d in du_specifier if d in all_dus]
        if not subdirs_to_search:
             raise ValueError(f"None of the specified DUs {du_specifier} found in {ROOT_DIR}")
        du_label = f'du-{"-".join(sorted(subdirs_to_search))}'
    else:
        raise ValueError(f"Invalid 'du_specifier': {du_specifier}. Use 'all', a valid DU name, or a list.")

    files = []
    for sd in subdirs_to_search:
        path = os.path.join(ROOT_DIR, sd, pattern)
        files.extend(glob.glob(path))

    if not files:
        print(f"Warning: No FITS files found for '{du_label}' with pattern '{pattern}' in {ROOT_DIR}.")

    return du_label, sorted(files)

def parse_energy(fname):
    """Extract energy in keV from 'sim_01000_...' filename."""
    # Search for energy value (5 digits) in filename
    m = re.search(r'sim_(\d{5})', os.path.basename(fname))
    # Convert matched digits to float in keV, or return NaN if no match
    return int(m.group(1)) / 1000.0 if m else np.nan


In [None]:
# ---------- Helper Functions ---------- #

def bg_flag(alpha, nrg):
    """Determine if events are background based on alpha and energy."""
    # Define parameters for background region boundaries
    a1, a2, nrg0 = 0.35, 0.7, 2.0  # Thresholds and reference energy
    b1 = (1 - a1) / (5.5 - nrg0)  # Slope for first boundary line
    b2 = (0.95 - a2) / (8.0 - nrg0)  # Slope for second boundary line
    # Identify events above the defined boundary lines as bad
    bad = (alpha > (a1 + b1 * (nrg - nrg0))) | (alpha > (a2 + b2 * (nrg - nrg0)))
    return ~bad  # Return boolean array: True for good events, False for bad

# Extract arrays from simulation FITS (full PI range, no XY cut)
def extract_common_data(d, pi_min, pi_max, bgflag=False):
    """
    Extract pi, nrg_pi, alpha, and phi from simulation FITS using the raw PHA channel.

    Parameters
    ----------
    d : FITS_rec
        The table data for one simulation file.
    pi_min, pi_max : float
        Full-range PI limits in channel units.
    bgflag : bool
        If True, apply the background cut via bg_flag().

    Returns
    -------
    pi, nrg_pi, alpha, phi : tuple of 1D numpy arrays
    """
    # 1) Raw PHA counts (0–32768)
    pha = d['PHA'].astype(float)  # Extract PHA (Pulse Height Amplitude) column
    pha2pi = 1.0 / (3000 * 0.04)  # 3000 ADC counts keV^-1, 40eV bins

    # 2) Compute PI channels (0–250) and energy keV (0–10)
    pi_raw  = pha * pha2pi  # Convert PHA to PI channels
    nrg_raw = pha / 3000  # 3000 ADC counts keV^-1: https://doi.org/10.1016/j.astropartphys.2021.102628

    # 3) Mask on the full PI range
    ok = (pi_raw > pi_min) & (pi_raw <= pi_max)  # Create boolean mask for events within PI range

    # 4) Slice out the selected events
    pi     = pi_raw[ok]  # PI channels for selected events
    nrg_pi = nrg_raw[ok]  # Energy (from PI) for selected events
    tl     = d['TRK_M2L'][ok]  # Track length proxy for selected events
    tw     = d['TRK_M2T'][ok]  # Track width proxy for selected events
    alpha  = (tl - tw) / (tl + tw)  # Calculate alpha parameter (shape parameter)
    phi    = d['DETPHI2'][ok]  # Detector phi angle for selected events

    # 5) Optional background filtering
    if bgflag:
        good = bg_flag(alpha, nrg_pi)  # Apply background flag
        # Filter data arrays based on good events
        pi, nrg_pi, alpha, phi = (
            pi[good], nrg_pi[good], alpha[good], phi[good]
        )

    return pi, nrg_pi, alpha, phi  # Return extracted and filtered arrays

def summarize_alpha_vs_pi_bins(pi, alpha, pi_min, pi_max, n_bins=4):
    """Compute mean alpha in PI bins for a low-PI range."""
    if n_bins <= 0:
        return np.array([]), np.array([]), np.array([], dtype=int)
    edges = np.linspace(pi_min, pi_max, n_bins + 1)
    centers = (edges[:-1] + edges[1:]) / 2
    means = np.full(n_bins, np.nan)
    counts = np.zeros(n_bins, dtype=int)
    for i in range(n_bins):
        sel = (pi >= edges[i]) & (pi < edges[i + 1])
        counts[i] = int(np.sum(sel))
        if counts[i] > 0:
            means[i] = np.nanmean(alpha[sel])
    return centers, means, counts

def find_pi_peak_band(pi, bin_width=1, min_prominence=0.1, tail_frac=0.3):
    """
    Identify the *highest* peak in the PI histogram and return an expanded band.

    The peak band is defined as the full Gaussian core:
    - Low edge: where the data begin to exceed the Gaussian fit (tail transition).
    - High edge: where the histogram flattens to the background (or the PI limit).

    Parameters
    ----------
    pi              : 1D array of PI channels
    bin_width       : width of each PI bin
    min_prominence : minimal prominence for find_peaks (fraction of max count)
    tail_frac       : fractional excess above the Gaussian to flag tail transition

    Returns
    -------
    pi_min, pi_max : floats
        Lower and upper PI channel bounds of the Gaussian core.
    """
    # 1) Histogram the PI data
    pi = np.asarray(pi)  # Ensure PI is a numpy array
    if pi.size == 0:
        return np.nan, np.nan
    lo, hi = int(np.nanmin(pi)), int(np.nanmax(pi))  # Min/max PI values for binning
    if hi <= lo:
        return float(lo), float(hi)
    bins = np.arange(lo, hi + bin_width, bin_width)  # Define histogram bins
    counts, edges = np.histogram(pi, bins=bins)  # Calculate histogram
    centers = (edges[:-1] + edges[1:]) / 2  # Bin centers

    peak_idxs, _ = find_peaks(counts, prominence=min_prominence * counts.max())
    if peak_idxs.size == 0:
        # Fallback: if no peaks found, use the highest single bin
        idx = counts.argmax()
    else:
        # Choose the index of the peak with the largest height among found peaks
        idx = peak_idxs[np.argmax(counts[peak_idxs])]

    bg_bins = max(5, int(0.1 * len(counts)))
    bg_counts = counts[-bg_bins:]
    bg_level = float(np.median(bg_counts))
    bg_sigma = np.sqrt(bg_level) if bg_level > 0 else 1.0

    high_idx = len(counts) - 1
    for j in range(idx, len(counts)):
        if counts[j] <= bg_level + 2.0 * bg_sigma:
            high_idx = j
            break

    def _gaussian(x, a, mu, sigma):
        return a * np.exp(-0.5 * ((x - mu) / sigma) ** 2)

    x_fit = centers[idx:high_idx + 1]
    y_fit = counts[idx:high_idx + 1] - bg_level
    fit_mask = y_fit > 0
    x_fit = x_fit[fit_mask]
    y_fit = y_fit[fit_mask]

    mu0 = centers[idx]
    sigma0 = max(bin_width, (centers[high_idx] - centers[idx]) / 2.355) if high_idx > idx else bin_width
    a0 = max(1.0, counts[idx] - bg_level)

    mu = mu0
    sigma = sigma0
    amp = a0
    if x_fit.size >= 3:
        try:
            popt, _ = curve_fit(_gaussian, x_fit, y_fit, p0=[a0, mu0, sigma0], maxfev=2000)
            amp, mu, sigma = popt
            sigma = abs(sigma) if sigma != 0 else sigma0
        except Exception:
            pass

    gaussian_total = _gaussian(centers, amp, mu, sigma) + bg_level

    low_edge = edges[0]
    tail_found = False
    for j in range(idx, -1, -1):
        if counts[j] > gaussian_total[j] * (1.0 + tail_frac):
            low_edge = centers[j] + bin_width / 2
            tail_found = True
            break

    if not tail_found:
        for j in range(idx, -1, -1):
            if counts[j] <= bg_level + 2.0 * bg_sigma:
                low_edge = centers[j] + bin_width / 2
                break

    high_edge = centers[high_idx] + bin_width / 2

    if high_edge <= low_edge:
        return float(edges[0]), float(edges[-1])
    return float(low_edge), float(high_edge)

def get_combined_data(file_list):
    """Reads FITS data from multiple files and concatenates them."""
    all_data = []
    print(f"    Combining data from {len(file_list)} file(s)...")
    for fname in file_list:
        try:
            all_data.append(fits.getdata(fname, 1))
        except Exception as e:
            print(f"    Warning: Could not read {os.path.basename(fname)}: {e}")
    if not all_data:
        return None
    # Ensure all data tables have the same columns before concatenating
    if len(all_data) > 1:
        first_dtype = all_data[0].dtype
        if not all(d.dtype == first_dtype for d in all_data[1:]):
            print("    Warning: FITS files have different structures. Cannot combine.")
            # For now, we'll return None to indicate failure.
            return None
    return np.concatenate(all_data)

# Build step-plot X,Y coordinates for histograms
def step_plot(x, y, binwidth):
    """Create x, y coordinates for a step plot from histogram data."""
    xsteps, ysteps = [], []  # Initialize lists for step plot coordinates
    # For each bin, create two x (left/right edge) and two y (height) points
    for xi, yi in zip(x, y):
        xsteps += [xi - binwidth/2, xi + binwidth/2]
        ysteps += [yi, yi]
    return xsteps, ysteps

# Plot XY vs Time and φ vs α (Unchanged, but not explicitly used by generate_summary_plots)
def plot_xy_vs_t(d, pi_min, pi_max, alpha_min, title):
    """Generates plots for XY vs Time and phi vs alpha."""
    # Extract data; bgflag is False by default here
    pi, nrg_pi, alpha, phi = extract_common_data(d, pi_min, pi_max, bgflag=False)
    time0 = d['TIME'] - np.min(d['TIME'])  # Normalize time to start from 0
    X, Y = d['ABSX'], d['ABSY']  # Absolute X and Y coordinates
    figs = []  # List to store generated figures

    # φ histogram
    dphi = 0.001 * np.pi  # Bin width for phi histogram
    phist, edges = np.histogram(phi, bins=np.arange(-np.pi, np.pi + dphi, dphi))  # Compute phi histogram
    pval = (edges[:-1] + edges[1:]) / 2  # Phi bin centers
    xs, ys = step_plot(pval, phist, dphi)  # Get step plot coordinates
    fig, ax = plt.subplots(); ax.plot(xs, ys); ax.set(title=f'{title} – φ dist', xlabel='φ'); figs.append(fig)

    # φ vs α scatter plot
    fig, ax = plt.subplots(); ax.scatter(phi, alpha, s=1)  # s=1 for small marker size
    ax.set(title=f'{title} – φ vs α', xlabel='φ', ylabel='α'); figs.append(fig)
    return figs

# Plot PI & α distributions (Unchanged)
def plot_pi_alpha(d, pi_min, pi_max, title, dist_component='Full', a_color='blueviolet'):
    """Generates plots for PI and alpha distributions."""
    # Extract data; bgflag is False by default here
    pi, _, alpha, _ = extract_common_data(d, pi_min, pi_max, bgflag=False)
    pi_full, _, alpha_full, _ = extract_common_data(d, 1, 374, bgflag=False)
    figs = []  # List to store generated figures

    # PI distribution plot
    phist, edges = np.histogram(pi, bins=np.arange(pi_min, pi_max+1, 1))  # PI histogram (bin width 1)
    pval = (edges[:-1] + edges[1:]) / 2  # PI bin centers
    xs, ys = step_plot(pval, phist, 1)  # Get step plot coordinates
    fig, ax = plt.subplots(); ax.plot(xs, ys, 'k', label=f'{dist_component} distribution')
    ax.set(title=f'{title} – PI dist', xlabel='PI'); figs.append(fig)
    ax.legend()

    # α distribution plot
    ahist, edges = np.histogram(alpha, bins=np.arange(0, 1.0, 0.01))  # Alpha histogram (bin width 0.01)
    aval = (edges[:-1] + edges[1:]) / 2  # Alpha bin centers
    xs, ys = step_plot(aval, ahist, 0.01)  # Get step plot coordinates
    full_alpha_max = None
    if alpha_full.size:
        full_ahist, _ = np.histogram(alpha_full, bins=np.arange(0, 1.0, 0.01))
        full_alpha_max = full_ahist.max()
    fig, ax = plt.subplots(); ax.plot(xs, ys, a_color, label=f'{dist_component} distribution')
    ax.set(title=f'{title} – α dist', xlabel='α'); figs.append(fig)
    if full_alpha_max is not None and full_alpha_max > 0:
        ax.set_ylim(0, full_alpha_max * 1.05)
    legend = ax.legend()
    if pi_full.size:
        fig.canvas.draw()
        renderer = fig.canvas.get_renderer()
        legend_bbox = legend.get_window_extent(renderer=renderer).transformed(ax.transAxes.inverted())
        width, height = 0.35, 0.35
        pad = 0.02
        gap = 0.02
        x0 = min(max(legend_bbox.x0, pad), 1.0 - width - pad)
        y0 = legend_bbox.y0 - gap - height
        if y0 < pad:
            y0 = min(legend_bbox.y1 + gap, 1.0 - height - pad)
        y0 = min(max(y0, pad), 1.0 - height - pad)
        axins = inset_axes(
            ax,
            width=f"{int(width*100)}%",
            height=f"{int(height*100)}%",
            bbox_to_anchor=(x0, y0, 1, 1),
            bbox_transform=ax.transAxes,
            loc="lower left",
            borderpad=0,
        )
        full_pi_min, full_pi_max = 1, 374
        phist_full, edges_full = np.histogram(pi_full, bins=np.arange(full_pi_min, full_pi_max + 1, 1))
        pval_full = (edges_full[:-1] + edges_full[1:]) / 2
        xs_full, ys_full = step_plot(pval_full, phist_full, 1)
        axins.plot(xs_full, ys_full, "k", lw=0.8)
        axins.axvline(pi_min, color=a_color, linestyle="--", lw=1)
        axins.axvline(pi_max, color=a_color, linestyle="--", lw=1)
        axins.set_xlim(full_pi_min, full_pi_max)
        axins.set_xticks([]); axins.set_yticks([])
        axins.patch.set_alpha(0.7)
        if dist_component.lower().startswith('tail') and pi_max > full_pi_min:
            axins.axvspan(full_pi_min, pi_max, color=a_color, alpha=0.05)
        elif dist_component.lower().startswith('peak'):
            axins.axvspan(pi_min, pi_max, color=a_color, alpha=0.05)
    return figs


In [None]:
# ---------- Fitting Functions ---------- #

def minimizer(func, p0, args=(), tol=1e-6):
    """Wrapper for scipy.optimize.minimize using Nelder-Mead method."""
    # Perform minimization
    res = minimize(func, p0, args=args, method='Nelder-Mead', tol=tol)
    return res.x  # Return the optimal parameters

def safe_polar_likelihood_1d(param, ci):
    """Calculates -2*log-likelihood for 1D polarization, with penalty for invalid parameters."""
    # param is modulation factor mu, ci is cos(2*phase_i)
    val = 1 + param * ci  # Argument of the logarithm in likelihood function
    # If any argument is non-positive, return a large penalty value (high cost)
    if np.any(val <= 0):
        return 1e6 + 1e3 * abs(param)
    return -2 * np.sum(np.log(val))  # Standard -2*log-likelihood

def invert_matrix(matrix):
    """Safely inverts a matrix, handling potential LinAlgError."""
    try:
        return 0, np.linalg.inv(matrix)  # Return 0 (success) and inverted matrix
    except np.linalg.LinAlgError:
        return 1, None  # Return 1 (error) and None

def polar_likelihood(param, evtq, evtu):
    """Calculates -2*log-likelihood for Stokes Q, U parameters."""
    q, u = param  # Unpack Stokes Q and U from parameters
    arg = 1 + q*evtq + u*evtu  # Argument of log: 1 + Q*cos(2*phi) + U*sin(2*phi)
    # Raise error if any argument is non-positive (log undefined)
    if np.any(arg <= 0):
        raise ValueError("Non-positive argument encountered in log")
    return -2 * np.sum(np.log(arg))  # Sum of -2*log values for all events

def polar_evpa_likelihood(param, evtq, evtu):
    """Calculates -2*log-likelihood for polarization degree (P) and angle (EVPA)."""
    dtor = np.pi/180.0  # Degrees to radians conversion factor
    # Convert P, EVPA to Q, U
    q = param[0]*np.cos(2*param[1]*dtor)  # param[0] is P, param[1] is EVPA
    u = param[0]*np.sin(2*param[1]*dtor)
    arg = 1 + q*evtq + u*evtu  # Argument of log
    # Raise error if any argument is non-positive
    if np.any(arg <= 0):
        raise ValueError("Non-positive argument encountered in log")
    return -2 * np.sum(np.log(arg))  # Sum of -2*log values

def pderiv(func, x, i, dx):
    """Computes partial derivative of func w.r.t. x[i] using central difference."""
    x0, x1 = x.copy(), x.copy()  # Create copies of parameter vector
    x0[i] -= 0.5*dx  # Perturb parameter x[i] backward
    x1[i] += 0.5*dx  # Perturb parameter x[i] forward
    return (func(x1) - func(x0)) / dx  # Central difference formula

def pderiv2(func, x, dx):
    """Computes the Hessian matrix (matrix of second partial derivatives) of func."""
    n = len(x)  # Number of parameters
    H = np.zeros((n,n))  # Initialize Hessian matrix
    # Iterate over upper triangle of the Hessian
    for i in range(n):
        for j in range(i, n):
            x0, x1 = x.copy(), x.copy()
            # Perturb x[i] to compute derivative w.r.t x[j]
            x0[i] -= 0.5*dx[i]; x1[i] += 0.5*dx[i]
            # Compute partial derivative of (d func / d x[j]) at x0 and x1
            pd0 = pderiv(func, x0, j, dx[j])
            pd1 = pderiv(func, x1, j, dx[j])
            # Second derivative (d^2 func / dx[i]dx[j])
            H[i,j] = (pd1 - pd0)/dx[i]
    # Symmetrize the Hessian (H[j,i] = H[i,j])
    return H + H.T - np.diag(np.diagonal(H))

def likelihood(evtq, evtu):
    """Estimates Stokes Q, U, polarization degree, EVPA, and errors."""
    # evtq = cos(2*phi_i), evtu = sin(2*phi_i) for each event i
    sumq2 = np.sum(evtq**2)  # Sum of cos^2(2*phi_i)
    sumu2 = np.sum(evtu**2)  # Sum of sin^2(2*phi_i)

    # Simple (approximate) initial estimates for Q, U errors
    qu_err_init = np.array([1/np.sqrt(sumq2) if sumq2>0 else np.nan,
                            1/np.sqrt(sumu2) if sumu2>0 else np.nan])
    # Simple initial estimates for Q, U values
    qu0 = np.array([np.sum(evtq)*(qu_err_init[0]**2) if sumq2>0 else 0.0,
                    np.sum(evtu)*(qu_err_init[1]**2) if sumu2>0 else 0.0])

    # Fit Q,U using Nelder-Mead minimization of polar_likelihood
    initial_like_val = 0  # Initialize
    try:
        initial_like_val = polar_likelihood(qu0, evtq, evtu)  # Calculate initial likelihood
        # Set tolerance for minimizer based on initial likelihood value
        tol = abs(0.01 / initial_like_val) if initial_like_val != 0 and (sumq2 + sumu2 > 0) else 1e-6
    except ValueError:  # Catch potential log error if qu0 is problematic
        tol = 1e-6

    try:
        # Perform minimization to find best-fit Q, U
        qu = minimizer(polar_likelihood, qu0, args=(evtq, evtu), tol=tol)
    except ValueError:  # If minimizer fails due to initial values
        qu = qu0.copy()  # Use initial estimates
    except Exception:  # Catch other minimization errors
        qu = qu0.copy()  # Use initial estimates

    # Error estimation from Hessian matrix of the likelihood function
    qu_err_fit = qu_err_init.copy()  # Default to initial error estimates
    try:
        # Step sizes for numerical differentiation, use initial errors or small default
        hess_dx = qu_err_init if np.all(np.isfinite(qu_err_init)) and np.all(qu_err_init > 0) else np.ones_like(qu_err_init)*1e-3
        # Compute Hessian (0.5 factor for -2logL)
        H = 0.5 * pderiv2(lambda x: polar_likelihood(x, evtq, evtu), qu, hess_dx)
        err_code, M = invert_matrix(H)  # Invert Hessian to get covariance matrix
        if M is not None:  # If inversion successful
            qu_err_fit = np.sqrt(np.abs(np.diag(M)))  # Errors are sqrt of diagonal elements
        else:  # If inversion failed
            qu_err_fit = np.array([np.nan, np.nan])
    except:  # Catch any other errors during Hessian calculation/inversion
        qu_err_fit = np.array([np.nan, np.nan])

    # Calculate polarization degree (poln) and electric vector position angle (evpa)
    poln = np.linalg.norm(qu)  # Polarization degree: sqrt(Q^2 + U^2)
    evpa = 0.5 * np.arctan2(qu[1], qu[0]) * (180.0/np.pi)  # EVPA in degrees

    # Minimum Detectable Polarization (MDP) estimate
    mdp = 4.29/np.sqrt(sumq2+sumu2) if sumq2+sumu2>0 else np.nan
    # Note: Errors for poln and evpa are not calculated here via propagation from Q,U errors.
    # Return initial Q,U (qu0), their errors (qe0=qu_err_init),
    return qu0, qu_err_init, qu, qu_err_fit, poln, np.nan, evpa, np.nan, np.nan, mdp

def fit_mu_alpha(d, pi_min, pi_max, nalpha, title, bgflag=False):
    """
    IDL-style fit_mu_alpha for SIM data, full-PI.
    Fits modulation factor mu in bins of alpha.

    Returns: (outputs_dict, list_of_figures)
    """
    # 1) Extract common data arrays (pi, energy, alpha, phi)
    pi, nrg_pi, alpha, phi = extract_common_data(d, pi_min, pi_max, bgflag)
    figs = []  # Initialize list to store figures

    # 3) Phi distribution + model (Step 2 from IDL might be implicit or handled elsewhere)
    dphi = 0.001 * np.pi  # Bin width for phi histogram
    phist, edges = np.histogram(phi, bins=np.arange(-np.pi, np.pi + dphi, dphi))  # Compute phi histogram
    phival = (edges[:-1] + edges[1:]) / 2  # Phi bin centers
    evtq = np.cos(2*phi); evtu = np.sin(2*phi)  # Per-event Q and U proxies

    # Handle cases with no events in the selected PI band
    if len(evtq) == 0 or len(evtu) == 0:
        # Prepare an empty output structure
        outputs = {
            "mu_noweight": np.nan, "mu_noweight_err": np.nan,
            "alpha_bins": (np.arange(nalpha)+0.5)*(1.0/float(nalpha)),
            "mu_bins": np.full(nalpha, np.nan), "mu_bins_err": np.full(nalpha, np.nan),
            "nevt_bins": np.zeros(nalpha, int)
        }
        # Create an empty placeholder plot for phi distribution
        phi_fig, phi_ax = plt.subplots()
        phi_ax.set(title=rf"{title} – $\phi$ Dist. (No Data)", xlabel=r"$\phi$")
        figs.append(phi_fig)
        return outputs, figs

    # Calculate overall Q, U, polarization, etc. using the likelihood function
    qu0, qe0, qu, qe, mu_nw, mu_nw_err, evpa, evpa_err, dlike, mdp = likelihood(evtq, evtu)

    # Correct normalization: number of φ-bins = len(phist)
    nbin_phi_hist = len(phist)  # Number of bins in the phi histogram
    model = np.zeros_like(phival)  # Initialize model array
    # Calculate model if there are events and histogram bins
    if nbin_phi_hist > 0 and len(phi) > 0:
        # Model: N_total * (1 + Q*cos(2*phi_val) + U*sin(2*phi_val)) / N_bins
        model = len(phi) * (1 + qu[0]*np.cos(2*phival) + qu[1]*np.sin(2*phival)) / nbin_phi_hist
    else:  # If no phi values, fill model with NaNs
        model.fill(np.nan)

    # Plot phi distribution with the fitted model
    xs2, ys2 = step_plot(phival, phist, dphi)  # Coordinates for step plot of histogram
    phi_fig, phi_ax = plt.subplots()
    phi_ax.plot(xs2, ys2, 'darkseagreen', alpha=0.75, label="Data")  # Plot histogram
    phi_ax.plot(phival, model, 'k', label="Model")  # Plot model
    phi_ax.set(title=rf"{title} – $\phi$ Dist.", xlabel=r"$\phi$")
    # phi_ax.legend() # Optional: add legend if desired
    figs.append(phi_fig)

    # Handle qu being all NaNs (e.g., if likelihood fit failed or no data)
    phase = (phi - 0.5*np.arctan2(qu[1], qu[0])) % np.pi if not np.all(np.isnan(qu)) else np.full_like(phi, np.nan)
    dalpha = 1.0/float(nalpha)  # Width of each alpha bin
    alpha_centers = (np.arange(nalpha)+0.5)*dalpha  # Center of each alpha bin

    mu_bins = np.zeros(nalpha)  # Initialize array for mu values in alpha bins
    mu_err  = np.full(nalpha, np.nan)  # Initialize array for mu errors (with NaNs)
    nevt    = np.zeros(nalpha, int)  # Initialize array for number of events in alpha bins

    MIN_EVENTS = 100  # Minimum number of events required in an alpha bin to perform fit

    # Iterate over alpha bins to fit mu
    for i, ac in enumerate(alpha_centers):
        # Select events within the current alpha bin
        sel = np.where((alpha >= ac - dalpha/2) & (alpha < ac + dalpha/2))[0]
        nevt[i] = sel.size  # Number of events in this bin

        # Only fit this bin if enough events and phase is valid
        if nevt[i] < MIN_EVENTS or np.all(np.isnan(phase)):
            mu_bins[i] = np.nan  # Not enough events or invalid phase, set mu to NaN
            continue  # Skip to next alpha bin

        ci = np.cos(2*phase[sel])  # cos(2*aligned_phase) for selected events
        # Initial guess for mu parameter for 1D likelihood fit
        p0 = np.array([np.sum(ci)/np.sum(ci*ci)]) if np.sum(ci*ci)>0 else np.array([0.0])
        # Fit mu using safe_polar_likelihood_1d minimizer
        mu_bins[i] = minimizer(safe_polar_likelihood_1d, p0, args=(ci,), tol=1e-6)[0]
        # Calculate error for mu_bins[i] from the likelihood curvature (1/sqrt(Fisher_information))
        denom = np.sum(ci*ci/((1+mu_bins[i]*ci)**2))  # Denominator for error calculation
        mu_err[i] = np.sqrt(1/denom) if denom>0 else np.nan

        # (Currently, this plot is generated for every valid bin)
        pbin = 0.01*np.pi  # Bin width for phase histogram in this alpha bin
        ph2, ed2 = np.histogram(phase[sel], bins=np.arange(0,np.pi+pbin,pbin))  # Histogram of aligned phases
        pv2 = (ed2[:-1]+ed2[1:])/2  # Phase bin centers
        norm = ph2/ph2.sum() if ph2.sum() > 0 else ph2  # Normalized counts (density)
        model_ph = 0.01*(1+mu_bins[i]*np.cos(2*pv2))  # Model for phase distribution: (1 + mu*cos(2*phase_val)) * dPhase

        fig, ax = plt.subplots()
        xs3, ys3 = step_plot(pv2, norm, pbin)  # Coordinates for step plot
        ax.plot(xs3, ys3, 'darkseagreen', label="Data")
        ax.plot(pv2, model_ph, 'k--', label=rf"$\mu={mu_bins[i]:.2f}$")  # Plot model with fitted mu
        ax.set(title=f"{title} – {ac-dalpha/2:.1f}<α<{ac+dalpha/2:.1f}",
               xlabel="Phase (rel. to EVPA)", ylabel="Normalized Counts")
        ax.legend()
        figs.append(fig)

    # Store results in a dictionary
    outputs = {
        "mu_noweight": mu_nw,  # Overall modulation factor (no alpha weighting)
        "mu_noweight_err": mu_nw_err,  # Error for mu_noweight
        "alpha_bins": alpha_centers,  # Centers of alpha bins
        "mu_bins": mu_bins,  # Fitted mu in each alpha bin
        "mu_bins_err": mu_err,  # Error of mu in each alpha bin
        "nevt_bins": nevt  # Number of events in each alpha bin
    }
    return outputs, figs  # Return results dictionary and list of figures


In [None]:
# ---------- Summary Figures ---------- #
import matplotlib.patheffects as pe  # Explicit import for text effects used in stackplot labels

def peak_with_erf_tail(E, A_peak, E_break, alpha1, alpha2, C_tail, E_trans, W_trans):
    """
    A broken power-law peak that smoothly transitions to a flat, constant tail.
    C_tail: The constant height of the flat tail.
    E_trans: The energy where the transition to the tail occurs.
    W_trans: The width (speed) of the erf transition.
    """
    # Create the broken power-law core shape
    if E_break <= 0:
        return np.full_like(E, np.inf)
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        term1 = (E / E_break)**(-alpha1)
        term2 = (E / E_break)**(-alpha2)
        peak = A_peak * (term1 + term2)**(-1)
    
    # Create the erf-based switch (goes smoothly from 0 to 1)
    switch = (1 + erf((E - E_trans) / W_trans)) / 2.0
    
    # Combine the two parts
    return peak * (1 - switch) + C_tail * switch

def compute_model(eplot, aa=-0.28, bb=0.2, cc=0.21, dd=1./24.):
    """
    Compute the IDL-like empirical model for modulation factor mu vs. energy.
    Model: mu = (1 / [(-aa - bb*E)^-4 + (-cc - dd*E)^-4])^0.25,
    handling non-finite values.
    """
    # Calculate the two terms in the denominator
    term1 = (-aa - bb * eplot)**(-4)
    term2 = (-cc - dd * eplot)**(-4)
    # Compute model, suppressing errors for invalid operations (e.g., division by zero)
    with np.errstate(divide='ignore', invalid='ignore'):
        model = (1.0 / (term1 + term2))**0.25
        model[~np.isfinite(model)] = np.nan  # Set non-finite results (e.g. from negative bases) to NaN
    return model

def generate_summary_sim_plots(energies_list, peak_results_list, low_results_list, du_label, filename="summary_plots_sim.pdf"):
    """
    Create a multipage PDF of summary plots for simulation data using only the Erf Tail model.
    """
    def result_is_valid(res):
        """Helper function to check if a result dictionary is valid (contains finite mu_noweight)."""
        return (
            res is not None and
            ("mu_noweight" in res) and
            np.isfinite(res["mu_noweight"])
        )

    valid_indices = [
        i for i, (p, l, e) in enumerate(zip(peak_results_list, low_results_list, energies_list))
        if result_is_valid(p) and result_is_valid(l) and np.isfinite(e)
    ]

    if not valid_indices:
        print(f"Error: No valid results found for DU {du_label}. Cannot generate summary plots.")
        return

    energies = np.array([energies_list[i] for i in valid_indices])
    peak_results = [peak_results_list[i] for i in valid_indices]
    low_results  = [low_results_list[i] for i in valid_indices]

    if len(energies) == 0:
        print(f"Error: No valid data remain after filtering for DU {du_label}. Cannot generate summary plots.")
        return

    print(f"\nGenerating summary plots for {len(energies)} valid simulations ({du_label}) in '{filename}'...")

    tot_hi = np.array([np.sum(res["nevt_bins"]) if "nevt_bins" in res and res["nevt_bins"] is not None else 0 for res in peak_results])
    tot_lo = np.array([np.sum(res["nevt_bins"]) if "nevt_bins" in res and res["nevt_bins"] is not None else 0 for res in low_results])
    valid_evt_mask = ((tot_hi + tot_lo) > 0)
    mu_noweight_hi = np.array([res["mu_noweight"] for res in peak_results])
    mu_noweight_hi_err = np.array([res["mu_noweight_err"] for res in peak_results])
    mu_noweight_lo = np.array([res["mu_noweight"] for res in low_results])
    mu_noweight_lo_err = np.array([res["mu_noweight_err"] for res in low_results])
    mu_noweight_combined = np.full_like(mu_noweight_hi, np.nan)
    numerator_combined = np.nan_to_num(tot_hi * mu_noweight_hi) + np.nan_to_num(tot_lo * mu_noweight_lo)
    denominator_combined = tot_hi + tot_lo
    mu_noweight_combined[valid_evt_mask] = numerator_combined[valid_evt_mask] / denominator_combined[valid_evt_mask]

    eplot = np.linspace(min(energies), max(energies), 500)
    mu_model_orig   = compute_model(eplot, aa=-0.28, bb=0.2, cc=0.21, dd=1./24.)
    mu_model_better = compute_model(eplot, aa=-0.28, bb=0.2, cc=0.21, dd=1./18.5)

    with PdfPages(filename) as pdf:
        fig1, ax1 = plt.subplots();
        ax1.errorbar(energies, mu_noweight_hi, yerr=mu_noweight_hi_err, c='deeppink', fmt='.', capsize=3, label="Peak PI")
        ax1.errorbar(energies, mu_noweight_lo, yerr=mu_noweight_lo_err, c='cornflowerblue', fmt='s', markersize=3, capsize=3, label="Tail PI")
        ax1.plot(energies, mu_noweight_combined, 'd', markersize=3, c='blueviolet', label="All PI")
        ax1.plot(eplot, mu_model_better, 'k', label="Peak Only Model", zorder=10)
        ax1.plot(eplot, mu_model_orig, 'k--', label="All PI Model (Di Marco+)", zorder=9)
        ax1.set(xlabel="Energy (keV)", ylabel=r"$\mu$", title=f"IXPE Simulations {du_label}"); ax1.legend()
        pdf.savefig(fig1); plt.close(fig1)

        a_ref, b_ref = 0.05, 0.8; a_line_ref = np.linspace(0,1,200); mu_line_ref = a_ref + b_ref*a_line_ref
        norm = plt.Normalize(vmin=np.min(energies), vmax=np.max(energies))
        fig2, ax2 = plt.subplots();
        for E_val, res_dict in zip(energies, peak_results):
            mask = res_dict["nevt_bins"] > 100
            ax2.plot(res_dict["alpha_bins"][mask], res_dict["mu_bins"][mask], color=plt.cm.rainbow_r(norm(E_val)), lw=1)
        ax2.plot(a_line_ref, mu_line_ref, 'k--', linewidth=3, label="Di Marco et al.")
        ax2.set(title="Peak PI: μ vs α", xlabel="α", ylabel="μ"); ax2.legend()
        fig2.colorbar(plt.cm.ScalarMappable(norm=norm, cmap="rainbow_r"), ax=ax2, label="Energy (keV)")
        plt.tight_layout(); pdf.savefig(fig2); plt.close(fig2)

        fig3, ax3 = plt.subplots()
        for E_val, res_dict in zip(energies, low_results):
            mask = res_dict["nevt_bins"] > 100
            ax3.plot(res_dict["alpha_bins"][mask], res_dict["mu_bins"][mask], color=plt.cm.rainbow_r(norm(E_val)), lw=1)
        ax3.plot(a_line_ref, mu_line_ref, 'k--', linewidth=3, label="Di Marco et al.")
        ax3.set(title="Tail PI: μ vs α", xlabel="α", ylabel="μ"); ax3.legend()
        fig3.colorbar(plt.cm.ScalarMappable(norm=norm, cmap="rainbow_r"), ax=ax3, label="Energy (keV)")
        plt.tight_layout(); pdf.savefig(fig3); plt.close(fig3)
        
        fig4, ax4 = plt.subplots()
        a_model_linear, b_model_linear = 0.05, 0.8
        for E_val, res_dict in zip(energies, peak_results):
            mask = res_dict["nevt_bins"] > 100
            mu_improvement = res_dict["mu_bins"] - (a_model_linear + b_model_linear * res_dict["alpha_bins"])
            ax4.plot(res_dict["alpha_bins"][mask], mu_improvement[mask], color=plt.cm.rainbow_r(norm(E_val)), linewidth=1)
        ax4.axhline(0, color='k', linestyle='--', linewidth=3)
        ax4.set(title="Peak PI: Model Improvement", xlabel="α", ylabel="μ − model (linear)")
        fig4.colorbar(plt.cm.ScalarMappable(norm=norm, cmap="rainbow_r"), ax=ax4, label="Energy (keV)")
        plt.tight_layout(); pdf.savefig(fig4); plt.close(fig4)

        fig4b, ax4b = plt.subplots()
        series_plotted = 0
        for E_val, res_dict in zip(energies, low_results):
            centers = res_dict.get("alpha_pi_low_centers")
            means = res_dict.get("alpha_pi_low_mean")
            if centers is None or means is None:
                continue
            centers = np.asarray(centers)
            means = np.asarray(means)
            counts = res_dict.get("alpha_pi_low_counts")
            if counts is None:
                mask = np.isfinite(means)
            else:
                mask = (np.asarray(counts) > 0) & np.isfinite(means)
            if np.any(mask):
                ax4b.plot(centers[mask], means[mask], color=plt.cm.rainbow_r(norm(E_val)), lw=1)
                series_plotted += 1
        if series_plotted == 0:
            print("   -> Tail PI alpha vs PI bin: no valid bins to plot.")
        else:
            ax4b.set(title="Tail PI: mean alpha vs PI bin", xlabel="PI bin center", ylabel="Mean alpha")
            fig4b.colorbar(plt.cm.ScalarMappable(norm=norm, cmap="rainbow_r"), ax=ax4b, label="Energy (keV)")
            plt.tight_layout(); pdf.savefig(fig4b)
        plt.close(fig4b)

        
        # Helper variables
        n_alpha_bins = len(peak_results[0]["alpha_bins"])
        alpha_bin_centers = peak_results[0]["alpha_bins"]
        d_alpha = 1.0 / n_alpha_bins
        
        labels = [fr"${center-d_alpha/2.0:.2f} \leq \alpha < {center+d_alpha/2.0:.2f}$" for center in alpha_bin_centers]
        colors = plt.cm.viridis(np.linspace(0, 1, n_alpha_bins))

        # We will normalize so the total stack height is 1.0
        stack_data = [[] for _ in range(2 * n_alpha_bins)]
        energies_for_stackplot = []
        boundary_line = []  # To draw the line between Low and Peak
        
        for E_val, res_p, res_l in zip(energies, peak_results, low_results):
            n_peak = res_p["nevt_bins"]
            n_low  = res_l["nevt_bins"]
            
            tot_peak = np.sum(n_peak)
            tot_low  = np.sum(n_low)
            total = tot_peak + tot_low
            
            if total > 0:
                energies_for_stackplot.append(E_val)
                
                # Fractions of the TOTAL for Low PI (Bottom half)
                fracs_low = n_low / total
                for i in range(n_alpha_bins):
                    stack_data[i].append(fracs_low[i])
                    
                # Fractions of the TOTAL for Peak PI (Top half)
                fracs_peak = n_peak / total
                for i in range(n_alpha_bins):
                    # stack_data index shifted by n_alpha_bins
                    stack_data[n_alpha_bins + i].append(fracs_peak[i])
                
                # The boundary line is the total fraction of Low PI events
                boundary_line.append(tot_low / total)

        if energies_for_stackplot:
            fig5, ax5 = plt.subplots(figsize=(12, 7))
            
            # We recycle the 'colors' list so both blocks look like full gradients.
            full_colors = list(colors) + list(colors)
            
            # We only want 1 set of labels for the legend, so we use labels for first N, then None
            full_labels = labels + [None] * n_alpha_bins
            
            ax5.stackplot(energies_for_stackplot, stack_data, labels=full_labels, colors=full_colors)
            
            # Plot the Boundary Line
            ax5.plot(energies_for_stackplot, boundary_line, color='white', linewidth=3, linestyle='-')
            ax5.plot(energies_for_stackplot, boundary_line, color='black', linewidth=1.5, linestyle='--')
            
            # Add Text Annotations to clarify regions using 'pe' for path effects
            mid_energy = (min(energies) + max(energies)) / 2
            ax5.text(7, 0.09, "Tail PI", ha='center', va='center',
                     fontsize=12, fontweight='bold', color='white',  path_effects=[pe.withStroke(linewidth=3, foreground="black")])
            ax5.text(mid_energy, 0.9, "Peak PI", ha='center', va='center',
                     fontsize=12, fontweight='bold', color='white', path_effects=[pe.withStroke(linewidth=3, foreground="black")])

            ax5.set(title="Alpha Fraction Distribution (Split by PI Selection)",
                    xlabel="Energy (keV)",
                    ylabel="Fraction of Total Events (stacked)",
                    ylim=(0, 1),
                    xlim=(min(energies), max(energies)))
            
            # Legend (only shows the alpha bins once)
            ax5.legend(loc='center left', bbox_to_anchor=(1.02, 0.5), fontsize='small', title="Alpha Bins")
            fig5.tight_layout(rect=[0, 0, 0.8, 1])
            pdf.savefig(fig5); plt.close(fig5)

        print("   -> Generating μ vs. Energy plots for each alpha bin...")
        
        erf_tail_params_log = []
        erf_tail_params_log_low = []
        
        last_successful_ebreak_peak = None
        last_successful_ebreak_low = None
        
        for i in range(n_alpha_bins):
            alpha_low, alpha_high = alpha_bin_centers[i] - d_alpha/2.0, alpha_bin_centers[i] + d_alpha/2.0
            mu_peak_slice = np.array([res['mu_bins'][i] for res in peak_results])
            mu_err_peak_slice = np.array([res['mu_bins_err'][i] for res in peak_results])
            mu_low_slice = np.array([res['mu_bins'][i] for res in low_results])
            mu_err_low_slice = np.array([res['mu_bins_err'][i] for res in low_results])

            fig_bin, ax_bin = plt.subplots()
            ax_bin.errorbar(energies, mu_peak_slice, yerr=mu_err_peak_slice, c='deeppink', fmt='.', capsize=3, label="Peak PI")
            ax_bin.errorbar(energies, mu_low_slice, yerr=mu_err_low_slice, c='cornflowerblue', fmt='s', markersize=3, capsize=3, label="Tail PI")
            
            e_fit_plot = np.linspace(energies.min(), energies.max(), 200)
            erf_tail_params_this_bin = [np.nan] * 7
            erf_tail_params_this_bin_low = [np.nan] * 7

            # Fit Peak PI data
            mask_peak = np.isfinite(mu_peak_slice) & (mu_err_peak_slice > 0)
            if np.sum(mask_peak) > 6:
                x_peak, y_peak, yerr_peak = energies[mask_peak], mu_peak_slice[mask_peak], mu_err_peak_slice[mask_peak]
                try:
                    p0_erf = [2 * alpha_high, 2, 6, -2.0, 0.1, 6.0, 2.0]
                    
                    lower_bounds = [2 * alpha_high - 0.001, 1.5, 5, -np.inf, 0, 4, 0.1]
                    upper_bounds = [2 * alpha_high + 0.001, 6, np.inf, 0, 1, 12, 10]
                    
                    if last_successful_ebreak_peak is not None:
                        lower_bounds[1] = last_successful_ebreak_peak
                        p0_erf[1] = last_successful_ebreak_peak
                    
                    bounds_erf = (lower_bounds, upper_bounds)
                    popt_erf, _ = curve_fit(peak_with_erf_tail, x_peak, y_peak, p0=p0_erf, sigma=yerr_peak, maxfev=10000, bounds=bounds_erf)
                    
                    ax_bin.plot(e_fit_plot, peak_with_erf_tail(e_fit_plot, *popt_erf), color='k', ls='-', lw=2)
                    erf_tail_params_this_bin = popt_erf
                    last_successful_ebreak_peak = popt_erf[1]
                except (RuntimeError, ValueError):
                    pass
            
            erf_tail_params_log.append(erf_tail_params_this_bin)

            # Fit Low PI data
            mask_low = np.isfinite(mu_low_slice) & (mu_err_low_slice > 0)
            if np.sum(mask_low) > 4:
                x_low, y_low, yerr_low = energies[mask_low], mu_low_slice[mask_low], mu_err_low_slice[mask_low]
                try:
                    p0_erf_low = [1, 3, 2.0, -2.0, 0.1, 6.0, 0.5]

                    lower_bounds_low = [0, 2, 0, -np.inf, 0, 4, 0.1]
                    upper_bounds_low = [2, 6, np.inf, 0, np.inf, 12, 1]
                    
                    if last_successful_ebreak_low is not None:
                        lower_bounds_low[1] = last_successful_ebreak_low
                        p0_erf_low[1] = last_successful_ebreak_low
                    
                    bounds_erf_low = (lower_bounds_low, upper_bounds_low)
                    popt_erf_low, _ = curve_fit(peak_with_erf_tail, x_low, y_low, p0=p0_erf_low, sigma=yerr_low, maxfev=10000, bounds=bounds_erf_low)
                    
                    ax_bin.plot(e_fit_plot, peak_with_erf_tail(e_fit_plot, *popt_erf_low), color='k', ls='--', lw=1.5)
                    erf_tail_params_this_bin_low = popt_erf_low
                    last_successful_ebreak_low = popt_erf_low[1]
                except (RuntimeError, ValueError):
                    pass
            
            erf_tail_params_log_low.append(erf_tail_params_this_bin_low)

            ax_bin.set(xlabel="Energy (keV)", ylabel=r"$\mu$", title=fr"Modulation vs. Energy for ${alpha_low:.2f} \leq \alpha < {alpha_high:.2f}$")
            ax_bin.legend(); ax_bin.grid(True, linestyle='--', alpha=0.6); ax_bin.set_ylim(-0.15, 1.05)
            pdf.savefig(fig_bin); plt.close(fig_bin)

        print("   -> Generating parameter vs. alpha plots...")
        erf_param_names = ['Peak Amp (A_peak)', 'Break Energy (E_break)', 'Rise Index (α1)', 'Fall Index (α2)', 'Tail Height (C_tail)', 'Transition Energy (E_trans)', 'Transition Width (W_trans)']
        
        # Plot for Peak PI parameters
        erf_params = np.array(erf_tail_params_log)
        fig_erf, axes_erf = plt.subplots(4, 2, figsize=(12, 16), constrained_layout=True)
        fig_erf.suptitle('Peak + Erf Tail Fit Parameters vs. α (Peak PI)', fontsize=16)
        if len(erf_param_names) % 2 != 0: axes_erf.flat[-1].set_visible(False)
        for idx, ax in enumerate(axes_erf.flat):
            if idx < len(erf_param_names):
                ax.plot(alpha_bin_centers, erf_params[:, idx], 'o-', color='deeppink')
                ax.set(title=erf_param_names[idx], xlabel='α bin center', ylabel='Parameter Value'); ax.grid(True, ls=':')
        pdf.savefig(fig_erf); plt.close(fig_erf)

        # Plot for Low PI parameters
        erf_params_low = np.array(erf_tail_params_log_low)
        fig_erf_low, axes_erf_low = plt.subplots(4, 2, figsize=(12, 16), constrained_layout=True)
        fig_erf_low.suptitle('Peak + Erf Tail Fit Parameters vs. α (Tail PI)', fontsize=16)
        if len(erf_param_names) % 2 != 0: axes_erf_low.flat[-1].set_visible(False)
        for idx, ax in enumerate(axes_erf_low.flat):
            if idx < len(erf_param_names):
                ax.plot(alpha_bin_centers, erf_params_low[:, idx], 'o-', color='cornflowerblue')
                ax.set(title=erf_param_names[idx], xlabel='α bin center', ylabel='Parameter Value'); ax.grid(True, ls=':')
        pdf.savefig(fig_erf_low); plt.close(fig_erf_low)

    print(f"Summary plots saved to '{filename}'.")


In [None]:
# Dataset selection and path setup
dataset_config = {
    "original": "sim_data_mit",
    "scrambled_80": "scrambled_sim_data_80percent",
    "scrambled_50": "scrambled_sim_data_50percent"
}

base_path = Path(os.path.expanduser("~/Library/CloudStorage/Box-Box/IXPE_rmfs"))
working_dir = Path(os.path.expanduser("~/Documents/IXPE"))
plot_output_dir = Path("/Users/leodrake/MIT Dropbox/Leonardo Drake/IXPE")

current_dataset_key = "original"

# Resolve dataset paths and suffix
try:
    input_subdir = dataset_config[current_dataset_key]
    root_dir = base_path / input_subdir
    
    output_suffix = '' if current_dataset_key == 'original' else f'-{current_dataset_key}'

    print(f"Processing dataset: '{current_dataset_key}'")
    print(f"  - Input: {root_dir}")
    print(f"  - Output Plots: {plot_output_dir}")
    print(f"  - Suffix: '{output_suffix}'")
except KeyError:
    raise KeyError(f"Dataset key '{current_dataset_key}' not found.")

# Run configuration
du_specifier = 'all'
force_recompute = False
generate_individual_plots = True

In [None]:
# Main processing entry point
if not root_dir.exists():
    print(f"Error: ROOT_DIR '{root_dir}' not found.")
else:
    print(f"--- Processing {du_specifier} ---")
    if force_recompute:
        print("--- Caching is OFF (force_recompute = True) ---")

    try:
        if not working_dir.exists():
            working_dir.mkdir(parents=True, exist_ok=True)
        os.chdir(working_dir)
        print(f"Working directory set to: {os.getcwd()}")
    except Exception as e:
        print(f"Warning: Could not set directory to {working_dir}: {e}")
        
    if not plot_output_dir.exists():
        try:
            plot_output_dir.mkdir(parents=True, exist_ok=True)
            print(f"Created plot output directory: {plot_output_dir}")
        except Exception as e:
            print(f"Warning: Could not create plot directory {plot_output_dir}: {e}")

    # Discover input files and group by energy
    du_label, fits_files = list_sim_fits(du_specifier=du_specifier) 

    if not fits_files:
        print(f"No FITS files found for {du_label} in {root_dir}. Exiting.")
    else:
        files_by_energy = defaultdict(list)
        for fname in fits_files:
            energy = parse_energy(fname)
            if not np.isnan(energy):
                files_by_energy[energy].append(fname)
        
        all_energies_found = sorted(files_by_energy.keys())
        
        energies_to_process = all_energies_found
        
        print(f"Found {len(fits_files)} files. Processing {len(energies_to_process)} energy points (full range).")

        cache_dir = working_dir / f'NewRMFsADP/fit-cache-{du_label}{output_suffix}'
        cache_dir.mkdir(parents=True, exist_ok=True)

        all_energies_for_du = []
        all_peak_results_for_du = []
        all_low_results_for_du = []

        output_pdf_individual = plot_output_dir / f'process-all-sim-{du_label}{output_suffix}.pdf'
        pdf = None
        
        if generate_individual_plots:
            print(f"Individual plots: {output_pdf_individual}")
            pdf = PdfPages(output_pdf_individual)

        # Process each energy (cache or compute)
        for e_kev_current_file in energies_to_process:
            base_name = f"sim-{e_kev_current_file*1000:05.0f}-{du_label}"
            cache_path_peak = cache_dir / f"{base_name}-peak.npz"
            cache_path_low = cache_dir / f"{base_name}-low.npz"

            try:
                # Cache load
                if cache_path_peak.exists() and cache_path_low.exists() and not force_recompute:
                    print(f"Loading {e_kev_current_file:.2f} keV from cache...")
                    
                    with np.load(cache_path_peak, allow_pickle=True) as data:
                        outputs_peak = {key: data[key].item() if data[key].ndim == 0 else data[key] for key in data}
                    with np.load(cache_path_low, allow_pickle=True) as data:
                        outputs_low = {key: data[key].item() if data[key].ndim == 0 else data[key] for key in data}
                    
                    all_energies_for_du.append(e_kev_current_file)
                    all_peak_results_for_du.append(outputs_peak)
                    all_low_results_for_du.append(outputs_low)
                    continue

                print(f"Processing {e_kev_current_file:.2f} keV...")
                
                # Compute per-energy outputs
                current_files = files_by_energy[e_kev_current_file]
                d = get_combined_data(current_files)
                if d is None:
                    print(f"    Warning: No data loaded for {e_kev_current_file:.2f} keV.")
                    continue

                pi_full, _, _, _ = extract_common_data(d, 1, 374, bgflag=False)
                if len(pi_full) == 0:
                    print(f"    Warning: No data in full PI range for {e_kev_current_file:.2f} keV.")
                    continue

                peak_lo, peak_hi = find_pi_peak_band(pi_full, bin_width=1, min_prominence=0.1)
                non_lo, non_hi   = 1, peak_lo 
                print(f"    PI Bands - Peak: {peak_lo:.1f}-{peak_hi:.1f}, Tail: {non_lo:.1f}-{non_hi:.1f}")

                base_title = f"SIM {e_kev_current_file:.2f} keV ({du_label}) "
                title_non = f"{base_title}– Tail PI {int(non_lo)}–{int(non_hi)}"
                title_pk = f"{base_title}– PI {int(peak_lo)}–{int(peak_hi)}"

                outputs_peak, fit_figs_peak = fit_mu_alpha(d, peak_lo, peak_hi, nalpha=10, title=title_pk)
                outputs_low, fit_figs_low = fit_mu_alpha(d, non_lo, non_hi, nalpha=10, title=title_non)

                pi_low, _, alpha_low, _ = extract_common_data(d, non_lo, non_hi, bgflag=False)
                n_tail_pi_bins = 4
                low_pi_centers, low_pi_alpha_mean, low_pi_counts = summarize_alpha_vs_pi_bins(
                    pi_low, alpha_low, non_lo, non_hi, n_bins=n_tail_pi_bins
                )
                outputs_low["alpha_pi_low_centers"] = low_pi_centers
                outputs_low["alpha_pi_low_mean"] = low_pi_alpha_mean
                outputs_low["alpha_pi_low_counts"] = low_pi_counts

                alpha_edges = np.linspace(0, 1.0, 21)
                tail_pi_edges = np.linspace(non_lo, non_hi, n_tail_pi_bins + 1)
                tail_hist = np.zeros((n_tail_pi_bins, len(alpha_edges) - 1), dtype=float)
                tail_median = np.full(n_tail_pi_bins, np.nan)
                tail_q1 = np.full(n_tail_pi_bins, np.nan)
                tail_q3 = np.full(n_tail_pi_bins, np.nan)
                for i in range(n_tail_pi_bins):
                    sel = (pi_low >= tail_pi_edges[i]) & (pi_low < tail_pi_edges[i + 1])
                    if np.any(sel):
                        tail_hist[i], _ = np.histogram(alpha_low[sel], bins=alpha_edges)
                        tail_median[i] = np.nanmedian(alpha_low[sel])
                        tail_q1[i] = np.nanpercentile(alpha_low[sel], 25)
                        tail_q3[i] = np.nanpercentile(alpha_low[sel], 75)

                outputs_low["alpha_pi_tail_hist"] = tail_hist
                outputs_low["alpha_pi_tail_edges"] = alpha_edges
                outputs_low["alpha_pi_tail_pi_edges"] = tail_pi_edges
                outputs_low["alpha_pi_tail_median"] = tail_median
                outputs_low["alpha_pi_tail_q1"] = tail_q1
                outputs_low["alpha_pi_tail_q3"] = tail_q3

                all_energies_for_du.append(e_kev_current_file)
                all_peak_results_for_du.append(outputs_peak)
                all_low_results_for_du.append(outputs_low)
                
                # Cache outputs
                np.savez_compressed(cache_path_peak, **outputs_peak)
                np.savez_compressed(cache_path_low, **outputs_low)
                print(f"    Saved to cache.")

                # Per-energy plots
                if pdf: 
                    figs_full = plot_pi_alpha(d, 1, 374, base_title + '- PI 1-374')
                    for fig in figs_full or []: pdf.savefig(fig); plt.close(fig)

                    figs_non = plot_pi_alpha(d, non_lo, non_hi, title_non, dist_component='Tail', a_color='cornflowerblue')
                    if figs_non and len(figs_non) > 1: pdf.savefig(figs_non[1]);
                    for fig in figs_non or []: plt.close(fig)

                    figs_peak = plot_pi_alpha(d, peak_lo, peak_hi, title_pk, dist_component='Peak', a_color='deeppink')
                    if figs_peak and len(figs_peak) > 1: pdf.savefig(figs_peak[1]);
                    for fig in figs_peak or []: plt.close(fig)

                    for fig in fit_figs_peak or []: pdf.savefig(fig); plt.close(fig)
                    for fig in fit_figs_low or []: plt.close(fig)

            except Exception as e:
                print(f"ERROR processing {e_kev_current_file:.2f} keV: {e}")

        if pdf:
            pdf.close()
            print(f"Individual PDF generation complete.")

        # Summary PDF (2-8 keV)
        if all_energies_for_du:
            summary_mask = [(2.0 <= e <= 8.0) for e in all_energies_for_du]
            if any(summary_mask):
                summary_energies = [e for e, m in zip(all_energies_for_du, summary_mask) if m]
                summary_peak_results = [r for r, m in zip(all_peak_results_for_du, summary_mask) if m]
                summary_tail_results = [r for r, m in zip(all_low_results_for_du, summary_mask) if m]
                summary_filename = plot_output_dir / f"simulation-summary-plots-{du_label}{output_suffix}.pdf"
                generate_summary_sim_plots(summary_energies, summary_peak_results, summary_tail_results,
                                           du_label, filename=str(summary_filename))
            else:
                print("No energies in 2–8 keV range. Skipping summary.")
        else:
            print(f"No valid data processed. Skipping summary.")

    print(f"--- Processing complete ---")