In [6]:
!which python 

/opt/anaconda3/envs/spiketurnpike_postanalysis/bin/python


In [7]:
!pip show pyabf

Name: pyabf
Version: 2.3.8
Summary: Python library for reading files in Axon Binary Format (ABF)
Home-page: http://swharden.com/pyabf
Author: Scott W Harden
Author-email: SWHarden@gmail.com
License: MIT License
Location: /opt/anaconda3/envs/spiketurnpike_postanalysis/lib/python3.9/site-packages
Requires: matplotlib, numpy, pytest
Required-by: 


In [8]:
from spiketurnpike_postanalysis.Extract_patch_data_from_abf import PatchDataExtractor
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import os
import re  # Import the re module for regular expressions
from matplotlib.backends.backend_pdf import PdfPages
import pyabf
import pyabf.plot
from scipy.signal import find_peaks
from statsmodels.stats.anova import AnovaRM
from statsmodels.stats.multitest import multipletests
    
from scipy.stats import ttest_ind
import matplotlib.gridspec as gridspec
from scipy.optimize import curve_fit
from scipy.stats import ks_2samp

In [9]:
class BladePatchDataProcessor:
    def __init__(self, base_path, time_units="sec", voltage_units="mV"):
        """
        Initialize the BladePatchDataProcessor with the base path.

        Args:
            base_path (str): Path to the BLADe_patch_data folder.
        """
        self.base_path = base_path  # Base directory path
        self.dataframe = None  # DataFrame to store metadata
        self.unique_groups = None  # List of unique group names
        self.time_units = time_units
        self.voltage_units = voltage_units
        self._abf_cache = {}  # Cache for ABF objects to avoid repeated loading, Cache for ABF objects: {(group, recording_id, label): abf_obj}

    def process_data(self):
        """
        Process the BLADe_patch_data folder to extract metadata about .abf files.

        This method populates the `dataframe` and `unique_groups` attributes.
        """
        data = []

        # Walk through each group in the base path
        for group_dir in os.listdir(self.base_path):
            group_path = os.path.join(self.base_path, group_dir)
            if not os.path.isdir(group_path):
                continue  # Skip non-directory files

            # Determine linking pattern based on the group
            if group_dir == "L + CS-Veh":
                identifier_pattern = r"Veh-\d+"
            else:
                identifier_pattern = r"CTZ-\d+"

            # Look for .abf files in the group directory
            for file_name in os.listdir(group_path):
                if file_name.endswith(".abf"):
                    # Parse the label ("Before" or "After")
                    if "Before" in file_name:
                        label = "Before"
                    elif "After" in file_name:
                        label = "After"
                    else:
                        continue  # Skip files without "Before" or "After"

                    # Extract linking identifier (e.g., "CTZ-1", "Veh-2")
                    match = re.search(identifier_pattern, file_name)
                    if not match:
                        print(f"Warning: No linking identifier found in file {file_name}. Skipping.")
                        continue
                    recording_id = match.group(0)

                    # Append metadata to the list
                    data.append({
                        "Group": group_dir,
                        "Recording_ID": recording_id,
                        "Label": label,
                        "File_Path": os.path.join(group_path, file_name)
                    })

        # Convert the list of metadata to a DataFrame
        self.dataframe = pd.DataFrame(data)

        # Sort the DataFrame by Group and Recording ID for clarity
        if not self.dataframe.empty:
            self.dataframe = self.dataframe.sort_values(by=["Group", "Recording_ID", "Label"]).reset_index(drop=True)
            # Extract unique group names
            self.unique_groups = self.dataframe["Group"].unique().tolist()
        else:
            self.unique_groups = []

    def get_recording_ids(self, group):
        """
        Given a group name, return all unique recording IDs associated with that group.
        """
        group_data = self.dataframe[self.dataframe["Group"] == group]
        return group_data["Recording_ID"].unique()

    def get_abf_file(self, group, recording_id, label):
        """
        Return a pyabf.ABF object for the given group, recording_id, and label.
        Uses caching to avoid reloading the file multiple times.
        """
        key = (group, recording_id, label)
        if key in self._abf_cache:
            return self._abf_cache[key]

        entry = self.dataframe[(self.dataframe["Group"] == group) &
                               (self.dataframe["Recording_ID"] == recording_id) & 
                               (self.dataframe["Label"] == label)]
        if entry.empty:
            raise ValueError(f"No entry found for Group: {group}, Recording ID: {recording_id} with label: {label}")

        file_path = entry["File_Path"].iloc[0]
        if not os.path.isfile(file_path):
            raise FileNotFoundError(f"ABF file not found at: {file_path}")

        abf = pyabf.ABF(file_path)
        self._abf_cache[key] = abf
        return abf

    def get_sweep_data(self, group, recording_id, label, sweep_number, channel=0):
        """
        Retrieve time and voltage arrays for a given sweep from the specified group and recording.
        """
        abf = self.get_abf_file(group, recording_id, label)
        abf.setSweep(sweepNumber=sweep_number, channel=channel)
        return abf.sweepX, abf.sweepY

    def colors_binned(self, count, colormap="viridis", reverse=False):
        cmap = plt.get_cmap(colormap)
        colors = [cmap(i / count) for i in range(count)]
        if reverse:
            colors.reverse()
        return colors

    def plot_sweeps(self, ax, group, recording_id, label, sweep_numbers=None,
                    offsetXsec=0.3, offsetYunits=40, startAtSec=0, endAtSec=None,
                    color=None, alpha=0.5, linewidth=1, hideAxis=True):
        """
        Plot multiple sweeps from a given group, recording, and label on the given axis.
        This version includes parameters for offsetting and time-limiting sweeps,
        and the option to hide axes.
        """
        abf = self.get_abf_file(group, recording_id, label)
        if sweep_numbers is None:
            sweep_numbers = abf.sweepList

        data_rate = abf.dataRate
        i1 = int(data_rate * startAtSec)
        i2 = int(data_rate * endAtSec) if endAtSec else None

        # Handle colors
        if color is None and len(sweep_numbers) > 1:
            colors = self.colors_binned(len(sweep_numbers))
        else:
            colors = [color] * len(sweep_numbers)

        for i, sweep_num in enumerate(sweep_numbers):
            time, voltage = self.get_sweep_data(group, recording_id, label, sweep_num)
            ax.plot(
                time[i1:i2] + offsetXsec * sweep_num,
                voltage[i1:i2] + offsetYunits * sweep_num,
                color=colors[i] if colors[i] else 'C0',
                alpha=alpha,
                linewidth=linewidth
            )

        # Remove all axis lines, ticks, and labels if requested
        if hideAxis:
            ax.axis('off')
        else:
            # If you ever want axes, you can customize here:
            ax.set_xlabel(self.time_units)
            ax.set_ylabel(self.voltage_units)

    def plot_scalebar(self, ax, scaleXms=200, scaleYmV=50, fontSize=8, lineWidth=2,
                    hideTicks=True, hideFrame=True):
        """
        Add a scale bar to the given axis dynamically.
        By default, shows a scale bar for 200 ms (0.2 s) and 50 mV.

        Args:
            ax (matplotlib.axes.Axes): The axis to draw the scale bar on.
            scaleXms (float): The horizontal scale length in milliseconds (default 200 ms).
            scaleYmV (float): The vertical scale length in millivolts (default 50 mV).
            fontSize (int): Font size of the scale bar labels.
            lineWidth (int): Line width of the scale bar lines.
            hideTicks (bool): If True, hides the axis ticks.
            hideFrame (bool): If True, hides the axis frame/spines.

        The scale bar will be placed in the lower-right corner of the axis.
        """
        # Convert ms to seconds for data coordinates
        scaleXsize_data = scaleXms / 1000.0  # e.g., 200 ms = 0.2 s
        scaleYsize_data = scaleYmV  # mV stays mV

        x1, x2 = ax.get_xlim()
        y1, y2 = ax.get_ylim()
        xs, ys = abs(x2 - x1), abs(y2 - y1)

        # Position the scale bar in the lower-right corner
        scaleBarPadX = 0.10
        scaleBarPadY = 0.10
        scaleBarX = x2 - scaleBarPadX * xs
        scaleBarX2 = scaleBarX - scaleXsize_data
        scaleBarY = y1 + scaleBarPadY * ys
        scaleBarY2 = scaleBarY + scaleYsize_data

        scaleBarXs = [scaleBarX2, scaleBarX, scaleBarX]
        scaleBarYs = [scaleBarY, scaleBarY, scaleBarY2]

        # Hide ticks/frames if requested
        if hideTicks:
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)
        if hideFrame:
            for spine in ax.spines.values():
                spine.set_visible(False)

        # Draw the scale bar
        ax.plot(scaleBarXs, scaleBarYs, 'k-', lw=lineWidth)

        # Padding for labels
        lblPadMult = 0.005 + 0.002 * lineWidth
        lblPadX = xs * lblPadMult
        lblPadY = ys * lblPadMult

        # Create labels with units
        # For time, we label in ms; for voltage, in mV
        time_label = f"{scaleXms} ms"
        voltage_label = f"{scaleYmV} mV"

        # Add text labels
        ax.text((scaleBarX + scaleBarX2) / 2, scaleBarY - lblPadY, time_label,
                ha='center', va='top', fontsize=fontSize)
        ax.text(scaleBarX + lblPadX, (scaleBarY + scaleBarY2) / 2, voltage_label,
                ha='left', va='center', fontsize=fontSize)

    def plot_before_after_comparison(self, group, recording_id,
                                    before_label="Before", after_label="After",
                                    sweep_numbers=None, startAtSec=0, endAtSec=1.5,
                                    offsetXsec=0.3, offsetYunits=40,
                                    color_before=None, color_after="red",
                                    alpha=0.5, linewidth=1,
                                    add_suptitle=True, suptitle_fontsize=14):
        """
        Create a figure comparing Before and After sweeps for a single recording in a group.
        This version allows control over offsets, time, and whether to show titles or axes.
        """
        fig, axes = plt.subplots(2, 1, figsize=(10, 6), sharex=True)

        # Plot Before sweeps
        self.plot_sweeps(axes[0], group, recording_id, before_label,
                        sweep_numbers=sweep_numbers,
                        offsetXsec=offsetXsec, offsetYunits=offsetYunits,
                        startAtSec=startAtSec, endAtSec=endAtSec,
                        color=color_before, alpha=alpha, linewidth=linewidth, hideAxis=True)

        # Plot After sweeps
        self.plot_sweeps(axes[1], group, recording_id, after_label,
                        sweep_numbers=sweep_numbers,
                        offsetXsec=offsetXsec, offsetYunits=offsetYunits,
                        startAtSec=startAtSec, endAtSec=endAtSec,
                        color=color_after, alpha=alpha, linewidth=linewidth, hideAxis=True)

        # Optionally add a figure-level title
        if add_suptitle:
            fig.suptitle(f"Group: {group}, Recording: {recording_id}",
                        fontsize=suptitle_fontsize)

        plt.tight_layout()
        return fig, axes

    def export_group_to_pdf(self, group, output_pdf_path,
                            before_label="Before", after_label="After",
                            sweep_numbers=None, startAtSec=0, endAtSec=1.5,
                            offsetXsec=0.3, offsetYunits=40,
                            color_before=None, color_after="red",
                            alpha=0.5, linewidth=1):
        """
        Create a multipage PDF for all recordings in a given group.
        Each page shows Before/After comparison for a single recording in that group.
        """
        os.makedirs(os.path.dirname(output_pdf_path), exist_ok=True)
        recording_ids = self.get_recording_ids(group)

        with PdfPages(output_pdf_path) as pdf:
            for recording_id in recording_ids:
                before_entry = self.dataframe[(self.dataframe["Group"] == group) &
                                            (self.dataframe["Recording_ID"] == recording_id) &
                                            (self.dataframe["Label"] == before_label)]
                after_entry = self.dataframe[(self.dataframe["Group"] == group) &
                                            (self.dataframe["Recording_ID"] == recording_id) &
                                            (self.dataframe["Label"] == after_label)]

                # Only create the page if both Before and After exist for this (Group, Recording_ID)
                if before_entry.empty or after_entry.empty:
                    continue

                fig, axes = self.plot_before_after_comparison(
                    group, recording_id,
                    before_label=before_label, after_label=after_label,
                    sweep_numbers=sweep_numbers,
                    startAtSec=startAtSec, endAtSec=endAtSec,
                    offsetXsec=offsetXsec, offsetYunits=offsetYunits,
                    color_before=color_before, color_after=color_after,
                    alpha=alpha, linewidth=linewidth,
                    add_suptitle=True
                )

                # Add a scale bar to the bottom axis if desired
                # Adjust hideFrame and hideTicks to maintain a clean look
                self.plot_scalebar(axes[1], hideTicks=True, hideFrame=True)

                pdf.savefig(fig)
                plt.close(fig)
            
    def export_all_groups_to_pdfs(self, output_dir,
                                before_label="Before", after_label="After",
                                sweep_numbers=None,
                                startAtSec=0, endAtSec=1.5,
                                offsetXsec=0.3, offsetYunits=40,
                                color_before=None, color_after="red",
                                alpha=0.5, linewidth=1):
        """
        Export one PDF per group, with each PDF containing all recordings 
        (Before/After) for that group.

        Args:
            output_dir (str): Directory to save all the PDF files.
            before_label (str): Label for the "Before" condition.
            after_label (str): Label for the "After" condition.
            sweep_numbers (list or None): Specific sweeps to plot. 
                                        If None, all sweeps are plotted.
            startAtSec (float): Start time (seconds) of the data to plot.
            endAtSec (float): End time (seconds) of the data to plot.
            offsetXsec (float): Horizontal offset per sweep.
            offsetYunits (float): Vertical offset per sweep.
            color_before (str or None): Color for "Before" sweeps. 
                                        None means use default or colormap.
            color_after (str): Color for "After" sweeps.
            alpha (float): Transparency of the sweep lines.
            linewidth (float): Width of the sweep lines.
        """
        os.makedirs(output_dir, exist_ok=True)
        for grp in self.unique_groups:
            pdf_path = os.path.join(output_dir, f"{grp}_plots.pdf")
            self.export_group_to_pdf(
                group=grp,
                output_pdf_path=pdf_path,
                before_label=before_label,
                after_label=after_label,
                sweep_numbers=sweep_numbers,
                startAtSec=startAtSec,
                endAtSec=endAtSec,
                offsetXsec=offsetXsec,
                offsetYunits=offsetYunits,
                color_before=color_before,
                color_after=color_after,
                alpha=alpha,
                linewidth=linewidth
            )
            print(f"PDF saved for group {grp} at: {pdf_path}")

    def get_summary(self):
        """
        Get a summary of the processed data.

        Returns:
            str: A summary string including number of groups and recordings.
        """
        if self.dataframe is None:
            return "No data processed yet."

        summary = (
            f"Total Groups: {len(self.unique_groups)}\n"
            f"Unique Groups: {self.unique_groups}\n"
            f"Total Recordings: {len(self.dataframe)}"
        )
        return summary
    
    def plot_sweeps_pdf(self, output_dir):
        """
        Generate a multipage PDF for each group, with each page containing Before and After sweeps.

        Args:
            output_dir (str): Path to save the generated PDF files.
        """
        if self.dataframe is None or self.dataframe.empty:
            print("No data to plot. Run `process_data` first.")
            return

        os.makedirs(output_dir, exist_ok=True)

        # Group by unique groups
        for group in self.unique_groups:
            group_data = self.dataframe[self.dataframe["Group"] == group]
            pdf_path = os.path.join(output_dir, f"{group}.pdf")
            saved_pages = 0  # Track the number of pages added to the PDF

            # Determine linking pattern based on the group
            if group == "L + CS-Veh":
                identifier_pattern = "Veh"
            else:
                identifier_pattern = "CTZ"

            # Extract unique linking pairs
            linking_ids = group_data["Recording_ID"].str.extract(f"({identifier_pattern}-\d+)")[0].dropna().unique()

            with PdfPages(pdf_path) as pdf:
                for link_id in linking_ids:
                    # Filter "Before" and "After" files for the current linking pair
                    before_file = group_data[(group_data["Recording_ID"].str.contains(link_id)) & (group_data["Label"] == "Before")]["File_Path"]
                    after_file = group_data[(group_data["Recording_ID"].str.contains(link_id)) & (group_data["Label"] == "After")]["File_Path"]

                    if before_file.empty or after_file.empty:
                        print(f"Skipping incomplete pair for Link ID: {link_id} in group {group}")
                        continue  # Skip if "Before" or "After" is missing

                    before_file = before_file.iloc[0]
                    after_file = after_file.iloc[0]

                    # Create the figure with 1x2 layout
                    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
                    fig.suptitle(f"Group: {group}, Link ID: {link_id}", fontsize=14)

                    # Plot Before
                    abf_before = pyabf.ABF(before_file)
                    for sweepNumber in abf_before.sweepList:
                        abf_before.setSweep(sweepNumber)
                        offset = 140 * sweepNumber
                        axes[0].plot(abf_before.sweepX, abf_before.sweepY + offset, color='C0')
                    axes[0].set_title("Before")
                    axes[0].get_yaxis().set_visible(False)
                    axes[0].set_xlabel(abf_before.sweepLabelX)

                    # Plot After
                    abf_after = pyabf.ABF(after_file)
                    for sweepNumber in abf_after.sweepList:
                        abf_after.setSweep(sweepNumber)
                        offset = 140 * sweepNumber
                        axes[1].plot(abf_after.sweepX, abf_after.sweepY + offset, color='C1')
                    axes[1].set_title("After")
                    axes[1].get_yaxis().set_visible(False)
                    axes[1].set_xlabel(abf_after.sweepLabelX)

                    # Save the current figure to the PDF
                    pdf.savefig(fig)
                    plt.close(fig)
                    saved_pages += 1

                if saved_pages == 0:
                    print(f"No valid data to plot for group '{group}'. Deleting empty PDF.")
                    os.remove(pdf_path)
                else:
                    print(f"Saved PDF for group '{group}' to: {pdf_path}")              

    def detect_action_potentials(self, group, recording_id, label, sweep_number, height=None, prominence=None, distance=None, width=None):
            """
            Detect and plot action potentials (APs) for a specific recording from the DataFrame.

            Args:
                group (str): The group name.
                recording_id (str): The recording ID (e.g., "CTZ-1").
                label (str): The label ("Before" or "After").
                sweep_number (int): Sweep number to analyze.
                height (float, optional): Minimum height of peaks (APs) to detect.
                prominence (float, optional): Minimum prominence of peaks (APs) to detect.
                distance (float, optional): Minimum distance between consecutive peaks.
                width (float, optional): Minimum width of peaks.

            Returns:
                dict: A dictionary with sweep number, detected AP count, and peak indices.
            """
            # Locate the file path from the DataFrame
            entry = self.dataframe[
                (self.dataframe["Group"] == group) &
                (self.dataframe["Recording_ID"] == recording_id) &
                (self.dataframe["Label"] == label)
            ]

            if entry.empty:
                raise ValueError(f"No entry found for Group: {group}, Recording ID: {recording_id}, Label: {label}")

            file_path = entry["File_Path"].iloc[0]

            # Load the ABF file
            abf = pyabf.ABF(file_path)
            abf.setSweep(sweepNumber=sweep_number)

            # Extract time (X-axis) and voltage (Y-axis) data
            time = abf.sweepX
            voltage = abf.sweepY

            # Detect peaks using scipy's find_peaks
            peaks, properties = find_peaks(
                voltage,
                height=height,
                prominence=prominence,
                distance=distance,
                width=width
            )

            # Count the number of action potentials (peaks)
            ap_count = len(peaks)

            # Plot the sweep with peaks overlaid
            plt.figure(figsize=(10, 6))
            plt.plot(time, voltage, label="Voltage Trace", color="C0")
            plt.plot(time[peaks], voltage[peaks], "x", label="Detected Peaks", color="C3")
            plt.title(f"Group: {group}, Recording ID: {recording_id}, Sweep {sweep_number}: Detected {ap_count} Action Potentials")
            plt.xlabel("Time (s)")
            plt.ylabel("Voltage (mV)")
            plt.legend()
            plt.show()

            # Return a summary of results
            return {
                "group": group,
                "recording_id": recording_id,
                "label": label,
                "sweep_number": sweep_number,
                "action_potential_count": ap_count,
                "peak_indices": peaks
            }
            
    def create_group_pdf_with_peaks(self, output_dir, height=None, prominence=None, distance=None, width=None):
        """
        Generate a 2x1 multipage PDF for each group. Each page contains "Before" and "After"
        sweeps of a single recording with action potential peaks annotated.

        Args:
            output_dir (str): Path to save the generated PDF files.
            height (float, optional): Minimum height of peaks (APs) to detect.
            prominence (float, optional): Minimum prominence of peaks (APs) to detect.
            distance (float, optional): Minimum distance between consecutive peaks.
            width (float, optional): Minimum width of peaks.
        """
        if self.dataframe is None or self.dataframe.empty:
            print("No data to plot. Run `process_data` first.")
            return

        os.makedirs(output_dir, exist_ok=True)

        for group in self.unique_groups:
            group_data = self.dataframe[self.dataframe["Group"] == group]
            pdf_path = os.path.join(output_dir, f"{group}_wide.pdf")

            with PdfPages(pdf_path) as pdf:
                # Get unique recording IDs for the group
                recording_ids = group_data["Recording_ID"].unique()

                for recording_id in recording_ids:
                    # Locate Before and After entries
                    before_entry = group_data[(group_data["Recording_ID"] == recording_id) & (group_data["Label"] == "Before")]
                    after_entry = group_data[(group_data["Recording_ID"] == recording_id) & (group_data["Label"] == "After")]

                    if before_entry.empty or after_entry.empty:
                        print(f"Skipping incomplete pair for Recording ID: {recording_id} in group {group}")
                        continue

                    # Load Before and After ABF files
                    before_file = before_entry["File_Path"].iloc[0]
                    after_file = after_entry["File_Path"].iloc[0]
                    before_abf = pyabf.ABF(before_file)
                    after_abf = pyabf.ABF(after_file)

                    # Create a 2x1 figure for the recording
                    fig, axes = plt.subplots(2, 1, figsize=(16, 12), sharex=True)
                    fig.suptitle(f"Group: {group}, Recording ID: {recording_id}", fontsize=16)

                    # Plot Before sweeps (Top Panel)
                    before_peak_counts = []
                    for sweep_number in before_abf.sweepList:
                        before_abf.setSweep(sweepNumber=sweep_number)
                        time = before_abf.sweepX
                        voltage = before_abf.sweepY

                        # Detect peaks
                        peaks, properties = find_peaks(
                            voltage,
                            height=height,
                            prominence=prominence,
                            distance=distance,
                            width=width
                        )

                        # Count peaks and store for annotation
                        before_peak_counts.append(len(peaks))

                        # Plot the sweep with peaks
                        offset = 140 * sweep_number  # Offset to stack sweeps visually
                        axes[0].plot(time, voltage + offset, label=f"Sweep {sweep_number}", color="C0")
                        axes[0].plot(time[peaks], voltage[peaks] + offset, "x", color="C3")

                        # Annotate number of peaks and sweep number for Before sweeps
                        for sweep_number, peak_count in enumerate(before_peak_counts):
                            axes[0].text(
                                -0.05,  # Slightly to the left of the x-axis start
                                140 * sweep_number,  # Same vertical position as the trace
                                f"Sweep {sweep_number}: {peak_count} APs",  # Add sweep number and AP count
                                fontsize=10,
                                color="C0",
                                ha="right"  # Align text to the right
                            )

                    axes[0].set_title("Before")
                    axes[0].get_yaxis().set_visible(False)

                    # Plot After sweeps (Bottom Panel)
                    after_peak_counts = []
                    for sweep_number in after_abf.sweepList:
                        after_abf.setSweep(sweepNumber=sweep_number)
                        time = after_abf.sweepX
                        voltage = after_abf.sweepY

                        # Detect peaks
                        peaks, properties = find_peaks(
                            voltage,
                            height=height,
                            prominence=prominence,
                            distance=distance,
                            width=width
                        )

                        # Count peaks and store for annotation
                        after_peak_counts.append(len(peaks))

                        # Plot the sweep with peaks
                        offset = 140 * sweep_number  # Offset to stack sweeps visually
                        axes[1].plot(time, voltage + offset, label=f"Sweep {sweep_number}", color="C1")
                        axes[1].plot(time[peaks], voltage[peaks] + offset, "x", color="C4")
                        
                        # Annotate number of peaks and sweep number for After sweeps
                        for sweep_number, peak_count in enumerate(after_peak_counts):
                            axes[1].text(
                                -0.05,  # Slightly to the left of the x-axis start
                                140 * sweep_number,  # Same vertical position as the trace
                                f"Sweep {sweep_number}: {peak_count} APs",  # Add sweep number and AP count
                                fontsize=10,
                                color="C1",
                                ha="right"  # Align text to the right
                            )
                    axes[1].set_title("After")
                    axes[1].get_yaxis().set_visible(False)

                    # Decorate the figure
                    axes[1].set_xlabel("Time (s)")
                    for ax in axes:
                        ax.set_ylabel("Voltage (mV)")
                        ax.legend(loc="upper right")

                    # Save the current page to the PDF
                    pdf.savefig(fig)
                    plt.close(fig)

            print(f"Saved wide PDF for group '{group}' to: {pdf_path}")
            
    def process_peaks(self, height=None, prominence=None, distance=None, width=None, save_csv_path=None):
        """
        Process all recordings to detect action potentials (peaks) for each sweep
        and optionally save the results to a CSV file.

        Args:
            height (float, optional): Minimum height of peaks (APs) to detect.
            prominence (float, optional): Minimum prominence of peaks (APs) to detect.
            distance (float, optional): Minimum distance between consecutive peaks.
            width (float, optional): Minimum width of peaks.
            save_csv_path (str, optional): Path to save the processed peaks data as a CSV file.

        Returns:
            None: Stores the results in self.peak_dataframe and optionally saves it as a CSV.
        """
        if self.dataframe is None or self.dataframe.empty:
            print("No data to process. Run `process_data` first.")
            return

        peak_data = []

        for group in self.unique_groups:
            group_data = self.dataframe[self.dataframe["Group"] == group]

            for _, row in group_data.iterrows():
                recording_id = row["Recording_ID"]
                label = row["Label"]
                file_path = row["File_Path"]

                # Load the ABF file
                abf = pyabf.ABF(file_path)

                for sweep_number in abf.sweepList:
                    abf.setSweep(sweepNumber=sweep_number)
                    voltage = abf.sweepY

                    # Detect peaks
                    peaks, _ = find_peaks(
                        voltage,
                        height=height,
                        prominence=prominence,
                        distance=distance,
                        width=width
                    )

                    # Append results to the list
                    peak_data.append({
                        "Group": group,
                        "Recording_ID": recording_id,
                        "Label": label,
                        "Sweep_Number": sweep_number,
                        "AP_Count": len(peaks)
                    })

        # Create a DataFrame from the results
        self.peak_dataframe = pd.DataFrame(peak_data)
        print(f"Processed peaks for {len(self.peak_dataframe)} sweeps.")

        # Save to CSV if path is provided
        if save_csv_path:
            self.peak_dataframe.to_csv(save_csv_path, index=False)
            print(f"Saved peak data to {save_csv_path}.")
                     
    def import_csv_and_plot_mean_peaks(self, csv_path, output_pdf_path):
        """
        Import a CSV file containing peak data and plot the mean AP counts
        for "Before" and "After" sweeps for each group.

        Args:
            csv_path (str): Path to the CSV file containing peak data.
            output_pdf_path (str): Path to save the output PDF.

        Returns:
            None
        """
        # Load the CSV into a DataFrame
        try:
            peak_data = pd.read_csv(csv_path)
        except FileNotFoundError:
            print(f"File not found: {csv_path}")
            return

        # Validate the required columns
        required_columns = ["Group", "Recording_ID", "Label", "Sweep_Number", "AP_Count"]
        if not all(col in peak_data.columns for col in required_columns):
            print("The CSV file is missing one or more required columns.")
            return

        # Ensure the output directory exists
        os.makedirs(os.path.dirname(output_pdf_path), exist_ok=True)

        # Open a PDF for plotting
        with PdfPages(output_pdf_path) as pdf:
            # Group by "Group" and compute means
            groups = peak_data["Group"].unique()
            for group in groups:
                group_data = peak_data[peak_data["Group"] == group]

                # Compute mean and SEM for Before and After
                before_data = group_data[group_data["Label"] == "Before"]
                after_data = group_data[group_data["Label"] == "After"]

                mean_before = before_data.groupby("Recording_ID")["AP_Count"].mean()
                mean_after = after_data.groupby("Recording_ID")["AP_Count"].mean()

                sem_before = before_data.groupby("Recording_ID")["AP_Count"].sem()
                sem_after = after_data.groupby("Recording_ID")["AP_Count"].sem()

                # Plot the data
                fig, ax = plt.subplots(figsize=(10, 6))
                ax.bar(
                    x=["Before", "After"],
                    height=[mean_before.mean(), mean_after.mean()],
                    yerr=[sem_before.mean(), sem_after.mean()],
                    capsize=5,
                    color=["C0", "C1"],
                    alpha=0.7,
                    label=["Before", "After"]
                )
                ax.set_title(f"Group: {group}", fontsize=14)
                ax.set_ylabel("Mean AP Count (± SEM)")
                ax.set_xlabel("Condition")
                ax.legend()

                # Save the page to the PDF
                pdf.savefig(fig)
                plt.close(fig)

        print(f"Saved mean AP count plots to: {output_pdf_path}")
   
    def import_csv_and_plot_mean_peaks_lineplot(self, csv_path, output_pdf_path):
        """
        Import a CSV file containing peak data and plot the mean AP counts
        as a function of sweep number for each group.

        Args:
            csv_path (str): Path to the CSV file containing peak data.
            output_pdf_path (str): Path to save the output PDF.

        Returns:
            None
        """
        # Load the CSV into a DataFrame
        try:
            peak_data = pd.read_csv(csv_path)
        except FileNotFoundError:
            print(f"File not found: {csv_path}")
            return

        # Validate the required columns
        required_columns = ["Group", "Recording_ID", "Label", "Sweep_Number", "AP_Count"]
        if not all(col in peak_data.columns for col in required_columns):
            print("The CSV file is missing one or more required columns.")
            return

        # Ensure the output directory exists
        os.makedirs(os.path.dirname(output_pdf_path), exist_ok=True)

        # Open a PDF for plotting
        with PdfPages(output_pdf_path) as pdf:
            # Get unique groups
            groups = peak_data["Group"].unique()
            for group in groups:
                group_data = peak_data[peak_data["Group"] == group]

                # Compute mean and SEM for "Before" and "After" by Sweep_Number
                before_data = group_data[group_data["Label"] == "Before"]
                after_data = group_data[group_data["Label"] == "After"]

                mean_before = before_data.groupby("Sweep_Number")["AP_Count"].mean()
                sem_before = before_data.groupby("Sweep_Number")["AP_Count"].sem()

                mean_after = after_data.groupby("Sweep_Number")["AP_Count"].mean()
                sem_after = after_data.groupby("Sweep_Number")["AP_Count"].sem()

                # Plot the data
                fig, ax = plt.subplots(figsize=(10, 6))
                sweep_numbers = mean_before.index

                # Plot "Before" with SEM
                ax.plot(
                    sweep_numbers, mean_before,
                    label="Before",
                    color="C0",
                    linewidth=2
                )
                ax.fill_between(
                    sweep_numbers,
                    mean_before - sem_before,
                    mean_before + sem_before,
                    color="C0",
                    alpha=0.3
                )

                # Plot "After" with SEM
                ax.plot(
                    sweep_numbers, mean_after,
                    label="After",
                    color="C1",
                    linewidth=2
                )
                ax.fill_between(
                    sweep_numbers,
                    mean_after - sem_after,
                    mean_after + sem_after,
                    color="C1",
                    alpha=0.3
                )

                # Decorate the plot
                ax.set_title(f"Group: {group}", fontsize=14)
                ax.set_xlabel("Sweep Number")
                ax.set_ylabel("Mean AP Count (± SEM)")
                ax.legend()
                ax.grid(True)

                # Save the page to the PDF
                pdf.savefig(fig)
                plt.close(fig)

        print(f"Saved line plots for mean AP counts to: {output_pdf_path}")   

    def import_csv_and_plot_mean_peaks_with_error_bars(self, csv_path, output_pdf_path):
        """
        Import a CSV file containing peak data and plot the mean AP counts
        with SEM for "Before" and "After" at each sweep number as error bars.

        Args:
            csv_path (str): Path to the CSV file containing peak data.
            output_pdf_path (str): Path to save the output PDF.

        Returns:
            None
        """
        # Load the CSV into a DataFrame
        try:
            peak_data = pd.read_csv(csv_path)
        except FileNotFoundError:
            print(f"File not found: {csv_path}")
            return

        # Validate the required columns
        required_columns = ["Group", "Recording_ID", "Label", "Sweep_Number", "AP_Count"]
        if not all(col in peak_data.columns for col in required_columns):
            print("The CSV file is missing one or more required columns.")
            return

        # Ensure the output directory exists
        os.makedirs(os.path.dirname(output_pdf_path), exist_ok=True)

        # Open a PDF for plotting
        with PdfPages(output_pdf_path) as pdf:
            # Get unique groups
            groups = peak_data["Group"].unique()
            for group in groups:
                group_data = peak_data[peak_data["Group"] == group]

                # Compute mean and SEM for "Before" and "After" by Sweep_Number
                before_data = group_data[group_data["Label"] == "Before"]
                after_data = group_data[group_data["Label"] == "After"]

                mean_before = before_data.groupby("Sweep_Number")["AP_Count"].mean()
                sem_before = before_data.groupby("Sweep_Number")["AP_Count"].sem()

                mean_after = after_data.groupby("Sweep_Number")["AP_Count"].mean()
                sem_after = after_data.groupby("Sweep_Number")["AP_Count"].sem()

                # Plot the data
                fig, ax = plt.subplots(figsize=(10, 6))
                sweep_numbers = mean_before.index

                # Plot "Before" means and SEM as error bars
                ax.errorbar(
                    sweep_numbers - 0.2,  # Offset "Before" slightly to the left
                    mean_before,
                    yerr=sem_before,
                    fmt="o",  # Circle markers for "Before"
                    label="Before",
                    color="C0",
                    capsize=5,
                    markersize=8,
                )

                # Plot "After" means and SEM as error bars
                ax.errorbar(
                    sweep_numbers + 0.2,  # Offset "After" slightly to the right
                    mean_after,
                    yerr=sem_after,
                    fmt="o",  # Circle markers for "After"
                    label="After",
                    color="C1",
                    capsize=5,
                    markersize=8,
                )

                # Decorate the plot
                ax.set_title(f"Group: {group}", fontsize=14)
                ax.set_xlabel("Sweep Number")
                ax.set_ylabel("Mean AP Count (± SEM)")
                ax.legend()
                ax.grid(True)

                # Save the page to the PDF
                pdf.savefig(fig)
                plt.close(fig)

        print(f"Saved error bar plots for mean AP counts to: {output_pdf_path}")
        
    def process_peaks_in_window(self, height=None, prominence=None, distance=None, width=None, 
                                save_csv_path=None, start_time=None, end_time=None):
        """
        Process all recordings to detect action potentials (peaks) for each sweep within a specified time window,
        and optionally save the results to a CSV file.

        Args:
            height (float, optional): Minimum height of peaks (APs) to detect.
            prominence (float, optional): Minimum prominence of peaks (APs) to detect.
            distance (float, optional): Minimum distance between consecutive peaks.
            width (float, optional): Minimum width of peaks.
            save_csv_path (str, optional): Path to save the processed peaks data as a CSV file.
            start_time (float, optional): Start time (in seconds) of the window to analyze. Defaults to the beginning.
            end_time (float, optional): End time (in seconds) of the window to analyze. Defaults to the end.

        Returns:
            None: Stores the results in self.peak_window_dataframe and optionally saves it as a CSV.
        """
        if self.dataframe is None or self.dataframe.empty:
            print("No data to process. Run `process_data` first.")
            return

        peak_data = []

        for group in self.unique_groups:
            group_data = self.dataframe[self.dataframe["Group"] == group]

            for _, row in group_data.iterrows():
                recording_id = row["Recording_ID"]
                label = row["Label"]
                file_path = row["File_Path"]

                # Load the ABF file
                abf = pyabf.ABF(file_path)

                for sweep_number in abf.sweepList:
                    abf.setSweep(sweepNumber=sweep_number)
                    voltage = abf.sweepY
                    time = abf.sweepX  # Time vector in seconds

                    # If time window is specified, extract the relevant portion
                    if start_time is not None and end_time is not None:
                        mask = (time >= start_time) & (time <= end_time)
                        voltage = voltage[mask]
                        time = time[mask]

                    # Detect peaks within the specified window
                    peaks, _ = find_peaks(
                        voltage,
                        height=height,
                        prominence=prominence,
                        distance=distance,
                        width=width
                    )

                    # Append results to the list
                    peak_data.append({
                        "Group": group,
                        "Recording_ID": recording_id,
                        "Label": label,
                        "Sweep_Number": sweep_number,
                        "Start_Time": start_time if start_time is not None else 0,
                        "End_Time": end_time if end_time is not None else time[-1],
                        "AP_Count": len(peaks)
                    })

        # Create a DataFrame from the results
        self.peak_window_dataframe = pd.DataFrame(peak_data)
        print(f"Processed peaks for {len(self.peak_window_dataframe)} sweeps.")

        # Save to CSV if path is provided
        if save_csv_path:
            self.peak_window_dataframe.to_csv(save_csv_path, index=False)
            print(f"Saved peak data to {save_csv_path}.")
               
    def process_peaks_by_phase(self, height=None, prominence=None, distance=None, width=None, 
                            early_start=None, early_end=None, late_start=None, late_end=None, 
                            save_csv_path=None):
        """
        Process recordings to detect action potentials (peaks) within early and late phases of each sweep,
        save indices of detected spikes, and optionally save the results to a CSV file.

        Args:
            height (float, optional): Minimum height of peaks (APs) to detect.
            prominence (float, optional): Minimum prominence of peaks (APs) to detect.
            distance (float, optional): Minimum distance between consecutive peaks.
            width (float, optional): Minimum width of peaks.
            early_start (float, optional): Start time (in seconds) for the early phase window.
            early_end (float, optional): End time (in seconds) for the early phase window.
            late_start (float, optional): Start time (in seconds) for the late phase window.
            late_end (float, optional): End time (in seconds) for the late phase window.
            save_csv_path (str, optional): Path to save the processed peaks data as a CSV file.

        Returns:
            None: Stores the results in self.phase_peak_dataframe and optionally saves it as a CSV.
        """
        if self.dataframe is None or self.dataframe.empty:
            print("No data to process. Run `process_data` first.")
            return

        phase_peak_data = []

        for group in self.unique_groups:
            group_data = self.dataframe[self.dataframe["Group"] == group]

            for _, row in group_data.iterrows():
                recording_id = row["Recording_ID"]
                label = row["Label"]
                file_path = row["File_Path"]

                # Load the ABF file
                abf = pyabf.ABF(file_path)

                for sweep_number in abf.sweepList:
                    abf.setSweep(sweepNumber=sweep_number)
                    voltage = abf.sweepY
                    time = abf.sweepX  # Time vector in seconds

                    # Initialize variables for indices
                    early_indices = []
                    late_indices = []

                    # Process Early Phase
                    if early_start is not None and early_end is not None:
                        early_mask = (time >= early_start) & (time <= early_end)
                        early_voltage = voltage[early_mask]
                        early_peaks, _ = find_peaks(
                            early_voltage,
                            height=height,
                            prominence=prominence,
                            distance=distance,
                            width=width
                        )
                        early_ap_count = len(early_peaks)
                        early_indices = np.where(early_mask)[0][early_peaks]  # Map local to global indices
                    else:
                        early_ap_count = 0

                    # Process Late Phase
                    if late_start is not None and late_end is not None:
                        late_mask = (time >= late_start) & (time <= late_end)
                        late_voltage = voltage[late_mask]
                        late_peaks, _ = find_peaks(
                            late_voltage,
                            height=height,
                            prominence=prominence,
                            distance=distance,
                            width=width
                        )
                        late_ap_count = len(late_peaks)
                        late_indices = np.where(late_mask)[0][late_peaks]  # Map local to global indices
                    else:
                        late_ap_count = 0

                    # Append results to the list
                    phase_peak_data.append({
                        "Group": group,
                        "Recording_ID": recording_id,
                        "Label": label,
                        "Sweep_Number": sweep_number,
                        "Early_AP_Count": early_ap_count,
                        "Late_AP_Count": late_ap_count,
                        "Early_Indices": early_indices.tolist(),  # Save indices as a list
                        "Late_Indices": late_indices.tolist(),    # Save indices as a list
                        "Early_Start_Time": early_start if early_start is not None else 0,
                        "Early_End_Time": early_end if early_end is not None else 0,
                        "Late_Start_Time": late_start if late_start is not None else 0,
                        "Late_End_Time": late_end if late_end is not None else 0
                    })

        # Create a DataFrame from the results
        self.phase_peak_dataframe = pd.DataFrame(phase_peak_data)
        print(f"Processed early and late peaks for {len(self.phase_peak_dataframe)} sweeps.")

        # Save to CSV if path is provided
        if save_csv_path:
            self.phase_peak_dataframe.to_csv(save_csv_path, index=False)
            print(f"Saved phase peak data to {save_csv_path}.")
            
    def create_group_pdf_with_deltas_from_dataframe(self, output_dir):
        """
        Generate PDFs for each group, with each page showing ΔAP (Late-Early) across sweeps
        for a single recording, using data from phase_peak_dataframe.

        Args:
            output_dir (str): Directory to save the generated PDF files.
        """
        if self.phase_peak_dataframe is None or self.phase_peak_dataframe.empty:
            print("No data in `phase_peak_dataframe`. Run `process_peaks_by_phase` first.")
            return

        os.makedirs(output_dir, exist_ok=True)

        # Group the data by Group
        group_data = self.phase_peak_dataframe.groupby("Group")

        for group, group_df in group_data:
            pdf_path = os.path.join(output_dir, f"{group}_delta_ap.pdf")

            with PdfPages(pdf_path) as pdf:
                # Iterate over each recording within the group
                recordings = group_df.groupby("Recording_ID")

                for recording_id, recording_df in recordings:
                    # Separate Before and After conditions
                    before_data = recording_df[recording_df["Label"] == "Before"]
                    after_data = recording_df[recording_df["Label"] == "After"]

                    # Calculate ΔAP (Late - Early)
                    before_data["Delta_AP"] = before_data["Late_AP_Count"] - before_data["Early_AP_Count"]
                    after_data["Delta_AP"] = after_data["Late_AP_Count"] - after_data["Early_AP_Count"]

                    # Debugging Info
                    print(f"\nGroup: {group}, Recording ID: {recording_id}")
                    print("Before ΔAP:")
                    print(before_data[["Sweep_Number", "Early_AP_Count", "Late_AP_Count", "Delta_AP"]])
                    print("After ΔAP:")
                    print(after_data[["Sweep_Number", "Early_AP_Count", "Late_AP_Count", "Delta_AP"]])

                    # Create the plot
                    fig, ax = plt.subplots(figsize=(10, 6))
                    ax.plot(
                        before_data["Sweep_Number"],
                        before_data["Delta_AP"],
                        label="Before Luciferin",
                        marker="o",
                        color="blue"
                    )
                    ax.plot(
                        after_data["Sweep_Number"],
                        after_data["Delta_AP"],
                        label="After Luciferin",
                        marker="o",
                        color="orange"
                    )
                    ax.axhline(0, color="black", linestyle="--", linewidth=0.8)
                    ax.set_xlabel("Sweep Number")
                    ax.set_ylabel("ΔAP (Late - Early)")
                    ax.set_title(f"ΔAP (Late - Early) for Recording ID: {recording_id}")
                    ax.legend()
                    ax.grid()

                    # Ensure y-axis scales dynamically
                    ax.autoscale(enable=True, axis='y', tight=True)

                    # Save the current page to the PDF
                    pdf.savefig(fig)
                    plt.close(fig)

            print(f"Saved ΔAP PDF for group '{group}' to: {pdf_path}")
            
    def create_group_pdf_with_early_vs_late_counts(self, output_dir):
        """
        Generate PDFs for each group, with each page showing Early vs. Late phase spike counts
        for a single recording, using data from phase_peak_dataframe.

        Args:
            output_dir (str): Directory to save the generated PDF files.
        """
        if self.phase_peak_dataframe is None or self.phase_peak_dataframe.empty:
            print("No data in `phase_peak_dataframe`. Run `process_peaks_by_phase` first.")
            return

        os.makedirs(output_dir, exist_ok=True)

        # Group the data by Group
        group_data = self.phase_peak_dataframe.groupby("Group")

        for group, group_df in group_data:
            pdf_path = os.path.join(output_dir, f"{group}_early_vs_late_counts.pdf")

            with PdfPages(pdf_path) as pdf:
                # Iterate over each recording within the group
                recordings = group_df.groupby("Recording_ID")

                for recording_id, recording_df in recordings:
                    # Separate Before and After data
                    before_data = recording_df[recording_df["Label"] == "Before"]
                    after_data = recording_df[recording_df["Label"] == "After"]

                    # Skip if no data for Before or After conditions
                    if before_data.empty or after_data.empty:
                        print(f"Skipping recording {recording_id} in group {group}: Missing 'Before' or 'After' data.")
                        continue

                    # Create the plot
                    fig, ax = plt.subplots(figsize=(12, 6))

                    bar_width = 0.35
                    indices = range(len(before_data))

                    # Ensure indices match the size of data
                    if len(before_data) != len(after_data):
                        print(f"Skipping recording {recording_id} in group {group}: Mismatched sweep counts between 'Before' and 'After'.")
                        continue

                    # Early phase counts
                    ax.bar(
                        [i - bar_width / 2 for i in indices],
                        before_data["Early_AP_Count"],
                        width=bar_width,
                        label="Before Early",
                        color="lightgrey",
                        alpha=0.8
                    )
                    ax.bar(
                        [i + bar_width / 2 for i in indices],
                        after_data["Early_AP_Count"],
                        width=bar_width,
                        label="After Early",
                        color="lightblue",
                        alpha=0.8
                    )

                    # Late phase counts
                    ax.bar(
                        [i - bar_width / 2 for i in indices],
                        before_data["Late_AP_Count"],
                        width=bar_width,
                        label="Before Late",
                        color="dimgrey",
                        alpha=0.8,
                        bottom=before_data["Early_AP_Count"]
                    )
                    ax.bar(
                        [i + bar_width / 2 for i in indices],
                        after_data["Late_AP_Count"],
                        width=bar_width,
                        label="After Late",
                        color="royalblue",
                        alpha=0.8,
                        bottom=after_data["Early_AP_Count"]
                    )

                    # Formatting
                    ax.set_xticks(indices)
                    ax.set_xticklabels(before_data["Sweep_Number"])
                    ax.set_xlabel("Sweep Number")
                    ax.set_ylabel("Spike Counts")
                    ax.set_title(f"Early vs. Late Phase Spike Counts\nGroup: {group}, Recording ID: {recording_id}")
                    ax.legend()
                    ax.grid(axis='y', linestyle="--", alpha=0.7)

                    # Save the current page to the PDF
                    pdf.savefig(fig)
                    plt.close(fig)

            print(f"Saved Early vs. Late Phase Spike Counts PDF for group '{group}' to: {pdf_path}")
            
    def create_group_pdf_with_early_to_late_ratios(self, output_dir):
        """
        Generate PDFs for each group, with each page showing Early-to-Late ratios
        for a single recording, using data from phase_peak_dataframe.

        Args:
            output_dir (str): Directory to save the generated PDF files.
        """
        if self.phase_peak_dataframe is None or self.phase_peak_dataframe.empty:
            print("No data in `phase_peak_dataframe`. Run `process_peaks_by_phase` first.")
            return

        os.makedirs(output_dir, exist_ok=True)

        # Group the data by Group
        group_data = self.phase_peak_dataframe.groupby("Group")

        for group, group_df in group_data:
            pdf_path = os.path.join(output_dir, f"{group}_early_to_late_ratios.pdf")

            with PdfPages(pdf_path) as pdf:
                # Iterate over each recording within the group
                recordings = group_df.groupby("Recording_ID")

                for recording_id, recording_df in recordings:
                    # Separate Before and After data
                    before_data = recording_df[recording_df["Label"] == "Before"].copy()
                    after_data = recording_df[recording_df["Label"] == "After"].copy()

                    # Calculate Early-to-Late ratios
                    before_data["Ratio"] = before_data["Late_AP_Count"] / before_data["Early_AP_Count"].replace(0, np.nan)
                    after_data["Ratio"] = after_data["Late_AP_Count"] / after_data["Early_AP_Count"].replace(0, np.nan)

                    # Skip if no valid data for either condition
                    if before_data["Ratio"].isna().all() and after_data["Ratio"].isna().all():
                        print(f"Skipping recording {recording_id} in group {group}: No valid ratios to plot.")
                        continue

                    # Create the plot
                    fig, ax = plt.subplots(figsize=(12, 6))

                    # Plot Before and After ratios
                    ax.plot(
                        before_data["Sweep_Number"],
                        before_data["Ratio"],
                        label="Before",
                        marker="o",
                        color="grey",
                        alpha=0.8
                    )
                    ax.plot(
                        after_data["Sweep_Number"],
                        after_data["Ratio"],
                        label="After",
                        marker="o",
                        color="blue",
                        alpha=0.8
                    )

                    # Formatting
                    ax.axhline(1, color="black", linestyle="--", linewidth=0.8, label="Ratio = 1")
                    ax.set_xticks(before_data["Sweep_Number"])
                    ax.set_xlabel("Sweep Number")
                    ax.set_ylabel("Early-to-Late Ratio")
                    ax.set_title(f"Early-to-Late Ratio\nGroup: {group}, Recording ID: {recording_id}")
                    ax.legend()
                    ax.grid(axis='y', linestyle="--", alpha=0.7)

                    # Save the current page to the PDF
                    pdf.savefig(fig)
                    plt.close(fig)

            print(f"Saved Early-to-Late Ratio PDF for group '{group}' to: {pdf_path}")
            
    def create_group_pdf_with_mean_and_sem(self, output_dir):
        """
        Generate a PDF for each group, with a single page showing the mean and SEM
        of ΔAP (Late - Early) across sweeps for the group.

        Args:
            output_dir (str): Directory to save the generated PDF files.
        """
        if self.phase_peak_dataframe is None or self.phase_peak_dataframe.empty:
            print("No data in `phase_peak_dataframe`. Run `process_peaks_by_phase` first.")
            return

        os.makedirs(output_dir, exist_ok=True)

        pdf_path = os.path.join(output_dir, "group_mean_and_sem_delta_ap.pdf")

        with PdfPages(pdf_path) as pdf:
            # Group the data by Group
            group_data = self.phase_peak_dataframe.groupby("Group")

            for group, group_df in group_data:
                # Calculate ΔAP (Late - Early)
                group_df["Delta_AP"] = group_df["Late_AP_Count"] - group_df["Early_AP_Count"]

                # Separate Before and After conditions
                before_data = group_df[group_df["Label"] == "Before"]
                after_data = group_df[group_df["Label"] == "After"]

                # Calculate mean and SEM for Before and After
                mean_before = before_data.groupby("Sweep_Number")["Delta_AP"].mean()
                sem_before = before_data.groupby("Sweep_Number")["Delta_AP"].sem()

                mean_after = after_data.groupby("Sweep_Number")["Delta_AP"].mean()
                sem_after = after_data.groupby("Sweep_Number")["Delta_AP"].sem()

                # Create the plot
                fig, ax = plt.subplots(figsize=(10, 6))

                # Plot mean and SEM for Before
                ax.errorbar(
                    mean_before.index,
                    mean_before,
                    yerr=sem_before,
                    label="Before Luciferin",
                    fmt="o-",
                    color="blue",
                    capsize=3
                )

                # Plot mean and SEM for After
                ax.errorbar(
                    mean_after.index,
                    mean_after,
                    yerr=sem_after,
                    label="After Luciferin",
                    fmt="o-",
                    color="orange",
                    capsize=3
                )

                # Formatting
                ax.axhline(0, color="black", linestyle="--", linewidth=0.8)
                ax.set_xlabel("Sweep Number")
                ax.set_ylabel("Mean ΔAP (Late - Early) ± SEM")
                ax.set_title(f"Mean ΔAP (Late - Early) ± SEM for Group: {group}")
                ax.legend()
                ax.grid()

                # Save the page to the PDF
                pdf.savefig(fig)
                plt.close(fig)

            print(f"Saved group mean and SEM ΔAP PDF to: {pdf_path}")
            
    def create_group_pdf_with_mean_and_sem_and_store_data(self, output_dir):
        """
        Generate a PDF for each group, with a single page showing the mean and SEM
        of ΔAP (Late-Early) across sweeps for the group, and store data as an attribute.

        Args:
            output_dir (str): Directory to save the generated PDF files.
        """
        if self.phase_peak_dataframe is None or self.phase_peak_dataframe.empty:
            print("No data in `phase_peak_dataframe`. Run `process_peaks_by_phase` first.")
            return

        os.makedirs(output_dir, exist_ok=True)
        pdf_path = os.path.join(output_dir, "group_mean_and_sem_delta_ap.pdf")

        # Initialize a dictionary to store data
        self.group_delta_ap_stats = {}

        with PdfPages(pdf_path) as pdf:
            # Group the data by Group
            group_data = self.phase_peak_dataframe.groupby("Group")

            for group, group_df in group_data:
                # Calculate ΔAP (Late - Early)
                group_df["Delta_AP"] = group_df["Late_AP_Count"] - group_df["Early_AP_Count"]

                # Separate Before and After conditions
                before_data = group_df[group_df["Label"] == "Before"]
                after_data = group_df[group_df["Label"] == "After"]

                # Calculate mean and SEM for Before and After
                mean_before = before_data.groupby("Sweep_Number")["Delta_AP"].mean()
                sem_before = before_data.groupby("Sweep_Number")["Delta_AP"].sem()
                mean_after = after_data.groupby("Sweep_Number")["Delta_AP"].mean()
                sem_after = after_data.groupby("Sweep_Number")["Delta_AP"].sem()

                # Save raw data for stats
                self.group_delta_ap_stats[group] = {
                    "before": before_data[["Sweep_Number", "Delta_AP"]],
                    "after": after_data[["Sweep_Number", "Delta_AP"]],
                    "mean_before": mean_before,
                    "sem_before": sem_before,
                    "mean_after": mean_after,
                    "sem_after": sem_after,
                }

                # Create the plot
                fig, ax = plt.subplots(figsize=(10, 6))

                # Plot mean and SEM for Before
                ax.errorbar(
                    mean_before.index,
                    mean_before,
                    yerr=sem_before,
                    label="Before Luciferin",
                    fmt="o-",
                    color="blue",
                    capsize=3
                )

                # Plot mean and SEM for After
                ax.errorbar(
                    mean_after.index,
                    mean_after,
                    yerr=sem_after,
                    label="After Luciferin",
                    fmt="o-",
                    color="orange",
                    capsize=3
                )

                # Formatting
                ax.axhline(0, color="black", linestyle="--", linewidth=0.8)
                ax.set_xlabel("Sweep Number")
                ax.set_ylabel("Mean ΔAP (Late - Early) ± SEM")
                ax.set_title(f"Mean ΔAP (Late - Early) ± SEM for Group: {group}")
                ax.legend()
                ax.grid()

                # Save the page to the PDF
                pdf.savefig(fig)
                plt.close(fig)

            print(f"Saved group mean and SEM ΔAP PDF to: {pdf_path}")

    def run_two_way_anova_with_correction(self):
        """
        Perform a two-way repeated-measures ANOVA for each group comparing Before and After
        conditions at each stimulus sweep, with Bonferroni correction for multiple comparisons.

        Returns:
            results (dict): Dictionary of ANOVA results with corrected p-values for each group.
        """
        if not hasattr(self, "group_delta_ap_stats"):
            print("No stored data. Run `create_group_pdf_with_mean_and_sem_and_store_data` first.")
            return

        results = {}

        for group, data in self.group_delta_ap_stats.items():
            # Combine before and after data into one DataFrame
            combined_data = pd.concat(
                [
                    data["before"].assign(Condition="Before"),
                    data["after"].assign(Condition="After")
                ]
            )

            # Aggregate data to ensure one observation per Sweep_Number and Condition
            combined_data = combined_data.groupby(["Sweep_Number", "Condition"])["Delta_AP"].mean().reset_index()

            # Perform repeated-measures ANOVA
            aovrm = AnovaRM(combined_data, depvar="Delta_AP", subject="Sweep_Number", within=["Condition"])
            anova_results = aovrm.fit()

            # Extract p-values and apply Bonferroni correction
            p_values = [anova_results.anova_table["Pr > F"].iloc[0]]  # Single condition comparison
            corrected_p_values = multipletests(p_values, method="bonferroni")[1]

            # Store results
            results[group] = {
                "anova_summary": anova_results.summary(),
                "corrected_p_values": corrected_p_values
            }

            # Print results for debugging
            print(f"\nGroup: {group}")
            print(anova_results)
            print(f"Corrected P-Values: {corrected_p_values}")

        return results, combined_data
        
    def create_group_pdf_with_peaks_complex(
            self,
            output_dir,
            height=None,
            prominence=None,
            distance=None,
            width=None,
            time_range=None,
            early_phase=(0, 0.5),
            late_phase=(0.5, 1.0)):
        """
        Generate a multi-row, multi-column PDF for each group with modular design.
        """
        if self.dataframe is None or self.dataframe.empty:
            print("No data to plot. Run `process_data` first.")
            return

        os.makedirs(output_dir, exist_ok=True)

        for group in self.unique_groups:
            group_data = self.dataframe[self.dataframe["Group"] == group]
            pdf_path = os.path.join(output_dir, f"{group}_complex.pdf")

            with PdfPages(pdf_path) as pdf:
                recording_ids = group_data["Recording_ID"].unique()

                for recording_id in recording_ids:
                    before_entry = group_data[(group_data["Recording_ID"] == recording_id) & (group_data["Label"] == "Before")]
                    after_entry = group_data[(group_data["Recording_ID"] == recording_id) & (group_data["Label"] == "After")]

                    if before_entry.empty or after_entry.empty:
                        print(f"Skipping incomplete pair for Recording ID: {recording_id} in group {group}")
                        continue

                    before_file = before_entry["File_Path"].iloc[0]
                    after_file = after_entry["File_Path"].iloc[0]
                    before_abf = pyabf.ABF(before_file)
                    after_abf = pyabf.ABF(after_file)

                    fig = plt.figure(figsize=(20, 24))
                    fig.suptitle(f"Group: {group}, Recording ID: {recording_id}", fontsize=16)
                    gs = gridspec.GridSpec(9, 4, figure=fig)
                    plt.subplots_adjust(wspace=0.25, hspace=0.35)

                    bar_data = {"sweep": [], "early_before": [], "late_before": [], "early_after": [], "late_after": []}
                    before_isi = []
                    after_isi = []

                    for sweep_number in range(min(9, len(before_abf.sweepList))):  # Limit to 9 rows
                        time_before, voltage_before, peaks_before, early_before, late_before = process_sweep_data(
                            before_abf, sweep_number, height, prominence, distance, width, time_range, early_phase, late_phase
                        )
                        time_after, voltage_after, peaks_after, early_after, late_after = process_sweep_data(
                            after_abf, sweep_number, height, prominence, distance, width, time_range, early_phase, late_phase
                        )

                        bar_data["sweep"].append(sweep_number + 1)
                        bar_data["early_before"].append(early_before)
                        bar_data["late_before"].append(late_before)
                        bar_data["early_after"].append(early_after)
                        bar_data["late_after"].append(late_after)

                        before_isi.extend(calculate_isi_distribution(peaks_before, time_before))
                        after_isi.extend(calculate_isi_distribution(peaks_after, time_after))

                        ax_before = fig.add_subplot(gs[sweep_number, 0])
                        plot_peaks(ax_before, time_before, voltage_before, peaks_before, "Before", "C0", f"Sweep {sweep_number + 1} (Before)", ylabel="Voltage (mV)")

                        ax_after = fig.add_subplot(gs[sweep_number, 1])
                        plot_peaks(ax_after, time_after, voltage_after, peaks_after, "After", "C1", f"Sweep {sweep_number + 1} (After)")

                    ax_bar = fig.add_subplot(gs[0:2, 2:])
                    generate_wide_plot_1(ax_bar, bar_data)
                    
                    ax_plot_2 = fig.add_subplot(gs[2:4, 2:])
                    generate_wide_plot_2(ax_plot_2, bar_data)

                    ax_plot_3 = fig.add_subplot(gs[4:6, 2:])
                    generate_wide_plot_3(ax_plot_3, before_isi, after_isi)
                    
                    ## Add more plots here to fill the remaining space
                    # Calculate and plot fractions of early spikes
                    fractions = calculate_fractions(bar_data)
                    ax_fraction = fig.add_subplot(gs[6:8, 2:])
                    generate_wide_plot_fraction(ax_fraction, fractions)
                    
                    ### add the ECDF plot 
                    ax_ecdf = fig.add_subplot(gs[8:9, 2:])
                    generate_ecdf_plot(ax_ecdf, fractions)

                    pdf.savefig(fig)
                    plt.close(fig)

                print(f"Saved complex PDF for group '{group}' to: {pdf_path}")

    def create_group_pooled_ecdf(
        self,
        output_dir,
        height=None,
        prominence=None,
        distance=None,
        width=None,
        time_range=None,
        early_phase=(0.1, 0.40),
        late_phase=(0.5, 1.12),
        x_min=-1.0,
        x_max=1.0,
        step_mode="post"
    ):
        """
        Create a pooled ECDF plot for each group by combining fraction differences (early_fraction - late_fraction)
        from all cells and sweeps. A single PDF per group is saved containing just the ECDF plot.

        Parameters
        ----------
        output_dir : str
            Directory to save the resulting PDFs.
        height : float or None
            Peak detection parameter passed to process_sweep_data.
        prominence : float or None
            Peak detection parameter passed to process_sweep_data.
        distance : float or None
            Peak detection parameter passed to process_sweep_data.
        width : float or None
            Peak detection parameter passed to process_sweep_data.
        time_range : tuple or None
            Time range for analyzing data. If None, use full sweep.
        early_phase : tuple (start, end)
            Time window defining the early phase fraction of the sweep.
        late_phase : tuple (start, end)
            Time window defining the late phase fraction of the sweep.
        x_min : float
            Minimum x-value for plotting ECDF. Default is -1.
        x_max : float
            Maximum x-value for plotting ECDF. Default is 1.
        step_mode : str
            How the step plot is drawn. Default "post".
            Options: "pre", "post", "mid".
        """
        if self.dataframe is None or self.dataframe.empty:
            print("No data to plot. Run `process_data` first.")
            return

        os.makedirs(output_dir, exist_ok=True)

        for group in self.unique_groups:
            group_data = self.dataframe[self.dataframe["Group"] == group]
            recording_ids = group_data["Recording_ID"].unique()

            pooled_diff_before = []
            pooled_diff_after = []

            # Collect all fraction differences across all recordings of this group
            for recording_id in recording_ids:
                before_entry = group_data[(group_data["Recording_ID"] == recording_id) & (group_data["Label"] == "Before")]
                after_entry = group_data[(group_data["Recording_ID"] == recording_id) & (group_data["Label"] == "After")]

                if before_entry.empty or after_entry.empty:
                    print(f"Skipping incomplete pair for Recording ID: {recording_id} in group {group}")
                    continue

                before_file = before_entry["File_Path"].iloc[0]
                after_file = after_entry["File_Path"].iloc[0]
                before_abf = pyabf.ABF(before_file)
                after_abf = pyabf.ABF(after_file)

                bar_data = {
                    "sweep": [],
                    "early_before": [],
                    "late_before": [],
                    "early_after": [],
                    "late_after": []
                }

                # Process each sweep and accumulate early/late spike counts
                for sweep_number in range(len(before_abf.sweepList)):
                    time_before, voltage_before, peaks_before, early_before, late_before = process_sweep_data(
                        before_abf, sweep_number, height, prominence, distance, width, time_range, early_phase, late_phase
                    )
                    time_after, voltage_after, peaks_after, early_after, late_after = process_sweep_data(
                        after_abf, sweep_number, height, prominence, distance, width, time_range, early_phase, late_phase
                    )

                    bar_data["sweep"].append(sweep_number + 1)
                    bar_data["early_before"].append(early_before)
                    bar_data["late_before"].append(late_before)
                    bar_data["early_after"].append(early_after)
                    bar_data["late_after"].append(late_after)

                # Compute fractions for this recording
                fractions = calculate_fractions(bar_data)
                diff_before = np.array(fractions["early_fraction_before"]) - np.array(fractions["late_fraction_before"])
                diff_after = np.array(fractions["early_fraction_after"]) - np.array(fractions["late_fraction_after"])

                # Filter out NaNs
                diff_before = diff_before[~np.isnan(diff_before)]
                diff_after = diff_after[~np.isnan(diff_after)]

                # Append to pooled lists
                pooled_diff_before.extend(diff_before)
                pooled_diff_after.extend(diff_after)

            # Now plot the pooled ECDF for this group
            pdf_path = os.path.join(output_dir, f"{group}_pooled_ecdf.pdf")

            with PdfPages(pdf_path) as pdf:
                fig, ax = plt.subplots(figsize=(8, 6))

                # Sort the data
                diff_before_sorted = np.sort(pooled_diff_before) if len(pooled_diff_before) > 0 else []
                diff_after_sorted = np.sort(pooled_diff_after) if len(pooled_diff_after) > 0 else []

                def ecdf(data):
                    if len(data) == 0:
                        # No data: just plot a flat line at zero
                        x = [x_min, x_max]
                        y = [0.0, 0.0]
                    else:
                        # Compute normal ECDF
                        y = np.arange(1, len(data) + 1) / len(data)
                        # Add starting point (x_min,0.0) and ending point (x_max,1.0)
                        x = np.concatenate([[x_min], data, [x_max]])
                        y = np.concatenate([[0.0], y, [1.0]])
                    return x, y

                x_before, y_before = ecdf(diff_before_sorted)
                x_after, y_after = ecdf(diff_after_sorted)

                # Plot ECDFs without axis lines
                ax.step(x_before, y_before, color="gray", label="Before", where=step_mode)
                ax.step(x_after, y_after, color="blue", label="After", where=step_mode)

                # Set axis limits
                ax.set_xlim(x_min, x_max)
                ax.set_ylim(0, 1.0)

                # Remove axis lines (spines)
                for spine in ax.spines.values():
                    spine.set_visible(False)

                ax.set_xlabel("Fraction Difference (Early - Late)", fontsize=12)
                ax.set_ylabel("Cumulative Probability", fontsize=12)
                ax.set_title("Pooled ECDF of Fraction Differences (Before vs. After)", fontsize=14)
                ax.legend()

                pdf.savefig(fig)
                plt.close(fig)

            print(f"Saved pooled ECDF PDF for group '{group}' to: {pdf_path}")

    def create_group_pooled_ecdf_svg(
        self,
        output_dir,
        height=None,
        prominence=None,
        distance=None,
        width=None,
        time_range=None,
        early_phase=(0.1, 0.40),
        late_phase=(0.5, 1.12),
        x_min=-1.0,
        x_max=1.0,
        step_mode="post",
        fig_width=8,
        fig_height=6,
        remove_spines=True,
        transparent=True
    ):
        """
        Create a pooled ECDF plot for each group by combining fraction differences (early_fraction - late_fraction)
        from all cells and sweeps. A single .svg per group is saved containing the ECDF plot.

        Parameters
        ----------
        output_dir : str
            Directory to save the resulting SVGs.
        height : float or None
            Peak detection parameter passed to process_sweep_data.
        prominence : float or None
            Peak detection parameter passed to process_sweep_data.
        distance : float or None
            Peak detection parameter passed to process_sweep_data.
        width : float or None
            Peak detection parameter passed to process_sweep_data.
        time_range : tuple or None
            Time range for analyzing data. If None, use full sweep.
        early_phase : tuple (start, end)
            Time window defining the early phase fraction of the sweep.
        late_phase : tuple (start, end)
            Time window defining the late phase fraction of the sweep.
        x_min : float
            Minimum x-value for plotting ECDF.
        x_max : float
            Maximum x-value for plotting ECDF.
        step_mode : str
            How the step plot is drawn. Default "post".
            Options: "pre", "post", "mid".
        fig_width : float
            Width of the figure in inches.
        fig_height : float
            Height of the figure in inches.
        remove_spines : bool
            If True, removes axis spines for a cleaner look.
        transparent : bool
            If True, sets the plot background to transparent.
        """

        if self.dataframe is None or self.dataframe.empty:
            print("No data to plot. Run `process_data` first.")
            return

        # Ensure text remains editable in SVG
        plt.rcParams["svg.fonttype"] = "none"

        os.makedirs(output_dir, exist_ok=True)

        for group in self.unique_groups:
            group_data = self.dataframe[self.dataframe["Group"] == group]
            recording_ids = group_data["Recording_ID"].unique()

            pooled_diff_before = []
            pooled_diff_after = []

            # Collect all fraction differences across all recordings of this group
            for recording_id in recording_ids:
                before_entry = group_data[(group_data["Recording_ID"] == recording_id) & (group_data["Label"] == "Before")]
                after_entry = group_data[(group_data["Recording_ID"] == recording_id) & (group_data["Label"] == "After")]

                if before_entry.empty or after_entry.empty:
                    print(f"Skipping incomplete pair for Recording ID: {recording_id} in group {group}")
                    continue

                before_file = before_entry["File_Path"].iloc[0]
                after_file = after_entry["File_Path"].iloc[0]
                before_abf = pyabf.ABF(before_file)
                after_abf = pyabf.ABF(after_file)

                bar_data = {
                    "sweep": [],
                    "early_before": [],
                    "late_before": [],
                    "early_after": [],
                    "late_after": []
                }

                # Process each sweep and accumulate early/late spike counts
                for sweep_number in range(len(before_abf.sweepList)):
                    time_before, voltage_before, peaks_before, early_before, late_before = process_sweep_data(
                        before_abf, sweep_number, height, prominence, distance, width, time_range, early_phase, late_phase
                    )
                    time_after, voltage_after, peaks_after, early_after, late_after = process_sweep_data(
                        after_abf, sweep_number, height, prominence, distance, width, time_range, early_phase, late_phase
                    )

                    bar_data["sweep"].append(sweep_number + 1)
                    bar_data["early_before"].append(early_before)
                    bar_data["late_before"].append(late_before)
                    bar_data["early_after"].append(early_after)
                    bar_data["late_after"].append(late_after)

                # Compute fractions for this recording
                fractions = calculate_fractions(bar_data)
                diff_before = np.array(fractions["early_fraction_before"]) - np.array(fractions["late_fraction_before"])
                diff_after = np.array(fractions["early_fraction_after"]) - np.array(fractions["late_fraction_after"])

                # Filter out NaNs
                diff_before = diff_before[~np.isnan(diff_before)]
                diff_after = diff_after[~np.isnan(diff_after)]

                # Append to pooled lists
                pooled_diff_before.extend(diff_before)
                pooled_diff_after.extend(diff_after)

            # Now plot the pooled ECDF for this group
            svg_path = os.path.join(output_dir, f"{group}_pooled_ecdf.svg")

            fig, ax = plt.subplots(figsize=(fig_width, fig_height))

            # Set transparent background if requested
            if transparent:
                fig.patch.set_facecolor('none')
                fig.patch.set_alpha(0)
                ax.set_facecolor('none')

            # Sort the data
            diff_before_sorted = np.sort(pooled_diff_before) if len(pooled_diff_before) > 0 else []
            diff_after_sorted = np.sort(pooled_diff_after) if len(pooled_diff_after) > 0 else []

            def ecdf(data):
                if len(data) == 0:
                    # No data: just plot a flat line at zero
                    x = [x_min, x_max]
                    y = [0.0, 0.0]
                else:
                    # Compute normal ECDF
                    y = np.arange(1, len(data) + 1) / len(data)
                    # Add starting point (x_min,0.0) and ending point (x_max,1.0)
                    x = np.concatenate([[x_min], data, [x_max]])
                    y = np.concatenate([[0.0], y, [1.0]])
                return x, y

            x_before, y_before = ecdf(diff_before_sorted)
            x_after, y_after = ecdf(diff_after_sorted)

            # Plot ECDFs
            ax.step(x_before, y_before, color="gray", label="Before", where=step_mode)
            ax.step(x_after, y_after, color="blue", label="After", where=step_mode)

            # Set axis limits
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(0, 1.0)

            # Remove spines if requested
            if remove_spines:
                for spine in ax.spines.values():
                    spine.set_visible(False)

            ax.set_xlabel("Fraction Difference (Early - Late)", fontsize=12)
            ax.set_ylabel("Cumulative Probability", fontsize=12)
            ax.set_title("Pooled ECDF of Fraction Differences (Before vs. After)", fontsize=14)
            ax.legend()

            fig.savefig(svg_path, format="svg", bbox_inches='tight', transparent=transparent)
            plt.close(fig)

            print(f"Saved pooled ECDF SVG for group '{group}' to: {svg_path}")


    def export_all_groups_to_svgs(self, output_dir,
                                before_label="Before", after_label="After",
                                sweep_numbers=None, startAtSec=0, endAtSec=1.5,
                                offsetXsec=0.3, offsetYunits=40,
                                color_before=None, color_after="red",
                                alpha=0.5, linewidth=1,
                                dpi=300, add_suptitle=True,
                                scaleXms=200, scaleYmV=50):
        """
        Export .svg files for each recording in each group.
        Each group will have its own folder, and inside it each recording
        will have its own folder. One .svg file per recording.

        Args:
            output_dir (str): Base directory to save all the .svg files.
            before_label (str): Label for the "Before" condition.
            after_label (str): Label for the "After" condition.
            sweep_numbers (list or None): Sweeps to plot. None = all sweeps.
            startAtSec (float): Start time for plotting.
            endAtSec (float): End time for plotting.
            offsetXsec (float): Horizontal offset per sweep.
            offsetYunits (float): Vertical offset per sweep.
            color_before (str or None): Color for "Before" sweeps.
            color_after (str): Color for "After" sweeps.
            alpha (float): Transparency of lines.
            linewidth (float): Width of the plotted lines.
            dpi (int): DPI for potential rasterization (mainly for embedded images).
            add_suptitle (bool): Whether to add a suptitle with Group and Recording ID.
            scaleXms (float): Horizontal scale bar length in ms.
            scaleYmV (float): Vertical scale bar height in mV.
        """
        # Ensure text remains editable (not converted to paths)
        plt.rcParams["svg.fonttype"] = "none"
        
        os.makedirs(output_dir, exist_ok=True)

        for grp in self.unique_groups:
            group_dir = os.path.join(output_dir, grp)
            os.makedirs(group_dir, exist_ok=True)

            recording_ids = self.get_recording_ids(grp)
            for recording_id in recording_ids:
                before_entry = self.dataframe[(self.dataframe["Group"] == grp) &
                                            (self.dataframe["Recording_ID"] == recording_id) &
                                            (self.dataframe["Label"] == before_label)]
                after_entry = self.dataframe[(self.dataframe["Group"] == grp) &
                                            (self.dataframe["Recording_ID"] == recording_id) &
                                            (self.dataframe["Label"] == after_label)]

                # Skip if we don't have both Before and After
                if before_entry.empty or after_entry.empty:
                    continue

                # Create a folder for each recording
                recording_dir = os.path.join(group_dir, recording_id)
                os.makedirs(recording_dir, exist_ok=True)

                # Create the figure
                fig, axes = plt.subplots(2, 1, figsize=(4, 10), sharex=True)

                # Plot Before
                self.plot_sweeps(axes[0], grp, recording_id, before_label,
                                sweep_numbers=sweep_numbers,
                                offsetXsec=offsetXsec, offsetYunits=offsetYunits,
                                startAtSec=startAtSec, endAtSec=endAtSec,
                                color=color_before, alpha=alpha, linewidth=linewidth,
                                hideAxis=True)

                # Plot After
                self.plot_sweeps(axes[1], grp, recording_id, after_label,
                                sweep_numbers=sweep_numbers,
                                offsetXsec=offsetXsec, offsetYunits=offsetYunits,
                                startAtSec=startAtSec, endAtSec=endAtSec,
                                color=color_after, alpha=alpha, linewidth=linewidth,
                                hideAxis=True)

                if add_suptitle:
                    fig.suptitle(f"Group: {grp}, Recording: {recording_id}", fontsize=14)

                plt.tight_layout()

                # Add a fixed scalebar to the bottom axis
                self.plot_scalebar(axes[1], scaleXms=scaleXms, scaleYmV=scaleYmV,
                                hideTicks=True, hideFrame=True)

                # Save as SVG
                svg_path = os.path.join(recording_dir, f"{recording_id}.svg")
                fig.savefig(svg_path, format="svg", dpi=dpi, bbox_inches='tight')
                plt.close(fig)

    def import_csv_and_plot_mean_peaks_with_error_bars_svg(self, csv_path, output_dir,
                                                        fig_width=10, fig_height=6,
                                                        ymin=None, ymax=None):
        """
        Import a CSV file containing peak data and plot the mean AP counts
        with SEM for "Before" (grey) and "After" (blue) at each sweep number as error bars.
        Save as .svg files, one per group, with a transparent background and customizable figure size.
        Allows specifying y-axis limits to keep the same scale across plots.

        Args:
            csv_path (str): Path to the CSV file containing peak data.
            output_dir (str): Directory to save the output .svg files.
            fig_width (float): Width of the figure in inches.
            fig_height (float): Height of the figure in inches.
            ymin (float or None): Minimum y-limit. If None, it will auto-scale.
            ymax (float or None): Maximum y-limit. If None, it will auto-scale.

        Returns:
            None
        """
        # Ensure text remains editable in SVG
        plt.rcParams["svg.fonttype"] = "none"

        # Load the CSV into a DataFrame
        try:
            peak_data = pd.read_csv(csv_path)
        except FileNotFoundError:
            print(f"File not found: {csv_path}")
            return

        # Validate the required columns
        required_columns = ["Group", "Recording_ID", "Label", "Sweep_Number", "AP_Count"]
        if not all(col in peak_data.columns for col in required_columns):
            print("The CSV file is missing one or more required columns.")
            return

        # Ensure the output directory exists
        os.makedirs(output_dir, exist_ok=True)

        # Get unique groups
        groups = peak_data["Group"].unique()
        for group in groups:
            group_data = peak_data[peak_data["Group"] == group]

            # Compute mean and SEM for "Before" and "After" by Sweep_Number
            before_data = group_data[group_data["Label"] == "Before"]
            after_data = group_data[group_data["Label"] == "After"]

            mean_before = before_data.groupby("Sweep_Number")["AP_Count"].mean()
            sem_before = before_data.groupby("Sweep_Number")["AP_Count"].sem()

            mean_after = after_data.groupby("Sweep_Number")["AP_Count"].mean()
            sem_after = after_data.groupby("Sweep_Number")["AP_Count"].sem()

            # Plot the data with the specified figure size
            fig, ax = plt.subplots(figsize=(fig_width, fig_height))
            
            # Set figure and axis backgrounds to transparent
            fig.patch.set_facecolor('none')
            fig.patch.set_alpha(0)
            ax.set_facecolor('none')

            sweep_numbers = mean_before.index

            # Plot "Before" means and SEM as error bars (grey)
            ax.errorbar(
                sweep_numbers - 0.2,  # Offset "Before" slightly to the left
                mean_before,
                yerr=sem_before,
                fmt="o",  # Circle markers
                label="Before",
                color="gray",
                capsize=5,
                markersize=8,
            )

            # Plot "After" means and SEM as error bars (blue)
            ax.errorbar(
                sweep_numbers + 0.2,  # Offset "After" slightly to the right
                mean_after,
                yerr=sem_after,
                fmt="o",  # Circle markers
                label="After",
                color="blue",
                capsize=5,
                markersize=8,
            )

            # Remove grid lines, keep axis lines
            ax.grid(False)

            # Set y-limits if provided
            if ymin is not None or ymax is not None:
                ax.set_ylim(ymin, ymax)

            # Add labels and legend
            ax.set_xlabel("Sweep Number")
            ax.set_ylabel("Mean AP Count (± SEM)")
            ax.legend()

            # Create a subdirectory for the group
            group_dir = os.path.join(output_dir, group)
            os.makedirs(group_dir, exist_ok=True)

            svg_path = os.path.join(group_dir, f"{group}_mean_peaks.svg")
            fig.savefig(svg_path, format="svg", bbox_inches='tight', transparent=True)
            plt.close(fig)

        print(f"Saved error bar plots as transparent SVG files in: {output_dir}")
  
    def create_group_pooled_mean_and_individual_traces_ecdf(
        self,
        output_dir,
        height=None,
        prominence=None,
        distance=None,
        width=None,
        time_range=None,
        early_phase=(0.1, 0.40),
        late_phase=(0.5, 1.12),
        x_min=-1.0,
        x_max=1.0,
        step_mode="post",
        fig_width=8,
        fig_height=6,
        remove_spines=True,
        transparent=True,
        individual_line_alpha=0.3,
        individual_line_style=":",
        individual_line_width=1
    ):
        """
        Similar to create_group_pooled_ecdf, but also plots individual ECDF traces for each recording 
        that contributed to the pooled ECDF. The pooled ECDF lines ("Before" and "After") are plotted 
        on top, and individual recordings' ECDF lines are shown underneath with thinner, dotted lines 
        and higher transparency.

        Parameters
        ----------
        output_dir : str
            Directory to save the resulting SVGs.
        height : float or None
            Peak detection parameter passed to process_sweep_data.
        prominence : float or None
            Peak detection parameter passed to process_sweep_data.
        distance : float or None
            Peak detection parameter passed to process_sweep_data.
        width : float or None
            Peak detection parameter passed to process_sweep_data.
        time_range : tuple or None
            Time range for analyzing data. If None, use full sweep.
        early_phase : tuple (start, end)
            Time window defining the early phase fraction of the sweep.
        late_phase : tuple (start, end)
            Time window defining the late phase fraction of the sweep.
        x_min : float
            Minimum x-value for plotting ECDF.
        x_max : float
            Maximum x-value for plotting ECDF.
        step_mode : str
            How the step plot is drawn. Default "post".
            Options: "pre", "post", "mid".
        fig_width : float
            Width of the figure in inches.
        fig_height : float
            Height of the figure in inches.
        remove_spines : bool
            If True, removes axis spines for a cleaner look.
        transparent : bool
            If True, sets the plot background to transparent.
        individual_line_alpha : float
            Transparency for the individual ECDF lines.
        individual_line_style : str
            Line style for individual ECDF lines (e.g., ":", "--", "-.").
        individual_line_width : float
            Line width for individual ECDF lines.
        """

        if self.dataframe is None or self.dataframe.empty:
            print("No data to plot. Run `process_data` first.")
            return

        # Ensure text remains editable in SVG
        plt.rcParams["svg.fonttype"] = "none"
        os.makedirs(output_dir, exist_ok=True)

        for group in self.unique_groups:
            group_data = self.dataframe[self.dataframe["Group"] == group]
            recording_ids = group_data["Recording_ID"].unique()

            pooled_diff_before = []
            pooled_diff_after = []

            # Lists to hold individual recordings' differences for separate plotting
            individual_diffs_before = []
            individual_diffs_after = []

            for recording_id in recording_ids:
                before_entry = group_data[(group_data["Recording_ID"] == recording_id) & (group_data["Label"] == "Before")]
                after_entry = group_data[(group_data["Recording_ID"] == recording_id) & (group_data["Label"] == "After")]

                if before_entry.empty or after_entry.empty:
                    print(f"Skipping incomplete pair for Recording ID: {recording_id} in group {group}")
                    continue

                before_file = before_entry["File_Path"].iloc[0]
                after_file = after_entry["File_Path"].iloc[0]
                before_abf = pyabf.ABF(before_file)
                after_abf = pyabf.ABF(after_file)

                bar_data = {
                    "sweep": [],
                    "early_before": [],
                    "late_before": [],
                    "early_after": [],
                    "late_after": []
                }

                # Process each sweep and accumulate early/late spike counts
                for sweep_number in range(len(before_abf.sweepList)):
                    time_before, voltage_before, peaks_before, early_before, late_before = process_sweep_data(
                        before_abf, sweep_number, height, prominence, distance, width, time_range, early_phase, late_phase
                    )
                    time_after, voltage_after, peaks_after, early_after, late_after = process_sweep_data(
                        after_abf, sweep_number, height, prominence, distance, width, time_range, early_phase, late_phase
                    )

                    bar_data["sweep"].append(sweep_number + 1)
                    bar_data["early_before"].append(early_before)
                    bar_data["late_before"].append(late_before)
                    bar_data["early_after"].append(early_after)
                    bar_data["late_after"].append(late_after)

                # Compute fractions for this recording
                fractions = calculate_fractions(bar_data)
                diff_before = np.array(fractions["early_fraction_before"]) - np.array(fractions["late_fraction_before"])
                diff_after = np.array(fractions["early_fraction_after"]) - np.array(fractions["late_fraction_after"])

                # Filter out NaNs
                diff_before = diff_before[~np.isnan(diff_before)]
                diff_after = diff_after[~np.isnan(diff_after)]

                # Append to pooled lists
                pooled_diff_before.extend(diff_before)
                pooled_diff_after.extend(diff_after)

                # Store individual differences for later plotting
                individual_diffs_before.append(diff_before)
                individual_diffs_after.append(diff_after)

            # Now plot the pooled ECDF for this group along with individual traces
            svg_path = os.path.join(output_dir, f"{group}_pooled_ecdf_individual.svg")
            fig, ax = plt.subplots(figsize=(fig_width, fig_height))

            if transparent:
                fig.patch.set_facecolor('none')
                fig.patch.set_alpha(0)
                ax.set_facecolor('none')

            def ecdf(data):
                if len(data) == 0:
                    # No data: flat line at zero
                    x = [x_min, x_max]
                    y = [0.0, 0.0]
                else:
                    # Compute ECDF
                    data_sorted = np.sort(data)
                    y = np.arange(1, len(data_sorted) + 1) / len(data_sorted)
                    x = np.concatenate([[x_min], data_sorted, [x_max]])
                    y = np.concatenate([[0.0], y, [1.0]])
                return x, y

            # Plot individual traces first (Before)
            for diffs in individual_diffs_before:
                x_i, y_i = ecdf(diffs)
                ax.step(x_i, y_i, where=step_mode,
                        color="gray", alpha=individual_line_alpha,
                        linestyle=individual_line_style, linewidth=individual_line_width)

            # Plot individual traces for After
            for diffs in individual_diffs_after:
                x_i, y_i = ecdf(diffs)
                ax.step(x_i, y_i, where=step_mode,
                        color="blue", alpha=individual_line_alpha,
                        linestyle=individual_line_style, linewidth=individual_line_width)

            # Now plot pooled ECDF lines on top
            x_before, y_before = ecdf(pooled_diff_before)
            x_after, y_after = ecdf(pooled_diff_after)

            ax.step(x_before, y_before, color="gray", label="Before (Pooled)", where=step_mode, linewidth=2)
            ax.step(x_after, y_after, color="blue", label="After (Pooled)", where=step_mode, linewidth=2)

            # Set axis limits
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(0, 1.0)

            if remove_spines:
                for spine in ax.spines.values():
                    spine.set_visible(False)

            ax.set_xlabel("Fraction Difference (Early - Late)", fontsize=12)
            ax.set_ylabel("Cumulative Probability", fontsize=12)
            ax.set_title("Pooled ECDF with Individual Traces (Before vs. After)", fontsize=14)
            ax.legend()

            fig.savefig(svg_path, format="svg", bbox_inches='tight', transparent=transparent)
            plt.close(fig)

            print(f"Saved pooled ECDF with individual traces for group '{group}' to: {svg_path}")
  
    def create_group_pooled_mean_and_individual_sigmoid_ecdf(
        self,
        output_dir,
        height=None,
        prominence=None,
        distance=None,
        width=None,
        time_range=None,
        early_phase=(0.1, 0.40),
        late_phase=(0.5, 1.12),
        x_min=-1.0,
        x_max=1.0,
        fig_width=8,
        fig_height=6,
        remove_spines=True,
        transparent=True,
        individual_line_alpha=0.3,
        individual_line_style=":",
        individual_line_width=1,
        main_line_width=2
    ):
        """
        Similar to create_group_pooled_mean_and_individual_traces_ecdf, but instead of plotting a step ECDF, 
        we fit and plot a sigmoid (logistic) curve to both the individual and pooled distributions.

        Parameters
        ----------
        output_dir : str
            Directory to save the resulting SVGs.
        height, prominence, distance, width : Peak detection parameters.
        time_range : tuple or None
            Time range for analyzing data. If None, use full sweep.
        early_phase : tuple (start, end)
            Time window defining the early phase fraction.
        late_phase : tuple (start, end)
            Time window defining the late phase fraction.
        x_min, x_max : float
            Limits for the x-axis.
        fig_width, fig_height : float
            Dimensions of the figure.
        remove_spines : bool
            If True, remove axis spines.
        transparent : bool
            If True, make background transparent.
        individual_line_alpha : float
            Transparency for individual ECDF lines.
        individual_line_style : str
            Line style for individual ECDF lines.
        individual_line_width : float
            Line width for individual ECDF lines.
        main_line_width : float
            Line width for the pooled mean sigmoid lines.

        The function calculates fraction differences for Before and After for each recording, 
        fits a logistic curve to their cumulative distribution, and plots them along with the pooled distribution.
        """

        if self.dataframe is None or self.dataframe.empty:
            print("No data to plot. Run `process_data` first.")
            return

        # Ensure text remains editable in SVG
        plt.rcParams["svg.fonttype"] = "none"
        os.makedirs(output_dir, exist_ok=True)

        # Logistic (sigmoid) function
        def logistic(x, x0, k):
            return 1.0 / (1.0 + np.exp(-(x - x0) / k))

        def fit_sigmoid(x_data, y_data):
            # Initial guesses for x0 and k
            # x0 ~ median of data, k based on rough guess
            x0_guess = np.median(x_data)
            k_guess = (x_max - x_min) / 10
            try:
                popt, _ = curve_fit(logistic, x_data, y_data, p0=[x0_guess, k_guess])
                return popt
            except:
                # If fitting fails, return default parameters
                return (x0_guess, k_guess)

        # Function to compute ECDF points (x,y)
        # Then we fit a logistic curve to these points
        def compute_and_fit(data):
            if len(data) == 0:
                # No data, return a flat line
                xs = np.linspace(x_min, x_max, 200)
                ys = np.zeros_like(xs)
                return xs, ys

            # Sort data
            data_sorted = np.sort(data)
            y = np.arange(1, len(data_sorted) + 1) / len(data_sorted)
            x_full = np.concatenate([[x_min], data_sorted, [x_max]])
            y_full = np.concatenate([[0.0], y, [1.0]])

            # Fit logistic curve to these ECDF points
            # Use the actual data points (excluding artificial endpoints) for better fitting
            popt = fit_sigmoid(data_sorted, y)

            # Generate a smooth x array and compute the fitted logistic
            xs_smooth = np.linspace(x_min, x_max, 200)
            ys_smooth = logistic(xs_smooth, *popt)
            return xs_smooth, ys_smooth

        for group in self.unique_groups:
            group_data = self.dataframe[self.dataframe["Group"] == group]
            recording_ids = group_data["Recording_ID"].unique()

            pooled_diff_before = []
            pooled_diff_after = []

            # Lists to hold individual recordings' differences
            individual_diffs_before = []
            individual_diffs_after = []

            for recording_id in recording_ids:
                before_entry = group_data[(group_data["Recording_ID"] == recording_id) & (group_data["Label"] == "Before")]
                after_entry = group_data[(group_data["Recording_ID"] == recording_id) & (group_data["Label"] == "After")]

                if before_entry.empty or after_entry.empty:
                    print(f"Skipping incomplete pair for Recording ID: {recording_id} in group {group}")
                    continue

                before_file = before_entry["File_Path"].iloc[0]
                after_file = after_entry["File_Path"].iloc[0]
                before_abf = pyabf.ABF(before_file)
                after_abf = pyabf.ABF(after_file)

                bar_data = {
                    "sweep": [],
                    "early_before": [],
                    "late_before": [],
                    "early_after": [],
                    "late_after": []
                }

                # Process each sweep
                for sweep_number in range(len(before_abf.sweepList)):
                    time_before, voltage_before, peaks_before, early_before, late_before = process_sweep_data(
                        before_abf, sweep_number, height, prominence, distance, width, time_range, early_phase, late_phase
                    )
                    time_after, voltage_after, peaks_after, early_after, late_after = process_sweep_data(
                        after_abf, sweep_number, height, prominence, distance, width, time_range, early_phase, late_phase
                    )

                    bar_data["sweep"].append(sweep_number + 1)
                    bar_data["early_before"].append(early_before)
                    bar_data["late_before"].append(late_before)
                    bar_data["early_after"].append(early_after)
                    bar_data["late_after"].append(late_after)

                fractions = calculate_fractions(bar_data)
                diff_before = np.array(fractions["early_fraction_before"]) - np.array(fractions["late_fraction_before"])
                diff_after = np.array(fractions["early_fraction_after"]) - np.array(fractions["late_fraction_after"])

                # Filter out NaNs
                diff_before = diff_before[~np.isnan(diff_before)]
                diff_after = diff_after[~np.isnan(diff_after)]

                # Append to pooled lists
                pooled_diff_before.extend(diff_before)
                pooled_diff_after.extend(diff_after)

                # Store individual differences
                individual_diffs_before.append(diff_before)
                individual_diffs_after.append(diff_after)

            # Now plot the pooled ECDF for this group along with individual traces as sigmoid
            svg_path = os.path.join(output_dir, f"{group}_pooled_sigmoid_ecdf_individual.svg")
            fig, ax = plt.subplots(figsize=(fig_width, fig_height))

            if transparent:
                fig.patch.set_facecolor('none')
                fig.patch.set_alpha(0)
                ax.set_facecolor('none')

            # Plot individual traces (Before)
            for diffs in individual_diffs_before:
                xs, ys = compute_and_fit(diffs)
                ax.plot(xs, ys,
                        color="gray", alpha=individual_line_alpha,
                        linestyle=individual_line_style, linewidth=individual_line_width)

            # Plot individual traces (After)
            for diffs in individual_diffs_after:
                xs, ys = compute_and_fit(diffs)
                ax.plot(xs, ys,
                        color="blue", alpha=individual_line_alpha,
                        linestyle=individual_line_style, linewidth=individual_line_width)

            # Pooled data
            xs_before, ys_before = compute_and_fit(pooled_diff_before)
            xs_after, ys_after = compute_and_fit(pooled_diff_after)

            # Plot pooled lines on top
            ax.plot(xs_before, ys_before, color="gray", label="Before (Pooled)", linewidth=main_line_width)
            ax.plot(xs_after, ys_after, color="blue", label="After (Pooled)", linewidth=main_line_width)

            # Set axis limits
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(0, 1.0)

            if remove_spines:
                for spine in ax.spines.values():
                    spine.set_visible(False)

            ax.set_xlabel("Fraction Difference (Early - Late)", fontsize=12)
            ax.set_ylabel("Cumulative Probability", fontsize=12)
            ax.set_title("Pooled Sigmoid ECDF with Individual Traces (Before vs. After)", fontsize=14)
            ax.legend()

            fig.savefig(svg_path, format="svg", bbox_inches='tight', transparent=transparent)
            plt.close(fig)

            print(f"Saved pooled sigmoid ECDF with individual traces for group '{group}' to: {svg_path}")
 
    def create_group_pooled_mean_and_individual_sigmoid_ecdf(
        self,
        output_dir,
        height=None,
        prominence=None,
        distance=None,
        width=None,
        time_range=None,
        early_phase=(0.1, 0.40),
        late_phase=(0.5, 1.12),
        x_min=-1.0,
        x_max=1.0,
        fig_width=8,
        fig_height=6,
        remove_spines=True,
        transparent=True,
        individual_line_alpha=0.3,
        individual_line_style=":",
        individual_line_width=1,
        main_line_width=2,
        smoothing_factor=1.0
    ):
        """
        Create a pooled ECDF plot with individual traces as sigmoid (logistic) curves.
        The smoothing_factor parameter scales the fitted 'k' parameter of the logistic,
        allowing control over the steepness/gradualness of the sigmoid curves.

        Parameters
        ----------
        output_dir : str
            Directory to save the resulting SVGs.
        height, prominence, distance, width : Peak detection parameters.
        time_range : tuple or None
            Time range for analyzing data. If None, use full sweep.
        early_phase : tuple (start, end)
            Time window defining the early phase fraction.
        late_phase : tuple (start, end)
            Time window defining the late phase fraction.
        x_min, x_max : float
            Limits for the x-axis.
        fig_width, fig_height : float
            Dimensions of the figure.
        remove_spines : bool
            If True, remove axis spines.
        transparent : bool
            If True, make background transparent.
        individual_line_alpha : float
            Transparency for individual ECDF lines.
        individual_line_style : str
            Line style for individual ECDF lines.
        individual_line_width : float
            Line width for individual ECDF lines.
        main_line_width : float
            Line width for the pooled mean sigmoid lines.
        smoothing_factor : float
            Factor by which to multiply the fitted 'k' parameter to adjust smoothing.
            >1 makes the curve more gradual (smoother), <1 makes it steeper.
        """

        if self.dataframe is None or self.dataframe.empty:
            print("No data to plot. Run `process_data` first.")
            return

        # Ensure text remains editable in SVG
        plt.rcParams["svg.fonttype"] = "none"
        os.makedirs(output_dir, exist_ok=True)

        # Logistic (sigmoid) function
        def logistic(x, x0, k):
            return 1.0 / (1.0 + np.exp(-(x - x0) / k))

        def fit_sigmoid(x_data, y_data):
            x0_guess = np.median(x_data)
            k_guess = (x_max - x_min) / 10
            try:
                popt, _ = curve_fit(logistic, x_data, y_data, p0=[x0_guess, k_guess])
                # Adjust k by smoothing_factor
                popt = (popt[0], popt[1] * smoothing_factor)
                return popt
            except:
                # If fitting fails, return defaults adjusted by smoothing_factor
                return (x0_guess, k_guess * smoothing_factor)

        def compute_and_fit(data):
            if len(data) == 0:
                # No data, return a flat line
                xs = np.linspace(x_min, x_max, 200)
                ys = np.zeros_like(xs)
                return xs, ys

            # Sort data
            data_sorted = np.sort(data)
            y = np.arange(1, len(data_sorted) + 1) / len(data_sorted)

            # Fit logistic curve to actual data points (no artificial endpoints for fitting)
            popt = fit_sigmoid(data_sorted, y)

            # Generate a smooth x array and compute the fitted logistic
            xs_smooth = np.linspace(x_min, x_max, 200)
            ys_smooth = logistic(xs_smooth, *popt)
            return xs_smooth, ys_smooth

        for group in self.unique_groups:
            group_data = self.dataframe[self.dataframe["Group"] == group]
            recording_ids = group_data["Recording_ID"].unique()

            pooled_diff_before = []
            pooled_diff_after = []

            individual_diffs_before = []
            individual_diffs_after = []

            for recording_id in recording_ids:
                before_entry = group_data[(group_data["Recording_ID"] == recording_id) & (group_data["Label"] == "Before")]
                after_entry = group_data[(group_data["Recording_ID"] == recording_id) & (group_data["Label"] == "After")]

                if before_entry.empty or after_entry.empty:
                    print(f"Skipping incomplete pair for Recording ID: {recording_id} in group {group}")
                    continue

                before_file = before_entry["File_Path"].iloc[0]
                after_file = after_entry["File_Path"].iloc[0]
                before_abf = pyabf.ABF(before_file)
                after_abf = pyabf.ABF(after_file)

                bar_data = {
                    "sweep": [],
                    "early_before": [],
                    "late_before": [],
                    "early_after": [],
                    "late_after": []
                }

                for sweep_number in range(len(before_abf.sweepList)):
                    time_before, voltage_before, peaks_before, early_before, late_before = process_sweep_data(
                        before_abf, sweep_number, height, prominence, distance, width, time_range, early_phase, late_phase
                    )
                    time_after, voltage_after, peaks_after, early_after, late_after = process_sweep_data(
                        after_abf, sweep_number, height, prominence, distance, width, time_range, early_phase, late_phase
                    )

                    bar_data["sweep"].append(sweep_number + 1)
                    bar_data["early_before"].append(early_before)
                    bar_data["late_before"].append(late_before)
                    bar_data["early_after"].append(early_after)
                    bar_data["late_after"].append(late_after)

                fractions = calculate_fractions(bar_data)
                diff_before = np.array(fractions["early_fraction_before"]) - np.array(fractions["late_fraction_before"])
                diff_after = np.array(fractions["early_fraction_after"]) - np.array(fractions["late_fraction_after"])

                diff_before = diff_before[~np.isnan(diff_before)]
                diff_after = diff_after[~np.isnan(diff_after)]

                pooled_diff_before.extend(diff_before)
                pooled_diff_after.extend(diff_after)

                individual_diffs_before.append(diff_before)
                individual_diffs_after.append(diff_after)

            svg_path = os.path.join(output_dir, f"{group}_pooled_sigmoid_ecdf_individual.svg")
            fig, ax = plt.subplots(figsize=(fig_width, fig_height))

            if transparent:
                fig.patch.set_facecolor('none')
                fig.patch.set_alpha(0)
                ax.set_facecolor('none')

            # Plot individual traces (Before)
            for diffs in individual_diffs_before:
                xs, ys = compute_and_fit(diffs)
                ax.plot(xs, ys, color="gray", alpha=individual_line_alpha,
                        linestyle=individual_line_style, linewidth=individual_line_width)

            # Plot individual traces (After)
            for diffs in individual_diffs_after:
                xs, ys = compute_and_fit(diffs)
                ax.plot(xs, ys, color="blue", alpha=individual_line_alpha,
                        linestyle=individual_line_style, linewidth=individual_line_width)

            # Pooled data
            xs_before, ys_before = compute_and_fit(pooled_diff_before)
            xs_after, ys_after = compute_and_fit(pooled_diff_after)

            # Plot pooled lines on top
            ax.plot(xs_before, ys_before, color="gray", label="Before (Pooled)", linewidth=main_line_width)
            ax.plot(xs_after, ys_after, color="blue", label="After (Pooled)", linewidth=main_line_width)

            ax.set_xlim(x_min, x_max)
            ax.set_ylim(0, 1.0)

            if remove_spines:
                for spine in ax.spines.values():
                    spine.set_visible(False)

            ax.set_xlabel("Fraction Difference (Early - Late)", fontsize=12)
            ax.set_ylabel("Cumulative Probability", fontsize=12)
            ax.set_title("Pooled Sigmoid ECDF with Individual Traces (Before vs. After)", fontsize=14)
            ax.legend()

            fig.savefig(svg_path, format="svg", bbox_inches='tight', transparent=transparent)
            plt.close(fig)

            print(f"Saved pooled sigmoid ECDF with individual traces for group '{group}' to: {svg_path}") 
                     
    def create_group_pooled_mean_and_individual_smoothed_ecdf(
        self,
        output_dir,
        height=None,
        prominence=None,
        distance=None,
        width=None,
        time_range=None,
        early_phase=(0.1, 0.40),
        late_phase=(0.5, 1.12),
        x_min=-1.0,
        x_max=1.0,
        fig_width=8,
        fig_height=6,
        remove_spines=True,
        transparent=True,
        individual_line_alpha=0.3,
        individual_line_style=":",
        individual_line_width=1,
        main_line_width=2,
        smoothing_window=5
    ):
        """
        Create a pooled ECDF plot with individual traces, applying a simple moving average 
        smoothing directly onto the ECDF (no sigmoid, no external filters).

        Parameters
        ----------
        output_dir : str
            Directory to save the resulting SVGs.
        height, prominence, distance, width : Peak detection parameters.
        time_range : tuple or None
            Time range for analyzing data. If None, use full sweep.
        early_phase : tuple (start, end)
            Time window defining the early phase fraction.
        late_phase : tuple (start, end)
            Time window defining the late phase fraction.
        x_min : float
            Minimum x-value for plotting ECDF.
        x_max : float
            Maximum x-value for plotting ECDF.
        fig_width, fig_height : float
            Dimensions of the figure in inches.
        remove_spines : bool
            If True, remove axis spines.
        transparent : bool
            If True, make background transparent.
        individual_line_alpha : float
            Transparency for individual ECDF lines.
        individual_line_style : str
            Line style for individual ECDF lines.
        individual_line_width : float
            Line width for individual ECDF lines.
        main_line_width : float
            Line width for the pooled mean lines.
        smoothing_window : int
            Size of the moving average window for smoothing the ECDF.
            Must be an odd number greater than 1 for best results.

        The method calculates ECDF points for each dataset and then applies a 
        simple moving average to the ECDF y-values to produce a smoother curve.
        """

        if self.dataframe is None or self.dataframe.empty:
            print("No data to plot. Run `process_data` first.")
            return

        # Ensure text remains editable in SVG
        plt.rcParams["svg.fonttype"] = "none"
        os.makedirs(output_dir, exist_ok=True)

        def compute_ecdf(data):
            if len(data) == 0:
                # No data: flat line at zero
                x = np.linspace(x_min, x_max, 200)
                y = np.zeros_like(x)
                return x, y
            data_sorted = np.sort(data)
            y = np.arange(1, len(data_sorted) + 1) / len(data_sorted)
            # Add boundaries at x_min and x_max
            x_full = np.concatenate([[x_min], data_sorted, [x_max]])
            y_full = np.concatenate([[0.0], y, [1.0]])
            return x_full, y_full

        def moving_average_smooth(y, window_size):
            # Simple moving average:
            # For each point, take the mean of points around it defined by the window.
            # We'll pad the ends so we can average uniformly.
            half_win = window_size // 2
            padded_y = np.pad(y, (half_win, half_win), mode='edge')
            y_smooth = np.empty_like(y)
            for i in range(len(y)):
                y_smooth[i] = np.mean(padded_y[i:i+window_size])
            return y_smooth

        def smooth_ecdf(x, y, window_size):
            # Ensure window_size is valid
            if window_size < 3 or window_size > len(x) or window_size % 2 == 0:
                # If invalid, return original
                return x, y
            y_smooth = moving_average_smooth(y, window_size)
            return x, y_smooth

        for group in self.unique_groups:
            group_data = self.dataframe[self.dataframe["Group"] == group]
            recording_ids = group_data["Recording_ID"].unique()

            pooled_diff_before = []
            pooled_diff_after = []

            individual_diffs_before = []
            individual_diffs_after = []

            for recording_id in recording_ids:
                before_entry = group_data[(group_data["Recording_ID"] == recording_id) & (group_data["Label"] == "Before")]
                after_entry = group_data[(group_data["Recording_ID"] == recording_id) & (group_data["Label"] == "After")]

                if before_entry.empty or after_entry.empty:
                    print(f"Skipping incomplete pair for Recording ID: {recording_id} in group {group}")
                    continue

                before_file = before_entry["File_Path"].iloc[0]
                after_file = after_entry["File_Path"].iloc[0]
                before_abf = pyabf.ABF(before_file)
                after_abf = pyabf.ABF(after_file)

                bar_data = {
                    "sweep": [],
                    "early_before": [],
                    "late_before": [],
                    "early_after": [],
                    "late_after": []
                }

                # Process each sweep
                for sweep_number in range(len(before_abf.sweepList)):
                    time_before, voltage_before, peaks_before, early_before, late_before = process_sweep_data(
                        before_abf, sweep_number, height, prominence, distance, width, time_range, early_phase, late_phase
                    )
                    time_after, voltage_after, peaks_after, early_after, late_after = process_sweep_data(
                        after_abf, sweep_number, height, prominence, distance, width, time_range, early_phase, late_phase
                    )

                    bar_data["sweep"].append(sweep_number + 1)
                    bar_data["early_before"].append(early_before)
                    bar_data["late_before"].append(late_before)
                    bar_data["early_after"].append(early_after)
                    bar_data["late_after"].append(late_after)

                fractions = calculate_fractions(bar_data)
                diff_before = np.array(fractions["early_fraction_before"]) - np.array(fractions["late_fraction_before"])
                diff_after = np.array(fractions["early_fraction_after"]) - np.array(fractions["late_fraction_after"])

                diff_before = diff_before[~np.isnan(diff_before)]
                diff_after = diff_after[~np.isnan(diff_after)]

                pooled_diff_before.extend(diff_before)
                pooled_diff_after.extend(diff_after)

                individual_diffs_before.append(diff_before)
                individual_diffs_after.append(diff_after)

            svg_path = os.path.join(output_dir, f"{group}_pooled_smoothed_ecdf_individual.svg")
            fig, ax = plt.subplots(figsize=(fig_width, fig_height))

            if transparent:
                fig.patch.set_facecolor('none')
                fig.patch.set_alpha(0)
                ax.set_facecolor('none')

            # Plot individual traces (Before)
            for diffs in individual_diffs_before:
                x_i, y_i = compute_ecdf(diffs)
                x_i, y_i = smooth_ecdf(x_i, y_i, smoothing_window)
                ax.plot(x_i, y_i, color="gray", alpha=individual_line_alpha,
                        linestyle=individual_line_style, linewidth=individual_line_width)

            # Plot individual traces (After)
            for diffs in individual_diffs_after:
                x_i, y_i = compute_ecdf(diffs)
                x_i, y_i = smooth_ecdf(x_i, y_i, smoothing_window)
                ax.plot(x_i, y_i, color="blue", alpha=individual_line_alpha,
                        linestyle=individual_line_style, linewidth=individual_line_width)

            # Pooled data
            x_before, y_before = compute_ecdf(pooled_diff_before)
            x_before, y_before = smooth_ecdf(x_before, y_before, smoothing_window)
            x_after, y_after = compute_ecdf(pooled_diff_after)
            x_after, y_after = smooth_ecdf(x_after, y_after, smoothing_window)

            # Plot pooled lines on top
            ax.plot(x_before, y_before, color="gray", label="Before (Pooled)", linewidth=main_line_width)
            ax.plot(x_after, y_after, color="blue", label="After (Pooled)", linewidth=main_line_width)

            ax.set_xlim(x_min, x_max)
            ax.set_ylim(0, 1.0)

            if remove_spines:
                for spine in ax.spines.values():
                    spine.set_visible(False)

            ax.set_xlabel("Fraction Difference (Early - Late)", fontsize=12)
            ax.set_ylabel("Cumulative Probability", fontsize=12)
            ax.set_title("Pooled Smoothed ECDF with Individual Traces (Before vs. After)", fontsize=14)
            ax.legend()

            fig.savefig(svg_path, format="svg", bbox_inches='tight', transparent=transparent)
            plt.close(fig)

            print(f"Saved pooled smoothed ECDF with individual traces for group '{group}' to: {svg_path}")
            
    def compare_group_distributions(self):
        """
        Compute pooled group-level fraction difference distributions for Before and After conditions,
        store them in a DataFrame, and run a Kolmogorov-Smirnov test for each group.

        Returns:
            group_distribution_df (pd.DataFrame): 
                A DataFrame in long format with columns:
                ["Group", "Condition", "FractionDifference"].
                Each row is a single sweep's fraction difference value, pooled at the group level.
            
            summary_df (pd.DataFrame): 
                A DataFrame summarizing the KS test results for each group.
                Columns: ["Group", "N_Before_Values", "N_After_Values", "KS_Statistic", "KS_pValue"].
        """

        if self.dataframe is None or self.dataframe.empty:
            print("No data to analyze. Run `process_data` first.")
            return None, None

        # Lists to accumulate group-level fraction differences and results
        distribution_records = []
        summary_records = []

        for group in self.unique_groups:
            group_data = self.dataframe[self.dataframe["Group"] == group]
            recording_ids = group_data["Recording_ID"].unique()

            pooled_diff_before = []
            pooled_diff_after = []

            # Process each recording within the group
            for recording_id in recording_ids:
                before_entry = group_data[(group_data["Recording_ID"] == recording_id) & (group_data["Label"] == "Before")]
                after_entry = group_data[(group_data["Recording_ID"] == recording_id) & (group_data["Label"] == "After")]

                if before_entry.empty or after_entry.empty:
                    # Skip if we don't have both conditions for this recording
                    continue

                before_file = before_entry["File_Path"].iloc[0]
                after_file = after_entry["File_Path"].iloc[0]

                before_abf = pyabf.ABF(before_file)
                after_abf = pyabf.ABF(after_file)

                bar_data = {
                    "sweep": [],
                    "early_before": [],
                    "late_before": [],
                    "early_after": [],
                    "late_after": []
                }

                # Process each sweep for this recording
                for sweep_number in range(len(before_abf.sweepList)):
                    # Use your desired parameters for process_sweep_data:
                    time_before, voltage_before, peaks_before, early_before, late_before = process_sweep_data(
                        before_abf, sweep_number, height=None, prominence=None, distance=None, width=None,
                        time_range=None, early_phase=(0.1,0.40), late_phase=(0.5,1.12)
                    )
                    time_after, voltage_after, peaks_after, early_after, late_after = process_sweep_data(
                        after_abf, sweep_number, height=None, prominence=None, distance=None, width=None,
                        time_range=None, early_phase=(0.1,0.40), late_phase=(0.5,1.12)
                    )

                    bar_data["sweep"].append(sweep_number+1)
                    bar_data["early_before"].append(early_before)
                    bar_data["late_before"].append(late_before)
                    bar_data["early_after"].append(early_after)
                    bar_data["late_after"].append(late_after)

                # Calculate fractions and differences for this recording
                fractions = calculate_fractions(bar_data)
                diff_before = np.array(fractions["early_fraction_before"]) - np.array(fractions["late_fraction_before"])
                diff_after = np.array(fractions["early_fraction_after"]) - np.array(fractions["late_fraction_after"])

                # Filter out NaNs
                diff_before = diff_before[~np.isnan(diff_before)]
                diff_after = diff_after[~np.isnan(diff_after)]

                # Append to pooled group-level distributions
                pooled_diff_before.extend(diff_before)
                pooled_diff_after.extend(diff_after)

            # Now we have pooled distributions for this group
            # Store them in the distribution_records
            for val in pooled_diff_before:
                distribution_records.append({
                    "Group": group,
                    "Condition": "Before",
                    "FractionDifference": val
                })
            for val in pooled_diff_after:
                distribution_records.append({
                    "Group": group,
                    "Condition": "After",
                    "FractionDifference": val
                })

            # Run KS test if both distributions have data
            if len(pooled_diff_before) > 0 and len(pooled_diff_after) > 0:
                ks_stat, ks_pvalue = ks_2samp(pooled_diff_before, pooled_diff_after)
            else:
                ks_stat, ks_pvalue = np.nan, np.nan

            summary_records.append({
                "Group": group,
                "N_Before_Values": len(pooled_diff_before),
                "N_After_Values": len(pooled_diff_after),
                "KS_Statistic": ks_stat,
                "KS_pValue": ks_pvalue
            })

        # Convert lists to DataFrames
        group_distribution_df = pd.DataFrame(distribution_records, 
                                            columns=["Group", "Condition", "FractionDifference"])
        summary_df = pd.DataFrame(summary_records, 
                                columns=["Group", "N_Before_Values", "N_After_Values", "KS_Statistic", "KS_pValue"])

        # Return the DataFrames
        return group_distribution_df, summary_df
        
def plot_peaks(ax, time, voltage, peaks, label, color, title, xlabel=None, ylabel=None):
    """
    Plot peaks on a given axis.
    """
    ax.plot(time, voltage, label=label, color=color)
    ax.plot(time[peaks], voltage[peaks], "x", color="C3")
    ax.set_title(title)
    if xlabel:
        ax.set_xlabel(xlabel)
    if ylabel:
        ax.set_ylabel(ylabel)


def process_sweep_data(abf_file, sweep_number, height, prominence, distance, width, time_range, early_phase, late_phase):
    """
    Process a single sweep from an ABF file and extract AP counts for early and late phases.
    """
    abf_file.setSweep(sweepNumber=sweep_number)
    time = abf_file.sweepX
    voltage = abf_file.sweepY

    if time_range:
        mask = (time >= time_range[0]) & (time <= time_range[1])
        time = time[mask]
        voltage = voltage[mask]

    peaks, _ = find_peaks(
        voltage,
        height=height,
        prominence=prominence,
        distance=distance,
        width=width
    )

    early_mask = (time[peaks] >= early_phase[0]) & (time[peaks] <= early_phase[1])
    late_mask = (time[peaks] >= late_phase[0]) & (time[peaks] <= late_phase[1])

    return time, voltage, peaks, early_mask.sum(), late_mask.sum()


def generate_wide_plot_1(ax, bar_data):
    """
    Generate the first wide plot showing AP counts in early and late phases.
    """
    width = 0.2  # Bar width
    x = range(len(bar_data["sweep"]))
    ax.bar([p - 1.5 * width for p in x], bar_data["early_before"], width, label="Early Before", color="lightgray")
    ax.bar([p - 0.5 * width for p in x], bar_data["late_before"], width, label="Late Before", color="gray")
    ax.bar([p + 0.5 * width for p in x], bar_data["early_after"], width, label="Early After", color="lightblue")
    ax.bar([p + 1.5 * width for p in x], bar_data["late_after"], width, label="Late After", color="blue")

    ax.set_xticks(x)
    ax.set_xticklabels(bar_data["sweep"])
    ax.set_xlabel("Sweep Number")
    ax.set_ylabel("AP Count")
    ax.legend()
    ax.set_title("AP Counts in Early and Late Phases")

def calculate_ap_change(bar_data):
    """
    Calculate the change in AP counts between late and early phases for before and after periods.
    """
    ap_change = {
        "stimulus": [],  # Use the sweep number or any defined stimulus input as x-axis
        "before_change": [],  # Late - Early (Before)
        "after_change": []    # Late - Early (After)
    }

    for i, sweep in enumerate(bar_data["sweep"]):
        before_change = bar_data["late_before"][i] - bar_data["early_before"][i]
        after_change = bar_data["late_after"][i] - bar_data["early_after"][i]

        ap_change["stimulus"].append(sweep)  # Use the sweep number as a proxy for stimulus
        ap_change["before_change"].append(before_change)
        ap_change["after_change"].append(after_change)

    return ap_change


def generate_wide_plot_2(ax, bar_data):
    """
    Generate the second wide plot showing AP change between late and early phases for before and after periods.
    """
    ap_change = calculate_ap_change(bar_data)

    # Plot before and after changes as a dot plot
    ax.scatter(ap_change["stimulus"], ap_change["before_change"], color="gray", label="Before")
    ax.scatter(ap_change["stimulus"], ap_change["after_change"], color="blue", label="After")

    # Connect the points with lines for visualization
    for i in range(len(ap_change["stimulus"])):
        ax.plot(
            [ap_change["stimulus"][i], ap_change["stimulus"][i]],
            [ap_change["before_change"][i], ap_change["after_change"][i]],
            color="black",
            alpha=0.5,
            linestyle="--"
        )

    # Set labels, legend, and title
    ax.set_xlabel("Stimulus Input (Sweep Number)")
    ax.set_ylabel("Change in AP Count (Late - Early)")
    ax.legend()
    ax.set_title("Change in AP Count (Late - Early) Before and After")


def calculate_isi_distribution(peaks, time):
    """
    Calculate interspike intervals (ISIs) from detected peaks.
    """
    if len(peaks) < 2:  # Need at least two peaks to calculate ISIs
        return []
    return np.diff(time[peaks])


def generate_wide_plot_3(ax, before_isi, after_isi):
    """
    Generate the third wide plot showing the ISI distribution for before and after periods using histograms.
    """
    # Define a consistent range and bin size for histograms
    all_isi = before_isi + after_isi
    if not all_isi:  # Handle the case where no ISIs are present
        ax.text(0.5, 0.5, "No ISIs detected", ha="center", va="center", fontsize=12)
        ax.set_title("ISI Distribution Before and After")
        return

    bins = np.histogram_bin_edges(all_isi, bins=30)  # Determine bin edges based on combined ISI data

    # Plot histograms for before and after ISIs
    ax.hist(before_isi, bins=bins, alpha=0.5, color="gray", label="Before", density=True)
    ax.hist(after_isi, bins=bins, alpha=0.5, color="blue", label="After", density=True)

    # Set axis labels, legend, and title
    ax.set_xlabel("Interspike Interval (s)")
    ax.set_ylabel("Density")
    ax.legend()
    ax.set_title("ISI Distribution Before and After")

def calculate_fractions(bar_data):
    """
    Calculate the fractions of early and late spikes for before and after.
    """
    fractions = {
        "sweep": bar_data["sweep"],
        "early_fraction_before": [],
        "late_fraction_before": [],
        "early_fraction_after": [],
        "late_fraction_after": []
    }

    for i, sweep in enumerate(bar_data["sweep"]):
        total_before = bar_data["early_before"][i] + bar_data["late_before"][i]
        total_after = bar_data["early_after"][i] + bar_data["late_after"][i]

        if total_before > 0:
            fractions["early_fraction_before"].append(bar_data["early_before"][i] / total_before)
            fractions["late_fraction_before"].append(bar_data["late_before"][i] / total_before)
        else:
            fractions["early_fraction_before"].append(np.nan)
            fractions["late_fraction_before"].append(np.nan)

        if total_after > 0:
            fractions["early_fraction_after"].append(bar_data["early_after"][i] / total_after)
            fractions["late_fraction_after"].append(bar_data["late_after"][i] / total_after)
        else:
            fractions["early_fraction_after"].append(np.nan)
            fractions["late_fraction_after"].append(np.nan)

    return fractions

def generate_wide_plot_fraction(ax, fractions):
    """
    Generate a bar plot showing the fraction of early and late spikes before and after,
    mirroring the style and color scheme of generate_wide_plot_1.
    """
    width = 0.2  # Same width as generate_wide_plot_1
    x = range(len(fractions["sweep"]))

    # Plot four bars per sweep, just like generate_wide_plot_1
    ax.bar([p - 1.5 * width for p in x], fractions["early_fraction_before"], width, label="Early Fraction Before", color="lightgray")
    ax.bar([p - 0.5 * width for p in x], fractions["late_fraction_before"], width, label="Late Fraction Before", color="gray")
    ax.bar([p + 0.5 * width for p in x], fractions["early_fraction_after"], width, label="Early Fraction After", color="lightblue")
    ax.bar([p + 1.5 * width for p in x], fractions["late_fraction_after"], width, label="Late Fraction After", color="blue")

    ax.set_xticks(x)
    ax.set_xticklabels(fractions["sweep"])
    ax.set_xlabel("Sweep Number")
    ax.set_ylabel("Fraction of Spikes")
    ax.legend()
    ax.set_title("Fraction of Early and Late Spikes Before and After")

