# AbsPL (Absorption and Photoluminescence) Analysis

## How to use this notebook:
1. Select batches to analyze (only batches of type "hysprint_batch" are considered)
2. The data will be loaded into a pandas DataFrame
3. Use the plotting tools to visualize your data:
   - Create scatter plots for comparing two parameters
   - Use box plots to analyze parameter distributions
4. Access advanced features for data table viewing and statistics

In [1]:
%matplotlib ipympl
%load_ext autoreload
%autoreload 2
import os
import base64
import io
import sys
import ipywidgets as widgets
import plotly.graph_objects as go
import plotly.express as px
from IPython.display import display, Markdown, HTML
import pandas as pd
import numpy as np
import json

sys.path.append(os.path.dirname(os.getcwd()))
from api_calls import get_ids_in_batch, get_sample_description, get_batch_ids,  get_all_eqe as get_all_trpl, get_all_batches_wth_data
import batch_selection
import plotting_utils
import access_token

url_base ="https://nomad-hzb-se.de"
url = f"{url_base}/nomad-oasis/api/v1"
token = access_token.get_token(url)
access_token.log_notebook_usage()

In [2]:
import matplotlib.pyplot as plt
import scipy.optimize
from functools import partial
from scipy.interpolate import make_splrep, generate_knots
from scipy.signal import savgol_filter
import matplotlib.pyplot as plt


import numpy as np
from scipy.interpolate import BSpline
from scipy.optimize import minimize


def fit_monotone_convex_spline_from_bspline(x, y, spl0,
                                            n_grid=200,
                                            tol_d1=0.0,
                                            tol_d2=0.0,
                                            constraint_tol=1e-6):
    """
    Given data (x, y) and an initial BSpline spl0, refit the coefficients so that
    the resulting spline is *approximately* monotone increasing (f' >= tol_d1)
    and convex (f'' >= tol_d2) on [x.min(), x.max()].

    Uses COBYLA. If constraints cannot be satisfied within `constraint_tol`,
    falls back to the original spline `spl0` instead of raising.
    """
    x = np.asarray(x)
    y = np.asarray(y)

    # Need at least a few points
    t_spline = spl0.t
    k = spl0.k
    c0 = spl0.c
    n_coeffs = c0.size

    if x.size < k + 2:
        # too few data points in region to meaningfully refit
        return spl0

    # --- Build design matrices ------------------------------------------------
    def design_matrix(x_eval, nu=0):
        A = np.empty((len(x_eval), n_coeffs))
        for j in range(n_coeffs):
            c = np.zeros(n_coeffs)
            c[j] = 1.0
            basis_j = BSpline(t_spline, c, k)
            A[:, j] = basis_j(x_eval, nu=nu)
        return A

    # LSQ fit matrix on the constrained region
    A_data = design_matrix(x, nu=0)

    # Constraint grid on [x.min(), x.max()]
    x_grid = np.linspace(x.min(), x.max(), n_grid)
    A_d1 = design_matrix(x_grid, nu=1)   # first derivative
    A_d2 = design_matrix(x_grid, nu=2)   # second derivative

    # --- Objective: least squares on masked region ----------------------------
    def objective(c):
        r = A_data @ c - y
        return 0.5 * np.dot(r, r)

    # Start from LSQ on the same region
    try:
        c_ls, *_ = np.linalg.lstsq(A_data, y, rcond=None)
    except Exception:
        c_ls = c0.copy()

    x0 = c_ls

    # --- Inequality constraints (COBYLA: fun(c) >= 0 is feasible) -------------
    def cons_d1_fun(c):
        # f'(x_grid) - tol_d1 >= 0
        return A_d1 @ c - tol_d1

    def cons_d2_fun(c):
        # f''(x_grid) - tol_d2 >= 0
        return A_d2 @ c - tol_d2

    constraints = [
        {'type': 'ineq', 'fun': cons_d1_fun},
        {'type': 'ineq', 'fun': cons_d2_fun},
    ]

    result = minimize(
        objective,
        x0,
        method='COBYLA',
        constraints=constraints,
        options={
            'maxiter': 2000,
            'rhobeg': 1.0,
        }
    )

    c_opt = result.x

    # --- Check constraint violation ourselves --------------------------------
    d1_vals = A_d1 @ c_opt
    d2_vals = A_d2 @ c_opt

    maxcv_d1 = max(0.0, tol_d1 - d1_vals.min())
    maxcv_d2 = max(0.0, tol_d2 - d2_vals.min())
    maxcv = max(maxcv_d1, maxcv_d2)

    if maxcv > constraint_tol:
        # Too much violation -> fall back gracefully
        # You can uncomment this print if you want to see when it happens:
        print(f"Warning: monotone/convex constraints violated by {maxcv:.2e}, "
              f"falling back to original spline.")
        return spl0

    # Otherwise, return the constrained spline
    return BSpline(t_spline, c_opt, k)

def fit_convex_spline_from_bspline(x, y, spl0, n_grid=200, tol=0.0):
    """
    Given data (x, y) and an initial BSpline spl0, refit the coefficients
    so that the resulting spline is convex (f'' >= tol) on [x.min(), x.max()].

    Parameters
    ----------
    x, y : 1D arrays
        Data points.
    spl0 : BSpline
        Initial spline (provides knots and degree).
    n_grid : int
        Number of grid points on which convexity is enforced.
    tol : float
        Minimal allowed second derivative (0.0 for f'' >= 0).
        Can set a tiny positive value to avoid borderline violations.

    Returns
    -------
    spl_convex : BSpline
        New convex spline with same knots/degree as spl0.
    """
    x = np.asarray(x)
    y = np.asarray(y)

    t_spline = spl0.t
    k = spl0.k
    c0 = spl0.c
    n_coeffs = c0.size

    # Design matrix for data: A_data[i, j] = B_j(x[i])
    def design_matrix(x_eval, nu=0):
        A = np.empty((len(x_eval), n_coeffs))
        for j in range(n_coeffs):
            c = np.zeros(n_coeffs)
            c[j] = 1.0
            basis_j = BSpline(t_spline, c, k)
            A[:, j] = basis_j(x_eval, nu=nu)
        return A

    # Data fit matrix
    A_data = design_matrix(x, nu=0)

    # Grid for enforcing convexity
    x_grid = np.linspace(x.min(), x.max(), n_grid)
    A_d2 = design_matrix(x_grid, nu=2)   # second derivative basis

    # Objective: 0.5 * ||A_data c - y||^2
    def objective(c):
        r = A_data @ c - y
        return 0.5 * np.dot(r, r)

    # Gradient of objective
    def objective_jac(c):
        r = A_data @ c - y
        return A_data.T @ r

    # Inequality constraint: A_d2 @ c - tol >= 0  -> f''(x_grid) >= tol
    def cons_fun(c):
        return A_d2 @ c - tol

    def cons_jac(c):
        # derivative of A_d2 @ c is A_d2
        return A_d2

    cons = {
        'type': 'ineq',
        'fun': cons_fun,
        'jac': cons_jac
    }
    result = minimize(
        objective, c0,
        jac=objective_jac,
        constraints=cons,
        method='SLSQP'
    )

    if not result.success:
        raise RuntimeError("Convex spline fit failed: " + result.message)

    c_opt = result.x
    return BSpline(t_spline, c_opt, k)

# Analysis Functions
def calculate_N0s(hc, spot_area, lambda_laser, thickness, bd_ratio, data):
    """Calculate N0s and fluences for each sample"""
    photon_energy = hc / lambda_laser  # J
    n0s = []
    fluences = []
    for i, row in data.iterrows():
        p = row.laser_power
        rep = row.repetition_rate
        power_per_pulse = p / rep  # J
        PowerDensity_per_pulse = power_per_pulse / spot_area
        photons_per_pulse = PowerDensity_per_pulse / photon_energy  # m-2
        fluences.append(photons_per_pulse)
        pump_carrierDensity = photons_per_pulse / thickness  # m-3
        pump_carrierDensity_cm = 1e-6 * pump_carrierDensity * bd_ratio  # cm-3 (includes the beamdump ratio)
        n0s.append(pump_carrierDensity_cm)

    n0s = np.array(n0s)
    fluences = np.array(fluences)
    return n0s, fluences

def calculate_noise(counts, denoise_value):
    """Calculate noise from counts based on denoise parameter"""
    noise = 0
    if denoise_value < 0:  
        noise = np.mean(np.trim_zeros(counts, trim='b')[denoise_value:]) 
    elif denoise_value > 0:
        noise = np.mean(counts[:denoise_value])
    return noise

def rate_calculation_function(count, integration_time_seconds, binsize_seconds, reprate, tau_COUNT_APD=45e-9):
    """Calculate rate from count data"""
    rate_measured = count / (binsize_seconds * integration_time_seconds * reprate)
    return rate_measured

