-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
any suggestions on how to improve performance for gradient step with odeint? #3993
Comments
It's hard to say what's going wrong without more diagnostic information. Ideally you could share a simplified example that illustrates the problem, but short of that:
|
Thank you. I will try to create a self contained example. The ode system is a set of coupled epidemiological models, box-car SEIRD. The coupling matrix and mobility are available at discrete time points, day scale.
nn_batch is the neural network part. Defined in the following way mostly to be able to use bfgs, lbfgs etc optimizers as well. Initial optimization is done using rmsprop_momentum with momentum decay, which works really well for this problem
|
Is it slow to execute (after compiling), or just to compile? That is, if you try evaluating the gradient a second time, is it faster on the second evaluation? |
As a wild guess, not packing values together with |
Matt -- I'm pretty sure we implicitly concatenate output arrays from odeint currently (that's why we think tree-vectorize would help). So I doubt that part will help. |
if helpful, my loss function is below:
@mattjj stacking was faster faster for the gradient step (5 s vs 14s). I was able to get the gradient step down to 5s by relaxing the odeint tolerances. |
I think the code snippets you've included in this issue are very close to being a full self-contained example (something we'd be able to run), but not quite there (e.g. they don't include I can't think of a good reason for the gradient to take so much longer than the forward pass (though maybe there are fixed-point iterations that take more steps?), so I agree that these times feel like something's wrong. Are you on CPU or GPU? |
@jekbradbury I have tried the same code on CPU and GPU (V100). GPU execution is a factor of ~2 faster. But the timing ratios are roughly the same. I have included the necessary files here. They are ndarrays. Hopefully, this is helpful. |
I have added a self-contained example (with the necessary data) here. |
@ibulu thanks for providing the example! Running your code locally on my laptop's CPU, I noticed that this is actually a case where both the compilation/tracing and evaluation are slow:
|
ohh that sounds painful :-) The example I included is a smaller version of the actual model I am trying to train, which is about ~500 coupled meta-populations. |
@shoyer thanks for looking into this. I have implemented the benchmarks in Julia included here Julia benchmarks. Though, this may not be the most efficient implementation, it gives an idea about what to expect. Julia has quite a few other local sensitivity algorithms implemented. I have tried with BackSolveAdjoint, which I believe is what's implemented in odeint. Also, everything in the Julia benchmarks are 64bit. So, the comparison may not be fully representative of the performance gap. |
Very interesting, thanks for sharing the Julia benchmarks! Yes, I think BackSolveAdjoint is exactly what we use in JAX. It's encouraging that we a similar runtime trend: 300 ms for the forward pass vs 18 seconds for the backward pass. Have you tried running one of the other adjoint calculation methods? I would be very curious if one of the other approaches has much better performance. That could be good motivation for implementing one of them in JAX. |
@shoyer |
I should mention that, in practice, I don't use InterpolatingAdjoint with autojacvec and precompile. Just never got good results in downstream tasks that I used it for so far. Usually, BacksolveAdjoint(autojacvec = ReverseDiffVJP(true)) works best for the type of problems I have worked with. So, I thought I'd include the benchmarks for that as well. |
By the way, XLA:CPU (and hence JAX on CPU) has some known performance issues with float64, e.g. I believe the 64bit GEMM kernel being called is a slow Eigen one, while the 32bit one is from MKL-DNN. I mention it just because I don't want it to be a confounder, though I haven't looked at the details here enough to know if it could be relevant. I'd love if an outcome of this investigation is that we should add some alternative VJPs for odeint, because that sounds really fun! |
@mattjj Happy to provide evidence in that direction :-) This is with 64bit disabled: |
I really would love to see how many function evaluations are being done in the backward pass, but that will require fixing #4015 first. One thing that puzzles me a little bit is that I thought there were some general guarantees about gradient calculations not being much more expensive than the forward pass (aside from increased memory consumption). I guess that can be violated here because we aren't differentiating through the implementation of the forward pass? If it's feasible, I would be curious how doing that in Julia compares (I think |
I think ReverseDiffAdjoint is probably a lot more slower than the others. I had to kill the computation after waiting for about 5 mins. And from the look of it, it consumes a lot more memory |
I tried to implement my own function call counter to count calls to runge_kutta_step (without success). I am guessing this is probably where the performance hit is happening during the backward call. I am wondering whether this behavior is common, meaning large performance difference between forward and backward call. The other surprising thing was I didn't see any performance improvement with jit. |
I made a notebook that shows how to use host_callback to benchmark an ODE solver: It includes an example solving the dynamics of damped oscillator. One nice thing about this system is that it's easy to make it "stiff" (in both the literal and ODE sense) by increasing For most parameter choices, the ratio of evaluations for calculating the gradient versus only the forward pass is around 3x. This is what we would expect if we need the same number of evaluations for both the forward and backward solve, because the ODE gets evaluated twice in each backward step (once for the primal and once for the cotangent). I've been able to increase the ratio up to 30x or so by increasing the effective level of damping. This makes sense in the context of our adaptive ODE solver:
That said, the gradients in scenario are totally unreliable, because the forward solution is effectively just "approximately zero" (within the tolerance of the ODE solvers). We also have large floating point error, and gradients that either explode or go to zero. I'm not entirely sure about the general conclusions of this exercise. One tentative answer (validated by the similar results in Julia) is that perhaps long gradient evaluation costs are an indication your gradient calculation itself is ill-posed? |
@shoyer |
Yes, I agree with these! Given the long compilation times, I would also add (d) improving compilation speed in JAX and/or XLA. But that really needs a more specific benchmark, first. |
cool. I think I have seen the Jax implementation of the BDF solver somewhere. Will try to plug that in for the backward call. |
to get @mattjj excited :-) I think something like this may help: Adaptive Checkpoint Adjoint Method |
To be honest, I'm rather shocked by this paper, which fails to mention the extensive pre-existing literature on gradient checkpointing for adjoint methods. Gradient checkpointing is very well established (since 1990s) in the context of autodiff, e.g., see "Related Work" section of this paper for a more balanced perspective: https://arxiv.org/pdf/1902.10298.pdf. We really should have methods like this in JAX. One challenge is that we will need to require that the user pre-specify the number of distinct gradient checkpoints to use, because XLA can't do dynamic memory allocation inside JIT. |
@shoyer agreed. there is a rich literature on checkpointing.
|
@shoyer wondering if you saw this FASTER ODE ADJOINTS WITH 12 LINES OF CODE |
Yes, that also looks like a good idea for us
…On Tue, Oct 27, 2020 at 7:34 PM ibulu ***@***.***> wrote:
@shoyer <https://github.com/shoyer> wondering if you saw this [FASTER ODE
ADJOINTS WITH 12 LINES OF CODE(https://arxiv.org/pdf/2009.09457.pdf)
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#3993 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJJFVXNADL5NS6YWVYMNXTSM57LBANCNFSM4PYA6SEA>
.
|
Just commenting on this - did anything happen here? We are integrating a simple ODE (22-dimensional linear homogeneous with a 4-parameter inhomogeneous term) and finding that the jitted model takes about 5 ms, the gradient about 600, and the jacrev(jacrev(fun)) hessian takes about a full minute. It would be awesome to know how we could speed up the gradients. |
For those watching this thread -- check out Diffrax, which is a new library of differential equation solvers written in pure JAX. Of relevance to this discussion:
|
I have a large system of coupled ODEs (~280x16). The loss computation, which involves solving odes (I am using odeint from experimental), takes 163ms. The gradient step takes ~33 secs. It is a relatively straightforward parameter estimation problem. Except, one of the parameters is modeled by a neural network as a function of some input parameter.
The text was updated successfully, but these errors were encountered: