In [None]:
import efel
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import scipy
from scipy.signal import find_peaks
from scipy import signal
from scipy import stats
import seaborn as sns
import glob
import matplotlib
sns.set_style("ticks")
sns.set_context("paper")
plt.rcParams['font.family'] = 'Arial'
#plt.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['pdf.fonttype'] = 42

In [None]:
# define some functions


def find_spike_times(voltages, dt, detection_level, min_interval):
    spike_times = []
    min_amplitude = -10
    last_spike_time = -min_interval
    for i in range(1, len(voltages)):
        t = i * dt
        if (voltages[i - 1] < detection_level <= voltages[i]) and (voltages[i] - voltages[i - 1] >= min_amplitude) and (t - last_spike_time >= min_interval):
            spike_times.append(t)
            last_spike_time = t
    return spike_times


def group_cvs(values, group_size):
    cvs = []
    if len(values) >= group_size:
        for i in range(0, len(values) - group_size, group_size):
            mu = np.mean(values[i:i + group_size])
            sigma = np.std(values[i:i + group_size])
            cvs.append(sigma / mu)
    return cvs


def segment(values, dx, x_min, x_max):
    return values[round(x_min / dx):round(x_max / dx)]


def find_slopes(values, dx):
    diffs = np.diff(values)
    slopes = [0] * len(values)
    slopes[0] = diffs[0] / dx
    slopes[-1] = diffs[-1] / dx
    for i in range(1, len(values) - 1):
        slopes[i] = (diffs[i - 1] + diffs[i]) / (2 * dx)
    return (slopes)


#use neo to import either voltage or current clamp data in the correct, scaled units!
def load_neo_file(file_name, **kwargs):
    import neo
    reader = neo.io.get_io(file_name)
    blocks = reader.read(**kwargs)
    new_blocks = []
    for bl in blocks:
        new_segments = []
        for seg in bl.segments:
            traces = []
            count_traces = 0
            analogsignals = seg.analogsignals

            for sig in analogsignals:
                traces.append({})
                traces[count_traces]['T'] = sig.times.rescale('ms').magnitude
                #need to write an if statement here for conversion
                try:
                    traces[count_traces]['A'] = sig.rescale('pA').magnitude
                except:
                    traces[count_traces]['V'] = sig.rescale('mV').magnitude
                count_traces += 1
            new_segments.append(traces)
        #new_blocks.append(efel_segments)
    return new_segments


#calculates the dvdt requires path and which sweep to calculate off
def dvdt(path, sweep):
    table2 = []
    os.chdir(path)
    for filename in os.listdir():
        #check whether file is in the axgx or axgd format
        if filename.endswith(".axgd") or filename.endswith(".axgx"):
            [traces] = efel.io.load_neo_file(filename, stim_start=0, stim_end=10000)
            for data in traces[sweep]:
                times = (data['T']) / 1000
                voltages = (data['V'])
                times -= times[0]
                dt = times[2] - times[1]
                detection_level = 0
                min_interval = 0.0001
                spike_times = find_spike_times(voltages, dt, detection_level, min_interval)
                isis = np.diff(spike_times)
                time_before = .018
                time_after = .015
                times_rel = list(np.arange(-time_before, time_after, dt))
                spike_voltages = []
                for i in range(0, len(spike_times)):
                    if time_before < spike_times[i] < times[-1] - time_after:
                        spike_voltages.append(
                            segment(voltages, dt, spike_times[i] - time_before, spike_times[i] + time_after))
                spike_voltage_arrays = [np.array(x) for x in spike_voltages]
                mean_spike_voltages = [np.mean(k) for k in zip(*spike_voltage_arrays)]
                dvdt_threshold = 10
                dvdt = find_slopes(mean_spike_voltages, dt)
                i = 1
                while dvdt[i] < dvdt_threshold:
                    i += 1
                v0 = mean_spike_voltages[i - 1]
                v1 = mean_spike_voltages[i]
                dvdt0 = dvdt[i - 1]
                dvdt1 = dvdt[i]
                v_threshold = v0 + (v1 - v0) * (dvdt_threshold - dvdt0) / (dvdt1 - dvdt0)
                pandas_dvdt = pd.DataFrame(dvdt)
                pandas_dvdt.rename(columns={0: filename}, inplace=True)  #naming the columns!
                pandas_membrane_voltages = pd.DataFrame(mean_spike_voltages)
            table2.append(pandas_dvdt)
            df_concat = pd.concat(table2, axis=1)

            df_concat.to_excel('dvdt' + 'master_file.xlsx', index=False)
    return (df_concat)









#this is the workhorse function. It requires the path for a folder that contains perf patch spont firing axograph files
#it also requires sweep start and end and time start and time end. These are important because the sweeps are
#split half in unperturbed firing and half with zero mean noise injection, which are not important for the current
#study
#the code returns many dataframes that are used for downstream analysis including plotting and calculation of 
#mean and minimum voltages
#these use cases are exampled below
def collect_isis_global_excel(path1, sample_rate, sweep_start, sweep_end, time_start_s, time_end_s, metadata):
    os.chdir(path1)
    num_points = 1000
    sampling_rate = sample_rate
    time_start = time_start_s * sampling_rate
    time_end = time_end_s * sampling_rate
    global_mean_traces = []
    global_mean_isi = []
    filename_list = []
    file_names = sorted(os.listdir())  # Sort the file names

    for filename in file_names:
        if filename.endswith(".axgd") or filename.endswith(".axgx"):
            print('Working on ' + filename)
            filename_list.append(filename)
            [traces] = efel.io.load_neo_file(filename, stim_start=0, stim_end=10000)
            traces = traces[sweep_start:sweep_end]
            if len(traces) > 0:
                file_results = []
                average_isi = []
                for trace in traces:
                    for p in trace:
                        times = (p['T'])
                        times = times[time_start:time_end]  #selected for the first 10 seconds of each sweep
                        voltages = (p['V']).flatten()
                        voltages = voltages[time_start:time_end]  #selected for the first 10 seconds of each sweep
                        times -= times[0]
                        dt = times[2] - times[1]
                        detection_level = -10
                        min_interval = 0.0001
                        spike_times = find_spike_times(voltages, dt, detection_level, min_interval)
                        interspike_intervals = np.diff(spike_times)
                        if len(interspike_intervals) > 0:
                            avg_isi_1 = np.mean(interspike_intervals)
                            average_isi.append(avg_isi_1)
                            resampled_isi_voltage_arrays = []
                            for i in range(len(interspike_intervals)):
                                start_time = spike_times[i]
                                end_time = spike_times[i] + interspike_intervals[i]
                                times_rel = np.linspace(0, 1, num_points)
                                isi_voltage_array = segment(voltages, dt, start_time, end_time)
                                resampled_isi_voltage_arrays.append(
                                    np.interp(times_rel, np.linspace(0, 1, len(isi_voltage_array)), isi_voltage_array))
                            mean_resampled_isi_voltages = np.mean(
                                np.concatenate(resampled_isi_voltage_arrays).reshape(len(resampled_isi_voltage_arrays),
                                                                                     num_points), axis=0)
                            file_results.append(mean_resampled_isi_voltages)
            isi_file_mean = np.mean(np.array(average_isi))
            global_mean_isi.append(isi_file_mean)
            file_mean = np.mean(np.array(file_results), axis=0)
            global_mean_traces.append(file_mean)
    filename_df = pd.DataFrame(filename_list)
    filename_df.rename(columns={0: 'filename'}, inplace=True)
    global_mean_isis = np.array(global_mean_isi)
    global_mean_isi_df = pd.DataFrame(global_mean_isis)
    global_mean_isi_df = pd.concat([filename_df, global_mean_isi_df], axis=1)
    global_mean_isi_df.set_index('filename', inplace=True)
    global_mean_isi_df.sort_index(inplace=True)
    global_mean_isis = global_mean_isi_df.values[:, 0]
    global_mean_isi_df.to_excel(metadata + '_global_mean_isi.xlsx')
    isi_final_mean = np.mean(global_mean_isi, axis=0)
    isi_final_mean = np.array([[isi_final_mean]])
    isi_final_mean_df = pd.DataFrame(isi_final_mean)
    display(isi_final_mean_df)
    isi_final_mean_df.to_excel(metadata + '_isi_final_mean.xlsx')

    global_mean_traces = np.array(global_mean_traces)
    global_mean_traces_df = pd.DataFrame(global_mean_traces)
    global_mean_traces_df = pd.concat([filename_df, global_mean_traces_df], axis=1)
    global_mean_traces_df.set_index('filename', inplace=True)
    global_mean_traces_df.sort_index(inplace=True)
    global_mean_traces = np.array(global_mean_traces_df)
    global_mean_traces_df.to_excel(metadata + '_global_mean_traces.xlsx')

    final_mean_trace = np.mean(global_mean_traces, axis=0)
    final_mean_trace_df = pd.DataFrame(final_mean_trace)
    final_mean_trace_df.to_excel(metadata + '_final_mean_trace.xlsx')

    file_results = np.array(file_results)
    file_results_df = pd.DataFrame(file_results)
    file_results_df = pd.concat([filename_df, file_results_df], axis=1)
    file_results_df.set_index('filename', inplace=True)
    file_results_df.sort_index(inplace=True)
    file_results_df.to_excel(metadata + '_total_trajectories.xlsx')
    print('Analysis complete')
    return final_mean_trace, isi_final_mean, global_mean_traces, file_results, global_mean_isis



#
def plot_trajectory_independent(mean_trace, avg_isi, total_trajectory_voltages, color):
    dVdt = np.gradient(mean_trace, axis=0)
    slope = np.mean(dVdt / 0.1)
    avg_isis = np.linspace(0, avg_isi * len(mean_trace), len(mean_trace))
    ci = 1.96 * np.std(np.array(total_trajectory_voltages), axis=0) / np.sqrt(len(total_trajectory_voltages))

    # calculate duration of each trace
    duration = len(mean_trace) * avg_isi

    # calculate scaling factor for x-axis
    scale = duration / len(mean_trace)

    # plot mean trace with actual duration
    fig, ax = plt.subplots()
    ax.plot(np.arange(0, duration, scale), mean_trace, color=color, label='Mean Trace')
    #ax.fill_between(np.arange(0, duration, scale), mean_trace - ci, mean_trace + ci, alpha=0.3, color='#848482', label='95% CI')
    ax.set_xlabel('Time from spike (ms)')
    ax.set_ylabel('Membrane Potential (mV)')
    ax.set_title('ISI voltage trajectories')

    ax.legend(loc='lower right', borderpad=0.1, labelspacing=.3)
    plt.ylim(-65, -30)
    plt.show()


def plot_voltage_trajectories_two(mean_trace1, mean_trace2, avg_isi1, avg_isi2, total_trajectory_voltages1, total_trajectory_voltages2, mean_trace1_label, mean_trace2_label, mean_trace1_color, mean_trace2_color, save, metadata):

    # Calculate first baseline statistics
    dVdt1 = np.gradient(mean_trace1, axis=0)
    slope1 = np.mean(dVdt1 / 0.1)
    avg_isis1 = np.linspace(0, avg_isi1 * len(mean_trace1), len(mean_trace1))
    #ci1 = 1.96 * np.std(np.array(total_trajectory_voltages1), axis=0) / np.sqrt(len(total_trajectory_voltages1))

    # Calculate second baseline statistics
    dVdt2 = np.gradient(mean_trace2, axis=0)
    slope2 = np.mean(dVdt2 / 0.1)
    avg_isis2 = np.linspace(0, avg_isi2 * len(mean_trace2), len(mean_trace2))
    #ci2 = 1.96 * np.std(np.array(total_trajectory_voltages2), axis=0) / np.sqrt(len(total_trajectory_voltages2))

    # Calculate duration of each trace
    num_points = 1000
    duration1 = (len(mean_trace1) * avg_isi1) / num_points-1
    duration2 = (len(mean_trace2) * avg_isi2) / num_points-1




    # Calculate scaling factors for x-axis
    scale1 = duration1 / len(mean_trace1)
    scale2 = duration2 / len(mean_trace2)


    # plot mean traces with actual durations
    fig, ax = plt.subplots(figsize=(2.75, 2.75))
    ax.plot(np.arange(0, duration1, scale1)[:1000], mean_trace1, color=mean_trace1_color, label=mean_trace1_label, linewidth=0.5)
    ax.plot(np.arange(0, duration2, scale2)[:1000], mean_trace2, color=mean_trace2_color, label=mean_trace2_label, linewidth=0.5)
    # ax.fill_between(np.arange(0, duration1, scale1), mean_trace1 - ci1, mean_trace1 + ci1, alpha=0.3,
    #                 color='#848482', label='95% CI (Trace 1)')
    # ax.fill_between(np.arange(0, duration2, scale2), mean_trace2 - ci2, mean_trace2 + ci2, alpha=0.3,
    #                 color='red', label='95% CI (Trace 2)')
    ax.set_xlabel('Time from spike (ms)', fontsize=9)
    ax.set_ylabel('Membrane Potential (mV)', fontsize=9)
    # ax.set_title('ISI voltage trajectories')
    ax.legend(loc='lower right', borderpad=0.4, labelspacing=.1, fontsize=7)
    ax.tick_params(axis='both', which='both', labelsize=8, width=0.5)  # You can adjust the font size as needed

    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['bottom'].set_linewidth(0.50)
    ax.spines['left'].set_linewidth(0.50)

    min_value = min(np.min(mean_trace1), np.min(mean_trace2))
    #ax.legend(fontsize=14)
    plt.legend(frameon=False, edgecolor='none', fontsize='large', handlelength=2, handleheight=1, loc='lower right')
    #set major ticks on the x and y axis

    #x.tick_params(axis='x', labelsize=14)
    plt.ylim(min_value-1, -30)

    if save == True:
        os.chdir('/Users/HBLANKEN/Library/CloudStorage/OneDrive-UniversityofOklahoma/Beckstead Lab/DA-AD paper files/Grand collection of axograph files/Noise Traces/Figures')
        plt.savefig(metadata + '.pdf', dpi=600, bbox_inches='tight', transparent=True)
    plt.show()




#uses output from the collect average ISI to then calculate the slope and mean voltage, this code was superceded by 
#the 50ms window code, but still was useful for tests
def calculate_slope_and_mean_voltage(global_mean_traces, global_mean_isis, save_path, metadata):
    os.chdir(save_path)
    slopes = []
    mean_voltages = []
    min_voltages = []
    for i, trace in enumerate(global_mean_traces):
        if len(trace) == 0:
            print(f"Empty trace found at index {i}")
            continue
        isi = global_mean_isis[i]

        # Scale the trace based on the actual length of the ISI
        time = np.linspace(0, isi, len(trace))

        slope_start_voltage_idx = int(0.3 * len(trace))  # find index of minimum voltage
        end_idx_slope = int(0.7 * len(trace))
        # print('Start Slope:', slope_start_voltage_idx)
        # print('End Slope:', end_idx_slope)
        # print('Time:', time[slope_start_voltage_idx:end_idx_slope])

        min_voltage_index = np.argmin(trace)
        mean_voltage_end_idx = int(0.98 * len(trace))

        plt.plot(time[slope_start_voltage_idx:end_idx_slope], trace[slope_start_voltage_idx:end_idx_slope])

        # Calculate the slope using the scaled trace
        slope = np.gradient(trace[slope_start_voltage_idx:end_idx_slope], time[slope_start_voltage_idx:end_idx_slope]) * 1000

        slopes.append(slope.mean())
        mean_voltages.append(trace[min_voltage_index:mean_voltage_end_idx].mean())
        min_voltages.append(trace[:].min())

    slope_df = pd.DataFrame({'slope': slopes})
    slope_df.to_excel(metadata + '_slopes.xlsx')
    mean_voltage_df = pd.DataFrame({'mean_voltage': mean_voltages})
    mean_voltage_df.to_excel(metadata + '_mean_voltages.xlsx')
    min_voltage_df = pd.DataFrame({'min_voltage': min_voltages})
    min_voltage_df.to_excel(metadata + '_min_voltages.xlsx')

    results_df = pd.concat([slope_df, mean_voltage_df, min_voltage_df], axis=1)
    return results_df




#helpful function to determine the total mean spike voltages. useful to detect break ins during perf patch recordings. a sharp shift will 
#towards a depolarization will be present if perf is ruptured
def mean_spike_voltages(path, sweep_start, sweep_end, detection_level, metadata):
    spike_voltage_list = []
    filename_list = []
    os.chdir(path)
    file_names = sorted(os.listdir())  # Sort the file names
    for filename in file_names:
        # check whether file is in the axgx or axgd format
        if filename.endswith(".axgd") or filename.endswith(".axgx"):
            filename_list.append(filename)
            print('Working on ' + filename)
            [traces] = efel.io.load_neo_file(filename, stim_start=0, stim_end=10000)
            table2 = []
            for data in traces[sweep_start:sweep_end]:
                for p in data:
                    times = (p['T']) / 1000
                    voltages = (p['V'])
                    times -= times[0]
                    dt = times[2] - times[1]
                    detection_level = detection_level
                    min_interval = 0.005
                    spike_times = find_spike_times(voltages, dt, detection_level, min_interval)
                    time_before = .025
                    time_after = .015
                    times_rel = list(np.arange(-time_before, time_after, dt))
                    spike_voltages = []
                    for i in range(0, len(spike_times)):
                        if time_before < spike_times[i] < times[-1] - time_after:
                            spike_voltages.append(
                                segment(voltages, dt, spike_times[i] - time_before, spike_times[i] + time_after))
                    spike_voltage_arrays = [np.array(x) for x in spike_voltages]
                    if len(spike_voltage_arrays) > 0:
                        mean_spike_voltages = [np.mean(k) for k in zip(*spike_voltage_arrays)]
                        table2.append(mean_spike_voltages)
            table3 = np.array(table2)
            file_mean_spike_voltages = np.mean(table3, axis=0)
            spike_voltage_list.append(file_mean_spike_voltages)
    filename_list_df = pd.DataFrame(filename_list)
    filename_list_df.columns = ['Filename']
    global_spike_voltage_list_np = np.array(spike_voltage_list)
    global_spike_voltages_df = pd.DataFrame(global_spike_voltage_list_np)
    global_spike_voltages_df = pd.concat([filename_list_df, global_spike_voltages_df], axis=1)
    global_spike_voltages_df.set_index('Filename', inplace=True)
    global_spike_voltages_df.sort_index(inplace=True)
    mean_spike_voltages = global_spike_voltages_df.mean(axis=0)
    mean_spike_voltages_df = pd.DataFrame(mean_spike_voltages)
    mean_spike_voltages_df.columns = ['Voltage (mV) ' + metadata]
    #mean_spike_voltages_df.index = times_rel

    return global_spike_voltages_df, mean_spike_voltages_df


#phase plane plots can be derived off of the dvdt function here that returns a dataframe
def calculate_dVdt(df, dt):
    # Convert the DataFrame to numeric, replacing non-numeric values with NaN
    df = df.apply(pd.to_numeric, errors='coerce')
    dVdt_df = pd.DataFrame(columns=df.columns)
    for column in df.columns:
        voltage = df[column].values
        dVdt = (np.gradient(voltage, dt)) / 1000
        dVdt_df[column] = dVdt
    return dVdt_df


#another workhorse code block that takes dataput from the collect ISI voltages function and calcs the mean and min and slope
#of the ISI
def calculate_slope_and_mean_voltage_90ms_window(global_mean_traces, global_mean_isis, save_path, metadata):
    os.chdir(save_path)
    slopes = []
    mean_voltages = []
    min_voltages = []
    for i, trace in enumerate(global_mean_traces):
        if len(trace) == 0:
            print(f"Empty trace found at index {i}")
            continue
        isi = global_mean_isis[i]
        # print(isi)

        # Scale the trace based on the actual length of the ISI
        time = np.linspace(0, isi, len(trace))

        slope_start_time = 10  # Start time of the window (e.g., 30% of ISI)
        slope_end_time = 100  # End time of the window (e.g., 50ms window)

        # Find the indices corresponding to the start and end times of the window
        slope_start_idx = np.argmin(np.abs(time - slope_start_time))
        slope_end_idx = np.argmin(np.abs(time - slope_end_time))
        slope_slope_start_voltage_idx = int(0.3 * len(trace))  # find index of minimum voltage
        end_idx_slope_slope = int(0.7 * len(trace))
        # print('Start Index:', slope_start_idx)
        # print('End Index:', slope_end_idx)

        min_voltage_index = np.argmin(trace)
        # print(min_voltage_index)
        # print(slope_end_idx)
        mean_voltage_end_idx = int(0.98 * len(trace))

        # Calculate the slope using the scaled trace within the window
        slope = np.gradient(trace[slope_slope_start_voltage_idx:end_idx_slope_slope],
                            time[slope_slope_start_voltage_idx:end_idx_slope_slope]) * 1000

        slopes.append(slope.mean())
        mean_voltages.append(trace[slope_start_idx:slope_end_idx].mean())
        min_voltages.append(trace[:].min())

        plt.plot(time, trace)  # Plot the entire trace
        plt.axvline(time[slope_start_idx], color='r', linestyle='--')  # Vertical line at start of window
        plt.axvline(time[slope_end_idx], color='r', linestyle='--')  # Vertical line at end of window

    slope_df = pd.DataFrame({'slope': slopes})
    slope_df.to_excel(metadata + '_slopes_mAHP.xlsx')
    mean_voltage_df = pd.DataFrame({'mean_voltage': mean_voltages})
    mean_voltage_df.to_excel(metadata + '_mean_voltages_mAHP.xlsx')
    min_voltage_df = pd.DataFrame({'min_voltage': min_voltages})
    min_voltage_df.to_excel(metadata + '_min_voltages_mAHP.xlsx')

    results_df = pd.concat([slope_df, mean_voltage_df, min_voltage_df], axis=1)
    return results_df


#This code calcs the spikewidth at half max and the apparent spike threshold at 
def calculate_spike_widths_and_thresholds(df, sampling_rate, dvdt_threshold, metadata, save_path):
    os.chdir(save_path)
    spike_widths = []
    filenames = []
    threshold = []
    dt = 1 / sampling_rate
    #the first thing we have to do is set the index to be the filename so that we dont have any strings in the df, but only if its there, so let write an if statement for it
    if 'Filename' in df.columns:
        df = df.set_index('Filename')
    else:
        pass
    for index, row in df.iterrows():
        # Extract voltage trace and time values from the row
        voltages = np.array(row.values)
        times = np.arange(0, len(voltages)) / sampling_rate

        # Find the first derivative of the voltage trace
        dvdt = calculate_dVdt(df, dt)

        # Find the start of the action potential (AP_begin_indices)
        start_indices = np.where(dvdt >= dvdt_threshold)[0]

        if len(start_indices) == 0:
            continue
        AP_begin_index = start_indices[0]
        threshold.append(voltages[AP_begin_index])
        # Find the peak of the action potential
        peak_index = np.argmax(voltages)

        # Calculate the half-max voltage value
        half_max_voltage = (voltages[peak_index] + voltages[AP_begin_index]) / 2

        # Find the indices corresponding to the time at spike start (AP_rise_indices) and time at spike end (AP_fall_indices)
        AP_rise_indices = np.where(voltages >= half_max_voltage)[0]
        AP_fall_indices = np.where(voltages <= half_max_voltage)[0]
        AP_rise_index = AP_rise_indices[0]
        AP_fall_index = AP_fall_indices[-1]

        # Calculate the action potential duration at half width
        spike_width = (times[AP_fall_index] - times[AP_rise_index]) * 100  # Convert to ms
        spike_widths.append(spike_width)

        # Extract the filename from the index
        filenames.append(index)

    spike_width_df = pd.DataFrame({'spike_width_half_max': spike_widths}, index=filenames)
    threshold_df = pd.DataFrame({'threshold': threshold}, index=filenames)
    final_df = spike_width_df.join(threshold_df)
    final_df.to_excel(metadata + '_spike_widths_and_thresholds.xlsx')
    return final_df


def collect_frequency_and_cv(path1, sample_rate, sweep_start, sweep_end, time_start_s, time_end_s, metadata):
    def one_by_isi(x):
        return 1 / (x / 1000)

    def calc_cv(interspike_intervals):
        return np.std(interspike_intervals) / np.mean(interspike_intervals)

    os.chdir(path1)
    sampling_rate = sample_rate
    time_start = time_start_s * sampling_rate
    time_end = time_end_s * sampling_rate
    global_mean_isi = []
    isi_file_means = []
    freq_file_means = []
    global_mean_cv = []
    filename_list = []
    file_names = sorted(os.listdir())  # Sort the file names
    for filename in file_names:
        if filename.endswith(".axgd") or filename.endswith(".axgx"):
            print('Working on ' + filename)
            filename_list.append(filename)
            [traces] = efel.io.load_neo_file(filename, stim_start=0, stim_end=10000)
            traces = traces[sweep_start:sweep_end]
            if len(traces) > 0:
                isi_list = []
                cv_list = []
                freq_list = []
                for trace in traces:
                    for p in trace:
                        times = (p['T'])
                        times = times[time_start:time_end]  #selected for the first 10 seconds of each sweep
                        voltages = (p['V']).flatten()
                        voltages = voltages[time_start:time_end]  #selected for the first 10 seconds of each sweep
                        times -= times[0]
                        dt = times[2] - times[1]
                        detection_level = -10
                        min_interval = 0.0001
                        spike_times = find_spike_times(voltages, dt, detection_level, min_interval)
                        interspike_intervals = np.diff(spike_times)
                        if len(interspike_intervals) > 0:
                            avg_isi_1 = np.mean(interspike_intervals)
                            avg_freq_1 = one_by_isi(avg_isi_1)
                            cv_1 = calc_cv(interspike_intervals)
                            isi_list.append(avg_isi_1)
                            cv_list.append(cv_1)
                            freq_list.append(avg_freq_1)
                if len(isi_list) > 0:
                    freq_means = np.mean(np.array(freq_list))
                    freq_file_means.append(freq_means)
                    isi_file_mean = np.mean(np.array(isi_list))
                    isi_file_means.append(isi_file_mean)
                    cv_file_mean = np.mean(np.array(cv_list))
                    global_mean_isi.append(isi_file_mean)
                    global_mean_cv.append(cv_file_mean)

    filename_df = pd.DataFrame(filename_list)
    filename_df.rename(columns={0: 'filename'}, inplace=True)
    global_mean_isis = np.array(global_mean_isi)
    global_mean_isi_df = pd.DataFrame(global_mean_isis)

    freq_means_df = pd.DataFrame(freq_file_means)
    freq_means_df.rename(columns={0: 'Frequency'}, inplace=True)
    freq_means_df = pd.concat([filename_df, freq_means_df], axis=1)
    freq_means_df.set_index('filename', inplace=True)
    freq_means_df.sort_index(inplace=True)

    global_mean_frequencies = global_mean_isi_df.applymap(one_by_isi)
    global_mean_frequencies.rename(columns={0: 'Frequency'}, inplace=True)
    global_mean_frequencies = pd.concat([filename_df, global_mean_frequencies], axis=1)
    global_mean_frequencies.set_index('filename', inplace=True)
    global_mean_frequencies.sort_index(inplace=True)

    global_mean_cvs = np.array(global_mean_cv)
    global_mean_cv_df = pd.DataFrame(global_mean_cvs)
    global_mean_cv_df.rename(columns={0: 'CV_ISI'}, inplace=True)
    global_mean_cv_df.set_index(global_mean_frequencies.index, inplace=True)

    result_df = pd.concat([freq_means_df, global_mean_cv_df], axis=1)
    result_df['Metadata'] = metadata
#     result_df.to_excel(metadata + '_frequency_and_cv.xlsx')
    return result_df



In [None]:
#set a directory, all code for these analyses function for batch processing, and will work on a file's worth of data. They will specifically pull files in the axograph file format, and analyze them for each function. certain functions are dependent on other functions. 
#the purpose of this example notebook is for a general workflow for analyzing data. We have included a few functions and example data to show how to use the functions.
example_dir = os.getcwd()

In [None]:
#Here, we will collect all interspike intervals and global mean traces for a set of files. This function will return a dataframe with the interspike intervals and global mean traces for each file, we will call these 'Control'. You can check the params for each function to see the input parameter requirements. 
control_noise_mean_trace, control_noise_isi, control_noise_global_mean_traces, control_noise_total_trajectories_baseline, control_noise_global_mean_isis = collect_isis_global_excel(example_dir, 20000, 0, 50, 0, 10, 'control')

In [None]:
def plot_average_voltage_trajectory(mean_trace1, avg_isi1, mean_trace1_label, mean_trace1_color, save, metadata):

    # Calculate first baseline statistics
    dVdt1 = np.gradient(mean_trace1, axis=0)
    slope1 = np.mean(dVdt1 / 0.1)
    avg_isis1 = np.linspace(0, avg_isi1 * len(mean_trace1), len(mean_trace1))
    #ci1 = 1.96 * np.std(np.array(total_trajectory_voltages1), axis=0) / np.sqrt(len(total_trajectory_voltages1))

    # Calculate second baseline statistics
   
    #ci2 = 1.96 * np.std(np.array(total_trajectory_voltages2), axis=0) / np.sqrt(len(total_trajectory_voltages2))

    # Calculate duration of each trace
    num_points = 1000
    duration1 = (len(mean_trace1) * avg_isi1) / num_points-1





    # Calculate scaling factors for x-axis
    scale1 = duration1 / len(mean_trace1)



    # plot mean traces with actual durations
    fig, ax = plt.subplots(figsize=(2.75, 2.75))
    ax.plot(np.arange(0, duration1, scale1)[:1000], mean_trace1, color=mean_trace1_color, label=mean_trace1_label, linewidth=0.5)
    # ax.fill_between(np.arange(0, duration1, scale1), mean_trace1 - ci1, mean_trace1 + ci1, alpha=0.3,
    #                 color='#848482', label='95% CI (Trace 1)')
    # ax.fill_between(np.arange(0, duration2, scale2), mean_trace2 - ci2, mean_trace2 + ci2, alpha=0.3,
    #                 color='red', label='95% CI (Trace 2)')
    ax.set_xlabel('Time from spike (ms)', fontsize=9)
    ax.set_ylabel('Membrane Potential (mV)', fontsize=9)
    # ax.set_title('ISI voltage trajectories')
    ax.legend(loc='lower right', borderpad=0.4, labelspacing=.1, fontsize=7)
    ax.tick_params(axis='both', which='both', labelsize=8, width=0.5)  # You can adjust the font size as needed

    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['bottom'].set_linewidth(0.50)
    ax.spines['left'].set_linewidth(0.50)

    min_value = np.min(mean_trace1)
    #ax.legend(fontsize=14)
    plt.legend(frameon=False, edgecolor='none', fontsize='large', handlelength=2, handleheight=1, loc='lower right')
    #set major ticks on the x and y axis

    #x.tick_params(axis='x', labelsize=14)
    plt.ylim(min_value-1, -30)

    if save == True:
        os.chdir('/Users/HBLANKEN/Library/CloudStorage/OneDrive-UniversityofOklahoma/Beckstead Lab/DA-AD paper files/Grand collection of axograph files/Noise Traces/Figures')
        plt.savefig(metadata + '.pdf', dpi=600, bbox_inches='tight', transparent=True)
    plt.show()

In [None]:
#Here we will plot the voltage trajectories for the control group. This function will plot the voltage trajectories for the control group. You can check the params for each function to see the input parameter requirements.
plot_average_voltage_trajectory(control_noise_mean_trace, control_noise_isi, 'Control', 'black', False, 'Control')

In [None]:
#here we calculate the slopes and mean voltages for the group. This function will calculate the slopes and mean voltages for the group. You can check the params for each function to see the input parameter requirements.
#control slopes and mean voltages

save_path = 'Output path'
control_slopes = calculate_slope_and_mean_voltage_50ms_window(control_noise_global_mean_traces, control_noise_global_mean_isis, save_path, 'control')

control_slopes

In [None]:
os.chdir('..')

In [None]:
#Here we collect the spike voltages to work off of
control_global_spike_voltages, control_mean_spike_voltages = mean_spike_voltages(os.getcwd(), 0, 50, -10, 'Control')

In [None]:
os.chdir('..')

In [None]:
#calcualte the spike width and thresholds from the spike voltages
calculate_spike_widths_and_thresholds(control_global_spike_voltages, 20_000, 10, 'WT', save_path)

In [None]:
#here we calculate the frequency and cv for the group. This function will calculate the frequency and cv for the group. You can check the params for each function to see the input parameter requirements.
collect_frequency_and_cv(os.getcwd(), 20000, 0, -10, 0, 10, 'control')