# Full-FORCE demonstration
Eli Pollock, Jazayeri Lab, MIT Brain and Cognitive Sciences

The following demonstration is based on the full-FORCE algorithm described in [DePasquale et al. 2018](http://journals.plos.org/plosone/article?id=10.1371/journal.pone.0191527)

Thanks to Brian DePasquale for providing example code. For MATLAB demo, see [https://github.com/briandepasquale/full-FORCE-demos/tree/master/Matlab]. 

First, we create a function that will output our desired inputs, targets, and hints:

In [6]:
%load_ext autoreload
%autoreload 2
%matplotlib notebook
import numpy as np
import matplotlib.pyplot as plt
import time
import seaborn as sns
import joblib
sns.set(palette="deep", style="darkgrid")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Below, we can train and test the network using the FF_Demo module. The key component is the RNN object class, which has activity and weights as attributes as methods that allow for training, running, and testing the network on arbitrary tasks. Open up the module file to see how RLS is implemented.

Here, I create a parameter dictionary, hand-tune some of them, instantiate an RNN, train it with full-FORCE, and then test it. You can see where the algorithm spends most of its time with %lprun. It should only take a couple of minutes to get good training results.

Note that FF_Demo requires numpy, scipy, and matplotlib, but that's it! 

In [2]:
from trials import fullforce_oscillation_test

fullforce_oscillation_test(dt=0.001,showplots=1);

<IPython.core.display.Javascript object>

In [13]:
import FF_Demo
p = FF_Demo.create_parameters(dt=0.001)
p['g'] = 1.5 # From paper
p['network_size'] = 400
p['ff_num_batches'] = 20
p['ff_trials_per_batch'] = 10
p['ff_init_trials'] = 10
p['test_init_trials']=10

rnn = FF_Demo.RNN(p,1,1)

rnn.train(fullforce_oscillation_test, monitor_training=True)

Initializing

HBox(children=(IntProgress(value=0, description='init trials', max=10), HTML(value='')))


Training network...


HBox(children=(IntProgress(value=0, description='batch', max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, description='trial', max=10), HTML(value='')))

  trial_w_err_ratio[n_updates] = w_err_plus/w_err


<IPython.core.display.Javascript object>

HBox(children=(IntProgress(value=0, description='trial', max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, description='trial', max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, description='trial', max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, description='trial', max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, description='trial', max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, description='trial', max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, description='trial', max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, description='trial', max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, description='trial', max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, description='trial', max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, description='trial', max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, description='trial', max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, description='trial', max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, description='trial', max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, description='trial', max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, description='trial', max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, description='trial', max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, description='trial', max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, description='trial', max=10), HTML(value='')))


Done training!


In [14]:
rnn.plot_training()

<IPython.core.display.Javascript object>

In [15]:
rnn.test(fullforce_oscillation_test)

Initializing

HBox(children=(IntProgress(value=0, description='init trials', max=10), HTML(value='')))




<IPython.core.display.Javascript object>

Testing: 10 trials


HBox(children=(IntProgress(value=0, description='test trials', max=10), HTML(value='')))


Normalized error: 6.51385e-07


array([[  6.51384701e-07]])

In [21]:
joblib.dump(rnn, filename='data/fullforce_oscillation_rnn.p.z', compress=3)

['data/fullforce_oscillation_rnn.p.z']

In [None]:
# %load_ext line_profiler
# %lprun -f rnn.train rnn.train(fullforce_oscillation_test, monitor_training=1)

The training statistics include an "error ratio" that is the ratio of the error after the update to the error before the update. It should converge to 1, meaning that the network cannot do any better.

The error magnitude is the just the norm of the error. That should decrease to close to 0.

The weights norm is, as one would expect, the norm of all weights. That should converge on some constant value as the weights stabilize.