In [26]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import scipy.io
import ipywidgets
# Modules for interactive plotting
import bokeh.plotting
import bokeh.io
bokeh.io.output_notebook()
from IPython.display import display
import os


%matplotlib notebook

In [27]:
# Note: data file source from https://github.com/nsteinme/steinmetz-et-al-2019/wiki/data-files
directory = 'Data/Radnitz_2017-01-08/'
locations = pd.read_csv(directory + 'channels.brainLocation.tsv', sep='\t')


In [28]:
# Import data
# Spike-related variables
spikes_times = np.load(os.path.join(directory, 'spikes.times.npy'))[:,0]
spikes_depths = np.load(os.path.join(directory, 'spikes.depths.npy'))[:,0]
spikes_amps = np.load(os.path.join(directory, 'spikes.amps.npy'))[:,0]
spikes_clusters = np.load(os.path.join(directory, 'spikes.clusters.npy'))[:,0]

# Trial-related variables
trials_feedback_times = np.load(os.path.join(directory, 'trials.feedback_times.npy'))[:,0]
trials_feedback_types = np.load(os.path.join(directory, 'trials.feedbackType.npy'))[:,0]
trials_gocue_times = np.load(os.path.join(directory, 'trials.goCue_times.npy'))[:,0]
trials_included = np.load(os.path.join(directory, 'trials.included.npy'))[:,0]
trials_start = np.load(os.path.join(directory, 'trials.intervals.npy'))[:,0]
trials_end = np.load(os.path.join(directory, 'trials.intervals.npy'))[:,1]
trials_repNum = np.load(os.path.join(directory, 'trials.repNum.npy'))[:,0]
trials_choice = np.load(os.path.join(directory, 'trials.response_choice.npy'))[:,0]
trials_response_times = np.load(os.path.join(directory, 'trials.response_times.npy'))[:,0]
trials_left_contrast = np.load(os.path.join(directory, 'trials.visualStim_contrastLeft.npy'))[:,0]
trials_right_contrast = np.load(os.path.join(directory, 'trials.visualStim_contrastRight.npy'))[:,0]
trials_stim_times = np.load(os.path.join(directory, 'trials.visualStim_times.npy'))[:,0]

# Cluster information
clusters_annotation = np.load(os.path.join(directory, 'clusters._phy_annotation.npy'))[:,0]
clusters_peakChannel = np.load(os.path.join(directory, 'clusters.peakChannel.npy'))[:,0]
clusters_brainLoc = locations.allen_ontology[clusters_peakChannel - 1]

Computing decision times based on wheel activity

In [29]:
trials_wheel = np.load(os.path.join(directory, 'wheel.position.npy'))[:,0]
wheel_tstamps = np.load(os.path.join(directory, 'wheel.timestamps.npy'))
wheelMoveType = np.load(os.path.join(directory, 'wheelMoves.type.npy'))[:,0]
wheelMoveIntervals = np.load(os.path.join(directory, 'wheelMoves.intervals.npy'))
tstampsInterp = np.linspace(wheel_tstamps[0,1], wheel_tstamps[1,1], len(trials_wheel))


#plt.figure()
#plt.plot(tstampsInterp, trials_wheel)

In [30]:
choicePoints = wheelMoveIntervals[(wheelMoveType == 1) | (wheelMoveType == 2), 0]

In [31]:
diff = (trials_stim_times - choicePoints[:,np.newaxis])**2
closestTrial = np.argmin(diff, axis=1)
closestChoicePt = np.argmin(diff, axis=0)

In [32]:
def findClosestChoicePt(tstim, choicepts):
    candidates = choicepts[choicepts > tstim - 0.1]
    return min(candidates)

In [33]:
closestChoicePt = []
for idx, time in enumerate(trials_stim_times):
    if trials_choice[idx] == 0:
        closest = np.nan
    else:
        closest = findClosestChoicePt(time, choicePoints)
    closestChoicePt.append(closest)
    print('Trial: ', idx, ', stim on at ', time, ', closest: ', closest)
    
    

Trial:  0 , stim on at  40.25402086431727 , closest:  40.23912258863113
Trial:  1 , stim on at  49.220577893551884 , closest:  49.41112258863114
Trial:  2 , stim on at  57.52032324506816 , closest:  57.652122588631144
Trial:  3 , stim on at  61.26998891209373 , closest:  61.275122588631135
Trial:  4 , stim on at  68.05290769977805 , closest:  68.26212258863114
Trial:  5 , stim on at  72.1189789079405 , closest:  72.38712258863114
Trial:  6 , stim on at  79.71911200730959 , closest:  79.88712258863114
Trial:  7 , stim on at  83.88478495977431 , closest:  84.23412258863114
Trial:  8 , stim on at  88.7512701854019 , closest:  nan
Trial:  9 , stim on at  93.86775978929823 , closest:  94.09212258863113
Trial:  10 , stim on at  98.73424501492583 , closest:  98.93112258863114
Trial:  11 , stim on at  103.7511328745199 , closest:  103.95412258863114
Trial:  12 , stim on at  107.93440613521473 , closest:  nan
Trial:  13 , stim on at  119.14980254785213 , closest:  119.10712258863113
Trial:  14 

Trial:  217 , stim on at  1076.180962812701 , closest:  1076.460122588631
Trial:  218 , stim on at  1079.9146281995172 , closest:  1080.760122588631
Trial:  219 , stim on at  1083.613892983884 , closest:  nan
Trial:  220 , stim on at  1088.8475846403126 , closest:  nan
Trial:  221 , stim on at  1095.8801077992605 , closest:  nan
Trial:  222 , stim on at  1104.9962674484511 , closest:  nan
Trial:  223 , stim on at  1109.9131535567376 , closest:  nan
Trial:  224 , stim on at  1115.2624472376779 , closest:  1116.546122588631
Trial:  225 , stim on at  1119.512521668246 , closest:  nan
Trial:  226 , stim on at  1124.2622048483465 , closest:  1125.0431225886311
Trial:  227 , stim on at  1127.89546847685 , closest:  nan
Trial:  228 , stim on at  1132.6287513697362 , closest:  nan
Trial:  229 , stim on at  1137.2944330787384 , closest:  nan
Trial:  230 , stim on at  1142.1273177159267 , closest:  nan
Trial:  231 , stim on at  1147.1446055825259 , closest:  nan
Trial:  232 , stim on at  1153.16

In [34]:
tdiff = closestChoicePt - trials_stim_times
tdiff[np.isnan(tdiff)] = np.inf
np.sum(tdiff[~np.isnan(tdiff)] > 3)
trial_decision_times = np.array(closestChoicePt)[tdiff < 3]
plt.figure()
plt.plot(tdiff)


#plt.figure()
#plt.hist(closestChoicePt - trials_stim_times)

<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x21482753358>]

In [35]:
# Now we make a 'labeled wheel' array to classify movements, timepoint by timepoint
# labeledWheel = tstampsInterp * np.nan
# for i in range(len(wheelMoveType)):
#     interval = wheelMoveIntervals[i, :]
#     label = wheelMoveType[i]
#     labeledWheel[(tstampsInterp > interval[0]) & (tstampsInterp < interval[1])] = label
    #print(interval, label)

In [36]:
# Find the closest timestamp
# def find_id_min(tarr, tpoint):
#     '''Returns the id of the closest element to tpoint'''
#     diff = (tarr - tpoint)**2
#     return np.argmin(diff)

# IDChoices = []
# IDStims = []
# IDFeedback = []

# for time in choicePoints:
#     idChoice = find_id_min(tstampsInterp, time)
#     IDChoices.append(idChoice)
# for time in trials_stim_times:
#     idStim = find_id_min(tstampsInterp, time)
#     IDStims.append(idStim)
# for time in trials_feedback_times:
#     idFB = find_id_min(tstampsInterp, time)
#     IDFeedback.append(idFB)   
    
# # Visualize movement types
# IDChoices = np.array(IDChoices)
# plt.figure()
# plt.plot(tstampsInterp[labeledWheel == 1], trials_wheel[labeledWheel == 1], '.', 
#          color='tomato', ms=1)
# plt.plot(tstampsInterp[labeledWheel == 2], trials_wheel[labeledWheel == 2], '.', 
#          color='dodgerblue', ms=1)
# plt.plot(tstampsInterp[(labeledWheel != 1) & (labeledWheel != 2)], trials_wheel[(labeledWheel != 1) & (labeledWheel != 2)], '.', 
#          color='k', ms=1)
# #plt.plot(tstampsInterp, trials_wheel)
# #plt.plot(tstampsInterp[IDChoices[trials_choice == 1]], trials_wheel[IDChoices[trials_choice == 1]], 'or', ms=5)
# #plt.plot(tstampsInterp[IDChoices[trials_choice == -1]], trials_wheel[IDChoices[trials_choice == -1]], 'ob', ms=5)
# #plt.plot(tstampsInterp[IDChoices[trials_choice == 0]], trials_wheel[IDChoices[trials_choice == 0]], 'ok', ms=5)
# plt.plot(tstampsInterp[IDChoices], trials_wheel[IDChoices], 'or', ms=4, label='choice')
# plt.plot(tstampsInterp[IDStims], trials_wheel[IDStims], 'ok', ms=4, label='stim')
# plt.plot(tstampsInterp[IDFeedback], trials_wheel[IDFeedback], 'ob', ms=4, label='feedback')
# #plt.legend()

In [37]:
spikes_fr = pd.DataFrame({'times': spikes_times, 'depths': spikes_depths, 'amps': spikes_amps,
                          'clusters': spikes_clusters})
trials_fr = pd.DataFrame({'feedbackTimes': trials_feedback_times, 'feedbackType': trials_feedback_types,
                         'gocueTimes': trials_gocue_times, 'included': trials_included,
                         'start': trials_start, 'end': trials_end, 'repNum': trials_repNum,
                         'choice': trials_choice, 'response_times': trials_response_times,
                         'leftContrast': trials_left_contrast, 'rightContrast': trials_right_contrast,
                         'stimTimes': trials_stim_times})
clusters_fr = pd.DataFrame({'id': np.arange(len(clusters_annotation)), 'annotation': clusters_annotation, 
                            'brainLoc': clusters_brainLoc, 'peakChannel': clusters_peakChannel})

trials_fr['signedContrast'] = trials_fr.leftContrast - trials_fr.rightContrast

ncells = spikes_fr.clusters.unique().max() + 1

In [38]:
trials_fr.head()

Unnamed: 0,feedbackTimes,feedbackType,gocueTimes,included,start,end,repNum,choice,response_times,leftContrast,rightContrast,stimTimes,signedContrast
0,41.292039,1.0,41.148036,True,39.535837,42.25658,1.0,1.0,41.254635,1.0,0.0,40.254021,1.0
1,50.556601,1.0,50.367798,True,43.255557,51.524418,1.0,1.0,50.520903,0.5,0.25,49.220578,0.25
2,58.189935,1.0,58.022732,True,52.520271,59.15952,1.0,-1.0,58.153744,0.25,1.0,57.520323,-0.75
3,63.63843,1.0,62.393208,True,60.153976,64.606675,1.0,1.0,63.603413,0.25,0.0,61.269989,0.25
4,69.489733,1.0,69.242928,True,65.603947,70.45713,1.0,1.0,69.453177,1.0,0.0,68.052908,1.0


In [39]:
plt.figure()
plt.plot(trials_fr.end - trials_fr.feedbackTimes)

<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x214806ef470>]

In [40]:
leftChoiceTrials = np.where(trials_fr['choice'] == 1.0)[0]
rightChoiceTrials = np.where(trials_fr['choice'] == -1.0)[0]
correctTrials = np.where(trials_fr['feedbackType'] == 1.0)[0]
incorrTrials = np.where(trials_fr['feedbackType'] == -1.0)[0]
hardTrials = np.where((trials_fr['leftContrast'] <= 0.25) & (trials_fr['rightContrast'] <= 0.25))[0]
easyTrials = np.where((trials_fr['leftContrast'] > 0.25) | (trials_fr['rightContrast'] > 0.25))[0]

In [41]:
def find_cluster_by_area(clusters_fr, area):
    '''
    clusters_fr: a frame of clusters
    area: a string indicating brain area
    Return a np array of good clusters (area and annotation >=2)
    '''
    good_clusters = clusters_fr[(clusters_fr.brainLoc == area) & (clusters_fr.annotation >= 2)].id
    return np.array(good_clusters)
    
areas = ['ACA', 'VISp', 'SCig']
area_dict = {}
for area in areas:
    good_clusters = find_cluster_by_area(clusters_fr, area)
    area_dict[area] = good_clusters
    
#scipy.io.savemat('cell_areas_Lederberg1207.mat', area_dict)


In [42]:
area_dict

{'ACA': array([ 15,  19,  23,  31,  37,  98, 166, 201, 219, 223, 233, 251, 253,
        315, 328, 333, 356, 398, 401, 424, 457, 626, 627, 628, 629]),
 'SCig': array([ 648,  658,  667,  668,  673,  674,  682,  687,  700,  703,  715,
         717,  718,  722,  728,  730,  732,  738,  748,  752,  754,  755,
         758,  759,  760,  771,  773,  774,  785,  788,  790,  793,  794,
         795,  797,  798,  800,  807,  808,  811,  821,  822,  828,  829,
         830,  831,  837,  839,  843,  851,  853,  855,  860,  869,  870,
         872,  873,  881,  892,  893,  898,  904,  907,  908,  909,  915,
         916,  917,  924,  926,  927,  929,  935,  936,  943,  948,  950,
         952,  956,  957,  958,  961,  963,  970,  971,  987,  989,  990,
         995,  996, 1002, 1003, 1005, 1008, 1009, 1010, 1011, 1014, 1018,
        1022, 1024, 1027, 1028, 1033, 1035, 1036, 1037, 1041, 1044, 1047,
        1049, 1052, 1058, 1059, 1060, 1064, 1068, 1069, 1071, 1072, 1075,
        1082, 1083, 1087, 10

### Get the overall activity statistics
Here, we try to understand how 'active' the ACC or VC neurons are

In [43]:
cell_nspikes = []
for cellid in area_dict['ACA']:
    nspikes = np.sum(spikes_fr.clusters == cellid)
    #nspikes = spike_cell.shape[0]
    cell_nspikes.append(nspikes)
cell_nspikes = np.array(cell_nspikes)

In [44]:
# Plot psychometric curve for this session
performance = trials_fr.groupby(['signedContrast'])['choice'].mean()
plt.figure()
plt.plot(performance)

<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x214840ce550>]

In [45]:
# Align all rasters to visual stimulus onset
def splitSpikeTimesToTrials(spikes, tstarts, tends):
    '''
    spikes: the spikes array (nspikes)
    tstarts: trial start times array (ntrials)
    tends: trial end times array (ntrials)
    Returns a list of grouped spike times, ntrials sublists
    '''
    spikeTrialTimes = []
    for i in range(len(tstarts)):
        spikeTrial = spikes[(spikes > tstarts[i]) & (spikes < tends[i])]
        spikeTrialTimes.append(spikeTrial)
        
    return spikeTrialTimes
    
# Perform binning
def BinSpikeTimes(spikes, tstarts, window, nbins):
    '''
    spikes: an array of spike times, split into trials
    tstarts: trial start times (ntrials)
    window: [start, end], window to visualize from start to end
    binwidth: width of bin 
    Returns an nbins x ntrials array of spike counts
    '''
    spikeBinnedCounts = []
    for i in range(len(tstarts)):
        spikeTrial = spikes[(spikes > tstarts[i] + window[0]) & (spikes < tstarts[i] + window[1])]
        
        # Bin
        edges = np.linspace(tstarts[i] + window[0], tstarts[i] + window[1], nbins)
        counts, times = np.histogram(spikeTrial, bins=edges)
        spikeBinnedCounts.append(counts)
               
    return edges[:-1], np.array(spikeBinnedCounts)

def BinTrials(spikes, tstarts, window):
    ''' Group by trials and align on tstarts
    Returns a list of ntrials list, each with the aligned spikes '''
    spikeGroups = []
    for i in range(len(tstarts)):
        spikeTrial = spikes[(spikes > tstarts[i] + window[0]) & (spikes < tstarts[i] + window[1])]
        spikeTrial = spikeTrial - tstarts[i]
        spikeGroups.append(spikeTrial)
    return spikeGroups

        

In [46]:
relativeTime = (trials_fr.response_times - trials_fr.stimTimes)[:,np.newaxis]
plt.figure()
plt.eventplot(relativeTime)

<IPython.core.display.Javascript object>

[<matplotlib.collections.EventCollection at 0x214840d3f98>,
 <matplotlib.collections.EventCollection at 0x214812a41d0>,
 <matplotlib.collections.EventCollection at 0x214812a44a8>,
 <matplotlib.collections.EventCollection at 0x214812a4710>,
 <matplotlib.collections.EventCollection at 0x214812a4978>,
 <matplotlib.collections.EventCollection at 0x214812a4be0>,
 <matplotlib.collections.EventCollection at 0x214812a4e48>,
 <matplotlib.collections.EventCollection at 0x214812d10f0>,
 <matplotlib.collections.EventCollection at 0x214812d1358>,
 <matplotlib.collections.EventCollection at 0x214812d15c0>,
 <matplotlib.collections.EventCollection at 0x214812d1828>,
 <matplotlib.collections.EventCollection at 0x214812d1a90>,
 <matplotlib.collections.EventCollection at 0x214812d1cf8>,
 <matplotlib.collections.EventCollection at 0x214812d1f60>,
 <matplotlib.collections.EventCollection at 0x214812a0208>,
 <matplotlib.collections.EventCollection at 0x214812a0470>,
 <matplotlib.collections.EventCollection

In [47]:
def plot_binned_spikes(cluster_id, area, aligned_by, window):
    # TODO: PLOT RESPONSE AND FEEDBACK TIMES, SUPERIMPOSED
    if area != 'None':
        cellid = area_dict[area][cluster_id]
    else:
        cellid = cluster_id
    spike_cluster = spikes_fr[spikes_fr.clusters == cellid]
    nspikes = np.sum(spikes_fr.clusters == cellid)
    print(locations.allen_ontology[clusters_peakChannel[cellid]])
    print('Good', area, 'cell number', cluster_id, ', corresponding to unit #', cellid)
    print('Number of spikes = ', nspikes)
    
    csfont = {'fontname':'Comic Sans MS'}
    hfont = {'fontname':'Helvetica'}

    #print(spike_cluster.times[:100])
    
    if aligned_by == 'default':
        spikeGroups1 = BinTrials(spike_cluster.times, trials_fr.stimTimes, window)
        #spikeGroups2 = BinTrials(spike_cluster.times, trials_fr.response_times, window)
        spikeGroups2 = BinTrials(spike_cluster.times, trial_decision_times, window)
        spikeGroups3 = BinTrials(spike_cluster.times, trials_fr.feedbackTimes, window)
    elif aligned_by == 'choice':
        #spikeGroups1 = BinTrials(spike_cluster.times, trials_fr.response_times, window)
        times = np.array(closestChoicePt)[np.concatenate((leftChoiceTrials,rightChoiceTrials))]
        spikeGroups1 = BinTrials(spike_cluster.times, times, window)
        binnedActivityLeft = np.array(spikeGroups1)[:len(leftChoiceTrials)]
        binnedActivityRight = np.array(spikeGroups1)[len(leftChoiceTrials):]
        titleL = 'Left choice'
        titleR = 'Right choice'
    elif aligned_by == 'feedback':
        spikeGroups1 = BinTrials(spike_cluster.times, trials_fr.feedbackTimes, window)
        binnedActivityLeft = np.array(spikeGroups1)[correctTrials]
        binnedActivityRight = np.array(spikeGroups1)[incorrTrials]
        titleL = 'Correct'
        titleR = 'Incorrect'
    elif aligned_by == 'difficulty':
        spikeGroups1 = BinTrials(spike_cluster.times, trials_fr.feedbackTimes, window)
        binnedActivityLeft = np.array(spikeGroups1)[easyTrials]
        binnedActivityRight = np.array(spikeGroups1)[hardTrials]
        titleL = 'Easy'
        titleR = 'Difficult'
        
    if aligned_by == 'default':
        #plt.figure(figsize=(10, 5))
        plt.subplot(131)
        plt.eventplot(spikeGroups1, colors='k')
        relativeResp = (np.array(closestChoicePt) - trials_fr.stimTimes)[:,np.newaxis]
        relativeRew = (trials_fr.feedbackTimes - trials_fr.stimTimes)[:,np.newaxis]
        #plt.eventplot(relativeResp, colors='r')  
        #plt.eventplot(relativeRew, colors='r')        

        plt.title('Stimulus', **hfont)
        plt.xlabel('Time (s)')
        plt.ylabel('Trial #')

        plt.subplot(132)
        plt.eventplot(spikeGroups2, colors='k')
        plt.xlabel('Time (s)')
        plt.title('Response', **hfont)

        plt.subplot(133)
        plt.eventplot(spikeGroups3, colors='k')
        plt.xlabel('Time (s)')
        plt.title('Feedback', **hfont)
    elif aligned_by in ['choice', 'feedback', 'difficulty']:
        #plt.figure(figsize=(8, 4))
        plt.subplot(121)
        plt.eventplot(binnedActivityLeft, colors='k')
        plt.title(titleL)
        plt.ylim([0, 200])
        
        plt.subplot(122)
        plt.eventplot(binnedActivityRight, colors='k')
        plt.title(titleR)
        plt.ylim([0, 200])
    else:
        raise ValueError('Invalid alignment')
    
    #plt.savefig('figure' + str(cluster_id) + area + aligned_by + '.png')
        

    
    #return binned_activity
    

In [54]:
def update(align_by, area, cell=0, window=[-1, 3]):
    # Find allen brain region
    plt.figure(figsize=(10, 5))
    plot_binned_spikes(cell, area, aligned_by=align_by, window=window)
    plt.savefig('ACA_unit626_Radnitz_2017-01-08.pdf')
    
ipywidgets.interactive(update, 
                       cell=ipywidgets.widgets.BoundedIntText(
                            value=0,
                            min=0,
                            max=ncells,
                            step=1,
                            description='CellID: ',
                            disabled=False),
                      window=ipywidgets.widgets.FloatRangeSlider(
                            value=[-0.5, 2],
                            min=-4,
                            max=4,
                            step=0.1,
                            description='Window (s):',
                            disabled=False,
                            continuous_update=False,
                            orientation='horizontal',
                            readout=True,
                            readout_format='.1f'),
                       area=ipywidgets.widgets.RadioButtons(
                            options=areas + ['None'],
                            value='None',
                            description='Brain area:',
                            disabled=False),
                    align_by=ipywidgets.widgets.RadioButtons(
                        options=['default', 'choice', 'feedback', 'difficulty'],
                        description='Align by:',
                        disabled=False
))

# For Radnitz_2017-01-08
# VC cells: 740, 817, 836, 862
# SC cells: 1250, 1155, 1170, 1163, 
# 1142, 1122 (choice)


A Jupyter Widget

In [49]:
# Plot all cells in ACC
area = 'SCig'
aligned_by = 'default'
window = [-0.5, 2]
plt.figure(figsize=(10, 5))
for i in range(len(area_dict[area])):
    plot_binned_spikes(i, area, aligned_by, window)
    plt.clf()

<IPython.core.display.Javascript object>

SCig
Good SCig cell number 0 , corresponding to unit # 648
Number of spikes =  15431
SCig
Good SCig cell number 1 , corresponding to unit # 658
Number of spikes =  38527
SCig
Good SCig cell number 2 , corresponding to unit # 667
Number of spikes =  65728
SCig
Good SCig cell number 3 , corresponding to unit # 668
Number of spikes =  18217
SCig
Good SCig cell number 4 , corresponding to unit # 673
Number of spikes =  11063
SCig
Good SCig cell number 5 , corresponding to unit # 674
Number of spikes =  173
SCig
Good SCig cell number 6 , corresponding to unit # 682
Number of spikes =  13598
SCig
Good SCig cell number 7 , corresponding to unit # 687
Number of spikes =  15412
SCig
Good SCig cell number 8 , corresponding to unit # 700
Number of spikes =  8511
SCig
Good SCig cell number 9 , corresponding to unit # 703
Number of spikes =  1773
SCig
Good SCig cell number 10 , corresponding to unit # 715
Number of spikes =  13144
SCig
Good SCig cell number 11 , corresponding to unit # 717
Number o

KeyboardInterrupt: 