# Learning Through Target Bursts (LTTB) - Figure 3

This notebook reproduces the results presented in `Figure 3` of the <a href="https://arxiv.org/abs/2201.11717">arXiv 2201.11717</a> preprint paper: Cristiano Capone<sup>\*</sup>, Cosimo Lupo<sup>\*</sup>, Paolo Muratore, Pier Stanislao Paolucci (2022) "*Burst-dependent plasticity and dendritic amplification support target-based learning and hierarchical imitation learning*". We test the `LTTB` model on a 3D-trajectory task.

Please give credit to this paper if you use or modify the code in a derivative work. This work is licensed under the Creative Commons Attribution 4.0 International License. To view a copy of this license, visit http://creativecommons.org/licenses/by/4.0/ or send a letter to Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.

<a rel="license" href="http://creativecommons.org/licenses/by/4.0/"><img alt="Creative Commons License" style="border-width:0" src="https://i.creativecommons.org/l/by/4.0/88x31.png" /></a><br />This work is licensed under a <a rel="license" href="http://creativecommons.org/licenses/by/4.0/">Creative Commons Attribution 4.0 International License</a>.

### Libraries Import & Parameter Loading

In this section we import the needed external libraries and load the model parameters for this task (via the `json` configuration file).

In [None]:
import numpy as np
import random
from importlib import reload
from tqdm import trange

import json
import glob

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.collections as mcoll
import matplotlib.path as mpath
import matplotlib.ticker as ticker
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

import lttb as lttb_module
#from env import Unlock
import env as env_module
#from env_module import Unlock

In [None]:
with open ('./config.json', 'r') as fp:
    all_pars = json.load(fp)

env_par = all_pars['BUTTON_FOOD']
low_par = all_pars['LOW_NETWORK']
high_par = all_pars['HIGH_NETWORK']

tb = env_par['tb']
tf = env_par['tf']
rt = env_par['rt']
rb = env_par['rb']
total_training_T = env_par['T']
max_T = env_par['max_T']

env_par['hint'] = (env_par['hint']=="True")
env_par['clump'] = (env_par['clump']=="True")
env_par['validate'] = (env_par['validate']=="True")
env_par['verbose'] = (env_par['verbose']=="True")

n_examples = env_par['n_examples']
total_training_T = low_par['T']

dt = low_par['dt']
tau_m = low_par['tau_m']

gamma = 1./low_par['du']
def f(x,gamma):
    return np.exp(x*gamma)/(np.exp(x*gamma)+1)

In [None]:
# ==== Environment Initialization ======
init = np.array ((0., 0.))
targ = np.array ((0., 1.))
btn = np.array ((0., 0.))

env = env_module.Unlock (init = init, targ = targ, btn = btn, unit = (env_par['dt'], env_par['dx']), \
                         max_T = env_par['max_T'], res = 20, rs = (env_par['r_lock'], env_par['r_targ']))

trainset = np.array (env_par['trainset'])
validset = np.array (env_par['validset'])
testset  = np.linspace (*env_par['testset'])

train_theta = trainset * np.pi / 180.
valid_theta = validset * np.pi / 180.
test_theta  = testset * np.pi / 180.

# equally-spaced buttons and random food on a circle
if True:
    train_thetas = np.linspace(0,2*np.pi,n_examples,endpoint=False)
    train_thetas_targ = [random.random()*2*np.pi for i in range(n_examples)]

train_bttns = np.array ([(1 * np.cos (t), 1 * np.sin (t)) for t in train_thetas])
train_targs = train_bttns + np.array ([(1 * np.cos (t), 1 * np.sin (t)) for t in train_thetas_targ])

# Here we ask the env for the expert behaviour
epar = {'offT' : (1, 1), 'steps' : (tb-1, tf-1), 'T' : (tb, tf)}
train_exp = [env.build_expert (targ, init, btn, **epar) for targ, btn in zip (train_targs, train_bttns)]
#valid_exp = [env.build_expert (targ, init, btn, **epar) for targ, btn in zip (valid_targs, valid_bttns)]

#### Functions

In [None]:
# ============= TEST FUNCTION =================
def test (lttb, env, testset, env_par):
    rt, rb = r = (env_par['rt'], env_par['rb'])
    size = np.shape(testset)[-1]

    hist = {'agent'  : np.zeros ((size, env_par['T'], 2)),
            'action' : np.zeros ((size, env_par['T'], env_par['O'])),
            'theta'  : np.zeros (size),
            'R'      : np.zeros (size)}

    tars = np.array ([(rt * np.cos (t), rt * np.sin (t)) for t in testset])
    btns =  [(0., rb)] * len (testset)

    for i, (targ, btn) in enumerate (zip (tars, btns)):
        env.reset (init = init, targ = targ, btn = btn)
        lttb.initialize(par)
        R = 0

        state = np.hstack ((env.encode (targ - init), env.encode (btn - init)))

        for t in range (par['T']):
            action, _ = lttb.step (apicalFactor = 0)
            state, r, done, agen = env.step (action)

            R = max (R, r)

            hist['action'][i, t] = action
            hist['agent'][i, t]  = agen

            if done: break

        hist['action'][i, t:] = np.nan
        hist['agent'][i, t:]  = agen
        hist['theta'][i]      = testset[i]
        hist['R'][i]          = R

    return hist

def low_target_list():
    
    SR_low_list = []
    
    for example_index in range(n_examples):
        
        low_lttb.initialize(low_par)
        low_lttb.y_targ = low_lttb.y_targ_collection[example_index]
        low_lttb.I_clock = low_lttb.I_clock_collection[example_index]
        
        for t in range(low_lttb.T-2):
            if(t<tb):
                low_lttb.cont = np.array([1,0])
            else:
                low_lttb.cont = np.array([0,1])
            low_lttb.step(apicalFactor = 1)
            SR_low = low_lttb.B_filt[:,t+1]
            
        SR_low_list.append(low_lttb.B_filt)
    
    return SR_low_list

