Skip to content

Commit

Permalink
Add thresholding preset.
Browse files Browse the repository at this point in the history
Addresses #1058.
  • Loading branch information
jgosmann committed May 31, 2016
1 parent 1a67961 commit 9516022
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 1 deletion.
2 changes: 1 addition & 1 deletion nengo/__init__.py
Expand Up @@ -31,7 +31,7 @@
from .simulator import Simulator
from .synapses import Alpha, LinearFilter, Lowpass, Triangle
from .utils.logging import log
from . import dists, exceptions, networks, processes, spa, utils
from . import dists, exceptions, networks, presets, processes, spa, utils

logger = logging.getLogger(__name__)
try:
Expand Down
40 changes: 40 additions & 0 deletions nengo/presets.py
@@ -0,0 +1,40 @@
"""Configuration presets for common use cases."""

import nengo


def ThresholdingPreset(threshold):
"""Configuration preset for a thresholding ensemble.
This preset adjust ensemble parameters for thresholding. The ensemble
neurons will only fire for values above the threshold. One can either
decode the represented value (if it is above the threshold) or decode a
step function if a binary classification is desired.
This preset sets:
- The intercepts to be between `threshold` and 1 with an exponential
distribution (shape parameter of 0.15). This clusters intercepts near
the threshold for a better approximation.
- The encoders to 1.
- The dimensions to 1.
- The evaluation points to be between `threshold` and 1. with a uniform
distribution.
Parameters
----------
threshold : float
Threshold of ensembles using this configuration preset.
Returns
-------
:class:`nengo.Config`
Configuration with presets.
"""
config = nengo.Config(nengo.Ensemble)
config[nengo.Ensemble].dimensions = 1
config[nengo.Ensemble].intercepts = nengo.dists.Exponential(
0.15, threshold, 1.)
config[nengo.Ensemble].encoders = nengo.dists.Choice([[1]])
config[nengo.Ensemble].eval_points = nengo.dists.Uniform(threshold, 1.)
return config
28 changes: 28 additions & 0 deletions nengo/tests/test_presets.py
@@ -0,0 +1,28 @@
import numpy as np

import nengo


def test_thresholding_preset(Simulator, seed, plt):
threshold = 0.3
with nengo.Network(seed) as model:
with nengo.presets.ThresholdingPreset(threshold):
ens = nengo.Ensemble(50, 1)
stimulus = nengo.Node(lambda t: t)
nengo.Connection(stimulus, ens)
p = nengo.Probe(ens, synapse=0.01)

with Simulator(model) as sim:
sim.run(1.)

plt.plot(sim.trange(), sim.trange(), label="optimal")
plt.plot(sim.trange(), sim.data[p], label="actual")
plt.xlabel("Time [s]")
plt.ylabel("Value")
plt.title("Threshold = {}".format(threshold))
plt.legend(loc='best')

se = np.square(np.squeeze(sim.data[p]) - sim.trange())

assert np.allclose(sim.data[p][sim.trange() < threshold], 0.0)
assert np.sqrt(np.mean(se[sim.trange() > 0.5])) < 0.05

0 comments on commit 9516022

Please sign in to comment.