## 1. Load data from all sessions

#### Load name and path of all sessions

In [None]:
import sys
import os
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

from session_utils import find_all_sessions

# Find all included sessions from Google sheet, with structure session_name: path
session_dict = find_all_sessions(sheet_path = 'https://docs.google.com/spreadsheets/d/1_Xs5i-rHNTywV-WuQ8-TZliSjTxQCCqGWOD2AL_LIq0/edit#gid=0',
                                 data_path = '/home/isabella/Documents/isabella/jake/recording_data',
                                 sorting_suffix = 'sorting_ks2_custom')
print(f'{len(session_dict.items())} sessions found')

## 2. Loop over all sessions, load data and filter for pyramidal cells
### Criteria:
#### - A. Cluster marked 'good' in phy
#### - B. Cluster depth 0 +-200um
#### - C. Mean firing rate between 0.1Hz and 10Hz
#### - D. Template spike width >300us
#### - E. Burst index - currently no cutoff value set
#### - F. Spatial information - currently no cutoff value set

In [None]:
from ephys import *
from ephys_utils import select_spikes_by_trial, transform_spike_data, find_template_for_clusters
from spatial_analysis import *

for session, session_path in session_dict.items():
    
    # Create ephys object
    obj = ephys(recording_type = 'nexus', path = session_path)
    
    ## A. Load good cells from phy
    obj.load_spikes('good')
    
    # Get cluster info from phy
    cluster_info = obj.spike_data['cluster_info']
    # Get total good cells for session
    total_cells = len(cluster_info.index)
    
    
    ## B. Get cluster depths and exclude any outside of 0 +-200um
    cluster_info = cluster_info[cluster_info['depth'].between(-200, 200)]
    

    ## C. Filter for mean firing rate between 1-10 Hz
    cluster_info = cluster_info[cluster_info['fr'].between(0.1, 10)]

    
    ## D. Filter for spike width >300 us from template
    if not cluster_info.empty:
        # Find kilosort template for each cluster (closely approximates spike width) and add to dataframe
        template_per_cluster = find_template_for_clusters(obj.spike_data['spike_clusters'], obj.spike_data['spike_templates'])
        template_df = pd.DataFrame.from_dict(template_per_cluster, orient = 'index')
        template_df.columns = ['template_id']

        cluster_info = cluster_info.join(template_df, how = 'left')

        # Load templates.npy
        templates = np.load(f'{session_path}/{session_path[-8:-6]}{session_path[-5:-3]}{session_path[-2:]}_sorting_ks2_custom/templates.npy')
        # Load inverse whitening matrix and apply to unwhiten templates
        whitening_matrix_inv =  np.load(f'{session_path}/{session_path[-8:-6]}{session_path[-5:-3]}{session_path[-2:]}_sorting_ks2_custom/whitening_mat_inv.npy')
        unwhitened_templates = np.einsum('ijk,kl->ijl', templates, whitening_matrix_inv)

        # Add template values to dataframe
        cluster_info['template'] = cluster_info.apply(lambda row: templates[int(row['template_id']), :, int(row['ch'])], axis=1)

        # Work out template width peak to trough
        sampling_rate = obj.spike_data['sampling_rate']

        spike_width_samples = cluster_info['template'].apply(
            lambda x: np.abs(np.argmax(x) - np.argmin(x))
        )

        # Convert spike width to microseconds and add to dataframe
        spike_width_microseconds = (spike_width_samples/sampling_rate)*1000000
        cluster_info['spike_width_microseconds'] = spike_width_microseconds

        # Filter for spike width > 300us as in Wills et al., 2010
        cluster_info = cluster_info[cluster_info['spike_width_microseconds'] > 300]
        
    ## E. Calculate burst index and filter
    if not cluster_info.empty:
        
        # Reload spike data only for included cells
        clusters_inc = list(cluster_info.index)
        obj.load_spikes(clusters_to_load = clusters_inc)
        
        # Generate autocorrelograms and burst index for each cluster
        from burst_index_and_autocorrelograms import *

        spike_times_inc = obj.spike_data['spike_times']
        spike_clusters_inc = obj.spike_data['spike_clusters']

        autocorrelograms, first_moments = compute_autocorrelograms_and_first_moment(spike_times_inc, 
                                                                                     spike_clusters_inc, 
                                                                                     bin_size = 0.001, #1ms
                                                                                     time_window = 0.05) #50ms
        
        cluster_info['first_moment_AC'] = first_moments.values()
        
        # Filter for first moment <25
        cluster_info = cluster_info[cluster_info['first_moment_AC'] < 25]
        
#     ## F. Calculate spatial information and filter
#     if not cluster_info.empty:
#         # Load position data for all trials
#         obj.load_pos([i for i, s in enumerate(obj.trial_list)], output_flag = False)
        
#         # Loop through trials and generate rate maps
#         rate_maps = {}
#         occupancy = {}

#         for trial, trial_name in enumerate(obj.trial_list):

#             # Select spikes for current trial and transform to create a dict of {cluster: spike_times, cluster:spike_times}
#             current_trial_spikes = select_spikes_by_trial(obj.spike_data, trial, obj.trial_offsets)
#             current_trial_spikes = transform_spike_data(current_trial_spikes)


#             rate_maps[trial], occupancy[trial] = make_rate_maps(spike_data = current_trial_spikes,
#                                        positions = obj.pos_data[trial]['xy_position'],  
#                                        ppm = 400, 
#                                        x_bins = 50,
#                                        y_bins = 50,
#                                        dt = 1.0,
#                                        smoothing_window = 10)
            
#         # Calculate spatial information - NEEDS ADDING
        
#         # Filter for spatial information > ??? - NEEDS ADDING
        
    
    ## SAVE INCLUDED CLUSTER IDs TO .NPY
    clusters_inc = cluster_info.index
    n_clusters_inc = len(cluster_info.index)

    np.save(f'{session_path}/clusters_inc.npy', clusters_inc)        
        
    print(f'Session {session}: {n_clusters_inc} cells retained of {total_cells} good cells from phy. Retained cells: {clusters_inc.values}')

In [None]:
spike_clusters_inc.shape

## 7. Calculate spatial information

In [None]:
from scipy.stats import entropy

def calculate_spatial_information(rate_maps, occupancy, dt=1.0):
    """
    Calculate Skaggs' spatial information score for given rate maps and occupancy.
    
    Parameters:
    - rate_maps: dict
        Dictionary containing smoothed rate maps organized by clusters.
    - occupancy: np.ndarray
        2D array indicating occupancy of each bin.
    - dt: float
        Time window for spike count.
        
    Returns:
    - skaggs_info_dict: dict
        Dictionary containing Skaggs' spatial information scores organized by clusters.
    """
    
    skaggs_info_dict = {}
    
    # Calculate the total time spent in the environment
    total_time = np.nansum(occupancy) * dt

    for cluster, rate_map in rate_maps.items():
        
        # Calculate mean firing rate across all bins
        mean_firing_rate = np.nansum(rate_map * occupancy) / total_time

        # Calculate probability of occupancy for each bin
        prob_occupancy = occupancy / np.nansum(occupancy)

        # Calculate Skaggs' spatial information score
        non_zero_idx = (rate_map > 0) & (prob_occupancy > 0)
        skaggs_info = np.nansum(
            prob_occupancy[non_zero_idx] *
            rate_map[non_zero_idx] *
            np.log2(rate_map[non_zero_idx] / mean_firing_rate)
        )

        skaggs_info_dict[cluster] = skaggs_info
            
    return skaggs_info_dict

# Calculate spatial information - NEEDS ADJUSTING TO HANDLE NAN VALUES
spatial_info = {}

for i in rate_maps.keys():
    spatial_info[i] = calculate_spatial_information(rate_maps[i], occupancy[i], dt = 1)
spatial_info

In [None]:
cluster_info