def generate_ecdf_plot(ax, fractions):
    """
    Generate an ECDF plot comparing the distribution of fraction differences
    (early_fraction - late_fraction) before and after.

    Parameters
    ----------
    ax : matplotlib.axes.Axes
        The axes on which to draw the plot.
    fractions : dict
        A dictionary containing:
            "early_fraction_before"
            "late_fraction_before"
            "early_fraction_after"
            "late_fraction_after"
    """
    # Compute the differences for before and after conditions
    diff_before = np.array(fractions["early_fraction_before"]) - np.array(fractions["late_fraction_before"])
    diff_after = np.array(fractions["early_fraction_after"]) - np.array(fractions["late_fraction_after"])

    # Remove NaN values if any exist (in cases where no spikes were detected)
    diff_before = diff_before[~np.isnan(diff_before)]
    diff_after = diff_after[~np.isnan(diff_after)]

    # Sort the data
    diff_before_sorted = np.sort(diff_before)
    diff_after_sorted = np.sort(diff_after)

    # Compute ECDF y-values
    y_before = np.arange(1, len(diff_before_sorted) + 1) / len(diff_before_sorted) if len(diff_before_sorted) > 0 else []
    y_after = np.arange(1, len(diff_after_sorted) + 1) / len(diff_after_sorted) if len(diff_after_sorted) > 0 else []

    # Plot the ECDFs
    ax.plot(diff_before_sorted, y_before, color="gray", label="Before")
    ax.plot(diff_after_sorted, y_after, color="blue", label="After")

    # Set axis limits and labels
    ax.set_xlim(-1, 1)
    ax.set_xlabel("Fraction Difference (Early - Late)")
    ax.set_ylabel("Cumulative Probability")
    ax.legend()
    ax.set_title("ECDF of Fraction Differences Before and After")

### compare ecdf acrross groups 

### FOR LINE PLOT 
def plot_traces(ax_before, ax_after, time_before, voltage_before, time_after, voltage_after):
    """
    Plot Before and After traces on separate axes.

    Args:
        ax_before (matplotlib axis): Axis to plot Before trace.
        ax_after (matplotlib axis): Axis to plot After trace.
        time_before, time_after: Time arrays for Before and After traces.
        voltage_before, voltage_after: Voltage arrays for Before and After traces.
    """
    # Plot Before traces
    ax_before.plot(time_before, voltage_before, color="gray")
    ax_before.set_title("Before", fontsize=8)
    ax_before.set_xlabel("Time (s)", fontsize=6)
    ax_before.set_ylabel("Voltage (mV)", fontsize=6)

    # Plot After traces
    ax_after.plot(time_after, voltage_after, color="blue")
    ax_after.set_title("After", fontsize=8)
    ax_after.set_xlabel("Time (s)", fontsize=6)

def plot_io_curve(ax, group_data, label, color, offset):
    """
    Plot input-output curve with mean and SEM for a single condition.

    Args:
        ax (matplotlib axis): Axis to plot on.
        group_data (DataFrame): Filtered data for the condition.
        label (str): Legend label.
        color (str): Color for the markers and lines.
        offset (float): Offset for the sweep numbers.
    """
    mean = group_data.groupby("Sweep_Number")["AP_Count"].mean()
    sem = group_data.groupby("Sweep_Number")["AP_Count"].sem()
    sweep_numbers = mean.index

    ax.errorbar(
        sweep_numbers + offset,
        mean,
        yerr=sem,
        fmt="o",
        color=color,
        label=label,
        capsize=3,
        markersize=4,
    )
    ax.set_xlabel("Sweep Number", fontsize=6)
    ax.set_ylabel("Mean AP Count (± SEM)", fontsize=6)
    ax.legend(fontsize=6)
    ax.grid(True, linestyle="--", linewidth=0.5)




In [10]:
# Example Usage
base_path = "/Users/ecrespo/Desktop/BLADe_patch_data"
processor = BladePatchDataProcessor(base_path)

# Process the data
processor.process_data()

# Access the attributes
print("DataFrame:")
print(processor.dataframe)
print("\nUnique Groups:")
print(processor.unique_groups)

# Print a summary
print("\nSummary:")
print(processor.get_summary())

DataFrame:
            Group Recording_ID   Label  \
0    L + ACR2-CTZ        CTZ-1   After   
1    L + ACR2-CTZ        CTZ-1  Before   
2    L + ACR2-CTZ       CTZ-11   After   
3    L + ACR2-CTZ       CTZ-11  Before   
4    L + ACR2-CTZ       CTZ-12   After   
..            ...          ...     ...   
114         Plain        CTZ-7  Before   
115         Plain        CTZ-8   After   
116         Plain        CTZ-8  Before   
117         Plain        CTZ-9   After   
118         Plain        CTZ-9  Before   

                                             File_Path  
0    /Users/ecrespo/Desktop/BLADe_patch_data/L + AC...  
1    /Users/ecrespo/Desktop/BLADe_patch_data/L + AC...  
2    /Users/ecrespo/Desktop/BLADe_patch_data/L + AC...  
3    /Users/ecrespo/Desktop/BLADe_patch_data/L + AC...  
4    /Users/ecrespo/Desktop/BLADe_patch_data/L + AC...  
..                                                 ...  
114  /Users/ecrespo/Desktop/BLADe_patch_data/Plain/...  
115  /Users/ecrespo/Desktop/

Check Available Groups and Recordings

In [50]:
print("Groups:", processor.unique_groups)
# For a given group:
group = processor.unique_groups[0]  # for example, pick the first group
print("Recordings in this group:", processor.get_recording_ids(group))

Groups: ['L + ACR2-CTZ', 'L + CS-CTZ', 'L + CS-Veh', 'L + DUD-CTZ', 'L Only', 'Plain']
Recordings in this group: ['CTZ-1' 'CTZ-11' 'CTZ-12' ... 'CTZ-7' 'CTZ-8' 'CTZ-9']


Plot a Single Recording’s Before/After Comparison

In [None]:
recording_id = processor.get_recording_ids(group)[0]  # take the first recording in the group
fig, axes = processor.plot_before_after_comparison(group, recording_id,
                                                   before_label="Before",
                                                   after_label="After",
                                                   sweep_numbers=None,  # or specify sweeps like [0, 1, 2]
                                                   startAtSec=0,
                                                   endAtSec=1.5)
plt.show()
processor.plot_scalebar(axes[1])
plt.show()

In [None]:
output_pdf_path = "/Users/ecrespo/Desktop/BLADe_patch_data_output"

processor.export_all_groups_to_pdfs(
    output_dir=output_pdf_path,
    before_label="Before",
    after_label="After",
    sweep_numbers=None,
    startAtSec=0.08,
    endAtSec=1.2,
    offsetXsec=0.3,
    offsetYunits=40,
    color_before=None,
    color_after="red",
    alpha=0.5,
    linewidth=1
)

In [12]:
output_pdf_path = "/Users/ecrespo/Desktop/BLADe_patch_data_output"

processor.export_all_groups_to_svgs(output_dir=output_pdf_path,
                                    before_label="Before",
                                    after_label="After",
                                    sweep_numbers=None,
                                    startAtSec=0.08,
                                    endAtSec=1.2,
                                    offsetXsec=0.1,
                                    offsetYunits=90,
                                    color_before="grey",
                                    color_after="blue",
                                    alpha=0.5,
                                    linewidth=0.5,
                                    dpi=300,    # For any raster elements
                                    add_suptitle=True,
                                    scaleXms=200,
                                    scaleYmV=50)

In [None]:
processor.dataframe

In [None]:
abf = pyabf.ABF('/Users/ecrespo/Desktop/BLADe_patch_data/L + ACR2-CTZ/L + ACR2 After CTZ-1 21224028.abf')
fig = plt.figure(figsize=(8, 5))

ax1 = fig.add_subplot(211 )
ax1.set_title("ABF Recording")
ax1.set_ylabel(abf.sweepLabelY)
ax1.plot(abf.sweepX, abf.sweepY, 'b', lw=.5)
plt.tight_layout()
plt.show()

In [None]:
output_dir = "/Users/ecrespo/Desktop/BLADe_patch_data_output"
processor.plot_sweeps_pdf(output_dir)

In [None]:
result = processor.detect_action_potentials(
    group="Plain",
    recording_id="CTZ-1",
    label="Before",
    sweep_number=5,
    height=0.1,  # Adjust based on signal properties
    prominence=0.05,
    distance=50
)

print(result)

In [None]:
output_dir = "/Users/ecrespo/Desktop/BLADe_patch_data_output_standard"
processor.create_group_pdf_with_peaks(
    output_dir=output_dir,
    height=0.1,  # Adjust based on signal properties
    prominence=0.05,
    distance=50
)

In [None]:
processor.create_group_pdf_with_peaks_complex(
    output_dir="/Users/ecrespo/Desktop/BLADe_patch_data_output_complex",
    height=0.1,  # Adjust based on signal properties
    prominence=0.05,
    distance=50, 
    time_range=None,
    early_phase=(0.1,0.49),
    late_phase=(0.5,1.12)
)

In [51]:
# groups ecdf
processor.create_group_pooled_ecdf(
    output_dir="/Users/ecrespo/Desktop/BLADe_patch_data_output_ecdf",
    height=0.1,  
    prominence=0.05,
    distance=50,
    time_range=None,
    early_phase=(0.1, 0.49),
    late_phase=(0.5, 1.12), 
    x_min=-1.0,
    x_max=1.0,
    step_mode="mid"
)

Saved pooled ECDF PDF for group 'L + ACR2-CTZ' to: /Users/ecrespo/Desktop/BLADe_patch_data_output_ecdf/L + ACR2-CTZ_pooled_ecdf.pdf
Saved pooled ECDF PDF for group 'L + CS-CTZ' to: /Users/ecrespo/Desktop/BLADe_patch_data_output_ecdf/L + CS-CTZ_pooled_ecdf.pdf
Saved pooled ECDF PDF for group 'L + CS-Veh' to: /Users/ecrespo/Desktop/BLADe_patch_data_output_ecdf/L + CS-Veh_pooled_ecdf.pdf
Skipping incomplete pair for Recording ID: CTZ-3 in group L + DUD-CTZ
Saved pooled ECDF PDF for group 'L + DUD-CTZ' to: /Users/ecrespo/Desktop/BLADe_patch_data_output_ecdf/L + DUD-CTZ_pooled_ecdf.pdf
Saved pooled ECDF PDF for group 'L Only' to: /Users/ecrespo/Desktop/BLADe_patch_data_output_ecdf/L Only_pooled_ecdf.pdf
Saved pooled ECDF PDF for group 'Plain' to: /Users/ecrespo/Desktop/BLADe_patch_data_output_ecdf/Plain_pooled_ecdf.pdf


In [None]:
processor.create_group_pooled_ecdf_svg(
    output_dir="/Users/ecrespo/Desktop/BLADe_patch_data_output_ecdf_svgs",
    height=0.1,  
    prominence=0.05,
    distance=50,
    time_range=None,
    early_phase=(0.1, 0.49),
    late_phase=(0.5, 1.12), 
    x_min=-1.0,
    x_max=1.0,
    step_mode="mid", 
    fig_width=5,
    fig_height=3,
    remove_spines=True,
    transparent=True
)

In [None]:
processor.create_group_pooled_mean_and_individual_traces_ecdf(
    output_dir="/Users/ecrespo/Desktop/BLADe_patch_data_output_ecdf_svgs_individual",
    height=0.1,
    prominence=0.05,
    distance=50,
    time_range=None,
    early_phase=(0.1, 0.49),
    late_phase=(0.5, 1.12),
    x_min=-1.0,
    x_max=1.0,
    step_mode="mid",
    fig_width=5,
    fig_height=3,
    remove_spines=True,
    transparent=True,
    individual_line_alpha=0.3,
    individual_line_style=":",
    individual_line_width=0.5
)


In [None]:
processor.create_group_pooled_mean_and_individual_sigmoid_ecdf(
    output_dir="/Users/ecrespo/Desktop/BLADe_patch_data_output_ecdf_svgs_sigmoid",
    height=0.1,
    prominence=0.05,
    distance=50,
    time_range=None,
    early_phase=(0.1, 0.49),
    late_phase=(0.5, 1.12),
    x_min=-1.0,
    x_max=1.0,
    fig_width=5,
    fig_height=3,
    remove_spines=True,
    transparent=True,
    individual_line_alpha=0.3,
    individual_line_style=":",
    individual_line_width=0.5,
    smoothing_factor=1.0
)

In [24]:
cell_data_df, summary_df = processor.compare_group_distributions()

In [None]:
print(cell_data_df)


In [None]:
print(summary_df)

In [None]:
processor.create_group_pooled_mean_and_individual_smoothed_ecdf(
    output_dir="/Users/ecrespo/Desktop/BLADe_patch_data_output_ecdf_svgs_smoothed",
    height=0.1,
    prominence=0.05,
    distance=50,
    time_range=None,
    early_phase=(0.1, 0.49),
    late_phase=(0.5, 1.12),
    x_min=-1.0,
    x_max=1.0,
    fig_width=5,
    fig_height=3,
    remove_spines=True,
    transparent=True,
    individual_line_alpha=0.3,
    individual_line_style=":",
    individual_line_width=0.5,
    main_line_width=1.5,
    smoothing_window=-.01
)

In [None]:
output_dir = "/Users/ecrespo/Desktop/BLADe_patch_data_output_optimized"
processor.create_group_pdf_with_peaks(
    output_dir=output_dir,
    height=0.08,  # Adjust based on signal properties
    prominence=0.1,
    distance=30, 
    width=0.01
)

In [None]:
processor.process_peaks(
    height=0.08,  # Adjust based on signal properties
    prominence=0.1,
    distance=30, 
    width=0.01,
    save_csv_path='/Users/ecrespo/Desktop/BLADe_patch_data_output_optimized/peak_data.csv'
)

In [None]:
processor.import_csv_and_plot_mean_peaks(
    csv_path='/Users/ecrespo/Desktop/BLADe_patch_data_output_optimized/peak_data_modified.csv',
    output_pdf_path='/Users/ecrespo/Desktop/BLADe_patch_data_output_optimized/peak_data_modified_plot.pdf'
)

In [None]:
processor.import_csv_and_plot_mean_peaks_lineplot(csv_path='/Users/ecrespo/Desktop/BLADe_patch_data_output_optimized/peak_data_modified.csv',
    output_pdf_path='/Users/ecrespo/Desktop/BLADe_patch_data_output_optimized/peak_data_modified_lineplot.pdf')

In [None]:
processor.import_csv_and_plot_mean_peaks_with_error_bars(csv_path='/Users/ecrespo/Desktop/BLADe_patch_data_output_optimized/peak_data_modified.csv',
    output_pdf_path='/Users/ecrespo/Desktop/BLADe_patch_data_output_optimized/peak_data_modified_lineplotwitherrorbars.pdf')

In [None]:
processor.import_csv_and_plot_mean_peaks_with_error_bars_svg(
    csv_path='/Users/ecrespo/Desktop/BLADe_patch_data_output_optimized/peak_data_modified.csv',
    output_dir='/Users/ecrespo/Desktop/BLADe_patch_data_output_optimized/svg_lineplots',
    fig_width=5,
    fig_height=3, 
    ymin=-10, 
    ymax=70,
    )

In [None]:
processor.process_peaks_in_window(height=0.8,
    prominence=0.2,
    distance=20,
    start_time=0.1,
    end_time=1.2,
    save_csv_path='/Users/ecrespo/Desktop/BLADe_patch_data_output_optimized/peak_data_specific_window.csv')

In [None]:
processor.import_csv_and_plot_mean_peaks_with_error_bars(csv_path='/Users/ecrespo/Desktop/BLADe_patch_data_output_optimized/peak_data_specific_window.csv',
    output_pdf_path='/Users/ecrespo/Desktop/BLADe_patch_data_output_optimized/peak_data_modified_lineplotwitherrorbars_specific_window.pdf')

In [None]:
processor.peak_window_dataframe

In [None]:
processor.process_peaks_by_phase(
    height=0.8,
    prominence=0.2,
    distance=20,
    early_start=0.1,
    early_end=0.2,
    late_start=0.21,
    late_end=0.41,
    save_csv_path=None
)

In [None]:
processor.phase_peak_dataframe

In [None]:
# Assuming `processor` is your class instance
processor.create_group_pdf_with_deltas_from_dataframe(output_dir='/Users/ecrespo/Desktop/BLADe_patch_data_output_optimized/')

In [None]:
processor.create_group_pdf_with_early_vs_late_counts(output_dir='/Users/ecrespo/Desktop/BLADe_patch_data_output_optimized/')

In [None]:
processor.create_group_pdf_with_early_to_late_ratios(output_dir='/Users/ecrespo/Desktop/BLADe_patch_data_output_optimized/')

In [None]:
processor.create_group_pdf_with_mean_and_sem(output_dir='/Users/ecrespo/Desktop/BLADe_patch_data_output_optimized/')

In [None]:
processor.create_group_pdf_with_mean_and_sem_and_store_data(output_dir='/Users/ecrespo/Desktop/BLADe_patch_data_output_optimized/')

In [None]:
stats_results = processor.run_two_way_anova_with_correction()

In [None]:
stats_results

In [None]:
processor.dataframe

In [49]:
import matplotlib as mpl
mpl.rcParams['svg.fonttype'] = 'none'


# Helper functions
def load_trace_data(file_path, time_window):
    """
    Load and slice trace data for a given file and time window.

    Args:
        file_path (str): Path to the trace data file (e.g., ABF file).
        time_window (tuple): (start, end) time range to extract.

    Returns:
        tuple: (time, voltage), arrays of time and voltage data within the window.
    """
    abf = pyabf.ABF(file_path)
    abf.setSweep(0)  # Use the first sweep for simplicity
    time = abf.sweepX
    voltage = abf.sweepY

    # Apply time window
    mask = (time >= time_window[0]) & (time <= time_window[1])
    return time[mask], voltage[mask]


def plot_traces(ax_before, ax_after, time_before, voltage_before, time_after, voltage_after):
    """
    Plot Before and After traces on separate axes.

    Args:
        ax_before (matplotlib axis): Axis to plot Before trace.
        ax_after (matplotlib axis): Axis to plot After trace.
        time_before, time_after: Time arrays for Before and After traces.
        voltage_before, voltage_after: Voltage arrays for Before and After traces.
    """
    # Plot Before traces
    ax_before.plot(time_before, voltage_before, color="gray")
    ax_before.set_title("Before", fontsize=8)
    ax_before.set_xlabel("Time (s)", fontsize=6)
    ax_before.set_ylabel("Voltage (mV)", fontsize=6)

    # Plot After traces
    ax_after.plot(time_after, voltage_after, color="blue")
    ax_after.set_title("After", fontsize=8)
    ax_after.set_xlabel("Time (s)", fontsize=6)


def plot_combined_io_curve(ax, before_data, after_data):
    """
    Overlay Before and After Input-Output curves on a single axis.

    Args:
        ax (matplotlib axis): Axis to plot on.
        before_data (DataFrame): Filtered data for the Before condition.
        after_data (DataFrame): Filtered data for the After condition.
    """
    # Compute mean and SEM for Before
    mean_before = before_data.groupby("Sweep_Number")["AP_Count"].mean()
    sem_before = before_data.groupby("Sweep_Number")["AP_Count"].sem()

    # Compute mean and SEM for After
    mean_after = after_data.groupby("Sweep_Number")["AP_Count"].mean()
    sem_after = after_data.groupby("Sweep_Number")["AP_Count"].sem()

    # Sweep numbers
    sweep_numbers = mean_before.index

    # Plot Before data
    ax.errorbar(
        sweep_numbers,
        mean_before,
        yerr=sem_before,
        fmt="o",
        color="gray",
        label="Before",
        capsize=3,
        markersize=4,
    )

    # Plot After data
    ax.errorbar(
        sweep_numbers,
        mean_after,
        yerr=sem_after,
        fmt="o",
        color="blue",
        label="After",
        capsize=3,
        markersize=4,
    )

    # Add labels, legend, and title
    ax.set_xlabel("Sweep Number", fontsize=6)
    ax.set_ylabel("Mean AP Count (± SEM)", fontsize=6)
    ax.legend(fontsize=6)
    ax.grid(True, linestyle="--", linewidth=0.5)

def print_sweep_epoch_info(abf_file_path):
    """
    Open an ABF file and print the epoch table for each sweep.
    For each epoch, we display:
      - The epoch index
      - The point index where the epoch begins
      - The epoch type (e.g., "Step", "Ramp")
      - The epoch level (usually in pA for current-clamp or mV for voltage-clamp)
    """
    abf = pyabf.ABF(abf_file_path)
    print(f"ABF File: {abf_file_path}")
    print(f"Sweep count: {abf.sweepCount}")
    print("-" * 60)

    for sweep_num in abf.sweepList:
        abf.setSweep(sweep_num)
        print(f"Sweep {sweep_num}:")

        # abf.sweepEpochs.p1s    -> list of epoch start points (in data-point indices)
        # abf.sweepEpochs.types  -> list of epoch types (e.g. "Step", "Ramp")
        # abf.sweepEpochs.levels -> list of epoch levels (e.g., -70.0 for mV or 50.0 for pA)
        for i, p1 in enumerate(abf.sweepEpochs.p1s):
            epoch_type = abf.sweepEpochs.types[i]
            epoch_level = abf.sweepEpochs.levels[i]
            print(f"  Epoch {i}: starts at point {p1}, type={epoch_type}, level={epoch_level}")
        print()

    print("Done.\n")
from matplotlib.ticker import MultipleLocator

def create_final_figure_by_group(
    group_recording_map,  # Ordered dictionary mapping group name -> desired recording_id
    processor,
    time_window,
    csv_path,
    output_svg_path,  # Path where the final SVG will be saved
    voltage_y_range=(-100, 500),  # y-axis limits for voltage traces
    io_y_range=(-10, 70)          # y-axis limits for IO curves
):
    """
    Create a multi-row figure where each row corresponds to a specific group.
    Only groups specified in group_recording_map are plotted, in the order provided.
    For each group, the left two columns show the 'Before' and 'After' voltage traces for a specified recording,
    and the right columns show the group summary Input–Output (I–O) curve.
    
    The time axis is converted from seconds to milliseconds (ms) and the corresponding labels are updated.
    The final figure is saved as an SVG file with editable text.
    """
    # Load the Input–Output (I–O) data from the CSV file.
    try:
        peak_data = pd.read_csv(csv_path)
    except FileNotFoundError:
        print(f"File not found: {csv_path}")
        return

    # Validate that processor has data.
    if processor.dataframe is None or processor.dataframe.empty:
        print("No data available in processor.dataframe.")
        return

    # Use only the groups provided in group_recording_map, in the order given.
    groups_to_plot = list(group_recording_map.keys())
    n_groups = len(groups_to_plot)
    if n_groups == 0:
        print("No groups specified in group_recording_map. Nothing to plot.")
        return

    plt.rcParams.update({"font.family": "Arial", "font.size": 8})
    # Adjust figure height based on the number of groups.
    fig = plt.figure(figsize=(6.5, n_groups * 2.5))
    gs = gridspec.GridSpec(n_groups, 4, figure=fig)
    plt.subplots_adjust(wspace=0.4, hspace=0.6)

    for i, group in enumerate(groups_to_plot):
        # Get the recording_id for this group.
        recording_id = group_recording_map.get(group)
        if not recording_id:
            print(f"No recording specified for group '{group}'. Skipping...")
            continue

        # Find rows matching the desired recording_id and group.
        rep_data = processor.dataframe[
            (processor.dataframe["Recording_ID"] == recording_id) &
            (processor.dataframe["Group"] == group)
        ]
        if rep_data.empty:
            print(f"Recording '{recording_id}' for group '{group}' not found. Skipping...")
            continue

        # For debugging: print which recording(s) are being used.
        used_recordings = rep_data["Recording_ID"].unique()
        print(f"Group '{group}' will use recording(s): {used_recordings}")

        # Get file paths for 'Before' and 'After' conditions.
        try:
            before_file = rep_data[rep_data["Label"] == "Before"]["File_Path"].iloc[0]
            after_file  = rep_data[rep_data["Label"] == "After"]["File_Path"].iloc[0]
        except IndexError:
            print(f"Missing file path information for group '{group}'. Skipping...")
            continue

        # Load the traces using pyABF.
        before_abf = pyabf.ABF(before_file)
        after_abf  = pyabf.ABF(after_file)
        
        print_sweep_epoch_info(before_file)
        print_sweep_epoch_info(after_file)
        
        # --- Plot the voltage traces ---
        # Plot the 'Before' trace.
        ax_before = fig.add_subplot(gs[i, 0])
        for sweep_number in before_abf.sweepList[::2]:  # Plot every other sweep
            before_abf.setSweep(sweep_number)
            time_before = before_abf.sweepX
            voltage_before = before_abf.sweepY
            mask = (time_before >= time_window[0]) & (time_before <= time_window[1])
            # Convert time from seconds to milliseconds.
            ax_before.plot(time_before[mask] * 1000, voltage_before[mask],
                           color="gray", alpha=0.7, linewidth=0.8)
        ax_before.set_title(f"{group} - Before", fontsize=8)
        ax_before.set_xlabel("Time (ms)", fontsize=6)
        ax_before.set_ylabel("Voltage (mV)", fontsize=6)
        ax_before.set_ylim(voltage_y_range)
        # Set x-axis ticks every 200 ms
        ax_before.xaxis.set_major_locator(MultipleLocator(200))

        # Plot the 'After' trace.
        ax_after = fig.add_subplot(gs[i, 1])
        for sweep_number in after_abf.sweepList[::2]:
            after_abf.setSweep(sweep_number)
            time_after = after_abf.sweepX
            voltage_after = after_abf.sweepY
            mask = (time_after >= time_window[0]) & (time_after <= time_window[1])
            ax_after.plot(time_after[mask] * 1000, voltage_after[mask],
                          color="blue", alpha=0.7, linewidth=0.8)
        ax_after.set_title(f"{group} - After", fontsize=8)
        ax_after.set_xlabel("Time (ms)", fontsize=6)
        ax_after.set_ylim(voltage_y_range)
        # Set x-axis ticks every 200 ms
        ax_after.xaxis.set_major_locator(MultipleLocator(200))

        # --- Plot the group summary I–O curve ---
        ax_io = fig.add_subplot(gs[i, 2:4])
        group_before_data = peak_data[(peak_data["Group"] == group) & (peak_data["Label"] == "Before")]
        group_after_data  = peak_data[(peak_data["Group"] == group) & (peak_data["Label"] == "After")]
        plot_combined_io_curve(ax_io, group_before_data, group_after_data)
        ax_io.set_ylim(io_y_range)
        ax_io.set_title(f"I–O Curve: {group}", fontsize=8)

    # Save the final figure as an SVG file.
    os.makedirs(os.path.dirname(output_svg_path), exist_ok=True)
    fig.savefig(output_svg_path, format="svg")
    plt.close(fig)

    print(f"Final figure saved to: {output_svg_path}")



group_recording_map = {
    'L + CS-CTZ':   'CTZ-6',
    'L + CS-Veh':   'Veh-5',
    'L + ACR2-CTZ': 'CTZ-5',
    'L + DUD-CTZ':  'CTZ-9'
}
time_window = (0, 1.2)  # Specify your desired time window (start, end in seconds)
csv_path = "/Users/ecrespo/Desktop/BLADe_patch_data_output_optimized/peak_data_modified.csv"
output_svg_path = "/Users/ecrespo/Desktop/BLADe_patch_data_output_optimized/final_attempt.svg"

create_final_figure_by_group(
    group_recording_map,
    processor,
    time_window=(0, 1.2),
    csv_path=csv_path,
    output_svg_path=output_svg_path,
    voltage_y_range=(-90, 50),
    io_y_range=(-10, 80)
)

Group 'L + CS-CTZ' will use recording(s): ['CTZ-6']
ABF File: /Users/ecrespo/Desktop/BLADe_patch_data/L + CS-CTZ/L + CS Before CTZ-6 21201009.abf
Sweep count: 10
------------------------------------------------------------
Sweep 0:
  Epoch 0: starts at point 0, type=Step, level=0.0
  Epoch 1: starts at point 234, type=Step, level=0.0
  Epoch 2: starts at point 1234, type=Step, level=-200.0
  Epoch 3: starts at point 11234, type=Step, level=0.0
  Epoch 4: starts at point 14234, type=Step, level=0.0

Sweep 1:
  Epoch 0: starts at point 0, type=Step, level=0.0
  Epoch 1: starts at point 234, type=Step, level=0.0
  Epoch 2: starts at point 1234, type=Step, level=-100.0
  Epoch 3: starts at point 11234, type=Step, level=0.0
  Epoch 4: starts at point 14234, type=Step, level=0.0

Sweep 2:
  Epoch 0: starts at point 0, type=Step, level=0.0
  Epoch 1: starts at point 234, type=Step, level=0.0
  Epoch 2: starts at point 1234, type=Step, level=0.0
  Epoch 3: starts at point 11234, type=Step, lev

In [110]:
import pandas as pd
import numpy as np
import scipy.stats as stats
import os

def save_io_ecdf_data(
    group_recording_map,
    processor,
    csv_path,
    output_io_csv,
    output_ecdf_csv,
    height=0.1,
    prominence=0.05,
    distance=50,
    width=None,
    time_range=(0, 1.2),
    early_phase=(0.1, 0.40),
    late_phase=(0.5, 1.12)
):
    """
    Extracts and saves I-O curve data and ECDF distributions for later statistical analysis.
    This version **removes trials with no spikes** from the ECDF analysis.

    Parameters
    ----------
    group_recording_map : dict
        Ordered dict mapping group_name -> desired recording_id.
    processor : object
        Your data processor with the main dataframe.
    csv_path : str
        Path to the CSV with peak data for I-O curves.
    output_io_csv : str
        Path where the input-output data will be saved.
    output_ecdf_csv : str
        Path where the ECDF data will be saved.
    height, prominence, distance, width : float
        Parameters for peak detection.
    time_range : tuple
        Time window for extracting data.
    early_phase, late_phase : tuple
        Time windows for computing early and late fractions.
    """
    
    # Load I-O peak data
    try:
        peak_data = pd.read_csv(csv_path)
    except FileNotFoundError:
        print(f"File not found: {csv_path}")
        return
    
    # Store I-O data in tidy format
    io_data_list = []
    ecdf_data_list = []

    for group in group_recording_map.keys():
        # Extract I-O Data
        group_before = peak_data[(peak_data["Group"] == group) & (peak_data["Label"] == "Before")]
        group_after  = peak_data[(peak_data["Group"] == group) & (peak_data["Label"] == "After")]

        # Add to list
        io_data_list.append(group_before.assign(Condition="Before"))
        io_data_list.append(group_after.assign(Condition="After"))

        # Extract ECDF Data (Fraction Differences)
        group_data_all = processor.dataframe[processor.dataframe["Group"] == group]
        all_rec_ids = group_data_all["Recording_ID"].unique()

        for rec_id in all_rec_ids:
            sub_df = group_data_all[group_data_all["Recording_ID"] == rec_id]
            before_row = sub_df[sub_df["Label"] == "Before"]
            after_row  = sub_df[sub_df["Label"] == "After"]

            if before_row.empty or after_row.empty:
                continue  # Skip if missing

            before_path = before_row["File_Path"].iloc[0]
            after_path  = after_row["File_Path"].iloc[0]

            # Load ABFs
            b_abf = pyabf.ABF(before_path)
            a_abf = pyabf.ABF(after_path)

            for sweep_number in b_abf.sweepList:
                # Process "Before" sweep
                _, _, _, early_b, late_b = process_sweep_data(
                    b_abf, sweep_number, height, prominence, distance, width, time_range, early_phase, late_phase
                )
                # Process "After" sweep
                _, _, _, early_a, late_a = process_sweep_data(
                    a_abf, sweep_number, height, prominence, distance, width, time_range, early_phase, late_phase
                )

                # Compute fraction difference
                diff_b = early_b - late_b
                diff_a = early_a - late_a

                # **Filter out trials where both early and late have zero spikes**
                if early_b == 0 and late_b == 0:
                    continue  # Skip this trial for Before
                if early_a == 0 and late_a == 0:
                    continue  # Skip this trial for After

                # Store only relevant trials
                ecdf_data_list.append({
                    "Group": group,
                    "Recording_ID": rec_id,
                    "Sweep": sweep_number,
                    "Condition": "Before",
                    "Fraction_Diff": diff_b
                })
                ecdf_data_list.append({
                    "Group": group,
                    "Recording_ID": rec_id,
                    "Sweep": sweep_number,
                    "Condition": "After",
                    "Fraction_Diff": diff_a
                })

    # Convert to DataFrame and save
    io_data_df = pd.concat(io_data_list, ignore_index=True)
    ecdf_data_df = pd.DataFrame(ecdf_data_list)

    os.makedirs(os.path.dirname(output_io_csv), exist_ok=True)
    os.makedirs(os.path.dirname(output_ecdf_csv), exist_ok=True)

    io_data_df.to_csv(output_io_csv, index=False)
    ecdf_data_df.to_csv(output_ecdf_csv, index=False)

    print(f"I-O data saved to: {output_io_csv}")
    print(f"ECDF data saved to: {output_ecdf_csv}")

def perform_statistical_tests(io_csv, ecdf_csv, output_stats_csv):
    """
    Performs Wilcoxon signed-rank test on I-O data (AP_Count) and KS test on ECDF distributions.
    The KS test now directly compares the two ECDF distributions instead of looking at differences.

    Parameters
    ----------
    io_csv : str
        Path to the input-output data CSV.
    ecdf_csv : str
        Path to the ECDF fraction difference data CSV.
    output_stats_csv : str
        Path where statistical results will be saved.
    """
    # Load data
    try:
        io_data = pd.read_csv(io_csv)
        ecdf_data = pd.read_csv(ecdf_csv)
    except FileNotFoundError as e:
        print(f"Error: {e}")
        return

    stats_results = []

    # Validate required columns
    if "AP_Count" not in io_data.columns:
        print("Error: 'AP_Count' column missing in I-O data. Check input data.")
        return
    if "Fraction_Diff" not in ecdf_data.columns:
        print("Error: 'Fraction_Diff' column missing in ECDF data. Check input data.")
        return

    # --- Wilcoxon Signed-Rank Test (I-O) ---
    for group in io_data["Group"].unique():
        group_io = io_data[io_data["Group"] == group]

        before_io = group_io[group_io["Condition"] == "Before"]["AP_Count"]
        after_io = group_io[group_io["Condition"] == "After"]["AP_Count"]

        if len(before_io) == len(after_io) and len(before_io) > 0:
            stat, p = stats.wilcoxon(before_io, after_io)
            mean_diff = np.mean(after_io - before_io)
            median_diff = np.median(after_io - before_io)
            std_diff = np.std(after_io - before_io, ddof=1)
            N = len(before_io)
            stats_results.append(["Wilcoxon", group, N, mean_diff, median_diff, std_diff, p])
        else:
            print(f"Skipping Wilcoxon for {group}: Unequal or insufficient samples.")

    # --- Kolmogorov-Smirnov Test (ECDF) ---
    for group in ecdf_data["Group"].unique():
        group_ecdf = ecdf_data[ecdf_data["Group"] == group]

        before_ecdf = group_ecdf[group_ecdf["Condition"] == "Before"]["Fraction_Diff"]
        after_ecdf = group_ecdf[group_ecdf["Condition"] == "After"]["Fraction_Diff"]

        if len(before_ecdf) > 0 and len(after_ecdf) > 0:
            # Perform KS test directly on the distributions
            stat, p = stats.ks_2samp(before_ecdf, after_ecdf)
            N = len(before_ecdf)
            stats_results.append(["KS Test", group, N, stat, np.nan, np.nan, p])
        else:
            print(f"Skipping KS test for {group}: Insufficient ECDF samples.")

    # Convert to DataFrame
    stats_df = pd.DataFrame(stats_results, columns=["Test", "Group", "N", "Stat", "Median_Diff", "SD", "P_Value"])
    
    os.makedirs(os.path.dirname(output_stats_csv), exist_ok=True)
    stats_df.to_csv(output_stats_csv, index=False)

    print("Statistical Test Results:")
    print(stats_df)

    return stats_df

def create_final_figure_by_group_appendededf(
    group_recording_map,
    processor,
    time_window,
    csv_path,
    output_svg_path,
    voltage_y_range=(-100, 500),
    io_y_range=(-10, 70),
    early_phase=(0.1, 0.40),
    late_phase=(0.5, 1.12),
    step_mode="mid",
    fig_width=8,
    fig_height_per_group=3,
    transparent=True,
    # New parameters (optional) for your peak detection
    height=0.1,
    prominence=0.05,
    distance=50,
    width=None
):
    """
    This function merges the logic of create_final_figure_by_group (plotting Before/After
    traces and I–O curves) with the ECDF logic from create_group_pooled_ecdf_svg.
    Each group gets a row of subplots:
      - col 0: Before traces (representative recording)
      - col 1: After traces  (representative recording)
      - col 2: I–O curve     (entire group, from CSV)
      - col 3: ECDF of (EarlyFraction - LateFraction) pooled across *all* recordings in that group.

    Parameters
    ----------
    group_recording_map : dict
        Ordered dict mapping group_name -> desired recording_id (the "representative" for voltage).
    processor : object
        Your data processor with a dataframe of columns ["Recording_ID", "Group", "Label", "File_Path", ...].
    time_window : tuple
        (start_sec, end_sec) for extracting a slice of the sweep for plotting (columns 0/1).
    csv_path : str
        Path to your CSV with peak_data for plotting I–O curves in column 2.
    output_svg_path : str
        Where to save the final SVG file.
    voltage_y_range : tuple
        Y-axis limits for the voltage traces (Before/After).
    io_y_range : tuple
        Y-axis limits for the I–O curve.
    early_phase : tuple
        (start_sec, end_sec) for the "early" portion of each sweep (for fraction).
    late_phase : tuple
        (start_sec, end_sec) for the "late" portion of each sweep (for fraction).
    step_mode : str
        How the ECDF step lines are drawn: "pre", "post", or "mid".
    fig_width : float
        The overall figure width in inches.
    fig_height_per_group : float
        How tall each group’s row should be (in inches).
    transparent : bool
        If True, set figure background to transparent.
    height, prominence, distance, width : float or None
        Peak detection parameters to pass into `process_sweep_data(...).`
    """

    import matplotlib as mpl
    import matplotlib.pyplot as plt
    import matplotlib.gridspec as gridspec
    from matplotlib.ticker import MultipleLocator
    import pyabf
    import pandas as pd
    import numpy as np
    import os

    # Ensure SVG text remains editable
    mpl.rcParams["svg.fonttype"] = "none"

    # 1) Load the Input–Output data from CSV (for column 2)
    try:
        peak_data = pd.read_csv(csv_path)
    except FileNotFoundError:
        print(f"File not found: {csv_path}")
        return

    # 2) Validate that processor has data
    if processor.dataframe is None or processor.dataframe.empty:
        print("No data available in processor.dataframe.")
        return

    # 3) Determine which groups to plot
    groups_to_plot = list(group_recording_map.keys())
    n_groups = len(groups_to_plot)
    if n_groups == 0:
        print("No groups specified in group_recording_map. Nothing to plot.")
        return

    # 4) Create the figure with 4 columns: (Before, After, I–O, ECDF)
    fig_height = n_groups * fig_height_per_group
    fig = plt.figure(figsize=(fig_width, fig_height))
    gs = gridspec.GridSpec(n_groups, 4, figure=fig)
    plt.subplots_adjust(wspace=0.4, hspace=0.6)
    if transparent:
        fig.patch.set_facecolor('none')

    # Go group-by-group
    for i, group in enumerate(groups_to_plot):
        recording_id = group_recording_map.get(group)
        if not recording_id:
            print(f"No 'representative' recording specified for group '{group}'. Skipping...")
            continue

        # For the entire group, gather all rows
        group_data_all = processor.dataframe[processor.dataframe["Group"] == group]
        if group_data_all.empty:
            print(f"No data found at all for group '{group}'. Skipping row.")
            continue

        # --- Columns 0 & 1: Plot "representative" Before/After from group_recording_map ---
        rep_data = group_data_all[group_data_all["Recording_ID"] == recording_id]
        if rep_data.empty:
            print(f"Representative rec_id '{recording_id}' not found for group '{group}'.")
            print("We'll skip the 'Before/After' voltage subplots but still do the pooled ECDF.")
        else:
            try:
                before_file = rep_data[rep_data["Label"] == "Before"]["File_Path"].iloc[0]
                after_file  = rep_data[rep_data["Label"] == "After"]["File_Path"].iloc[0]
            except IndexError:
                print(f"Incomplete 'Before'/'After' for group '{group}', rec_id='{recording_id}'.")
                before_file, after_file = None, None

            # Column 0: "Before" trace(s)
            ax_before = fig.add_subplot(gs[i, 0])
            if before_file:
                before_abf = pyabf.ABF(before_file)
                for sweep_number in before_abf.sweepList[::2]:
                    before_abf.setSweep(sweep_number)
                    t = before_abf.sweepX
                    v = before_abf.sweepY
                    mask = (t >= time_window[0]) & (t <= time_window[1])
                    ax_before.plot(t[mask] * 1000, v[mask],
                                   color="gray", alpha=0.7, linewidth=0.8)
                ax_before.set_title(f"{group}\nRep: {recording_id}\nBefore", fontsize=8)
            else:
                ax_before.text(0.5, 0.5, "No 'Before' file", transform=ax_before.transAxes,
                               ha="center", va="center")
            ax_before.set_xlabel("Time (ms)", fontsize=6)
            ax_before.set_ylabel("Voltage (mV)", fontsize=6)
            ax_before.set_ylim(voltage_y_range)
            ax_before.xaxis.set_major_locator(MultipleLocator(200))

            # Column 1: "After" trace(s)
            ax_after = fig.add_subplot(gs[i, 1])
            if after_file:
                after_abf = pyabf.ABF(after_file)
                for sweep_number in after_abf.sweepList[::2]:
                    after_abf.setSweep(sweep_number)
                    t = after_abf.sweepX
                    v = after_abf.sweepY
                    mask = (t >= time_window[0]) & (t <= time_window[1])
                    ax_after.plot(t[mask] * 1000, v[mask],
                                  color="blue", alpha=0.7, linewidth=0.8)
                ax_after.set_title(f"{group}\nRep: {recording_id}\nAfter", fontsize=8)
            else:
                ax_after.text(0.5, 0.5, "No 'After' file", transform=ax_after.transAxes,
                              ha="center", va="center")
            ax_after.set_xlabel("Time (ms)", fontsize=6)
            ax_after.set_ylim(voltage_y_range)
            ax_after.xaxis.set_major_locator(MultipleLocator(200))

        # --- Column 2: The group-level I–O curve from the entire CSV for that group ---
        ax_io = fig.add_subplot(gs[i, 2])
        group_before_data = peak_data[(peak_data["Group"] == group) & (peak_data["Label"] == "Before")]
        group_after_data  = peak_data[(peak_data["Group"] == group) & (peak_data["Label"] == "After")]
        plot_combined_io_curve(ax_io, group_before_data, group_after_data)
        ax_io.set_ylim(io_y_range)
        ax_io.set_title(f"I–O Curve: {group}", fontsize=8)

        # --- Column 3: Pooled ECDF across all recordings in THIS group ---
        ax_ecdf = fig.add_subplot(gs[i, 3])
        pooled_diff_before = []
        pooled_diff_after  = []

        # For each recording in THIS group, find "Before" / "After" ABFs and
        # gather fraction differences from all sweeps
        all_rec_ids = group_data_all["Recording_ID"].unique()
        for rec_id in all_rec_ids:
            sub_df = group_data_all[group_data_all["Recording_ID"] == rec_id]
            # We expect 1 "Before" row, 1 "After" row
            before_row = sub_df[sub_df["Label"] == "Before"]
            after_row  = sub_df[sub_df["Label"] == "After"]
            if before_row.empty or after_row.empty:
                # skip incomplete
                continue

            before_path = before_row["File_Path"].iloc[0]
            after_path  = after_row["File_Path"].iloc[0]

            # Load ABFs
            b_abf = pyabf.ABF(before_path)
            a_abf = pyabf.ABF(after_path)

            # Loop all sweeps for each ABF, measure fraction difference
            for sweep_number in b_abf.sweepList:
                # "Before" sweep
                time_b, volt_b, peaks_b, early_b, late_b = process_sweep_data(
                    b_abf,
                    sweep_number=sweep_number,
                    height=height,
                    prominence=prominence,
                    distance=distance,
                    width=width,
                    time_range=time_window,
                    early_phase=early_phase,
                    late_phase=late_phase
                )
                # "After" sweep (same sweep_number, typically)
                time_a, volt_a, peaks_a, early_a, late_a = process_sweep_data(
                    a_abf,
                    sweep_number=sweep_number,
                    height=height,
                    prominence=prominence,
                    distance=distance,
                    width=width,
                    time_range=time_window,
                    early_phase=early_phase,
                    late_phase=late_phase
                )

                # Build dict for fraction function
                bar_data = {
                    "sweep": [sweep_number],
                    "early_before": [early_b],
                    "late_before":  [late_b],
                    "early_after":  [early_a],
                    "late_after":   [late_a],
                }
                fractions = calculate_fractions(bar_data)
                diff_b = fractions["early_fraction_before"][0] - fractions["late_fraction_before"][0]
                diff_a = fractions["early_fraction_after"][0]  - fractions["late_fraction_after"][0]

                if not np.isnan(diff_b):
                    pooled_diff_before.append(diff_b)
                if not np.isnan(diff_a):
                    pooled_diff_after.append(diff_a)

        # Now we have big lists of fraction differences from ALL recordings in this group
        diff_before_sorted = np.sort(pooled_diff_before)
        diff_after_sorted  = np.sort(pooled_diff_after)

        # A local ECDF helper
        def ecdf(data, x_min=-1.0, x_max=1.0):
            if len(data) == 0:
                x = [x_min, x_max]
                y = [0.0, 0.0]
            else:
                y = np.arange(1, len(data) + 1) / len(data)
                x = np.concatenate([[x_min], data, [x_max]])
                y = np.concatenate([[0.0], y, [1.0]])
            return x, y

        x_b, y_b = ecdf(diff_before_sorted)
        x_a, y_a = ecdf(diff_after_sorted)

        ax_ecdf.step(x_b, y_b, color="gray", label="Before", where=step_mode)
        ax_ecdf.step(x_a, y_a, color="blue",  label="After",  where=step_mode)
        ax_ecdf.set_xlim(-1.0, 1.0)
        ax_ecdf.set_ylim(0, 1.0)
        ax_ecdf.set_xlabel("Frac. Diff (Early - Late)", fontsize=6)
        ax_ecdf.set_ylabel("Cumulative Probability",    fontsize=6)
        ax_ecdf.set_title("Pooled ECDF", fontsize=8)
        ax_ecdf.legend(fontsize=6)
        ax_ecdf.grid(True, linestyle="--", linewidth=0.5)

    # -- End for each group --

    # Save the final figure as SVG
    os.makedirs(os.path.dirname(output_svg_path), exist_ok=True)
    fig.savefig(output_svg_path, format="svg", transparent=transparent, bbox_inches="tight")
    plt.close(fig)

    print(f"Final figure saved to: {output_svg_path}")
    
    # Paths for data collection and stats

output_io_csv = "/Users/ecrespo/Desktop/BLADe_patch_data_output_optimized/io_curve_data.csv"
output_ecdf_csv = "/Users/ecrespo/Desktop/BLADe_patch_data_output_optimized/ecdf_fraction_data.csv"
output_stats_csv = "/Users/ecrespo/Desktop/BLADe_patch_data_output_optimized/statistical_results.csv"

create_final_figure_by_group_appendededf(
    group_recording_map,
    processor,
    time_window=(0, 1.2),
    csv_path=csv_path,
    output_svg_path=output_svg_path,
    voltage_y_range=(-90, 50),
    io_y_range=(-10, 80), 
    early_phase=(0.1, 0.40),
    late_phase=(0.5, 1.12),
    step_mode="post",
    fig_width=8,
    fig_height_per_group=3,
    transparent=True
)



Final figure saved to: /Users/ecrespo/Desktop/BLADe_patch_data_output_optimized/final_attempt.svg


In [109]:
save_io_ecdf_data(
    group_recording_map,
    processor,
    csv_path=csv_path,
    output_io_csv=output_io_csv,
    output_ecdf_csv=output_ecdf_csv
)

I-O data saved to: /Users/ecrespo/Desktop/BLADe_patch_data_output_optimized/io_curve_data.csv
ECDF data saved to: /Users/ecrespo/Desktop/BLADe_patch_data_output_optimized/ecdf_fraction_data.csv


In [108]:
# Run statistical analysis on the exact data used for plotting
stats_results = perform_statistical_tests(
    io_csv=output_io_csv,
    ecdf_csv=output_ecdf_csv,
    output_stats_csv=output_stats_csv
)

Skipping Wilcoxon for L + DUD-CTZ: Unequal or insufficient samples.
Statistical Test Results:
       Test         Group    N      Stat  Median_Diff  SD       P_Value
0  Wilcoxon    L + CS-CTZ  110       NaN          NaN NaN  4.954578e-09
1  Wilcoxon    L + CS-Veh   90       NaN          NaN NaN  3.117688e-01
2  Wilcoxon  L + ACR2-CTZ  110       NaN          NaN NaN  4.309037e-10
3   KS Test    L + CS-CTZ   48  0.375000          NaN NaN  2.134584e-03
4   KS Test    L + CS-Veh   48  0.145833          NaN NaN  6.926601e-01
5   KS Test  L + ACR2-CTZ   23  0.565217          NaN NaN  9.901949e-04
6   KS Test   L + DUD-CTZ   29  0.068966          NaN NaN  1.000000e+00


In [101]:
# Debug: Print path variables
print(f"csv_path: {csv_path}")
print(f"output_svg_path: {output_svg_path}")
print(f"output_io_csv: {output_io_csv}")
print(f"output_ecdf_csv: {output_ecdf_csv}")
print(f"output_stats_csv: {output_stats_csv}")

# Ensure none of the paths are empty
assert csv_path.strip() != "", "Error: csv_path is empty!"
assert output_svg_path.strip() != "", "Error: output_svg_path is empty!"
assert output_io_csv.strip() != "", "Error: output_io_csv is empty!"
assert output_ecdf_csv.strip() != "", "Error: output_ecdf_csv is empty!"
assert output_stats_csv.strip() != "", "Error: output_stats_csv is empty!"

csv_path: /Users/ecrespo/Desktop/BLADe_patch_data_output_optimized/peak_data_modified.csv
output_svg_path: /Users/ecrespo/Desktop/BLADe_patch_data_output_optimized/final_attempt.svg
output_io_csv: /Users/ecrespo/Desktop/BLADe_patch_data_output_optimized/io_curve_data.csv
output_ecdf_csv: /Users/ecrespo/Desktop/BLADe_patch_data_output_optimized/ecdf_fraction_data.csv
output_stats_csv: /Users/ecrespo/Desktop/BLADe_patch_data_output_optimized/statistical_results.csv


Making a GUI for counting spikes

In [12]:
import os
import tkinter as tk
from tkinter import ttk, messagebox
import numpy as np
import matplotlib
matplotlib.use("TkAgg")
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure
import pyabf
import pandas as pd
from scipy.signal import find_peaks

class DraggableThresholdLine:
    """
    A simpler line-drag tool (no concurrency management). 
    Uses direct "button_press_event" / "button_release_event" / "motion_notify_event" 
    and a small pixel-distance check around the line for user-friendly picking.
    """
    def __init__(self, ax, init_y, on_release_callback, on_drag_start=None, pick_tolerance_pixels=10):
        """
        Args:
            ax (Axes): The subplot on which to place the line.
            init_y (float): Initial Y position for the line.
            on_release_callback(callable): Called when user releases mouse, 
                with final y-value as argument.
            on_drag_start(callable): Called when user first picks the line.
            pick_tolerance_pixels(int): Pixel radius to consider a line "picked".
        """
        self.ax = ax
        self.canvas = ax.figure.canvas
        self.on_release_callback = on_release_callback
        self.on_drag_start = on_drag_start
        self.pick_tolerance = pick_tolerance_pixels

        # Create the line
        self.line = ax.axhline(y=init_y, color='r', linestyle='--', lw=1)

        # Drag state
        self.is_dragging = False

        # Event IDs
        self.cid_press   = self.canvas.mpl_connect("button_press_event",   self.on_press)
        self.cid_release = self.canvas.mpl_connect("button_release_event", self.on_release)
        self.cid_motion  = self.canvas.mpl_connect("motion_notify_event",  self.on_motion)

    def on_press(self, event):
        """Check if the user clicked near the line."""
        # Only consider left-click in the same axes
        if event.inaxes != self.ax:
            return
        if event.button != 1:  # left mouse button
            return

        # Convert the line's y-value to pixel coords
        y_line = self.line.get_ydata()[0]
        x_min, x_max = self.ax.get_xlim()
        y_min, y_max = self.ax.get_ylim()

        # Transform to display (pixel) space
        line_disp = self.ax.transData.transform((x_min, y_line))
        click_disp = self.ax.transData.transform((event.xdata, event.ydata))
        # We only compare the Y difference in display space 
        # (since it doesn't matter where horizontally you click)
        dist_pixels = abs(line_disp[1] - click_disp[1])

        if dist_pixels <= self.pick_tolerance:
            self.is_dragging = True
            if self.on_drag_start:
                self.on_drag_start()

    def on_motion(self, event):
        """While dragging, move the line's y."""
        if not self.is_dragging:
            return
        if event.inaxes != self.ax:
            return
        if event.button != 1:  # must keep left mouse pressed
            return

        new_y = event.ydata
        self.line.set_ydata([new_y, new_y])
        self.canvas.draw()

    def on_release(self, event):
        """Finalize the drag on mouse release."""
        if not self.is_dragging:
            return
        if event.button != 1:  # left mouse
            return

        self.is_dragging = False
        final_y = self.line.get_ydata()[0]
        if self.on_release_callback:
            self.on_release_callback(final_y)

class ManualPeakCounterGUI:
    """
    A simpler user-friendly GUI:
      - Each threshold line uses DraggableThresholdLine (no concurrency manager).
      - We do a confirm/unconfirm approach for each sweep.
      - We store peak counts & indices in the final CSV.
      - We always run detection on both subplots at load-time 
        so you see spikes right away.
    """

    def __init__(self, master, dataframe, output_csv="manual_peak_counts.csv"):
        self.master = master
        self.master.title("Simple Manual Peak Counter (User-Friendly)")

        self.dataframe = dataframe.copy()
        self.output_csv = output_csv

        # Sort data
        self.dataframe.sort_values(by=["Group", "Recording_ID"], inplace=True)

        # Build list of (group, rec_id)
        self.group_rec_pairs = []
        for group in self.dataframe["Group"].unique():
            sub = self.dataframe[self.dataframe["Group"] == group]
            for rec_id in sub["Recording_ID"].unique():
                self.group_rec_pairs.append((group, rec_id))

        self.current_cell_index  = 0
        self.current_sweep_index = 0

        self.before_abf = None
        self.after_abf  = None
        self.before_sweep_count = 0
        self.after_sweep_count  = 0

        # results dict -> store threshold + counts + indices
        # results[(group, rec_id)] = {
        #   "before_counts": [...],
        #   "before_thresholds": [...],
        #   "before_indices": [...],
        #   "after_counts": [...],
        #   "after_thresholds": [...],
        #   "after_indices": [...]
        # }
        self.results = {}

        # Confirmation dictionary
        # confirm_dict[(group, rec_id, sweep_idx)] = bool
        self.confirm_dict = {}

        self.create_widgets()
        self.load_current_cell()

    def create_widgets(self):
        # Top bar: nav + zero
        top_frame = ttk.Frame(self.master)
        top_frame.pack(side=tk.TOP, fill=tk.X, padx=5, pady=5)

        btn_prev_cell = ttk.Button(top_frame, text="Prev Cell", command=self.go_to_prev_cell)
        btn_prev_cell.pack(side=tk.LEFT, padx=3)
        btn_next_cell = ttk.Button(top_frame, text="Next Cell", command=self.go_to_next_cell)
        btn_next_cell.pack(side=tk.LEFT, padx=3)

        btn_prev_sweep = ttk.Button(top_frame, text="Prev Sweep", command=self.go_to_prev_sweep)
        btn_prev_sweep.pack(side=tk.LEFT, padx=3)
        btn_next_sweep = ttk.Button(top_frame, text="Next Sweep", command=self.go_to_next_sweep)
        btn_next_sweep.pack(side=tk.LEFT, padx=3)

        btn_b_zero = ttk.Button(top_frame, text="Set Before=0", command=self.set_before_zero)
        btn_b_zero.pack(side=tk.LEFT, padx=3)
        btn_a_zero = ttk.Button(top_frame, text="Set After=0", command=self.set_after_zero)
        btn_a_zero.pack(side=tk.LEFT, padx=3)

        mid_frame = ttk.Frame(self.master)
        mid_frame.pack(side=tk.TOP, fill=tk.X, padx=5, pady=5)
        self.confirm_var = tk.BooleanVar(value=False)
        self.chk_confirm = ttk.Checkbutton(
            mid_frame, text="Confirm Current Sweep",
            variable=self.confirm_var,
            command=self.on_confirm_toggle
        )
        self.chk_confirm.pack(side=tk.LEFT, padx=5)

        # Figure area
        self.fig_frame = ttk.Frame(self.master)
        self.fig_frame.pack(side=tk.TOP, fill=tk.BOTH, expand=True)

        # Status label
        self.status_label = ttk.Label(self.master, text="No data yet.")
        self.status_label.pack(side=tk.TOP, padx=5, pady=5)

        # Save & Quit
        bottom_frame = ttk.Frame(self.master)
        bottom_frame.pack(side=tk.TOP, fill=tk.X, padx=5, pady=5)
        btn_save_quit = ttk.Button(bottom_frame, text="Save & Quit", command=self.save_and_quit)
        btn_save_quit.pack(side=tk.LEFT, padx=5)

        self.fig = None
        self.canvas = None

    # ---------------------------------------------------------------
    # Navigation
    # ---------------------------------------------------------------
    def go_to_prev_cell(self):
        if not self.allow_navigation():
            return
        self.current_cell_index -= 1
        if self.current_cell_index < 0:
            self.current_cell_index = 0
        self.load_current_cell()

    def go_to_next_cell(self):
        if not self.allow_navigation():
            return
        self.current_cell_index += 1
        if self.current_cell_index >= len(self.group_rec_pairs):
            self.current_cell_index = len(self.group_rec_pairs) - 1
        self.load_current_cell()

    def go_to_prev_sweep(self):
        if not self.allow_navigation():
            return
        self.current_sweep_index -= 1
        if self.current_sweep_index < 0:
            self.current_sweep_index = 0
        self.show_current_sweep()

    def go_to_next_sweep(self):
        if not self.allow_navigation():
            return
        self.current_sweep_index += 1
        max_sw = max(self.before_sweep_count, self.after_sweep_count)
        if self.current_sweep_index >= max_sw:
            self.current_sweep_index = max_sw - 1
        self.show_current_sweep()

    def allow_navigation(self):
        """Block nav unless user has confirmed this sweep."""
        group, rec_id = self.group_rec_pairs[self.current_cell_index]
        cval = self.confirm_dict.get((group, rec_id, self.current_sweep_index), False)
        if not cval:
            messagebox.showwarning("Sweep Not Confirmed",
                "Please confirm the current sweep before navigating away.")
            return False
        return True

    # ---------------------------------------------------------------
    # Confirm / Unconfirm
    # ---------------------------------------------------------------
    def on_confirm_toggle(self):
        group, rec_id = self.group_rec_pairs[self.current_cell_index]
        idx = self.current_sweep_index
        val = self.confirm_var.get()
        self.confirm_dict[(group, rec_id, idx)] = val

    def unconfirm_current_sweep(self):
        """If user drags or sets zero, unconfirm."""
        self.confirm_var.set(False)
        group, rec_id = self.group_rec_pairs[self.current_cell_index]
        idx = self.current_sweep_index
        self.confirm_dict[(group, rec_id, idx)] = False

    # ---------------------------------------------------------------
    # Loading a cell & sweeps
    # ---------------------------------------------------------------
    def load_current_cell(self):
        if self.current_cell_index < 0:
            self.current_cell_index = 0
        if self.current_cell_index >= len(self.group_rec_pairs):
            self.current_cell_index = len(self.group_rec_pairs)-1

        group, rec_id = self.group_rec_pairs[self.current_cell_index]
        sdf = self.dataframe[
            (self.dataframe["Group"] == group) &
            (self.dataframe["Recording_ID"] == rec_id)
        ]
        before_entry = sdf[sdf["Label"] == "Before"]
        after_entry  = sdf[sdf["Label"] == "After"]
        if before_entry.empty or after_entry.empty:
            print(f"Incomplete pair for {group}, {rec_id}")
            return

        before_file = before_entry["File_Path"].iloc[0]
        after_file  = after_entry["File_Path"].iloc[0]

        self.before_abf = pyabf.ABF(before_file)
        self.after_abf  = pyabf.ABF(after_file)

        self.before_sweep_count = len(self.before_abf.sweepList)
        self.after_sweep_count  = len(self.after_abf.sweepList)

        if (group, rec_id) not in self.results:
            self.results[(group, rec_id)] = {
                "before_counts":    [0]*self.before_sweep_count,
                "before_thresholds":[None]*self.before_sweep_count,
                "before_indices":   [None]*self.before_sweep_count,
                "after_counts":     [0]*self.after_sweep_count,
                "after_thresholds": [None]*self.after_sweep_count,
                "after_indices":    [None]*self.after_sweep_count,
            }

        self.current_sweep_index = 0
        self.show_current_sweep()

    def show_current_sweep(self):
        # Destroy old figure
        if self.fig and self.canvas:
            self.canvas.get_tk_widget().destroy()
            plt.close(self.fig)

        group, rec_id = self.group_rec_pairs[self.current_cell_index]
        data = self.results[(group, rec_id)]
        b_idx = min(self.current_sweep_index, self.before_sweep_count-1)
        a_idx = min(self.current_sweep_index, self.after_sweep_count-1)

        self.fig = Figure(figsize=(8,6))
        axB = self.fig.add_subplot(2,1,1)
        axA = self.fig.add_subplot(2,1,2)
        self.fig.suptitle(f"{group} - {rec_id} | Sweep {self.current_sweep_index}")

        # Plot BEFORE
        self.before_abf.setSweep(b_idx)
        tB, vB = self.before_abf.sweepX, self.before_abf.sweepY
        axB.plot(tB, vB, "C0")
        axB.set_title(f"Before (sweep {b_idx})")

        # Plot AFTER
        self.after_abf.setSweep(a_idx)
        tA, vA = self.after_abf.sweepX, self.after_abf.sweepY
        axA.plot(tA, vA, "C1")
        axA.set_title(f"After (sweep {a_idx})")
        axA.set_xlabel("Time (s)")

        b_thresh = data["before_thresholds"][b_idx]
        if b_thresh is None:
            b_thresh = np.mean(vB) + 0.5*np.std(vB)
        a_thresh = data["after_thresholds"][a_idx]
        if a_thresh is None:
            a_thresh = np.mean(vA) + 0.5*np.std(vA)

        # Make draggable lines
        # If user picks line => unconfirm
        self.line_before = DraggableThresholdLine(
            axB, b_thresh,
            on_release_callback=lambda val: self.update_peak_counts("before", val),
            on_drag_start=self.unconfirm_current_sweep,
            pick_tolerance_pixels=10
        )
        self.line_after = DraggableThresholdLine(
            axA, a_thresh,
            on_release_callback=lambda val: self.update_peak_counts("after", val),
            on_drag_start=self.unconfirm_current_sweep,
            pick_tolerance_pixels=10
        )

        self.canvas = FigureCanvasTkAgg(self.fig, master=self.fig_frame)
        self.canvas.draw()
        self.canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True)

        # Always do detection on load
        self.update_peak_counts("before", b_thresh, replot=False)
        self.update_peak_counts("after",  a_thresh, replot=False)
        self.draw_peaks()

        # Reflect confirm
        is_conf = self.confirm_dict.get((group, rec_id, self.current_sweep_index), False)
        self.confirm_var.set(is_conf)
        self.update_status_label()

    # ---------------------------------------------------------------
    # Peak counting
    # ---------------------------------------------------------------
    def update_peak_counts(self, subplot, threshold, replot=True):
        group, rec_id = self.group_rec_pairs[self.current_cell_index]
        data = self.results[(group, rec_id)]
        b_idx = min(self.current_sweep_index, self.before_sweep_count-1)
        a_idx = min(self.current_sweep_index, self.after_sweep_count-1)

        if subplot == "before":
            self.before_abf.setSweep(b_idx)
            volt = self.before_abf.sweepY
            pk, _ = find_peaks(volt, height=threshold)
            data["before_counts"][b_idx]     = len(pk)
            data["before_thresholds"][b_idx] = threshold
            data["before_indices"][b_idx]    = pk
        else:
            self.after_abf.setSweep(a_idx)
            volt = self.after_abf.sweepY
            pk, _ = find_peaks(volt, height=threshold)
            data["after_counts"][a_idx]     = len(pk)
            data["after_thresholds"][a_idx] = threshold
            data["after_indices"][a_idx]    = pk

        if replot:
            self.draw_peaks()
        self.update_status_label()

    def draw_peaks(self):
        """Re-plot both subplots with the final thresholds & peaks."""
        if not self.fig or not self.canvas:
            return

        group, rec_id = self.group_rec_pairs[self.current_cell_index]
        data = self.results[(group, rec_id)]
        b_idx = min(self.current_sweep_index, self.before_sweep_count-1)
        a_idx = min(self.current_sweep_index, self.after_sweep_count-1)

        axB, axA = self.fig.axes
        axB.clear()
        axA.clear()

        # BEFORE
        self.before_abf.setSweep(b_idx)
        tB, vB = self.before_abf.sweepX, self.before_abf.sweepY
        axB.plot(tB, vB, 'C0')
        b_thr = data["before_thresholds"][b_idx]
        b_pks = data["before_indices"][b_idx]
        if b_thr is not None:
            axB.axhline(b_thr, color='r', linestyle='--', lw=1)
        if b_pks is not None and len(b_pks) > 0:
            axB.plot(tB[b_pks], vB[b_pks], 'rx')
        axB.set_title(f"Before (sweep {b_idx})")

        # AFTER
        self.after_abf.setSweep(a_idx)
        tA, vA = self.after_abf.sweepX, self.after_abf.sweepY
        axA.plot(tA, vA, 'C1')
        a_thr = data["after_thresholds"][a_idx]
        a_pks = data["after_indices"][a_idx]
        if a_thr is not None:
            axA.axhline(a_thr, color='r', linestyle='--', lw=1)
        if a_pks is not None and len(a_pks) > 0:
            axA.plot(tA[a_pks], vA[a_pks], 'rx')
        axA.set_title(f"After (sweep {a_idx})")
        axA.set_xlabel("Time (s)")

        self.canvas.draw()

    # ---------------------------------------------------------------
    # Zeroing
    # ---------------------------------------------------------------
    def set_before_zero(self):
        self.unconfirm_current_sweep()
        group, rec_id = self.group_rec_pairs[self.current_cell_index]
        data = self.results[(group, rec_id)]
        b_idx = min(self.current_sweep_index, self.before_sweep_count-1)
        data["before_counts"][b_idx] = 0
        data["before_thresholds"][b_idx] = None
        data["before_indices"][b_idx] = []
        self.draw_peaks()
        self.update_status_label()

    def set_after_zero(self):
        self.unconfirm_current_sweep()
        group, rec_id = self.group_rec_pairs[self.current_cell_index]
        data = self.results[(group, rec_id)]
        a_idx = min(self.current_sweep_index, self.after_sweep_count-1)
        data["after_counts"][a_idx] = 0
        data["after_thresholds"][a_idx] = None
        data["after_indices"][a_idx] = []
        self.draw_peaks()
        self.update_status_label()

    # ---------------------------------------------------------------
    # Status & Saving
    # ---------------------------------------------------------------
    def update_status_label(self):
        group, rec_id = self.group_rec_pairs[self.current_cell_index]
        data = self.results[(group, rec_id)]
        bc = data["before_counts"]
        ac = data["after_counts"]
        b_str = ','.join(str(x) for x in bc)
        a_str = ','.join(str(x) for x in ac)
        txt = (f"Cell: {group}-{rec_id}\n"
               f"Before counts: [{b_str}]\n"
               f"After counts : [{a_str}]")
        self.status_label.config(text=txt)

    def save_and_quit(self):
        rows = []
        for (group, rec_id), val in self.results.items():
            bc = val["before_counts"]
            bt = val["before_thresholds"]
            bi = val["before_indices"]
            ac = val["after_counts"]
            at = val["after_thresholds"]
            ai = val["after_indices"]

            max_len = max(len(bc), len(ac))
            for i in range(max_len):
                # Convert arrays to strings, e.g. "[10, 22, 55]"
                bi_str = str(list(bi[i])) if bi[i] is not None else "[]"
                ai_str = str(list(ai[i])) if ai[i] is not None else "[]"

                row = {
                    "Group": group,
                    "Recording_ID": rec_id,
                    "Sweep_Index": i,
                    "Manual_Before_Count": bc[i],
                    "Before_Threshold":    bt[i],
                    "Before_Peak_Indices": bi_str,
                    "Manual_After_Count":  ac[i],
                    "After_Threshold":     at[i],
                    "After_Peak_Indices":  ai_str
                }
                rows.append(row)

        out_df = pd.DataFrame(rows)
        out_df.to_csv(self.output_csv, index=False)
        print(f"Saved {len(rows)} rows to {self.output_csv}")
        self.master.quit()


# -------------------
# Example usage
# -------------------
if __name__ == "__main__":
    # Example mock DataFrame

    df = pd.DataFrame(processor.dataframe)

    root = tk.Tk()
    app = ManualPeakCounterGUI(root, df, output_csv="manual_peak_counts.csv")
    root.mainloop()