Skip to content

Commit

Permalink
Add warning for 1-step training if synapse=None
Browse files Browse the repository at this point in the history
Fixes #54
  • Loading branch information
drasmuss committed Sep 5, 2018
1 parent fc6b4f7 commit b2e7754
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 1 deletion.
5 changes: 5 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ Release History
1.2.1 (unreleased)
------------------

**Added**

- Added a warning if users run one-timestep training with a network containing
synaptic filters.


1.2.0 (September 5, 2018)
-------------------------
Expand Down
8 changes: 8 additions & 0 deletions nengo_dl/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,14 @@ def train(self, inputs, targets, optimizer, n_epochs=1, objective="mse",
raise ValidationError(
"Network was created with inference_only=True, cannot "
"be trained", "inference_only")
if (n_steps == 1 and self.model.toplevel is not None and
any(x.synapse is not None for x in
(self.model.toplevel.all_connections +
list(targets.keys())))):
warnings.warn(
"Training for one timestep, but the network contains "
"synaptic filters (which will introduce at least a "
"one-timestep delay); did you mean to set synapse=None?")

# check for non-differentiable elements in graph
# utils.find_non_differentiable(
Expand Down
40 changes: 39 additions & 1 deletion nengo_dl/tests/test_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1292,7 +1292,7 @@ def test_inference_only(Simulator, neuron_type, seed):
# validation checks (can't do train/gradients in inference-only mode)
with pytest.raises(ValidationError):
sim2.train({a: np.zeros((1, 10, 1))}, {p: np.zeros((1, 10, 1))},
tf.train.GradientDescentOptimizer(1))
tf.train.GradientDescentOptimizer(1))

with pytest.raises(ValidationError):
sim2.check_gradients()
Expand All @@ -1302,3 +1302,41 @@ def test_dtype(Simulator):
with pytest.warns(DeprecationWarning):
with Simulator(None, dtype=tf.float32) as sim:
assert sim.tensor_graph.dtype == tf.float32


def test_synapse_warning(Simulator):
with nengo.Network() as net:
a = nengo.Node([0])
b = nengo.Ensemble(10, 1)
c = nengo.Connection(a, b, synapse=1)
p = nengo.Probe(b)
p2 = nengo.Probe(b)

# warning from connection
with Simulator(net) as sim:
with pytest.warns(UserWarning) as rec:
sim.train({a: np.zeros((1, 1, 1))}, {p: np.zeros((1, 1, 1))},
tf.train.GradientDescentOptimizer(0))
assert any(str(w.message).startswith("Training for one timestep")
for w in rec)

# warning from probe
c.synapse = None
p.synapse = 1
with Simulator(net) as sim:
with pytest.warns(UserWarning):
sim.train({a: np.zeros((1, 1, 1))}, {p: np.zeros((1, 1, 1))},
tf.train.GradientDescentOptimizer(0))
assert any(str(w.message).startswith("Training for one timestep")
for w in rec)

# no warning from non-target probe
p.synapse = None
p2.synapse = 1
with Simulator(net) as sim:
with pytest.warns(UserWarning) as rec:
sim.train({a: np.zeros((1, 1, 1))}, {p: np.zeros((1, 1, 1))},
tf.train.GradientDescentOptimizer(0))
assert not any(
str(w.message).startswith("Training for one timestep")
for w in rec)

0 comments on commit b2e7754

Please sign in to comment.