Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Table look-up #1

Closed
jajcayn opened this issue Mar 7, 2020 · 9 comments
Closed

Table look-up #1

jajcayn opened this issue Mar 7, 2020 · 9 comments

Comments

@jajcayn
Copy link

jajcayn commented Mar 7, 2020

Hey again :)
Not sure whether this should come here or to specific jitc*de I am using, but I feel it is a general question / problem.

So, I would need to make a table look-up as part of my dynamics. The thing is I am using a model (mean-field approximation of spiking neuron model) which uses pre-computed quantities in order to avoid solving Fokker-Planck equation each time step. So this table "converts" mean and standard deviation of membrane potential into firing rate. You can imagine it as

x_idx = int(<some function of mean potential>)
y_idx = int(<some function of std potential>)
firing_rate = table[x_idx, y_idx]

The problem is, of course, that the indices x_idx and y_idx are symbolic. Just to make a point x_idx is

floor(14.0*(-0.5 + sqrt(2.25 + 472.392*current_y(9)/(2.0 + 20.0*(1.0 + del_fr_exc_exc)) + 2178.0*current_y(10)/(5.0 + 20.0*(1.0 + del_fr_exc_inh)))))

So what I tried:

symengine.Lambdify

the code is roughly as this:

def compute_from_cascade(..., x_idx, y_idx, dx_idx, dy_idx)
        table_se = sympy.MatrixSymbol("table_se", *table.shape)
        expr = (
            table_se[y_idx, x_idx] * (1 - dx_idx) * (1 - dy_idx)
            + table_se[y_idx, x_idx + 1] * dx_idx * (1 - dy_idx)
            + table_se[y_idx + 1, x_idx] * (1 - dx_idx) * dy_idx
            + table_se[y_idx + 1, x_idx + 1] * dx_idx * dy_idx
        )
        func = se.Lambdify(table_se, expr)
        return func(table)

so I need to define Matrix as a symbol and lambdify it in order to get correct indices for a table lookup. However, this is not working as

symengine.lib.symengine_wrapper.SympifyError: sympy2symengine: Cannot convert 'table_se[0, 0]' (of type <class 'sympy.matrices.expressions.matexpr.MatrixElement'>) to a symengine type.

Then I found out that MatrxSymbol is not implemented in symengine and nobody knows if and when they'd implement it (I've spent an afternoon searching their github issues)

Use sympy's fallback symbols y and t

so the other option was to use "slow" sympy (still faster I guess than pure python) and instead of

from jitcdde import t as time_vector
from jitcdde import y as state_vector_dde

I did

from jitcdde.sympy_symbols import t as time_vector
from jitcdde.sympy_symbols import y as state_vector_dde

unfortunately, this is not working as this time I am getting

TypeError: cannot determine truth value of Relational

unrelated to the table lookup and this is an earlier error, which means that sympy is not really fully compatible with symenigne since symengine has no problems in determining the truth value of Relational...

However, the problem is that in order to work this with sympy I would need to rewrite almost all integration logic which then would become unusable with symengine... Since this is the only model I am having problem implementing in jitc*de I really do not want to rewrite the core logic to be compatible with sympy, I am more than happy with symengine.

I guess you're not an expert on symengine but the main question is whether you have some idea how you would do symbolic table lookup, given that symengine does not implement 2D symbolic arrays? Should I somehow approximate the table with a function?

Thanks a lot!
N.

@Wrzlprmft
Copy link
Contributor

An inevitable problem with your approach is that there is currently no way to implement your table into the C code created by JiTC*DE. At first glance, addressing this seems like a major difficulty (compared to which your symbolic issues are peanuts). Moreover, I consider applications of this very limited (see below), so I am strongly leaning towards not implementing something like this.

To somewhat expand on the latter: One of the main reasons for existence of JiTC*DE is to avoid the look-ups of coupling matrices, since those are usually comparably expensive, in particular if the coupling matrices are sparse. More generally, the JiTC*DE way of dealing with tables is to hard-code them. In a related manner, JiTC*DE exist to make computing the derivative (right-hand side of the differential equation) as fast as possible, because that’s very often the bottle-neck. If the derivative includes some complex table look-up, which is the actual main bottle-neck, we are entering a different category of problem for which JiTC*DE wasn’t designed. Now, I am aware that this argument doesn’t fully translate to your application, but it may illustrate why what you want to do does not fit well into the general architecture of JiTC*DE.

Another problem with implementing your table as it is, is that the outcome depends on your system’s state in a non-continuous manner – which integrators don’t like at all.

