From d767ec47f8d09f245f289754687814b6449ef4a3 Mon Sep 17 00:00:00 2001 From: Aaron Voelker Date: Mon, 15 Apr 2019 13:36:11 -0400 Subject: [PATCH] Handle sliced probes Closes #205 and #206. --- nengo_loihi/builder/probe.py | 10 +++++-- nengo_loihi/tests/test_simulator.py | 41 +++++++++++++++++++++++++++++ nengo_loihi/tests/test_splitter.py | 20 ++++++++++++++ 3 files changed, 69 insertions(+), 2 deletions(-) diff --git a/nengo_loihi/builder/probe.py b/nengo_loihi/builder/probe.py index c4b972dcd..a2280a87f 100644 --- a/nengo_loihi/builder/probe.py +++ b/nengo_loihi/builder/probe.py @@ -1,5 +1,6 @@ import nengo from nengo import Ensemble, Connection, Node +from nengo.base import ObjView from nengo.connection import LearningRule from nengo.ensemble import Neurons from nengo.exceptions import BuildError @@ -74,10 +75,15 @@ def conn_probe(model, nengo_probe): model.seeded[conn] = model.seeded[nengo_probe] model.seeds[conn] = model.seeds[nengo_probe] + if isinstance(nengo_probe.target, ObjView): + target_obj = nengo_probe.target.obj + else: + target_obj = nengo_probe.target + d = conn.size_out - if isinstance(nengo_probe.target, Ensemble): + if isinstance(target_obj, Ensemble): # probed values are scaled by the target ensemble's radius - scale = nengo_probe.target.radius + scale = target_obj.radius w = np.diag(scale * np.ones(d)) weights = np.vstack([w, -w]) else: diff --git a/nengo_loihi/tests/test_simulator.py b/nengo_loihi/tests/test_simulator.py index 8d5c778c5..eab5fcc3f 100644 --- a/nengo_loihi/tests/test_simulator.py +++ b/nengo_loihi/tests/test_simulator.py @@ -801,6 +801,47 @@ def test_simulator_passthrough(remove_passthrough, Simulator): assert conn_y_d not in model.params +def test_slicing_bugs(Simulator, seed): + + n = 50 + with nengo.Network() as model: + a = nengo.Ensemble(n, 1, label="a") + p0 = nengo.Probe(a[0]) + p = nengo.Probe(a) + + with Simulator(model) as sim: + sim.run(0.1) + + assert np.allclose(sim.data[p0], sim.data[p]) + assert a in sim.model.params + assert a not in sim.model.host.params + + with nengo.Network() as model: + nengo_loihi.add_params(model) + + a = nengo.Ensemble(n, 1, label="a") + + b0 = nengo.Ensemble(n, 1, label="b0", seed=seed) + model.config[b0].on_chip = False + nengo.Connection(a[0], b0) + + b = nengo.Ensemble(n, 1, label="b", seed=seed) + model.config[b].on_chip = False + nengo.Connection(a, b) + + p0 = nengo.Probe(b0) + p = nengo.Probe(b) + + with Simulator(model) as sim: + sim.run(0.1) + + assert np.allclose(sim.data[p0], sim.data[p]) + assert a in sim.model.params + assert a not in sim.model.host.params + assert b not in sim.model.params + assert b in sim.model.host.params + + def test_network_unchanged(Simulator): with nengo.Network() as model: nengo.Ensemble(100, 1) diff --git a/nengo_loihi/tests/test_splitter.py b/nengo_loihi/tests/test_splitter.py index a30033799..6195550b5 100644 --- a/nengo_loihi/tests/test_splitter.py +++ b/nengo_loihi/tests/test_splitter.py @@ -262,6 +262,26 @@ def test_split_remove_passthrough(remove_passthrough): assert split.passthrough.to_add == set() +def test_sliced_passthrough_bug(): + with nengo.Network() as model: + add_params(model) + + a = nengo.Ensemble(1, 1, label="a") + passthrough = nengo.Node(size_in=1, label="passthrough") + + nengo.Connection(a, passthrough) + p = nengo.Probe(passthrough[0]) + + split = Split(model, remove_passthrough=True) + + assert len(split.passthrough.to_add) == 0 + assert len(split.passthrough.to_remove) == 0 + + assert split.on_chip(a) + assert not split.on_chip(passthrough) + assert not split.on_chip(p) + + def test_precompute_remove_passthrough(): with nengo.Network() as net: add_params(net)