Skip to content

Commit

Permalink
Fix bug with multiple indices in post ObjView
Browse files Browse the repository at this point in the history
Also test that boolean indexing works. With this, advanced indexing
is fully supported for ObjViews in connections.

Fixes #947.
  • Loading branch information
AllenHW authored and tbekolay committed Sep 25, 2017
1 parent 51ff3ed commit 9bc1e22
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 2 deletions.
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ Release History
(`#1340 <https://github.com/nengo/nengo/pull/1340>`_)
- Fixed an issue in which ``ShapeParam`` would always store ``None``.
(`#1342 <https://github.com/nengo/nengo/pull/1342>`_)
- Fixed an issue in which multiple identical indices in a slice were ignored.
(`#947 <https://github.com/nengo/nengo/issues/947>`_)

**Deprecated**

Expand Down
2 changes: 1 addition & 1 deletion nengo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def __init__(self, obj, key=slice(None)):
# Node.size_in != size_out, so one of these can be invalid
try:
self.size_in = np.arange(self.obj.size_in)[key].size
except IndexError:
except (IndexError, ValueError):
self.size_in = None
try:
self.size_out = np.arange(self.obj.size_out)[key].size
Expand Down
15 changes: 14 additions & 1 deletion nengo/builder/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import numpy as np

import nengo.utils.numpy as npext
from nengo.utils.compat import is_array_like
from nengo.utils.connection import function_name
from nengo.exceptions import BuildError, SimulationError

Expand Down Expand Up @@ -386,9 +387,21 @@ def make_step(self, signals, dt, rng):
dst_slice = self.dst_slice if self.dst_slice is not None else Ellipsis
inc = self.inc

if inc:
dst_slice = (np.asarray(dst_slice) if is_array_like(dst_slice) else
dst_slice)
# There are repeated indices in dst_slice, special case
repeats = (is_array_like(dst_slice) and dst_slice.dtype != np.bool and
len(np.unique(dst_slice)) < len(dst_slice))
if inc and repeats:
def step_copy():
np.add.at(dst, dst_slice, src[src_slice])
elif inc:
def step_copy():
dst[dst_slice] += src[src_slice]
elif repeats:
raise BuildError("%s: Cannot have repeated indices in "
"``dst_slice`` when copy is not an increment"
% self)
else:
def step_copy():
dst[dst_slice] = src[src_slice]
Expand Down
62 changes: 62 additions & 0 deletions nengo/tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,68 @@ def test_slicing_function(Simulator, plt, seed):
assert np.allclose(w, y, atol=0.1)


def test_list_indexing(Simulator, plt, seed):

with nengo.Network(seed=seed) as model:
u = nengo.Node([-1, 1])
a = nengo.Ensemble(40, dimensions=1)
b = nengo.Ensemble(40, dimensions=1, radius=2.2)
c = nengo.Ensemble(80, dimensions=2, radius=1.3)
d = nengo.Ensemble(80, dimensions=2, radius=1.3)
nengo.Connection(u[[0, 1]], a[[0, 0]])
nengo.Connection(u[[1, 1]], b[[0, 0]])
nengo.Connection(u[[0, 1]], c[[0, 1]])
nengo.Connection(u[[1, 1]], d[[0, 1]])

a_probe = nengo.Probe(a, synapse=0.03)
b_probe = nengo.Probe(b, synapse=0.03)
c_probe = nengo.Probe(c, synapse=0.03)
d_probe = nengo.Probe(d, synapse=0.03)

with Simulator(model) as sim:
sim.run(0.2)

t = sim.trange()
a_data = sim.data[a_probe]
b_data = sim.data[b_probe]
c_data = sim.data[c_probe]
d_data = sim.data[d_probe]

line = plt.plot(t, a_data)
plt.axhline(0, color=line[0].get_color())
assert np.allclose(a_data[t > 0.15], [0], atol=0.1)
line = plt.plot(t, b_data)
plt.axhline(2, color=line[0].get_color())
assert np.allclose(b_data[t > 0.15], [2], atol=0.1)
line = plt.plot(t, c_data)
plt.axhline(-1, color=line[0].get_color())
assert np.allclose(c_data[t > 0.15], [-1, 1], atol=0.1)
line = plt.plot(t, d_data)
plt.axhline(1, color=line[1].get_color())
assert np.allclose(d_data[t > 0.15], [1, 1], atol=0.1)


def test_boolean_indexing(Simulator, rng, plt):
D = 10
mu = np.arange(D) % 2 == 0
mv = np.arange(D) % 2 == 1
x = rng.uniform(-1, 1, size=D)
y = np.zeros(D)
y[mv] = x[mu]

with nengo.Network() as model:
u = nengo.Node(x)
v = nengo.Node(size_in=D)
nengo.Connection(u[mu], v[mv], synapse=None)
v_probe = nengo.Probe(v)

with Simulator(model) as sim:
sim.run(0.01)

plt.plot(sim.trange(), sim.data[v_probe])
assert np.allclose(sim.data[v_probe][1:], y, atol=1e-5, rtol=1e-3)


def test_set_weight_solver():
with nengo.Network():
a = nengo.Ensemble(10, 2)
Expand Down
3 changes: 3 additions & 0 deletions nengo/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def __init__(self, *args, **kwargs):
def __call__(self, *args, **kwargs):
return Mock()

def __getitem__(self, key):
return Mock()

def __mul__(self, other):
return 1.0

Expand Down

0 comments on commit 9bc1e22

Please sign in to comment.