In [1]:
import matplotlib.pyplot as plt
import numpy as np
import nengo

In [2]:
class Learner(nengo.processes.Process):
    def __init__(self, size_in, size_out, tau_learn):
        self.size_in = size_in
        self.size_out = size_out
        self.tau_learn = tau_learn   # convergence time of the learning process
        super().__init__(default_size_in=size_in*2+size_out,
                         default_size_out=size_out)
        
    def make_step(self, shape_in, shape_out, dt, rng, state=None):
        w = np.zeros((self.size_out, self.size_in))
        learn_scale = 1-np.exp(-dt/self.tau_learn)
        
        def step_learn(t, x, w=w, learn_scale=learn_scale):
            pre, meta, error = x[:self.size_in], x[self.size_in:self.size_in*2],x[self.size_in*2:]
            
            # compute the learning rate that will give the desired convergence time
            lr = np.sum(pre*pre*meta)
            if lr != 0:
                lr = 1.0 / lr
            lr = learn_scale * lr
            
            w -= lr * np.outer(error, pre*meta)            
            
            return w @ pre
        return step_learn

In [5]:
N = 100

model = nengo.Network()
with model:
    model.config[nengo.Connection].synapse=None
    stim = nengo.Node(1)
    target = nengo.Node(-1)
    output = nengo.Node(None, size_in=1)
    ens = nengo.Ensemble(n_neurons=N, dimensions=1, neuron_type=nengo.LIFRate())
    nengo.Connection(stim, ens)
    
    learn = nengo.Node(Learner(size_in=ens.n_neurons, size_out=target.size_out,
                               tau_learn=0.02))
    error = nengo.Node(None, size_in=1)
    nengo.Connection(ens.neurons, learn[:ens.n_neurons])
    # feed i
    nengo.Connection(nengo.Node([1]*ens.n_neurons), learn[ens.n_neurons:ens.n_neurons*2])
    nengo.Connection(ens.neurons, learn[ens.n_neurons:ens.n_neurons*2])
    nengo.Connection(learn, output)
    nengo.Connection(output, error)
    nengo.Connection(target, error, transform=-1)
    nengo.Connection(error, learn[-target.size_out:], synapse=0)
    
    p_error = nengo.Probe(error)

sim = nengo.Simulator(model)
with sim:
    sim.run(0.5)
    
plt.figure(figsize=(14,5))
plt.plot(sim.trange(), np.exp(-sim.trange()/learn.output.tau_learn), ls='--', label=r'$e^{-t/ \tau_{learn}}$')
plt.plot(sim.trange(), sim.data[p_error], lw=5, alpha=0.5, label='error')
plt.legend(fontsize=18)
plt.show()
    

In [6]:
import nengo_gui.jupyter
nengo_gui.jupyter.InlineGUI(model, cfg='metaplasticity.cfg')

error: global flags not at the start of the expression at position 399 (line 7, column 33)