# Full-FORCE demonstration
Eli Pollock, Jazayeri Lab, MIT Brain and Cognitive Sciences

The following demonstration is based on the full-FORCE algorithm described in [DePasquale et al. 2018](http://journals.plos.org/plosone/article?id=10.1371/journal.pone.0191527)

Thanks to Brian DePasquale for providing example code. For MATLAB demo, see [https://github.com/briandepasquale/full-FORCE-demos/tree/master/Matlab]. 

First, we create a function that will output our desired inputs, targets, and hints:

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib notebook
import numpy as np
import matplotlib.pyplot as plt
import time
import seaborn as sns
import joblib
from tqdm import tnrange
sns.set(palette="deep", style="darkgrid")

Below, we can train and test the network using the FF_Demo module. The key component is the RNN object class, which has activity and weights as attributes as methods that allow for training, running, and testing the network on arbitrary tasks. Open up the module file to see how RLS is implemented.

Here, I create a parameter dictionary, hand-tune some of them, instantiate an RNN, train it with full-FORCE, and then test it. You can see where the algorithm spends most of its time with %lprun. It should only take a couple of minutes to get good training results.

Note that FF_Demo requires numpy, scipy, and matplotlib, but that's it! 

In [2]:
import FF_Demo
from trials import fullforce_poisson_clicks

fullforce_poisson_clicks(dt=0.001,showplots=True);

<IPython.core.display.Javascript object>

In [25]:
p = FF_Demo.create_parameters(dt=0.001)
p['g'] = 1.5 # From paper
p['network_size'] = 1000
p['ff_num_batches'] = 50
p['ff_trials_per_batch'] = 0
p['ff_init_trials'] = 10
p['test_init_trials']=10
p['ff_steps_per_update'] = 1

rnn = FF_Demo.RNN(p,2,1)

In [None]:
#rnn.train(fullforce_poisson_clicks, monitor_training=True)

In [3]:
rnn = joblib.load('data/fullforce_poisson_rnn.p.z')

In [42]:
rnn.plot_training()

<IPython.core.display.Javascript object>

In [4]:
rnn.p['test_trials'] = 5
rnn.test(fullforce_poisson_clicks)

Initializing

HBox(children=(IntProgress(value=0, description='init trials', max=10), HTML(value='')))




<IPython.core.display.Javascript object>

Testing: 5 trials


A proxy artist may be used instead.
See: http://matplotlib.org/users/legend_guide.html#creating-artists-specifically-for-adding-to-the-legend-aka-proxy-artists
  "aka-proxy-artists".format(orig_handle)


HBox(children=(IntProgress(value=0, description='test trials', max=5), HTML(value='')))


Normalized error: 0.22285


array([[ 0.22284976]])

In [8]:
inp, targ, = fullforce_poisson_clicks(dt=0.001)[:2]
nt = inp.shape[0]
N = 5000 # trials

inps = np.zeros((nt, inp.shape[1], N))
targs = np.zeros((nt, N))
L_all = np.zeros(N)
t_stims = np.zeros_like(L_all)

for i in tnrange(N):
    inps[:,:,i], targ, _, L_all[i], t_stims[i] = fullforce_poisson_clicks(dt=0.001)
    targs[:,i] = targ.ravel()
    
n_outputs = 20
outputs = rnn.test_batch(inps[:,:,:n_outputs], targs[:,:n_outputs], t_stims,
                         norm_only=False, inps_and_targs=fullforce_poisson_clicks)[0]

HBox(children=(IntProgress(value=0, max=5000), HTML(value='')))


Initializing

HBox(children=(IntProgress(value=0, description='init trials', max=10), HTML(value='')))


Testing: 20 trials


HBox(children=(IntProgress(value=0, description='test trials', max=20), HTML(value='')))




In [9]:
sns.set_style("white")
plt.figure()
plt.imshow((outputs - targs[:,:n_outputs]).T, aspect='auto', vmin=-2, vmax=2, cmap="PRGn")
plt.ylabel("Trial")
plt.title("Output - Target")
plt.colorbar();

<IPython.core.display.Javascript object>

In [10]:
# use only period from stim on to end of target
norms, decisions = rnn.test_batch(inps, targs, t_stims,
                       norm_only=True, norm_idx = [200, 2000],
                       inps_and_targs=fullforce_poisson_clicks)

Initializing

HBox(children=(IntProgress(value=0, description='init trials', max=10), HTML(value='')))


Testing: 5000 trials


HBox(children=(IntProgress(value=0, description='test trials', max=5000), HTML(value='')))




In [46]:
import pandas as pd
df = pd.DataFrame(data=np.stack((L_all, t_stims, norms, decisions), axis=1), columns=['L','dur', 'norm', 'decision'])
df['dur_bin'] = pd.cut(df.dur, 6, labels=False)
df.head()

Unnamed: 0,L,dur,norm,decision,dur_bin
0,27.0,1200.0,0.184667,1.0,2
1,28.0,1120.0,0.337339,1.0,2
2,7.0,900.0,0.085295,0.0,0
3,8.0,980.0,0.478526,1.0,0
4,19.0,900.0,0.07536,0.0,0


In [47]:
df.to_pickle('data/poisson_data.pkl')

In [90]:
# error as a function of L
fig,ax = plt.subplots(figsize=(10,8))
g = sns.pointplot(x="L",y="norm", data=df, ax=ax)
g.axes.set_xticks(np.arange(40, step=2))
g.axes.set_xticklabels(g.axes.get_xticks());
g.set_ylabel('Normalized error')

<IPython.core.display.Javascript object>

Text(0,0.5,'Normalized error')

In [84]:
fig,ax = plt.subplots(figsize=(10,8))
g = sns.pointplot(x="L",y="decision", data=df, join=False, ax=ax)
sns.regplot(x="L",y="decision", data=df, logistic=True, n_boot=100, ax=ax, scatter=False)
g.axes.set_xticks(np.arange(40, step=2))
g.axes.set_xticklabels(g.axes.get_xticks());

<IPython.core.display.Javascript object>

In [62]:
sns.lmplot(x="L", y="decision", data=df, logistic=True, scatter=False,
           ci=95, n_boot=100, hue="dur_bin", palette='inferno', markers=None, size=8);

<IPython.core.display.Javascript object>

## Look at eigenvalue spectra

In [44]:
def get_eig(rnn):
    J = np.transpose(rnn.rnn_par['rec_weights'])
    w,v = np.linalg.eig(J)
    return w

In [99]:
untrained_rnn = FF_Demo.RNN(p,2,1)
w_untrained = get_eig(untrained_rnn)
w = get_eig(rnn)

In [116]:
fig, ax = plt.subplots(1,2, figsize=(10,6))
ax[0].scatter(w_untrained.real, w_untrained.imag, alpha=0.25, label="Untrained")
ax[0].scatter(w.real, w.imag, alpha=0.5, label="Trained")
ax[0].set_xlabel(r'$\mathrm{Re}(\lambda)}$')
ax[0].set_ylabel(r'$\mathrm{Im}(\lambda)}$')
ax[0].axis('equal')
ax[0].legend()

sort_w = lambda x: np.sort(x.real**2 + x.imag**2)[::-1]
ax[1].plot(sort_w(w_untrained), label="Untrained")
ax[1].plot(sort_w(w), label="Trained")
ax[1].set_xlabel('EV #')
ax[1].set_ylabel(r'$\mid\lambda\mid$')
ax[1].legend()

<IPython.core.display.Javascript object>

<matplotlib.legend.Legend at 0x114a9d898>

In [8]:
# %load_ext line_profiler
# %lprun -f rnn.train rnn.train(fullforce_oscillation_test, monitor_training=1)

The training statistics include an "error ratio" that is the ratio of the error after the update to the error before the update. It should converge to 1, meaning that the network cannot do any better.

The error magnitude is the just the norm of the error. That should decrease to close to 0.

The weights norm is, as one would expect, the norm of all weights. That should converge on some constant value as the weights stabilize.