In [None]:
# Imports
from joblib import Parallel, delayed
import matplotlib.pyplot as plt
from matplotlib.ticker import FixedLocator, FuncFormatter
import numpy as np
import pandas as pd
import pickle
import os
import scipy as sp
from scipy.constants import hbar, physical_constants
from scipy.optimize import minimize
from scipy.stats import chi2
from tqdm import tqdm

# Retrieve physical constants
e_au = physical_constants["atomic unit of charge"][0]

# Conversion factors
s_to_eVminus1 = e_au / hbar

In [None]:
def plot_polarization(df_dark, df_blue, df_red, day):
    """
    Plot the polarization versus time for the three datasets.
    
    Parameters:
    - df_dark (pd.DataFrame): DataFrame for the dark dataset.
    - df_blue (pd.DataFrame): DataFrame for the blue dataset.
    - df_red (pd.DataFrame): DataFrame for the red dataset.
    - day (str): Day of the observation.
    
    Returns:
        None
    """

    # Create a figure
    plt.figure(figsize=(10, 6))

    # Plot the polarization versus time
    plt.scatter(df_dark["time"], df_dark["phi"], marker="x", color="black", label="CARMA")
    plt.scatter(df_blue["time"], df_blue["phi"], marker="x", color="blue", label="SMTL-CARMAR")
    plt.scatter(df_red["time"], df_red["phi"], marker="x", color="red", label="SMTR-CARMAL")

    # Plot error bars
    plt.errorbar(df_dark["time"], df_dark["phi"], yerr=df_dark["sigma"], fmt="none", color="black", capsize=5)
    plt.errorbar(df_blue["time"], df_blue["phi"], yerr=df_blue["sigma"], fmt="none", color="blue", capsize=5)
    plt.errorbar(df_red["time"], df_red["phi"], yerr=df_red["sigma"], fmt="none", color="red", capsize=5)

    # Add labels and title
    plt.xlabel("UT Hour", fontsize=12)
    plt.ylabel(r"Angle ($^\circ$)", fontsize=12)
    plt.title(f"Polarization versus Time ({day})", fontsize=14)

    # Add legend
    plt.legend(loc="upper right", fontsize=10)

    # Show the plot
    plt.show()

def plot_interp_func(interp_data, i):
    """
    Plot the interpolated function for a given index.

    Parameters:
    - interp_data (list): List of interpolated functions.
    - i (int): Index of the function to plot.   
    
    Returns:
        None
    """

    # Extract the function
    func = interp_data[i]

    # Get the x limits
    xmin = func.x.min()
    xmax = func.x.max()

    # Get the x and y values
    x = np.linspace(xmin, xmax, 1000)
    y = func(x)

    # Create a figure
    plt.figure(figsize=(10, 6))

    # Plot the function
    plt.plot(x, y)

    # Add labels and title
    plt.xlabel("Time (hours)", fontsize=12)
    plt.ylabel(r"$\Delta \phi$ (rad)", fontsize=12)
    plt.title(f"Interpolated Function {i}", fontsize=14)

    # Show the plot
    plt.show()

def mass_to_period(m):
    """
    Convert mass to period in hours.
    
    Parameters:
    - m (float): Mass value in eV.
    
    Returns:
    - period_hr (float): Period value in hours.
    """

    # Convert mass to period
    period_hr = 2 * np.pi / m / s_to_eVminus1 / 3600
    return period_hr

def period_to_mass(period_hr):
    """
    Convert period to mass in eV.

    Parameters:
    - period_hr (float): Period value in hours.

    Returns:
    - m (float): Mass value in eV.
    """

    # Convert period to mass
    m = 2 * np.pi / period_hr / s_to_eVminus1 / 3600
    return m

def extract_data_obs(data):
    """
    Extract data from the observation DataFrame.

    Parameters:
    - data (pd.DataFrame): DataFrame containing the observation data.

    Returns:
    - times_obs (np.ndarray): Array of observation times.
    - phis_obs (np.ndarray): Array of observation angles.
    - sigmas_obs (np.ndarray): Array of observation errors
    """

    # Extract data from observation
    times_obs = data["time"].to_numpy()
    phis_obs = data["phi"].to_numpy()
    sigmas_obs = data["sigma"].to_numpy()
    return times_obs, phis_obs, sigmas_obs

def compute_chisq(phis_obs, sigmas_obs, phis_theory):
    """
    Compute the chi-squared value for the given observation and theoretical values.
    
    Parameters:
    - phis_obs (np.ndarray): Array of observed angles.
    - sigmas_obs (np.ndarray): Array of observed errors.
    - phis_theory (np.ndarray): Array of theoretical angles.

    Returns:
    - chisq (float): Chi-squared value.
    """

    # Compute chi-squared value
    chisq = np.sum((phis_obs - phis_theory) ** 2 / sigmas_obs ** 2)
    return chisq

def compute_chisqs_params(i, params, times_obs, phis_obs, sigmas_obs, interp_data, phi_bkgs, divisions=100):
    """
    Compute chi-squared values for a single parameter set.
    
    Parameters:
    - i (int): Index of the parameter set.
    - params (np.ndarray): Array of parameter sets.
    - times_obs (np.ndarray): Array of observation times.
    - phis_obs (np.ndarray): Array of observation angles.
    - sigmas_obs (np.ndarray): Array of observation errors.
    - interp_data (list): List of interpolated functions.
    - phi_bkgs (np.ndarray): Array of background angles.
    - divisions (int, optional): Number of divisions for the period.
    
    Returns:
    - chisqs (list): List of chi-squared values for the parameter set.
    """

    # Compute chi-squared values for a single parameter set
    chisqs = []
    
    # Extract the parameters
    m, epsilon = params[i]

    # Calculate the period in hours
    period_hr = mass_to_period(m)

    # Generate initial phases
    phases = np.linspace(0, period_hr, divisions)

    # Get the theoretical best-fit function
    delta_phi_func = interp_data[i]

    for phase in phases:
        # Time shift calculation
        times = (times_obs + phase) % period_hr
        
        # Get the theoretical delta phi values
        delta_phis = delta_phi_func(times) * 180 / np.pi

        # Normalize the delta phi values
        delta_phis = delta_phis - np.mean(delta_phis)

        # Calculate the theoretical phi values
        for phi_bkg in phi_bkgs:
            phis_theory = delta_phis + phi_bkg

            # Calculate chi-squared values
            chisq = compute_chisq(phis_obs, sigmas_obs, phis_theory)
            chisqs.append([chisq, phase, phi_bkg])

    return chisqs

def compute_chisqs_all(params, times_obs, phis_obs, sigmas_obs, interp_data, phi_bkgs, n_jobs=-1, divisions=100):
    """
    Compute chi-squared values for all parameter sets.

    Parameters:
    - params (np.ndarray): Array of parameter sets.
    - times_obs (np.ndarray): Array of observation times.
    - phis_obs (np.ndarray): Array of observation angles.
    - sigmas_obs (np.ndarray): Array of observation errors.
    - interp_data (list): List of interpolated functions.
    - phi_bkgs (np.ndarray): Array of background angles.
    - n_jobs (int, optional): Number of jobs for parallel processing.
    - divisions (int, optional): Number of divisions for the period.

    Returns:
    - chisqs_all (list): List of chi-squared values for all parameter sets.
    """

    # Compute chi-squared values for all parameter sets
    chisqs_all = Parallel(n_jobs=n_jobs)(
        delayed(compute_chisqs_params)(i, params, times_obs, phis_obs, sigmas_obs, interp_data, phi_bkgs, divisions=divisions) 
        for i in tqdm(range(len(params)))
    )

    return chisqs_all

def objective_chisq(phase_phi_bkg, times_obs, period_hr, phis_obs, sigmas_obs, delta_phi_func):
    """
    Objective function to minimize the chi-squared value for a given parameter set.

    Parameters:
    - phase_phi_bkg (np.ndarray): Array of phase and background angle.
    - times_obs (np.ndarray): Array of observation times.
    - period_hr (float): Period in hours.
    - phis_obs (np.ndarray): Array of observation angles.
    - sigmas_obs (np.ndarray): Array of observation errors.
    - delta_phi_func (callable): Interpolated function for delta phi.

    Returns:
    - chisq (float): Chi-squared value.
    """

    # Extract the parameters
    phase, phi_bkg = phase_phi_bkg

    # Time shift calculation
    times = (times_obs + phase) % period_hr

    # Get the theoretical delta phi values
    delta_phis = delta_phi_func(times) * 180 / np.pi

    # Normalize the delta phi values
    delta_phis = delta_phis - np.mean(delta_phis)

    # Calculate the theoretical phi values
    phis_theory = delta_phis + phi_bkg

    # Calculate the chi-squared value
    chisq = compute_chisq(phis_obs, sigmas_obs, phis_theory)

    return chisq

def compute_chisqs_min_total(params, times_obs, phis_obs, sigmas_obs, interp_data, phi_bkgs, method="L-BFGS-B"):
    """
    Compute the minimized chi-squared values, phases, and background angles for all parameter sets.

    Parameters:
    - params (np.ndarray): Array of parameter sets.
    - times_obs (np.ndarray): Array of observation times.
    - phis_obs (np.ndarray): Array of observation angles.
    - sigmas_obs (np.ndarray): Array of observation errors.
    - interp_data (list): List of interpolated functions.
    - phi_bkgs (np.ndarray): Array of background angles.
    - method (str, optional): Optimization method. Default is "L-BFGS-B".

    Returns:
    - results (np.ndarray): Array of minimized chi-squared values, phases, and background angles for all parameter sets.
    """

    # Initialize the result array
    chisqs_min_total = np.empty((len(params), 3))

    # Iterate over all (mass, epsilon) pairs
    for i, (m, epsilon) in enumerate(tqdm(params)):
        # Calculate the period in hours
        period_hr = mass_to_period(m)

        # Get the interpolated function
        delta_phi_func = interp_data[i]

        # Initial guess for phase and phi_bkg
        initial_guess = [0, np.mean(phi_bkgs)]

        # Bounds for phase and phi_bkg
        bounds = [(0, period_hr), (min(phi_bkgs), max(phi_bkgs))]

        # Minimize the chi-squared value
        result = minimize(
            objective_chisq,
            initial_guess,
            args=(times_obs, period_hr, phis_obs, sigmas_obs, delta_phi_func),
            bounds=bounds,
            method=method
        )

        # Store the results
        if result.success:
            chisqs_min_total[i] = [result.fun, result.x[0], result.x[1]]
    return chisqs_min_total

def extract_chisqs_min(chisqs_all):
    """
    Extract the minimum chi-squared values for all parameter sets.

    Parameters:
    - chisqs_all (list): List of chi-squared values for all parameter sets.
    
    Returns:
    - chisqs_min (np.ndarray): Array of minimum chi-squared values.
    """

    # Initialize an empty array to store the results
    chisqs_min = np.empty((len(chisqs_all), 3))
    
    for i in range(len(chisqs_all)):
        # Find the index of the minimum chi-squared value for the i-th parameter combination
        min_idx = np.argmin(chisqs_all[i][:, 0])

        # Store the minimum chi-squared value, the corresponding phase and phi_bkg
        chisqs_min[i] = chisqs_all[i][min_idx]
    
    return chisqs_min

def define_thresholds():
    """
    Define thresholds for 95% and 90% confidence levels.

    Returns:
    - threshold95 (float): Threshold for 95% confidence level.
    - threshold90 (float): Threshold for 90% confidence level.
    """

    # Define thresholds for 95% and 90% confidence levels
    threshold95 = chi2.ppf(0.95, df=1)
    threshold90 = chi2.ppf(0.90, df=1)
    print(f"Threshold for 95% CL: {threshold95:.4f}")
    print(f"Threshold for 90% CL: {threshold90:.4f}")

    return threshold95, threshold90

def extract_verify_parameters(params):
    """
    Extract and verify the structure of input parameters.
    
    Parameters:
    - params (np.ndarray): Array of input parameters.
    
    Returns:
    - unique_masses (np.ndarray): Array of unique mass values.
    - unique_epsilons (np.ndarray): Array of unique epsilon values.
    - num_masses (int): Number of unique mass values.
    - num_epsilons (int): Number of unique epsilon values.
    - epsilons_per_mass (np.ndarray): Array of epsilon values per mass.
    """

    # Extract and verify the structure of input parameters
    masses = params[:, 0]
    epsilons = params[:, 1]

    # Extract unique masses and epsilons
    unique_masses = np.unique(masses)
    unique_epsilons = np.unique(epsilons)

    # Get the number of unique masses and epsilons
    num_masses = len(unique_masses)
    num_epsilons = len(unique_epsilons)
    
    # Verify the structure of the input parameters
    assert num_masses * num_epsilons == len(params), "Parameter array size mismatch"
    
    epsilons_per_mass = epsilons.reshape(num_masses, num_epsilons)
    
    return unique_masses, unique_epsilons, num_masses, num_epsilons, epsilons_per_mass

def compute_epsilon_limits(period, epsilons_per_mass, num_masses, num_epsilons, threshold95, threshold90, scheme="grid-search"):
    """
    Compute epsilon limits for a given period.

    Parameters:
    - period (int): Period value.
    - epsilons_per_mass (np.ndarray): Array of epsilon values per mass.
    - num_masses (int): Number of unique mass values.
    - num_epsilons (int): Number of unique epsilon values.
    - threshold95 (float): Threshold for 95% confidence level.
    - threshold90 (float): Threshold for 90% confidence level.

    Returns:
    - chisqs_per_mass (np.ndarray): Array of chi-squared values per mass.
    - chisq_min_per_mass (np.ndarray): Array of minimum chi-squared values per mass.
    - epsilon_min_per_mass (np.ndarray): Array of corresponding epsilon values per mass.
    - upper_limit_epsilons95 (np.ndarray): Array of 95% CL upper limits on epsilon per mass.
    - upper_limit_epsilons90 (np.ndarray): Array of 90% CL upper limits on epsilon per mass.
    """

    if scheme == "grid-search":
        # Load chi-squared values for the given period
        chisqs_all_total = np.load(f"chisqs_all_total_day{period}.npy", allow_pickle=True)
        chisqs_min_total = extract_chisqs_min(chisqs_all_total)
    elif scheme == "scipy-minimize":
        chisqs_min_total = np.load(f"chisqs_min_total_day{period}.npy", allow_pickle=True)
    chisqs_min = chisqs_min_total[:, 0]
    chisqs_per_mass = chisqs_min.reshape(num_masses, num_epsilons)
    
    # Initialize arrays to store results
    chisq_min_per_mass = np.zeros(num_masses)
    epsilon_min_per_mass = np.zeros(num_masses)
    upper_limit_epsilons95 = np.zeros(num_masses)
    upper_limit_epsilons90 = np.zeros(num_masses)
    
    for i in range(num_masses):
        # Extract chi-squared values and epsilons for the i-th mass
        chisqs = chisqs_per_mass[i]
        epsilons = epsilons_per_mass[i]
        
        # Find the minimum chi-squared value and corresponding epsilon
        chisq_min_idx = np.argmin(chisqs)
        chisq_min = chisqs[chisq_min_idx]
        epsilon_min = epsilons[chisq_min_idx]
        
        # Store the minimum chi-squared value and corresponding epsilon
        chisq_min_per_mass[i] = chisq_min
        epsilon_min_per_mass[i] = epsilon_min
        
        # Compute the threshold for this mass
        chisq_threshold95 = chisq_min + threshold95
        chisq_threshold90 = chisq_min + threshold90
        
        # Sort by epsilon in ascending order to find crossing point
        sort_idx = np.argsort(epsilons)
        epsilons_sorted = epsilons[sort_idx]
        chisqs_sorted = chisqs[sort_idx]
        epsilon_min_idx = np.where(epsilons_sorted == epsilon_min)[0][0]

        # Find the upper limit for 95% and 90% CL, store the results
        upper_limit_epsilons95[i] = find_upper_limit(epsilons_sorted, chisqs_sorted, epsilon_min_idx, chisq_threshold95)[0]
        upper_limit_epsilons90[i] = find_upper_limit(epsilons_sorted, chisqs_sorted, epsilon_min_idx, chisq_threshold90)[0]
    
    return chisqs_per_mass, chisq_min_per_mass, epsilon_min_per_mass, upper_limit_epsilons95, upper_limit_epsilons90

def find_upper_limit(epsilons_sorted, chisqs_sorted, start_idx, threshold):
    """
    Find the upper limit for a given threshold.
    
    Parameters:
    - epsilons_sorted (np.ndarray): Array of epsilon values sorted in ascending order.
    - chisqs_sorted (np.ndarray): Array of chi-squared values sorted in ascending order.
    - start_idx (int): Starting index for the search.
    - threshold (float): Threshold value.
    
    Returns:
    - epsilon_limit (float): Upper limit on epsilon.
    - chisq_limit (float): Corresponding chi-squared value.
    """

    # Find the upper limit for a given threshold
    for j in range(start_idx, len(epsilons_sorted)):
        if chisqs_sorted[j] > threshold:
            return epsilons_sorted[j], chisqs_sorted[j]
    return np.nan, np.nan

def plot_epsilon_limits(unique_masses, upper_limit_epsilons95_day82, upper_limit_epsilons90_day82, upper_limit_epsilons95_day81to82, upper_limit_epsilons90_day81to82):
    """
    Plot upper limits on epsilon vs mu for day 82 and day 81 to 82.

    Parameters:
    - unique_masses (np.ndarray): Array of unique mass values.
    - upper_limit_epsilons95_day82 (np.ndarray): 95% CL upper limits on epsilon for day 82.
    - pper_limit_epsilons90_day82 (np.ndarray): 90% CL upper limits on epsilon for day 82.
    - upper_limit_epsilons95_day81to82 (np.ndarray): 95% CL upper limits on epsilon for day 81 to 82.
    - upper_limit_epsilons90_day81to82 (np.ndarray): 90% CL upper limits on epsilon for day 81 to 82.

    Returns:
        None
    """

    # Create masks to remove NaN values for day 82
    mask95_day82 = ~np.isnan(upper_limit_epsilons95_day82)
    mask90_day82 = ~np.isnan(upper_limit_epsilons90_day82)
    
    # Extract valid masses and upper limits for day 82
    valid_masses95_day82 = unique_masses[mask95_day82]
    valid_upper_limits95_day82 = upper_limit_epsilons95_day82[mask95_day82]
    valid_masses90_day82 = unique_masses[mask90_day82]
    valid_upper_limits90_day82 = upper_limit_epsilons90_day82[mask90_day82]

    # Create masks to remove NaN values for day 81 to 82
    mask95_day81to82 = ~np.isnan(upper_limit_epsilons95_day81to82)
    mask90_day81to82 = ~np.isnan(upper_limit_epsilons90_day81to82)
    
    # Extract valid masses and upper limits for day 81 to 82
    valid_masses95_day81to82 = unique_masses[mask95_day81to82]
    valid_upper_limits95_day81to82 = upper_limit_epsilons95_day81to82[mask95_day81to82]
    valid_masses90_day81to82 = unique_masses[mask90_day81to82]
    valid_upper_limits90_day81to82 = upper_limit_epsilons90_day81to82[mask90_day81to82]

    # Create a figure
    plt.figure(figsize=(8, 6))

    # Plot upper limits for day 82
    plt.plot(np.log10(valid_masses95_day82), np.log10(valid_upper_limits95_day82), 
             color="blue", label="95% CL (1 Day)")
    plt.plot(np.log10(valid_masses90_day82), np.log10(valid_upper_limits90_day82), 
             color="blue", linestyle="--", label="90% CL (1 Day)")

    # Plot upper limits for day 81 to 82
    plt.plot(np.log10(valid_masses95_day81to82), np.log10(valid_upper_limits95_day81to82), 
             color="red", label="95% CL (2 Days)")
    plt.plot(np.log10(valid_masses90_day81to82), np.log10(valid_upper_limits90_day81to82), 
             color="red", linestyle="--", label="90% CL (2 Days)")

    # Set axis limits and ticks
    all_valid_masses = np.concatenate([valid_masses95_day82, valid_masses90_day82, 
                                       valid_masses95_day81to82, valid_masses90_day81to82])
    all_valid_limits = np.concatenate([valid_upper_limits95_day82, valid_upper_limits90_day82, 
                                       valid_upper_limits95_day81to82, valid_upper_limits90_day81to82])
    
    x_min, x_max = np.floor(np.log10(all_valid_masses.min())) - 0.5, np.floor(np.log10(all_valid_masses.max())) + 0.5
    y_min, y_max = np.floor(np.log10(all_valid_limits.min())) - 0.5, 0
    x_ticks = np.arange(x_min, x_max + 0.5, 0.5)
    y_ticks = np.arange(y_min, y_max + 1, 1)
    plt.gca().xaxis.set_major_locator(FixedLocator(x_ticks))
    plt.gca().yaxis.set_major_locator(FixedLocator(y_ticks))
    plt.gca().xaxis.set_major_formatter(FuncFormatter(lambda x, pos: f"{x:.1f}"))
    plt.gca().yaxis.set_major_formatter(FuncFormatter(lambda x, pos: f"{x:.1f}"))
    plt.ylim(y_min, y_max)

    # Add vertical lines at the edges
    plt.vlines(np.log10(all_valid_masses[0]), np.log10(all_valid_limits[0]), y_max, colors="black")
    plt.vlines(np.log10(all_valid_masses[-1]), np.log10(all_valid_limits[-1]), y_max, colors="black")

    # Fill the area between the lines for day 82
    plt.fill_between(np.log10(valid_masses95_day82), np.log10(valid_upper_limits95_day82), y_max, 
                     color="blue", alpha=0.1)
    plt.fill_between(np.log10(valid_masses90_day82), np.log10(valid_upper_limits90_day82), y_max, 
                     color="blue", alpha=0.05)

    # Fill the area between the lines for day 81 to 82
    plt.fill_between(np.log10(valid_masses95_day81to82), np.log10(valid_upper_limits95_day81to82), y_max, 
                     color="red", alpha=0.1)
    plt.fill_between(np.log10(valid_masses90_day81to82), np.log10(valid_upper_limits90_day81to82), y_max, 
                     color="red", alpha=0.05)

    # Set labels and title
    plt.xlabel(r"$\log_{10}(\mu / \text{eV})$", fontsize=12)
    plt.ylabel(r"$\log_{10}(\epsilon)$", fontsize=12)
    plt.title(r"Upper Limits on Dark Photon-Photon Coupling", fontsize=14)

    # Add a legend
    plt.legend(loc="lower left", fontsize=10)

    # Adjust the spacing
    plt.tight_layout()

    # Show the plot
    plt.show()

def plot_chisq_coupling(axes, unique_masses, chisqs_per_mass, epsilons_per_mass, chisq_min_per_mass, epsilon_min_per_mass, threshold95, threshold90):
    """
    Plot chi-squared values against coupling for each mass on the provided subplots.
    
    Parameters:
    - axes (np.ndarray): Array of subplots.
    - unique_masses (np.ndarray): Array of unique mass values.
    - chisqs_per_mass (np.ndarray): Array of chi-squared values per mass.
    - epsilons_per_mass (np.ndarray): Array of epsilon values per mass.
    - chisq_min_per_mass (np.ndarray): Array of minimum chi-squared values per mass.
    - epsilon_min_per_mass (np.ndarray): Array of corresponding epsilon values per mass.
    - threshold95 (float): Threshold for 95% confidence level.
    - threshold90 (float): Threshold for 90% confidence level.
    
    Returns:
        None
    """
    
    # Plot chi-squared values against coupling for each mass on the provided subplots
    for i, mass in enumerate(unique_masses):
        if i >= len(axes):
            break
        
        # Extract precomputed values for the given mass
        chisqs = chisqs_per_mass[i]
        epsilons = epsilons_per_mass[i]
        chisq_min = chisq_min_per_mass[i]
        epsilon_min = epsilon_min_per_mass[i]
        
        # Compute the threshold for this mass
        chisq_threshold95 = chisq_min + threshold95
        chisq_threshold90 = chisq_min + threshold90
        
        # Sort by epsilon in ascending order to find crossing point
        sort_idx = np.argsort(epsilons)
        epsilons_sorted = epsilons[sort_idx]
        chisqs_sorted = chisqs[sort_idx]
        min_idx_sorted = np.where(epsilons_sorted == epsilon_min)[0][0]
        
        # Find the point where chi-squared just exceeds threshold
        epsilon_cross, chisqs_cross = find_upper_limit(epsilons_sorted, chisqs_sorted, min_idx_sorted, chisq_threshold95)
        
        # Create plot on the corresponding subplot
        plot_single_subplot(axes[i], mass, epsilons, epsilon_min, chisqs, chisq_min, threshold95, threshold90, chisq_threshold95, chisq_threshold90, epsilon_cross, chisqs_cross)

def plot_single_subplot(ax, mass, epsilons, epsilon_min, chisqs, chisq_min, threshold95, threshold90, chisq_threshold95, chisq_threshold90, epsilon_cross, chisqs_cross):
    """
    Plot chi-squared values against coupling for a single mass on a given subplot.
    
    Parameters:
    - ax (plt.Axes): Subplot to plot on.
    - mass (float): Mass value in eV.
    - epsilons (np.ndarray): Array of epsilon values.
    - epsilon_min (float): Epsilon value corresponding to the minimum chi-squared.
    - chisqs (np.ndarray): Array of chi-squared values.
    - chisq_min (float): Minimum chi-squared value.
    - threshold95 (float): Threshold for 95% confidence level.
    - threshold90 (float): Threshold for 90% confidence level.
    - chisq_threshold95 (float): Threshold for 95% confidence level.
    - chisq_threshold90 (float): Threshold for 90% confidence level.
    - epsilon_cross (float): Epsilon value where chi-squared just exceeds threshold.
    - chisqs_cross (float): Chi-squared value where chi-squared just exceeds threshold.
    
    Returns:
        None
    """
    
    # Plot chi-squared vs coupling for a single mass on a given subplot
    ax.semilogx(epsilons, chisqs, "o-", color="blue", markersize=3)
    
    # Add horizontal lines at 95% and 90% CL
    ax.axhline(y=chisq_threshold95, color="r", linestyle="--", label=r"95% CL ($\chi^2_\text{min}$"+f"+{threshold95:.2f})")
    ax.axhline(y=chisq_threshold90, color="orange", linestyle="--", label=r"90% CL ($\chi^2_\text{min}$"+f"+{threshold90:.2f})")

    # Add horizontal line at minimum chi-squared
    ax.axhline(y=chisq_min, color="g", linestyle="--", label=r"$\chi^2_\text{min}$")
    
    # Highlight minimum point
    ax.plot(epsilon_min, chisq_min, "go", markersize=5)
    
    # Circle the point where chi-squared just exceeds threshold
    if not np.isnan(epsilon_cross):
        ax.plot(epsilon_cross, chisqs_cross, "o", markersize=10, markeredgecolor="red", markerfacecolor="none")
    
    # Add labels and title
    ax.set_xlabel(r"$\epsilon$", fontsize=12)
    ax.set_ylabel(r"$\chi^2$", fontsize=12)
    mass_coefficient = mass / 10**np.floor(np.log10(mass))
    mass_exponent = int(np.floor(np.log10(mass)))

    # Set the title
    ax.set_title(rf"$\mu={mass_coefficient:.2f} \times 10^{{{mass_exponent}}}$ eV", fontsize=14)

    # Add a legend
    ax.legend(loc="upper right", fontsize=10)

def hide_unused_subplots(axes, num_used):
    # Hide unused subplots in a figure
    for j in range(num_used, len(axes)):
        axes[j].axis("off")

In [None]:
# Define the data folder
data_folder = "sgrdata"
day81_folder = os.path.join(data_folder, "day81")
day82_folder = os.path.join(data_folder, "day82")

# Define the file paths for day 81
file_dark_day81 = os.path.join(day81_folder, "dark.npy")
file_blue1_day81 = os.path.join(day81_folder, "blue1.npy")
file_blue2_day81 = os.path.join(day81_folder, "blue2.npy")
file_red1_day81 = os.path.join(day81_folder, "red1.npy")
file_red2_day81 = os.path.join(day81_folder, "red2.npy")

# Define the file paths for day 82
file_dark_day82 = os.path.join(day82_folder, "dark.npy")
file_blue1_day82 = os.path.join(day82_folder, "blue1.npy")
file_blue2_day82 = os.path.join(day82_folder, "blue2.npy")
file_red1_day82 = os.path.join(day82_folder, "red1.npy")
file_red2_day82 = os.path.join(day82_folder, "red2.npy")

# Read the data for day 81
data_dark_day81 = np.load(file_dark_day81)
data_blue1_day81 = np.load(file_blue1_day81)
data_blue2_day81 = np.load(file_blue2_day81)
data_red1_day81 = np.load(file_red1_day81)
data_red2_day81 = np.load(file_red2_day81)

# Read the data for day 82
data_dark_day82 = np.load(file_dark_day82)
data_blue1_day82 = np.load(file_blue1_day82)
data_blue2_day82 = np.load(file_blue2_day82)
data_red1_day82 = np.load(file_red1_day82)
data_red2_day82 = np.load(file_red2_day82)

# Convert dark data to DataFrames and add a sigma column with default value 3
df_dark_day81 = pd.DataFrame(data_dark_day81, columns=["time", "phi"])
df_dark_day81["sigma"] = 3

df_dark_day82 = pd.DataFrame(data_dark_day82, columns=["time", "phi"])
df_dark_day82["sigma"] = 3

# Convert blue and red data to DataFrames for day 81
df_blue1_day81 = pd.DataFrame(data_blue1_day81, columns=["time", "phi"])
df_blue2_day81 = pd.DataFrame(data_blue2_day81, columns=["time", "phi"])
df_red1_day81 = pd.DataFrame(data_red1_day81, columns=["time", "phi"])
df_red2_day81 = pd.DataFrame(data_red2_day81, columns=["time", "phi"])

# Convert blue and red data to DataFrames for day 82
df_blue1_day82 = pd.DataFrame(data_blue1_day82, columns=["time", "phi"])
df_blue2_day82 = pd.DataFrame(data_blue2_day82, columns=["time", "phi"])
df_red1_day82 = pd.DataFrame(data_red1_day82, columns=["time", "phi"])
df_red2_day82 = pd.DataFrame(data_red2_day82, columns=["time", "phi"])

# Create combined blue and red DataFrames for day 81
df_blue_day81 = pd.DataFrame({
    "time": (df_blue1_day81["time"] + df_blue2_day81["time"]) / 2,
    "phi": (df_blue1_day81["phi"] + df_blue2_day81["phi"]) / 2,
    "sigma": abs(df_blue1_day81["phi"] - df_blue2_day81["phi"]) / 2
})

df_red_day81 = pd.DataFrame({
    "time": (df_red1_day81["time"] + df_red2_day81["time"]) / 2,
    "phi": (df_red1_day81["phi"] + df_red2_day81["phi"]) / 2,
    "sigma": abs(df_red1_day81["phi"] - df_red2_day81["phi"]) / 2
})

# Create combined blue and red DataFrames for day 82
df_blue_day82 = pd.DataFrame({
    "time": (df_blue1_day82["time"] + df_blue2_day82["time"]) / 2,
    "phi": (df_blue1_day82["phi"] + df_blue2_day82["phi"]) / 2,
    "sigma": abs(df_blue1_day82["phi"] - df_blue2_day82["phi"]) / 2
})

df_red_day82 = pd.DataFrame({
    "time": (df_red1_day82["time"] + df_red2_day82["time"]) / 2,
    "phi": (df_red1_day82["phi"] + df_red2_day82["phi"]) / 2,
    "sigma": abs(df_red1_day82["phi"] - df_red2_day82["phi"]) / 2
})

# Combine dark, blue, and red DataFrames for day 81
df_total_day81 = pd.concat([df_dark_day81, df_blue_day81, df_red_day81])
df_total_day81 = df_total_day81.sort_values(by="time").reset_index(drop=True)

# Combine dark, blue, and red DataFrames for day 82
df_total_day82 = pd.concat([df_dark_day82, df_blue_day82, df_red_day82])
df_total_day82 = df_total_day82.sort_values(by="time").reset_index(drop=True)

# Combine day 81 and day 82 DataFrames
df_total_day81to82 = pd.concat([df_total_day81, df_total_day82])
df_total_day81to82 = df_total_day81to82.sort_values(by="time").reset_index(drop=True)

In [None]:
# Read the pickle file
interp_file = os.path.join(data_folder, "interp_datanew.pkl")
with open(interp_file, "rb") as f:
    interp_data = pickle.load(f)

# Extract the parameters
params_data = np.loadtxt("paralist_all_SgrltNE.dat")
params = []
for i in range(len(params_data) // 20):
    m = params_data[i * 20, 1]
    epsilon = params_data[i * 20, 2]
    params.append([m, epsilon])

# Convert to numpy array
params = np.array(params)

In [None]:
# Plot the polarization versus time for day 82
plot_polarization(df_dark_day82, df_blue_day82, df_red_day82, "Day 82")

# Plot the interpolation function 
plot_interp_func(interp_data, 839)

In [None]:
# Define the datasets with labels
datasets_day81 = {"total": {"data": df_total_day81, "label": "All"}}

datasets_day82 = {"total": {"data": df_total_day82, "label": "All"}}

datasets_day81to82 = {"total": {"data": df_total_day81to82, "label": "All"}}

datasets_group = {
    "81": datasets_day81,
    "82": datasets_day82,
    "81to82": datasets_day81to82
}

# Define the periods
periods = ["81", "82", "81to82"]

# Define the number of divisions
divisions = 200

# Define the background phases
phi_bkgs = np.linspace(0, 180, divisions)

for period in periods:
    for colour, info in datasets_group[period].items():
        # Define the filename
        filename = f"chisqs_all_{colour}_day{period}.npy"
        min_filename = f"chisqs_min_{colour}_day{period}.npy"

        # Extract observational data
        data = info["data"]
        label = info["label"]
        times_obs, phis_obs, sigmas_obs = extract_data_obs(data)

        # Check if file already exists to avoid recomputing chi-squared values
        if os.path.exists(filename):
            print(f"File {filename} already exists. Skipping computation.")
        else:
            # Compute chi-squared values for all parameter combinations
            chisqs_all = compute_chisqs_all(params, times_obs, phis_obs, sigmas_obs, interp_data, phi_bkgs, divisions=divisions)
            
            # Save the results to a .npy file
            np.save(filename, chisqs_all)
            print(f"Saved {filename} for {label}.")

        # Check if minimum chi-squared file exists
        if os.path.exists(min_filename):
            print(f"File {min_filename} already exists. Skipping computation.")
        else:
            # Compute the minimum chi-squared values
            chisqs_min = compute_chisqs_min_total(params, times_obs, phis_obs, sigmas_obs, interp_data, phi_bkgs)
            
            # Save the results to a .npy file
            np.save(min_filename, chisqs_min)
            print(f"Saved {min_filename} for {label}.")

In [None]:
threshold95, threshold90 = define_thresholds()
unique_masses, unique_epsilons, num_masses, num_epsilons, epsilons_per_mass = extract_verify_parameters(params)

_, _, _, upper_limit_epsilons95_day82, upper_limit_epsilons90_day82 = compute_epsilon_limits("82", epsilons_per_mass, num_masses, num_epsilons, threshold95, threshold90)
chisqs_per_mass_day81to82, chisq_min_per_mass_day81to82, epsilon_min_per_mass_day81to82, upper_limit_epsilons95_day81to82, upper_limit_epsilons90_day81to82 = compute_epsilon_limits("81to82", epsilons_per_mass, num_masses, num_epsilons, threshold95, threshold90)

# Plot upper limits on epsilon for day 82 and day 81 to 82
plot_epsilon_limits(unique_masses, upper_limit_epsilons95_day82, upper_limit_epsilons90_day82, upper_limit_epsilons95_day81to82, upper_limit_epsilons90_day81to82)

# Create a figure with multiple subplots
fig, axes = plt.subplots(4, 4, figsize=(20, 20))
axes = axes.flatten()

# Plot chi-squared values against coupling for each mass
plot_chisq_coupling(axes, unique_masses, chisqs_per_mass_day81to82, epsilons_per_mass, chisq_min_per_mass_day81to82, epsilon_min_per_mass_day81to82, threshold95, threshold90)

# Hide unused subplots
hide_unused_subplots(axes, len(unique_masses))

# Adjust the spacing and show the plot
plt.tight_layout()
plt.show()