# Figure S2: Effect of daily variability on stability

Plot to explore a range of noise levels in the encoding, as well as drift on the readout synaptic weights.

In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

'''
We need to use start-method 'spawn' or 'forkserver' for compatability
with Jax. 'forkserver' is allegedly faster. Spawn/forkserver starts 
each worker from a clean state. Code and data required to compute the 
job are serialized and sent to the worker. Functions being run should
be class methods (easily serialized). Any static initialization must
be done again, manually, in each worker process. We also need to use
multiprocess rather than multiprocessing to get around picking errors.
(Multiprocess uses  'dill' instead of pickle). When working on a 
remote server, workers should only print ASCII-compatible outputs to 
avoid unicode encoding errors in the worker processes. Also note that
both Jax and Multiprocessing will fail if the process IDs on the host
system have been exhausted, with obscure errors. 
'''
import traceback
import multiprocess as multi
try:
    multi.set_start_method('spawn')
except RuntimeError:
    traceback.print_exc()
    print('couldn\'t set context')
    
import config, master, standard_options
from config           import *
from master           import *
from standard_options import *

import parallel
from parallel import _parmap, helper, try_variations, methods, method_names

# If runnign parallel jobs, limit cores/job to 1
# Disable logging (too verbose)
limit_cores()
debug = False
master.PRINT_LOGGING = not debug


In [None]:
nseeds     = 20   # Number of replicas
εθ         = 0.5  # Error threshold for code "failure"
maxr       = 1.0  # i.i.d. encoding tuning variability per timepoint
maxn       = 0.25 # fraction decoding weight noise per timepoint
ntest      = 9
test_r     = linspace(0.0,maxr,ntest)
test_n     = linspace(0.0,maxn,ntest)
methods    = ('hebbhomeo','predictive','recurrent')
geometries = ('ring','line','tee')

noiseopts  = {
    **options,
    'fail_early':False,
    'features':'ougaussian',
    'T':4005
}

rjobs = [
{
**noiseopts,
'seed':i,
'r':r,
'method':m,
'geometry':t,
'readout':t,
**rates[m],
}
for i in range(nseeds)
for r in test_r
for m in methods
for t in geometries]

njobs = [{
**noiseopts,
'seed':i,
'n':n,
'method':m,
'geometry':t,
'readout':t,
**rates[m],
}
for i in range(nseeds)
for n in test_n
for m in methods
for t in geometries]

results  = _parmap(helper,njobs+rjobs,debug=debug)
notify('simulating readout weight drift. done')



Shuffling job sequence
Parallel using 20 cores
>>> 20 cores available.
>>> limited each process to 1 cores.
Preparing to run 3240 jobs
Starting...
[███████████████▉                        ] 39.692% 

In [None]:
results_n = results[:len(njobs)]
results_r = results[len(njobs):]

In [None]:
NTOP  = len(geometries)
NMETH = len(methods)
NR    = len(test_r)
NN    = len(test_n)
NS    = nseeds

nmap = dict(zip([(i,n,m,t)
       for i in range(nseeds)
       for n in range(NN)
       for m in range(NMETH)
       for t in range(NTOP)],
               results_n))
rmap = dict(zip([(i,n,m,t)
       for i in range(nseeds)
       for n in range(NR)
       for m in range(NMETH)
       for t in range(NTOP)],
               results_r))

In [None]:
def find_first(x):
    return find(x)[0] if np.any(x) else len(x)

Δ = noiseopts['Δ']

maxtime    = noiseopts['T']
# Collect results
all_replicas_r = {}
all_replicas_n = {}
for it,geometry in enumerate(geometries):
    for im,method in enumerate(methods):
        for iv in range(ntest):
            all_replicas_r[im,iv,it] = array([Δ*find_first(rmap[i,iv,im,it]>εθ) for i in range(nseeds)])
            all_replicas_n[im,iv,it] = array([Δ*find_first(nmap[i,iv,im,it]>εθ) for i in range(nseeds)])
notify('Supplemental Figure 2 finished simulating')

In [None]:
import matplotlib as mpl
mpl.rcParams['savefig.dpi']=mpl.rcParams['figure.dpi']=140

method_names = (
    'Response normalization',
    'Recurrent feedback',
    'Linear-nonlinear map')

figure(figsize=(PNAS_SMALL_WIDTH,3.5))
subplots_adjust(wspace=0.2,hspace=1,left=0.13,right=0.97,bottom=0.1,top=0.85)

axs = {}
for im,method in [*enumerate(methods)]:
    axs[0,im] = subplot(2,3,im+1)
    axs[1,im] = subplot(2,3,im+3+1)

colors = [MAUVE,OCHRE,AZURE]
    
for it,geometry in enumerate(geometries):
    # Plot results for tuning variability
    for im,method in [*enumerate(methods)]:
        sca(axs[0,im])
        reps = [all_replicas_r[im,iv,it] for iv in range(ntest)]
        show = np0.median(reps,axis=1)>Δ
        colored_boxplot([*array(reps)[show]],arange(ntest)[show]+(it-1.5)/4,
                        color=colors[it],
                        widths=0.1,
                        whis=[5,95],
                        linewidth=0.8)
        title(method_names[im],va='top',fontsize=MEDIUM,pad=10)
        simpleaxis()
        limit = ntest-1
        xticks(np0.linspace(-1.5/4,limit+(2-1.5)/4,5),['0%','','','','%02d%%'%(100*maxr)] if im==0 else ['',]*5,fontsize=6)
        xlim(-1.5/4,limit+(2-1.5)/4)
        ylim(0,maxtime)
        gca().yaxis.set_tick_params(labelsize=6)
        if im==3:
            axhline(maxtime,color=TURQUOISE,linestyle='-',lw=1,zorder=inf)
            text(limit+pixels_to_xunits(15),maxtime,'Max.\nIterations',
                 ha='left',va='top',fontsize=6)

    # Plot results for readout synapse drift
    for im,method in [*enumerate(methods)]:
        sca(axs[1,im])
        reps = [all_replicas_n[im,iv,it] for iv in range(ntest)]
        show = np0.median(reps,axis=1)>Δ
        colored_boxplot([*array(reps)[show]],arange(ntest)[show]+(it-1.5)/4,
                        color=colors[it],
                        widths=0.1,
                        whis=[5,95],
                        linewidth=0.8)
        title(method_names[im],va='top',fontsize=MEDIUM,pad=10)
        simpleaxis()
        limit = ntest-1
        xticks(np0.linspace(-1.5/4,limit+(2-1.5)/4,6),['0%','','','','','%02d%%'%(100*maxn)] if im==0 else ['',]*6,fontsize=6)
        xlim(-1.5/4,limit+(2-1.5)/4)
        ylim(0,maxtime)
        gca().yaxis.set_tick_params(labelsize=6)
        if im==3:
            axhline(maxtime,color=TURQUOISE,linestyle='-',lw=1,zorder=inf)
            text(limit+pixels_to_xunits(15),maxtime,'Max.\nIterations',
                 ha='left',va='top',fontsize=6)
sca(axs[0,0])
xl,yl = xlim(),ylim()
for it,geometry in enumerate(geometries):
    scatter([-100],[-100],marker='s',s=20,color=colors[it],label=geometry.title())
xlim(*xl); ylim(*yl)
nice_legend(fontsize=SMALL)
sca(axs[0,0])
xlabel('Excess variability',fontsize=6, labelpad=3);
ylabel('Survival Time\n(# reconfigurations)',fontsize=SMALL)
subfigurelabel('(a) Per-timepoint encoding variability',fontsize=7,dx=70,dy=30)
sca(axs[1,0])
xlabel('Decoding weight drift',fontsize=6, labelpad=3);
ylabel('Survival Time\n(# reconfigurations)',fontsize=SMALL)
subfigurelabel('(b) Per-timepoint readout weight drift',fontsize=7,dx=70,dy=30)
for i,j in [(0,1),(0,2),(1,1),(1,2)]:
    sca(axs[i,j])
    yticks([0,maxtime-options['Δ']],['',''])
    ylim(0,maxtime-options['Δ'])
for i,j in [(0,0),(1,0)]:
    sca(axs[i,j])
    ylim(0,maxtime-options['Δ'])
    yticks([0,maxtime-options['Δ']],['0',str(maxtime//options['K'])])
figurebox('w')
savefig('s2.pdf')
notify('Supplemental Figure 2 finished plotting')