Skip to content

Commit

Permalink
Processes are now reset with simulator
Browse files Browse the repository at this point in the history
The simulator now re-makes all `step` functions on a reset, which
should reset all operators. Fixes #616.

The simulator seed (currently only used by Processes, I think) can
be changed when calling `Simulator.reset`.
  • Loading branch information
hunse committed Feb 18, 2015
1 parent f943ea1 commit 2647bd9
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 17 deletions.
47 changes: 30 additions & 17 deletions nengo/simulator.py
Expand Up @@ -84,22 +84,19 @@ def __init__(self, network, dt=0.001, seed=None, model=None):
A network object to the built and then simulated.
If a fully built ``model`` is passed in, then you can skip
building the network by passing in network=None.
dt : float
dt : float, optional
The length of a simulator timestep, in seconds.
seed : int
seed : int, optional
A seed for all stochastic operators used in this simulator.
Note that there are not stochastic operators implemented
currently, so this parameters does nothing.
model : nengo.builder.Model instance or None
model : nengo.builder.Model instance or None, optional
A model object that contains build artifacts to be simulated.
Usually the simulator will build this model for you; however,
if you want to build the network manually, or to inject some
build artifacts in the Model before building the network,
then you can pass in a ``nengo.builder.Model`` instance.
"""
dt = float(dt) # make sure it's a float (for division purposes)

if model is None:
dt = float(dt) # make sure it's a float (for division purposes)
self.model = Model(dt=dt,
label="%s, dt=%f" % (network, dt),
decoder_cache=get_default_decoder_cache())
Expand All @@ -112,27 +109,24 @@ def __init__(self, network, dt=0.001, seed=None, model=None):

self.model.decoder_cache.shrink()

self.seed = np.random.randint(npext.maxint) if seed is None else seed
self.rng = np.random.RandomState(self.seed)

# -- map from Signal.base -> ndarray
self.signals = SignalDict(__time__=np.asarray(0.0, dtype=np.float64))
for op in self.model.operators:
op.init_signals(self.signals)

# Order the steps (they are made in `Simulator.reset`)
self.dg = operator_depencency_graph(self.model.operators)
self._step_order = [node for node in toposort(self.dg)
if hasattr(node, 'make_step')]
self._steps = [node.make_step(self.signals, dt, self.rng)
for node in self._step_order]
self._step_order = [op for op in toposort(self.dg)
if hasattr(op, 'make_step')]

# Add built states to the probe dictionary
self._probe_outputs = self.model.params

# Provide a nicer interface to probe outputs
self.data = ProbeDict(self._probe_outputs)

self.reset()
seed = np.random.randint(npext.maxint) if seed is None else seed
self.reset(seed=seed)

@property
def dt(self):
Expand Down Expand Up @@ -240,14 +234,33 @@ def run_steps(self, steps, progress_bar=True):
self.step()
progress.step()

def reset(self):
"""Reset the simulator state."""
def reset(self, seed=None):
"""Reset the simulator state.
Parameters
----------
seed : int, optional
A seed for all stochastic operators used in the simulator.
This will change the random sequences generated for noise
or inputs (e.g. from Processes), but not the built objects
(e.g. ensembles, connections).
"""
if seed is not None:
self.seed = seed

self.n_steps = 0
self.signals['__time__'][...] = 0

# reset signals
for key in self.signals:
if key != '__time__':
self.signals.reset(key)

# rebuild steps (resets ops with their own state, like Processes)
self.rng = np.random.RandomState(self.seed)
self._steps = [op.make_step(self.signals, self.dt, self.rng)
for op in self._step_order]

# clear probe data
for probe in self.model.probes:
self._probe_outputs[probe] = []
19 changes: 19 additions & 0 deletions nengo/tests/test_processes.py
Expand Up @@ -172,6 +172,25 @@ def test_sampling_shape():
assert process.run_steps(1, d=2). shape == (1, 2)


def test_reset(seed):
trun = 0.1

with nengo.Network() as model:
u = nengo.Node(WhiteNoise(Gaussian(0, 1), scale=False))
up = nengo.Probe(u)

sim = nengo.Simulator(model, seed=seed)

sim.run(trun)
x = np.array(sim.data[up])

sim.reset()
sim.run(trun)
y = np.array(sim.data[up])

assert (x == y).all()


if __name__ == "__main__":
nengo.log(debug=True)
pytest.main([__file__, "-v"])

0 comments on commit 2647bd9

Please sign in to comment.