-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BDF solver for stiff ODEs #3781
base: main
Are you sure you want to change the base?
Conversation
Thanks for your pull request. It looks like this may be your first contribution to a Google open source project (if not, look below for help). Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). 📝 Please visit https://cla.developers.google.com/ to sign. Once you've signed (or fixed any issues), please reply here with What to do if you already signed the CLAIndividual signers
Corporate signers
ℹ️ Googlers: Go here for more info. |
@googlebot I signed it! |
CLAs look good, thanks! ℹ️ Googlers: Go here for more info. |
I have been benchmarking the current solver against SciPy's VODE with large chemical systems (e.g. systems with IVPs of ~600-1000) and the solver seems to be really really slow even on an accelerator (NVIDIA Tesla V100). I searched through issues and couldn't find anything solid to track recompilation. I assumed JIT+GPU should at least be comparable to CPU times for large systems. I profiled the code using NVProf but I can only see kernel calls and I'm in the process of understanding and setting up profiling using TensorboardX so I can track python function calls. Any help would be greatly appreciated regarding this as the potential use case for this solver (along with adjoint sensitivities) is to solve chemical systems in the order of ~10000-40000 IVPs. Thanks!!!! 🙏 |
Hey @skrsna , awesome work! Sorry we haven't had a chance to follow up; we're really slammed with work right now. About the slowness, that's something we'd certainly like to fix! Can you try profiling using the TensorBoard profiler? That can show a lot more information, including both host and device activity in detail and often with source code provenance information. |
Sure! I had trouble setting up Tensorboard profiler on our campus' cluster due to port/network issues and our on premise GPU workstation is really old so I kept running into CUDA OOM for large systems. I'll work with our sys admins to figure out Tensorboard. The project I'm working on is in really early stages of development so it's not that urgent anyway. Thanks!! |
If you're concerned about recompilation, try setting |
Thanks @jekbradbury, I tried your suggestion and I'm seeing lots of Full compile log.
|
This PR addresses #3686 .
Implements BDF solver with JAX control flow and JIT support using code from tensforflow_probability. The current implementation is tested against SciPy's VODE wrapper with stiff chemical kinetics.
To-Do:
Adjoint sensitivities using
custom_vjp
Lazy jacobian evaluation
Automatically calculate jacobian using JAX's autodiff if user doesn't provide jacobian function
Automatically choose times if user only provides
end_time
Comments and suggestions are welcome. @shoyer