### Configuration

In [1]:
import os
import numpy as np
import pandas as pd
import sklearn.preprocessing

import umap
from scipy.integrate import simps
from scipy.spatial.distance import euclidean

from scipy.spatial.distance import mahalanobis
from scipy.linalg import inv
from sklearn.metrics import pairwise_distances
from scipy.stats import chi2

import utils__config

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
os.chdir(utils__config.working_directory)
os.getcwd()

'Z:\\Layton\\Sleep_051324'

### Parameters

In [3]:
recordings = [
    {
        'recording_id': 'Feb02',
        'recording_length': 2,
        'mean_waveforms': 'Data/S01_Feb02_mean_waveforms.csv',
        'sampled_waveforms': 'Data/S01_Feb02_waveforms_sampled.csv'
    },
    {
        'recording_id': 'Jul11',
        'recording_length': 9.68,
        'mean_waveforms': 'Data/S05_Jul11_mean_waveforms.csv',
        'sampled_waveforms': 'Data/S05_Jul11_waveforms_sampled.csv'
    },
    {
        'recording_id': 'Jul12',
        'recording_length': 10.55,
        'mean_waveforms': 'Data/S05_Jul12_mean_waveforms.csv',
        'sampled_waveforms': 'Data/S05_Jul12_waveforms_sampled.csv'
    },
    {
        'recording_id': 'Jul13',
        'recording_length': 10.40,
        'mean_waveforms': 'Data/S05_Jul13_mean_waveforms.csv',
        'sampled_waveforms': 'Data/S05_Jul13_waveforms_sampled.csv'
    }
]

output_path = 'Cache/waveform_stability.csv'

seed = 42

### Load Data

In [4]:
data = pd.DataFrame()

for recording in recordings:
    
    # Process mean_waveforms
    mean_waveforms = pd.read_csv(recording['mean_waveforms'])
    mean_waveforms['type'] = 'mean'
    mean_waveforms['percent'] = np.NaN
    mean_waveforms['spike_id'] = mean_waveforms.apply(lambda row: str(row['spike_id']) + "_" + str('mean'), axis=1)
    
    # Process sampled_waveforms
    sampled_waveforms = pd.read_csv(recording['sampled_waveforms'])
    sampled_waveforms['type'] = 'sample'
    sampled_waveforms['percent'] = ((sampled_waveforms['milliseconds'] / 1000) / (recording['recording_length'] * 3600)) * 100
    sampled_waveforms['spike_id'] = sampled_waveforms.apply(lambda row: str(row['spike_id']) + "_" + str('sample'), axis=1)

    # Row bind mean_waveforms and sampled_waveforms
    combined = pd.concat([mean_waveforms, sampled_waveforms], ignore_index=True)
    combined['unit_id'] = combined['unit_id'].astype(str) + '_' + recording['recording_id']
    combined['spike_id'] = combined.apply(lambda row: str(row['spike_id']) + "_" + str(row['unit_id']), axis=1)
    
    # Append to data
    data = pd.concat([data, combined], ignore_index=True)

In [5]:
# Format data and set final spike_id
data = data[['unit_laterality', 'unit_region', 'unit_id', 'type', 'spike_id', 'percent', 'time_point', 'amplitude']]
data['spike_id'] = pd.factorize(data['spike_id'])[0] + 1
data

# Filter out the mean waveforms for use as reference in calculating Euclidean distances
mean_waveforms = data[data['type'] == 'mean']

# Calculate metrics for sample waveforms
sample_waveforms = data[data['type'] == 'sample']
sample_waveforms

Unnamed: 0,unit_laterality,unit_region,unit_id,type,spike_id,percent,time_point,amplitude
2368,right,CLA,S01_Ch195_neg_Unit3_Feb02,sample,38,0.041281,1,6.967503
2369,right,CLA,S01_Ch195_neg_Unit3_Feb02,sample,38,0.041281,2,4.712843
2370,right,CLA,S01_Ch195_neg_Unit3_Feb02,sample,38,0.041281,3,0.827369
2371,right,CLA,S01_Ch195_neg_Unit3_Feb02,sample,38,0.041281,4,-1.831863
2372,right,CLA,S01_Ch195_neg_Unit3_Feb02,sample,38,0.041281,5,-2.136348
...,...,...,...,...,...,...,...,...
7815803,right,AMY,S05_Ch240_neg_Unit3_Jul13,sample,122122,99.976814,60,16.517321
7815804,right,AMY,S05_Ch240_neg_Unit3_Jul13,sample,122122,99.976814,61,8.810473
7815805,right,AMY,S05_Ch240_neg_Unit3_Jul13,sample,122122,99.976814,62,4.711396
7815806,right,AMY,S05_Ch240_neg_Unit3_Jul13,sample,122122,99.976814,63,4.422690


### Calculate Stability Metrics

In [6]:
def calculate_unit_metrics(df, mean_waveforms):
    # Initialize results list
    results = []

    for unit_id, group in df.groupby('unit_id'):
        polarity = 'neg' if 'neg' in unit_id else 'pos'
        metrics = {
            'fwhm': [],
            'auc': [],
            'max_amplitude': [],
            'euclidean_distance': []
        }
        
        mean_waveform = mean_waveforms[(mean_waveforms['unit_id'] == unit_id)].sort_values('time_point')['amplitude'].values
        
        for _, spike_group in group.groupby('spike_id'):
            waveform = spike_group.sort_values('time_point')['amplitude'].values
            
            # Handle polarity
            if polarity == 'neg':
                waveform = -waveform  # Flip waveform for negative-spiking units
            
            # Calculate AUC
            auc = simps(waveform, dx=1)
            metrics['auc'].append(auc)
            
            # Calculate Maximum Amplitude
            max_amplitude = np.max(waveform)
            metrics['max_amplitude'].append(max_amplitude)
            
            # Calculate FWHM
            half_max = max_amplitude / 2
            indices_above_half_max = np.where(waveform >= half_max)[0]
            fwhm = indices_above_half_max[-1] - indices_above_half_max[0] if len(indices_above_half_max) > 0 else np.nan
            metrics['fwhm'].append(fwhm)
            
            # Calculate Euclidean distance from the mean waveform
            if len(mean_waveform) == len(waveform):
                euclidean_dist = euclidean(mean_waveform, waveform)
                metrics['euclidean_distance'].append(euclidean_dist)

        # Calculate Coefficient of Variation for each metric
        cv_results = {
            'unit_id': unit_id,
            'fwhm_cv': np.std(metrics['fwhm']) / np.mean(metrics['fwhm']) if np.mean(metrics['fwhm']) else np.nan,
            'auc_cv': np.std(metrics['auc']) / np.mean(metrics['auc']) if np.mean(metrics['auc']) else np.nan,
            'max_amplitude_cv': np.std(metrics['max_amplitude']) / np.mean(metrics['max_amplitude']) if np.mean(metrics['max_amplitude']) else np.nan,
            'euclidean_distance_cv': np.std(metrics['euclidean_distance']) / np.mean(metrics['euclidean_distance']) if np.mean(metrics['euclidean_distance']) else np.nan
        }
        results.append(cv_results)

    return pd.DataFrame(results)

In [7]:
def perform_umap(df):
    # Ensure the DataFrame is sorted by unit_id, spike_id, and time_point for consistent row ordering
    sorted_df = df.sort_values(['unit_id', 'spike_id', 'time_point'])

    # Pivot the data to create a 2D array where each row represents a spike and columns are time points
    waveforms = sorted_df.pivot(index='spike_id', columns='time_point', values='amplitude').values

    # Initialize UMAP reducer
    reducer = umap.UMAP(n_neighbors=15, random_state=42)

    # Perform UMAP transformation
    embedding = reducer.fit_transform(waveforms)

    # Create a DataFrame of the UMAP results
    umap_df = pd.DataFrame(embedding, columns=['umap_1', 'umap_2'])

    # Reset the index in sorted_df to make sure we can pull the correct spike_id and unit_id
    sorted_df = sorted_df.drop_duplicates('spike_id').set_index('spike_id')

    # Assign spike_id back to umap_df from the index of waveforms to ensure correct alignment
    umap_df['spike_id'] = sorted_df.index

    # Retrieve the unit_id for each spike_id and add it to umap_df
    umap_df['unit_id'] = sorted_df['unit_id']
    umap_df['unit_region'] = sorted_df['unit_region']

    return umap_df

