<a href="https://colab.research.google.com/github/dtabuena/Patch_Ephys/blob/main/IV_analyzer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
def IV_analyzer_v3(abf,Na_window=[.2, 3],K_window=[40,50],to_plot=True,figopt={'type':'jpg'}):

    """
    1) Measures Na-peak (raw) in Na_window.
    2) Measures K-mean (raw) in K_window.
    3) Plots three subplots (Na-window, K full trace, IV).
       - Na-window trace is restricted to the window ± some margin.
       - K subplot is the entire trace, with a marker for the mean.
       - IV subplot has no legend and includes lines at x=0, y=0.

    Returns:
        results = {
            f"Na_{Na_window}": { voltage: raw_peak, ... },
            f"K_{K_window}":   { voltage: raw_mean, ... }
        }
    """

    is_base, is_stim = protocol_baseline_and_stim(abf)
    t0_relative = abf.sweepX[np.where(is_stim)[0][0]]*1000
    print(t0_relative)
    na_start, na_stop = [x + t0_relative for x in Na_window]
    k_start, k_stop   = [x + t0_relative for x in K_window]

    # Convert ms to seconds for the analysis windows
    na_start_s = na_start / 1000
    na_stop_s = na_stop / 1000
    k_start_s = k_start / 1000
    k_stop_s = k_stop / 1000

    # Calculate durations
    na_delta_s = na_stop_s - na_start_s
    k_delta_s = k_stop_s - k_start_s

    # Prepare data containers
    na_peaks = []
    na_voltages = []
    k_means = []
    k_voltages = []

    # ---- Create figure with 3 subplots: (Na-window, K full, IV) ----
    if to_plot:
        fig, (ax_na, ax_k, ax_iv) = plt.subplots(1, 3, figsize=(8,3))  # total figure size
    else:
        # Dummy axes if not plotting
        ax_na = ax_k = ax_iv = None

    # ----------------- NA MEASUREMENT & PLOT (SUBPLOT 1) ------------------ #
    # We'll restrict the visible x-axis to just around the Na window ± margin
    margin_fraction = 0.2
    na_plot_start_s = na_start_s - margin_fraction * na_delta_s
    na_plot_stop_s  = na_stop_s  + margin_fraction * na_delta_s

    for s in abf.sweepList:
        # Load the sweep for current (channel=0) and skip "blown" sweeps
        abf.setSweep(s, channel=0)
        if abs(abf.sweepY[-1]) > 1000:
            continue

        time = abf.sweepX
        current = abf.sweepY

        # Identify the analysis region for Na
        na_mask = (time >= na_start_s) & (time <= na_stop_s)
        if not np.any(na_mask):
            continue

        I_analysis = current[na_mask]
        t_analysis = time[na_mask]

        # Baseline is median within that window (for peak detection),
        # but we will report and plot raw peaks (no baseline subtraction).
        baseline = np.median(I_analysis)
        delta_i = I_analysis - baseline
        peak_idx = np.argmax(np.abs(delta_i))
        # raw peak = baseline + largest dev
        raw_peak_value = baseline + delta_i[peak_idx]
        raw_peak_time = t_analysis[peak_idx]

        # Now get the command voltage
        abf.setSweep(s, channel=1)
        v_hold = np.median(abf.sweepY)
        abf.setSweep(s, channel=0)
        cmd_trace = abf.sweepC + v_hold
        cmd_mask = na_mask
        command_v = np.median(cmd_trace[cmd_mask])
        command_v = round(command_v / 10) * 10  # round to nearest 10 mV if desired

        na_peaks.append(raw_peak_value)
        na_voltages.append(command_v)

        # -------------- PLOTTING --------------
        if ax_na is not None:
            # Plot only the window ± margin for Na
            na_plot_mask = (time >= na_plot_start_s) & (time <= na_plot_stop_s)
            ax_na.plot(time[na_plot_mask]*1000, current[na_plot_mask], 'k')
            # Mark the raw peak on the trace
            ax_na.scatter(raw_peak_time*1000, raw_peak_value, color='m',  zorder=5)

    # Set up the NA subplot
    if ax_na is not None:
        # ax_na.set_xlim([na_plot_start_s*1000, na_plot_stop_s*1000])
        ax_k.set_xlim(Na_window[0]-na_delta_s/2,Na_window[1]+na_delta_s/2)
        ax_na.set_xlabel("Time (ms)")
        ax_na.set_ylabel("I (pA)")
        ax_na.set_title(f"Na Window\n({Na_window[0]}-{Na_window[1]} ms)")

    # ----------------- K MEASUREMENT & PLOT (SUBPLOT 2) ------------------ #
    # For K, measure the mean in K_window but plot the entire sweep
    for s in abf.sweepList:
        abf.setSweep(s, channel=0)
        if abs(abf.sweepY[-1]) > 1000:
            continue

        time = abf.sweepX
        current = abf.sweepY

        # entire-sweep mask for plotting
        # (time >= 0) & (time <= time[-1]) would also be fine, but we'll just do everything
        full_mask = np.ones_like(time, dtype=bool)

        # measurement mask for K
        k_mask = (time >= k_start_s) & (time <= k_stop_s)
        if not np.any(k_mask):
            continue

        I_analysis = current[k_mask]
        I_mean = np.median(I_analysis)  # raw mean for K

        # command voltage
        abf.setSweep(s, channel=1)
        v_hold = np.median(abf.sweepY)
        abf.setSweep(s, channel=0)
        cmd_trace = abf.sweepC + v_hold
        command_v = np.median(cmd_trace[k_mask])
        command_v = round(command_v / 10) * 10

        k_means.append(I_mean)
        k_voltages.append(command_v)

        # -------------- PLOTTING --------------
        if ax_k is not None:
            # plot entire sweep
            ax_k.plot(time[full_mask]*1000, current[full_mask], 'k')
            # scatter the measured mean at the window's midpoint
            k_mid_t = (k_start_s + k_stop_s)/2
            ax_k.scatter(k_mid_t*1000, I_mean, color='c',  zorder=5)

    # Set up the K subplot
    if ax_k is not None:
        # ax_k.set_xlim([time[0]*1000, time[-1]*1000])  # entire sweep
        ax_k.set_xlim(K_window[0]-k_delta_s/2,K_window[1]+k_delta_s/2)
        ax_k.set_xlabel("Time (ms)")
        ax_k.set_ylabel("I (pA)")
        ax_k.set_title(f"K Window\n({K_window[0]}-{K_window[1]} ms)")

    # --------------------- BUILD RESULTS DICT (raw values) --------------------- #
    results = {}
    # Na
    key_na = f"Na_{Na_window}"
    results[key_na] = {}
    for v, peak in zip(na_voltages, na_peaks):
        results[key_na][v] = peak

    # K
    key_k = f"K_{K_window}"
    results[key_k] = {}
    for v, mean_i in zip(k_voltages, k_means):
        results[key_k][v] = mean_i

    # --------------------- I–V PLOT (SUBPLOT 3) --------------------- #
    if to_plot and ax_iv is not None:
        ax_iv.set_title("I–V")
        ax_iv.set_xlabel("Voltage (mV)")
        ax_iv.set_ylabel("I (pA)")

        # Sort by voltage for a cleaner line
        unique_na_voltages = sorted(set(na_voltages))
        unique_k_voltages  = sorted(set(k_voltages))

        # Extract in sorted order
        sorted_na_peaks = [results[key_na][v] for v in unique_na_voltages]
        sorted_k_means  = [results[key_k][v]   for v in unique_k_voltages]

        # Plot them
        ax_iv.plot(unique_na_voltages, sorted_na_peaks, '-o', color='m')
        ax_iv.plot(unique_k_voltages,  sorted_k_means,  '-o', color='c')

        # Add lines at x=0 and y=0
        ax_iv.axhline(0, color='k', linewidth=.25)
        ax_iv.axvline(0, color='k', linewidth=.25)

        # Remove legend if any
        ax_iv.get_legend_handles_labels()  # if you had a legend, remove it
        # No explicit call to legend -> ensures none is shown

    # ----------- Adjust layout, Save, and Return ----------- #
    if to_plot:
        plt.tight_layout()  # ensure everything fits
        os.makedirs('Saved_Figs/IV_Curves/', exist_ok=True)
        save_name = f"Saved_Figs/IV_Curves/IV_Curves_{abf.abfID}.{figopt['type']}"
        plt.savefig(save_name)
        plt.show()

    return results


