### Configuration

In [1]:
import os
import pandas as pd
import numpy as np
from scipy.optimize import curve_fit

import utils__config

In [2]:
os.chdir(utils__config.working_directory)
os.getcwd()

'Z:\\Layton\\Sleep_083023'

### Parameters

In [3]:
recordings = [
    {
        'recording_id': 'Feb02',
        'recording_length': 2,
        'spike_times_path': 'Cache/Subject01/Feb02/S01_spikes.csv',
        'spike_forms_path': 'Cache/Subject01/Feb02/S01_spikeforms.csv'
    },
    {
        'recording_id': 'Jul11',
        'recording_length': 9.68,
        'spike_times_path': 'Cache/Subject05/Jul11/S05_spikes.csv',
        'spike_forms_path': 'Cache/Subject05/Jul11/S05_spikeforms.csv'
    },
    {
        'recording_id': 'Jul12',
        'recording_length': 10.55,
        'spike_times_path': 'Cache/Subject05/Jul12/S05_spikes.csv',
        'spike_forms_path': 'Cache/Subject05/Jul12/S05_spikeforms.csv'
    },
    {
        'recording_id': 'Jul13',
        'recording_length': 10.40,
        'spike_times_path': 'Cache/Subject05/Jul13/S05_spikes.csv',
        'spike_forms_path': 'Cache/Subject05/Jul13/S05_spikeforms.csv'
    }
]

output_path = 'Cache/cell_type_metrics.csv'

### Load Data

In [4]:
# Empty lists to store dataframes for waveforms and times
waveforms_dfs = []
times_dfs = []

# Process each recording
for recording in recordings:
    # Load and process waveforms
    waveforms = pd.read_csv(recording['spike_forms_path'])
    waveforms = waveforms[['unit_id', 'time_point', 'amplitude']]
    waveforms.columns = ['unit_id_old', 'time_point', 'amplitude']
    waveforms['unit_id'] = waveforms['unit_id_old'].astype(str) + '_' + recording['recording_id']
    waveforms_dfs.append(waveforms)
    
    # Load and process times
    times = pd.read_csv(recording['spike_times_path'])
    times = times[['unit_id', 'seconds']]
    times.columns = ['unit_id_old', 'time']
    times['unit_id'] = times['unit_id_old'].astype(str) + '_' + recording['recording_id']
    times['recording_id'] = recording['recording_id']
    times_dfs.append(times)

# Concatenate all dataframes for waveforms and times
waveforms = pd.concat(waveforms_dfs, ignore_index=True)
times = pd.concat(times_dfs, ignore_index=True)

### Firing Rate

In [5]:
# Group by unit_id and count the number of spikes for each unit_id
spike_counts = times.groupby('unit_id').size()

# Extract the recording_id from the unit_id and map it to its corresponding recording length in seconds
recording_lengths = {rec['recording_id']: rec['recording_length'] * 3600 for rec in recordings}  # convert hours to seconds
times['recording_length_seconds'] = times['recording_id'].map(recording_lengths)

# Ensure that each unit_id has the same recording length (this should be the case)
recording_lengths_by_unit = times.groupby('unit_id')['recording_length_seconds'].first()

# Calculate the average firing rate for each unit_id
firing_rates = spike_counts / recording_lengths_by_unit

# Convert the Series into a DataFrame
firing_rates = firing_rates.reset_index()
firing_rates.columns = ['unit_id', 'firing_rate']

### Trough-to-peak time

In [6]:
import numpy as np
import pandas as pd

def trough_to_peak_times_df(waveforms_df, sampling_rate=30000):
    """
    Calculate the trough-to-peak time for each neuron's average waveform (or will assign NaN if unit is positive-spiking).

    Parameters:
    - waveforms_df: A pandas DataFrame with columns 'unit_id', 'time_point', and 'amplitude'.
    - sampling_rate: The sampling rate in Hz. Default is 30000.

    Returns:
    - A DataFrame with columns 'unit_id' and 'trough_to_peak_time'.
    """
    # Reshape the data to wide format
    waveforms_wide = waveforms_df.pivot(index='time_point', columns='unit_id', values='amplitude')
    
    # Initialize lists to store results
    unit_ids = []
    times_to_peak = []

    # Loop through each unit
    for unit in waveforms_wide.columns:
        # Check if the unit is positive-spiking
        if "pos" in unit:
            unit_ids.append(unit)
            times_to_peak.append(np.nan)
            continue

        # If the unit is negative-spiking, compute the trough-to-peak time
        waveform = waveforms_wide[unit].values
        trough_idx = np.argmin(waveform)
        peak_idx = trough_idx + np.argmax(waveform[trough_idx:])
        samples_to_peak = peak_idx - trough_idx
        time_to_peak = (samples_to_peak / sampling_rate) * 1000

        unit_ids.append(unit)
        times_to_peak.append(time_to_peak)

    # Create a DataFrame for the results
    results_df = pd.DataFrame({
        'unit_id': unit_ids,
        'trough_to_peak': times_to_peak
    })

    return results_df

ttp_times = trough_to_peak_times_df(waveforms)

### Burst Index

In [7]:
# Define the ISI threshold for bursts (10 ms)
threshold = 0.01

def calculate_burst_index(group):
    # Calculate ISIs
    isis = group['time'].diff().dropna()
    
    # Identify spikes that are part of a burst
    burst_spikes = isis[isis < threshold].count() + 1  # + 1 to account for the first spike in each burst
    
    # Calculate Burst Index
    bi = burst_spikes / len(group)
    
    return bi

# Calculate Burst Index for each unit
burst_indices = times.groupby('unit_id').apply(calculate_burst_index).reset_index()
burst_indices.columns = ['unit_id', 'burst_index']

### Merge metrics and plot

In [8]:
merged_df = firing_rates.merge(ttp_times, on='unit_id').merge(burst_indices, on='unit_id')
merged_df['log_firing_rate'] = np.log2(merged_df['firing_rate'])

id_frame = waveforms[['unit_id', 'unit_id_old']].drop_duplicates()
final_df = merged_df.merge(id_frame, on='unit_id', how='left')
final_df.to_csv(output_path)

### Autocorrelogram

In [9]:
# import matplotlib.pyplot as plt

# def plot_autocorrelogram(times, bin_size=0.003, window=0.1):
#     """Plot the autocorrelogram of spike times."""
#     time_diffs = []
#     for i in range(len(times)):
#         for j in range(len(times)):
#             if i != j:
#                 time_diffs.append(times[i] - times[j])
    
#     bins = np.arange(-window, window + bin_size, bin_size)
#     hist, _ = np.histogram(time_diffs, bins=bins)
    
#     plt.bar(bins[:-1], hist, width=bin_size, align='edge')
#     plt.xlabel('Time lag (s)')
#     plt.ylabel('Spike count')
#     plt.title('Autocorrelogram')
#     plt.xlim([-window, window])
#     plt.show()

# # Generate autocorrelograms for each unit
# for unit_id, group in df.groupby('unit_id'):
#     print(f"Unit: {unit_id}")
#     plot_autocorrelogram(group['time'].values)

### Tau Rise

In [10]:
import numpy as np
import pandas as pd
from scipy.optimize import curve_fit
from tqdm import tqdm

# Define the triple-exponential function
def acg_fit(x, tau_decay, tau_rise, c, d, rateasymptote, trefrac, tau_burst, h):
    return np.maximum(c * (np.exp(-(x - trefrac) / tau_decay) - d * np.exp(-(x - trefrac) / tau_rise)) + h * np.exp(-(x - trefrac) / tau_burst) + rateasymptote, 0)

def compute_autocorrelogram(spike_times, bin_size=0.0005, max_lag=0.05):
    time_diffs = []
    
    for i in range(len(spike_times)):
        diffs = spike_times - spike_times[i]
        relevant_diffs = diffs[(diffs > -max_lag) & (diffs < max_lag) & (diffs != 0)]
        time_diffs.extend(relevant_diffs)

    autocorr, bin_edges = np.histogram(time_diffs, bins=np.arange(-max_lag, max_lag + bin_size, bin_size))
    autocorr[int(len(autocorr)/2)] = 0  # Set the value at zero lag to zero
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
    return autocorr, bin_centers

In [11]:
keys = ['unit_id', 'tau_decay', 'tau_rise', 'c', 'd', 'rateasymptote', 'trefrac', 'tau_burst', 'h']
a0 = [1, 2, 20, 30, 1.5, 2, 5, 0.5]
lb = [0.1, 0, 1, 0, 0.1, 0, 0, -30]
ub = [50, 15, 500, 500, 5, 100, 20, 50]

In [12]:
# Create results list
results_list = []

for unit, group in tqdm(times.groupby('unit_id'), total=times['unit_id'].nunique()):
    spike_times = group['time'].values
    autocorr, bin_centers = compute_autocorrelogram(spike_times)
    
    # Fit the data
    try:
        popt, _ = curve_fit(acg_fit, bin_centers, autocorr, p0=a0, bounds=(lb, ub))
        results_list.append([unit] + list(popt))
    except:
        results_list.append([unit] + [np.nan for _ in range(len(keys) - 1)])

# Convert results list to DataFrame
results_df = pd.DataFrame(results_list, columns=keys)

### Test Plot

In [13]:
import matplotlib.pyplot as plt

# Select the first 5 unique unit_ids
selected_unit_ids = times['unit_id'].unique()[:5]
subset_times = times[times['unit_id'].isin(selected_unit_ids)]

for unit, group in subset_times.groupby('unit_id'):
    spike_times = group['time'].values
    autocorr, bin_centers = compute_autocorrelogram(spike_times)
    
    try:
        popt, _ = curve_fit(acg_fit, bin_centers, autocorr, p0=a0, bounds=(lb, ub))
        plt.figure()
        plt.plot(bin_centers, autocorr, 'b-', label='Data')
        plt.plot(bin_centers, acg_fit(bin_centers, *popt), 'r-', label='Fit')
        plt.title(f'Unit {unit}')
        plt.legend()
    except:
        pass
plt.show()

### More testing

In [16]:
import numpy as np
import pandas as pd
from scipy.optimize import curve_fit
from tqdm import tqdm

# Define the triple-exponential function
def acg_fit(x, a, b, c, d, e, f, g, h):
    return np.maximum(c * (np.exp(-(x - f) / a) - d * np.exp(-(x - f) / b)) + h * np.exp(-(x - f) / g) + e, 0)

def compute_autocorrelogram(spike_times, bin_size=0.0005, max_lag=0.05):
    time_diffs = []

    for i in range(len(spike_times)):
        diffs = spike_times - spike_times[i]
        relevant_diffs = diffs[(diffs > -max_lag) & (diffs < max_lag) & (diffs != 0)]
        time_diffs.extend(relevant_diffs)

    autocorr, bin_edges = np.histogram(time_diffs, bins=np.arange(-max_lag, max_lag + bin_size, bin_size))
    autocorr[int(len(autocorr)/2)] = 0  # Set the value at zero lag to zero
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
    return autocorr, bin_centers

# New initial values and bounds from MATLAB implementation
a0 = [20, 1, 30, 2, 0.5, 5, 1.5, 2]
lb = [1, 0.1, 0, 0, -30, 0, 0.1, 0]
ub = [500, 50, 500, 15, 50, 20, 5, 100]

results_list = []
keys = ['unit_id', 'tau_decay', 'tau_rise', 'c', 'd', 'rateasymptote', 'trefrac', 'tau_burst', 'h']

for unit, group in tqdm(times.groupby('unit_id'), total=times['unit_id'].nunique()):
    spike_times = group['time'].values
    autocorr, bin_centers = compute_autocorrelogram(spike_times)

    # Fit the data
    try:
        popt, _ = curve_fit(acg_fit, bin_centers, autocorr, p0=a0, bounds=(lb, ub))
        results_list.append([unit] + list(popt))
    except:
        results_list.append([unit] + [np.nan for _ in range(len(keys) - 1)])

results_df = pd.DataFrame(results_list, columns=keys)

 35%|███▌      | 43/122 [01:21<10:20,  7.85s/it]