def process_trpl_data(data, row_widgets, denoise_value, lambda_laser, spot_area, thickness, bd_ratio, bg, nc, nv, kt):
    """Process TRPL data and update the dataframe with calculated values"""
    # Update data with widget values
    data["repetition_rate"] = [row_widget['rep_rate'].value for row_widget in row_widgets]
    data["laser_power"] = [row_widget['power'].value for row_widget in row_widgets]
    data["nd"] = [row_widget['nd'].value for row_widget in row_widgets]
    data["integration_time"] = [row_widget['integration_time'].value for row_widget in row_widgets]
    
    # Calculate noise for each sample
    # data["noise"] = [rate_calculation_function(calculate_noise(counts, denoise_value) for counts, int_time, rep_rate, ns_perB in zip(data["counts"], data["integration_time"], data["repetition_rate"], data["integration_time"],) ]
    data["noise"] = [calculate_noise(counts, denoise_value) for counts in data["counts"]]
    
    # Process counts data
    counts_no_noise_list = []
    counts_no_noise_normalized_list = []
    noise_new_list = []
    counts_list = []
    for i, row in data.iterrows():
        counts_list.append(rate_calculation_function(np.array(row.counts), row.integration_time, row.ns_per_bin, row.repetition_rate))
        counts_no_noise = np.array(row.counts) - row.noise
        counts_no_noise = rate_calculation_function(counts_no_noise, row.integration_time, row.ns_per_bin, row.repetition_rate)
        print("Before Noise:", row.noise)
        noise_new_list.append(rate_calculation_function(row.noise, row.integration_time, row.ns_per_bin, row.repetition_rate))
        print("After Noise:", rate_calculation_function(row.noise, row.integration_time, row.ns_per_bin, row.repetition_rate))
        counts_no_noise_list.append(counts_no_noise)
        counts_no_noise_normalized_list.append(counts_no_noise / np.amax(counts_no_noise))
    
    data["counts_no_noise"] = counts_no_noise_list 
    data["counts_no_noise_normalized"] = counts_no_noise_normalized_list 
    data["noise"] = noise_new_list
    data["counts"] = counts_list
    
    # Calculate physical constants and N0s
    hc = 1.98645E-25
    ni = np.sqrt(nc * nv * np.exp(-bg / kt))
    data["n0s"], data["fluences"] = calculate_N0s(hc, spot_area, lambda_laser, thickness, bd_ratio, data)
    
    return data

def plot_trpl_results(data):
    """Create plot showing TRPL results for all samples"""
    
    fig, ax = plt.subplots(figsize=(10, 6))
    
    for i, row in data.iterrows():
        time_data = np.array(row['time'])
        counts_normalized = np.array(row['counts_no_noise'])
        sample_id = row['sample_id']

        #including noise
        sc = ax.scatter(time_data, row.counts, label=f'{sample_id}', marker='o', s=3)
        ax.axhline(row.noise, linestyle = "--", color = sc.get_facecolors()[-1])
        ax.axhline(np.average(row.counts[:50]), color = sc.get_facecolors()[-1])
    
    ax.set_xlabel('Time [ns]')
    ax.set_ylabel('Counts [a.u.]')
    ax.set_yscale('log')
    ax.set_title('TRPL Analysis: Counts vs Time')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

def fitfunc(x, *args):
    """
    
    Used for fitting of a multi (varying amount defined by len(args)) of exponential decay.

    Parameters
    ----------
    x: evaluation points of function.
    
    *args : arguments for multiexponential decay. len(args) has to be 2n + 1.
        
    Returns
    -------
    f = arg[0] + arg[1]*exp(-arg[2]*x) + ... + arg[i]*exp(-arg[i+1]*x)
    
    """
    #print(args)
    params = np.array([arg for arg in args])
    #params = params[:-1]
    
    if ((len(params) == 1) or (len(params) > 30)):
        print("Number of params is wrong n = "+str(len(params))+"\n")
    
    s = params[0]
    for p1, p2 in zip(params[1::2], params[2::2]):
        s = s + p1*np.exp(-p2*x)
            
    return s

def fitfunc_2(x, *all_args):
    """
    Multi-exponential decay with a fixed constant c at the end of all_args.

    all_args = (p0, p1, p2, ..., p_{2n}, c_fixed)
    
    params (p0..p_{2n}) are fitted.
    c_fixed is passed separately to curve_fit via its 'args' parameter
    and is NOT fitted.
    """
    params = np.array(all_args)
    
    # Original checks still apply to the variable params:
    if ((len(params) == 1) or (len(params) > 30)):
        print("Number of params is wrong n = " + str(len(params)) + "\n")
    
    # Your original multi-exponential structure:
    s = 0
    for p1, p2 in zip(params[0::2], params[1::2]):
        s = s + p1 * np.exp(-p2 * x)
            
    # Add the fixed constant instead of global `noise`
    return s

def fit_difflifetimes(data, n_exp = None, l2 = None, noise = None):
        """
        Calculates differential lifetime values, given a raw TRPL table, using a arbritrary amount of exponentials to fit the data.

        Parameters
        ----------
        time: time values, array like.
        
        TRPL_denoised : denoised TRPL values, array-like
        
        powers : Powers corresponding to the second dimension of TRPL_raw, array-like
        
        thickness : thickness of the samples, in cm.
        
        l2: number of data points considered for fitting.
            
        Returns
        -------
        time_fit: x_axis considered for fit
        densities2: the calculated carrier densities following n0*sqrt(data)
        diff_taus: diffenrential lifetimes, tau_diff = -2*(dt/d(log(fit))). following Thomas Kirchartzs' work
        
        """
        if n_exp == None:
            n_exp = 3
        
        diff_taus = []
        densities2 = []
        time_fit = []
        print("Number of exponentials for fit used is = "+str(n_exp)+"\n")
        f, ax = plt.subplots(3, len(data), figsize=(25,12))
        i = 0 
        for _, row in data.iterrows():
            #t = time[i, (ns_raw[i, :] > 2*self.Noise[selection[0]]) & (time[i,:] > 0)]
            t = np.array(row.time)
            pl = np.array(row.counts_no_noise)
            print("Noise Level:", row.noise)
            
            pl_argmax = np.argmax(pl)
            t = t[pl_argmax:] / 1e12
            pl = pl[pl_argmax:]

            t_min = t[0]
            t = t - t_min
        
            p = [1, 1e4]*(n_exp[i])
            lb = (1e-12,)*(2*n_exp[i])
            ub = (np.inf,)*(2*n_exp[i])
            
            # fit_sav = savgol_filter(pl,51,3)
            # s= 1e-5
            # knots = list(generate_knots(t, fit_sav, s=s, k=3, nest=30))
            # fit_knots = make_splrep(t,fit_sav,k=3, s=s, t=knots[-1])(t)

            # 1) smooth with Savitzky‚ÄìGolay
            fit_sav = savgol_filter(pl, 51, 3)
            s = 1e-5
            # 2) choose knots (your existing routine)
            knots = list(generate_knots(t, fit_sav, s=s, k=3, nest=30))
            # 3) unconstrained smoothing spline (just to get knots + initial coeffs)
            spr0 = make_splrep(t, fit_sav, k=3, s=s, t=knots[-1])
            fit_knots = make_splrep(t, fit_sav, k=3, s=s, t=knots[-1])(t)
            mask = fit_knots > 10*row.noise
            # 4) refit coefficients with convexity constraint f'' >= 0
            # spl_convex = fit_monotone_convex_spline_from_bspline(t[fit_knots > 2*row.noise], fit_sav[fit_knots > 2*row.noise], spr0, n_grid=300, tol_d1=1e-6, tol_d2 = 1e-6)
            spl_convex = fit_convex_spline_from_bspline(t[mask], fit_sav[mask], spr0, n_grid=300, tol = 0.0)
            spl_convex = spl_convex(t[mask])
        
            p, _ = scipy.optimize.curve_fit(lambda x, *params: fitfunc_2(x, *params) + row.noise, 1e3*t[mask], pl[mask], maxfev = 100000, p0 = p, bounds = (lb, ub))
            fit = (fitfunc_2(t*1e3, *p) + row.noise)
            print("P =", p)
            t+= t_min
            tau_diff = -2*(np.diff(t)/np.diff(np.log(fit)))
            carrier_densities_fit = np.sqrt(fit/np.max(fit))*(row.n0s)
            # print('{:.2e}'.format(row.n0s))

            tau_diffSpline = -2*(np.diff(t[mask])/np.diff(np.log(spl_convex)))
            carrier_densities_fitSpline = np.sqrt(spl_convex/np.max(spl_convex))*(row.n0s)
            
            #Plotting
            sc = ax[0, i].scatter(1e9*t, pl, marker = 'x')
            facecolors = sc.get_facecolors()
            ax[0, i].axhline(row.noise, linestyle = "--", color = facecolors[-1])
            ax[0, i].axhline(2*row.noise, color = 'red')
            ax[0, i].plot(1e9*t[mask], np.abs(fit_knots[mask]), color = 'red')
            ax[0, i].plot(1e9*t[:len(spl_convex)], np.abs(spl_convex), color = 'green')
            ax[0, i].plot(1e9*t[:len(fit)], np.abs(fit), color = 'orange')
            ax[0, i].set_yscale("log")
            ax[0, i].set_xlabel("time [ns]")
            ax[0, i].set_ylabel("PL counts [#]")

            ax[1, i].plot(1e9*t[:len(tau_diff)], tau_diff)
            ax[1, i].plot(1e9*t[:len(tau_diffSpline)], tau_diffSpline, color = "green")
            ax[1, i].set_xlim([min(1e9*t), max(1e9*t[:len(tau_diff)])])
            ax[1, i].set_xlabel("time [ns]")
            ax[1, i].set_ylabel("Differential lifetime [s]")
            
            ax[2, i].plot(carrier_densities_fit[1:], tau_diff)
            ax[2, i].plot(carrier_densities_fitSpline[1:], tau_diffSpline, color = "green")
            ax[2, i].set_xlabel("Carrier Concentration [cm-3]")
            ax[2, i].set_ylabel("Differential lifetime [s]")
            ax[2, i].set_xscale("log")
            ax[2, i].set_yscale("log")
        
            densities2.append(carrier_densities_fit)
            diff_taus.append(tau_diff)
            time_fit.append(t[:len(fit)])
            i+=1
        plt.tight_layout()
        plt.show()
            
        f = plt.figure()
        for i, (b, c) in enumerate(zip(diff_taus, densities2)):
            plt.plot(c[:-1], b)
            #plt.plot(c[0:cut], savgol_filter(b[0:cut], 20, 7, mode = "nearest"))
            
        ax = f.gca()
        ax.set_yscale('log')
        ax.set_xscale('log')
        ax.set_xlabel("Carrier Density [cm-3]")
        ax.set_ylabel("Differential lifetime [s]")
        plt.legend(range(len(data)), loc = 'upper left')                                                                                                                               

        plt.tight_layout()
        plt.show()
        
        # return time_fit, densities2, diff_taus

In [3]:
warning_sign = "\u26A0"

out = widgets.Output()
out2 = widgets.Output()
read = widgets.Output()
dynamic_content = widgets.Output()  # For dynamically updated content
results_content = widgets.Output(layout={
    # 'border': '1px solid black',  # Optional: adds a border to the widget
    'max_height': '1000px',  # Set the height
    'overflow': 'scroll',  # Adds a scrollbar if content overflows
    })
cell_edit = widgets.VBox() 

default_variables = widgets.Dropdown(
    options=['sample name', 'batch',"sample description", 'custom'],
    index=0,
    description='name preset:',
    disabled=False,
    tooltip="Presets for how the samples will be named in the plot"
)
data = None
original_data = None  # To store original data for filter reset


#this function takes sample ids and returns the eqe curves and parameters as Dataframes
def get_trpl_data(try_sample_ids, variation):
    #make api call, result has everything in json format
    all_trpl = get_all_trpl(url, token, try_sample_ids, eqe_type="HySprint_TimeResolvedPhotoluminescence")

    existing_sample_ids = pd.Series(all_trpl.keys())

    # Check if there's any EQE data
    if len(existing_sample_ids) == 0:
        return None  # Return None value to indicate no data

    sample_params_list = []
    for sample_id, sample_data in all_trpl.items():
        for trpl_entry in sample_data:
            df = pd.DataFrame()
            df["counts"] = [trpl_entry[0]["trpl_properties"]["counts"]]
            df["time"] = [trpl_entry[0]["trpl_properties"]["time"]]
            df["ns_per_bin"] = [trpl_entry[0]["trpl_properties"]["ns_per_bin"]]
            df["sample_id"] = sample_id
            data_file = trpl_entry[0].get("data_file")
            if data_file:
                data_file = ".".join(data_file.split(".")[1:-2])
            df["data_file"] = data_file

            df["variation"] = variation.get(sample_id, '')
            df["name"] = trpl_entry[0].get("name", '')
            sample_params_list.append(df)
    
      
    # Only try to concatenate if there's data
    if sample_params_list:
        return pd.concat(sample_params_list)
    return None

def on_load_data_clicked(batch_ids_selector):
    #global dictionary to hold data
    global data, original_data
    dynamic_content.clear_output()
    with out:
        out.clear_output()
        print("Loading Data")

        try_sample_ids = get_ids_in_batch(url, token, batch_ids_selector.value)

        #extract EQE here
        identifiers = get_sample_description(url, token, list(try_sample_ids))
        data = get_trpl_data(try_sample_ids, identifiers)

        # Check if EQE data was found
        if data is None:
            out.clear_output()
            print("The batches selected don't contain any TRPL measurements")
            return

        # Store original data for filter reset functionality
        original_data = data.copy()
        
        out.clear_output()
        print("Data Loaded")
            
        # Create parameter widgets
        with dynamic_content:
            dynamic_content.clear_output()
            
            print("TRPL Analysis Parameters:")
            
            # Float widgets
            bg_widget = widgets.FloatText(
                description='BG:',
                tooltip='Background value',
                style={'description_width': '120px'}
            )
            
            lambda_laser_widget = widgets.FloatText(
                value=705e-9,
                description='lambda_laser:',
                tooltip='Laser wavelength',
                style={'description_width': '120px'}
            )
            
            spot_diameter_widget = widgets.FloatText(
                value=2.72e-04,
                description='Spot Diameter [cm]:',
                tooltip='Spot diameter',
                style={'description_width': '120px'}
            )
            
            thickness_widget = widgets.FloatText(
                description='Thickness [nm]:',
                tooltip='Thickness value',
                style={'description_width': '120px'}
            )
            
            nc_widget = widgets.FloatText(
                value=2e18,
                description='Nc:',
                tooltip='Nc parameter',
                style={'description_width': '120px'}
            )
            
            nv_widget = widgets.FloatText(
                value=2e18,
                description='Nv:',
                tooltip='Nv parameter',
                style={'description_width': '120px'}
            )
            
            kt_widget = widgets.FloatText(
                value=27.7e-3,
                description='kT:',
                tooltip='kT parameter',
                style={'description_width': '120px'}
            )
            
            bd_ratio_widget = widgets.FloatText(
                value=0.21,
                description='BD_ratio:',
                tooltip='BD ratio parameter',
                style={'description_width': '120px'}
            )
            
            # Checkbox widgets
            denoise_widget = widgets.IntText(
                value=0,
                description='denoise',
                tooltip='Choose denoise'
            )
            
            retime_widget = widgets.Checkbox(
                value=True,
                description='retime',
                tooltip='Enable retiming'
            )
            
            # Display widgets in a organized layout
            float_widgets_box = widgets.VBox([
                widgets.HTML("<h4>Float Parameters:</h4>"),
                bg_widget,
                lambda_laser_widget,
                spot_diameter_widget,
                thickness_widget,
                nc_widget,
                nv_widget,
                kt_widget,
                bd_ratio_widget
            ])
            
            checkbox_widgets_box = widgets.VBox([
                widgets.HTML("<h4>Boolean Parameters:</h4>"),
                denoise_widget,
                retime_widget
            ])
            
            parameter_widgets = widgets.HBox([float_widgets_box, checkbox_widgets_box])
            display(parameter_widgets)
            
            # Create sample-specific parameter table
            print("\nSample-specific Parameters:")
            rep_rate_default = 10000 # Default: 10 kHz
            integration_time_default = 10 # Default: 10 seconds
            power_default = 0.4 # Default: 0.4 uW
            fitting_interval_default = 100 # Default: 100 data points
            num_exponentials_default = 3 # Default: 3 exponentials
            
            # Global widgets to set all values at once
            global_rep_rate_widget = widgets.FloatText(
                value=rep_rate_default,  
                description='Set all Rep. Rates:',
                tooltip='Set this value to apply to all samples',
                style={'description_width': '150px'}
            )
            global_power_widget = widgets.FloatText(
                description='Set all Powers:',
                value=power_default,  
                tooltip='Set this value to apply to all samples',
                style={'description_width': '150px'}
            )
            global_nd_widget = widgets.FloatText(
                description='Set all ND:',
                tooltip='Set this value to apply to all samples',
                style={'description_width': '150px'}
            )
            global_integration_time_widget = widgets.FloatText(
                description='Set all Int. Times:',
                value=integration_time_default,  
                tooltip='Set this value to apply to all samples',
                style={'description_width': '150px'}
            )
            global_fitting_interval_widget = widgets.FloatText(
                description='Set all Fitting Intervals:',
                value=fitting_interval_default,  
                tooltip='Set this value to apply to all samples',
                style={'description_width': '150px'}
            )
            global_num_exponentials_widget = widgets.IntText(
                description='Set all Num. Exponentials:',
                value=num_exponentials_default,  
                tooltip='Set this value to apply to all samples',
                style={'description_width': '150px'}
            )
            
            # Store individual widgets for each row
            row_widgets = []
            
            # Create table header
            header = widgets.HBox([
                widgets.HTML(value="<b>Sample ID</b>", layout=widgets.Layout(width='200px')),
                widgets.HTML(value="<b>Data File</b>", layout=widgets.Layout(width='300px')),
                widgets.HTML(value="<b>Rep. Rate [kHz]</b>", layout=widgets.Layout(width='120px')),
                widgets.HTML(value="<b>Power [uW]</b>", layout=widgets.Layout(width='120px')),
                widgets.HTML(value="<b>ND</b>", layout=widgets.Layout(width='120px')),
                widgets.HTML(value="<b>Int. Time [s]</b>", layout=widgets.Layout(width='120px')),
                widgets.HTML(value="<b>Fitting Interval</b>", layout=widgets.Layout(width='120px')),
                widgets.HTML(value="<b>Num. Exponentials</b>", layout=widgets.Layout(width='130px'))
            ])
            
            # Create rows for each sample
            table_rows = [header]
            for idx, row in data.iterrows():
                sample_id_label = widgets.HTML(
                    value=str(row.get('sample_id', '')),
                    layout=widgets.Layout(width='200px')
                )
                
                data_file_label = widgets.HTML(
                    value=str(row.get('data_file', '')),
                    layout=widgets.Layout(width='300px')
                )
                
                rep_rate_widget = widgets.FloatText(
                    value=rep_rate_default,  
                    layout=widgets.Layout(width='120px'),
                    tooltip=f'Repetition rate for {row.get("sample_id", "")}'
                )
                
                power_widget = widgets.FloatText(
                    layout=widgets.Layout(width='120px'),
                    value=power_default,  
                    tooltip=f'Power for {row.get("sample_id", "")}'
                )
                
                nd_widget = widgets.FloatText(
                    layout=widgets.Layout(width='120px'),
                    tooltip=f'ND for {row.get("sample_id", "")}'
                )
                
                integration_time_widget = widgets.FloatText(
                    value=integration_time_default,  
                    layout=widgets.Layout(width='120px'),
                    tooltip=f'Integration time for {row.get("sample_id", "")}'
                )
                
                fitting_interval_widget = widgets.FloatText(
                    value=fitting_interval_default,  
                    layout=widgets.Layout(width='120px'),
                    tooltip=f'Fitting interval for {row.get("sample_id", "")}'
                )
                
                num_exponentials_widget = widgets.IntText(
                    value=num_exponentials_default,  
                    layout=widgets.Layout(width='130px'),
                    tooltip=f'Number of exponentials for {row.get("sample_id", "")}'
                )
                
                row_widgets.append({
                    'rep_rate': rep_rate_widget,
                    'power': power_widget,
                    'nd': nd_widget,
                    'integration_time': integration_time_widget,
                    'fitting_interval': fitting_interval_widget,
                    'num_exponentials': num_exponentials_widget
                })
                
                row_box = widgets.HBox([
                    sample_id_label,
                    data_file_label,
                    rep_rate_widget,
                    power_widget,
                    nd_widget,
                    integration_time_widget,
                    fitting_interval_widget,
                    num_exponentials_widget
                ])
                table_rows.append(row_box)
            
            # Functions to set all values at once
            def set_all_rep_rates(change):
                if change['type'] == 'change' and change['name'] == 'value':
                    for row_widget in row_widgets:
                        row_widget['rep_rate'].value = change['new']
            
            def set_all_powers(change):
                if change['type'] == 'change' and change['name'] == 'value':
                    for row_widget in row_widgets:
                        row_widget['power'].value = change['new']
            
            def set_all_nds(change):
                if change['type'] == 'change' and change['name'] == 'value':
                    for row_widget in row_widgets:
                        row_widget['nd'].value = change['new']
            
            def set_all_integration_times(change):
                if change['type'] == 'change' and change['name'] == 'value':
                    for row_widget in row_widgets:
                        row_widget['integration_time'].value = change['new']
            
            def set_all_fitting_intervals(change):
                if change['type'] == 'change' and change['name'] == 'value':
                    for row_widget in row_widgets:
                        row_widget['fitting_interval'].value = change['new']
            
            def set_all_num_exponentials(change):
                if change['type'] == 'change' and change['name'] == 'value':
                    for row_widget in row_widgets:
                        row_widget['num_exponentials'].value = change['new']
            
            # Connect global widgets to update functions
            global_rep_rate_widget.observe(set_all_rep_rates)
            global_power_widget.observe(set_all_powers)
            global_nd_widget.observe(set_all_nds)
            global_integration_time_widget.observe(set_all_integration_times)
            global_fitting_interval_widget.observe(set_all_fitting_intervals)
            global_num_exponentials_widget.observe(set_all_num_exponentials)
            
            # Display global widgets
            global_widgets_box = widgets.VBox([
                widgets.HTML("<h4>Set All Values:</h4>"),
                global_rep_rate_widget,
                global_power_widget,
                global_nd_widget,
                global_integration_time_widget,
                global_fitting_interval_widget,
                global_num_exponentials_widget
            ])
            display(global_widgets_box)
            
            # Display the table
            table_widget = widgets.VBox(table_rows)
            display(table_widget)
            
            # Create analysis button
            analysis_button = widgets.Button(
                description='Analysis',
                button_style='success',
                tooltip='Run TRPL analysis with current parameters',
                layout=widgets.Layout(width='200px', height='40px')
            )
            
            # Analysis output
            analysis_output = widgets.Output()
            
            def run_analysis(b):
                with analysis_output:
                    analysis_output.clear_output()
                    
                    # Process the data using the new analysis functions
                    spot_area = np.pi * (spot_diameter_widget.value / 2) ** 2
                    processed_data = process_trpl_data(
                        data=data,
                        row_widgets=row_widgets,
                        denoise_value=denoise_widget.value,
                        lambda_laser=lambda_laser_widget.value,
                        spot_area=spot_area,
                        thickness=thickness_widget.value,
                        bd_ratio=bd_ratio_widget.value,
                        bg=bg_widget.value,
                        nc=nc_widget.value,
                        nv=nv_widget.value,
                        kt=kt_widget.value
                    )
                    
                    # Create the plot
                    plot_trpl_results(processed_data)
                    fit_difflifetimes(data, n_exp=[r.get("num_exponentials").value for r in row_widgets], l2=[r.get("fitting_interval").value for r in row_widgets],
                                      noise = processed_data['noise'].to_numpy())
                    
                    print("Analysis completed successfully!")
            
            # Connect button to analysis function
            analysis_button.on_click(run_analysis)
            
            # Display button and output
            display(analysis_button)
            display(analysis_output)
            
# BATCH SELECTION WITH OPTIONAL FILTERING
def create_batch_selection_with_optional_filtering():
    """
    Create batch selection widget with optional filtering button
    """
    # Create the original batch selection widget (fast)
    original_batch_widget = batch_selection.create_batch_selection(url, token, on_load_data_clicked)
    
    # Get the batch selector from the original widget to count total batches
    batch_selector = None
    for child in original_batch_widget.children:
        if isinstance(child, widgets.SelectMultiple):
            batch_selector = child
            break
    
    total_batches = len(batch_selector.options) if batch_selector else 0
    
    # Create filter button
    filter_button = widgets.Button(
        description=f"üîç Filter to show only batches with AbsPL data",
        button_style='info',
        tooltip=f'Click to filter {total_batches} batches (this may take a few minutes)',
        layout=widgets.Layout(width='400px')
    )
    
    # Create status output
    filter_status = widgets.Output()
    
    # Filter function
    def start_filtering(b):
        filter_button.disabled = True
        filter_button.description = "üîÑ Filtering in progress..."
        
        with filter_status:
            filter_status.clear_output(wait=True)
            print("Finding batches with TRPL data...")
            
            # Get all batch IDs using the same filtering as the original batch_selection
            batch_ids_list_tmp = list(get_batch_ids(url, token))
            all_batch_ids = []
            for batch in batch_ids_list_tmp:
                if "_".join(batch.split("_")[:-1]) in batch_ids_list_tmp:
                    continue
                all_batch_ids.append(batch)
            
            print(f"Testing {len(all_batch_ids)} batches...")
            
            valid_batches = get_all_batches_wth_data(url, token, "HySprint_TimeResolvedPhotoluminescence")
            
            
            # Update the original widget's options
            if batch_selector:
                batch_selector.options = valid_batches
            
            # Show final results
            filter_status.clear_output(wait=True)
            print("="*60)
            print("FILTERING COMPLETE")
            print("="*60)
            print(f"‚úÖ Found {len(valid_batches)} batches with TRPL data out of {total_batches} total")
            if len(valid_batches) > 0:
                print(f"Valid batches: {valid_batches}")
            else:
                print("‚ö†Ô∏è  No batches with TRPL data found!")
            
            # Update button
            filter_button.description = f"‚úÖ Filtering complete - {len(valid_batches)} valid batches found"
            filter_button.disabled = True
            
            # Add info to the widget
            info_html = widgets.HTML(
                value=f"<p><b>Showing {len(valid_batches)} of {total_batches} batches with confirmed AbsPL data</b></p>"
            )
            original_batch_widget.children = (info_html,) + original_batch_widget.children
    
    # Connect the button
    filter_button.on_click(start_filtering)
    
    # Create the complete widget
    complete_widget = widgets.VBox([
        widgets.HTML(f"<p>Select batches from all {total_batches} available batches, or use the filter button below:</p>"),
        filter_button,
        filter_status,
        original_batch_widget
    ])
    
    return complete_widget


# Create and display the batch selection widget with optional filtering
batch_widget = create_batch_selection_with_optional_filtering()
display(batch_widget)

display(out)
display(dynamic_content)  # This will be updated dynamically with the variables menu

VBox(children=(HTML(value='<p>Select batches from all 215 available batches, or use the filter button below:</‚Ä¶

Output()

Output()