Skip to content

Commit

Permalink
Reload tf_indices outside while in get_tensor
Browse files Browse the repository at this point in the history
Fixes #56
  • Loading branch information
drasmuss committed Sep 11, 2018
1 parent f44a115 commit bcf714d
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 0 deletions.
5 changes: 5 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ Release History
1.2.1 (unreleased)
------------------

**Fixed**

- Fixed an error that was thrown when calling ``get_tensor`` on a ``Signal``
that was first initialized inside the Simulation while loop
(`#56 <https://github.com/nengo/nengo-dl/issues/56>`_)

1.2.0 (September 5, 2018)
-------------------------
Expand Down
5 changes: 5 additions & 0 deletions nengo_dl/tensor_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,11 @@ def get_tensor(self, sig):
tensor_sig = self.signals[sig]

base = self.base_vars[tensor_sig.key][0]

if "while/" in tensor_sig.tf_indices.name:
# rebuild tf indices outside the while loop
tensor_sig._tf_indices = None

return tf.gather(base, tensor_sig.tf_indices)

def mark_signals(self):
Expand Down
21 changes: 21 additions & 0 deletions nengo_dl/tests/test_tensor_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,3 +395,24 @@ def test_create_signals_partition():
graph = dummies.TensorGraph(plan, tf.float32, 10)
graph.create_signals(sigs)
assert len(graph.base_arrays_init) == 4


def test_get_tensor(Simulator):
with nengo.Network() as net:
a = nengo.Node([1])
b = nengo.Ensemble(10, 1)
c = nengo.Connection(a, b.neurons, transform=np.arange(10)[:, None],
synapse=None)
p = nengo.Probe(c)

# build a signal probe so that the indices get loaded into the sim
# (checks that the indices reloading works properly)
nengo.Probe(c, "weights")

with Simulator(net) as sim:
tensor = sim.tensor_graph.get_tensor(sim.model.sig[c]["weights"])

assert np.allclose(sim.sess.run(tensor), np.arange(10)[:, None])

sim.run_steps(10)
assert np.allclose(sim.data[p], np.arange(10)[None, :])

0 comments on commit bcf714d

Please sign in to comment.