In [8]:
def calculate_l_ratio(umap_df):
    l_ratios = []

    for (unit_id, unit_region), cluster in umap_df.groupby(['unit_id', 'unit_region']):
        cluster_center = cluster[['umap_1', 'umap_2']].mean().values
        other_data = umap_df[umap_df['unit_id'] != unit_id][['umap_1', 'umap_2']].values

        # Calculate covariance matrix of the cluster
        covariance_matrix = np.cov(cluster[['umap_1', 'umap_2']], rowvar=False)
        # Regularize matrix to prevent issues with matrix inversion if covariance matrix is singular or near-singular
        inv_cov_matrix = np.linalg.inv(covariance_matrix + np.eye(covariance_matrix.shape[0]) * 1e-5)

        # Calculate Mahalanobis distances within the cluster
        mahalanobis_distances = cluster[['umap_1', 'umap_2']].apply(
            lambda row: mahalanobis(row.values, cluster_center, inv_cov_matrix), axis=1
        )

        # Chi-squared cumulative distribution function values for the distances
        chi2_vals = chi2.cdf(mahalanobis_distances**2, df=2)  # 2 degrees of freedom for 2 features

        # Calculate Mahalanobis distances for all non-cluster spikes to the cluster center
        all_distances = [mahalanobis(point, cluster_center, inv_cov_matrix) for point in other_data]

        # L-Ratio calculation
        L_ratio = chi2_vals.sum() / len(all_distances)
        
        l_ratios.append({
            'unit_id': unit_id,
            'unit_region': unit_region,
            'l_ratio': L_ratio
        })

    return pd.DataFrame(l_ratios)

In [9]:
unit_metrics = calculate_unit_metrics(sample_waveforms, mean_waveforms)
umap_results = perform_umap(sample_waveforms)
l_ratios = calculate_l_ratio(umap_results)
final_df = unit_metrics.merge(l_ratios, on='unit_id')
final_df

Unnamed: 0,unit_id,fwhm_cv,auc_cv,max_amplitude_cv,euclidean_distance_cv,unit_region,l_ratio
0,S01_Ch195_neg_Unit3_Feb02,1.085909,-0.803846,0.223294,0.090769,CLA,0.002575
1,S01_Ch195_pos_Unit2_Feb02,0.347622,3.463531,0.187287,0.264948,CLA,0.002741
2,S01_Ch196_neg_Unit1_Feb02,1.058625,-1.036876,0.235700,0.088132,CLA,0.002795
3,S01_Ch196_neg_Unit3_Feb02,0.436251,0.524794,0.177187,0.108910,CLA,0.003424
4,S01_Ch196_neg_Unit4_Feb02,1.537203,-0.474102,0.132667,0.055160,CLA,0.000935
...,...,...,...,...,...,...,...
117,S05_Ch239_neg_Unit3_Jul11,0.403287,0.408167,0.180976,0.081559,AMY,0.004064
118,S05_Ch240_neg_Unit1_Jul11,0.523054,0.597725,0.249160,0.127569,AMY,0.004394
119,S05_Ch240_neg_Unit2_Jul12,0.786767,0.962331,0.211506,0.074270,AMY,0.004138
120,S05_Ch240_neg_Unit2_Jul13,0.846109,0.957269,0.241092,0.117439,AMY,0.004216


### Export for UMAP

In [10]:
final_df.to_csv(output_path, index=False)
final_df

Unnamed: 0,unit_id,fwhm_cv,auc_cv,max_amplitude_cv,euclidean_distance_cv,unit_region,l_ratio
0,S01_Ch195_neg_Unit3_Feb02,1.085909,-0.803846,0.223294,0.090769,CLA,0.002575
1,S01_Ch195_pos_Unit2_Feb02,0.347622,3.463531,0.187287,0.264948,CLA,0.002741
2,S01_Ch196_neg_Unit1_Feb02,1.058625,-1.036876,0.235700,0.088132,CLA,0.002795
3,S01_Ch196_neg_Unit3_Feb02,0.436251,0.524794,0.177187,0.108910,CLA,0.003424
4,S01_Ch196_neg_Unit4_Feb02,1.537203,-0.474102,0.132667,0.055160,CLA,0.000935
...,...,...,...,...,...,...,...
117,S05_Ch239_neg_Unit3_Jul11,0.403287,0.408167,0.180976,0.081559,AMY,0.004064
118,S05_Ch240_neg_Unit1_Jul11,0.523054,0.597725,0.249160,0.127569,AMY,0.004394
119,S05_Ch240_neg_Unit2_Jul12,0.786767,0.962331,0.211506,0.074270,AMY,0.004138
120,S05_Ch240_neg_Unit2_Jul13,0.846109,0.957269,0.241092,0.117439,AMY,0.004216
