# Learning Through Target Bursts (LTTB) - Figure 1

This notebook reproduces the results presented in `Figure 1` of the <a href="https://arxiv.org/abs/2201.11717">arXiv 2201.11717</a> preprint paper: Cristiano Capone<sup>\*</sup>, Cosimo Lupo<sup>\*</sup>, Paolo Muratore, Pier Stanislao Paolucci (2022) "*Burst-dependent plasticity and dendritic amplification support target-based learning and hierarchical imitation learning*". We test the `LTTB` model on a 3D-trajectory task.

Please give credit to this paper if you use or modify the code in a derivative work. This work is licensed under the Creative Commons Attribution 4.0 International License. To view a copy of this license, visit http://creativecommons.org/licenses/by/4.0/ or send a letter to Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.

<a rel="license" href="http://creativecommons.org/licenses/by/4.0/"><img alt="Creative Commons License" style="border-width:0" src="https://i.creativecommons.org/l/by/4.0/88x31.png" /></a><br />This work is licensed under a <a rel="license" href="http://creativecommons.org/licenses/by/4.0/">Creative Commons Attribution 4.0 International License</a>.

### Libraries Import

In this section we import the needed external libraries.

In [2]:
import numpy as np
import glob
import json
import random
import matplotlib
import matplotlib.pyplot as plt
from tqdm import trange

import lttb as lttb_module

### Function Definitions

In this section we define several useful functions that will be used during the experiments.

In [3]:
def f(x,gamma):
    return np.exp(x*gamma)/(np.exp(x*gamma)+1)

# * Define clock and target
def init_clock_targ():
    
    lttb.y_targ_collection = []
    
    for k in range(n_contexts):
        lttb.init_targ(par)
        lttb.y_targ_collection.append(lttb.y_targ)
    
    
    lttb.init_clock(par)

# * -------------- TESTING FUNCTIONS --------------------
def short_test(apicalFactor=0):
    
    mses = np.zeros(n_contexts)
    
    for cont_index in range(n_contexts):
    
        lttb.cont = lttb.cont*0
        lttb.cont[cont_index] = 1
        lttb.y_targ = lttb.y_targ_collection[cont_index]
        lttb.initialize(par)
    
        #run simulation
        for t in range(lttb.T-2):
            lttb.step(apicalFactor)

        SR = lttb.B_filt_total[:,1:-2]
        Y = lttb.Jout@SR + np.tile(lttb.Bias,(lttb.T-3,1)).T
        mse_rec_train = np.std(lttb.y_targ[:,1:-2] - Y)**2
        mses[cont_index] = mse_rec_train
    
    return mses

def full_test(apicalFactor=0):
    
    stats = {}
    
    stats['targs'] = []
    stats['outputs'] = []
    stats['contexts'] = []
    stats['S_somas'] = []
    stats['S_winds'] = []
    stats['mses'] = np.zeros(n_contexts)
    if n_contexts==2:
        stats['mses_offDiag'] = np.zeros(n_contexts)
    
    for cont_index in range(n_contexts):
        
        context = np.zeros((lttb.T-2,n_contexts))
        context[:,cont_index] = np.array([1 for _ in range(lttb.T-2)])
        lttb.y_targ = lttb.y_targ_collection[cont_index]
        lttb.initialize(par)

        #run simulation
        for t in range(lttb.T-2):
            lttb.cont = context[t]
            lttb.step(apicalFactor)

        SR = lttb.B_filt_total[:,1:-2]
        Y = lttb.Jout@SR + np.tile(lttb.Bias,(lttb.T-3,1)).T
        
        stats['outputs'].append(Y)
        stats['contexts'].append(context)
        stats['S_somas'].append(lttb.S_soma)
        stats['S_winds'].append(lttb.S_wind)
        stats['targs'].append(lttb.y_targ[:,1:-2])
        stats['mses'][cont_index] = np.std(lttb.y_targ[:,1:-2] - Y)**2
        if n_contexts==2:
            wrong_targ = lttb.y_targ_collection[1-cont_index][:,1:-2]
            stats['mses_offDiag'][cont_index] = np.std(wrong_targ - Y)**2
    
    return stats

# * --------------------- TRAINING FUNCTION --------------------------------
def training(nIterRec=100, test_every=5, eta=5., eta_out=0.01, etaW=0., eta_bias=0.0002, 
             verbose = True):
    
    OUTER_ERRORS = np.zeros((int(nIterRec/test_every),n_contexts))
    INNER_ERRORS = np.zeros((int(nIterRec/test_every),n_contexts))
    
    iterator = trange(nIterRec, desc = 'LTTB Training', leave = True)

    for iteration in iterator:
    
        #initialize simulation
        
        for cont_index in range(n_contexts):
        
            lttb.cont = lttb.cont*0
            lttb.cont[cont_index] = 1
            lttb.y_targ = lttb.y_targ_collection[cont_index]
            lttb.initialize(par)
    
            #run simulation
            dH = 0
            for t in range(lttb.T-2):
                lttb.step(apicalFactor = 1)
    
                dH = dH*(1-dt/tau_m) + dt/tau_m*lttb.S_filt[:,t]
                DJ = np.outer( ( lttb.S_apic_dist[:,t+1] - f(lttb.VapicRec[:,t],gamma) )*(1-lttb.S_apic_prox[:,t])*lttb.S_wind_soma[:,t+1] ,dH)
                lttb.J =  lttb.J + eta*DJ
                
                if pin_Jrec:
                    lttb.J[Ne:Ne+Ni,:] *= 0
                    lttb.J[0:Ne,0:Ne] = np.maximum(0,lttb.J[0:Ne,0:Ne])
                    lttb.J[0:Ne,0:Ne] = np.minimum(100,lttb.J[0:Ne,0:Ne])
                    lttb.J[0:Ne,Ne:Ne+Ni] = np.minimum(0,lttb.J[0:Ne,Ne:Ne+Ni])
                    lttb.J[0:Ne,Ne:Ne+Ni] = np.maximum(-100,lttb.J[0:Ne,Ne:Ne+Ni])
                    np.fill_diagonal(lttb.J, 0.)
    
                SR = lttb.B_filt_total[:,t+1]
                Y = lttb.Jout@SR + lttb.Bias
                DJRO = np.outer(lttb.y_targ[:,t+1] - Y,SR.T)
                dBias = lttb.y_targ[:,t+1] - Y
                lttb.Jout = lttb.Jout + eta_out*DJRO
                lttb.Bias = lttb.Bias + eta_bias*dBias
        
        ###### Test
        
        if (iteration+1)%test_every==0:
            
            INNER_ERRORS[int(iteration/test_every),:] = np.std(lttb.B_filt_rec-lttb.B_filt)**2
            mses = short_test(apicalFactor = 0)
            OUTER_ERRORS[int(iteration/test_every),:] = mses
            
            if verbose:
                msg = 'LTTB Training. Outer MSEs: ' + ''.join([f'{mse:.4f} | ' for mse in mses])
                iterator.set_description(msg)

    return OUTER_ERRORS,INNER_ERRORS

def render_fig_v1(dct):
    
    fs = 12
    cm = 1/2.54  # centimeters in inches
    fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(12*cm, 12*cm))
    
    for ax in axes:
        ax.tick_params(axis='both', which='major', labelsize=fs, pad=1)
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.spines['left'].set_linewidth(0.5)
        ax.spines['bottom'].set_linewidth(0.5)
        ax.xaxis.set_tick_params(width=0.5)
        ax.yaxis.set_tick_params(width=0.5)
    
    ax = axes[0]
    if (dct['par']['sigma_apical_cont']==0 and dct['par']['sigma_basal_cont']==0):
        for d in range(dct['par']['n_contexts']):
            ax.plot(0.*dct['par']['context'].T[d], zorder=1, ls='--', color=['black','red'][d], lw=1, alpha=0.5)
    else:
        for d in range(dct['par']['n_contexts']):
            ax.plot(dct['par']['context'].T[d], zorder=1, ls='--', color=['black','red'][d], lw=1, alpha=0.5)
    ax.set_ylabel('', fontsize=fs)
    ax.set_xlim([0,1000])
    ax.text(-0.15, 0.5, 'context', fontsize=fs, ha='center', va='center', rotation=90, \
            transform=ax.transAxes, rotation_mode='anchor')
        
    ax = axes[1]
    for d in range(dct['par']['O']):
        ax.plot(dct['output'][0][d], zorder=0, ls='-', color='C' + str(d+1), lw=2)
    for d in range(dct['par']['O']):
        ax.plot(dct['target'][0][d], zorder=1, ls='--', color='C' + str(d+1), lw=1)
    ax.set_ylabel('', fontsize=fs)
    ax.set_xlim([0,1000])
    ax.text(-0.15, 0.5, 'trajectories', fontsize=fs, ha='center', va='center', rotation=90, \
            transform=ax.transAxes, rotation_mode='anchor')
    
    ax = axes[2]
    ax.scatter(dct['spikes_idx'][1], dct['spikes_idx'][0], color='orange', marker='.', s=2)
    ax.scatter(dct['bursts_idx'][1], dct['bursts_idx'][0], color='blue', marker='.', s=2)
    ax.set_xlabel('t', fontsize=fs)
    ax.set_ylabel('', fontsize=fs)
    ax.set_xlim([0,1000])
    ax.set_ylim([50,0])
    ax.set_xticks([0,500,1000])
    ax.set_yticks([50,25,0])
    ax.text(-0.15, 0.5, 'neuron id', fontsize=fs, ha='center', va='center', rotation=90, \
            transform=ax.transAxes, rotation_mode='anchor')
    
    plt.tight_layout()
    #plt.tight_layout(pad=0.05, w_pad=0.5, h_pad=2.0)
    plt.subplots_adjust(left=0.23, bottom=0.12, right=0.93, top=0.95, wspace=None, hspace=0.4)
    for ext in ['pdf','eps','png']:
        fig.savefig("./figures/Fig_1/Figure1_%s_v1.%s" % (dct['session_name'],ext), transparent=False)
    
    plt.show()