With all that being said, here are some general approaches to solving your problem that I can think of:

  • What I could plausibly implement is the evaluation of Python functions as part of the derivative. This would allow you to implement your table look-up and have general applicability such as allowing for a general input. The disadvantage is that such a function could not be symbolically differentiated which prevents using Lyapunov exponents, stiff integrators, and some further tools. Also, you still would have to take care that this function depends continuously on the state and time of the system.

  • What you may be able to do is to replace your table look up by a function. This may be very complicated, using a lot of conditionals. That in turn may be unnecessarily slow. However, if you are more liberal on how to implement your table, a function approximating its look-up may considerably speed-up things. This could be a good example of applying a symbolic regression. (Of course, when you do something like this, try to go from your mean and standard deviation of membrane potentials to the table entries directly and do not compute the indices first.)

  • If the outcome of your table look-up changes slowly in comparison to the dynamics, implement the results as control parameters in JiTC*DE, and then:

    1. Integrate a bit.
    2. Perform the look-ups in Python.
    3. Change the respective control parameters.
    4. Continue integrating. Beware that for DDEs, you would have to address initial discontinuities now.
  • If your problem is an ODE, consider not using JiTCODE at all, but SciPy’s ODE routines, in which you can define the derivative as a Python function, in which you can do whatever you want. Of course, this is slower, but your table lookup may be your main bottleneck anyway. Also, keep in mind that you still want things to be continuous.

  • If by any chance, you can somehow translate your table-look up to be one-dimensional, you can use jitcdde_input for this.

@jajcayn
Copy link
Author

jajcayn commented Mar 8, 2020

Thanks for the response @Wrzlprmft, you are the MVP!

Just if interested, this is how the "table look-up" looks like. The inputs mu_I and std_I are computed in each timestep as a result of continuous dynamics (DDE) and from these the firing rate is approximated using the table (again, instead of solving Fokker-Planck equation for mean-field approximation).
image
Soo, it's not particularly nice. On the other hand, individual slices with mu_I fixed (i.e. vertical slices on the image) usually yield just exponential growth as a function of std_I with different exponent per mu_I value.

As for the proposed solutions:

  • What I could plausibly implement is the evaluation of Python functions as part of the derivative.

That's nice of you, but I would lose the speed I am getting from using jitc*de across my models, i.e. not an option. My current implementation of this particular model with table lookups is in numba so it's fast, but not compatible with my other models using jitc*de (by compatibility I mean way how the different models build up a circuit which is then integrated in time and handles all the connections between models).

  • What you may be able to do is to replace your table lookup by a function.

Of this I thought, maybe I'd resort for it. I have some experience with Gaussian Processes, but even if I would train one "off-line" (i.e. before jitting the integration), then the prediction (i.e. getting the actual value of firing rate during the integration) using the trained hyperparameters rely on complicated algebra like Cholesky decomposition and solving matrix equations. Symbolic regression might be the way to go, albeit I never used it and have no experience with it.

  • If the outcome of your table look-up changes slowly in comparison to the dynamics,

Unfortunately, not the case. Table lookups are done in each time-step and as you can see from the plot, they might jump.

  • If your problem is an ODE, consider not using JiTCODE at all, but SciPy’s ODE routines

Unfortunately, they are DDE (with noise, but I can precompute it -- I am using Ornstein-Uhlenbeck process as noisy input)

This I used (and actually asked about it in the jitcdde repo) for a different model, but I do not get how you mean. Inputs can be N-dimensional time-dependent arrays, which are stored in dummy's dynamical variable past as cubic splines, right? I would use time as the variable that is changing in my 1-D lookup?

@Wrzlprmft
Copy link
Contributor

  • What I could plausibly implement is the evaluation of Python functions as part of the derivative.

That's nice of you, but I would lose the speed I am getting from using jitc*de across my models, i.e. not an option. My current implementation of this particular model with table lookups is in numba so it's fast, but not compatible with my other models using jitc*de (by compatibility I mean way how the different models build up a circuit which is then integrated in time and handles all the connections between models).

Well, if you have your table look-up in an efficient Python function, the proposed change would allow you to call this function from a JiTC*DE derivative. Python would only act as an interface here. This would of course cause a little overhead, but it should not be too bad. If this is truly relevant to speed, one could also intercept the C code created by JiTC*DE and plug your table look-up into it. It’s a bit messy (but I can help you with it), but certainly easier than to provide a general-purpose solution that is more efficient than the proposed Python function calls.

Symbolic regression might be the way to go, albeit I never used it and have no experience with it.

From your diagram, I would strongly recommend it. This does not look utterly complicated and if it works, the result is probably considerably faster than your table look-up. The main annoyance will be the periodic discontinuity (so I wouldn’t make this part of the symbolic regression at first), but that will be a problem either way. Symbolic regression is capable of reconstructing the Bethe–Weizsäcker formula from data; I am quite confident that it can handle your problem.

This I used (and actually asked about it in the jitcdde repo) for a different model, but I do not get how you mean. Inputs can be N-dimensional time-dependent arrays, which are stored in dummy's dynamical variable past as cubic splines, right? I would use time as the variable that is changing in my 1-D lookup?

I think you got the idea. jitcdde_input essentially allows you to access additional one-dimensional datasets stored in cubic splines via the same infrastructure that allows you to access the past (which is a bunch of one-dimensional datasets stored in cubic splines). The typical use for this would make the our one-dimensional parameter time, but it can be anything else.

Anyway, looking at your data, this probably cannot be applied here.

@Wrzlprmft
Copy link
Contributor

@jajcayn: Have you made any progress with this? I am mainly asking because calling Python functions in the derivative is on my to-do list anyway, but if you have a use for it, it is no problem to prioritise it.

@jajcayn
Copy link
Author

jajcayn commented Mar 30, 2020

Hey sorry, for the delayed response.
Well, no luck unfortunately so far:( I did try symbolic regression (using gplearn) but to no avail - the regression did catch the main features of the "table" but failed to fit the "details" which are actually pretty important for the neural dynamics.
Then I spent a day or so playing with my problem so that I can maybe plug it in via jitcdde_input, but again - not successfully. I wasn't able to truly turn it into a 1D lookup. Plus I was having problems with out-of-bounds lookup. In the original python function I am using, this is solved by very simple interpolation (not extrapolation - if your table "ends" at mu=7. and your input is 7.1 you would take table value at mu=7 and then add a bit using dx). I ended up using a lot of conditionals (something in the light of input(time=max(min(current_mu, 0.5)), 7)) but still, I wasn't getting correct results at all.

So in the end, I must say, I will be really glad for evaluating python function within the symbolic derivatives :) for the time being, I am still using old numba-based implementation but as I was saying in my previous post - this is problematic since I already implemented other models in jitc*de and there is no easy way how to couple the two models (which is my goal).

All-in-all, if you plan on working on python function within jited derivatives, I'd be grateful and I definitely try them out! Thanks!

@Wrzlprmft
Copy link
Contributor

Callbacks from the derivative are now implemented.

Please read the respective documentation (search for callback), get the latest version of JiTC*DE from GitHub, and tell me whether everything works for you.

@jajcayn
Copy link
Author

jajcayn commented Apr 18, 2020

Hey @Wrzlprmft, wonderful! Thank you very much!
I tried for the time being a very simple example:
Firstly I took your DDE example of Mackey–Glass

τ = 15
n = 10
β = 0.25
γ = 0.1

f = [β * y(0, t - τ) / (1 + y(0, t - τ) ** n) - γ * y(0) / 2.0 * y(0)]
DDE = jitcdde(f, callback_functions=())

DDE.constant_past([1.0])

DDE.step_on_discontinuities()

data = []
for time in numpy.arange(DDE.t, DDE.t+500, 10):
	data.append( DDE.integrate(time) )
plt.plot(data)

then played the game that the last guy in the equation (γ*y(0)) is actually a callback and did this:

def func(y, factor, gamma):
    return y[0] * factor * gamma

callback = se.Function("callback")

# when calling callback in derivatives definition, I omit the state vector as the mandatory first argument, is that correct?
f = [β * y(0, t - τ) / (1 + y(0, t - τ) ** n) - callback(y(0) / 2.0, γ)]
DDE = jitcdde(f, callback_functions=[(callback, func, 2)], verbose=True)

DDE.constant_past([1.0])

DDE.step_on_discontinuities()

data = []
for time in numpy.arange(DDE.t, DDE.t + 500, 10):
    data.append(DDE.integrate(time))
plt.plot(data)

and it works nicely. I tried to imitate my use case when each neural mass within the network will have three callback functions: each callback will be a table lookup (with different tables, that's why three) taking two arguments: y (will use one of the state variables - that is mean of the synaptic current) and a function of y (a non-linear combination of more state variables - that is the sigma of the synaptic current) which will be defined within the derivative function.

I'll try to implement the full network of neural masses using the callbacks in the evening.
Just to be sure: is my simple implementation of Mackey-Glass correct? Is it enough to define the symengine function object as callback = se.Function("callback"), or do I need to provide more context for the function object?

Thanks again, I'll let you know whether I succeed with full network implementation.
Best :)

@Wrzlprmft
Copy link
Contributor

Firstly I took your DDE example of Mackey–Glass

Just checking: You added that / 2.0 * y(0) at the end, right?

when calling callback in derivatives definition, I omit the state vector as the mandatory first argument, is that correct?

Yes.

Just to be sure: is my simple implementation of Mackey-Glass correct?

Apart from not being Mackey–Glass anymore, yes. The two scripts should be equivalent. (I suppose you already checked that they produce the same output, so I will not double-check.)

Is it enough to define the symengine function object as callback = se.Function("callback"), or do I need to provide more context for the function object?

That should be enough – unless you want to do something that requires the derivative of the f, e.g. computing Lyapunov exponents. Then things get a bit more tricky.

@jajcayn
Copy link
Author

jajcayn commented Apr 18, 2020

Yes, I added division by 2*y(0) just to try callback with more parameters and some of them depending on the state vector.

the scripts are indeed equivalent, both yield the same integrated time-series.

I also successfully implement table lookup in my neural models, everything works as expected:) thanks again, I think you can close the issue.
Best,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants