In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import scipy.io
import ipywidgets
# Modules for interactive plotting
import bokeh.plotting
import bokeh.io
bokeh.io.output_notebook()
from IPython.display import display
import os


%matplotlib notebook

In [2]:
# Note: data file source from https://github.com/nsteinme/steinmetz-et-al-2019/wiki/data-files
directory = 'Data/Radnitz_2017-01-08/'
locations = pd.read_csv(directory + 'channels.brainLocation.tsv', sep='\t')


In [3]:
# Import data
# Spike-related variables
spikes_times = np.load(os.path.join(directory, 'spikes.times.npy'))[:,0]
spikes_depths = np.load(os.path.join(directory, 'spikes.depths.npy'))[:,0]
spikes_amps = np.load(os.path.join(directory, 'spikes.amps.npy'))[:,0]
spikes_clusters = np.load(os.path.join(directory, 'spikes.clusters.npy'))[:,0]

# Trial-related variables
trials_feedback_times = np.load(os.path.join(directory, 'trials.feedback_times.npy'))[:,0]
trials_feedback_types = np.load(os.path.join(directory, 'trials.feedbackType.npy'))[:,0]
trials_gocue_times = np.load(os.path.join(directory, 'trials.goCue_times.npy'))[:,0]
trials_included = np.load(os.path.join(directory, 'trials.included.npy'))[:,0]
trials_start = np.load(os.path.join(directory, 'trials.intervals.npy'))[:,0]
trials_end = np.load(os.path.join(directory, 'trials.intervals.npy'))[:,1]
trials_repNum = np.load(os.path.join(directory, 'trials.repNum.npy'))[:,0]
trials_choice = np.load(os.path.join(directory, 'trials.response_choice.npy'))[:,0]
trials_response_times = np.load(os.path.join(directory, 'trials.response_times.npy'))[:,0]
trials_left_contrast = np.load(os.path.join(directory, 'trials.visualStim_contrastLeft.npy'))[:,0]
trials_right_contrast = np.load(os.path.join(directory, 'trials.visualStim_contrastRight.npy'))[:,0]
trials_stim_times = np.load(os.path.join(directory, 'trials.visualStim_times.npy'))[:,0]

# Cluster information
clusters_annotation = np.load(os.path.join(directory, 'clusters._phy_annotation.npy'))[:,0]
clusters_peakChannel = np.load(os.path.join(directory, 'clusters.peakChannel.npy'))[:,0]
clusters_brainLoc = locations.allen_ontology[clusters_peakChannel - 1]

Computing decision times based on wheel activity

In [4]:
trials_wheel = np.load(os.path.join(directory, 'wheel.position.npy'))[:,0]
wheel_tstamps = np.load(os.path.join(directory, 'wheel.timestamps.npy'))
wheelMoveType = np.load(os.path.join(directory, 'wheelMoves.type.npy'))[:,0]
wheelMoveIntervals = np.load(os.path.join(directory, 'wheelMoves.intervals.npy'))
tstampsInterp = np.linspace(wheel_tstamps[0,1], wheel_tstamps[1,1], len(trials_wheel))


#plt.figure()
#plt.plot(tstampsInterp, trials_wheel)

In [5]:
choicePoints = wheelMoveIntervals[(wheelMoveType == 1) | (wheelMoveType == 2), 0]

In [6]:
diff = (trials_stim_times - choicePoints[:,np.newaxis])**2
closestTrial = np.argmin(diff, axis=1)
closestChoicePt = np.argmin(diff, axis=0)

In [7]:
def findClosestChoicePt(tstim, choicepts):
    candidates = choicepts[choicepts > tstim - 0.1]
    return min(candidates)

In [8]:
closestChoicePt = []
for idx, time in enumerate(trials_stim_times):
    if trials_choice[idx] == 0:
        closest = np.nan
    else:
        closest = findClosestChoicePt(time, choicePoints)
    closestChoicePt.append(closest)
    print('Trial: ', idx, ', stim on at ', time, ', closest: ', closest)
    
    

Trial:  0 , stim on at  65.26940832339372 , closest:  65.43250410481991
Trial:  1 , stim on at  71.2027029221732 , closest:  71.41450410481991
Trial:  2 , stim on at  76.05238024406377 , closest:  76.23250410481991
Trial:  3 , stim on at  81.23526287848605 , closest:  nan
Trial:  4 , stim on at  86.80095161626113 , closest:  86.86450410481991
Trial:  5 , stim on at  90.65061299417887 , closest:  91.92350410481991
Trial:  6 , stim on at  95.98429803295227 , closest:  96.6805041048199
Trial:  7 , stim on at  99.55075489553691 , closest:  103.13550410481992
Trial:  8 , stim on at  105.56645080809975 , closest:  nan
Trial:  9 , stim on at  111.0669385063278 , closest:  111.26550410481991
Trial:  10 , stim on at  115.98261688052057 , closest:  116.2795041048199
Trial:  11 , stim on at  123.71634018482871 , closest:  nan
Trial:  12 , stim on at  129.73243610376912 , closest:  129.8115041048199
Trial:  13 , stim on at  133.28249270487262 , closest:  133.40450410481992
Trial:  14 , stim on at 

In [9]:
tdiff = closestChoicePt - trials_stim_times
tdiff[np.isnan(tdiff)] = np.inf
np.sum(tdiff[~np.isnan(tdiff)] > 3)
trial_decision_times = np.array(closestChoicePt)[tdiff < 3]
plt.figure()
plt.plot(tdiff)


#plt.figure()
#plt.hist(closestChoicePt - trials_stim_times)

<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x214f4fd65c0>]

In [10]:
# Now we make a 'labeled wheel' array to classify movements, timepoint by timepoint
# labeledWheel = tstampsInterp * np.nan
# for i in range(len(wheelMoveType)):
#     interval = wheelMoveIntervals[i, :]
#     label = wheelMoveType[i]
#     labeledWheel[(tstampsInterp > interval[0]) & (tstampsInterp < interval[1])] = label
    #print(interval, label)

In [11]:
# Find the closest timestamp
# def find_id_min(tarr, tpoint):
#     '''Returns the id of the closest element to tpoint'''
#     diff = (tarr - tpoint)**2
#     return np.argmin(diff)

# IDChoices = []
# IDStims = []
# IDFeedback = []

# for time in choicePoints:
#     idChoice = find_id_min(tstampsInterp, time)
#     IDChoices.append(idChoice)
# for time in trials_stim_times:
#     idStim = find_id_min(tstampsInterp, time)
#     IDStims.append(idStim)
# for time in trials_feedback_times:
#     idFB = find_id_min(tstampsInterp, time)
#     IDFeedback.append(idFB)   
    
# # Visualize movement types
# IDChoices = np.array(IDChoices)
# plt.figure()
# plt.plot(tstampsInterp[labeledWheel == 1], trials_wheel[labeledWheel == 1], '.', 
#          color='tomato', ms=1)
# plt.plot(tstampsInterp[labeledWheel == 2], trials_wheel[labeledWheel == 2], '.', 
#          color='dodgerblue', ms=1)
# plt.plot(tstampsInterp[(labeledWheel != 1) & (labeledWheel != 2)], trials_wheel[(labeledWheel != 1) & (labeledWheel != 2)], '.', 
#          color='k', ms=1)
# #plt.plot(tstampsInterp, trials_wheel)
# #plt.plot(tstampsInterp[IDChoices[trials_choice == 1]], trials_wheel[IDChoices[trials_choice == 1]], 'or', ms=5)
# #plt.plot(tstampsInterp[IDChoices[trials_choice == -1]], trials_wheel[IDChoices[trials_choice == -1]], 'ob', ms=5)
# #plt.plot(tstampsInterp[IDChoices[trials_choice == 0]], trials_wheel[IDChoices[trials_choice == 0]], 'ok', ms=5)
# plt.plot(tstampsInterp[IDChoices], trials_wheel[IDChoices], 'or', ms=4, label='choice')
# plt.plot(tstampsInterp[IDStims], trials_wheel[IDStims], 'ok', ms=4, label='stim')
# plt.plot(tstampsInterp[IDFeedback], trials_wheel[IDFeedback], 'ob', ms=4, label='feedback')
# #plt.legend()

In [12]:
spikes_fr = pd.DataFrame({'times': spikes_times, 'depths': spikes_depths, 'amps': spikes_amps,
                          'clusters': spikes_clusters})
trials_fr = pd.DataFrame({'feedbackTimes': trials_feedback_times, 'feedbackType': trials_feedback_types,
                         'gocueTimes': trials_gocue_times, 'included': trials_included,
                         'start': trials_start, 'end': trials_end, 'repNum': trials_repNum,
                         'choice': trials_choice, 'response_times': trials_response_times,
                         'leftContrast': trials_left_contrast, 'rightContrast': trials_right_contrast,
                         'stimTimes': trials_stim_times})
clusters_fr = pd.DataFrame({'id': np.arange(len(clusters_annotation)), 'annotation': clusters_annotation, 
                            'brainLoc': clusters_brainLoc, 'peakChannel': clusters_peakChannel})

trials_fr['signedContrast'] = trials_fr.leftContrast - trials_fr.rightContrast

ncells = spikes_fr.clusters.unique().max() + 1

In [13]:
trials_fr.head()

Unnamed: 0,feedbackTimes,feedbackType,gocueTimes,included,start,end,repNum,choice,response_times,leftContrast,rightContrast,stimTimes,signedContrast
0,66.456227,1.0,66.296625,True,62.900284,67.423484,1.0,1.0,66.419612,1.0,0.0,65.269408,1.0
1,72.640326,1.0,72.077117,True,68.420838,73.604476,1.0,-1.0,72.602206,0.0,0.5,71.202703,-0.5
2,77.038396,1.0,76.877593,True,74.602902,78.006757,1.0,1.0,77.001671,1.0,0.5,76.05238,0.5
3,83.531699,1.0,81.996875,True,79.003653,84.506778,1.0,0.0,83.502065,0.0,0.0,81.235263,0.0
4,87.628565,-1.0,87.462962,True,85.501795,88.621336,1.0,1.0,87.617727,0.5,1.0,86.800952,-0.5


In [14]:
plt.figure()
plt.plot(trials_fr.end - trials_fr.feedbackTimes)

<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x2148004ae10>]

In [15]:
leftChoiceTrials = np.where(trials_fr['choice'] == 1.0)[0]
rightChoiceTrials = np.where(trials_fr['choice'] == -1.0)[0]
correctTrials = np.where(trials_fr['feedbackType'] == 1.0)[0]
incorrTrials = np.where(trials_fr['feedbackType'] == -1.0)[0]
hardTrials = np.where((trials_fr['leftContrast'] <= 0.25) & (trials_fr['rightContrast'] <= 0.25))[0]
easyTrials = np.where((trials_fr['leftContrast'] > 0.25) | (trials_fr['rightContrast'] > 0.25))[0]

In [16]:
def find_cluster_by_area(clusters_fr, area):
    '''
    clusters_fr: a frame of clusters
    area: a string indicating brain area
    Return a np array of good clusters (area and annotation >=2)
    '''
    good_clusters = clusters_fr[(clusters_fr.brainLoc == area) & (clusters_fr.annotation >= 2)].id
    return np.array(good_clusters)
    
areas = ['ACA', 'VISp', 'SCig']
area_dict = {}
for area in areas:
    good_clusters = find_cluster_by_area(clusters_fr, area)
    area_dict[area] = good_clusters
    
#scipy.io.savemat('cell_areas_Lederberg1207.mat', area_dict)


In [17]:
area_dict

{'ACA': array([  0,   2,   7,  23,  26,  28,  35,  47,  53,  59,  67,  68,  72,
         73,  79,  86,  87,  90, 100, 106, 107, 110, 115, 119, 132, 133,
        134, 141, 142, 146, 162, 168, 169, 185, 191, 192, 199, 203, 204,
        205, 211, 213, 215, 224, 245, 250, 261, 262, 275, 279, 280, 286,
        287, 297, 300, 303, 305, 309, 310, 312, 318, 332, 334, 336, 341,
        342, 343, 348, 350, 358, 364, 367, 371, 375, 379, 383, 387, 396,
        399, 404, 408, 409, 410, 419, 420, 440, 444, 451, 458, 459, 460,
        461, 462, 463, 464, 468, 477, 487, 488, 493, 504, 509, 510, 516,
        527, 532, 538, 540, 543]),
 'SCig': array([], dtype=int32),
 'VISp': array([ 551,  552,  553,  557,  558,  560,  564,  565,  566,  569,  571,
         574,  575,  577,  578,  584,  586,  587,  588,  591,  594,  597,
         598,  600,  601,  602,  606,  610,  613,  617,  622,  626,  628,
         630,  632,  635,  637,  639,  640,  642,  643,  646,  647,  655,
         656,  658,  662,  665,  670,

### Get the overall activity statistics
Here, we try to understand how 'active' the ACC or VC neurons are

In [18]:
cell_nspikes = []
for cellid in area_dict['ACA']:
    nspikes = np.sum(spikes_fr.clusters == cellid)
    #nspikes = spike_cell.shape[0]
    cell_nspikes.append(nspikes)
cell_nspikes = np.array(cell_nspikes)

In [19]:
# Plot psychometric curve for this session
performance = trials_fr.groupby(['signedContrast'])['choice'].mean()
plt.figure()
plt.plot(performance)

<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x2148009b470>]

In [20]:
# Align all rasters to visual stimulus onset
def splitSpikeTimesToTrials(spikes, tstarts, tends):
    '''
    spikes: the spikes array (nspikes)
    tstarts: trial start times array (ntrials)
    tends: trial end times array (ntrials)
    Returns a list of grouped spike times, ntrials sublists
    '''
    spikeTrialTimes = []
    for i in range(len(tstarts)):
        spikeTrial = spikes[(spikes > tstarts[i]) & (spikes < tends[i])]
        spikeTrialTimes.append(spikeTrial)
        
    return spikeTrialTimes
    
# Perform binning
def BinSpikeTimes(spikes, tstarts, window, nbins):
    '''
    spikes: an array of spike times, split into trials
    tstarts: trial start times (ntrials)
    window: [start, end], window to visualize from start to end
    binwidth: width of bin 
    Returns an nbins x ntrials array of spike counts
    '''
    spikeBinnedCounts = []
    for i in range(len(tstarts)):
        spikeTrial = spikes[(spikes > tstarts[i] + window[0]) & (spikes < tstarts[i] + window[1])]
        
        # Bin
        edges = np.linspace(tstarts[i] + window[0], tstarts[i] + window[1], nbins)
        counts, times = np.histogram(spikeTrial, bins=edges)
        spikeBinnedCounts.append(counts)
               
    return edges[:-1], np.array(spikeBinnedCounts)

def BinTrials(spikes, tstarts, window):
    ''' Group by trials and align on tstarts
    Returns a list of ntrials list, each with the aligned spikes '''
    spikeGroups = []
    for i in range(len(tstarts)):
        spikeTrial = spikes[(spikes > tstarts[i] + window[0]) & (spikes < tstarts[i] + window[1])]
        spikeTrial = spikeTrial - tstarts[i]
        spikeGroups.append(spikeTrial)
    return spikeGroups

        

In [21]:
relativeTime = (trials_fr.response_times - trials_fr.stimTimes)[:,np.newaxis]
plt.figure()
plt.eventplot(relativeTime)

<IPython.core.display.Javascript object>

[<matplotlib.collections.EventCollection at 0x214800dc908>,
 <matplotlib.collections.EventCollection at 0x214800dcbe0>,
 <matplotlib.collections.EventCollection at 0x214800dce48>,
 <matplotlib.collections.EventCollection at 0x214800ef0f0>,
 <matplotlib.collections.EventCollection at 0x214800ef358>,
 <matplotlib.collections.EventCollection at 0x214800ef5c0>,
 <matplotlib.collections.EventCollection at 0x214800ef828>,
 <matplotlib.collections.EventCollection at 0x214800efa90>,
 <matplotlib.collections.EventCollection at 0x214800efcf8>,
 <matplotlib.collections.EventCollection at 0x214800eff60>,
 <matplotlib.collections.EventCollection at 0x214800fd208>,
 <matplotlib.collections.EventCollection at 0x214800fd470>,
 <matplotlib.collections.EventCollection at 0x214800fd6d8>,
 <matplotlib.collections.EventCollection at 0x214800fd940>,
 <matplotlib.collections.EventCollection at 0x214800fdba8>,
 <matplotlib.collections.EventCollection at 0x214800fde10>,
 <matplotlib.collections.EventCollection

In [22]:
def plot_binned_spikes(cluster_id, area, aligned_by, window):
    # TODO: PLOT RESPONSE AND FEEDBACK TIMES, SUPERIMPOSED
    if area != 'None':
        cellid = area_dict[area][cluster_id]
    else:
        cellid = cluster_id
    spike_cluster = spikes_fr[spikes_fr.clusters == cellid]
    nspikes = np.sum(spikes_fr.clusters == cellid)
    print(locations.allen_ontology[clusters_peakChannel[cellid]])
    print('Good', area, 'cell number', cluster_id, ', corresponding to unit #', cellid)
    print('Number of spikes = ', nspikes)
    
    csfont = {'fontname':'Comic Sans MS'}
    hfont = {'fontname':'Helvetica'}

    #print(spike_cluster.times[:100])
    
    if aligned_by == 'default':
        spikeGroups1 = BinTrials(spike_cluster.times, trials_fr.stimTimes, window)
        #spikeGroups2 = BinTrials(spike_cluster.times, trials_fr.response_times, window)
        spikeGroups2 = BinTrials(spike_cluster.times, trial_decision_times, window)
        spikeGroups3 = BinTrials(spike_cluster.times, trials_fr.feedbackTimes, window)
    elif aligned_by == 'choice':
        #spikeGroups1 = BinTrials(spike_cluster.times, trials_fr.response_times, window)
        times = np.array(closestChoicePt)[np.concatenate((leftChoiceTrials,rightChoiceTrials))]
        spikeGroups1 = BinTrials(spike_cluster.times, times, window)
        binnedActivityLeft = np.array(spikeGroups1)[:len(leftChoiceTrials)]
        binnedActivityRight = np.array(spikeGroups1)[len(leftChoiceTrials):]
        titleL = 'Left choice'
        titleR = 'Right choice'
    elif aligned_by == 'feedback':
        spikeGroups1 = BinTrials(spike_cluster.times, trials_fr.feedbackTimes, window)
        binnedActivityLeft = np.array(spikeGroups1)[correctTrials]
        binnedActivityRight = np.array(spikeGroups1)[incorrTrials]
        titleL = 'Correct'
        titleR = 'Incorrect'
    elif aligned_by == 'difficulty':
        spikeGroups1 = BinTrials(spike_cluster.times, trials_fr.feedbackTimes, window)
        binnedActivityLeft = np.array(spikeGroups1)[easyTrials]
        binnedActivityRight = np.array(spikeGroups1)[hardTrials]
        titleL = 'Easy'
        titleR = 'Difficult'
        
    if aligned_by == 'default':
        #plt.figure(figsize=(10, 5))
        plt.subplot(131)
        plt.eventplot(spikeGroups1, colors='k')
        relativeResp = (np.array(closestChoicePt) - trials_fr.stimTimes)[:,np.newaxis]
        relativeRew = (trials_fr.feedbackTimes - trials_fr.stimTimes)[:,np.newaxis]
        #plt.eventplot(relativeResp, colors='r')  
        #plt.eventplot(relativeRew, colors='r')        

        plt.title('Stimulus', **hfont)
        plt.xlabel('Time (s)')
        plt.ylabel('Trial #')

        plt.subplot(132)
        plt.eventplot(spikeGroups2, colors='k')
        plt.xlabel('Time (s)')
        plt.title('Response', **hfont)

        plt.subplot(133)
        plt.eventplot(spikeGroups3, colors='k')
        plt.xlabel('Time (s)')
        plt.title('Feedback', **hfont)
    elif aligned_by in ['choice', 'feedback', 'difficulty']:
        #plt.figure(figsize=(8, 4))
        plt.subplot(121)
        plt.eventplot(binnedActivityLeft, colors='k')
        plt.title(titleL)
        plt.ylim([0, 200])
        
        plt.subplot(122)
        plt.eventplot(binnedActivityRight, colors='k')
        plt.title(titleR)
        plt.ylim([0, 200])
    else:
        raise ValueError('Invalid alignment')
    
    #plt.savefig('figure' + str(cluster_id) + area + aligned_by + '.png')
        

    
    #return binned_activity
    

In [25]:
def update(align_by, area, cell=0, window=[-1, 3]):
    # Find allen brain region
    plt.figure(figsize=(10, 5))
    plot_binned_spikes(cell, area, aligned_by=align_by, window=window)
    
ipywidgets.interactive(update, 
                       cell=ipywidgets.widgets.BoundedIntText(
                            value=0,
                            min=0,
                            max=ncells,
                            step=1,
                            description='CellID: ',
                            disabled=False),
                      window=ipywidgets.widgets.FloatRangeSlider(
                            value=[-0.5, 2],
                            min=-4,
                            max=4,
                            step=0.1,
                            description='Window (s):',
                            disabled=False,
                            continuous_update=False,
                            orientation='horizontal',
                            readout=True,
                            readout_format='.1f'),
                       area=ipywidgets.widgets.RadioButtons(
                            options=areas + ['None'],
                            value='None',
                            description='Brain area:',
                            disabled=False),
                    align_by=ipywidgets.widgets.RadioButtons(
                        options=['default', 'choice', 'feedback', 'difficulty'],
                        description='Align by:',
                        disabled=False
))

# For Radnitz_2017-01-08
# VC cells: 740, 817, 836, 862
# SC cells: 1250, 1155, 1170, 1163, 
# 1142, 1122 (choice)


A Jupyter Widget

In [24]:
# Plot all cells in ACC
area = 'SCig'
aligned_by = 'default'
window = [-0.5, 2]
plt.figure(figsize=(10, 5))
for i in range(len(area_dict[area])):
    plot_binned_spikes(i, area, aligned_by, window)
    plt.clf()

<IPython.core.display.Javascript object>