-
Notifications
You must be signed in to change notification settings - Fork 175
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Parisien Transforms #921
Comments
Tagging myself here because I had an idea for how to fix the fact that this transform introduces additional filtering (and error propagation) with the intermediate population. |
ah kewl, yeah makes sense you could account for the delay! |
Has anyone looked at optimizing the E->I connection (rather than just setting it as a uniform constant connection weights)? It feels like you could do something like backprop there..... |
This would also be something good to think about when figuring out what sort of build-process hooks we want: #869 |
I really should get around to doing this. I was just reminded up it by Chris posting this paper, which does some nice analysis of what happens with pure inhibitory connections, and would be a good resource for implementing this... |
Here's a quick implementation of the standard Parisien transform in Nengo 2.0 def parisien_transform(conn, model, inh_synapse, inh_proportion=0.25):
# only works for ens->ens connections
assert isinstance(conn.pre_obj, nengo.Ensemble)
assert isinstance(conn.post_obj, nengo.Ensemble)
# make sure the pre and post ensembles have seeds so we can guarantee their params
if conn.pre_obj.seed is None:
conn.pre_obj.seed = np.random.randint(0x7FFFFFFF)
if conn.post_obj.seed is None:
conn.post_obj.seed = np.random.randint(0x7FFFFFFF)
# compute the encoders, decoders, and tuning curves
model2 = nengo.Network(add_to_container=False)
model2.ensembles.append(conn.pre_obj)
model2.ensembles.append(conn.post_obj)
model2.connections.append(conn)
sim = nengo.Simulator(model2)
enc = sim.data[conn.post_obj].encoders
dec = sim.data[conn].weights
eval_points = sim.data[conn].eval_points
pts, act = nengo.utils.ensemble.tuning_curves(conn.pre_obj, sim, inputs=eval_points)
# compute the original weights
transform = nengo.utils.builder.full_transform(conn)
w = np.dot(enc, np.dot(transform, dec))
# compute the bias function, bias encoders, bias decoders, and bias weights
total = np.sum(act, axis=1)
bias_d = np.ones(conn.pre_obj.n_neurons) / np.max(total)
bias_func = total / np.max(total)
bias_e = np.max(-w / bias_d, axis=1)
bias_w = np.outer(bias_e, bias_d)
# add the new model compontents
with model:
nengo.Connection(conn.pre_obj.neurons, conn.post_obj.neurons,
transform=bias_w,
synapse=conn.synapse)
inh = nengo.Ensemble(n_neurons = int(conn.pre_obj.n_neurons*inh_proportion),
dimensions = 1,
encoders = nengo.dists.Choice([[1]]),
intercepts= nengo.dists.Uniform(0, 1))
nengo.Connection(conn.pre_obj, inh,
solver=nengo.solvers.NnlsL2(),
transform=1,
synapse=inh_synapse,
**nengo.utils.connection.target_function(pts, bias_func))
nengo.Connection(inh, conn.post_obj.neurons,
solver=nengo.solvers.NnlsL2(),
transform=-bias_e[:,None]) And here's how to use it: model = nengo.Network()
with model:
stim = nengo.Node(lambda t: np.sin(t*np.pi*2))
a = nengo.Ensemble(100, 1)
b = nengo.Ensemble(101, 1)
nengo.Connection(stim, a)
conn = nengo.Connection(a, b)
parisien_transform(conn, model, inh_synapse=conn.synapse) This should handle slices, functions, and transforms on the Connection too! |
Bump; is this the most recent implementation, or has anything else been done since in this direction? |
I think this may be slightly more recent: https://github.com/tcstewar/nengo_parisien/blob/master/Parisien%20v2.ipynb I also just committed a change I made ages ago to fix the incorrect synapses in that example...... |
Thanks! So, I finally got around to trying the secret approach I mentioned at the top of this thread. The idea is to implement a custom solver that declares (arbitrarily) that the first import time
import cvxpy
import numpy as np
from nengo.params import NumberParam
from nengo.solvers import Solver
class DalesSolver(Solver):
"""Solves for weights subject to Dale's principle."""
# TODO: needs testing (e.g., transforms), support for slicing
p_inh = NumberParam('p_inh', low=0, high=1)
def __init__(self, p_inh=0.2):
super(DalesSolver, self).__init__(weights=True)
self.p_inh = p_inh
def __call__(self, A, Y, rng=None, E=None):
pre_n_neurons = A.shape[1]
post_n_neurons = E.shape[1]
tstart = time.time()
i = int(self.p_inh * pre_n_neurons)
W = cvxpy.Variable(pre_n_neurons, post_n_neurons)
objective = cvxpy.Minimize(
cvxpy.sum_entries(cvxpy.square(A*W - Y.dot(E))))
constraints = [W[:i, :] <= 0,
W[i:, :] >= 0]
prob = cvxpy.Problem(objective, constraints)
value = prob.solve()
# Do an extra clip for minor numerical reasons
W = np.asarray(W.value)
W[:i, :] = W[:i, :].clip(None, 0)
W[i:, :] = W[i:, :].clip(0, None)
return W, {
'time': time.time() - tstart,
'status': prob.status,
'cost': value,
'i': i,} and then used as follows: nengo.Connection(a, b, solver=DalesSolver(p_inh)) The main benefit of this approach is that it does not introduce an intermediate ensemble with an additional round of filtering, and should therefore be preferable for situations where timing is important (e.g., Principle 3). It also appears to be slightly more accurate for the communication channel example I tried (based on Terry's code above), using Simulation time is about equal between the two approaches. The main reason training time is slow is because it solves for the all-to-all weight matrix. Note that (unless the Initially, I thought that the difference in accuracy was mostly due to the fact that I haven't added an L2 penalty term to my cost function. It is hard to be completely fair with regularization when using different solvers. However, I tried switching the solver in the Parisien approach to Things that would be nice to have:
And... here is my benchmarking code: import numpy as np
import matplotlib.pyplot as plt
import nengo
from nengo.utils.numpy import rmse
from nengolib import DoubleExp
def go(seed, n_neurons=100, p_inh=0.20,
u = lambda t: (np.sin(t*2*np.pi*3) / np.sqrt(2),
np.cos(t*2*np.pi*5) / np.sqrt(2)),
tau=0.005, tau_probe=0.03, T=2.0, verbose=False):
# p_inh is the proportion of *total pre* neurons that are inhibitory
# inh_proportion is a different quantity (n_inhibitory/n_excitatory)
# which is related as follows:
inh_proportion = p_inh / (1 - p_inh)
assert np.allclose(p_inh, inh_proportion / (1 + inh_proportion))
n_neurons_cmp = n_neurons*(1 - p_inh) # for parisien_transform
assert np.allclose(n_neurons_cmp*(1 + inh_proportion), n_neurons)
with nengo.Network(seed=seed) as model:
# model.config[nengo.Connection].solver = nengo.solvers.Lstsq()
stim = nengo.Node(u)
a1 = nengo.Ensemble(int(n_neurons_cmp), 2)
a2 = nengo.Ensemble(n_neurons, 2)
b1 = nengo.Ensemble(n_neurons, 2)
b2 = nengo.Ensemble(n_neurons, 2)
nengo.Connection(stim, a1, synapse=None)
nengo.Connection(stim, a2, synapse=None)
conn1 = nengo.Connection(a1, b1, synapse=tau)
conn2 = nengo.Connection(a2, b2, synapse=tau, solver=DalesSolver(p_inh))
p_stim = nengo.Probe(stim, synapse=DoubleExp(tau, tau_probe))
p1 = nengo.Probe(b1, synapse=tau_probe)
p2 = nengo.Probe(b2, synapse=tau_probe)
parisien_transform(
conn1, model, inh_synapse=conn1.synapse, inh_proportion=inh_proportion)
with nengo.Simulator(model, progress_bar=verbose) as sim:
sim.run(T, progress_bar=verbose)
# Check that Dale's principle has been satisfied
W = sim.data[conn2].weights
i = sim.data[conn2].solver_info['i']
assert W.shape == (b2.n_neurons, a2.n_neurons)
assert np.all(W[:, :i] <= 0)
assert np.all(W[:, i:] >= 0)
rmses = (rmse(sim.data[p_stim], sim.data[p1], axis=0),
rmse(sim.data[p_stim], sim.data[p2], axis=0))
if verbose:
print ("Solver Info:", sim.data[conn2].solver_info)
fig, ax = plt.subplots(2, 1, figsize=(6, 10))
for i in range(2):
ax[i].plot(sim.trange(), sim.data[p_stim], linestyle='--')
ax[0].set_title("parisien_transform (RMSE=%s)" % rmses[0].round(3))
ax[1].set_title("DalesSolver (RMSE=%s)" % rmses[1].round(3))
ax[0].plot(sim.trange(), sim.data[p1])
ax[1].plot(sim.trange(), sim.data[p2])
ax[1].set_xlabel("Time (s)")
plt.show()
return rmses import seaborn as sns
import pandas as pd
num_trials = 25
data = []
for seed in range(num_trials):
print(seed)
rmses = go(seed=seed, n_neurons=50)
for i, method in enumerate(("parisien_transform", "DalesSolver")):
for dimension in range(2):
data.append([rmses[i][dimension], method, str(dimension)])
df = pd.DataFrame(data, columns=["RMSE", "Method", "Dimension"])
plt.figure(figsize=(8, 6))
sns.boxplot(x="Dimension", y="RMSE", hue="Method", data=df)
plt.show() |
Awesome! That's a nice way to do it.... If we want a few more benchmark tasks, we can use the |
You could also pose this as a NNLS problem by making the activities for the inhibitory neurons negative, and having all positive weights. Then you could use a solver like the one in Scipy. Not sure if it would be any faster, but it is a more specific problem than the one you're posing to cvxpy, so I could see it being more efficient. |
I like this idea. Assuming you're talking about doing this for a full-weight matrix (i.e., @celiasmith suggested looking at the distribution of weights that we get from using this to implement an integrator. But first, I added L2-regularization by changing the objective function to: objective = cvxpy.Minimize(
cvxpy.sum_entries(cvxpy.square(A*W - Y.dot(E))) +
cvxpy.sum_entries(lmbda*cvxpy.square(W))) with import nengo
tau = 0.1
tau_probe = 0.005
dt = 0.001
with nengo.Network(seed=0) as model:
stim = nengo.Node(output=lambda t: 2*np.pi*5*np.cos(2*np.pi*5*t))
x = nengo.Ensemble(100, 1)
# Discrete Principle 3 (Voelker & Eliasmith, 2017; eq 21)
nengo.Connection(stim, x, transform=dt / (1 - np.exp(-dt/tau)), synapse=tau)
conn = nengo.Connection(x, x, synapse=tau, solver=DalesSolver())
p_stim = nengo.Probe(stim, synapse=tau_probe)
p_x = nengo.Probe(x, synapse=tau_probe)
with nengo.Simulator(model, dt=dt) as sim:
sim.run(1.0)
print("Solver Info:", sim.data[conn].solver_info) import matplotlib.pyplot as plt
from nengo.utils.numpy import rmse
ideal = np.cumsum(sim.data[p_stim], axis=0)*dt
plt.figure()
plt.title("RMSE: %s" % (rmse(sim.data[p_x], ideal)))
plt.plot(sim.trange(), ideal, linestyle='--', lw=5)
plt.plot(sim.trange(), sim.data[p_x], alpha=0.5)
plt.show() import seaborn as sns
W = sim.data[conn].weights
i = sim.data[conn].solver_info['i']
fig, ax = plt.subplots(2, 2, figsize=(8, 8))
for axis, sW, label in (
(ax[0, 0], W[:i, :i], "I->I"),
(ax[0, 1], W[:i, i:], "E->I"),
(ax[1, 0], W[i:, :i], "I->E"),
(ax[1, 1], W[i:, i:], "E->E"),):
sparsity = len(np.where(np.abs(sW.flatten()) < 1e-8)[0]) / float(sW.size)
axis.set_title("%s (%d%% sparsity)" % (label, 100*sparsity))
sns.heatmap(sW, ax=axis, center=0, cmap="BrBG",
xticklabels=False, yticklabels=False)
plt.show() The above visualizes the block structure of |
I tried your suggestion @hunse and it sped things up by a factor of ~50x for 50 neurons! I guess cvxpy is not being clever in this case. :( I just replaced the cvxpy code with the following: from scipy.optimize import nnls
A[:, :i] *= (-1)
W = np.empty((pre_n_neurons, post_n_neurons))
J = Y.dot(E)
for j in range(post_n_neurons):
W[:, j], _ = nnls(A, J[:, j])
W[:i, :] *= (-1) The accuracy is the same (without regularization). To add regularization we can do the same thing you did for |
@arvoelke is it okay if I take this parisien transform code and make a PR into |
I was not so that's fine by me. You may want to use the above version (without cvxpy) because it's faster, and to also transpose the approach from NnlsL2 to get a version with L2-regularization. |
For reference sake, a bunch of the above is extended / tested / compared in https://github.com/tcstewar/nengo_solver_dales and featured in https://github.com/astoeckel/nengo-bio. |
An example implementing a Parisien transform in Nengo should be created (and included in utils?), as discussed in the meeting today. According to Chris, it's used to go from a network with positive and negative weights, to something more biologically plausible. It's described in this paper.
The text was updated successfully, but these errors were encountered: