Skip to content

Commit

Permalink
Processes are now reset with simulator
Browse files Browse the repository at this point in the history
The simulator now re-makes all `step` functions on a reset, which
should reset all operators. Fixes #616.

The simulator seed (currently only used by Processes, I think) can
be changed when calling `Simulator.reset`. Fixes #582.
  • Loading branch information
hunse authored and tbekolay committed May 15, 2015
1 parent c129c51 commit d56b724
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 17 deletions.
47 changes: 30 additions & 17 deletions nengo/simulator.py
Expand Up @@ -84,22 +84,19 @@ def __init__(self, network, dt=0.001, seed=None, model=None):
A network object to the built and then simulated.
If a fully built ``model`` is passed in, then you can skip
building the network by passing in network=None.
dt : float
dt : float, optional
The length of a simulator timestep, in seconds.
seed : int
seed : int, optional
A seed for all stochastic operators used in this simulator.
Note that there are not stochastic operators implemented
currently, so this parameters does nothing.
model : nengo.builder.Model instance or None
model : nengo.builder.Model instance or None, optional
A model object that contains build artifacts to be simulated.
Usually the simulator will build this model for you; however,
if you want to build the network manually, or to inject some
build artifacts in the Model before building the network,
then you can pass in a ``nengo.builder.Model`` instance.
"""
dt = float(dt) # make sure it's a float (for division purposes)

if model is None:
dt = float(dt) # make sure it's a float (for division purposes)
self.model = Model(dt=dt,
label="%s, dt=%f" % (network, dt),
decoder_cache=get_default_decoder_cache())
Expand All @@ -112,27 +109,24 @@ def __init__(self, network, dt=0.001, seed=None, model=None):

self.model.decoder_cache.shrink()

self.seed = np.random.randint(npext.maxint) if seed is None else seed
self.rng = np.random.RandomState(self.seed)

# -- map from Signal.base -> ndarray
self.signals = SignalDict(__time__=np.asarray(0.0, dtype=np.float64))
for op in self.model.operators:
op.init_signals(self.signals)

# Order the steps (they are made in `Simulator.reset`)
self.dg = operator_depencency_graph(self.model.operators)
self._step_order = [node for node in toposort(self.dg)
if hasattr(node, 'make_step')]
self._steps = [node.make_step(self.signals, dt, self.rng)
for node in self._step_order]
self._step_order = [op for op in toposort(self.dg)
if hasattr(op, 'make_step')]

# Add built states to the probe dictionary
self._probe_outputs = self.model.params

# Provide a nicer interface to probe outputs
self.data = ProbeDict(self._probe_outputs)

self.reset()
seed = np.random.randint(npext.maxint) if seed is None else seed
self.reset(seed=seed)

@property
def dt(self):
Expand Down Expand Up @@ -240,14 +234,33 @@ def run_steps(self, steps, progress_bar=True):
self.step()
progress.step()

def reset(self):
"""Reset the simulator state."""
def reset(self, seed=None):
"""Reset the simulator state.
Parameters
----------
seed : int, optional
A seed for all stochastic operators used in the simulator.
This will change the random sequences generated for noise
or inputs (e.g. from Processes), but not the built objects
(e.g. ensembles, connections).
"""
if seed is not None:
self.seed = seed

self.n_steps = 0
self.signals['__time__'][...] = 0

# reset signals
for key in self.signals:
if key != '__time__':
self.signals.reset(key)

# rebuild steps (resets ops with their own state, like Processes)
self.rng = np.random.RandomState(self.seed)
self._steps = [op.make_step(self.signals, self.dt, self.rng)
for op in self._step_order]

# clear probe data
for probe in self.model.probes:
self._probe_outputs[probe] = []
20 changes: 20 additions & 0 deletions nengo/tests/test_processes.py
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import pytest

import nengo
import nengo.utils.numpy as npext
from nengo.dists import Distribution, Gaussian
from nengo.processes import BrownNoise, WhiteNoise, WhiteSignal
Expand Down Expand Up @@ -183,3 +184,22 @@ def test_sampling_shape():
assert process.run_steps(1).shape == (1, 1)
assert process.run_steps(5, d=1).shape == (5, 1)
assert process.run_steps(1, d=2). shape == (1, 2)


def test_reset(seed):
trun = 0.1

with nengo.Network() as model:
u = nengo.Node(WhiteNoise(Gaussian(0, 1), scale=False))
up = nengo.Probe(u)

sim = nengo.Simulator(model, seed=seed)

sim.run(trun)
x = np.array(sim.data[up])

sim.reset()
sim.run(trun)
y = np.array(sim.data[up])

assert (x == y).all()

0 comments on commit d56b724

Please sign in to comment.