### Configuration

In [1]:
import os
import pandas as pd
import numpy as np
from scipy.optimize import curve_fit

import utils__config

In [2]:
os.chdir(utils__config.working_directory)
os.getcwd()

'G:\\My Drive\\Residency\\Research\\Lab - Damisah\\Project - Sleep\\Revisions'

### Parameters

In [3]:
recordings = [
    {
        'recording_id': 'Feb02',
        'recording_length': 2,
        'spike_times_path': 'Data/S01_Feb02_spike_times.csv',
        'spike_forms_path': 'Data/S01_Feb02_spike_waveforms.csv'
    },
    {
        'recording_id': 'Jul11',
        'recording_length': 9.68,
        'spike_times_path': 'Data/S05_Jul11_spike_times.csv',
        'spike_forms_path': 'Data/S05_Jul11_spike_waveforms.csv'
    },
    {
        'recording_id': 'Jul12',
        'recording_length': 10.55,
        'spike_times_path': 'Data/S05_Jul12_spike_times.csv',
        'spike_forms_path': 'Data/S05_Jul12_spike_waveforms.csv'
    },
    {
        'recording_id': 'Jul13',
        'recording_length': 10.40,
        'spike_times_path': 'Data/S05_Jul13_spike_times.csv',
        'spike_forms_path': 'Data/S05_Jul13_spike_waveforms.csv'
    }
]

output_path = 'Data/cell_type_metrics.csv'

### Load Data

In [4]:
# Empty lists to store dataframes for waveforms and times
waveforms_dfs = []
times_dfs = []

# Process each recording
for recording in recordings:
    # Load and process waveforms
    waveforms = pd.read_csv(recording['spike_forms_path'])
    waveforms = waveforms[['unit_id', 'time_point', 'amplitude']]
    waveforms.columns = ['unit_id_old', 'time_point', 'amplitude']
    waveforms['unit_id'] = waveforms['unit_id_old'].astype(str) + '_' + recording['recording_id']
    waveforms_dfs.append(waveforms)
    
    # Load and process times
    times = pd.read_csv(recording['spike_times_path'])
    times = times[['unit_id', 'seconds']]
    times.columns = ['unit_id_old', 'time']
    times['unit_id'] = times['unit_id_old'].astype(str) + '_' + recording['recording_id']
    times['recording_id'] = recording['recording_id']
    times_dfs.append(times)

# Concatenate all dataframes for waveforms and times
waveforms = pd.concat(waveforms_dfs, ignore_index=True)
times = pd.concat(times_dfs, ignore_index=True)

### Firing Rate

In [5]:
# Group by unit_id and count the number of spikes for each unit_id
spike_counts = times.groupby('unit_id').size()

# Extract the recording_id from the unit_id and map it to its corresponding recording length in seconds
recording_lengths = {rec['recording_id']: rec['recording_length'] * 3600 for rec in recordings}  # convert hours to seconds
times['recording_length_seconds'] = times['recording_id'].map(recording_lengths)

# Ensure that each unit_id has the same recording length (this should be the case)
recording_lengths_by_unit = times.groupby('unit_id')['recording_length_seconds'].first()

# Calculate the average firing rate for each unit_id
firing_rates = spike_counts / recording_lengths_by_unit

# Convert the Series into a DataFrame
firing_rates = firing_rates.reset_index()
firing_rates.columns = ['unit_id', 'firing_rate']

### Trough-to-peak time

In [6]:
def trough_to_peak_time(waveforms_df, sampling_rate=30000):
    """
    Calculate the trough-to-peak time for each neuron's average waveform (or will assign NaN if unit is positive-spiking).

    Parameters:
    - waveforms_df: A pandas DataFrame with columns 'unit_id', 'time_point', and 'amplitude'.
    - sampling_rate: The sampling rate in Hz. Default is 30000.

    Returns:
    - A DataFrame with columns 'unit_id' and 'trough_to_peak_time'.
    """
    # Reshape the data to wide format
    waveforms_wide = waveforms_df.pivot(index='time_point', columns='unit_id', values='amplitude')
    
    unit_ids = []
    times_to_peak = []

    # Loop through each unit
    for unit in waveforms_wide.columns:
        # Check if the unit is positive-spiking
        if "pos" in unit:
            unit_ids.append(unit)
            times_to_peak.append(np.nan)
            continue

        # If the unit is negative-spiking, compute the trough-to-peak time
        waveform = waveforms_wide[unit].values
        trough_idx = np.argmin(waveform)
        peak_idx = trough_idx + np.argmax(waveform[trough_idx:])
        samples_to_peak = peak_idx - trough_idx
        time_to_peak = (samples_to_peak / sampling_rate) * 1000

        unit_ids.append(unit)
        times_to_peak.append(time_to_peak)

    # Create a DataFrame for the results
    results_df = pd.DataFrame({
        'unit_id': unit_ids,
        'trough_to_peak': times_to_peak
    })

    return results_df

In [7]:
ttp_times = trough_to_peak_time(waveforms)

### Full-Width Half Maximum (FWHM)

In [8]:
def spike_width_fwhm(waveforms_df, sampling_rate=30000):
    """
    Calculate the Full Width at Half Maximum (FWHM) for each neuron's average waveform,
    distinguishing between positive and negative spiking units based on 'unit_id'.

    Parameters:
    - waveforms_df: A pandas DataFrame with columns 'unit_id', 'time_point', and 'amplitude'.
    - sampling_rate: The sampling rate in Hz. Default is 30000.

    Returns:
    - A DataFrame with columns 'unit_id' and 'fwhm'.
    """
    
    waveforms_wide = waveforms_df.pivot(index='time_point', columns='unit_id', values='amplitude')
    
    unit_ids = []
    fwhm_values = []

    for unit in waveforms_wide.columns:
        waveform = waveforms_wide[unit].values

        if "_neg_" in unit:  # Negative-spiking unit

            half_amplitude = np.min(waveform) / 2
            reference_idx = np.argmin(waveform)

        elif "_pos_" in unit:  # Positive-spiking unit

            half_amplitude = np.max(waveform) / 2
            reference_idx = np.argmax(waveform)

        # Find indices where waveform crosses the half amplitude level
        cross_points = np.where(np.diff(np.sign(waveform - half_amplitude)))[0]

        # Find the closest points to the reference index (trough or peak)
        before_idx = cross_points[cross_points < reference_idx][-1] if len(cross_points[cross_points < reference_idx]) > 0 else reference_idx
        after_idx = cross_points[cross_points > reference_idx][0] if len(cross_points[cross_points > reference_idx]) > 0 else reference_idx
        
        # Calculate FWHM
        samples_width = after_idx - before_idx
        fwhm_time = (samples_width / sampling_rate) * 1000  # Convert to milliseconds

        unit_ids.append(unit)
        fwhm_values.append(fwhm_time)

    results_df = pd.DataFrame({
        'unit_id': unit_ids,
        'fwhm': fwhm_values
    })

    return results_df

In [9]:
fwhm_times = spike_width_fwhm(waveforms)

### Burst Index

In [10]:
# Define the ISI threshold for bursts (10 ms)
threshold = 0.01

def calculate_burst_index(group):
    # Calculate ISIs
    isis = group['time'].diff().dropna()
    
    # Identify spikes that are part of a burst
    burst_spikes = isis[isis < threshold].count() + 1  # + 1 to account for the first spike in each burst
    
    # Calculate Burst Index
    bi = burst_spikes / len(group)
    
    return bi

# Calculate Burst Index for each unit
burst_indices = times.groupby('unit_id').apply(calculate_burst_index).reset_index()
burst_indices.columns = ['unit_id', 'burst_index']

### Merge metrics and plot

In [11]:
merged_df = firing_rates.merge(ttp_times, on='unit_id').merge(burst_indices, on='unit_id').merge(fwhm_times, on='unit_id')
merged_df['log_firing_rate'] = np.log2(merged_df['firing_rate'])

id_frame = waveforms[['unit_id', 'unit_id_old']].drop_duplicates()
final_df = merged_df.merge(id_frame, on='unit_id', how='left')

final_df.to_csv(output_path)
final_df

Unnamed: 0,unit_id,firing_rate,trough_to_peak,burst_index,fwhm,log_firing_rate,unit_id_old
0,S01_Ch195_neg_Unit3_Feb02,2.494861,0.566667,0.082225,0.166667,1.318960,S01_Ch195_neg_Unit3
1,S01_Ch195_pos_Unit2_Feb02,2.695000,,0.063080,0.433333,1.430285,S01_Ch195_pos_Unit2
2,S01_Ch196_neg_Unit1_Feb02,1.947639,0.533333,0.038793,0.166667,0.961726,S01_Ch196_neg_Unit1
3,S01_Ch196_neg_Unit3_Feb02,1.253750,1.433333,0.039548,0.366667,0.326250,S01_Ch196_neg_Unit3
4,S01_Ch196_neg_Unit4_Feb02,1.324167,0.500000,0.062828,0.166667,0.405085,S01_Ch196_neg_Unit4
...,...,...,...,...,...,...,...
117,S05_Ch239_neg_Unit3_Jul11,1.039199,1.100000,0.103413,0.333333,0.055472,S05_Ch239_neg_Unit3
118,S05_Ch240_neg_Unit1_Jul11,10.708907,1.200000,0.164058,0.300000,3.420739,S05_Ch240_neg_Unit1
119,S05_Ch240_neg_Unit2_Jul12,4.995735,1.066667,0.097055,0.200000,2.320697,S05_Ch240_neg_Unit2
120,S05_Ch240_neg_Unit2_Jul13,9.179140,1.133333,0.137654,0.200000,3.198359,S05_Ch240_neg_Unit2
