<a href="https://colab.research.google.com/github/mmcinnestaylor/NMA-CN-2022/blob/main/project/steinmetz_analysis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [30]:
import os, csv, math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from datetime import datetime
from tqdm import tqdm

In [3]:
BASE_PATH = '/content/drive/MyDrive/steinmetz' # Path to dataset base directory

In [35]:
# @title Helper Functions
#@markdown - get_spike_trains(clustered_spikes, start_time=0, bin_size=10, window_size=100, format='binary')
#@markdown  - Generates spike trains for each valid cluster in a recording session.
#@markdown - get_area_spike_trains(clustered_spikes, clusters_locs, start_time=0, bin_size=10, window_size=100, format='counts', aggregate=True)
#@markdown  - Generates spike trains separated by brain area
#@markdown - clusters_to_area(cluster_ids, cluster_locs)
#@markdown  - Generates a dictionary keyed by brain areas monitored during a recording session. The values are lists of np arrays corresponding the the cluster
#@markdown - group_trial_timestamps(intervals, vis_stims, go_cues, responses, feedback)
#@markdown  - Groups all timepoints per trial: start, visual stimulus, go cue, response, feedback, trial end

def get_spike_trains(clustered_spikes, start_time=0, bin_size=10, window_size=100, format='binary'):
  """
  Spike train generator function.
  Builds trains over a given time window using discrete time bins. 

  Args:
    clustered_spikes : a dictionary of clusters
    start_time(ms)   : the starting time point of the window within a session
    bin_size(ms)     : the size of the discrete time step within the window
    window_size(ms)  : the width of the time window
    window_size(ms)  : binary or counts

  Returns:
    spike_trains     : the spikes trains of the neurons in the window
  """
  
  num_neurons = len(clustered_spikes.keys())
  num_bins = int(window_size/bin_size)

  if format == 'binary' or format == 'counts':
    # Initialize spike train matrix [num_neurons x num_bins]
    spike_trains = np.zeros((num_neurons,num_bins), dtype=int)

    # Iterate over all neurons in a recording session
    for i,cluster_id in enumerate(sorted(clustered_spikes.keys())):
      # Iterate over time bins in the window
      for j in range(num_bins):
        # Define bin start and end times
        bin_start = start_time + (j * bin_size)
        bin_end = start_time + ((j+1) * bin_size)
        
        # Convert spike times of a given neuron to millisecond scale
        neuron_spikes = clustered_spikes[cluster_id] * 1000 
        
        # Check if a spike occured in time bin [start, end)
        spikes = np.logical_and(neuron_spikes>=bin_start, neuron_spikes<bin_end)
        if format == 'binary':
          if True in spikes:
            spike_trains[i][j] = 1
        else:
          spike_trains[i][j] = np.count_nonzero(spikes==True) 
  else: # to be implemented
    pass
    
  return spike_trains


def clusters_to_area(cluster_ids, cluster_locs):
  # Get set of brain areas recorded
  brain_areas = set(cluster_locs.tolist())

  # Build output dict-> key=brain area, val=all clusters in brain area
  clusters_per_area = {area: [] for area in brain_areas}

  # Group clusters by brain area
  for id in cluster_ids:
    cluster_area = cluster_locs[id]
    clusters_per_area[cluster_area].append(id)
  
  # Convert lists to np arrays for efficiency 
  for area in clusters_per_area.keys():
    clusters_per_area[area] = np.array(clusters_per_area[area])

  return clusters_per_area


def get_area_spike_trains(clustered_spikes, clusters_locs, start_time=0, bin_size=10, window_size=100, format='counts', aggregate=True):
  num_bins = int(window_size/bin_size)
  clusters_by_area = clusters_to_area(clustered_spikes.keys(), clusters_locs)
  area_spikes = {area: np.zeros((clusters_by_area[area].shape[0], num_bins)) for area in clusters_by_area.keys()}

  # Process 
  for area in area_spikes.keys():
    print(f'Processing area: {area} - {len(clusters_by_area[area])} clusters')

    # Iterate through the clusters in an area
    for i,cluster_id in enumerate(clusters_by_area[area]):
      
      # Convert spike times of a given neuron to millisecond scale
      neuron_spikes = clustered_spikes[cluster_id] * 1000

      # Iterate through time bins
      for j in range(num_bins):
        
        # Define bin start and end times
        bin_start = start_time + (j * bin_size)
        bin_end = start_time + ((j+1) * bin_size)

        # Check if a spike occured in time bin [start, end)
        spikes = np.logical_and(neuron_spikes>=bin_start, neuron_spikes<bin_end)
      
        if format == 'binary':
          if True in spikes:
            area_spikes[area][i][j] = 1
        else: # Counts
          area_spikes[area][i][j] = np.count_nonzero(spikes==True)

    # Combines all area neuron behavior into a single array
    if aggregate==True and format=='counts':
      area_spikes[area] = area_spikes[area].sum(axis=0)
    elif aggregate==True and format=='binary':
      counts = area_spikes[area].sum(axis=0)
      area_spikes[area] = 1 * np.logical_and(counts > 0)

  return area_spikes

def group_trial_timestamps(intervals, vis_stims, go_cues, responses, feedback):
  num_trials = intervals.shape[0]
  trial_timepoints = np.zeros((num_trials, 6))

  for trial in range(num_trials):     
    # Trial start
    trial_timepoints[trial][0] = intervals[trial][0]
    # Visual stimulus
    trial_timepoints[trial][1] = vis_stims[trial][0]
    # Go gue
    trial_timepoints[trial][2] = go_cues[trial][0]
    # Response(wheel turn)
    trial_timepoints[trial][3] = responses[trial][0]
    # Feeback
    trial_timepoints[trial][4] = feedback[trial][0]
    # Trial end
    trial_timepoints[trial][5] = intervals[trial][1]

  return trial_timepoints


In [None]:
# @title Load Data
#@markdown - Uses the path defined in BASE_PATH
#@markdown - Assumes each recording session resides in a separate subdirectory formatted as mouseID_year-month-date
#@markdown - **'mouse_id'**: (str) Mouse name per session directory format
#@markdown - **'session_date'**: (date) Date of session per session directory format
#@markdown - **'clustered_spikes'**: A dictionary of clusters(neurons) recorded during the session. Clusters with annotation values < 1 are not included.
#@markdown  - keys(int): The cluster's integer ID as specified in the datafiles.
#@markdown  - values(np.array): A 1-D array of size *nSpikes* where each entry corresponds to a time point in seconds during the recording session in which the cluster(neuron) produced a spike.
#@markdown - **'clusters_locs'**: (np.array) A 1-D  array of size *nClusters* where each entry corresponds to the the Allen CCF brain area of cluster *n*.
#@markdown  - This array includes **all** clusters from a recording session. Thus the keys of `clustered_spikes` should be used as the index number when accessing the values here.
#@markdown - **'trials_fb_times'**: trials.feedback_times.npy
#@markdown - **'trials_fb_type'**: trials.feedbackType.npy
#@markdown - **'trials_go_times'**: trials.goCue_times.npy
#@markdown - **'trials_included'**: trials.included.npy
#@markdown - **'trials_intervals'**: trials.intervals.npy
#@markdown - **'trials_rep_num'**: trials.repNum.npy
#@markdown - **'trials_resp_choice'**: trials.response_choice.npy
#@markdown - **'trials_resp_times'**: trials.response_times.npy
#@markdown - **'trials_vis_times'**: trials.visualStim_times.npy


all_session_data = dict()

# Order sessions by surname + date
for i,session in tqdm(enumerate(sorted(os.listdir(BASE_PATH)))):
  session_path = os.path.join(BASE_PATH, session)

  print(f"Loading session: {session} ")

  # Load spike and cluster data
  raw_spikes = np.load(session_path+'/spikes.times.npy')
  raw_clusters = np.load(session_path+'/spikes.clusters.npy')
  cluster_annotations = np.load(session_path+'/clusters._phy_annotation.npy'), #cluster quality
  num_clusters = raw_clusters.max() 

  # Initialize 2-D list
  sorted_spike_times = [[] for i in range(num_clusters+1)]

  # Group spike times by their predicted cluster number 
  for j in range(len(raw_spikes)):
    sorted_spike_times[raw_clusters[j][0]].append(raw_spikes[j][0])
  
  # Dict to store valid clusters (annotation quality > 1)
  filtered_clusters = dict()

  # Convert clustered spikes to numpy arrays for efficiency
  for j,cluster in enumerate(sorted_spike_times):
    # Only store valid clusters
    if cluster_annotations[0][j][0] > 1:
      filtered_clusters[j] = np.array(sorted_spike_times[j])

  cluster_locs = []
  cluster_peak_channels = np.load(session_path+'/clusters.peakChannel.npy')
  brain_locs = pd.read_csv(session_path+'/channels.brainLocation.tsv', sep='\t')

  for cluster in cluster_peak_channels:
    # Offset area index by 1 due to 0-indexing in Pandas
    # clusters.peakChannel.npy appears to use 1-indexing
    idx = cluster[0]-1
    cluster_locs.append(brain_locs['allen_ontology'][idx])
  
  # Convert cluster brain locations to np array for efficiency
  cluster_locs = np.array(cluster_locs)

  # Load session data into dictionary
  all_session_data[i] = {
    # Session information
    'mouse_id': session.split('_')[0],
    'session_date': datetime.strptime(session.split('_')[1], '%Y-%m-%d').date(),

    # Neural Data 
    'clustered_spikes': filtered_clusters, # Dict: key=cluster_ID, val=spike_times(seconds)

    # Cluster Data
    'clusters_locs': cluster_locs,

    # Trial Data
    'trials_fb_type': np.load(session_path+'/trials.feedbackType.npy'),
    'trials_fb_times': np.load(session_path+'/trials.feedback_times.npy'),
    'trials_go_times': np.load(session_path+'/trials.goCue_times.npy'),
    'trials_included': np.load(session_path+'/trials.included.npy'),
    'trials_rep_num': np.load(session_path+'/trials.repNum.npy'),
    'trials_resp_choice': np.load(session_path+'/trials.response_choice.npy'),
    'trials_resp_times': np.load(session_path+'/trials.response_times.npy'),
    'trials_intervals': np.load(session_path+'/trials.intervals.npy'),
    'trials_vis_times': np.load(session_path+'/trials.visualStim_times.npy')
  }

In [None]:
# @title File writer for Kobayashi method

#num_clusters = len(sorted_spike_times.keys())

with open("/content/drive/MyDrive/steinmetz/Cori_2016-12-14/kobayashi_datafile_5n.txt", "w") as f:
  for i,cluster in enumerate(sorted_spike_times.keys()):
    written = False
    
    # check if cluster is valid 
    if spike_clusters_annotations[cluster][0] < 2:
      continue

    for time in sorted_spike_times[cluster]:
      # limit spikes to first 5 minutes
      if time <= 10:
        written = True
        f.write(str(time)+'\n')

    #if i != num_clusters:
    # append only if a spike time has been written for a given neuron
    if written:    
      f.write(';\n')
  
  f.close()

In [26]:
clustered_locs = clusters_to_area(all_session_data[0]['clustered_spikes'].keys(), all_session_data[0]['clusters_locs'])

In [31]:
print(clustered_locs.keys())

dict_keys(['DG', 'CA3', 'ACA', 'VISp', 'SUB', 'MOs', 'LS', 'root'])


In [28]:
area_spikes = get_area_spike_trains(all_session_data[0]['clustered_spikes'], all_session_data[0]['clusters_locs'], start_time=0, bin_size=10, window_size=100, format='counts', aggregate=True)

In [29]:
trials_times = group_trial_timestamps(all_session_data[0]['trials_intervals'], all_session_data[0]['trials_vis_times'], all_session_data[0]['trials_go_times'], all_session_data[0]['trials_resp_times'], all_session_data[0]['trials_fb_times'])

In [None]:
visual_area_all = []
window = 1000

for trial in tqdm(trials_times, desc='TRIAL'):
  start = trial[1]*1000 - window/2
  spikes = get_area_spike_trains(all_session_data[0]['clustered_spikes'], all_session_data[0]['clusters_locs'], start_time=start, bin_size=50, window_size=window)
  
  visual_area_all.append(spikes['VISp'])