## 1. Load data from all sessions

#### Load name and path of all sessions

In [2]:
sheet_path = 'https://docs.google.com/spreadsheets/d/1_Xs5i-rHNTywV-WuQ8-TZliSjTxQCCqGWOD2AL_LIq0/edit#gid=0'
data_path = '/home/isabella/Documents/isabella/jake/recording_data'
############

from spelt.session_utils import find_all_sessions
# Find all included sessions from Google sheet, with structure session_name: path
session_dict = find_all_sessions(sheet_path = sheet_path,
                                 data_path = data_path,
                                 sorting_suffix = 'sorting_ks2_custom')
print(f'{len(session_dict.items())} sessions found')

58 sessions found


## 2. Loop over all sessions, load data and filter for pyramidal cells
### Criteria:
#### - A. Cluster marked 'good' in phy
#### - B. Cluster depth 0 +-200um
#### - C. Mean FR < 10 Hz
#### - D. Mean spike width >500us
#### - E. Burst index - first moment of AC < 25

In [3]:
import pandas as pd
import matplotlib.pyplot as plt
from spelt.ephys import ephys
from spelt.postprocessing.burst_index_and_autocorrelograms import *

total_cells_inc = 0
total_cells_processed = 0

cluster_info_all = pd.DataFrame()
df_all_cells = pd.DataFrame(columns = ['clusters_inc', 'ephys_object'], dtype = 'object')

for session, session_info in session_dict.items():
    
    session_path = session_info[0]
    probe_type = session_info[1]
    if probe_type == '5x12_buz':
        recording_type = 'nexus'
    else:
        recording_type = probe_type

    # Create ephys object
    obj = ephys(recording_type = recording_type, path = session_path, sheet_url = sheet_path)
    
    ## A. Load good cells from phy
    obj.load_spikes('good')
    
    # Get cluster info from phy
    cluster_info = obj.spike_data['cluster_info']

    cluster_info['session'] = session
    
    # Get total good cells for session
    total_cells = len(cluster_info.index)
    total_cells_processed += total_cells
    
    # ## B. Get cluster depths and exclude any outside of 0 +-200um
    # cluster_info = cluster_info[cluster_info['depth'].between(-200, 200)].copy()

    # ## C. Filter for mean firing rate between 0-10 Hz
    # cluster_info = cluster_info[cluster_info['fr'].between(0, 10)]
    
    # ## D. Filter for spike width >300 us from template
    # if not cluster_info.empty:
    #     obj.load_mean_waveforms(clusters_to_load = list(cluster_info.index)) #Pick up to 500 spikes at random for performance

    #     # Calculate spike width for each cluster from mean of every 50th spike 
    #     for cluster, mean_waveform in obj.mean_waveforms.items():
            
    #         # Find peak to trough time in us
    #         peak = np.argmin(mean_waveform)
    #         trough = np.argmax(mean_waveform[peak:]) + peak
    #         peak_to_trough = trough - peak         

    #         # # Plot for sanity
    #         # print(f'Spike width for cluster {cluster}: {peak_to_trough / obj.spike_data["sampling_rate"] * 1e6} us')
    #         # plt.plot(mean_waveform)
    #         # plt.scatter(trough, mean_waveform[trough])
    #         # plt.scatter(peak, mean_waveform[peak])
    #         # plt.show()

    #         cluster_info.loc[cluster, 'spike_width_microseconds'] = (peak_to_trough / obj.spike_data['sampling_rate']) * 1e6
        
    #     # Filter for spike width > 300us as in Wills et al., 2010
    #     cluster_info = cluster_info[cluster_info['spike_width_microseconds'] > 500].copy()

    
    # ## E. Calculate burst index and filter
    # if not cluster_info.empty:
        
    #     # Reload spike data only for included cells
    #     clusters_inc = list(cluster_info.index)
    #     obj.load_spikes(clusters_to_load = clusters_inc)
        
    #     # Generate autocorrelograms and burst index for each cluster
    #     spike_times_inc = obj.spike_data['spike_times']
    #     spike_clusters_inc = obj.spike_data['spike_clusters']

    #     autocorrelograms, first_moments = compute_autocorrelograms_and_first_moment(spike_times_inc, 
    #                                                                                  spike_clusters_inc, 
    #                                                                                  bin_size = 0.001, #1ms
    #                                                                                  time_window = 0.05) #50ms
        
    #     cluster_info['first_moment_AC'] = first_moments.values()
        
    #     # Filter for first moment <25
    #     cluster_info = cluster_info[cluster_info['first_moment_AC'] < 25]
            
    ## SAVE INCLUDED CLUSTER IDs TO .NPY
    clusters_inc = cluster_info.index
    n_clusters_inc = len(cluster_info.index)
    total_cells_inc += n_clusters_inc

    # Append session to cluster_info_all index
    cluster_info.index = cluster_info['session'] + '_' + cluster_info.index.astype(str)
    
    # Concat cluster info for all sessions
    cluster_info_all = pd.concat([cluster_info_all, cluster_info])

    # Save included clusters to .npy
    np.save(f'{session_path}/clusters_inc.npy', clusters_inc)

    # Make df_all_cells
    df_all_cells.loc[session, 'clusters_inc'] = clusters_inc.values
    df_all_cells.loc[session, 'ephys_object'] = obj      
        
    print(f'Session {session}: {n_clusters_inc} cells retained of {total_cells} good cells from phy. Retained cells: {clusters_inc.values}')

print(f'Total cells retained: {total_cells_inc} of total {total_cells_processed} good cells from phy')

# Save cluster_info_all and df_all_cells to pickle
pd.to_pickle(cluster_info_all, '/home/isabella/Documents/isabella/jake/ephys_analysis/processed_data/cluster_info_all.pkl')
pd.to_pickle(df_all_cells, '/home/isabella/Documents/isabella/jake/ephys_analysis/processed_data/df_all_cells.pkl')
print('cluster_info_all and df_all_cells saved to pickle')

Session 230503_r1354: 5 cells retained of 5 good cells from phy. Retained cells: [317 318 322 324 328]
Session 230504_r1354: 8 cells retained of 8 good cells from phy. Retained cells: [215 223 229 238 240 243 245 248]
Session 230505_r1354: 4 cells retained of 4 good cells from phy. Retained cells: [162 166 172 180]
Session 230506_r1354: 23 cells retained of 23 good cells from phy. Retained cells: [191 203 208 214 220 226 233 236 238 242 247 251 256 260 264 269 274 281
 292 300 308 322 325]
Session 230507_r1354: 17 cells retained of 17 good cells from phy. Retained cells: [ 37  62 141 155 163 192 205 215 231 248 253 326 371 379 380 388 389]
Session 230508_r1354: 15 cells retained of 15 good cells from phy. Retained cells: [ 30 239 243 261 263 297 318 320 326 343 350 356 363 374 390]
Session 230509_r1354: 22 cells retained of 22 good cells from phy. Retained cells: [131 192 194 200 213 215 232 233 261 282 301 307 356 359 361 362 363 364
 365 366 367 368]
Session 230510_r1354: 15 cells re

In [5]:
plt.scatter(cluster_info_all['spike_width_microseconds'], cluster_info_all['fr'])
plt.xlabel('Spike width (us)')
# Set y axis to log 10 scale
plt.yscale('log')
plt.ylabel('Firing rate')
plt.title('Spike width vs firing rate')
plt.show()

KeyError: 'spike_width_microseconds'

In [None]:
# Load spatial info all
spatial_info_all = pd.read_pickle('/home/isabella/Documents/isabella/jake/ephys_analysis/processed_data/spatial_info_all.pkl')

# Select only rows including 'open-field_1'
spatial_info_open_field_1 = spatial_info_all[spatial_info_all.index.str.contains('raw_open-field_1')]

# Drop 'open-field_1' from index
spatial_info_open_field_1.index = spatial_info_open_field_1.index.str.replace('raw_open-field_1_', '')


spatial_info_open_field_1

In [None]:
# Merge cluster_info_all with spatial_info_open_field_1
cluster_info_spatial_info = pd.merge(cluster_info_all, spatial_info_open_field_1, left_index = True, right_index = True)

# Plot spike width against bits per spike
plt.scatter(cluster_info_spatial_info['spike_width_microseconds'], cluster_info_spatial_info['bits_per_spike'])
plt.xlabel('Spike width (us)')
plt.ylabel('Bits per spike')
plt.title('Spike width vs bits per spike')
plt.show()

In [None]:
# Plot firing rate against bits per spike
plt.scatter(cluster_info_spatial_info['fr'], cluster_info_spatial_info['bits_per_spike'])
plt.xlabel('Firing rate')
plt.ylabel('Bits per spike')
plt.title('Firing rate vs bits per spike')
plt.show()

In [None]:
# Plot first moment of autocorrelogram against bits per spike
plt.scatter(cluster_info_spatial_info['first_moment_AC'], cluster_info_spatial_info['bits_per_spike'])
plt.xlabel('First moment of autocorrelogram')
plt.ylabel('Bits per spike')
plt.title('First moment of autocorrelogram vs bits per spike')
plt.show()

In [None]:
# Plot spike width vs burst index
plt.scatter(cluster_info_spatial_info['spike_width_microseconds'], cluster_info_spatial_info['first_moment_AC'])
plt.xlabel('Spike width (us)')
plt.ylabel('Burst index')
plt.title('Spike width vs burst index')
plt.show()

In [None]:
import spikeinterface.full as si
import spikeinterface.extractors as se
import spikeinterface.curation as sc
import numpy as np

from spikeinterface.postprocessing import compute_template_metrics, compute_spike_amplitudes

recording_path = '/home/isabella/Documents/isabella/jake/recording_data/r1354/2023-05-06/230506_r1354_raw_open-field_1_preprocessed'

recording = si.load_extractor(recording_path)
recording

sorting = se.read_phy('/home/isabella/Documents/isabella/jake/recording_data/r1354/2023-05-06/230506_sorting_ks2_custom', exclude_cluster_groups=['noise', 'mua'])
sorting

sorting = sc.remove_excess_spikes(sorting, recording)

clusters_inc = np.load('/home/isabella/Documents/isabella/jake/recording_data/r1354/2023-05-06/clusters_inc.npy')


waveform_extractor = si.extract_waveforms(recording=recording, 
                                          sorting=sorting, 
                                          folder = f'waveforms',
                                          load_if_exists=True,
                                          n_jobs = -1)


template_metrics = compute_template_metrics(waveform_extractor)
amplitudes = compute_spike_amplitudes(waveform_extractor, outputs = 'by_unit')
display(template_metrics.loc[clusters_inc, :])

In [None]:
import spikeinterface.preprocessing as spre
import spikeinterface.widgets as sw
import matplotlib.pyplot as plt
import ipywidgets as widgets

recording_highpass = spre.highpass_filter(recording, freq_min=300)

channel_dict = si.get_template_extremum_channel(waveform_extractor)
good_units = {key: channel_dict[key] for key in channel_dict if key in clusters_inc}

channels = list(good_units.values())
unit_ids = list(good_units.keys())

unit_index = 1

def plot_data(start, time_window):
    fig, ax = plt.subplots()
    sw.plot_traces(recording_highpass, time_range=(start, start + time_window), channel_ids=[channels[unit_index]], show_channel_ids=True, ax=ax, color = 'blue')
    sw.plot_rasters(sorting=sorting, segment_index=0, time_range=(start, start + time_window), unit_ids=[unit_ids[unit_index]], color='orange', ax=ax)
    plt.show()

# Define the time range for the widget
total_time = recording_highpass.get_num_frames() / recording_highpass.get_sampling_frequency()
time_window = 1  # Length of the time window in seconds

# Create the widget
widgets.interact(plot_data,
                 start=widgets.FloatSlider(min=0, max=total_time-time_window, step=0.008, value=0, description='Start Time'),
                 time_window=widgets.FloatSlider(min=0.001, max=time_window, step=0.0005, value=time_window, description='Time Window'))
