# Cell rankings and stability checks


### Prepare workspace

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

# Add local scripts to path
import os,sys
sys.path.insert(0,os.path.abspath("./"))
import neurotools

# Set up cache
from neurotools.jobs.initialize_system_cache import initialize_caches,cache_test
PYCACHEDIR = os.path.abspath('./')
CACHENAME  = 'PPC_cache'
from neurotools.tools import ensure_dir
ensure_dir(PYCACHEDIR+os.sep+CACHENAME)
initialize_caches(
    level1  = PYCACHEDIR,
    force   = False,
    verbose = False,
    CACHE_IDENTIFIER = CACHENAME)

# Import libraries
from neurotools.nlab import *
import ppc_data_loader

# Set this to the location of the PPC data on your machine
ppc_data_loader.path = '/home/mer49/Dropbox (Cambridge University)/Datasets/PPC_data/'
from ppc_data_loader   import *
from ppc_trial         import *

# Configure Numpy
np.seterr(all='raise');
from numpy.linalg import solve
np.random.seed(0)

# Configure Matplotlib
import matplotlib as mpl
mpl.rcParams['figure.dpi'] = 120
TEXTWIDTH = 5.62708
matplotlib.rcParams['figure.figsize'] = (TEXTWIDTH, TEXTWIDTH/sqrt(2))
import warnings
from matplotlib import MatplotlibDeprecationWarning
warnings.filterwarnings("ignore",category=MatplotlibDeprecationWarning)
SMALL_SIZE  = 7
MEDIUM_SIZE = 8
BIGGER_SIZE = 9
matplotlib.rc('font'  , size     =SMALL_SIZE ) # controls default text sizes
matplotlib.rc('axes'  , titlesize=MEDIUM_SIZE) # fontsize of the axes title
matplotlib.rc('axes'  , labelsize=MEDIUM_SIZE) # fontsize of the x and y labels
matplotlib.rc('xtick' , labelsize=SMALL_SIZE ) # fontsize of the tick labels
matplotlib.rc('ytick' , labelsize=SMALL_SIZE ) # fontsize of the tick labels
matplotlib.rc('legend', fontsize =SMALL_SIZE ) # legend fontsize
matplotlib.rc('figure', titlesize=BIGGER_SIZE) # fontsize of the figure title
print('Matplotlib configured')

Data location is /home/mer49/Workspace2/PPC_data/
Matplotlib configured


## Get data (just one session for starters)

### Randomly select subject and context

In [2]:
MINNEURONS = 50   # Only consider sessions w. at least this many neurons
MAXSAMPLES = 400  # Remove overly-long trials
NXVAL      = 10   # No. crossvalidation blocks
JITTER     = 25   # No. frames tolerance btwn trial end marker & reward
LOWF       = 0.03 # Low-frequency cutoff for filtering
HIGHF      = 1.25 # High-frequency cutoff for filtering

animals    = get_subject_ids()
animal     = choice(animals)
cues       = [ppc_trial.Trial.CUE_LEFT,ppc_trial.Trial.CUE_RIGHT]
CUE        = choice(cues)
PREV       = choice(cues)
cname      = ['left','right'][CUE]
pname      = ['left','right'][PREV]
sessions   = get_session_ids(animal)
session    = choice(sessions)
units      = good_units_index(animal,session)
NNEURONS   = len(units)
FS         = get_FS(animal,session)

print('Using animal' ,animal )
print('Using trials with cued direction',cname)
print('Using trials with previous direction',pname)
print('Using session',session)

#### Sanity check that physical units seem correct
intrials = get_intrial(animal,session)
print('\n%02d%% of timepoints are within a trial'%(100*mean(intrials)))
print('y dimension is %s meters'%np.max(get_y(animal,session)[intrials]))
print('x dimension is %s meters'%np.max(get_x(animal,session)[intrials]))
print('peak dx from position is',\
      np.max(abs(diff(get_x(animal,session)))[intrials[:-1]])*FS)
print('peak dx from velocity is',\
      np.max(abs(get_dx(animal,session))[intrials]))

hdfdata = get_data(animal,session)
printmatHDF5(hdfdata)
for i,label in enumerate(getHDF(hdfdata,'session_obj/timeSeries/virmen/labels')):
    print(i,label)

Using animal 5
Using trials with cued direction right
Using trials with previous direction right
Using session 1

85% of timepoints are within a trial
y dimension is 4.580330147041574 meters
x dimension is 0.15015015015015015 meters
peak dx from position is 0.7997261576589432
peak dx from velocity is 0.47088561809043883
 session_obj:
 :  | confidenceLabel | 955 x 1 | float64 | 'numpy.ndarray | 
 :  | deltaDays       | 1 x 33  | float64 | 'numpy.ndarray | 
 :  | numConditions   | 1 x 34  | float64 | 'numpy.ndarray | 
 :  | sessionList     | 1 x 33  | object  | 'numpy.ndarray | 
 :  | sessionNumber   | 1 x 1   | float64 | 1.0            | 
 : analysis:
 : : glm:
 : : :  | cv_object           | 1 x 955     | object  | 'numpy.ndarray | 
 : : :  | filterMatrix        | 11580 x 144 | float64 | 'numpy.ndarray | 
 : : :  | filterMatrixIDs     | 144 x 1     | object  | 'numpy.ndarray | 
 : : :  | foldID              | 11580 x 1   | float64 | 'numpy.ndarray | 
 : : :  | frame_idx           | 115

## Decode position, speed, and angle

Position has both real (x,y) location as well a pseudotime location; Which to use? 

In [3]:
def get_in_trial(signal,animal,session,MAXSAMPLES=MAXSAMPLES,JITTER=JITTER,meanzero=True,dozscore=False):
    # Find edges of all CORRECT trials
    all_trials = get_basic_trial_info(animal,session,pad_edges=False,JITTER=JITTER)
    use_trials = [t for t in all_trials if t.correct and t.nsample<MAXSAMPLES]
    starts     = [t.istart for t in use_trials]
    stops      = [t.istop  for t in use_trials]
    snips      = [signal[a:b,...] for (a,b) in zip(starts,stops)]
    if meanzero: # mean-center each trial separately
        snips = list(map(lambda x:zeromean(x,axis=0),snips))
    if dozscore:
        snips = list(map(lambda x:zscore(x,axis=0),snips))
    return np.concatenate(snips)

@memoize
def get_neural_signals_for_training(animal,session,units=None,verbose=False):
    # Get filtered log calcium transients
    if units is None: 
        units = good_units_index(animal,session)
    #Y  = array([get_smoothed_dFF(animal,session,u,LOWF,HIGHF) for u in units])
    dFF = get_dFF(animal,session,units)
    FS  = get_FS(animal,session)
    Y   = array([bandpass_filter(z,fa=LOWF,fb=HIGHF,Fs=FS) for z in dFF.T])
    if verbose:
        print('Sample rate is %f Hz'%get_FS(animal,session))
        print('Filtering between %0.2f and %0.2f Hz'%(LOWF,HIGHF))
        print('Obtained filtered calcium signals, shape is',Y.shape)
    ydata = zeromean(get_in_trial(Y.T,animal,session),axis=0)
    times = arange(ydata.shape[0])/FS
    return times, ydata

def polar_error_degrees(xdata,xhat):
    # Report error in physical units
    herr = abs(xdata-xhat)
    herr[herr>180] = 360-herr[herr>180]
    mabs = mean(abs(herr))
    rmse = mean(herr**2)**0.5
    return rmse, mabs

@memoize
def head_direction_prediction_error(animal,session,units=None,mode='polar'):
    times, ydata = get_neural_signals_for_training(animal,session,units=units)
    if mode=='polar':
        theta      = get_in_trial(get_theta(animal,session),animal,session,meanzero=False)*pi/180
        sincos     = array([sin(theta),cos(theta)]).T
        xmean      = mean(sincos,axis=0)
        xdata      = zeromean(sincos,axis=0)
        W,xhat,_,_ = crossvalidated_least_squares(ydata,xdata,NXVAL)
        # Compute and return error measures
        thetahat   = angle((xmean+xhat)@[1j,1])
        rmse, mabs = polar_error_degrees(theta*180/pi,thetahat*180/pi)
        return theta, thetahat, rmse, mabs
    elif mode=='linear':
        # Angle decoding in degrees
        xdata      = get_in_trial(get_theta(animal,session),animal,session,meanzero=True)
        xmean      = mean(xdata)
        xdata      = zeromean(xdata)
        W,xhat,_,_ = crossvalidated_least_squares(ydata,xdata,NXVAL)
        # Compute and return error measures
        rmse, mabs = polar_error_degrees(xdata,xhat)
        return xdata+xmean, xhat+xmean, rmse, mabs
    raise ArgumentError('Mode should be either polar or linear')

@memoize
def speed_prediction_error(animal,session,units=None):
    '''
    Error in decoding Y speed
    '''
    times, ydata = get_neural_signals_for_training(animal,session,units=units)
    dy    = get_dy(animal,session)
    speed = get_in_trial(abs(dy),animal,session)
    xdata = zeromean(speed)
    xmean = mean(speed)
    W,xhat,_,_ = crossvalidated_least_squares(ydata,xdata,NXVAL)
    # Compute and return error measures
    xerr  = (xdata-xhat)
    mabs  = mean(abs(xerr))
    rmse  = mean(abs(xerr)**2)**0.5
    return times, speed, xhat+xmean, rmse, mabs

@memoize
def position_prediction_error(animal,session,units=None):
    '''
    Error in decoding Y position
    '''
    times, ydata = get_neural_signals_for_training(animal,session,units=units)
    y     = get_y(animal,session)
    xdata = get_in_trial(y,animal,session)
    xmean = mean(xdata,axis=0)
    xdata = zeromean(xdata,axis=0)
    W,xhat,_,_ = crossvalidated_least_squares(ydata,xdata,NXVAL)
    # Compute and return error measures
    xerr = zeromean(xdata)-zeromean(xhat)
    mabs = mean(abs(xerr))
    rmse = mean(abs(xerr)**2)**0.5
    return xdata+xmean, xhat+xmean, rmse, mabs

# Greedy optimization of cross-validated fits

### Modify crossvalidated least-squares to precompute covariance structure

 - Break data into testing and training groups
 - Get covariances for each training group

## Closed-form L2 error for greedy sorting

See personal notebook for derivation. We think the L2 error should be

\begin{equation}
\begin{aligned}
\operatorname{tr}\left(
\Sigma_{b_\text{test}} + \Sigma_{ba} \Sigma_{aa}^{-1}
\Sigma_{aa_\text{test}}\Sigma_{aa}^{-1}\Sigma_{ab}-2\Sigma_{ba}\Sigma_{aa}^{-1}\Sigma_{ab_\text{test}}
\right)
\end{aligned}
\end{equation}

Abbreviated derivation of the code:

    mean((tsb - tsa @ solve(aatr,abtr))**2)
    mean(tsb**2 + (tsa @ solve(aatr,abtr))**2 - 2*tsb*(tsa @ solve(aatr,abtr)))
    mean(tsb**2) + mean((tsa @ solve(aatr,abtr))**2) - 2*mean(tsb*(tsa @ solve(aatr,abtr)))

# Neuron-stability and information correlations in an example datastaset

Get
 - Individual neuron predictions
 - Neuron predictions in ensemble
 - Neuron prediction with ascending greedy search

# Percentage of cells and composition of informative subsets over time

Get the following for each day: 

 - Ascending  Greedy sort with per-cell Δerrors
 - Descending Greedy sort with per-cell Δerrors
 - Single cell Δerrors
 
Do this for a range of regularization strengths. 

**Caution: changing the source code of the cell below will reset any previously cached results**

In [4]:
@memoize
def get_importance_forward_position(animal,session,units,reg,method,NXVAL):
    
    # Load data
    times,ydata = get_neural_signals_for_training(animal,session)
    avail_units = good_units_index(animal,session)
    xdata       = get_in_trial(get_y(animal,session),animal,session,dozscore=True)
    
    units = array(sorted(list(set(units))))
    assert(set(units)==(set(avail_units)&set(units)))
    pick  = array([where(avail_units==u)[0][0] for u in units])
    ydata = ydata[:,pick]
    u2    = avail_units[pick]
    assert(all(u2==units))
    
    # Error with all cells (best-case)
    T,N  = ydata.shape
    R    = reg*eye(N)
    xhat = crossvalidated_least_squares(ydata,xdata,NXVAL,reg=reg)[1]
    err0 = mean((xdata-xhat)**2)
    
    # Break data into testing and training groups
    trA,trB,tsA,tsB = partition_data_for_crossvalidation(ydata,xdata,NXVAL)
    
    # Get covariances for each training group
    # We solve linear least squares as:
    # Solves Ax=B for x with L2 regularization
    # Q = A.T.dot(A) + np.eye(A.shape[1])*reg*A.shape[0]
    # return np.linalg.solve(Q, A.T.dot(B))
    sAtr  = [float32((a.T@ a)/a.shape[0]+R) for a     in trA         ] # Train ind.
    sABtr = [float32((a.T@ b)/a.shape[0])   for (a,b) in zip(trA,trB)] # Train cross
    sBts  = [float32((b.T@ b)/b.shape[0])   for b     in tsB         ] # Test dep.
    sAts  = [float32((a.T@ a)/a.shape[0])   for a     in tsA         ] # Test ind. NO REGULARIZATION FOR THIS ONE
    sABts = [float32((a.T@ b)/a.shape[0])   for (a,b) in zip(tsA,tsB)] # Test cross
    emax  = mean(sBts) # Baseline maximum error rate (variance of testing data)
    
    def xverrsubset(u):
        u = int32(array(sorted(list(set(u)))))
        error = emax*NXVAL
        for Σaa,Σab,Σaav,Σabv in zip(sAtr,sABtr,sAts,sABts):
            w = solve(Σaa[u,:][:,u],Σab[u])
            error += w.T @ (Σaav[:,u] @ w - 2 * Σabv)[u]
        return error/NXVAL

    if method=='ascending_greedy':
        unused = set(arange(N))
        uids   = []
        errs   = []
        for i in progress_bar(range(N)):
            check   = array(sorted(list(unused)))
            er      = [xverrsubset(uids+[ch]) for ch in check]
            best    = check[argmin(er)]
            uids   += [best]
            errs   += [np.min(er)]
            unused -= {best}
        # Neurons are added to the population by
        # greedy search. Each neuron added decreases the
        # error (starting at emax) by some amount. We add
        # the best neuron on each iteration. The result
        # should be the neurons sorted from most to least
        # important. 
        Δerrs = -diff([emax]+errs)
        sys.stdout.write('\r'+' '*70+'\r'); sys.stdout.flush()
    
    if method=='descending_greedy':
        used   = set(arange(N))
        uids   = []
        errs   = []
        for i in progress_bar(range(N-1)):
            check = array(sorted(list(used)))
            er    = [xverrsubset(used-{ch}) for ch in check]
            best  = check[argmin(er)]
            uids += [best]
            errs += [np.min(er)]
            used -= {best}
        last  = list(used)
        uids += last
        errs += [xverrsubset(last)]
        # Neurons are removed from the population, removing
        # neurons that don't matter first. The result is
        # a list of neurons sorted in increasing order of
        # importance. Error steadily increases in this
        # procedure.
        Δerrs = diff([err0]+errs)
        # Switch to descending order of importance
        Δerrs = Δerrs[::-1]
        errs  =  errs[::-1]
        uids  =  uids[::-1]
        
    if method=='single_cell':
        uids  = arange(N)
        # Error after ADDING the unit from the population
        # This should lower the error relative to guessing
        # (emax). The more the error is lowered, the better
        # the neuron is
        errs  = array([xverrsubset([i]) for i in uids])
        Δerrs = emax-errs
        order = argsort(-Δerrs)
        errs  = errs[order]
        uids  = uids[order]
        Δerrs = Δerrs[order]

    if method=='leave_one_out':
        uids  = arange(N)
        # Error after REMOVING the unit from the population
        # This should increase the error above optimum err0
        # The higher this is, the more important neuron was
        errs  = array([xverrsubset(set(uids)-{j}) for j in uids])
        Δerrs = errs - err0
        order = argsort(-Δerrs)
        errs  = errs[order]
        uids  = uids[order]
        Δerrs = Δerrs[order]
    
    return err0, emax, errs, Δerrs, uids, units[uids]


In [5]:
import pickle

use = [(1,[ 1,  4,  5,  6,  7, 10, 14]),
       (3,[ 1,  2,  4,  6,  7,  8,  9, 10, 11, 12]),
       (3,[13, 14, 15, 16, 17, 18, 19, 20, 21, 22]),
       (4,[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11]),
       (5,[ 6,  7,  8,  9, 10, 11, 12])]

NXVAL        = 10
methods      = ['single_cell','leave_one_out','ascending_greedy','descending_greedy']
regstrengths = [1e-1,1e-2,1e-3,1e-4,1e-5,1e-6,1e-8,1e-10]

In [None]:
RECOMPUTE_RESULTS = False
saveto = './datafiles/importance_sorting_result.p'
if RECOMPUTE_RESULTS:
    results   = {}
    timestamp = now()
    for animal,sessions in use:
        print('Subject',animal)
        units,uidxs = get_units_in_common(animal,sessions)
        for session in progress_bar(sessions):
            for method in methods:
                for reg in regstrengths: 
                    results[animal,session,method,reg] = get_importance_forward_position(animal,session,units,reg,method,NXVAL)
                    # err0, emax, errs, Δerrs, uids, units[uids]
            pickle.dump(results,open(saveto,'wb'))
else:
    results = pickle.load(open(saveto,'rb'))

Subject 1
Subject 3                                                             
Subject 3                                                             
Subject 4                                                             
[######                                            ] 12%  18/142

# Check situation with unit ID consistency number

 - 1: same cell
 - 2: likely the same cell
 - 3: maybe the same cell
 - 4: Nuclear 

In [None]:
for animal, sessions in use:
    units,uidxs = get_units_in_common(animal,sessions)
    stability_index = array([get_recording_stability_index(animal,s) for s in sessions])
    print(np.min(stability_index,axis=0)[units])

# What to check? 

 - Pick optimal subset that explains most of the variance
 - Reproduce these claims in the paper
   - 6% of neurons are consistently in the top 50%
   - 0% of nerons are in the top 10%
   - These check out, results are similar for various levels of regularization


In [None]:

for method in methods:
    print('\nUsing method:',method)
    print('==============================')

    for animal,sessions in use:
        print('\nM%d (%2d sessions spanning %2d days)'%(animal,len(sessions),sessions[-1]-sessions[0]))
        print('---------------------------------')
            
        for reg in [1e-10,1e-4,1e-2,1e-1]:
            print('Regularization strength %0.0e:'%reg)
            top10 = []
            top50 = []
            errs  = []
            err30 = []
            pct99 = []
            for session in sessions:
                err0, emax, errs, Δerrs, uids, units = results[animal,session,method,reg]
                Nunits = len(units)
                N10    = int(Nunits/10)
                N30    = int(Nunits/3)
                N50    = int(Nunits/2)
                top10 += [units[:N10]]
                top50 += [units[:N50]]
                err30 += [errs[N30]/emax]
                errs  += [err0/emax]
                if 'greedy' in method:
                    N99    = find((1-errs/emax)>=(1-err0/emax)*0.99)[0]
                    pct99 += [N99/Nunits]
            all10 = set.intersection(*map(set,top10))
            all50 = set.intersection(*map(set,top50))
            print(' - Cells always in the top 10%% : %2d units, %2d%%'%(len(all10),len(all10)*100/Nunits))
            print(' - Cells always in the top 50%% : %2d units, %2d%%'%(len(all50),len(all50)*100/Nunits))
            if 'greedy' in method:
                print(' - Average NMSE using all cells: %2d%% ±%2dσ'%(mean(errs)*100,std(errs)*100))
                print(' - Average NMSE using top 30%%  : %2d%%'%(100*mean(err30)))
                print(' - 30%% of cells account for %d%% of variance explained by full population'%(100*(1-mean(err30))/(1-mean(errs))))
                print(' - %2d%% of cells account for 99%% of variance explained by full population'%(100*mean(pct99)))


# Table format

We considered whether it might be possible to identify a long-term stable subset based on single-day decoding performance. To assess this, we ranked subsets of cells using greedy search, based on their contribution to cross-validated decoding within a single session (Methods). On any single session, no more than 30% of cells were needed to achieve 99% of the explained variance of the full population. Despite this, it was impossible to identify a core subset that was stable over time. No more than 1% of cells were ranked in the top 10% for importance on all days, an no more than 13% in the top 50% for all days. 

We also evaluted whether regularization could affect these results. Regularization was performed by adding a constant time the identity matrix $\lambda I$ to the covariance matrix of the neural data. For all sessions, we found that generalization error could be slightly reduced using regularization strengths between $\lambda=10^{-4}$ to $10^{-3}$. Despite this, regularization did not improve our ability to predict which cells would be stable over days. Similar results were obtained for a range of regularization values ranging from $\lambda=10^{-10}$ to $10^{-1}$.


In [None]:
for animal, sessions in use:
    all_errs = []
    for reg in regstrengths[1:-2]:
        errs = []
        for session in sessions:
            err0, emax, errs, Δerrs, uids, units = results[animal,session,method,reg]
            errs += [err0/emax]
        #print('%0.2e %0.4f'%(reg,mean(errs)))
        all_errs += [mean(errs)]

    subplot(221)
    plot(log10(regstrengths[1:-2]), all_errs/all_errs[-1]*100,lw=1)
    simpleraxis()
    xlabel('Regularization strength')
    ylabel('% change in NMSE')
    xticks(range(-6,-1,2),['$10^{%d}$'%d for d in range(-6,-1,2)]);
    axhline(100,lw=1,color='k',linestyle=':')


In [None]:

for method in methods:
    print('\nUsing method:',method)
    print('==============================')

    min10,max10 = inf,-inf
    min50,max50 = inf,-inf
    min99,max99 = inf,-inf
    for animal,sessions in use:
        for reg in [1e-10,1e-6,1e-3,1e-2]:
            top10 = []
            top50 = []
            errs  = []
            err30 = []
            pct99 = []
            for session in sessions:
                err0, emax, errs, Δerrs, uids, units = results[animal,session,method,reg]
                Nunits = len(units)
                N10    = int(Nunits/10)
                N30    = int(Nunits/3)
                N50    = int(Nunits/2)
                top10 += [units[:N10]]
                top50 += [units[:N50]]
                err30 += [errs[N30]/emax]
                errs  += [err0/emax]
                if 'greedy' in method:
                    N99    = find((1-errs/emax)>=(1-err0/emax)*0.99)[0]
                    pct99 += [N99/Nunits]
            all10 = set.intersection(*map(set,top10))
            all50 = set.intersection(*map(set,top50))
            pct10 = len(all10)*100/Nunits
            min10 = min(min10,pct10)
            max10 = max(max10,pct10)
            pct50 = len(all50)*100/Nunits
            min50 = min(min50,pct50)
            max50 = max(max50,pct50)
            if 'greedy' in method:
                min99 = min(min99,mean(pct99))
                max99 = max(max99,mean(pct99))

    print('%0.3f %0.3f'%(min10,max10))
    print('%0.3f %0.3f'%(min50,max50))
    if 'greedy' in method:
        print('%0.3f %0.3f'%(min99,max99))

# How does decoding performance degrade? 

 - Get top 10%, 30%, 50%, 100%
 - Compute Growth of NMSE in time

In [None]:

@memoize
def get_all_xy_data(animal,sessions,uu):
    xdata,ydata = [],[]
    for session in sessions:
        t,y    = get_neural_signals_for_training(animal,session)
        us     = good_units_index(animal,session)
        assert(set(uu)&set(us)==set(uu))
        xdata += [get_in_trial(get_y(animal,session),animal,session,dozscore=True)]
        ydata += [y[:,[where(us==u)[0][0] for u in uu]]]
    return xdata,ydata

def top_K_percent(animal,session,method,reg,PCT):
    err0, emax, errs, Δerrs, uids, order = results[animal,session,method,reg]
    Nunits = len(order)
    Npct   = int(Nunits*PCT/100)
    top    = order[:Npct]
    return top

def get_inter_day_generalization_errors(animal,session,sessions,method,reg,PCT):
    # Identify top percent of units
    top = top_K_percent(animal,session,method,reg,PCT)  
    # Collect data for these neuronss
    xdata, ydata = get_all_xy_data(animal,sessions,top)
    # Fit model on target session, first half
    i = where(session==array(sessions))[0][0]
    T = xdata[i].shape[0]
    w = reglstsq(ydata[i][:T//2,:],xdata[i][:T//2],reg=reg).ravel()
    # Only test on second half of target sessions
    ydata[i] = ydata[i][T//2:,:]
    xdata[i] = xdata[i][T//2:]
    # Test model on all days
    ee = array([mean((x-y@w)**2) for x,y in zip(xdata,ydata)])
    # Get actual elapsed days between sessions
    daymap = dict(zip(sessions,get_days(animal)))
    days   = [daymap[s] for s in sessions]
    Δdmax  = days[-1]-days[0]
    # Average ?
    Δderrs = defaultdict(list)
    for j in range(len(sessions)):
        Δdays = abs(days[i]-days[j])
        Δderrs[Δdays] += [ee[j]]
    Δdays = array(sorted(list(Δderrs.keys())))
    Δerrs = array([mean(Δderrs[i]) for i in Δdays])
    return Δdays,Δerrs,Δderrs

@memoize
def get_day_error_growth(animal,sessions,method,PCT,reg):
    units  = get_units_in_common(animal,sessions)[0]
    # Each session has its own set of top units _
    Nsessions = len(sessions)
    all_Δderrs = defaultdict(list)
    for i,session in enumerate(sessions):
        Δdays,Δerrs,Δderrs = get_inter_day_generalization_errors(animal,session,sessions,method,reg,PCT)
        for Δ,ee in Δderrs.items():
            all_Δderrs[Δ] += ee
    Δdays = array(sorted(list(all_Δderrs.keys())))
    Δerrs = array([mean(all_Δderrs[i]) for i in Δdays])
    print('hi')
    return Δdays,Δerrs,all_Δderrs

method = 'descending_greedy'
#method = 'single_cell'
#method = 'ascending_greedy'
#method = 'leave_one_out'
reg    = 1e-5
animal, sessions = use[2]
Nsessions = len(sessions)

subplot(221)
for i,PCT in enumerate((10,30,50,100)):
    Δdays,Δerrs,all_Δderrs = get_day_error_growth(animal,sessions,method,PCT,reg)
    plot(Δdays,Δerrs,color=riley(i*0.8/4),label='%d%%'%PCT)

simpleraxis()
xlabel('Δ days')
ylabel('NMSE')
rightlegend()

### Identify good SNR signals as in Driscoll et al. 

 - You'll need to get position-triggered averages of the signal
 - You'll need to identify cells that are peaked

In [None]:
def get_SNR_metrics(animal,session,units,
                    N=100,
                    threshold=0.1,
                    doplot=False,
                    miny = -1.8,
                    maxy = 1.8):
    times,ydata = get_neural_signals_for_training(animal,session)
    xdata       = get_in_trial(get_y(animal,session),animal,session,dozscore=True)
    avail_units = good_units_index(animal,session)
    units       = array(sorted(list(set(units))))
    pick        = array([where(avail_units==u)[0][0] for u in units])
    ydata       = ydata[:,pick]
    xdata       = cat(extract_in_trial(xdata,animal,session,dozscore=True))
    ydata       = cat(extract_in_trial(ydata,animal,session,dozscore=True))
    edges   = linspace(miny,maxy,N+1)
    centers = (edges[:-1]+edges[1:])/2
    μ,σ,m = [],[],[]
    for a,b in zip(edges[:-1],edges[1:]):
        i = (xdata>=a)&(xdata<b)
        y = ydata[i,:]
        μ.append(np.sum(y,axis=0))
        σ.append(np.sum(y**2,axis=0))
        m.append(len(y))
    μ = array(μ)
    σ = array(σ)
    m = array(m)+0.5
    smooth = exp(-linspace(-8,8,N)**2)
    smooth/= sum(smooth)
    smooth = toeplitz(fftshift(smooth))
    smooth/= sum(smooth,1)[:,None]
    μ = smooth @ μ
    σ = smooth @ σ
    m = smooth @ m
    μ = μ/m[:,None]
    σ = sqrt( σ/m[:,None] - μ**2)
    ε = σ/sqrt(m)[:,None]*1.96

    peaks = argmax(μ,axis=0)
    SNRs  = array([mean(μi-εi>threshold) for μi,εi in zip(μ.T,ε.T)])
    tuned = find(SNRs>threshold)
    modix = mean((μ/σ)**2,axis=0)**0.5
    
    if doplot:
        figure(figsize=(TEXTWIDTH,)*2)
        #tuned = arange(len(units)) ## REMOVE THIS LATER
        tuned  = argsort(modix)
        ntuned = len(tuned)
        k = int(ceil(sqrt(ntuned)))
        for j,i in enumerate(tuned):
            subplot(k,k,j+1)
            title(i)
            u = array(μ)[:,i]
            plot(centers,u,lw=0.6);
            e = array(ε)[:,i]
            fill_between(centers,u-e,u+e,lw=0.6,alpha=0.25,color=TURQUOISE);
            axhline(0,lw=1,linestyle=':')
            noxyaxes();
            simpleraxis()
        tight_layout()
        
    return peaks, SNRs, tuned, modix

animal,sessions = use[0]
units,uidxs = get_units_in_common(animal,sessions)
session     = sessions[0]
peaks, SNRs, tuned, modix = get_SNR_metrics(animal,session,units,doplot=True)

### Check how ranking stability relates to SNR

In [None]:
method = 'single_cell'
method = 'descending_greedy'
reg    = 1e-3

animal,sessions = use[1]
#for animal,sessions in use:    
units,uidxs = get_units_in_common(animal,sessions)
Nunits = len(units)
all_peak, all_SNRs, all_tuned, all_modix = list(map(array,zip(*[get_SNR_metrics(animal,s,units) for s in sessions])))
best = find(np.min(all_SNRs,axis=0)>0.05)
all_rank = [] 
for session in sessions:
    err0, emax, errs, Δerrs, order_idx, order_units = results[animal,session,method,reg]
    order_idx = array(order_idx)
    rank      = array([find(order_idx==i) for i in range(Nunits)]).squeeze()
    all_rank.append(rank)
all_rank = array(all_rank)

In [None]:
# These have consistently sharp peaks
minSNR = np.min(all_SNRs,axis=0)
print('Median minimum SNR is',median(minSNR))
best = find(minSNR>0.025)
avg  = find(np.mean(all_SNRs,axis=0)>0.1)

# These have qualitatively stable peak locatoins
peak_Δ = np.max(all_peak,axis=0)-np.min(all_peak,axis=0)
stable = find(peak_Δ<30)

# Report some numbers
a = set(avg)
b = set(best)
s = set(stable)
print('')
print('%d cells (%d%%) are strongly tuned on average'%(len(a),100*len(a)/len(units)))
print('%d cells (%d%%) are consistently strongly tuned'%(len(b),100*len(b)/len(units)))
print('%d cells (%d%%) are have stable peaks'%(len(s),100*len(s)/len(units)))
print('%2.1f%% of consistently strongly-tuned cells have stable peaks'%(100*len(b&s)/len(b)))
print('%2.1f%% of cells with stable peaks are consistently strongly tuned'%(100*len(b&s)/len(s)))

In [None]:
# These have at least one sharp peak
maxSNR = np.max(all_SNRs,axis=0)
print('Median maximum SNR is',median(maxSNR))
best = find(maxSNR>0.25)
avg  = find(np.mean(all_SNRs,axis=0)>0.1)

# These have qualitatively stable peak locatoins
peak_Δ = np.max(all_peak,axis=0)-np.min(all_peak,axis=0)
stable = find(peak_Δ<30)

# Report some numbers
a = set(avg)
b = set(best)
s = set(stable)
print('')
print('%d cells (%d%%) are strongly tuned on average'%(len(a),100*len(a)/len(units)))
print('%d cells (%d%%) are strongly tuned sometimes'%(len(b),100*len(b)/len(units)))
print('%d cells (%d%%) are have stable peaks'%(len(s),100*len(s)/len(units)))
print('%2.1f%% of sometimes strongly-tuned cells have stable peaks'%(100*len(b&s)/len(b)))
print('%2.1f%% of cells with stable peaks are sometimes strongly tuned'%(100*len(b&s)/len(s)))

### Run sorting based only on top 30% most modulated cells

In [None]:
RECOMPUTE_RESULTS = False

saveto = "./datafiles/importance_sorting_result_top_modulated_only.p"

if RECOMPUTE_RESULTS:
    results_modulated = {}
    for animal,sessions in use:
        print('Subject',animal)
        units,uidxs = get_units_in_common(animal,sessions)

        Nunits = len(units)
        all_peak, all_SNRs, all_tuned, all_modix = list(map(array,zip(*[get_SNR_metrics(animal,s,units) for s in sessions])))
        # Let's select cells based on modulation index
        maxMIX  = np.max(all_modix,axis=0)
        mmm     = percentile(maxMIX,30)
        study   = maxMIX>mmm
        units   = array(units)[study]
        
        for session in progress_bar(sessions):
            for method in methods:
                for reg in regstrengths: 
                    results_modulated[animal,session,method,reg] = (units,get_importance_forward_position(animal,session,units,reg,method,NXVAL))
                    # err0, emax, errs, Δerrs, uids, units[uids]
            pickle.dump(results_modulated,open(saveto,'wb'))
else:
    results_modulated = pickle.load(open(saveto,'rb'))


### Check: are ranking results any different now? 

Not really

In [None]:
reg = 1e-3
print('Regularization strength %0.0e:'%reg)
for method in methods:
    print('\nUsing method:',method)
    print('==============================')

    for animal,sessions in use:
        print('\nM%d (%2d sessions spanning %2d days)'%(animal,len(sessions),sessions[-1]-sessions[0]))
        print('---------------------------------')

        top10 = []
        top50 = []
        errs  = []
        err30 = []
        pct99 = []
        for session in sessions:
            err0, emax, errs, Δerrs, uids, units = results_modulated[animal,session,method,reg][1]
            Nunits = len(units)
            N10    = int(Nunits/10)
            N30    = int(Nunits/3)
            N50    = int(Nunits/2)
            top10 += [units[:N10]]
            top50 += [units[:N50]]
            err30 += [errs[N30]/emax]
            errs  += [err0/emax]
            if 'greedy' in method:
                N99    = find((1-errs/emax)>=(1-err0/emax)*0.99)[0]
                pct99 += [N99/Nunits]
        all10 = set.intersection(*map(set,top10))
        all50 = set.intersection(*map(set,top50))
        print(' - Cells always in the top 10%% : %2d units, %2d%%'%(len(all10),len(all10)*100/Nunits))
        print(' - Cells always in the top 50%% : %2d units, %2d%%'%(len(all50),len(all50)*100/Nunits))
        if 'greedy' in method:
            print(' - Average NMSE using all cells: %2d%% ±%2dσ'%(mean(errs)*100,std(errs)*100))
            print(' - Average NMSE using top 30%%  : %2d%%'%(100*mean(err30)))
            print(' - 30%% of cells account for %d%% of variance explained by full population'%(100*(1-mean(err30))/(1-mean(errs))))
            print(' - %2d%% of cells account for 99%% of variance explained by full population'%(100*mean(pct99)))


## Measures we can correlate

Check how single-cell, pull-one-out, and modulation index compare. 
Can any of these be used to predict the other in the future? 

In [None]:
reg = 1e-3
animal,sessions = use[1]
units,uidxs = get_units_in_common(animal,sessions)

Nunits = len(units)
all_peak, all_SNRs, all_tuned, all_modix = list(map(array,zip(*[get_SNR_metrics(animal,s,units) for s in sessions])))
        
Δerrs1 = array([results[animal,session,'single_cell'  ,reg][3] for session in sessions])
uids1  = array([results[animal,session,'single_cell'  ,reg][4] for session in sessions])
Δerrs2 = array([results[animal,session,'leave_one_out',reg][3] for session in sessions])
uids2  = array([results[animal,session,'leave_one_out',reg][4] for session in sessions])

# Convert ranked orderings to list of ranks
pos1   = array([[find(uu==i)[0] for i in range(Nunits)] for uu in uids1])
pos2   = array([[find(uu==i)[0] for i in range(Nunits)] for uu in uids2])
# Re-order the Δε
Δε1 = array([[e[p]] for e,p in zip(Δerrs1,pos1)]).squeeze()
Δε2 = array([[e[p]] for e,p in zip(Δerrs2,pos2)]).squeeze()

subplot(411)
imshow(Δε1,aspect='auto')
noxyaxes()
title('Single Cell Δerror')
subplot(412)
imshow(Δε2,aspect='auto')
noxyaxes()
title('Pull-one-out Δerror')
subplot(413)
imshow(all_modix,aspect='auto')
title('Modulation Index')
noxyaxes()
subplot(414)
imshow(-abs(zeromean(all_peak,axis=0)),aspect='auto')
title('Peak Shift Range Magnitude')
noxyaxes()
tight_layout()

In [None]:
subplot(221)
scatter(np.min(Δε1,0),np.min(Δε2,0))
simpleaxis(); xlabel('Δε1'); ylabel('Δε2'); 
subplot(222)
scatter(np.min(all_modix,0),np.min(Δε2,0))
simpleaxis(); xlabel('modulation'); ylabel('Δε2'); 
subplot(223)
scatter(np.min(Δε1,0),np.min(all_modix,0))
simpleaxis(); xlabel('Δε1'); ylabel('modulation'); 
tight_layout()

## Check how much different critera overlap

In [None]:
Δpeak = np.max(all_peak,axis=0)-np.min(all_peak,axis=0)
stable    = find(Δpeak<20)
modulated = find(np.min(all_modix,0)>0.1)
minΔε     = np.min(Δε1,0)
important = find(minΔε>median(minΔε))
s = set(stable   )
m = set(modulated)
i = set(important)
ns,nm,ni = len(s),len(m),len(i)
print(ns,nm,ni)
print('Fraction stable    cells that are modulated',len(s&m)/ns)
print('Fraction modulated cells that are stable   ',len(s&m)/nm)
print('Fraction stable    cells that are important',len(s&i)/ns)
print('Fraction important cells that are stable   ',len(s&i)/ni)
print('Fraction modulated cells that are important',len(m&i)/nm)
print('Fraction important cells that are modulated',len(m&i)/ni)

## Ranking by modulation index is also unstable

In [None]:
print('Modulation index rank stability')
print('M%d sessions'%animal,sessions)
m      = all_modix
mxrank = argsort(-m,axis=1)
mxrank[:,:Nunits//10]
print('%2d%% always in top 10'%\
      (100*len(set.intersection(*map(set,mxrank[:,:Nunits//10])))/Nunits))
print('%2d%% always in top 50'%\
      (100*len(set.intersection(*map(set,mxrank[:,:Nunits//2 ])))/Nunits))

## 1. Plot the variance of a ranking as a function of its mean; Do good cells tend to stay good-ish? 


### Get various statistics

In [None]:
method = 'single_cell'
#method = 'ascending_greedy'
reg    = 1e-3

animal,sessions = use[1]
#for animal,sessions in use:    
units,uidxs = get_units_in_common(animal,sessions)
Nunits = len(units)
all_peak, all_SNRs, all_tuned, all_modix = list(map(array,zip(*[get_SNR_metrics(animal,s,units) for s in sessions])))
all_rank = [] 
for session in sessions:
    err0, emax, errs, Δerrs, order_idx, order_units = results[animal,session,method,reg]
    all_rank.append(array([find(array(order_idx)==i) for i in range(Nunits)]).squeeze())
all_rank = array(all_rank)

### Check how variance of metrics changes for good/bad cells

In [None]:

subplot(221)
x = np.mean(all_rank/Nunits,0)
y = var(all_rank/Nunits,0)
scatter(x,y,s=5,clip_on=False)
simpleaxis()
title('Ascending greedy ranking stability')
xlabel('Mean rank\n(lower = more important)')
ylabel('Rank variance')
xlim(0,1); ylim(0,0.125)
x = linspace(0,1,100)
plot(x,x*(1-x)/4,color=TURQUOISE,zorder=-inf)

subplot(222)
x = np.mean(all_modix,0)
y = std(all_modix,0)#/mean(all_modix,0)
scatter(x,y,s=5,clip_on=False)
simpleaxis()
title('Modulation index stability')
xlabel('Mean over days\n(higher = stronger task modulation)')
#ylabel('Coefficient of\nvariation σ/μ')
ylabel('Standard deviation')

tight_layout()

In [None]:
mm = mean(all_modix,0)
o  = argsort(mm)
mx = all_modix[:,o]
mm = mm[o]
ee = std(mx,0)*1.96
m1 = np.min(all_modix,0)[o]
m2 = np.max(all_modix,0)[o]

In [None]:
scatter(arange(Nunits),mm,s=5)
for i,(m1i,m2i) in enumerate(zip(m1,m2)):
    plot([i,i],[m1i,m2i],lw=0.8,color=BLACK)
simpleraxis()


### Thoughts? 

 - Variability is high
 - Middle 50% of cells extremely variable and unstable
 - Variance in modulation does not decrease or high-quality cells
 
Overall this is consistent with rankings being highly unstable. Nevertheless Driscoll et al. found that 40% of task-modulated cells had stable peaks. This does not mean that the usefulness of these cells for decoding was also stable. We need to better dissociate various effects. 

## Inspect the so-called stable cells

In [None]:
all_peak.shape

stability = std(all_peak,0)

In [None]:
scatter(mean(all_modix,0),stability)

In [None]:
animal,sessions = use[1]
# Load data
times,ydata = get_neural_signals_for_training(animal,session)
avail_units = good_units_index(animal,session)
xdata       = get_in_trial(get_y(animal,session),animal,session,dozscore=True)

units = array(sorted(list(set(units))))
assert(set(units)==(set(avail_units)&set(units)))
pick  = array([where(avail_units==u)[0][0] for u in units])
ydata = ydata[:,pick]
u2    = avail_units[pick]
assert(all(u2==units))

If neuron becomes relatively quieter or less reliable then the weight assigned may become inappropriate for decoding. this affects our analyses, and would also physiologically affect a downstream neuron with fixed synaptic weights.

 - Show that stablly-tuned neurons have unstable SNR 
 - or unstable rates
 - These are equivalent for normalized signals, hmm... so I guess SNR is the only proxy
 
We define the Signal to Noise Ratio (SNR) for a single-cell tuning curve as the root mean-squared ratio of the location-triggered mean rate to location-triggered variance. 

$$
k^2 = \left<\frac{\mu^2}{\sigma^2}\right>
$$

Data checks out:

For example, no more than 8% of neurons that were in the top 20% in terms of tuning-curve stability were also consistently in the top 25% in terms of SNR. 

If neuron becomes relatively quieter or less reliable then the weight assigned may become inappropriate for decoding. this affects our analyses, and would also physiologically affect a downstream neuron with fixed synaptic weights.

In [None]:
#method = 'single_cell'
method = 'ascending_greedy'
reg    = 1e-3

pct    = 25

animal,sessions = use[1]
for animal,sessions in use:    
    print('Mouse',animal)
    units,uidxs = get_units_in_common(animal,sessions)
    Nunits = len(units)
    all_peak, _, all_tuned, all_SNRs = list(map(array,zip(*[get_SNR_metrics(animal,s,units) for s in sessions])))
    Δpeaks = np.max(all_peak,0)-np.min(all_peak,0)
    stable_peaks = Δpeaks < percentile(Δpeaks,pct)
    print('%d%%'%(100*mean(np.min(all_SNRs[:,stable_peaks],axis=0)>percentile(all_SNRs.ravel(),100-pct))))

For example, no more than 8% of neurons that were in the top 20% in terms of tuning-curve stability were also consistently in the top 25% in terms of SNR. 

Between 8-36% of neurons that had above-median stability also consistently had above-median SNR. 