Skip to content

Commit

Permalink
Avoid overflow on zero-weight connections
Browse files Browse the repository at this point in the history
By reducing `vth_max` for these connections.

This is most common on learning connections with zeroed
initial weights, so we'll test in `test_multiple_pes`.
  • Loading branch information
hunse committed Dec 7, 2018
1 parent e1c4f9b commit 3dd3b3f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
20 changes: 16 additions & 4 deletions nengo_loihi/loihi_cx.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,15 @@ def discretize(target, value):
self.vmax = 2**(9 + 2*vmaxe) - 1

# --- discretize weights and vth
# To avoid overflow, we can either lower vth_max or lower wgtExp_max.
# Lowering vth_max is more robust, but has the downside that it may
# force smaller wgtExp on connections than necessary, potenially
# leading to lost weight bits (see SynapseFmt.discretize_weights).
# Lowering wgtExp_max can let us keep vth_max higher, but overflow
# is still be possible on connections with many small inputs (uncommon)
vth_max = VTH_MAX
wgtExp_max = 0

w_maxs = [s.max_abs_weight() for s in self.synapses]
w_max = max(w_maxs) if len(w_maxs) > 0 else 0
b_max = np.abs(self.bias).max()
Expand All @@ -258,12 +267,12 @@ def discretize(target, value):
w_scale = (255. / w_max)
s_scale = 1. / (u_infactor * v_infactor)

for wgtExp in range(0, -8, -1):
for wgtExp in range(wgtExp_max, -8, -1):
v_scale = s_scale * w_scale * SynapseFmt.get_scale(wgtExp)
b_scale = v_scale * v_infactor
vth = np.round(self.vth * v_scale)
bias = np.round(self.bias * b_scale)
if (vth <= VTH_MAX).all() and (np.abs(bias) <= BIAS_MAX).all():
if (vth <= vth_max).all() and (np.abs(bias) <= BIAS_MAX).all():
break
else:
raise BuildError("Could not find appropriate wgtExp")
Expand All @@ -274,14 +283,17 @@ def discretize(target, value):
w_scale = b_scale * u_infactor / SynapseFmt.get_scale(wgtExp)
vth = np.round(self.vth * v_scale)
bias = np.round(self.bias * b_scale)
if np.all(vth <= VTH_MAX):
if np.all(vth <= vth_max):
break

b_scale /= 2.
else:
raise BuildError("Could not find appropriate bias scaling")
else:
v_scale = np.array([VTH_MAX / (self.vth.max() + 1)])
# reduce vth_max in this case to avoid overflow since we're setting
# all vth to vth_max (esp. in learning with zeroed initial weights)
vth_max = min(vth_max, 2**Q_BITS - 1)
v_scale = np.array([vth_max / (self.vth.max() + 1)])
vth = np.round(self.vth * v_scale)
b_scale = v_scale * v_infactor
bias = np.round(self.bias * b_scale)
Expand Down
4 changes: 3 additions & 1 deletion nengo_loihi/tests/test_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def test_pes_comm_channel(allclose, plt, seed, Simulator, n_per_dim, dims):
assert allclose(y_loihi, y_nengo, atol=0.15, rtol=0.1)


def test_multiple_pes(allclose, plt, seed, Simulator):
@pytest.mark.parametrize('init_function', [None, lambda x: 0])
def test_multiple_pes(init_function, allclose, plt, seed, Simulator):
n_errors = 5
targets = np.linspace(-0.9, 0.9, n_errors)
with nengo.Network(seed=seed) as model:
Expand All @@ -77,6 +78,7 @@ def test_multiple_pes(allclose, plt, seed, Simulator):
conn = nengo.Connection(
pre_ea.ea_ensembles[i],
output[i],
function=init_function,
learning_rule_type=nengo.PES(learning_rate=3e-3),
)
nengo.Connection(target[i], conn.learning_rule, transform=-1)
Expand Down

0 comments on commit 3dd3b3f

Please sign in to comment.