Below is the filtering function (as described in section 1/2)

In [1]:
import numpy as np

#this function will do the filtering
def filter_spikes(alldata, session_id):
    #alldata is the collated data from all sessions/neurons/timepoints as shown in the tutorial notebook
    #grab spikes/choices from one section
    dat = alldat[session_id]
    spks = dat['spks']
    chcs = dat['response']
    RTs = dat['response_time']
    
    #grab only spikes/choices from trials where left/right contrast is equal and nonzero
    unfair_filter = np.logical_and((dat['contrast_right']==dat['contrast_left']), (dat['contrast_right'] != 0))
    unfair_chosey_filter = np.logical_and(unfair_filter,(dat['response']!=0))
    spks = spks[:,unfair_chosey_filter,:]
    chcs = chcs[unfair_chosey_filter]
    RTs = RTs[unfair_chosey_filter]
    
    #grab only spikes from the VISp
    spks = spks[dat['brain_area']=='VISp',:,:]
    
    #grab only spikes from between -500ms and 500ms, relative to stimulus onset (each bin is 10ms)
    spks = spks[:,:,0:100]
    
    return spks, chcs, RTs
        

Below we are grabbing the data and testing the function

In [2]:
#Data retrieval
import os, requests
import numpy as np

fname = []
for j in range(3):
  fname.append('steinmetz_part%d.npz'%j)
url = ["https://osf.io/agvxh/download"]
url.append("https://osf.io/uv3mw/download")
url.append("https://osf.io/ehmw2/download")

for j in range(len(url)):
  if not os.path.isfile(fname[j]):
    try:
      r = requests.get(url[j])
    except requests.ConnectionError:
      print("!!! Failed to download data !!!")
    else:
      if r.status_code != requests.codes.ok:
        print("!!! Failed to download data !!!")
      else:
        with open(fname[j], "wb") as fid:
          fid.write(r.content)
        
#@title Data loading
import numpy as np

alldat = np.array([])
for j in range(len(fname)):
  alldat = np.hstack((alldat, np.load('steinmetz_part%d.npz'%j, allow_pickle=True)['dat']))

# select just one of the recordings here. 11 is nice because it has some neurons in vis ctx. 
dat = alldat[11]

In [5]:
#Test fucntion
session = 11
spks_filtered, chcs_filtered, RTs_filtered = filter_spikes(alldat,session)

#The shape of spks is neuron x trial x time point (-500 to 500)
#The shape of choices is by trial
print(spks_filtered.shape, chcs_filtered.shape, RTs_filtered.shape)

(66, 18, 100) (18,) (18, 1)


In [7]:
print("For session ", session, "we have", chcs_filtered.shape[0], "trials to use and ", spks_filtered.shape[0], "neurons")

For session  11 we have 18 trials to use and  66 neurons
