Gradients with odeint
slow on GPU
#5006
Labels
NVIDIA GPU
Issues specific to NVIDIA GPUs
P1 (soon)
Assignee is working on this now, among other tasks. (Assignee required)
The following MWE trains a simple neural ODE model with gradient descent to match a 2-D dynamical system (Van der Pol oscillator) with sampled data along a single trajectory. Each iteration of the training loop runs slowly on my GPU when compared to running everything on my CPU (roughly estimated with
tqdm
at 17 iterations/sec on GPU vs. upwards of 800 iterations/sec on CPU).Any first impressions about what might be going on? I can look into doing better profiling if need be.
Versions: jax 0.2.6, jaxlib 0.1.57+cuda102, cuda 10.2
The text was updated successfully, but these errors were encountered: