Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

any suggestions on how to improve performance for gradient step with odeint? #3993

Open
ibulu opened this issue Aug 7, 2020 · 31 comments
Open

Comments

@ibulu
Copy link

ibulu commented Aug 7, 2020

I have a large system of coupled ODEs (~280x16). The loss computation, which involves solving odes (I am using odeint from experimental), takes 163ms. The gradient step takes ~33 secs. It is a relatively straightforward parameter estimation problem. Except, one of the parameters is modeled by a neural network as a function of some input parameter.

@shoyer
Copy link
Member

shoyer commented Aug 7, 2020

It's hard to say what's going wrong without more diagnostic information. Ideally you could share a simplified example that illustrates the problem, but short of that:

  • What platform (CPU/TPU/GPU) are you running this on?
  • Have you tried using Tensorboard profiling to see where the time is being spent?
  • How many odeint steps does your model take for forward and backwards steps? host_callback could help here.
  • What sort of ODE are you solving? The gradient method currently implemented in JAX is probably only suitable for a restricted class of ODEs (e.g., neural ODEs or wave equations) because it relies on the ODE being invertible. If you're trying to differentiate through long roll-outs of a dissipative system, the gradient calculation will be very ill-conditioned (which manifest itself in a large number of required time steps for the corresponding ode solve).

@ibulu
Copy link
Author

ibulu commented Aug 7, 2020

Thank you. I will try to create a self contained example. The ode system is a set of coupled epidemiological models, box-car SEIRD. The coupling matrix and mobility are available at discrete time points, day scale.

def SEIRD_mobility_coupled(u, t, p_, mobility_, coupling_matrix_):
    s, e, id1, id2, id3, id4, id5, id6, id7, d, ir1, ir2, ir3, ir4, ir5, r = u
    κ, α, γ = softplus(p_[:3])
    # κ*α and γ*η are not independent. The probablibility of transition from e to Ir and Id has to add up to 1
    η = - jnp.log(-jnp.expm1(-κ*α))/γ 
    ind = jnp.rint(t.astype(jnp.float32))
    n_c = coupling_matrix_.shape[0]
    scaler_ = softplus(p_[3:3+n_c])
    cm_ = jnp.expand_dims(scaler_,(1))*coupling_matrix_[...,ind.astype(jnp.int32)]
    β = nn_batch(p_[3+n_c:], mobility_[...,ind.astype(jnp.int32)])[:,0,0]
    i = id1+id2+id3+ir1+ir2+ir3+ir4+ir5
    
    a = β*s*i+β*s*(jnp.matmul(i,cm_.T)+jnp.matmul(cm_,i))
    ds = -a
    de = a - κ*α*e - γ*η*e
    
    d_id1 = κ*(α*e-id1)
    d_id2 = κ*(id1-id2)
    d_id3 = κ*(id2-id3)
    d_id4 = κ*(id3-id4)
    d_id5 = κ*(id4-id5)
    d_id6 = κ*(id5-id6)
    d_id7 = κ*(id6-id7)
    d_d = κ*id7
    
    d_ir1 = γ*(η*e-ir1)
    d_ir2 = γ*(ir1-ir2)
    d_ir3 = γ*(ir2-ir3)
    d_ir4 = γ*(ir3-ir4)
    d_ir5 = γ*(ir4-ir5)
    d_r = γ*ir5
    
    return jnp.stack([ds,
                      de,
                      d_id1, d_id2, d_id3, d_id4, d_id5, d_id6, d_id7, d_d,
                      d_ir1 ,d_ir2, d_ir3, d_ir4, d_ir5, d_r])

nn_batch is the neural network part. Defined in the following way mostly to be able to use bfgs, lbfgs etc optimizers as well. Initial optimization is done using rmsprop_momentum with momentum decay, which works really well for this problem

from functools import partial
from jax import random
from jax.nn.initializers import (xavier_normal, xavier_uniform, glorot_normal, glorot_uniform, uniform, 
                                 normal, lecun_uniform, lecun_normal,kaiming_uniform,kaiming_normal)

from jax.nn.functions import (softplus, selu,gelu,glu,swish,relu,relu6,elu,sigmoid, swish)
from jax import vmap, grad, partial, pmap, value_and_grad

key = random.PRNGKey(0)
layers = [7, 7, 7, 7, 1]
activations = [swish, swish, swish, softplus]
weight_initializer = xavier_normal
bias_initializer = normal

def init_layers(nn_layers,nn_weight_initializer_,
                nn_bias_initializer_):
    init_w = weight_initializer()
    init_b = bias_initializer()
    params = []
    for in_, out_ in zip(layers[:-1],layers[1:]):
        key = random.PRNGKey(in_)
        weights = init_w(key,(in_,out_)).reshape((in_*out_,))
        biases = init_b(key,(out_,))
        params_ = jnp.concatenate((weights,biases))
        params.append(params_)
    return jnp.concatenate(params)

def nnet(nn_layers, nn_activations, nn_params, x):
    n_s = 0
    x_in = jnp.expand_dims(x,axis=1) #
    #x_in = x.reshape(len(x),1)
    for in_,out_, act_ in zip(nn_layers[:-1],nn_layers[1:],nn_activations):
        n_w = in_*out_
        n_b = out_
        n_t = n_w+n_b
        weights = nn_params[n_s:n_s+n_w].reshape((out_,in_))
        biases = jnp.expand_dims(nn_params[n_s+n_w:n_s+n_t],axis=1)
        x_in = act_(jnp.matmul(weights,x_in)+biases)
        n_s += n_t

    return x_in
nn_batch = vmap(partial(nnet,layers,activations), (None,0),0)

@mattjj
Copy link
Member

mattjj commented Aug 8, 2020

Is it slow to execute (after compiling), or just to compile? That is, if you try evaluating the gradient a second time, is it faster on the second evaluation?

@mattjj
Copy link
Member

mattjj commented Aug 8, 2020

As a wild guess, not packing values together with jnp.stack (as in the return value of SEIRD_mobility_coupled) can help both compilation time and performance.

@shoyer
Copy link
Member

shoyer commented Aug 8, 2020

Matt -- I'm pretty sure we implicitly concatenate output arrays from odeint currently (that's why we think tree-vectorize would help). So I doubt that part will help.

@ibulu
Copy link
Author

ibulu commented Aug 8, 2020

if helpful, my loss function is below:

def diff(sol_,data_):
    l1 = jnp.square(jnp.ediff1d((1-sol_[:,0])) - data_[:,0])
    l2 = jnp.square(jnp.ediff1d(sol_[:,9]) - data_[:,1])
    return l1+20000*l2
diff_v = vmap(diff,(2,2))

def loss(data_,m_array_, coupling_matrix_, params_):
    sol_ = odeint(SEIRD_mobility_coupled, u0, t0, params_, m_array_,coupling_matrix_, 
                  rtol=1e-3, atol=1e-6)
    return jnp.sum(diff_v(sol_,data_)) 

loss_ = partial(loss, epi_array,mobilitypopulation_array_scaled,coupling_matrix)

@mattjj stacking was faster faster for the gradient step (5 s vs 14s). I was able to get the gradient step down to 5s by relaxing the odeint tolerances.

image

image

image

@jekbradbury
Copy link
Contributor

I think the code snippets you've included in this issue are very close to being a full self-contained example (something we'd be able to run), but not quite there (e.g. they don't include epi_array; note that fake data of the right shape and distribution would probably be fine to demonstrate the problem).

I can't think of a good reason for the gradient to take so much longer than the forward pass (though maybe there are fixed-point iterations that take more steps?), so I agree that these times feel like something's wrong.

Are you on CPU or GPU?

@ibulu
Copy link
Author

ibulu commented Aug 9, 2020

@jekbradbury I have tried the same code on CPU and GPU (V100). GPU execution is a factor of ~2 faster. But the timing ratios are roughly the same. I have included the necessary files here. They are ndarrays. Hopefully, this is helpful.

@ibulu
Copy link
Author

ibulu commented Aug 9, 2020

I have added a self-contained example (with the necessary data) here.

@shoyer
Copy link
Member

shoyer commented Aug 10, 2020

@ibulu thanks for providing the example!

Running your code locally on my laptop's CPU, I noticed that this is actually a case where both the compilation/tracing and evaluation are slow:

  • Forward compilation: 9 s
  • Forward evaluation: 120 ms
  • Gradient compilation: 84 s
  • Gradient evaluation: 36 s

@ibulu
Copy link
Author

ibulu commented Aug 10, 2020

ohh that sounds painful :-) The example I included is a smaller version of the actual model I am trying to train, which is about ~500 coupled meta-populations.
I tried jit with static_argnums argument. It lead to a minor (very minor) improvement.

@ibulu
Copy link
Author

ibulu commented Aug 12, 2020

@shoyer thanks for looking into this. I have implemented the benchmarks in Julia included here Julia benchmarks. Though, this may not be the most efficient implementation, it gives an idea about what to expect. Julia has quite a few other local sensitivity algorithms implemented. I have tried with BackSolveAdjoint, which I believe is what's implemented in odeint. Also, everything in the Julia benchmarks are 64bit. So, the comparison may not be fully representative of the performance gap.

@shoyer
Copy link
Member

shoyer commented Aug 12, 2020

Very interesting, thanks for sharing the Julia benchmarks!

Yes, I think BackSolveAdjoint is exactly what we use in JAX. It's encouraging that we a similar runtime trend: 300 ms for the forward pass vs 18 seconds for the backward pass.

Have you tried running one of the other adjoint calculation methods? I would be very curious if one of the other approaches has much better performance. That could be good motivation for implementing one of them in JAX.

@ibulu
Copy link
Author

ibulu commented Aug 12, 2020

@shoyer
to establish a baseline, I set both Julia and Jax to use 64bit. And use DP5 as ode solver (same tolerances etc). This is the benchmarks for JAX:
image
Julia with BacksolveAdjoint:
image
Julia with InterpolatingAdjoint:
image
With QuadratureAdjoint:
image
And the best for the last: InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true)). From DiffEqSensitivity docs:
"ReverseDiffVJP(compile=false): Uses ReverseDiff.jl for the vjp. compile is a boolean for whether to precompile the tape, which should only be done if there are no branches (if or while statements) in the f function. When applicable, ReverseDiffVJP(true) is the fastest method"

image

@ibulu
Copy link
Author

ibulu commented Aug 12, 2020

I should mention that, in practice, I don't use InterpolatingAdjoint with autojacvec and precompile. Just never got good results in downstream tasks that I used it for so far. Usually, BacksolveAdjoint(autojacvec = ReverseDiffVJP(true)) works best for the type of problems I have worked with. So, I thought I'd include the benchmarks for that as well.

image

@mattjj
Copy link
Member

mattjj commented Aug 12, 2020

By the way, XLA:CPU (and hence JAX on CPU) has some known performance issues with float64, e.g. I believe the 64bit GEMM kernel being called is a slow Eigen one, while the 32bit one is from MKL-DNN. I mention it just because I don't want it to be a confounder, though I haven't looked at the details here enough to know if it could be relevant.

I'd love if an outcome of this investigation is that we should add some alternative VJPs for odeint, because that sounds really fun!

@ibulu
Copy link
Author

ibulu commented Aug 12, 2020

@mattjj Happy to provide evidence in that direction :-) This is with 64bit disabled:
image
I'll try to see whether Julia allows me to do all the computations in 32 bit

@shoyer
Copy link
Member

shoyer commented Aug 12, 2020

I really would love to see how many function evaluations are being done in the backward pass, but that will require fixing #4015 first.

One thing that puzzles me a little bit is that I thought there were some general guarantees about gradient calculations not being much more expensive than the forward pass (aside from increased memory consumption). I guess that can be violated here because we aren't differentiating through the implementation of the forward pass? If it's feasible, I would be curious how doing that in Julia compares (I think ReverseDiffAdjoint() is the relevant method on CPU).

@ibulu
Copy link
Author

ibulu commented Aug 12, 2020

I think ReverseDiffAdjoint is probably a lot more slower than the others. I had to kill the computation after waiting for about 5 mins. And from the look of it, it consumes a lot more memory

@ibulu
Copy link
Author

ibulu commented Aug 17, 2020

I tried to implement my own function call counter to count calls to runge_kutta_step (without success). I am guessing this is probably where the performance hit is happening during the backward call. I am wondering whether this behavior is common, meaning large performance difference between forward and backward call. The other surprising thing was I didn't see any performance improvement with jit.

@shoyer
Copy link
Member

shoyer commented Aug 18, 2020

I made a notebook that shows how to use host_callback to benchmark an ODE solver:
https://gist.github.com/shoyer/9c6593ef6f65ddfcb394e96b90f87a72

It includes an example solving the dynamics of damped oscillator. One nice thing about this system is that it's easy to make it "stiff" (in both the literal and ODE sense) by increasing k. Note that this ODE is also linear, so we don't need to worry about numerics for recalculating the primal values -- although our gradient calculation doesn't know that, and recalculates the primal values anyways.

For most parameter choices, the ratio of evaluations for calculating the gradient versus only the forward pass is around 3x. This is what we would expect if we need the same number of evaluations for both the forward and backward solve, because the ODE gets evaluated twice in each backward step (once for the primal and once for the cotangent).

I've been able to increase the ratio up to 30x or so by increasing the effective level of damping. This makes sense in the context of our adaptive ODE solver:

  • The forward solution goes to near zero due to the damping, so the adaptive solver can take very large steps.
  • The adaptive solver for the backward pass can't take large steps because the cotangents don't start small.

That said, the gradients in scenario are totally unreliable, because the forward solution is effectively just "approximately zero" (within the tolerance of the ODE solvers). We also have large floating point error, and gradients that either explode or go to zero.

I'm not entirely sure about the general conclusions of this exercise. One tentative answer (validated by the similar results in Julia) is that perhaps long gradient evaluation costs are an indication your gradient calculation itself is ill-posed?

@ibulu
Copy link
Author

ibulu commented Aug 18, 2020

@shoyer
thank you for looking into this. This is awesome. I will give host_callbak a try right away. do you think the solution to this issue is to maybe:
a) use a different solver for backward call (maybe use a solver suitable for stiff problems)
b) simply rescale the ode to avoid numerical issues
c) use smaller tolerances (slow but potentially reliable gradients)
yeah, the correlation between Julia and Jax is really encouraging.

@shoyer
Copy link
Member

shoyer commented Aug 18, 2020

a) use a different solver for backward call (maybe use a solver suitable for stiff problems)
b) simply rescale the ode to avoid numerical issues
c) use smaller tolerances (slow but potentially reliable gradients)

Yes, I agree with these! Given the long compilation times, I would also add (d) improving compilation speed in JAX and/or XLA. But that really needs a more specific benchmark, first.

@ibulu
Copy link
Author

ibulu commented Aug 18, 2020

cool. I think I have seen the Jax implementation of the BDF solver somewhere. Will try to plug that in for the backward call.
@shoyer thanks for the script for the host_callback. I just tried it on the problem I am working on. The evaluation ratio is ~4. The forward call takes ~70ms and gradient steps takes 21s. I am puzzled :-)
I am hoping that there is a factor of 10 to be gained somewhere :-)
Only explanation I can come up with is that there must be additional overhead somewhere.

@ibulu
Copy link
Author

ibulu commented Aug 20, 2020

to get @mattjj excited :-) I think something like this may help: Adaptive Checkpoint Adjoint Method

@shoyer
Copy link
Member

shoyer commented Aug 20, 2020

to get @mattjj excited :-) I think something like this may help: Adaptive Checkpoint Adjoint Method

To be honest, I'm rather shocked by this paper, which fails to mention the extensive pre-existing literature on gradient checkpointing for adjoint methods. Gradient checkpointing is very well established (since 1990s) in the context of autodiff, e.g., see "Related Work" section of this paper for a more balanced perspective: https://arxiv.org/pdf/1902.10298.pdf.

We really should have methods like this in JAX. One challenge is that we will need to require that the user pre-specify the number of distinct gradient checkpoints to use, because XLA can't do dynamic memory allocation inside JIT.

@ibulu
Copy link
Author

ibulu commented Aug 20, 2020

@shoyer agreed. there is a rich literature on checkpointing.
I am wondering whether the number of checkpoints can be pre-computed from the forward run.
I think this investigation uncovered few things:

  • JAX implementation is on par with Julia
  • the number of function calls can be quite large during the backward call depending on the problem
  • there is potentially some large overhead during the backward call (the evaluation ratio was ~4, the execution time ratio is ~300). So there is some room for more optimization.
  • for the particular problem I shared, there is almost no benefit in using jit (at least on CPU). And I wonder whether this generalizes to other ode problems (I suspect it might)

@ibulu
Copy link
Author

ibulu commented Oct 28, 2020

@shoyer wondering if you saw this FASTER ODE ADJOINTS WITH 12 LINES OF CODE

@shoyer
Copy link
Member

shoyer commented Oct 28, 2020 via email

@benjaminpope
Copy link

Just commenting on this - did anything happen here? We are integrating a simple ODE (22-dimensional linear homogeneous with a 4-parameter inhomogeneous term) and finding that the jitted model takes about 5 ms, the gradient about 600, and the jacrev(jacrev(fun)) hessian takes about a full minute. It would be awesome to know how we could speed up the gradients.

@patrick-kidger
Copy link
Collaborator

patrick-kidger commented Feb 14, 2022

For those watching this thread -- check out Diffrax, which is a new library of differential equation solvers written in pure JAX. Of relevance to this discussion:

  • You can use a stiff solver on the backward pass if desired. (e.g. diffeqsolve(..., adjoint=BacksolveAdjoint(solver=Kvaerno5()))
  • You can backpropagate directly through the solver using checkpointing, if desired.
  • It sounds like flatten-a-PyTree-to-a-vector might have been a bottleneck in the original code. Diffrax supports using PyTree-valued state without flattening.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants