# Investigate burst firing


## Imports

In [None]:
import os
import sys
import numpy as np
import pandas as pd
import scipy.stats
#%matplotlib widget
import matplotlib.pyplot as plt
import seaborn as sns

import pickle


In [None]:
from helpers import *
from utils import remove_top_right_frame, jitter_scatterplot

In [None]:
base_path = os.getcwd()
base_path = base_path.replace('notebooks', 'scripts') # note: if you have notebooks twice in your base_path, this won't work
sys.path.insert(1, base_path)


## Load Original Data (Only Run Once)

In [None]:
with open('data_bio482.pkl', 'rb') as file:
    data_df = pickle.load(file)

## Add AP Information (Only Run Once)

Using the Function_Detect_APs function, we can extract the details of the APs per sweep. For each entry in the original dataframe, we append the information about the APs in the Sweep_MembranePotential.

In [None]:
def explain_ap(ap_params):
    return {
        'ap_thresh_vm': ap_params[1],
        'ap_peak_times': ap_params[2],
        'ap_peak_vm': ap_params[3],
        'ap_peak_amp': ap_params[4],
        'ap_duration': ap_params[5]
    }

In [None]:
def process_sweep(row):
    """Process a single row to detect and explain action potentials."""
    sr_vm = row['Sweep_MembranePotential_SamplingRate']
    ap_vm_deriv_thresh = row['Cell_APThreshold_Slope']
    potential_across_time = row['Sweep_MembranePotential']

    # Detect action potentials
    ap_params = Function_Detect_APs(potential_across_time, sr_vm, ap_vm_deriv_thresh)

    # Return a dictionary of explained APs
    return {ap[0]: explain_ap(ap) for ap in ap_params}



In [None]:
# Apply the function to each row
ap_info = data_df.apply(process_sweep, axis=1)


In [None]:
ap_info

In [None]:
data_df['ap_info'] = ap_info


In [None]:
#path = "/content/drive/MyDrive/NSCCM/"

file_to_save = "data_with_AP_info"
#data_df.to_pickle(f'{path}{file_to_save}.pkl')
data_df.to_pickle(f'{file_to_save}.pkl')

## Import New Pickle

No need to run last 2 sections if you already have data_with_AP_info.pkl file in your directory. It is too large to push to github.

In [None]:
with open('data_with_AP_info.pkl', 'rb') as file:
    data_df = pickle.load(file)

## Visualise Data:

In [None]:
data_df.head()

Each entry in the dataframe represents some sweep. There may be multiple sweeps per cell (as represented by Sweep_Counter) and so multiple entries with the same Cell_ID. Each entry has a Membrane Potential vs Time measure which can be visualised below:

In [None]:
def plot_aps(data_df, sweep_num, save_figure=False):

    sweep = data_df.iloc[sweep_num]

    sr_vm = sweep['Sweep_MembranePotential_SamplingRate']
    membrane_potential = sweep['Sweep_MembranePotential']
    
    time = [i/sr_vm for i in range(len(membrane_potential))]
    
    ap_peak_times = np.array([ap["ap_peak_times"] for ap in sweep["ap_info"].values()])
    ap_peak_vm = np.array([ap["ap_peak_vm"] for ap in sweep["ap_info"].values()])
    ap_thresh_vm = np.array([ap["ap_thresh_vm"] for ap in sweep["ap_info"].values()])
    ap_thresh_times = np.array([key for key in sweep["ap_info"].keys()])

    fig, ax = plt.subplots(1,1, figsize=(10,3), dpi=150)
    
    ax.set_title(f'APs for Sweep {sweep_num}')
    ax.plot(time, membrane_potential, lw=0.5,  zorder=0)
    ax.scatter(ap_peak_times, ap_peak_vm, lw=0.5, marker='+', color='r')
    ax.scatter(ap_thresh_times, ap_thresh_vm, lw=0.1, marker='o', color='b')
    ax.set_xlabel('Time (s)')
    ax.set_ylabel('Vm (V)')
    plt.xlim(6,6.5)
    
    if save_figure:
        # Save figure
        fname = os.path.join('images', f'Example_Cell_{sweep_num}.png')
        plt.savefig(fname=fname, dpi='figure', format='png', bbox_inches='tight')


In [None]:
plot_aps(data_df, 0, save_figure = True)

## Identify Bursts for one sweep 

In [None]:
ap_df = pd.DataFrame.from_dict(data_df.ap_info[0], orient='index')

In [None]:
ap_df['peak_end_time'] = ap_df.index + ap_df['ap_duration'] / 1000
ap_df['time_since_last_AP'] = ap_df.index - ap_df['peak_end_time'].shift(1) 
ap_df['time_since_last_AP_2x'] = ap_df.index - ap_df['peak_end_time'].shift(2) 
ap_df['is_in_burst'] = ap_df['time_since_last_AP_2x'] < .03 
ap_df['is_in_burst'] = ap_df['is_in_burst'] | ap_df['is_in_burst'].shift(-1) | ap_df['is_in_burst'].shift(-2)

In [None]:
ap_df.head()

In [None]:
def burst_times(ap_df):
    bursts = []
    i = 0
    while i < len(ap_df):
        row = ap_df.iloc[i]
        is_in_burst = row['is_in_burst']
        if is_in_burst:
            burst_start = row.name
            for j in range(i, len(ap_df)):
                row_1 = ap_df.iloc[j]
                if j == len(ap_df)-1:
                    # we have reached last peak in sweep
                    burst_end = row_1['peak_end_time']
                    bursts.append((burst_start, burst_end, j-i+1))
                    break
                row_2 = ap_df.iloc[j+1]
                if not row_2['is_in_burst']:
                    # we have reached peak which is not in this burst
                    burst_end = row_1['peak_end_time']
                    bursts.append((burst_start, burst_end, j-i+1))
                    break
                else:
                    continue
            i = j
        i += 1
    return bursts

In [None]:
burst_times(ap_df)

## Bursts for all sweeps

In [None]:
burst_times_per_sweep = []

for i in range(len(data_df)):
    if i % 10 == 0:
        print(f"completing sweep {i+1}")
    ap_df = pd.DataFrame.from_dict(data_df.ap_info[i], orient='index')
    if len(ap_df):
        ap_df['peak_end_time'] = ap_df.index + ap_df['ap_duration'] / 1000
        ap_df['time_since_last_AP'] = ap_df.index - ap_df['peak_end_time'].shift(1) 
        ap_df['time_since_last_AP_2x'] = ap_df.index - ap_df['peak_end_time'].shift(2) 
        ap_df['is_final_in_burst'] = ap_df['time_since_last_AP_2x'] < .03 
        ap_df['is_in_burst'] = ap_df['is_final_in_burst'] | ap_df['is_final_in_burst'].shift(-1) | ap_df['is_final_in_burst'].shift(-2)
        burst_info = burst_times(ap_df)
        burst_times_per_sweep.append(burst_info)
        print(f"found {len(burst_info)} bursts")
    else:
        print(f"found no peaks")
        burst_times_per_sweep.append([])
    

In [None]:
file_path = "data_with_burst_info.txt"

# Write the list to a text file
with open(file_path, "w") as file:
    for item in burst_times_per_sweep:
        file.write(f"{item}\n")

#### Add burst info for each sweep to the data_df dataftame   

In [None]:
# add a column to data_df containing burst_times_per_sweep
# each sweep has format [(start, stop, nb_of_ap), (start, stop, nb_of_ap), ...]
data_df['burst_info'] = burst_times_per_sweep

# add col for nb of burts 
data_df['nb_of_bursts'] = data_df['burst_info'].apply(len)

#add col for mean burst duration ->TODO check is there is a lot of diversity in the duration : pertinent to do mean ?
def mean_burst_duration(burst_inf):
    if not burst_inf:
        return 0
    durations = [stop-start for start, stop, _ in burst_inf]
    return sum(durations)/len(durations)

data_df['mean_burst_duration_sweep'] = data_df['burst_info'].apply(lambda x: mean_burst_duration(x) )

## Analyse bursting behavior of neurons 

### Define metrics to categorise neurons according to their bursting behavior:
- mean number of burst 
- mean burst duration

In [None]:
def plot_metrics(df_metric, x):
    '''
    plots subplots of all metrics for x
    
    - df_metric (DataFrame): dataframe with 1 column x and multiple colums for metrics (but no other kind of columns)
    - x (str): the x axis (ex 'Cell_Type' or 'Cell_layer') 
    '''
    # create list of metrics : columns except x
    metrics = df_metric.columns.difference([x])

    # Create subplots dynamically
    fig, axes = plt.subplots(1, len(metrics), figsize=(6 * len(metrics), 6), sharey=False)

    # color palette based on the unique x values
    palette_x = sns.color_palette("Set2", n_colors=len(df_metric[x].unique()))

    # Iterate through metrics and plot each one
    for ax, metric in zip(axes, metrics):
        ax.bar(df_metric[x], df_metric[metric], color=palette_x)
        ax.set_xlabel(x)
        ax.set_ylabel(metric.replace('_', ' ').capitalize())
        ax.set_title(f'{metric.replace("_", " ").capitalize()} by {x}')

    plt.tight_layout()
    plt.show()

#### Per cell type

In [None]:
# Calculate metrics
mean_nb_burst_cell_type = data_df.groupby('Cell_Type')['nb_of_bursts'].mean()
mean_burst_duration_cell_type = data_df.groupby('Cell_Type')['mean_burst_duration_sweep'].mean()

# Create DataFrame and combine metrics
cell_type_metrics = pd.DataFrame({
    'Cell_Type': mean_nb_burst_cell_type.index,
    'mean_nb_bursts': mean_nb_burst_cell_type.values,
    'mean_burst_duration': mean_burst_duration_cell_type.values
})
cell_type_metrics

In [None]:
plot_metrics(cell_type_metrics, 'Cell_Type')

#### Per cell layer

In [None]:
# Calculate metrics
mean_nb_burst_layer = data_df.groupby('Cell_Layer')['nb_of_bursts'].mean()
mean_burst_duration_layer = data_df.groupby('Cell_Layer')['mean_burst_duration_sweep'].mean()

# Create DataFrame and combine metrics
layer_metrics = pd.DataFrame({
    'Cell_Layer': mean_nb_burst_layer.index,
    'mean_nb_bursts': mean_nb_burst_layer.values,
    'mean_burst_duration': mean_burst_duration_layer.values
})

In [None]:
plot_metrics(layer_metrics,'Cell_Layer')

### Cell depth 

In [None]:
# Calculate metrics
mean_nb_burst_depth = data_df.groupby('Cell_Depth')['nb_of_bursts'].mean()
mean_burst_duration_depth = data_df.groupby('Cell_Depth')['mean_burst_duration_sweep'].mean()

# Create DataFrame and combine metrics
depth_metrics = pd.DataFrame({
    'Cell_Depth': mean_nb_burst_depth.index,
    'mean_nb_bursts': mean_nb_burst_depth.values,
    'mean_burst_duration': mean_burst_duration_depth.values
})

In [None]:
plot_metrics(depth_metrics, 'Cell_Depth')

### Targeted brain area

In [None]:
# Calculate metrics
mean_nb_burst_area = data_df.groupby('Cell_TargetedBrainArea')['nb_of_bursts'].mean()
mean_burst_duration_area = data_df.groupby('Cell_TargetedBrainArea')['mean_burst_duration_sweep'].mean()

# Create DataFrame and combine metrics
area_metrics = pd.DataFrame({
    'Cell_TargetedBrainArea': mean_nb_burst_area.index,
    'mean_nb_bursts': mean_nb_burst_area.values,
    'mean_burst_duration': mean_burst_duration_area.values
})

In [None]:
plot_metrics(area_metrics,'Cell_TargetedBrainArea')