Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradients with odeint slow on GPU #5006

Open
spenrich opened this issue Nov 24, 2020 · 5 comments
Open

Gradients with odeint slow on GPU #5006

spenrich opened this issue Nov 24, 2020 · 5 comments
Assignees
Labels
NVIDIA GPU Issues specific to NVIDIA GPUs P1 (soon) Assignee is working on this now, among other tasks. (Assignee required)

Comments

@spenrich
Copy link

The following MWE trains a simple neural ODE model with gradient descent to match a 2-D dynamical system (Van der Pol oscillator) with sampled data along a single trajectory. Each iteration of the training loop runs slowly on my GPU when compared to running everything on my CPU (roughly estimated with tqdm at 17 iterations/sec on GPU vs. upwards of 800 iterations/sec on CPU).

Any first impressions about what might be going on? I can look into doing better profiling if need be.

Versions: jax 0.2.6, jaxlib 0.1.57+cuda102, cuda 10.2

import jax
import jax.numpy as jnp
from jax.experimental.ode import odeint
try:
    from tqdm import tqdm
except ImportError:
    tqdm = lambda x: x

# Uncomment this line to force using the CPU
# jax.config.update('jax_platform_name', 'cpu')

# Some utilities for dealing with PyTrees of parameters
def tree_axpy(a, x_tree, y_tree):
    """Compute `y = a*x` for two PyTrees `(x, y)` and a scalar `a`."""
    ax = jax.tree_util.tree_map(lambda x: a * x, x_tree)
    axpy = jax.tree_util.tree_multimap(lambda x, y: x + y, ax, y_tree)
    return axpy

def tree_normsq(x_tree):
    """Compute sum of squared norms across a PyTree."""
    normsq = jax.tree_util.tree_reduce(lambda x, y: x + jnp.sum(y**2), x_tree, 0.)
    return normsq

# Define true ODE, our approximator, and the loss function
def f(x, t):
    """Compute state derivative of a Van der Pol oscillator."""
    mu = 1.
    dx = jnp.hstack([
        mu*(x[0] - x[0]**3/3 - x[1]),
        x[0]/mu
    ])
    return dx

def f_est(x, t, params):
    """Estimate state derivative with a two-layer neural network."""
    W = params['W']
    b = params['b']
    y = W[0]@x + b[0]
    y = W[1]@jnp.tanh(y) + b[1]
    return y

def loss(params, x, t, reg_coeff):
    """Compute the sum of squared losses along a queried trajectory."""
    x_hat = odeint(f_est, x[0], t, params)
    error = jnp.sum((x - x_hat)**2)
    loss_value = error + reg_coeff*tree_normsq(params)
    return loss_value

# Generate data along a trajectory of the true system
x0 = jnp.array([1., 0.])
t0, tf = (0., 5.)
dt = 0.1
num_steps = int((tf - t0) / dt) + 1
t = jnp.linspace(t0, tf, num_steps)
x = odeint(f, x0, t)

# Initialize neural network parameters
n = 2
hdim = 32  # size of hidden layer
key = jax.random.PRNGKey(0)
params = {
    'W': [
        0.1*jax.random.normal(key, (hdim, n)),
        0.1*jax.random.normal(key, (n, hdim)),
    ],
    'b': [
        0.1*jax.random.normal(key, (hdim,)),
        0.1*jax.random.normal(key, (n,)),
    ]
}

# Training
loss_buffer = []
step_size = 1e-4
reg_coeff = 1e-6
value_and_grad = jax.jit(jax.value_and_grad(loss))
for _ in tqdm(range(5000)):
    value, grad = value_and_grad(params, x, t, reg_coeff)
    loss_buffer.append(value)
    params = tree_axpy(-step_size, grad, params)  # gradient descent step
print('Regularized fit loss:', loss_buffer[-1])

# Plotting (optional)
try:
    import matplotlib.pyplot as plt

    x_est = odeint(f_est, x0, t, params)
    
    fig, axes = plt.subplots(1, 2, figsize=(15,5))
    axes[0].plot(x[:,0], x[:,1], '--x')
    axes[0].plot(x_est[:,0], x_est[:,1], '-')
    axes[1].plot(loss_buffer)
    axes[1].set_yscale('log')
    plt.show()
except ImportError:
    print('Package `matplotlib` not found! Skipping plots.')
@shoyer
Copy link
Member

shoyer commented Nov 24, 2020

The short answer is that unfortunately at this time XLA GPU is not great at code generation for tight loops like those in odeint. The body of the while_loop is compiled into one or more GPU kernels, which has significant launch overhead because control flow goes back to the CPU in each iteration.

@awav
Copy link

awav commented Dec 8, 2020

@shoyer does anyone work on improving while_loop?

@shoyer
Copy link
Member

shoyer commented Dec 8, 2020

Yes, there are several ongoing streams of work to improve while_loop.

@awav
Copy link

awav commented Dec 9, 2020

@shoyer, apologies for unrelated with the topic questions. Could you share the links to PRs, branches to ongoing work if it's publicly available?

@shoyer
Copy link
Member

shoyer commented Dec 9, 2020 via email

@sudhakarsingh27 sudhakarsingh27 added NVIDIA GPU Issues specific to NVIDIA GPUs P1 (soon) Assignee is working on this now, among other tasks. (Assignee required) labels Aug 10, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
NVIDIA GPU Issues specific to NVIDIA GPUs P1 (soon) Assignee is working on this now, among other tasks. (Assignee required)
Projects
None yet
Development

No branches or pull requests

5 participants