In [None]:
def movmean(x, n=3):
    """Simple moving-average function to smooth the signal."""
    return np.convolve(x, np.ones(n)/n, mode='same')

In [None]:
def pseudo_pn_subtraction(abf,to_plot=False):
    'Uses one of the low pulses to estimate P/N subtraction'

    sweep_df = pd.DataFrame({'sweep':abf.sweepList}).set_index('sweep')
    sweep_df['sweepX'] = None
    sweep_df['sweepY'] = None
    sweep_df['sweepY_pn'] = None
    sweep_df['sweepC'] = None
    sweep_df['stim'] = None
    sweep_df['N'] = None

    # measured v_hold
    abf.setSweep(sweepNumber=0, channel=1)
    v_hold = round(abf.sweepY[0]/10)*10
    abf.setSweep(sweepNumber=0, channel=0)


    for s in abf.sweepList:
        abf.setSweep(s)
        sweep_df.at[s,'sweepX'] = abf.sweepX
        sweep_df.at[s,'sweepY'] = abf.sweepY
        sweep_df.at[s,'sweepC'] = abf.sweepC
        sweep_df.at[s,'stim'] = max(abf.sweepC.min(), abf.sweepC.max(), key=abs)
        sweep_df.at[s,'Vm'] = sweep_df.at[s,'stim'] + v_hold

    stim_list = list(sweep_df['stim'])
    delta = np.min([abs(v) for v in stim_list if not v==0])

    if len(abf.sweepList)<7:
        return sweep_df

    sweep_df['N'] = sweep_df['stim']/delta
    is_p = list(abs(sweep_df['N'])==1)
    p_list = []
    for i in range(len(is_p)):
        if is_p[i]:
            p_trace =  sweep_df.loc[i,'sweepY'] * sweep_df.loc[i,'N']
            p_trace = p_trace - p_trace[0]
            p_list.append(  p_trace )
    p_val = np.mean(  np.stack(p_list, axis=0)  , axis=0 )
    if to_plot: fig, ax = plt.subplots(len(abf.sweepList),figsize=(2.5, 50))
    x_lim = [.015, 0.02]
    y_lim = [ -20000, 10000]
    for s in abf.sweepList:
        correction = (p_val * sweep_df.loc[s,'N'])
        sweep_df.at[s,'sweepY_pn'] = sweep_df.at[s,'sweepY'] - correction
        sweep_df.at[s,'sweepY_pn'] = sweep_df.at[s,'sweepY_pn'] - sweep_df.at[s,'sweepY_pn'][0]
        if to_plot:
            ax[s].plot(sweep_df.loc[s,'sweepX'],sweep_df.loc[s,'sweepY'],'k')
            ax[s].plot(sweep_df.loc[s,'sweepX'],sweep_df.loc[s,'sweepY_pn'],'m')
            ax[s].plot(sweep_df.loc[s,'sweepX'],correction,'c')
            ax[s].set_xlim(*x_lim)
            ax[s].set_title( sweep_df.loc[s,'stim'] )
    if to_plot: plt.show()

    return sweep_df



def iv_analysis_df(sweep_df, abf,measure_windows={'IV_Early':(16, 35),'IV_Steady_State':(100,120)},to_plot=False,pn=True,figopt={'type':'jpg','dpi':300}):
    results = {}
    if pn: y_col = 'sweepY_pn'
    else:  y_col = 'sweepY'
    if to_plot:
        fig, axs = plt.subplots(2,len(measure_windows),figsize = [4.5,3] )
    for w_i in range(len(measure_windows.keys())):
        w = list(measure_windows.keys())[w_i]
        w_range_s = np.array(measure_windows[w])/1000
        plot_range = (np.diff(w_range_s) * np.array([-1, 1]) * .5) + w_range_s
        in_window_bool = np.logical_and(sweep_df.loc[0,'sweepX'] >= w_range_s[0] ,  sweep_df.loc[0,'sweepX'] <= w_range_s[1])
        in_window_i = [int(i) for i in range(len(in_window_bool)) if in_window_bool[i]]

        # max
        # print(plot_range)
        peak_list = []
        median_list = []
        stim_list = []

        for s in sweep_df.index:
            if abs(sweep_df.loc[s,y_col][-1])>1000:
                'bad sweep, bail'
                continue
            oscl_check = np.max(abs(sweep_df.loc[s,y_col]))>10000
            if oscl_check:
                'bad sweep, bail'
                continue

            neg_peak = min(sweep_df.loc[s,y_col][in_window_i])
            wind_median = np.median(sweep_df.loc[s,y_col][in_window_i])
            peak_ind = np.where(sweep_df.loc[s,y_col] == neg_peak )[0][0]
            if to_plot:
                axs[0,w_i].plot(sweep_df.loc[s,'sweepX'],sweep_df.loc[s,y_col],'k')
                axs[0,w_i].set_xlim(*plot_range)
                axs[0,w_i].scatter(sweep_df.loc[s,'sweepX'][peak_ind],sweep_df.loc[s,y_col][peak_ind],color='m')
                med_line = wind_median*np.ones_like( in_window_i )
                axs[0,w_i].plot( sweep_df.loc[s,'sweepX'][in_window_i], med_line,'turquoise')

            peak_list.append(neg_peak)
            median_list.append(wind_median)
            stim_list.append(sweep_df.loc[s,'Vm']) # iterate the list to garantee matching stims and responses

        results[w] = {'range':measure_windows[w],
                        'I_peak':peak_list,
                        'I_mean':median_list,
                        'V_stim':stim_list}

        if to_plot:
            axs[1,w_i].spines['left'].set_position('zero')
            axs[1,w_i].spines['bottom'].set_position('zero')
            axs[1,w_i].plot(stim_list,peak_list,'-o',color='m')
            axs[1,w_i].plot(stim_list,median_list,'-o',color='turquoise')


    if to_plot:
        try:    os.makedirs('Saved_Figs/IV_Curves/')
        except:     None
        plt.savefig( 'Saved_Figs/IV_Curves/IV_Curves'+'_' + abf.abfID +'.'+figopt['type'],dpi=figopt['dpi'])
        plt.show()
    return results




def check_wc_comp(abf,cut_off=40,verbose =False,to_plot=False):
    sweep_df = pseudo_pn_subtraction(abf)
    sweep_df['command_real']=None
    is_compensated = 'unknown'
    is_hold_ind = [ i for i in range(len(sweep_df.at[0,'sweepC'])) if  sweep_df.at[0,'sweepC'][i]==sweep_df.at[0,'sweepC'][0]]
    is_stim =  1*np.logical_not(sweep_df.at[0,'sweepC']==sweep_df.at[0,'sweepC'][0])
    offset_ind = np.where( np.diff( is_stim, prepend = is_stim[0]) == -1)[0][0]
    window_ms = [-0.5, 1]
    window_I_range = [int(w*abf.sampleRate/1000) for w in window_ms]
    window_I = np.arange(*window_I_range)
    # print(window_I)
    offset_range = list(offset_ind + window_I)
    s = 0
    abf.setSweep(sweepNumber=s, channel=1)
    sweep_df.at[s,'command_real'] = abf.sweepY
    command_real_bal = sweep_df.at[s,'command_real'][offset_range] - sweep_df.at[s,'command_real'][0]
    sweepC = sweep_df.at[s,'sweepC'][offset_range]
    sum_delta = np.sum(abs(command_real_bal- sweepC))

    if to_plot:
        plt.plot( sweep_df.at[s,'command_real'][offset_range] - sweep_df.at[s,'command_real'][0] )
        plt.plot( sweep_df.at[s,'sweepC'][offset_range] )
        plt.show()

    # sum_delta = np.sum((sweep_df.at[0,'command_real'][offset_range]-sweep_df.at[0,'command_real'][0]) - sweep_df.at[0,'sweepC'][offset_range])
    if verbose: print(sum_delta)
    is_compensated = sum_delta>cut_off
    return is_compensated, sum_delta


# abf = abf_or_name('my_ephys_data/2022_08_15/2022x08x15_RNF182_E4KI_F_P251_s001_c001_CA3xPOS_0002.abf')
# _ = IV_analyzer_v2(abf)

In [None]:
def collect_valid_sweeps(abf, is_base, leak_threshold):
    """
    Collects all valid sweeps, filtering out blown seals and high leak.

    Parameters:
        abf: pyabf object
        is_base: baseline mask from protocol_baseline_and_stim
        leak_threshold: float, minimum acceptable baseline current in pA

    Returns:
        valid_sweeps: list of dicts with keys: 'sweep', 'time', 'current',
                      'voltage', 'v_hold'
    """
    valid_sweeps = []

    for s in abf.sweepList:
        abf.setSweep(s, channel=0)

        # Skip blown seals
        if abs(abf.sweepY[-1]) > 1000:
            continue

        # Check baseline leak current
        baseline_current = np.median(abf.sweepY[is_base])
        if baseline_current < leak_threshold:
            continue

        # Get command voltage
        abf.setSweep(s, channel=1)
        v_hold = np.median(abf.sweepY)
        abf.setSweep(s, channel=0)
        cmd_trace = abf.sweepC + v_hold
        command_v = np.median(cmd_trace[~is_base])
        command_v = round(command_v / 10) * 10

        valid_sweeps.append({
            'sweep': s,
            'time': abf.sweepX.copy(),
            'current': abf.sweepY.copy(),
            'voltage': command_v,
            'v_hold': v_hold
        })

    return valid_sweeps


In [None]:
def IV_analyzer_v4(abf, Na_window=[.2, 3], K_window=[40,50], to_plot=True,
                   figopt={'type':'jpg'}, leak_threshold=-200, use_PN=True, n_pulses=3):

    """
    1) Collects valid sweeps (filters blown seals and high leak).
    2) Optionally performs P/N subtraction using n_pulses most negative sweeps.
    3) Measures Na-peak in Na_window.
    4) Measures K-mean in K_window.
    5) Plots three subplots (Na-window, K full trace, IV).

    Parameters:
        leak_threshold : float, default=-200
            Minimum acceptable baseline current in pA.
        use_PN : bool, default=True
            Whether to perform P/N subtraction for leak correction.
        n_pulses : int, default=3
            Number of most negative pulses to use for P/N subtraction template.

    Returns:
        results = {
            f"Na_{Na_window}": { voltage: peak, ... },
            f"K_{K_window}":   { voltage: mean, ... }
        }
    """

    is_base, is_stim = protocol_baseline_and_stim(abf)
    t0_relative = abf.sweepX[np.where(is_stim)[0][0]]*1000
    print(t0_relative)
    na_start, na_stop = [x + t0_relative for x in Na_window]
    k_start, k_stop   = [x + t0_relative for x in K_window]

    # Convert ms to seconds for the analysis windows
    na_start_s = na_start / 1000
    na_stop_s = na_stop / 1000
    k_start_s = k_start / 1000
    k_stop_s = k_stop / 1000

    # Calculate durations
    na_delta_s = na_stop_s - na_start_s
    k_delta_s = k_stop_s - k_start_s

    # Collect valid sweeps
    valid_sweeps = collect_valid_sweeps(abf, is_base, leak_threshold)

    # Optionally perform P/N subtraction
    if use_PN:
        sweeps, PN_info = PN_subtraction(valid_sweeps, n_pulses, is_base)
    else:
        sweeps = valid_sweeps

    # Prepare data containers
    na_peaks = []
    na_voltages = []
    k_means = []
    k_voltages = []

    # ---- Create figure with 3 subplots: (Na-window, K full, IV) ----
    if to_plot:
        fig, (ax_na, ax_k, ax_iv) = plt.subplots(1, 3, figsize=(8,3))
    else:
        ax_na = ax_k = ax_iv = None

    # ----------------- NA MEASUREMENT & PLOT (SUBPLOT 1) ------------------ #
    margin_fraction = 0.2
    na_plot_start_s = na_start_s - margin_fraction * na_delta_s
    na_plot_stop_s  = na_stop_s  + margin_fraction * na_delta_s

    for sweep_data in sweeps:
        time = sweep_data['time']
        current = sweep_data['current']
        command_v = sweep_data['voltage']

        # Identify the analysis region for Na
        na_mask = (time >= na_start_s) & (time <= na_stop_s)
        if not np.any(na_mask):
            continue

        I_analysis = current[na_mask]
        t_analysis = time[na_mask]

        # Baseline is median within that window (for peak detection)
        baseline = np.median(I_analysis)
        delta_i = I_analysis - baseline
        peak_idx = np.argmax(np.abs(delta_i))
        # peak = baseline + largest dev
        peak_value = baseline + delta_i[peak_idx]
        peak_time = t_analysis[peak_idx]

        na_peaks.append(peak_value)
        na_voltages.append(command_v)

        # -------------- PLOTTING --------------
        if ax_na is not None:
            na_plot_mask = (time >= na_plot_start_s) & (time <= na_plot_stop_s)
            ax_na.plot(time[na_plot_mask]*1000, current[na_plot_mask], 'k')
            ax_na.scatter(peak_time*1000, peak_value, color='m',  zorder=5)

    # Set up the NA subplot
    if ax_na is not None:
        ax_na.set_xlim([na_plot_start_s*1000, na_plot_stop_s*1000])
        ax_na.set_xlabel("Time (ms)")
        ax_na.set_ylabel("I (pA)")
        ax_na.set_title(f"Na Window\n({Na_window[0]}-{Na_window[1]} ms)")

    # ----------------- K MEASUREMENT & PLOT (SUBPLOT 2) ------------------ #
    for sweep_data in sweeps:
        time = sweep_data['time']
        current = sweep_data['current']
        command_v = sweep_data['voltage']

        # measurement mask for K
        k_mask = (time >= k_start_s) & (time <= k_stop_s)
        if not np.any(k_mask):
            continue

        I_analysis = current[k_mask]
        I_mean = np.median(I_analysis)

        k_means.append(I_mean)
        k_voltages.append(command_v)

        # -------------- PLOTTING --------------
        if ax_k is not None:
            ax_k.plot(time*1000, current, 'k')
            k_mid_t = (k_start_s + k_stop_s)/2
            ax_k.scatter(k_mid_t*1000, I_mean, color='c',  zorder=5)

    # Set up the K subplot
    if ax_k is not None:
        ax_k.set_xlim([K_window[0]-k_delta_s*500, K_window[1]+k_delta_s*500])
        ax_k.set_xlabel("Time (ms)")
        ax_k.set_ylabel("I (pA)")
        ax_k.set_title(f"K Window\n({K_window[0]}-{K_window[1]} ms)")

    # --------------------- BUILD RESULTS DICT --------------------- #
    results = {}
    # Na
    key_na = f"Na_{Na_window}"
    results[key_na] = {}
    for v, peak in zip(na_voltages, na_peaks):
        results[key_na][v] = peak

    # K
    key_k = f"K_{K_window}"
    results[key_k] = {}
    for v, mean_i in zip(k_voltages, k_means):
        results[key_k][v] = mean_i

    # --------------------- I–V PLOT (SUBPLOT 3) --------------------- #
    if to_plot and ax_iv is not None:
        ax_iv.set_title("I–V")
        ax_iv.set_xlabel("Voltage (mV)")
        ax_iv.set_ylabel("I (pA)")

        # Sort by voltage for a cleaner line
        unique_na_voltages = sorted(set(na_voltages))
        unique_k_voltages  = sorted(set(k_voltages))

        # Extract in sorted order
        sorted_na_peaks = [results[key_na][v] for v in unique_na_voltages]
        sorted_k_means  = [results[key_k][v]   for v in unique_k_voltages]

        # Plot them
        ax_iv.plot(unique_na_voltages, sorted_na_peaks, '-o', color='m')
        ax_iv.plot(unique_k_voltages,  sorted_k_means,  '-o', color='c')

        # Add lines at x=0 and y=0
        ax_iv.axhline(0, color='k', linewidth=.25)
        ax_iv.axvline(0, color='k', linewidth=.25)

    # ----------- Adjust layout, Save, and Return ----------- #
    if to_plot:
        plt.tight_layout()
        os.makedirs('Saved_Figs/IV_Curves/', exist_ok=True)
        save_name = f"Saved_Figs/IV_Curves/IV_Curves_{abf.abfID}.{figopt['type']}"
        plt.savefig(save_name)
        plt.show()

    return results

In [None]:




def PN_subtraction(valid_sweeps, n_pulses, is_base):
    """
    Performs P/N subtraction using the N most negative pulses as leak template.
    Baseline-subtracts the leak template to avoid magnifying DC leak.

    Parameters:
        valid_sweeps: list of sweep dicts from collect_valid_sweeps
        n_pulses: int, number of most negative pulses to use for template
        is_base: baseline mask from protocol_baseline_and_stim

    Returns:
        corrected_sweeps: list of sweep dicts with 'current' replaced by corrected values
        PN_info: dict with 'voltages', 'template', 'avg_PN_voltage', and 'holding_potential'
    """

    # Sort by voltage and select n_pulses most negative
    sorted_sweeps = sorted(valid_sweeps, key=lambda x: x['voltage'])
    PN_sweeps = sorted_sweeps[:n_pulses]

    # Create averaged leak template
    v_hold = PN_sweeps[0]['v_hold']
    template_currents = [s['current'] for s in PN_sweeps]
    leak_template = np.mean(template_currents, axis=0)

    # Find separate baseline regions
    base_indices = np.where(is_base)[0]

    # Find breaks in baseline (where diff > 1 indicates separate regions)
    breaks = np.where(np.diff(base_indices) > 1)[0]

    if len(breaks) > 0:
        # Multiple baseline regions exist
        first_region = base_indices[:breaks[0]+1]
        last_region = base_indices[breaks[-1]+1:]
    else:
        # Single continuous baseline region
        first_region = base_indices
        last_region = base_indices

    # Take 10% from each region
    n_first = max(1, int(len(first_region) * 0.1))
    n_last = max(1, int(len(last_region) * 0.1))

    first_base_indices = first_region[:n_first]
    last_base_indices = last_region[-n_last:]

    template_baseline = (np.mean(leak_template[first_base_indices]) +
                         np.mean(leak_template[last_base_indices])) / 2
    leak_template = leak_template - template_baseline

    PN_voltages = [s['voltage'] for s in PN_sweeps]
    avg_PN_voltage = np.mean(PN_voltages)

    # Apply P/N subtraction to all sweeps
    corrected_sweeps = []

    for sweep_info in valid_sweeps:
        V_test = sweep_info['voltage']
        current = sweep_info['current']

        # Scale factor based on voltage
        scale = (V_test - v_hold) / (avg_PN_voltage - v_hold)

        # Subtract scaled leak template
        corrected_current = current - (leak_template * scale)

        # Create new dict with corrected current
        corrected_sweep = sweep_info.copy()
        corrected_sweep['current'] = corrected_current
        corrected_sweeps.append(corrected_sweep)

    PN_info = {
        'voltages': PN_voltages,
        'template': leak_template,
        'holding_potential': v_hold,
        'avg_PN_voltage': avg_PN_voltage
    }

    # Create diagnostic plots
    fig, ax = plt.subplots(3, 1, figsize=(3, 3))

    time = valid_sweeps[0]['time'] * 1000  # convert to ms

    # Plot 1: Raw traces
    for sweep_info in valid_sweeps:
        ax[0].plot(time, sweep_info['current'], 'k')
    ax[0].set_ylabel("I (pA)")
    ax[0].set_title("Raw Traces")

    # Plot 2: Calculated correction for 10mV
    test_voltage = v_hold + 10
    scale_10mV = (test_voltage - v_hold) / (avg_PN_voltage - v_hold)
    correction_10mV = leak_template * scale_10mV
    ax[1].plot(time, correction_10mV, 'r')

    # Shade the baseline regions that were used
    ax[1].axvspan(time[first_base_indices[0]], time[first_base_indices[-1]],
                  color='gray', zorder=0)
    ax[1].axvspan(time[last_base_indices[0]], time[last_base_indices[-1]],
                  color='gray', zorder=0)

    ax[1].set_ylabel("I (pA)")
    ax[1].set_title("Correction for +10mV")

    # Plot 3: Corrected sweeps
    for corrected_sweep in corrected_sweeps:
        ax[2].plot(time, corrected_sweep['current'], 'k')
    ax[2].set_ylabel("I (pA)")
    ax[2].set_xlabel("Time (ms)")
    ax[2].set_title("Corrected Sweeps")

    plt.tight_layout()
    plt.show()

    return corrected_sweeps, PN_info