Skip to content

Commit

Permalink
WIP: error connects to LearningRule
Browse files Browse the repository at this point in the history
- as discussed in #632
- TODO: cannot build error connection until post LearningRule has
  been built, but cannot build LearningRule until target connection
  has been built.
  • Loading branch information
hunse committed Feb 2, 2015
1 parent 66c463f commit 8365d4c
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 16 deletions.
3 changes: 2 additions & 1 deletion nengo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ def __init__(self, optional=False, readonly=True,

def validate(self, instance, nengo_obj):
from nengo.ensemble import Neurons
if not isinstance(nengo_obj, (NengoObject, Neurons, ObjView)):
from nengo.connection import LearningRule
if not isinstance(nengo_obj, (NengoObject, Neurons, ObjView, LearningRule)):
raise ValueError("'%s' is not a Nengo object" % nengo_obj)
if self.nonzero_size_in and nengo_obj.size_in < 1:
raise ValueError("'%s' must have size_in > 0." % nengo_obj)
Expand Down
14 changes: 7 additions & 7 deletions nengo/builder/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from nengo.builder.operator import DotInc, ElementwiseInc, PreserveValue, Reset
from nengo.builder.signal import Signal
from nengo.builder.synapses import filtered_signal
from nengo.connection import Connection
from nengo.connection import Connection, LearningRule
from nengo.ensemble import Ensemble, Neurons
from nengo.neurons import Direct
from nengo.node import Node
Expand Down Expand Up @@ -140,12 +140,12 @@ def get_prepost_signal(is_pre):
if conn.synapse is not None:
signal = filtered_signal(model, conn, signal, conn.synapse)

if conn.modulatory:
# Make a new signal, effectively detaching from post
model.sig[conn]['out'] = Signal(
np.zeros(model.sig[conn]['out'].size),
name="%s.mod_output" % conn)
model.add_op(Reset(model.sig[conn]['out']))
# if conn.modulatory:
# # Make a new signal, effectively detaching from post
# model.sig[conn]['out'] = Signal(
# np.zeros(model.sig[conn]['out'].size),
# name="%s.mod_output" % conn)
# model.add_op(Reset(model.sig[conn]['out']))

# Add operator for transform
if isinstance(conn.post_obj, Neurons):
Expand Down
10 changes: 9 additions & 1 deletion nengo/builder/learning_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,17 @@ def build_oja(model, oja, rule):
@Builder.register(PES)
def build_pes(model, pes, rule):
# TODO: Filter activities

# Error signal
error = Signal(np.zeros(rule.size_in), name="PES:error")
model.sig[rule]['in'] = error # so we can connect into it



conn = rule.connection
activities = model.sig[conn.pre_obj]['out']
error = model.sig[pes.error_connection]['out']
# error = model.sig[pes.error_connection]['out']
# error = model.sig[pes.error_connection]['out']

scaled_error = Signal(np.zeros(error.shape),
name="PES:error * learning_rate")
Expand Down
24 changes: 23 additions & 1 deletion nengo/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

from nengo.base import NengoObject, NengoObjectParam, ObjView
from nengo.ensemble import Ensemble
from nengo.ensemble import Ensemble, Neurons
from nengo.learning_rules import LearningRuleType, LearningRuleTypeParam
from nengo.node import Node
from nengo.params import (Default, BoolParam, DistributionParam, FunctionParam,
Expand Down Expand Up @@ -339,3 +339,25 @@ def __str__(self):
@property
def probeable(self):
return self.learning_rule_type.probeable

@property
def size_in(self): # size of error signal
if self.learning_rule_type.error_type == 'none':
return 0
elif self.learning_rule_type.error_type == 'decoder':
if isinstance(self.connection.pre_obj, Neurons):
return self.connection.pre_obj.ensemble.dimensions
elif isinstance(self.connection.pre_obj, Ensemble):
return self.connection.size_in # sliced Ensemble dimensions
else:
raise ValueError("Cannot learn on '%s' type" % (
self.connection.pre_obj.__class__.__name__))
elif self.learning_rule_type.error_type == 'neuron':
raise NotImplementedError()
else:
raise ValueError("Unrecognized error type '%s'" % (
self.learning_rule_type.error_type))

# @property
# def size_out(self): # shape of learned decoders/weights
# return self.connection.size_out
8 changes: 5 additions & 3 deletions nengo/learning_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class LearningRuleType(object):
"""

learning_rate = NumberParam(low=0, low_open=True)
error_type = 'none'
probeable = []

def __init__(self, learning_rate=1e-6):
Expand Down Expand Up @@ -56,12 +57,11 @@ class PES(LearningRuleType):
The modulatory connection created to project the error signal.
"""

error_connection = ConnectionParam()
error_type = 'decoder'
modifies = ['Ensemble', 'Neurons']
probeable = ['scaled_error', 'activities']

def __init__(self, error_connection, learning_rate=1e-6):
self.error_connection = error_connection
def __init__(self, learning_rate=1e-6):
super(PES, self).__init__(learning_rate)


Expand Down Expand Up @@ -97,6 +97,7 @@ class BCM(LearningRuleType):
pre_tau = NumberParam(low=0, low_open=True)
post_tau = NumberParam(low=0, low_open=True)
theta_tau = NumberParam(low=0, low_open=True)
error_type = 'none'
modifies = ['Neurons']
probeable = ['theta', 'pre_filtered', 'post_filtered']

Expand Down Expand Up @@ -140,6 +141,7 @@ class Oja(LearningRuleType):
pre_tau = NumberParam(low=0, low_open=True)
post_tau = NumberParam(low=0, low_open=True)
beta = NumberParam(low=0)
error_type = 'none'
modifies = ['Neurons']
probeable = ['pre_filtered', 'post_filtered']

Expand Down
6 changes: 3 additions & 3 deletions nengo/tests/test_learning_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_pes_weights(Simulator, nl_nodirect, plt, seed, rng):
assert np.allclose(sim.data[se_p][tend] / rate, 0, atol=0.05)


def test_pes_decoders(Simulator, nl_nodirect, seed, plt):
def test_pes_decoders1(Simulator, nl_nodirect, seed, plt):
n = 200
learned_vector = [0.5, -0.5]

Expand All @@ -75,8 +75,8 @@ def test_pes_decoders(Simulator, nl_nodirect, seed, plt):
nengo.Connection(u, a)
nengo.Connection(u_learned, e, transform=-1)
nengo.Connection(u, e)
e_c = nengo.Connection(e, u_learned, modulatory=True)
conn = nengo.Connection(a, u_learned, learning_rule_type=PES(e_c))
conn = nengo.Connection(a, u_learned, learning_rule_type=PES())
nengo.Connection(e, conn.learning_rule)

u_learned_p = nengo.Probe(u_learned, synapse=0.1)
e_p = nengo.Probe(e, synapse=0.1)
Expand Down

0 comments on commit 8365d4c

Please sign in to comment.