Skip to content

Commit

Permalink
Simplified test_learning.test_pes_comm_channel
Browse files Browse the repository at this point in the history
Can directly compare delayed input and output.
  • Loading branch information
hunse committed Oct 2, 2018
1 parent 5197019 commit f10dd23
Showing 1 changed file with 25 additions and 22 deletions.
47 changes: 25 additions & 22 deletions nengo_loihi/tests/test_learning.py
Expand Up @@ -8,6 +8,7 @@
def test_pes_comm_channel(allclose, plt, seed, Simulator, n_per_dim, dims):
scale = np.linspace(1, 0, dims + 1)[:-1]
input_fn = lambda t: np.sin(t * 2 * np.pi) * scale
tau = 0.01

with nengo.Network(seed=seed) as model:
stim = nengo.Node(input_fn)
Expand All @@ -19,12 +20,12 @@ def test_pes_comm_channel(allclose, plt, seed, Simulator, n_per_dim, dims):
conn = nengo.Connection(
pre, post,
function=lambda x: np.zeros(dims),
synapse=0.01,
learning_rule_type=nengo.PES(learning_rate=1e-3))
synapse=tau,
learning_rule_type=nengo.PES(learning_rate=2e-4))

error = nengo.Node(None, size_in=dims)
nengo.Connection(post, error)
nengo.Connection(stim, error, transform=-1)
nengo.Connection(stim, error, transform=-1, synapse=tau)
nengo.Connection(error, conn.learning_rule)

p_stim = nengo.Probe(stim, synapse=0.02)
Expand All @@ -35,30 +36,32 @@ def test_pes_comm_channel(allclose, plt, seed, Simulator, n_per_dim, dims):
sim.run(5.0)

t = sim.trange()
plt.subplot(211)
plt.plot(t, sim.data[p_stim])
stim_delayed_pre = nengo.Lowpass(sim.model.inter_tau).filt(
sim.data[p_stim])
stim_delayed_post = nengo.Lowpass(tau).filt(stim_delayed_pre)
post_tmask = t > 4.0

rows, cols = 3, 1
plt.subplot(rows, cols, 1)
plt.plot(t, stim_delayed_pre)
plt.plot(t, sim.data[p_pre])
plt.plot(t, sim.data[p_post])

# --- fit input_fn to output, determine magnitude
# The larger the magnitude, the closer the output is to the input
x = np.array([input_fn(tt)[0] for tt in t[t > 4]])
y = sim.data[p_post][t > 4][:, 0]
m = np.linspace(0, 1, 21)
errors = np.abs(y - m[:, None]*x).mean(axis=1)
m_best = m[np.argmin(errors)]
plt.subplot(rows, cols, 2)
plt.plot(t, stim_delayed_post)
plt.plot(t, sim.data[p_post])

plt.subplot(212)
plt.plot(t[t > 4], x)
plt.plot(t[t > 4], y)
plt.plot(t[t > 4], m_best * x, ':')
plt.subplot(rows, cols, 3)
plt.plot(t[post_tmask], stim_delayed_post[post_tmask])
plt.plot(t[post_tmask], sim.data[p_post][post_tmask])

assert allclose(sim.data[p_pre][t > 0.1],
sim.data[p_stim][t > 0.1],
atol=0.15,
rtol=0.15)
assert np.min(errors) < 0.3, "Not able to fit correctly"
assert m_best > (0.3 if n_per_dim < 150 else 0.6)
stim_delayed_pre[t > 0.1],
atol=0.1,
rtol=0.05)
assert allclose(sim.data[p_post][post_tmask],
stim_delayed_post[post_tmask],
atol=0.1,
rtol=0.05)


def test_multiple_pes(allclose, plt, seed, Simulator):
Expand Down

0 comments on commit f10dd23

Please sign in to comment.