def high_target_list():
    
    SR_high_list = []
    
    for example_index in range(n_examples):
        
        high_lttb.initialize(high_par)
        high_lttb.y_targ = high_lttb.y_targ_collection[example_index]
        high_lttb.I_clock = high_lttb.I_clock_collection[example_index]
        
        for t in range(high_lttb.T-2):
            high_lttb.cont = np.array([0])
            high_lttb.step(apicalFactor = 1)
            SR_high = high_lttb.B_filt_rec[:,t+1]
            
        SR_high_list.append(high_lttb.B_filt_rec)
    
    return SR_high_list

def short_test_low(use_low_rec=False):
    
    mses = np.zeros(n_examples)
    
    for example_index in range(n_examples):
        
        low_lttb.initialize(low_par)
        low_lttb.y_targ = low_lttb.y_targ_collection[example_index]
        low_lttb.I_clock = low_lttb.I_clock_collection[example_index]
        
        #run simulation
        for t in range(low_lttb.T-2):
            if(t<tb):
                low_lttb.cont = np.array([1,0])
            else:
                low_lttb.cont = np.array([0,1])
            low_lttb.step(apicalFactor = 0)
        
        if use_low_rec:
            SR_low = low_lttb.B_filt_rec[:,1:-2]
        else:
            SR_low = low_lttb.B_filt[:,1:-2]
        Y_low = low_lttb.Jout@SR_low
        mse_ro_train_low = np.std(low_lttb.y_targ[:,1:-2] - Y_low)**2
        
        mses[example_index] = mse_ro_train_low
    
    return mses

def short_test_high():
    
    mses = np.zeros(n_examples)
    
    for example_index in range(n_examples):
        
        high_lttb.initialize(high_par)
        high_lttb.y_targ = high_lttb.y_targ_collection[example_index]
        high_lttb.I_clock = high_lttb.I_clock_collection[example_index]
        
        #run simulation
        for t in range(high_lttb.T-2):
            high_lttb.cont = np.array([0])
            high_lttb.step(apicalFactor = 0)
        
        SR_high = high_lttb.B_filt_rec[:,1:-2]
        Y_high = high_lttb.Jout@SR_high + np.tile(high_lttb.Bias,(high_lttb.T-3,1)).T
        mse_ro_train_high = np.std(high_lttb.y_targ[:,1:-2] - Y_high)**2
        
        mses[example_index] = mse_ro_train_high
    
    return mses

def training_low(nIterRec=100, test_every=5, eta=5., eta_out=0.01, eta_W=0., eta_bias=0.0002, verbose=True, use_low_rec=False):
    
    ERRORS_low = np.zeros((int(nIterRec/test_every),n_examples))
    
    iterator = trange(nIterRec, desc = 'LTTB Low Network Training', leave = True)
    
    for iteration in iterator:
        
        #initialize simulation
    
        for example_index in range(n_examples):
            
            low_lttb.y_targ = low_lttb.y_targ_collection[example_index]
                
            #ON-LINE
                
            low_lttb.initialize(low_par)
            low_lttb.I_clock = low_lttb.I_clock_collection[example_index]
                
            #run simulation
            dH = 0
                
            for t in range(low_lttb.T-2):
                    
                if(t<tb):
                    low_lttb.cont = np.array([1,0])
                else:
                    low_lttb.cont = np.array([0,1])
                low_lttb.step(apicalFactor = 1)
                
                if use_low_rec:
                    dH = dH*(1-dt/tau_m) + dt/tau_m*low_lttb.S_filt[:,t]
                    DJ = np.outer( ( low_lttb.S_apic_dist[:,t+1] - f(low_lttb.VapicRec[:,t],gamma) )*(1-low_lttb.S_apic_prox[:,t])*low_lttb.S_wind_soma[:,t+1] ,dH)
                    low_lttb.J =  low_lttb.J + eta*DJ
                    SR_low = low_lttb.B_filt_rec[:,t+1]
                else:
                    SR_low = low_lttb.B_filt[:,t+1]
                
                Y_low = low_lttb.Jout@SR_low
                DJRO = np.outer(low_lttb.y_targ[:,t+1] - Y_low,SR_low.T)
                low_lttb.Jout = low_lttb.Jout + eta_out*DJRO
           
        ###### Test
        
        if (iteration+1)%test_every==0:
            
            mses = short_test_low(use_low_rec=use_low_rec)
            
            ERRORS_low[int(iteration/test_every),:] = mses
            
            if verbose:
                msg = 'LTTB Low Network Training. <MSE>: %.4f' % np.mean(mses)
                iterator.set_description(msg)

    return ERRORS_low

def training_high(nIterRec=100, test_every=5, eta=5., eta_out=0.01, eta_W=0., eta_bias=0.0002, verbose=True):
    
    ERRORS_high = np.zeros((int(nIterRec/test_every),n_examples))
    
    iterator = trange(nIterRec, desc = 'LTTB High Network Training', leave = True)
    
    for iteration in iterator:

    #initialize simulation
    
        for example_index in range(n_examples):
        
            high_lttb.y_targ = high_lttb.y_targ_collection[example_index]
        
            #ON-LINE
            
            high_lttb.initialize(high_par)
            high_lttb.I_clock = high_lttb.I_clock_collection[example_index]
            
            #run simulation
            dH = 0
            
            for t in range(high_lttb.T-2):
                
                high_lttb.cont = np.array([0])
                high_lttb.step(apicalFactor = 1)
                
                dH = dH*(1-dt/tau_m) + dt/tau_m*high_lttb.S_filt[:,t]
                DJ = np.outer( ( high_lttb.S_apic_dist[:,t+1] - f(high_lttb.VapicRec[:,t],gamma) )*(1-high_lttb.S_apic_prox[:,t])*high_lttb.S_wind_soma[:,t+1] ,dH)
                
                high_lttb.J =  high_lttb.J + eta*DJ
                    
                SR_high = high_lttb.B_filt_rec[:,t+1]
                Y_high = high_lttb.Jout@SR_high + high_lttb.Bias
                
                DJRO = np.outer(high_lttb.y_targ[:,t+1] - Y_high,SR_high.T)
                dBias = high_lttb.y_targ[:,t+1] - Y_high
                
                high_lttb.Jout = high_lttb.Jout + eta_out*DJRO
                high_lttb.Bias = high_lttb.Bias + eta_bias*dBias
        
        ###### Test
    
        if (iteration+1)%test_every==0:
        
            mses = short_test_high()
            
            ERRORS_high[int(iteration/test_every),:] = mses
        
            if verbose:
                msg = 'LTTB High Network Training. <MSE>: %.4f' % np.mean(mses)
                iterator.set_description(msg)


    return ERRORS_high

def make_segments(x, y):
    """
    Create list of line segments from x and y coordinates, in the correct format
    for LineCollection: an array of the form numlines x (points per line) x 2 (x
    and y) array
    """

    points = np.array([x, y]).T.reshape(-1, 1, 2)
    segments = np.concatenate([points[:-1], points[1:]], axis=1)
    return segments

def colorline(x, y, z=None, cmap=plt.get_cmap('copper'), norm=plt.Normalize(0.0, 1.0), linewidth=3, alpha=1.0):
    """
    http://nbviewer.ipython.org/github/dpsanders/matplotlib-examples/blob/master/colorline.ipynb
    http://matplotlib.org/examples/pylab_examples/multicolored_line.html
    Plot a colored line with coordinates x and y
    Optionally specify colors in the array z
    Optionally specify a colormap, a norm function and a line width
    """

    # Default colors equally spaced on [0,1]:
    if z is None:
        z = np.linspace(0.0, 1.0, len(x))

    # Special case if a single number:
    if not hasattr(z, "__iter__"):  # to check for numerical input -- this is a hack
        z = np.array([z])

    z = np.asarray(z)

    segments = make_segments(x, y)
    lc = mcoll.LineCollection(segments, array=z, cmap=cmap, norm=norm,
                              linewidth=linewidth, alpha=alpha)

    ax = plt.gca()
    ax.add_collection(lc)

    return lc

def make_test(btn, targ, show_plot=True, use_HighNetwork=False, verbose=True):
    
    closed_loop = True
    
    t_reach_btn = max_T
    t_reach_targ = max_T
    
    low_par['T'] = max_T
    low_lttb.T = max_T
    low_lttb.init_clock(low_par)
    low_lttb.initialize(low_par)
    low_lttb.y_targ = np.zeros((low_lttb.O,low_lttb.T))
    
    if use_HighNetwork:
        high_par['T'] = max_T
        high_lttb.T = max_T
        high_lttb.init_clock(high_par)
        high_lttb.initialize(high_par)
        high_lttb.y_targ = np.zeros((high_lttb.O,high_lttb.T))
    
    trajectory = np.zeros((2,max_T-1))
    trajectory[:,0] = init

    env.reset (init=init, targ=targ, btn=btn)

    agen = init
    
    cont_noise = 0.0
    
    R = 0

    ACTIONS = []
    ACTIONS_HIGH = []
    POSITION = []

    state = np.hstack ((env.encode (targ - init), env.encode (btn - init)))
    #lttb.I_clock[:,0] *= 0

    #state, r, done, agen = env.step ([0,0])
    
    low_lttb.I_clock[:,0] = state
    low_lttb.cont = np.array([1,0])
    
    if use_HighNetwork:
        high_lttb.I_clock[:,0] = state
        high_lttb.cont = np.array([0])

    context = []

    btn_on = False
    
    for t in range (max_T-2):
        
        if not use_HighNetwork:
            
            if not btn_on:
                low_lttb.cont = np.array([1,0]) + np.random.normal(loc=0, scale=cont_noise, size=2)
            else:
                low_lttb.cont = np.array([0,1]) + np.random.normal(loc=0, scale=cont_noise, size=2)
            
            if closed_loop:
                low_lttb.I_clock[:,low_lttb.t+1] = state
                if np.sqrt( np.sum((agen - btn)**2) ) < env_par['r_lock'] and not btn_on:
    
                    btn_on = True
                    t_reach_btn = t
                    low_lttb.cont = np.array([0,1]) + np.random.normal(loc=0, scale=cont_noise, size=2)
    
            else:
                low_lttb.I_clock[:,lttb.t+1] = low_lttb.I_clock_collection[cont_index][:,low_lttb.t+1]
                
                if(t<tb):
                    low_lttb.cont = np.array([1,0]) + np.random.normal(loc=0, scale=cont_noise, size=2)
                else:
                    low_lttb.cont = np.array([0,1]) + np.random.normal(loc=0, scale=cont_noise, size=2)

        else:
            
            Y_high = high_lttb.Jout@high_lttb.B_filt_rec[:,t] + high_lttb.Bias
            if Y_high[0] > Y_high[1]:
                low_lttb.cont = np.array([1,0])
            else:
                low_lttb.cont = np.array([0,1])
            #low_lttb.cont = Y_high
            
            if closed_loop:
                low_lttb.I_clock[:,low_lttb.t+1] = state
                high_lttb.I_clock[:,high_lttb.t+1] = state
                if np.sqrt( np.sum((agen - btn)**2) ) < env_par['r_lock'] and not btn_on:
                    btn_on = True
                    t_reach_btn = t
            else:
                low_lttb.I_clock[:,low_lttb.t+1] = low_lttb.I_clock_collection[cont_index][:,low_lttb.t+1]
                high_lttb.I_clock[:,high_lttb.t+1] = high_lttb.I_clock_collection[cont_index][:,high_lttb.t+1]
            
        context.append(low_lttb.cont)

        low_lttb.step (apicalFactor = 0)
        if use_HighNetwork:
            high_lttb.step (apicalFactor = 0)
            ACTIONS_HIGH.append(high_lttb.Jout@high_lttb.B_filt_rec[:,t] + high_lttb.Bias)
        action = low_lttb.Jout@low_lttb.B_filt[:,t]
        ACTIONS.append(action)

        state, r, done, agen = env.step (action)

        POSITION.append( list(agen) )

        R = max (R, r)

        if np.sqrt( np.sum((agen - targ)**2) )  < env_par['r_targ'] and btn_on:
            done = True
            t_reach_targ = t

        if done:
            break

    
    min_btn_dist = min([np.sqrt( np.sum((np.array(POSITION)[_,:] - btn)**2) ) for _ in range(len(POSITION))])
    min_targ_dist = min([np.sqrt( np.sum((np.array(POSITION)[_,:] - targ)**2) ) for _ in range(len(POSITION))])
    if verbose:
        print('btn_theta = %.3f' % btn_theta)
        print('R = %.3f' % R)
        print('min_btn_dist = %.3f' % min_btn_dist)
        print('min_food_dist = %.3f' % min_targ_dist)
        print('t_reach = [%d,%d]' % (t_reach_btn,t_reach_targ))
    
    if show_plot:
        plt.figure()
        plt.scatter(0,0,marker='X',color='black')
        plt.plot(np.array(POSITION)[:,0],np.array(POSITION)[:,1],'-')
        plt.plot(targ[0]+env_par['r_targ']*np.cos(np.linspace(0,2*np.pi,100)),targ[1]+env_par['r_targ']*np.sin(np.linspace(0,2*np.pi,100)),'k-')
        plt.plot(btn[0]+env_par['r_lock']*np.cos(np.linspace(0,2*np.pi,100)),btn[1]+env_par['r_lock']*np.sin(np.linspace(0,2*np.pi,100)),'r-')
        plt.xlim(-2.2,2.2)
        plt.ylim(-2.2,2.2)
        plt.show()
    
    res = {}
    res['POSITION'] = POSITION
    res['ACTIONS'] = ACTIONS
    if use_HighNetwork:
        res['ACTIONS_HIGH'] = ACTIONS_HIGH
    res['context'] = context
    res['S_soma'] = low_lttb.S_soma
    res['S_wind'] = low_lttb.S_wind
    res['B_filt'] = low_lttb.B_filt
    res['R'] = R
    res['R_timeResc'] = R * (1. if t_reach_targ<total_training_T else np.exp(-(t_reach_targ-total_training_T)/total_training_T))
    res['min_btn_dist'] = min_btn_dist
    res['min_targ_dist'] = min_targ_dist
    res['t_reach_btn'] = t_reach_btn
    res['t_reach_targ'] = t_reach_targ
    
    return res

