## Load data from a single session

#### Select session

In [28]:
import sys
import os
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

from session_utils import *

session_list = find_all_sessions(sheet_path = 'https://docs.google.com/spreadsheets/d/1_Xs5i-rHNTywV-WuQ8-TZliSjTxQCCqGWOD2AL_LIq0/edit#gid=0',
                                 data_path = '/home/isabella/Documents/isabella/jake/recording_data',
                                 sorting_suffix = 'sorting_ks2_custom')

# Create an instance of the SessionSelector class
selector = SessionSelector(session_list)

interactive(children=(Dropdown(description='Select Session:', options=('230503_r1354', '230504_r1354', '230505…

#### Load spike times and position for a single session

In [30]:
from ephys import *
from ephys_utils import select_spikes_by_trial, transform_spike_data, find_template_for_clusters
from spatial_analysis import *

path_to_session = selector.path_to_session

obj = ephys(recording_type = 'nexus', path = path_to_session)

obj.load_spikes('good')
obj.load_pos([i for i, s in enumerate(obj.trial_list)])
pos_sample_rate = obj.pos_data[0]['pos_sampling_rate']

Loading pos file: /home/isabella/Documents/isabella/jake/recording_data/r1354/2023-05-05/230505_r1354_raw_open-field_1.pos
431 LED swaps detected and fixed
Loading pos file: /home/isabella/Documents/isabella/jake/recording_data/r1354/2023-05-05/230505_r1354_raw_t-maze_1.pos
244 LED swaps detected and fixed
Loading pos file: /home/isabella/Documents/isabella/jake/recording_data/r1354/2023-05-05/230505_r1354_raw_open-field_2.pos
272 LED swaps detected and fixed
Loading pos file: /home/isabella/Documents/isabella/jake/recording_data/r1354/2023-05-05/230505_r1354_raw_t-maze_2.pos
220 LED swaps detected and fixed


## Filter good cells to only include CA1 pyramidal cells following criteria for inclusion in Wills et al., 2010
### Criteria:
#### - Cluster depth 0 +-200um
#### - Mean firing rate between 1Hz and 10Hz
#### - Template spike width >300us
#### - Burst index - currently no cutoff value set

In [31]:
cluster_info = obj.spike_data['cluster_info'] #Get cluster info from phy

## Get cluster depths and exclude any outside of 0 +-200um
cluster_info = cluster_info[cluster_info['depth'].between(-200, 200)]

## Filter for mean firing rate - CHECK EXACT VALUES
cluster_info = cluster_info[cluster_info['fr'].between(1, 10)]

## Filter for spike width from template
# Find kilosort template for each cluster (closely approximates spike width) and add to dataframe
template_per_cluster = find_template_for_clusters(obj.spike_data['spike_clusters'], obj.spike_data['spike_templates'])
template_df = pd.DataFrame.from_dict(template_per_cluster, orient = 'index')
template_df.columns = ['template_id']

cluster_info = cluster_info.join(template_df, how = 'left')

# Load templates.npy
templates = np.load(f'{path_to_session}/{path_to_session[-8:-6]}{path_to_session[-5:-3]}{path_to_session[-2:]}_sorting_ks2_custom/templates.npy')
# Load inverse whitening matrix and apply to unwhiten templates
whitening_matrix_inv =  np.load(f'{path_to_session}/{path_to_session[-8:-6]}{path_to_session[-5:-3]}{path_to_session[-2:]}_sorting_ks2_custom/whitening_mat_inv.npy')
unwhitened_templates = np.einsum('ijk,kl->ijl', templates, whitening_matrix_inv)

# Add template values to dataframe
cluster_info['template'] = cluster_info.apply(lambda row: templates[int(row['template_id']), :, int(row['ch'])], axis=1)

# Work out template width peak to trough
sampling_rate = obj.spike_data['sampling_rate']

spike_width_samples = cluster_info['template'].apply(
    lambda x: np.abs(np.argmax(x) - np.argmin(x))
)

# Convert spike width to microseconds and add to dataframe
spike_width_microseconds = (spike_width_samples/sampling_rate)*1000000
cluster_info['spike_width_microseconds'] = spike_width_microseconds

# Filter for spike width > 300us as in Wills et al., 2010
cluster_info = cluster_info[cluster_info['spike_width_microseconds'] > 300]

print(f'{len(cluster_info.index)} cells retained of {len(obj.spike_data["cluster_info"].index)} good cells from phy')

ValueError: Wrong number of items passed 14, placement implies 1

In [34]:
templates.shape

(147, 82, 64)

In [None]:
# Reload spike data only for included cells
clusters_inc = list(cluster_info.index)
obj.load_spikes(clusters_to_load = clusters_inc)

print(f'Reloaded spike data for {len(clusters_inc)} candidate CA1 pyramidal cells')

In [21]:
# Generate autocorrelograms and burst index for each cluster
from burst_index_and_autocorrelograms import *

spike_times_inc = obj.spike_data['spike_times']
spike_clusters_inc = obj.spike_data['spike_clusters']

autocorrelograms, burst_indices = compute_autocorrelograms_and_burst_indices(spike_times_inc, 
                                                                             spike_clusters_inc, 
                                                                             bin_size = 0.001, #1ms
                                                                             time_window = 0.05, #50ms
                                                                             burst_threshold = 0.01 #10ms
                                                                            )
cluster_info['burst_index'] = burst_indices.values()

## Save candidate pyramidal cell cluster IDs to .npy in sorting folder

In [22]:
clusters_inc = cluster_info.index
n_clusters_inc = len(cluster_info.index)

np.save(f'{path_to_session}/clusters_inc.npy', clusters_inc)
print(f'{n_clusters_inc} clusters saved as candidate pyramidal cells. Cluster IDs: {clusters_inc.values}')

1 clusters saved as candidate pyramidal cells. Cluster IDs: [229]


## Plot autocorrelograms for all candidate pyramidal cells

In [6]:
# Plot the autocorrelograms
plot_autocorrelograms_with_dropdown(autocorrelograms)



interactive(children=(Dropdown(description='Cluster ID:', options=(), value=None), Output()), _dom_classes=('w…

## Generate rate maps

In [7]:
# Loop through trials and generate rate maps
rate_maps = {}
occupancy = {}

for trial, trial_name in enumerate(obj.trial_list):
    
    # Select spikes for current trial and transform to create a dict of {cluster: spike_times, cluster:spike_times}
    current_trial_spikes = select_spikes_by_trial(obj.spike_data, trial, obj.trial_offsets)
    current_trial_spikes = transform_spike_data(current_trial_spikes)
    
    
    rate_maps[trial], occupancy[trial] = generate_rate_maps(spike_data = current_trial_spikes,
                               positions = obj.pos_data[trial]['xy_position'],  
                               ppm = 400, 
                               x_bins = 50,
                               y_bins = 50,
                               dt = 1.0,
                               smoothing_window = 10)


## Plot rate maps for all candidate pyramidal cells

In [8]:
#Plot rate maps
interactive_cluster_plot(rate_maps, title_prefix="Rate Maps")

interactive(children=(Dropdown(description='Cluster ID:', options=(), value=None), Output()), _dom_classes=('w…

## Calculate spatial information

In [9]:
from scipy.stats import entropy

def calculate_spatial_information(rate_maps, occupancy, dt=1.0):
    """
    Calculate Skaggs' spatial information score for given rate maps and occupancy.
    
    Parameters:
    - rate_maps: dict
        Dictionary containing smoothed rate maps organized by clusters.
    - occupancy: np.ndarray
        2D array indicating occupancy of each bin.
    - dt: float
        Time window for spike count.
        
    Returns:
    - skaggs_info_dict: dict
        Dictionary containing Skaggs' spatial information scores organized by clusters.
    """
    
    skaggs_info_dict = {}
    
    # Calculate the total time spent in the environment
    total_time = np.nansum(occupancy) * dt

    for cluster, rate_map in rate_maps.items():
        
        # Calculate mean firing rate across all bins
        mean_firing_rate = np.nansum(rate_map * occupancy) / total_time

        # Calculate probability of occupancy for each bin
        prob_occupancy = occupancy / np.nansum(occupancy)

        # Calculate Skaggs' spatial information score
        non_zero_idx = (rate_map > 0) & (prob_occupancy > 0)
        skaggs_info = np.nansum(
            prob_occupancy[non_zero_idx] *
            rate_map[non_zero_idx] *
            np.log2(rate_map[non_zero_idx] / mean_firing_rate)
        )

        skaggs_info_dict[cluster] = skaggs_info
            
    return skaggs_info_dict

# Calculate spatial information - NEEDS ADJUSTING TO HANDLE NAN VALUES
spatial_info = {}

for i in rate_maps.keys():
    spatial_info[i] = calculate_spatial_information(rate_maps[i], occupancy[i], dt = 1)
# spatial_info

In [11]:
cluster_info

Unnamed: 0_level_0,%RPV,Amplitude,ContamPct,ISI_viol,KSLabel,RPV,amp,ch,depth,fr,group,n_spikes,sh,template_id,template,spike_width_microseconds,burst_index
cluster_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
