Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add association matrix learning rule (AML). #72

Merged
merged 1 commit into from May 24, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Expand Up @@ -9,6 +9,7 @@ dist
*.swp
.ipynb_checkpoints/
.cache
.pytest_cache/

# --- data files:
*.json
Expand Down
5 changes: 5 additions & 0 deletions CHANGES.rst
Expand Up @@ -22,7 +22,12 @@ Release history
0.1.1 (unreleased)
==================

**Added**

- Added the association matrix learning rule (AML)
to learn associations from cue vectors to target vectors
in a one-shot fashion without catastrophic forgetting.
(`#72 <https://github.com/nengo/nengo-extras/pull/72>`_)

0.1.0 (March 14, 2018)
======================
Expand Down
2 changes: 2 additions & 0 deletions docs/learning_rules.rst
Expand Up @@ -9,4 +9,6 @@ can be used.

.. default-role:: obj

.. autoclass:: nengo_extras.learning_rules.AML

.. autoclass:: nengo_extras.learning_rules.DeltaRule
111 changes: 110 additions & 1 deletion nengo_extras/learning_rules.py
Expand Up @@ -3,13 +3,122 @@
import numpy as np

from nengo.builder import Builder, Signal
from nengo.builder.operator import DotInc, ElementwiseInc, Reset, SimPyFunc
from nengo.builder.connection import get_eval_points, solve_for_decoders
from nengo.builder.operator import (
DotInc, ElementwiseInc, Operator, Reset, SimPyFunc)
from nengo.exceptions import ValidationError
from nengo.learning_rules import LearningRuleType
from nengo.params import EnumParam, FunctionParam, NumberParam
from nengo.synapses import Lowpass


class AML(LearningRuleType):
r"""Association matrix learning rule (AML).

Enables one-shot learning without catastrophic forgetting of outer product
association matrices.

The cue is provided by the pre-synaptic ensemble. The error signal is split
up: ``error[0]`` provides a scaling factor to the learning rate.
``error[1]`` provides a decay rate (i.e., weights are multiplied with this
value in every time step), ``error[2:]`` provides the target vector.

The update is given by:

decoders[...] *= error[1] # decay
decoders[...] += alpha * error[0] * error[2:, None] * np.dot(
pre, base_decoders.T)

where *alpha* is the learning rate adjusted for *dt* and *base_decoders*
is the decoder matrix for decoding the identity from the pre-ensemble.

Parameters
----------
d : int
Dimensionality of input and output vectors (error signal will be
*d+2*).
learning_rate : float, optional
Learning rate (increase of dot product similarity per second).
"""
error_type = 'decoded'
modifies = 'decoders'

def __init__(self, d, learning_rate=1.):
super(AML, self).__init__(learning_rate, size_in=d + 2)


class SimAML(Operator):
def __init__(self, learning_rate, base_decoders, pre, error, decoders,
tag=None):
super(SimAML, self).__init__(tag=tag)

self.learning_rate = learning_rate
self.base_decoders = base_decoders

self.sets = []
self.incs = []
self.reads = [pre, error]
self.updates = [decoders]

def make_step(self, signals, dt, rng):
base_decoders = self.base_decoders
pre = signals[self.pre]
error = signals[self.error]
decoders = signals[self.decoders]
alpha = self.learning_rate * dt

def step_assoc_learning():
scale = error[0]
decay = error[1]
target = error[2:]
decoders[...] *= decay
decoders[...] += alpha * scale * target[:, None] * np.dot(
pre, base_decoders.T)

return step_assoc_learning

@property
def pre(self):
return self.reads[0]

@property
def error(self):
return self.reads[1]

@property
def decoders(self):
return self.updates[0]


@Builder.register(AML)
def build_aml(model, aml, rule):
conn = rule.connection
rng = np.random.RandomState(model.seeds[conn])

error = Signal(np.zeros(rule.size_in), name="aml:error")
model.add_op(Reset(error))
model.sig[rule]['in'] = error

pre = model.sig[conn.pre_obj]['in']
decoders = model.sig[conn]['weights']

encoders = model.params[conn.pre_obj].encoders
gain = model.params[conn.pre_obj].gain
bias = model.params[conn.pre_obj].bias

eval_points = get_eval_points(model, conn, rng)
targets = eval_points

x = np.dot(eval_points, encoders.T)

wrapped_solver = (model.decoder_cache.wrap_solver(solve_for_decoders)
if model.seeded[conn] else solve_for_decoders)
base_decoders, _ = wrapped_solver(conn, gain, bias, x, targets, rng=rng)

model.add_op(SimAML(
aml.learning_rate, base_decoders, pre, error, decoders))


class DeltaRuleFunctionParam(FunctionParam):
function_test_size = 8 # arbitrary size to test function

Expand Down
156 changes: 156 additions & 0 deletions nengo_extras/ocl.py
@@ -0,0 +1,156 @@
import mako
import nengo_ocl
import numpy as np
import pyopencl as cl


def plan_aml_decode(queue, pre, base_decoders, decoded, tag=None):
assert pre.ctype == base_decoders.ctype == decoded.ctype
assert len(pre) == len(base_decoders) == len(decoded)
assert np.all(pre.shape0s == base_decoders.shape1s)
assert np.all(base_decoders.shape0s == decoded.shape0s)

text = '''
__kernel void aml_decode(
__global const int *ds,
__global const int *ns,
__global const int *pre_stride0s,
__global const int *pre_starts,
__global const ${type} *pre_data,
__global const int *base_decoders_stride0s,
__global const int *base_decoders_starts,
__global const ${type} *base_decoders_data,
__global const int *decoded_stride0s,
__global const int *decoded_starts,
__global ${type} *decoded_data
) {
const int i = get_global_id(0);
const int k = get_global_id(1);

const int d = ds[k];
const int n = ns[k];

__global const ${type} *pre = pre_data + pre_starts[k];
__global const ${type} *base_decoders = base_decoders_data +
base_decoders_starts[k];
__global ${type} *decoded = decoded_data + decoded_starts[k];

if (i < n) {
${type} x = 0.;
for (int s = 0; s < d; ++s) {
x += base_decoders[i * base_decoders_stride0s[k] + s] * pre[s];
}
decoded[i] = x;
}
}
'''

textconf = dict(type=pre.ctype)
text = nengo_ocl.utils.as_ascii(mako.template.Template(
text, output_encoding='ascii').render(**textconf))

full_args = (
base_decoders.cl_shape1s, base_decoders.cl_shape0s,
pre.cl_stride0s, pre.cl_starts, pre.cl_buf,
base_decoders.cl_stride0s, base_decoders.cl_starts,
base_decoders.cl_buf,
decoded.cl_stride0s, decoded.cl_starts, decoded.cl_buf,
)
_fn = cl.Program(queue.context, text).build().aml_decode
_fn.set_args(*(arr.data for arr in full_args))

lsize = None
gsize = (base_decoders.shape0s.max(), len(pre))
plan = nengo_ocl.plan.Plan(
queue, _fn, gsize, lsize=lsize, name="cl_aml_decode", tag=tag)
plan.full_args = full_args # prevent garbage collection
plan.flops_per_call = np.sum(
base_decoders.shape0s * base_decoders.shape1s * 2 +
base_decoders.shape1s * 2)
plan.bw_per_call = decoded.nbytes + pre.nbytes + base_decoders.nbytes

return plan


def plan_aml(queue, error, decoders, alpha, decoded, tag=None):
assert error.ctype == decoders.ctype == alpha.ctype == decoded.ctype
assert len(error) == len(decoders) == len(alpha) == len(decoded)
assert np.all(error.shape0s - 2 == decoders.shape0s)

text = '''
__kernel void aml(
__global const int *ds,
__global const int *ns,
__global const int *error_stride0s,
__global const int *error_starts,
__global const ${type} *error_data,
__global const int *decoders_stride0s,
__global const int *decoders_starts,
__global ${type} *decoders_data,
__global const int *decoded_stride0s,
__global const int *decoded_starts,
__global const ${type} *decoded_data,
__global const ${type} *alphas
) {
const int ij = get_global_id(0);
const int k = get_global_id(1);

const int d = ds[k];
const int n = ns[k];
const int i = ij / n;
const int j = ij % n;

__global ${type} *decoders = decoders_data + decoders_starts[k];
const ${type} scale = error_data[error_starts[k]];
const ${type} decay = error_data[error_starts[k] + 1];
const ${type} error = error_data[error_starts[k] + i + 2];
const ${type} decoded = decoded_data[decoded_starts[k] + j];
const ${type} alpha = alphas[k];

if (i < d) {
decoders[i * decoders_stride0s[k] + j] *= decay;
decoders[i * decoders_stride0s[k] + j] += alpha * scale * error *
decoded;
}
}
'''

textconf = dict(type=error.ctype)
text = nengo_ocl.utils.as_ascii(mako.template.Template(
text, output_encoding='ascii').render(**textconf))

full_args = (
decoders.cl_shape0s, decoders.cl_shape1s,
error.cl_stride0s, error.cl_starts, error.cl_buf,
decoders.cl_stride0s, decoders.cl_starts, decoders.cl_buf,
decoded.cl_stride0s, decoded.cl_starts, decoded.cl_buf,
alpha,
)
_fn = cl.Program(queue.context, text).build().aml
_fn.set_args(*(arr.data for arr in full_args))

lsize = None
gsize = (decoders.sizes.max(), len(error))
plan = nengo_ocl.plan.Plan(
queue, _fn, gsize, lsize=lsize, name="cl_aml", tag=tag)
plan.full_args = full_args # prevent garbage collection
plan.flops_per_call = np.sum(2 * (error.shape0s * decoded.shape0s))
plan.bw_per_call = (
decoded.nbytes + error.nbytes + alpha.nbytes + decoders.nbytes)

return plan


class AmlSimulator(nengo_ocl.Simulator):
def plan_SimAML(self, ops):
alpha = self.Array([op.learning_rate * self.model.dt for op in ops])
base_decoders = self.RaggedArray(
[op.base_decoders for op in ops], dtype=np.float32)
pre = self.all_data[[self.sidx[op.pre] for op in ops]]
error = self.all_data[[self.sidx[op.error] for op in ops]]
decoders = self.all_data[[self.sidx[op.decoders] for op in ops]]
decoded = self.RaggedArray(
[np.zeros(op.decoders.shape[1]) for op in ops], dtype=np.float32)
return [
plan_aml_decode(self.queue, pre, base_decoders, decoded),
plan_aml(self.queue, error, decoders, alpha, decoded)]
69 changes: 68 additions & 1 deletion nengo_extras/tests/test_learning_rules.py
Expand Up @@ -5,7 +5,74 @@
import pytest


from nengo_extras.learning_rules import DeltaRule
from nengo_extras.learning_rules import AML, DeltaRule


@pytest.mark.slow
def test_aml(Simulator, seed, rng, plt):
d = 32
vocab = nengo.spa.Vocabulary(d, rng=rng)
n_items = 3
item_duration = 1.

def err_stimulus(t):
if t <= n_items * item_duration:
v = vocab.parse('Out' + str(int(t // item_duration))).v
else:
v = np.zeros(d)
return np.concatenate(((1., 1.), v))

def pre_stimulus(t):
return vocab.parse('In' + str(int((t // item_duration) % n_items))).v

with nengo.Network(seed=seed) as model:
pre = nengo.Ensemble(50 * d, d)
post = nengo.Node(size_in=d)
c = nengo.Connection(
pre, post, learning_rule_type=AML(d),
function=lambda x: np.zeros(d))
err = nengo.Node(err_stimulus)
inp = nengo.Node(pre_stimulus)
nengo.Connection(inp, pre)
nengo.Connection(err, c.learning_rule)
p_pre = nengo.Probe(pre, synapse=0.01)
p_post = nengo.Probe(post, synapse=0.01)
p_err = nengo.Probe(err, synapse=0.01)

with Simulator(model) as sim:
sim.run(2 * n_items * item_duration)

vocab_out = vocab.create_subset(['Out' + str(i) for i in range(n_items)])
vocab_in = vocab.create_subset(['In' + str(i) for i in range(n_items)])

fig = plt.figure()

ax1 = fig.add_subplot(3, 1, 1)
ax1.plot(sim.trange(), nengo.spa.similarity(sim.data[p_pre], vocab_in))
ax1.set_ylabel(r"Cue $\mathbf{u}(t)$")

ax2 = fig.add_subplot(3, 1, 2, sharex=ax1, sharey=ax1)
ax2.plot(sim.trange(), nengo.spa.similarity(
sim.data[p_err][:, 2:], vocab_out))
ax2.set_ylabel(r"Target $\mathbf{v}(t)$")

ax3 = fig.add_subplot(3, 1, 3, sharex=ax1, sharey=ax1)
ax3.plot(sim.trange(), nengo.spa.similarity(sim.data[p_post], vocab_out))
ax3.set_ylabel("AML output")

ax1.set_ylim(bottom=0.)

for ax in [ax1, ax2, ax3]:
ax.label_outer()
fig.tight_layout()

t = sim.trange()
similarity = nengo.spa.similarity(sim.data[p_post], vocab_out)
for i in range(n_items):
assert item_duration > 0.3
start = (n_items + i) * item_duration + 0.3
end = (n_items + i + 1) * item_duration
assert np.all(similarity[(start < t) & (t <= end), i] > 0.8)


@pytest.mark.parametrize('post_target', [None, 'in', 'out'])
Expand Down