def mean_test(nn):
    
    Rs = []
    Rs_timeResc = []
    btn_dists = []
    targ_dists = []
    
    for n in range(nn):

        btn_theta = random.random()*2*np.pi
        targ_theta = random.random()*2*np.pi

        btn_dist = 1.0
        targ_dist = 1.0

        btn_test = (btn_dist * np.cos (btn_theta), btn_dist * np.sin (btn_theta))
        targ_test = np.array(btn_test) + np.array((btn_dist * np.cos (targ_theta), btn_dist * np.sin (targ_theta)))

        res = make_test(btn_test, targ_test, max_T, show_plot=False, use_HighNetwork=True, verbose=False)

        Rs.append(res['R'])
        Rs_timeResc.append(res['R_timeResc'])
        btn_dists.append(res['min_btn_dist'])
        targ_dists.append(res['min_targ_dist'])
    
    R_mean = np.mean(Rs)
    btn_dist_mean = np.mean(btn_dists)
    targ_dist_mean = np.mean(targ_dists)
    button_rate = len([_ for _ in Rs if _>0])/len(Rs)
    food_rate = len([_ for _ in Rs if _==1])/len(Rs)
    
    return R_mean, btn_dist_mean, targ_dist_mean, button_rate, food_rate

def mean_of_vector(vec):
    
    L = max([len(_) for _ in vec])
    m = np.zeros(L)
    m2 = np.zeros(L)
    count = np.zeros(L)
    
    for v,vv in enumerate(vec):
        for l in range(len(vv)):
            m[l] += vv[l]
            m2[l] += vv[l]*vv[l]
            count[l] += 1
    
    for l in range(L):
        if count[l]>0:
            m[l] /= count[l]
            m2[l] /= count[l]
    
    s = np.sqrt(m2-m*m)
    
    return m,s

### Model Initialization & training

In [None]:
rescale_eta = True
factor_eta = 0.9
use_low_rec = False

#### Low network

In [None]:
low_lttb = lttb_module.LTTB(low_par)

low_lttb.y_targ_collection = []
low_lttb.I_clock_collection = []

for k in range(n_examples):
    low_lttb.y_targ_collection.append(train_exp[k][1])
    low_lttb.I_clock_collection.append(train_exp[k][0])

low_lttb.j_apical_cont[0:int(low_par['Ne']/2),1] = 0
low_lttb.j_apical_cont[int(low_par['Ne']/2):,0] = 0
low_lttb.j_apical_cont[low_par['Ne']:,:] = 0

ERRORS_low = []
Rs_training = []

eta_low = low_par['eta']
eta_out_low = low_par['eta_out']

In [None]:
# 3 epochs of 100 iterations each
# eta = 0
# eta_out = 0.03
# after each epoch, learning rates are diminished by factor_eta

In [None]:
nEpochs = 3
nIterRec = 100

low_lttb.T = total_training_T
low_par['T'] = total_training_T
low_lttb.initialize(low_par)

for epoch in range(nEpochs):
    ERRORS_low.extend( training_low(nIterRec=nIterRec, eta=eta_low, eta_out=eta_out_low, eta_bias=0, use_low_rec=use_low_rec) )
    if rescale_eta:
        eta_low *= factor_eta
        eta_out_low *= factor_eta

#### High network

In [None]:
high_lttb = lttb_module.LTTB(high_par)

high_lttb.y_targ_collection = []
high_lttb.I_clock_collection = []

high_targ = np.array([[1. if t<tb else 0. for t in range(tb+tf)],[0. if t<tb else 1. for t in range(tb+tf)]])

for k in range(n_examples):
    high_lttb.y_targ_collection.append(high_targ)
    high_lttb.I_clock_collection.append(train_exp[k][0]) # same as low network

high_lttb.j_apical_cont[:,:] = 0

ERRORS_high = []

eta_high = high_par['eta']
eta_out_high = high_par['eta_out']

In [None]:
# 2 epochs of 100 iterations each
# eta = 0.5
# eta_out = 0.03
# after each epoch, learning rates are diminished by factor_eta

In [None]:
nEpochs = 2
nIterRec = 100

high_lttb.T = total_training_T
high_par['T'] = total_training_T
high_lttb.initialize(high_par)

for epoch in range(nEpochs):
    ERRORS_high.extend( training_high(nIterRec=nIterRec, eta=eta_high, eta_out=eta_out_high, eta_bias=0.0002) )
    if rescale_eta:
        eta_high *= factor_eta
        eta_out_high *= factor_eta

#### Check the training

In [None]:
plt.figure(figsize=(12, 8))
plt.suptitle('LOW')

plt.subplot(221)
plt.plot(5*np.arange(1,1+len(ERRORS_low)),ERRORS_low)
plt.xlabel("iteration")
plt.ylabel("mse - LOW")

plt.subplot(222)
plt.plot(5*np.arange(1,1+len(ERRORS_low)),[np.mean(ERRORS_low[_]) for _ in range(len(ERRORS_low))], marker='.')
plt.xlabel("iteration")
plt.ylabel("aver. mse - LOW")

plt.subplot(223)
plt.plot(5*np.arange(1,1+len(ERRORS_high)),ERRORS_high)
plt.xlabel("iteration")
plt.ylabel("mse - HIGH")

plt.subplot(224)
plt.plot(5*np.arange(1,1+len(ERRORS_high)),[np.mean(ERRORS_high[_]) for _ in range(len(ERRORS_high))], marker='.')
plt.xlabel("iteration")
plt.ylabel("aver. mse - HIGH")

plt.show()

In [None]:
example_index = 4

low_lttb.T = total_training_T
low_par['T'] = total_training_T
low_lttb.initialize(low_par)
low_lttb.y_targ = low_lttb.y_targ_collection[example_index]
low_lttb.I_clock = low_lttb.I_clock_collection[example_index]

#run simulation
for t in range(low_lttb.T-2):
    if(t<tb):
        low_lttb.cont = np.array([1,0])
    else:
        low_lttb.cont = np.array([0,1])
    low_lttb.step(apicalFactor = 0)

SR_low = low_lttb.B_filt[:,1:-2]
Y_low = low_lttb.Jout@SR_low
mse_ro_train_low = np.std(low_lttb.y_targ[:,1:-2] - Y_low)**2
print(mse_ro_train_low)

plt.plot(low_lttb.y_targ[:,1:-2].T)
plt.plot(Y_low.T)
plt.show()

#### Simple closed-loop test

In [None]:
n_session = 1


btn_theta = random.random()*2*np.pi
targ_theta = random.random()*2*np.pi

btn_dist = 1.0
targ_dist = 1.0

btn_test = (btn_dist * np.cos (btn_theta), btn_dist * np.sin (btn_theta))
targ_test = np.array(btn_test) + np.array((btn_dist * np.cos (targ_theta), btn_dist * np.sin (targ_theta)))

res = make_test(btn_test, targ_test, max_T, show_plot=False, use_HighNetwork=True)

In [None]:
fs = 12
cm = 1/2.54  # centimeters in inches
final_dpi = 500

X = np.array(res['POSITION'])[:,0]
Y = np.array(res['POSITION'])[:,1]

t_reach_btn = min(res['t_reach_btn'],max_T-3)
t_reach_targ = min(res['t_reach_targ'],max_T-3)

x_min = min([np.min(X),targ_test[0]-env_par['r_targ'],btn_test[0]-env_par['r_lock']])
x_max = max([np.max(X),targ_test[0]+env_par['r_targ'],btn_test[0]+env_par['r_lock']])
dX = x_max-x_min
y_min = min([np.min(Y),targ_test[1]-env_par['r_targ'],btn_test[1]-env_par['r_lock']])
y_max = max([np.max(Y),targ_test[1]+env_par['r_targ'],btn_test[1]+env_par['r_lock']])
dY = y_max-y_min
d = 1.5*max(dX,dY)

fig, ax = plt.subplots()

path = mpath.Path(np.column_stack([X, Y]))
verts = path.interpolated(steps=3).vertices
x, y = verts[:, 0], verts[:, 1]
z = np.linspace(0, 1, len(x))
colorline(x, y, z, cmap=plt.get_cmap('gnuplot'), linewidth=2)
ax.set_xlim([0.5*(x_min+x_max)-0.5*d,0.5*(x_min+x_max)+0.5*d])
ax.set_ylim([0.5*(y_min+y_max)-0.5*d,0.5*(y_min+y_max)+0.5*d])
ax.scatter(0,0,marker='X',color='black',s=100)
ax.plot(targ_test[0]+env_par['r_targ']*np.cos(np.linspace(0,2*np.pi,100)),targ_test[1]+env_par['r_targ']*np.sin(np.linspace(0,2*np.pi,100)),'k-')
ax.plot(btn_test[0]+env_par['r_lock']*np.cos(np.linspace(0,2*np.pi,100)),btn_test[1]+env_par['r_lock']*np.sin(np.linspace(0,2*np.pi,100)),'r-')
#ax.scatter([1*np.cos(t) for t in np.linspace(0,2*np.pi,100)],[1*np.sin(t) for t in np.linspace(0,2*np.pi,100)], color='grey', marker='.', s=5)
#ax.scatter([btn_test[0]+1*np.cos(t) for t in np.linspace(0,2*np.pi,100)],[btn_test[1]+1*np.sin(t) for t in np.linspace(0,2*np.pi,100)], color='grey', marker='.', s=5)
ax.set_xlabel('x', fontsize=16)
ax.set_ylabel('y', fontsize=16)
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
ax.xaxis.set_minor_locator(ticker.MultipleLocator(0.5))
ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.5))

t1 = t_reach_btn/t_reach_targ
img = plt.imshow(np.array([[0,1]]), cmap=plt.get_cmap('gnuplot'))
img.set_visible(False)
axins = inset_axes(ax, width="5%", height="100%", bbox_to_anchor=(1.05, 0.00, 1, 1), bbox_transform=ax.transAxes, loc=3, borderpad=0)
cbar = fig.colorbar(img, cax=axins, ticks=[0,t1,1])
cbar.ax.set_yticklabels(['','','']) # [r'$t_0=0$', r'$t_1=%d$' % res['t_reach_btn'], r'$t_2=%d$' % res['t_reach_targ']])
cbar.set_label('time', rotation=90, fontsize=16, labelpad=100)

axins2 = inset_axes(ax, width="20%", height="100%", bbox_to_anchor=(1.15, 0.00, 1, 1), bbox_transform=ax.transAxes, loc=3, borderpad=0)
#cbar = fig.colorbar(img, cax=axins, ticks=[0,t1,1])
#cbar.ax.set_yticklabels([r'$t_0=0$', r'$t_1=%d$' % res['t_reach_btn'], r'$t_2=%d$' % res['t_reach_targ']])
#cbar.set_label('time', rotation=270, fontsize=16, labelpad=18)
axins2.plot()
axins2.plot(res['context'],[_/t_reach_targ for _ in range(t_reach_targ+1)])
if('ACTIONS_HIGH' in res.keys()):
    for n in range(2):
        axins2.plot(np.array(res['ACTIONS_HIGH']).T[n],[_/(t_reach_targ+1) for _ in range(t_reach_targ+1)], color='C'+str(n))
a,b = axins2.get_xlim()
axins2.set_xlim([b,a])
axins2.set_ylim([0,1])
axins2.set_yticks([0,t1,1])
axins2.set_yticklabels([r'$t_0=0$', r'$t_1=%d$' % res['t_reach_btn'], r'$t_2=%d$' % res['t_reach_targ']])
axins2.yaxis.tick_right()
axins2.set_xlabel('context', fontsize=16)

fig.tight_layout()
fig.subplots_adjust(top=0.96, bottom=0.13, left=-0.06, right=0.80)
fig_title = 'BF_apical_n%03d' % n_session
for ext in ['eps','pdf','png']:
    fig.savefig('./figures/Fig_3/' + fig_title + '.' + ext, dpi=final_dpi, transparent=False)
plt.show()

In [None]:
fs = 12
cm = 1/2.54  # centimeters in inches
final_dpi = 500
fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(12*cm, 12*cm))

for ax in axes:
    ax.tick_params(axis='both', which='major', labelsize=fs, pad=1)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['left'].set_linewidth(0.5)
    ax.spines['bottom'].set_linewidth(0.5)
    ax.xaxis.set_tick_params(width=0.5)
    ax.yaxis.set_tick_params(width=0.5)

ax = axes[0]
lb = [r'$\to$ button',r'$\to$ food']
clr = ['magenta','green']
if('ACTIONS_HIGH' in res.keys()):
    for n in range(2):
        ax.plot([_/(t_reach_targ+1) for _ in range(t_reach_targ+1)], np.array(res['ACTIONS_HIGH']).T[n], color=clr[n], label=lb[n])
    #ax.plot([_/t_reach_targ for _ in range(t_reach_targ+1)], res['context'])
ax.set_ylabel('', fontsize=fs)
#a,b = ax.get_ylim()
#ax.set_ylim([a,b])
ax.set_xlim([0,1])
ax.set_ylim([-0.05,1.55])
ax.set_yticks([0,0.5,1.0])
ax.set_xticks([0,t1,1])
ax.set_xticklabels([])
ax.text(-0.2, 0.5, 'output high', fontsize=fs, ha='center', va='center', rotation=90, \
        transform=ax.transAxes, rotation_mode='anchor')
#ax.text(0.5, 1.1, 'context\n(target selection)', fontsize=fs, ha='center', va='top', rotation=0, \
#        transform=ax.transAxes, rotation_mode='anchor')

handles, labels = ax.get_legend_handles_labels()
legend = fig.legend(
    handles,
    labels,
    borderaxespad=0.00,    # Small spacing around legend box
    title="context (target selection)",  # Title for the legend
    ncol=2,
    bbox_to_anchor=(0.42,1.05),
    bbox_transform=ax.transAxes,
    loc="upper center",   # Position of legend
    frameon=False,
    fontsize=fs,
    facecolor=None,
    edgecolor='black',
    handletextpad=0.5
)
plt.setp(legend.get_title(), multialignment='center', fontsize=fs)

