diff --git a/.gitignore b/.gitignore index 72fff0e..603b5b3 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ dist *.swp .ipynb_checkpoints/ .cache +.pytest_cache/ # --- data files: *.json diff --git a/CHANGES.rst b/CHANGES.rst index 5899a4e..bb7a0ec 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -22,7 +22,12 @@ Release history 0.1.1 (unreleased) ================== +**Added** +- Added the association matrix learning rule (AML) + to learn associations from cue vectors to target vectors + in a one-shot fashion without catastrophic forgetting. + (`#72 `_) 0.1.0 (March 14, 2018) ====================== diff --git a/docs/learning_rules.rst b/docs/learning_rules.rst index cf4455f..a47a49d 100644 --- a/docs/learning_rules.rst +++ b/docs/learning_rules.rst @@ -9,4 +9,6 @@ can be used. .. default-role:: obj +.. autoclass:: nengo_extras.learning_rules.AML + .. autoclass:: nengo_extras.learning_rules.DeltaRule diff --git a/nengo_extras/learning_rules.py b/nengo_extras/learning_rules.py index dc0e86e..a8827a8 100644 --- a/nengo_extras/learning_rules.py +++ b/nengo_extras/learning_rules.py @@ -3,13 +3,122 @@ import numpy as np from nengo.builder import Builder, Signal -from nengo.builder.operator import DotInc, ElementwiseInc, Reset, SimPyFunc +from nengo.builder.connection import get_eval_points, solve_for_decoders +from nengo.builder.operator import ( + DotInc, ElementwiseInc, Operator, Reset, SimPyFunc) from nengo.exceptions import ValidationError from nengo.learning_rules import LearningRuleType from nengo.params import EnumParam, FunctionParam, NumberParam from nengo.synapses import Lowpass +class AML(LearningRuleType): + r"""Association matrix learning rule (AML). + + Enables one-shot learning without catastrophic forgetting of outer product + association matrices. + + The cue is provided by the pre-synaptic ensemble. The error signal is split + up: ``error[0]`` provides a scaling factor to the learning rate. + ``error[1]`` provides a decay rate (i.e., weights are multiplied with this + value in every time step), ``error[2:]`` provides the target vector. + + The update is given by: + + decoders[...] *= error[1] # decay + decoders[...] += alpha * error[0] * error[2:, None] * np.dot( + pre, base_decoders.T) + + where *alpha* is the learning rate adjusted for *dt* and *base_decoders* + is the decoder matrix for decoding the identity from the pre-ensemble. + + Parameters + ---------- + d : int + Dimensionality of input and output vectors (error signal will be + *d+2*). + learning_rate : float, optional + Learning rate (increase of dot product similarity per second). + """ + error_type = 'decoded' + modifies = 'decoders' + + def __init__(self, d, learning_rate=1.): + super(AML, self).__init__(learning_rate, size_in=d + 2) + + +class SimAML(Operator): + def __init__(self, learning_rate, base_decoders, pre, error, decoders, + tag=None): + super(SimAML, self).__init__(tag=tag) + + self.learning_rate = learning_rate + self.base_decoders = base_decoders + + self.sets = [] + self.incs = [] + self.reads = [pre, error] + self.updates = [decoders] + + def make_step(self, signals, dt, rng): + base_decoders = self.base_decoders + pre = signals[self.pre] + error = signals[self.error] + decoders = signals[self.decoders] + alpha = self.learning_rate * dt + + def step_assoc_learning(): + scale = error[0] + decay = error[1] + target = error[2:] + decoders[...] *= decay + decoders[...] += alpha * scale * target[:, None] * np.dot( + pre, base_decoders.T) + + return step_assoc_learning + + @property + def pre(self): + return self.reads[0] + + @property + def error(self): + return self.reads[1] + + @property + def decoders(self): + return self.updates[0] + + +@Builder.register(AML) +def build_aml(model, aml, rule): + conn = rule.connection + rng = np.random.RandomState(model.seeds[conn]) + + error = Signal(np.zeros(rule.size_in), name="aml:error") + model.add_op(Reset(error)) + model.sig[rule]['in'] = error + + pre = model.sig[conn.pre_obj]['in'] + decoders = model.sig[conn]['weights'] + + encoders = model.params[conn.pre_obj].encoders + gain = model.params[conn.pre_obj].gain + bias = model.params[conn.pre_obj].bias + + eval_points = get_eval_points(model, conn, rng) + targets = eval_points + + x = np.dot(eval_points, encoders.T) + + wrapped_solver = (model.decoder_cache.wrap_solver(solve_for_decoders) + if model.seeded[conn] else solve_for_decoders) + base_decoders, _ = wrapped_solver(conn, gain, bias, x, targets, rng=rng) + + model.add_op(SimAML( + aml.learning_rate, base_decoders, pre, error, decoders)) + + class DeltaRuleFunctionParam(FunctionParam): function_test_size = 8 # arbitrary size to test function diff --git a/nengo_extras/ocl.py b/nengo_extras/ocl.py new file mode 100644 index 0000000..f4ba927 --- /dev/null +++ b/nengo_extras/ocl.py @@ -0,0 +1,156 @@ +import mako +import nengo_ocl +import numpy as np +import pyopencl as cl + + +def plan_aml_decode(queue, pre, base_decoders, decoded, tag=None): + assert pre.ctype == base_decoders.ctype == decoded.ctype + assert len(pre) == len(base_decoders) == len(decoded) + assert np.all(pre.shape0s == base_decoders.shape1s) + assert np.all(base_decoders.shape0s == decoded.shape0s) + + text = ''' + __kernel void aml_decode( + __global const int *ds, + __global const int *ns, + __global const int *pre_stride0s, + __global const int *pre_starts, + __global const ${type} *pre_data, + __global const int *base_decoders_stride0s, + __global const int *base_decoders_starts, + __global const ${type} *base_decoders_data, + __global const int *decoded_stride0s, + __global const int *decoded_starts, + __global ${type} *decoded_data + ) { + const int i = get_global_id(0); + const int k = get_global_id(1); + + const int d = ds[k]; + const int n = ns[k]; + + __global const ${type} *pre = pre_data + pre_starts[k]; + __global const ${type} *base_decoders = base_decoders_data + + base_decoders_starts[k]; + __global ${type} *decoded = decoded_data + decoded_starts[k]; + + if (i < n) { + ${type} x = 0.; + for (int s = 0; s < d; ++s) { + x += base_decoders[i * base_decoders_stride0s[k] + s] * pre[s]; + } + decoded[i] = x; + } + } + ''' + + textconf = dict(type=pre.ctype) + text = nengo_ocl.utils.as_ascii(mako.template.Template( + text, output_encoding='ascii').render(**textconf)) + + full_args = ( + base_decoders.cl_shape1s, base_decoders.cl_shape0s, + pre.cl_stride0s, pre.cl_starts, pre.cl_buf, + base_decoders.cl_stride0s, base_decoders.cl_starts, + base_decoders.cl_buf, + decoded.cl_stride0s, decoded.cl_starts, decoded.cl_buf, + ) + _fn = cl.Program(queue.context, text).build().aml_decode + _fn.set_args(*(arr.data for arr in full_args)) + + lsize = None + gsize = (base_decoders.shape0s.max(), len(pre)) + plan = nengo_ocl.plan.Plan( + queue, _fn, gsize, lsize=lsize, name="cl_aml_decode", tag=tag) + plan.full_args = full_args # prevent garbage collection + plan.flops_per_call = np.sum( + base_decoders.shape0s * base_decoders.shape1s * 2 + + base_decoders.shape1s * 2) + plan.bw_per_call = decoded.nbytes + pre.nbytes + base_decoders.nbytes + + return plan + + +def plan_aml(queue, error, decoders, alpha, decoded, tag=None): + assert error.ctype == decoders.ctype == alpha.ctype == decoded.ctype + assert len(error) == len(decoders) == len(alpha) == len(decoded) + assert np.all(error.shape0s - 2 == decoders.shape0s) + + text = ''' + __kernel void aml( + __global const int *ds, + __global const int *ns, + __global const int *error_stride0s, + __global const int *error_starts, + __global const ${type} *error_data, + __global const int *decoders_stride0s, + __global const int *decoders_starts, + __global ${type} *decoders_data, + __global const int *decoded_stride0s, + __global const int *decoded_starts, + __global const ${type} *decoded_data, + __global const ${type} *alphas + ) { + const int ij = get_global_id(0); + const int k = get_global_id(1); + + const int d = ds[k]; + const int n = ns[k]; + const int i = ij / n; + const int j = ij % n; + + __global ${type} *decoders = decoders_data + decoders_starts[k]; + const ${type} scale = error_data[error_starts[k]]; + const ${type} decay = error_data[error_starts[k] + 1]; + const ${type} error = error_data[error_starts[k] + i + 2]; + const ${type} decoded = decoded_data[decoded_starts[k] + j]; + const ${type} alpha = alphas[k]; + + if (i < d) { + decoders[i * decoders_stride0s[k] + j] *= decay; + decoders[i * decoders_stride0s[k] + j] += alpha * scale * error * + decoded; + } + } + ''' + + textconf = dict(type=error.ctype) + text = nengo_ocl.utils.as_ascii(mako.template.Template( + text, output_encoding='ascii').render(**textconf)) + + full_args = ( + decoders.cl_shape0s, decoders.cl_shape1s, + error.cl_stride0s, error.cl_starts, error.cl_buf, + decoders.cl_stride0s, decoders.cl_starts, decoders.cl_buf, + decoded.cl_stride0s, decoded.cl_starts, decoded.cl_buf, + alpha, + ) + _fn = cl.Program(queue.context, text).build().aml + _fn.set_args(*(arr.data for arr in full_args)) + + lsize = None + gsize = (decoders.sizes.max(), len(error)) + plan = nengo_ocl.plan.Plan( + queue, _fn, gsize, lsize=lsize, name="cl_aml", tag=tag) + plan.full_args = full_args # prevent garbage collection + plan.flops_per_call = np.sum(2 * (error.shape0s * decoded.shape0s)) + plan.bw_per_call = ( + decoded.nbytes + error.nbytes + alpha.nbytes + decoders.nbytes) + + return plan + + +class AmlSimulator(nengo_ocl.Simulator): + def plan_SimAML(self, ops): + alpha = self.Array([op.learning_rate * self.model.dt for op in ops]) + base_decoders = self.RaggedArray( + [op.base_decoders for op in ops], dtype=np.float32) + pre = self.all_data[[self.sidx[op.pre] for op in ops]] + error = self.all_data[[self.sidx[op.error] for op in ops]] + decoders = self.all_data[[self.sidx[op.decoders] for op in ops]] + decoded = self.RaggedArray( + [np.zeros(op.decoders.shape[1]) for op in ops], dtype=np.float32) + return [ + plan_aml_decode(self.queue, pre, base_decoders, decoded), + plan_aml(self.queue, error, decoders, alpha, decoded)] diff --git a/nengo_extras/tests/test_learning_rules.py b/nengo_extras/tests/test_learning_rules.py index be7a63d..524dea8 100644 --- a/nengo_extras/tests/test_learning_rules.py +++ b/nengo_extras/tests/test_learning_rules.py @@ -5,7 +5,74 @@ import pytest -from nengo_extras.learning_rules import DeltaRule +from nengo_extras.learning_rules import AML, DeltaRule + + +@pytest.mark.slow +def test_aml(Simulator, seed, rng, plt): + d = 32 + vocab = nengo.spa.Vocabulary(d, rng=rng) + n_items = 3 + item_duration = 1. + + def err_stimulus(t): + if t <= n_items * item_duration: + v = vocab.parse('Out' + str(int(t // item_duration))).v + else: + v = np.zeros(d) + return np.concatenate(((1., 1.), v)) + + def pre_stimulus(t): + return vocab.parse('In' + str(int((t // item_duration) % n_items))).v + + with nengo.Network(seed=seed) as model: + pre = nengo.Ensemble(50 * d, d) + post = nengo.Node(size_in=d) + c = nengo.Connection( + pre, post, learning_rule_type=AML(d), + function=lambda x: np.zeros(d)) + err = nengo.Node(err_stimulus) + inp = nengo.Node(pre_stimulus) + nengo.Connection(inp, pre) + nengo.Connection(err, c.learning_rule) + p_pre = nengo.Probe(pre, synapse=0.01) + p_post = nengo.Probe(post, synapse=0.01) + p_err = nengo.Probe(err, synapse=0.01) + + with Simulator(model) as sim: + sim.run(2 * n_items * item_duration) + + vocab_out = vocab.create_subset(['Out' + str(i) for i in range(n_items)]) + vocab_in = vocab.create_subset(['In' + str(i) for i in range(n_items)]) + + fig = plt.figure() + + ax1 = fig.add_subplot(3, 1, 1) + ax1.plot(sim.trange(), nengo.spa.similarity(sim.data[p_pre], vocab_in)) + ax1.set_ylabel(r"Cue $\mathbf{u}(t)$") + + ax2 = fig.add_subplot(3, 1, 2, sharex=ax1, sharey=ax1) + ax2.plot(sim.trange(), nengo.spa.similarity( + sim.data[p_err][:, 2:], vocab_out)) + ax2.set_ylabel(r"Target $\mathbf{v}(t)$") + + ax3 = fig.add_subplot(3, 1, 3, sharex=ax1, sharey=ax1) + ax3.plot(sim.trange(), nengo.spa.similarity(sim.data[p_post], vocab_out)) + ax3.set_ylabel("AML output") + + ax1.set_ylim(bottom=0.) + + for ax in [ax1, ax2, ax3]: + ax.label_outer() + fig.tight_layout() + + t = sim.trange() + similarity = nengo.spa.similarity(sim.data[p_post], vocab_out) + for i in range(n_items): + assert item_duration > 0.3 + start = (n_items + i) * item_duration + 0.3 + end = (n_items + i + 1) * item_duration + assert np.all(similarity[(start < t) & (t <= end), i] > 0.8) @pytest.mark.parametrize('post_target', [None, 'in', 'out'])