def render_fig_v2(dct):
    
    fs = 7
    cm = 1/2.54  # centimeters in inches
    fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(6*cm, 6*cm))
    
    for ax in axes:
        ax.tick_params(axis='both', which='major', labelsize=fs, pad=1)
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.spines['left'].set_linewidth(0.5)
        ax.spines['bottom'].set_linewidth(0.5)
        ax.xaxis.set_tick_params(width=0.5)
        ax.yaxis.set_tick_params(width=0.5)
    
    ax = axes[0]
    for d in range(dct['par']['O']):
        ax.plot(dct['output'][0][d], zorder=0, ls='-', color='C' + str(d+1), lw=1)
    for d in range(dct['par']['O']):
        ax.plot(dct['target'][0][d], zorder=1, ls='--', color='C' + str(d+1), lw=0.5)
    ax.set_ylabel('', fontsize=fs)
    ax.set_xlim([0,1000])
    ax.set_xticks([0,500,1000],[])
    ax.text(-0.175, 0.5, 'trajectories', fontsize=fs, ha='center', va='center', rotation=90, \
            transform=ax.transAxes, rotation_mode='anchor')
    
    ax = axes[1]
    ax.scatter(dct['spikes_idx'][1], dct['spikes_idx'][0], color='orange', marker='.', s=1)
    ax.scatter(dct['bursts_idx'][1], dct['bursts_idx'][0], color='blue', marker='.', s=1)
    ax.set_xlabel('t', fontsize=fs)
    ax.set_ylabel('', fontsize=fs)
    ax.set_xlim([-0.5,1000.5])
    ax.set_ylim([50.5,-0.5])
    ax.set_xticks([0,500,1000])
    ax.set_yticks([50,25,0])
    ax.text(-0.175, 0.5, 'neuron id', fontsize=fs, ha='center', va='center', rotation=90, \
            transform=ax.transAxes, rotation_mode='anchor')
    
    plt.tight_layout()
    #plt.tight_layout(pad=0.05, w_pad=0.5, h_pad=2.0)
    plt.subplots_adjust(left=0.18, bottom=0.13, right=0.94, top=0.96, wspace=None, hspace=0.15)
    for ext in ['pdf','eps','png']:
        fig.savefig("./Figure1_%s_v2.%s" % (dct['session_name'],ext), transparent=False, dpi=300)
    
    plt.show()
    
    return

### Model Initialization

In this section we load the model parameters for this task (via the `json` configuration file) and then we inizialize the network.

In [None]:
with open ('./config.json', 'r') as fp:
    par = json.load(fp)['FIGURE_1']

n_contexts = par['n_contexts']
T = par['T']
dt = par['dt']
tau_m = par['tau_m']
eta = par['eta']
eta_out = par['eta_out']
eta_bias = par['eta_bias']
Ne = par['Ne']
Ni = par['Ni']
gamma = 1./par['du']

In [5]:
lttb = lttb_module.LTTB(par)

if True:
    # generate new target
    init_clock_targ()
    Y_TARG_COLLECTION = []
    for k in range(n_contexts):
        Y_TARG_COLLECTION.append(lttb.y_targ_collection[k])
    Y_TARG_COLLECTION = np.array(Y_TARG_COLLECTION)
else:
    # recall previous target
    lttb.init_clock(par)
    lttb.y_targ_collection = []
    for k in range(n_contexts):
        lttb.y_targ_collection.append(Y_TARG_COLLECTION[k])
    lttb.y_targ_collection = np.array(lttb.y_targ_collection)

if pin_Jrec:
    lttb.J[Ne:Ne+Ni,:] *= 0
    lttb.J[0:Ne,0:Ne] = np.maximum(0,lttb.J[0:Ne,0:Ne])
    lttb.J[0:Ne,Ne:Ne+Ni] = np.minimum(0,lttb.J[0:Ne,Ne:Ne+Ni])
    np.fill_diagonal(lttb.J, 0.)

OUTER_ERRORS = []
INNER_ERRORS = []

nTotalEpochs = 0
nIterRec = 50
test_every = 5
rescale_eta = True
factor_eta = 0.99
pin_Jrec = True

### Model Training

In this section we train the model via the `training` function on the `3D Trajectory` task.

In [6]:
nEpochs = 25
nTotalEpochs += nEpochs

for epoch in range(nEpochs):
    outer_err,inner_err = training(nIterRec=nIterRec, test_every=test_every, \
                                   eta=eta, eta_out=eta_out, eta_bias=eta_bias, \
                                   verbose = True)
    OUTER_ERRORS.extend(outer_err)
    INNER_ERRORS.extend(inner_err)
    if rescale_eta:
        eta *= factor_eta
        eta_out *= factor_eta

LTTB Training. MSEs: 0.1992 | : 100%|██████████| 100/100 [02:42<00:00,  1.63s/it]
LTTB Training. MSEs: 0.1380 | : 100%|██████████| 100/100 [03:16<00:00,  1.97s/it]
LTTB Training. MSEs: 0.1157 | : 100%|██████████| 100/100 [03:08<00:00,  1.89s/it]
LTTB Training. MSEs: 0.0588 | : 100%|██████████| 100/100 [03:06<00:00,  1.87s/it]
LTTB Training. MSEs: 0.0326 | : 100%|██████████| 100/100 [03:06<00:00,  1.86s/it]


In [None]:
cm = 1/2.54  # centimeters in inches

fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(20*cm, 8*cm))

ax = axes[0]
ax.plot([l[0] for l in OUTER_ERRORS])
ax.set_xlabel('training iterations (x5)')
ax.set_ylabel('outer mse')
ax.grid(True)
ymin, ymax = ax.get_ylim()
plt.ylim([0,ymax])

ax = axes[1]
ax.plot([np.sqrt(l[0]) for l in INNER_ERRORS])
ax.set_xlabel('training iterations (x5)')
ax.set_ylabel('inner mse')
ax.grid(True)
ymin, ymax = ax.get_ylim()
ax.set_ylim([0,ymax])

plt.tight_layout()
plt.subplots_adjust(left=0.0, bottom=0.0, right=1.0, top=1.0, wspace=0.3, hspace=0)
#fig.savefig("./mse_inner.eps", transparent=False)

plt.show()

### Testing the Model

In this section we perform a `full_test` of the model and collect the results for visualization.

In [7]:
res = full_test(apicalFactor=1)

cont_index = 0

contexts = res['contexts']
context = contexts[cont_index]
Y = res['outputs'][cont_index]
targ = res['targs'][cont_index]
S_soma = res['S_somas'][cont_index]
S_wind = res['S_winds'][cont_index]

# first coordinate is for neurons
# second coordinate is for time
spikes_idx = np.where((S_wind==0) & (S_soma>0)) # spike somatici a finestra chiusa --> no burst
bursts_idx = np.where((S_wind>0) & (S_soma>0)) # spike somatici a finestra aperta --> burst

#L1 = list(zip(spikes_idx[0],spikes_idx[1]))
#L2 = list(zip(bursts_idx[0],bursts_idx[1]))
#L1f = [tup for tup in L1 if (tup[0]>2 and tup[0]<49)]

### Save/Load and Visualize Results

In this section we save/load the results and compose the final visualization used in `Figure_1` of the paper.

In [None]:
Load = False

if not Load:
    
    # SAVING
    
    dct = {}
    dct['session_name'] = 'n02_TeachON_TrainON_fixedJ'
    dct['par'] = par
    dct['par']['N_epochs'] = nTotalEpochs
    dct['par']['N_iterEpoch'] = nIterRec
    dct['par']['test_every'] = test_every
    dct['par']['rescale_eta'] = rescale_eta
    dct['par']['factor_eta'] = factor_eta
    dct['par']['pin_Jrec'] = pin_Jrec
    dct['target'] = [_ for _ in lttb.y_targ_collection]
    dct['contexts'] = [_.tolist() for _ in contexts]
    dct['output'] = [_ for _ in res['outputs']]
    for i,ii in enumerate(dct['target']):
        dct['target'][i] = [_ for _ in dct['target'][i]]
        for j,jj in enumerate(dct['target'][i]):
            dct['target'][i][j] = [round(_,8) for _ in dct['target'][i][j]]
    for i,ii in enumerate(dct['output']):
        dct['output'][i] = [_ for _ in dct['output'][i]]
        for j,jj in enumerate(dct['output'][i]):
            dct['output'][i][j] = [round(_,8) for _ in dct['output'][i][j]]
    dct['outer_mse_during_training'] = [[_ for _ in l] for l in OUTER_ERRORS] # [round(l[0],8) for l in OUTER_ERRORS]
    dct['inner_mse_during_training'] = [[_ for _ in l] for l in INNER_ERRORS] # [round(l[0],8) for l in INNER_ERRORS]
    dct['spikes_idx'] = [spikes_idx[0].tolist(),spikes_idx[1].tolist()]
    dct['bursts_idx'] = [bursts_idx[0].tolist(),bursts_idx[1].tolist()]
    
    with open ('./data/Fig_1/Figure1_%s.json' % dct['session_name'], 'w') as fp:
        json.dump(dct, fp)
    
    render_fig_v2(dct)

else:
    
    # LOADING
    
    session_name = '01_noTeach'
    
    with open ('./Figure1_n%s.json' % session_name) as fp:
        dct = json.load(fp)

### Many-sample averages

In this section we average over many samples the results for `Figure_1` of the paper.

In [3]:
files_TeachON = glob.glob('./data/Fig_1/Figure1_n??_TeachON_TrainON.json')
files_TeachON.sort()
print(files_TeachON)
fs_TeachON = []
for file in files_TeachON:
    with open (file) as fp:
        ff = json.load(fp)
        fs_TeachON.append(ff)

files_TeachOFF = glob.glob('./data/Fig_1/Figure1_n??_TeachOFF_TrainON.json')
files_TeachOFF.sort()
print(files_TeachOFF)
fs_TeachOFF = []
for file in files_TeachOFF:
    with open (file) as fp:
        ff = json.load(fp)
        fs_TeachOFF.append(ff)

[]
[]


In [None]:
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(8*cm, 6*cm))

Y = []
for ff,fff in enumerate(fs_TeachOFF):
    y = [_[0] for _ in fff['outer_mse_during_training']]
    Y.append(np.array(y))
    #plt.plot(y, label=None, lw=0.75, color='grey')
Y = np.array(Y)
M = np.mean(Y,axis=0)
S = np.std(Y,axis=0)
plt.fill_between(range(len(M)), M-S, M+S, alpha=0.3, color='C0')
plt.plot(M, label=None, lw=2, color='C0')

Y = []
for ff,fff in enumerate(fs_TeachOFF):
    y = [100000*_[0] for _ in fff['inner_mse_during_training']]
    Y.append(np.array(y))
    #plt.plot(y, label=None, lw=0.75, color='grey')
Y = np.array(Y)
M = np.mean(Y,axis=0)
S = np.std(Y,axis=0)
plt.fill_between(range(len(M)), M-S, M+S, alpha=0.3, color='C1')
plt.plot(M, label=None, lw=2, color='C1')

#plt.plot([0,200],[0.05,0.05],color='black',lw=2,ls='--')
plt.xlim([0,200])
plt.ylim([0.01,1])
plt.yscale('log')
plt.xlabel('training iterations (x5)')
plt.ylabel('outer mse')
plt.grid(True)
ymin, ymax = plt.gca().get_ylim()
#plt.ylim([0,ymax])

plt.tight_layout()
#plt.tight_layout(pad=0.05, w_pad=0.5, h_pad=2.0)
plt.subplots_adjust(left=0.22, bottom=0.20, right=0.96, top=0.96, wspace=None, hspace=0.15)
for ext in ['pdf','eps','png']:
    fig.savefig("./figures/Fig_1/Figure1_mean_outer_mse_TeachON.%s" % ext, transparent=False, dpi=300)

plt.show()