<a href="https://colab.research.google.com/github/mmcinnestaylor/NMA-CN-2022/blob/main/project/steinmetz_analysis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
import numpy as np

from datetime import datetime
from tqdm import tqdm

In [None]:
BASE_PATH = '/content/drive/MyDrive/steinmetz' # Path to dataset base directory

WINDOW_SIZE = 10 #will be used for binning spikes

In [None]:
# @title Helper Functions
#@markdown - get_spike_trains(clustered_spikes, start_time=0, bin_size=10, window_size=100, format='binary')
def get_spike_trains(clustered_spikes, start_time=0, bin_size=10, window_size=100, format='binary'):
  """
  Spike train generator function.
  Builds trains over a given time window using discrete time bins. 

  Args:
    clustered_spikes : a dictionary of clusters
    start_time(ms)   : the starting time point of the window within a session
    bin_size(ms)     : the size of the discrete time step within the window
    window_size(ms)  : the width of the time window
    window_size(ms)  : binary or counts

  Returns:
    spike_trains     : the spikes trains of the neurons in the window
  """
  
  num_neurons = len(clustered_spikes.keys())
  num_bins = int(window_size/bin_size)

  if format == 'binary' or format == 'counts':
    # Initialize spike train matrix [num_neurons x num_bins]
    spike_trains = np.zeros((num_neurons,num_bins), dtype=int)

    # Iterate over all neurons in a recording session
    for i,cluster_id in enumerate(sorted(clustered_spikes.keys())):
      # Iterate over time bins in the window
      for j in range(num_bins):
        # Define bin start and end times
        bin_start = start_time + (j * bin_size)
        bin_end = start_time + ((j+1) * bin_size)
        
        # Convert spike times of a given neuron to millisecond scale
        neuron_spikes = clustered_spike_times[cluster_id] * 1000 
        
        # Check if a spike occured in time bin [start, end)
        spikes = np.logical_and(neuron_spikes>=bin_start, neuron_spikes<bin_end)
        if format == 'binary':
          if True in spikes:
            spike_trains[i][j] = 1
        else:
          spike_trains[i][j] = np.count_nonzero(spikes==True) 
  else: # to be implemented
    pass
    
  return spike_trains

In [None]:
# @title Load Data
#@markdown - Uses the path defined in BASE_PATH
#@markdown - Assumes each recording session resides in a separate subdirectory formatted as mouseID_year-month-date
#@markdown - **'mouse_id'**: (str) Mouse name per session directory format
#@markdown - **'session_date'**: (date) Date of session per session directory format
#@markdown - **'clustered_spikes'**: A dictionary of clusters(neurons) recorded during the session. Clusters with annotation values < 1 are not included.
#@markdown  - keys(int): The cluster's integer ID as specified in the datafiles.
#@markdown  - values(np.array): A 1-D array of size *nSpikes* where each entry corresponds to a time point in seconds during the recording session in which the cluster(neuron) produced a spike.
#@markdown - **'clusters_peak_channels'**: (np.array) A 1-D  array of size *nClusters* where each entry corresponds to the channel number of the location of the peak of the cluster's waveform. The number maps to the brain region of the channel, using th Allen CCF.

all_session_data = dict()

# Order sessions by surname + date
for i,session in tqdm(enumerate(sorted(os.listdir(BASE_PATH)))):
  session_path = os.path.join(BASE_PATH, session)

  print(f"Loading session: {session} ")

  # Load spike and cluster data
  raw_spikes = np.load(session_path+'/spikes.times.npy')
  raw_clusters = np.load(session_path+'/spikes.clusters.npy')
  cluster_annotations = np.load(session_path+'/clusters._phy_annotation.npy'), #cluster quality
  num_clusters = raw_clusters.max() 

  # Initialize 2-D list
  sorted_spike_times = [[] for i in range(num_clusters+1)]

  # Group spike times by their predicted cluster number 
  for j in range(len(raw_spikes)):
    sorted_spike_times[raw_clusters[j][0]].append(raw_spikes[j][0])
  
  # Dict to store valid clusters (annotation quality > 1)
  filtered_clusters = dict()

  # Convert clustered spikes to numpy arrays for efficiency
  for j,cluster in enumerate(sorted_spike_times):
    # Only store valid clusters
    if cluster_annotations[0][j][0] > 1:
      filtered_clusters[j] = np.array(sorted_spike_times[j])

  # Load session data into dictionary
  all_session_data[i] = {
    # Session information
    'mouse_id': session.split('_')[0],
    'session_date': datetime.strptime(session.split('_')[1], '%Y-%m-%d').date(),

    # Neural Data 
    'clustered_spikes': filtered_clusters, # Dict: key=cluster_ID, val=spike_times(seconds)

    # Cluster Data
    'clusters_peak_channels': np.load(session_path+'/clusters.peakChannel.npy'),

    # Channels Data
    #'channels_site': np.load(session_path+'/channels.site.npy')
  }

In [None]:
# @title File writer for Kobayashi method

#num_clusters = len(sorted_spike_times.keys())

with open("/content/drive/MyDrive/steinmetz/Cori_2016-12-14/kobayashi_datafile_5n.txt", "w") as f:
  for i,cluster in enumerate(sorted_spike_times.keys()):
    written = False
    
    # check if cluster is valid 
    if spike_clusters_annotations[cluster][0] < 2:
      continue

    for time in sorted_spike_times[cluster]:
      # limit spikes to first 5 minutes
      if time <= 10:
        written = True
        f.write(str(time)+'\n')

    #if i != num_clusters:
    # append only if a spike time has been written for a given neuron
    if written:    
      f.write(';\n')
  
  f.close()

In [None]:
print(all_session_data[0]['clustered_spikes'][0][:50])

[  0.8149      14.82246667  24.9646      25.1436      38.8709
  50.8208      54.80686667  59.51183333  80.47036667 127.09636667
 166.78713333 175.2476     177.2883     178.31753333 231.78033333
 240.66676667 271.28896667 305.5415     305.64986667 311.1602
 312.75296667 318.46916667 322.8011     324.85113333 356.86103333
 356.9972     377.01963333 388.1814     390.37333333 434.21206667
 442.29883333 473.92586667 484.7502     503.32483333 503.3358
 503.36313333 503.4157     530.48366667 574.1442     584.98683333
 586.90366667 593.3421     614.83113333 654.2239     654.53083333
 654.8104     719.13886667 739.67513333 744.6198     745.3653    ]


In [None]:
train = get_spike_trains(all_session_data[0]['clustered_spikes'], start_time=0, window_size=1000, bin_size=100, format='binary')

In [None]:
train[0]

In [None]:
train = get_spike_trains(all_session_data[0]['clustered_spikes'], start_time=24000, window_size=4000, bin_size=2000, format='counts')

In [None]:
train[0]

array([2, 0])