diff --git a/CHANGES.rst b/CHANGES.rst index 632b023b7..5a9ce1daf 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -21,6 +21,11 @@ Release History 1.2.1 (unreleased) ------------------ +**Fixed** + +- Fixed an error that was thrown when calling ``get_tensor`` on a ``Signal`` + that was first initialized inside the Simulation while loop + (`#56 `_) 1.2.0 (September 5, 2018) ------------------------- diff --git a/nengo_dl/tensor_graph.py b/nengo_dl/tensor_graph.py index 8ec0beeae..6b0ded442 100644 --- a/nengo_dl/tensor_graph.py +++ b/nengo_dl/tensor_graph.py @@ -709,6 +709,11 @@ def get_tensor(self, sig): tensor_sig = self.signals[sig] base = self.base_vars[tensor_sig.key][0] + + if "while/" in tensor_sig.tf_indices.name: + # rebuild tf indices outside the while loop + tensor_sig._tf_indices = None + return tf.gather(base, tensor_sig.tf_indices) def mark_signals(self): diff --git a/nengo_dl/tests/test_tensor_graph.py b/nengo_dl/tests/test_tensor_graph.py index 2a912ee69..8e59d98f2 100644 --- a/nengo_dl/tests/test_tensor_graph.py +++ b/nengo_dl/tests/test_tensor_graph.py @@ -395,3 +395,24 @@ def test_create_signals_partition(): graph = dummies.TensorGraph(plan, tf.float32, 10) graph.create_signals(sigs) assert len(graph.base_arrays_init) == 4 + + +def test_get_tensor(Simulator): + with nengo.Network() as net: + a = nengo.Node([1]) + b = nengo.Ensemble(10, 1) + c = nengo.Connection(a, b.neurons, transform=np.arange(10)[:, None], + synapse=None) + p = nengo.Probe(c) + + # build a signal probe so that the indices get loaded into the sim + # (checks that the indices reloading works properly) + nengo.Probe(c, "weights") + + with Simulator(net) as sim: + tensor = sim.tensor_graph.get_tensor(sim.model.sig[c]["weights"]) + + assert np.allclose(sim.sess.run(tensor), np.arange(10)[:, None]) + + sim.run_steps(10) + assert np.allclose(sim.data[p], np.arange(10)[None, :])