ax = axes[1]
ax.plot([_/(t_reach_targ+1) for _ in range(t_reach_targ+1)], np.array(res['ACTIONS'])[:,0], color=clr[0], label=r'$v_x$')
ax.plot([_/(t_reach_targ+1) for _ in range(t_reach_targ+1)], np.array(res['ACTIONS'])[:,1], color=clr[1], label=r'$v_y$')
ax.set_xlim([0,1])
ax.set_xticks([0,t1,1])
ax.set_ylim([-1.5,3.5])
#ax.set_xticklabels([r'$t_0=0$', r'$t_1=%d$' % res['t_reach_btn'], r'$t_2=%d$' % res['t_reach_targ']])
ax.set_xticklabels([r'$0$', r'$t_{\rm button}$', r'$t_{\rm food}$'])
ax.xaxis.set_tick_params(pad=5)
ax.text(-0.2, 0.5, 'output low', fontsize=fs, ha='center', va='center', rotation=90, \
        transform=ax.transAxes, rotation_mode='anchor')
#ax.text(0.5, 1.1, 'velocities\n(target actuation)', fontsize=fs, ha='center', va='top', rotation=0, \
#        transform=ax.transAxes, rotation_mode='anchor')

handles, labels = ax.get_legend_handles_labels()
legend = fig.legend(
    handles,
    labels,
    borderaxespad=0.00,    # Small spacing around legend box
    title="velocities (target actuation)",  # Title for the legend
    ncol=2,
    bbox_to_anchor=(0.42,1.05),
    bbox_transform=ax.transAxes,
    loc="upper center",   # Position of legend
    frameon=False,
    fontsize=fs,
    facecolor=None,
    edgecolor='black',
    handletextpad=0.5
)
plt.setp(legend.get_title(), multialignment='center', fontsize=fs)

plt.tight_layout()
plt.subplots_adjust(left=0.20, bottom=0.10, right=0.93, top=0.95, wspace=None, hspace=0.15)
fig_title = 'BF_apical_n%03d_actions' % n_session
for ext in ['eps','pdf','png']:
    fig.savefig('./figures/Fig_3/' + fig_title + '.' + ext, dpi=final_dpi, transparent=False)

plt.show()

#### Many closed-loop tests

In [None]:
Rs = []
Rs_timeResc = []
btn_dists = []
targ_dists = []

th = np.linspace(0,360,100,endpoint=False)

for btn_theta in th:
    
    targ_theta = random.random()*2*np.pi
    
    btn_dist = 1.0
    targ_dist = 1.0
    
    btn_test = (btn_dist * np.cos (btn_theta), btn_dist * np.sin (btn_theta))
    targ_test = np.array(btn_test) + np.array((btn_dist * np.cos (targ_theta), btn_dist * np.sin (targ_theta)))
    
    res = make_test(btn_test, targ_test, max_T=1500, show_plot=False, use_HighNetwork=True, verbose=False)
    
    Rs.append(res['R'])
    Rs_timeResc.append(res['R_timeResc'])
    btn_dists.append(res['min_btn_dist'])
    targ_dists.append(res['min_targ_dist'])

In [None]:
plt.title('contesto sopra')
plt.plot(th, Rs, marker='.', color='C0', label='R')
plt.plot(th, Rs_timeResc, marker='.', color='C1', label='R(t)')
plt.plot(th, [np.mean(Rs) for _ in th], ls='--', color='red', label=None)
plt.ylabel('R')
plt.legend(loc='upper left')
plt.show()

plt.title('contesto sopra - n_examples = %d' % n_examples)
plt.plot(th, btn_dists, marker='.', color='red', label='button')
plt.plot(th, targ_dists, marker='.', color='black', label='food')
plt.plot([0,360], [env_par['r_lock'],env_par['r_lock']], ls='--', color='blue')
for theta in train_thetas:
    plt.plot([theta/2/np.pi*360,theta/2/np.pi*360],[0,1],ls='--',color='gray',lw=0.25)
plt.xlabel('angle of button')
plt.ylabel('min dist from btn/food')
plt.legend(loc='upper left')
plt.show()

bins = np.linspace(0,12,12, endpoint=False)
x = bins[:-1]
h,b = np.histogram(Rs, bins=bins, density=True)
plt.bar(x, h, width=0.5, color='C0', label='apical context')
plt.plot([np.mean(Rs),np.mean(Rs)], [0,1], ls='--', color='C0')
plt.xlabel('R')
plt.ylabel('density')
plt.legend(loc='upper right')
plt.ylim([0,1.0])
plt.show()

#### Train low and high network together, making tests in the meanwhile

In [None]:
n_session = 1

context_type = 'apical' # can be 'apical' or 'basal'

In [None]:
# init the low network

low_lttb = lttb_module.LTTB(low_par)

low_lttb.y_targ_collection = []
low_lttb.I_clock_collection = []

for k in range(n_examples):
    low_lttb.y_targ_collection.append(train_exp[k][1])
    low_lttb.I_clock_collection.append(train_exp[k][0])

if context_type == 'apical':
    low_lttb.j_apical_cont[0:int(low_par['Ne']/2),1] = 0
    low_lttb.j_apical_cont[int(low_par['Ne']/2):,0] = 0
    low_lttb.j_apical_cont[low_par['Ne']:,:] = 0
    low_lttb.j_basal_cont[:,:] = 0
elif context_type == 'basal':
    low_lttb.j_apical_cont[:,:] = 0
    low_lttb.j_basal_cont[0:int(low_par['Ne']/2),1] = 0
    low_lttb.j_basal_cont[int(low_par['Ne']/2):,0] = 0
    low_lttb.j_basal_cont[low_par['Ne']:,:] = 0
else:
    print('Wrong type of context selected!')

In [None]:
# init the high network

high_lttb = lttb_module.LTTB(high_par)

high_lttb.y_targ_collection = []
high_lttb.I_clock_collection = []

high_targ = np.array([[1. if t<tb else 0. for t in range(tb+tf)],[0. if t<tb else 1. for t in range(tb+tf)]])

for k in range(n_examples):
    high_lttb.y_targ_collection.append(high_targ)
    high_lttb.I_clock_collection.append(train_exp[k][0]) # same as low network

high_lttb.j_apical_cont[:,:] = 0

In [None]:
# init other variables

low_lttb.T = total_training_T
low_par['T'] = total_training_T
low_lttb.initialize(low_par)

high_lttb.T = total_training_T
high_par['T'] = total_training_T
high_lttb.initialize(high_par)

training_results = {}
training_results['ERRORS_low'] = []
training_results['ERRORS_high'] = []
training_results['R_mean'] = []
training_results['button_rate'] = []
training_results['food_rate'] = []
training_results['btn_dist_mean'] = []
training_results['targ_dist_mean'] = []

eta_low = 0.
eta_out_low = 0.05
eta_high = 0.2
eta_out_high = 0.01

reduce_eta = 0.95

mse_low = short_test_low()
mse_high = short_test_high()
R_mean, btn_dist_mean, targ_dist_mean, button_rate, food_rate = mean_test(50)
print('Before training...')
print(' R_mean: %f' % R_mean)
print(' average mse low: %f' % np.mean(mse_low))
print(' average mse high: %f' % np.mean(mse_high))
print(' button rate: %f' % button_rate)
print(' food rate: %f' % food_rate)
print(' average button dist: %f' % btn_dist_mean)
print(' average food dist: %f' % targ_dist_mean)
print()
training_results['R_mean'].append(R_mean)
training_results['button_rate'].append(button_rate)
training_results['food_rate'].append(food_rate)
training_results['ERRORS_low'].extend(np.array([mse_low]))
training_results['ERRORS_high'].extend(np.array([mse_high]))
training_results['btn_dist_mean'].append(btn_dist_mean)
training_results['targ_dist_mean'].append(targ_dist_mean)

In [None]:
# train the two networks together

nEpochs = 100
nIterRec = 10

for epoch in range(nEpochs):
    
    low_lttb.T = total_training_T
    low_par['T'] = total_training_T
    low_lttb.initialize(low_par)

    high_lttb.T = total_training_T
    high_par['T'] = total_training_T
    high_lttb.initialize(high_par)
    
    mse_low = training_low(nIterRec=nIterRec, test_every=nIterRec, eta=eta_low, \
                           eta_out=eta_out_low, eta_bias=0, use_low_rec=use_low_rec)
    eta_low *= reduce_eta
    eta_out_low *= reduce_eta
    
    mse_high = training_high(nIterRec=nIterRec, test_every=nIterRec, eta=eta_high, \
                             eta_out=eta_out_high, eta_bias=0.0002)
    eta_high *= reduce_eta
    eta_out_high *= reduce_eta
    
    R_mean, btn_dist_mean, targ_dist_mean, button_rate, food_rate = mean_test(50)
    
    print('Epoch: %d' % epoch)
    print(' R_mean: %f' % R_mean)
    print(' average mse low: %f' % np.mean(mse_low))
    print(' average mse high: %f' % np.mean(mse_high))
    print(' button_rate: %f' % button_rate)
    print(' food_rate: %f' % food_rate)
    print(' average button dist: %f' % btn_dist_mean)
    print(' average food dist: %f' % targ_dist_mean)
    print()
    training_results['R_mean'].append(R_mean)
    training_results['button_rate'].append(button_rate)
    training_results['food_rate'].append(food_rate)
    training_results['ERRORS_low'].extend(mse_low)
    training_results['ERRORS_high'].extend(mse_high)
    training_results['btn_dist_mean'].append(btn_dist_mean)
    training_results['targ_dist_mean'].append(targ_dist_mean)

In [None]:
plt.subplots(figsize=(16,7))
plt.suptitle('apical context')

plt.subplot(231)
plt.plot(training_results['R_mean'], marker='o')
plt.xlabel('training epochs')
plt.ylabel('average reward')

plt.subplot(232)
plt.plot(training_results['button_rate'], marker='o')
plt.xlabel('training epochs')
plt.ylabel('average button rate')

plt.subplot(233)
plt.plot(training_results['btn_dist_mean'], marker='o')
plt.xlabel('training epochs')
plt.ylabel('average distance from button')
plt.xlim([-2,len(training_results['btn_dist_mean'])+1])
plt.ylim([0,0.5])
plt.plot([-2,len(training_results['btn_dist_mean'])+1], [env_par['r_lock'],env_par['r_lock']], ls='--', \
         color='red')

plt.subplot(234)
plt.plot(range(len(training_results['ERRORS_low'])), [np.mean(training_results['ERRORS_low'][i]) for i in range(len(training_results['ERRORS_low']))], marker='o', label='low')
plt.plot(range(len(training_results['ERRORS_high'])), [np.mean(training_results['ERRORS_high'][i]) for i in range(len(training_results['ERRORS_high']))], marker='o', label='high')
plt.xlabel('training epochs')
plt.ylabel('average mses')
plt.ylim([0,0.5])
plt.grid(True)
plt.legend(loc='upper right')

plt.subplot(235)
plt.plot(training_results['food_rate'], marker='o')
plt.xlabel('training epochs')
plt.ylabel('average food rate')

plt.subplot(236)
plt.plot(training_results['targ_dist_mean'], marker='o')
plt.xlabel('training epochs')
plt.ylabel('average distance from food')
plt.xlim([-2,len(training_results['targ_dist_mean'])+1])
plt.ylim([0,0.5])
plt.plot([-2,len(training_results['targ_dist_mean'])+1], [env_par['r_targ'],env_par['r_targ']], ls='--', \
         color='black')

plt.show()

In [None]:
for i in range(len(training_results['ERRORS_low'])):
    training_results['ERRORS_low'][i] = [_ for _ in training_results['ERRORS_low'][i]]
for i in range(len(training_results['ERRORS_high'])):
    training_results['ERRORS_high'][i] = [_ for _ in training_results['ERRORS_high'][i]]
    
with open ('./data/Fig_3/Apical_vs_Basal/R_vs_epochs_apical_n%03d.json' % n_session, 'w') as fp:
    json.dump(training_results, fp)