Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vmap of cond's predicate results in select, leading to unexpected compute/memory use #8409

Open
aespielberg opened this issue Oct 30, 2021 · 32 comments
Assignees
Labels
enhancement New feature or request performance make things lean and fast

Comments

@aespielberg
Copy link

I have been playing around with converting diffmpm from the difftaichi package into a jax version, and while the forward pass has been working wonderfully, the backward pass has been using way too much GPU memory.

Today, I was able to track down that memory usage to the grid op. The grid op step is a series of nested if statements. At first, I was using jnp.where, which evaluates all branches. That is extremely inefficient and can lead to OOM errors. I simplified my code, and switch to jnp.cond, but my only conclusion is that cond is also evaluating both branches, otherwise I cannot see why this would run into OOM issues.

Below is a modified version of the grid op, that is composed into itself 4,000 times, like a simulation. Even run with the XLA_PYTHON_CLIENT_PREALLOCATE=false flag, this quickly leads to the the whole GPU being used, and more if the loop length is increased. This is not true if every line from lin = .... until right before the return of grid_op is commented out. In that case, memory usage is practically negligible. Note that because bound = 0, literally every line written v_out = jax.lax.cond ... evaluates to False by definition, and so most of the expressions, including the v_out_gate's and their dependencies, shouldn't even need to be evaluated in the jitted function.

Maybe I am misunderstanding cond; if so, what is the proper way to get this sparse branching behavior? I don't want to evlauate and hang onto a bunch of expensive tensors that are never actually needed and crash my GPU with OOM, especially in an backward pass. This is a core bottleneck to practical deployment of my code and a feature that I think should be supported. FWIW, I am using Version: 0.1.69+cuda101

Code to reproduce is below.

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import os
import math
import numpy as np
import matplotlib.pyplot as plt
import jax.nn as jnn
import jax.lax as jlax
import timeit
import jax

dim = 2
n_grid = 128
dt = 1e-3
gravity = 3.8


def allocate_arrays():
  global grid_m_in, grid_v_in, grid_v_out, loss, index_array
  grid_m_in = jnp.ones((n_grid, n_grid))
  grid_v_in = jnp.zeros((n_grid, n_grid, dim))
  grid_v_out = jnp.zeros((n_grid, n_grid, dim))


  index_array = np.zeros((n_grid, n_grid, dim))
  
  for i in range(n_grid):
    for j in range(n_grid):
      index_array[i, j] = np.array([i, j])
 
  index_array = jnp.array(index_array)

  

def grid_op(grid_v_in, grid_m_in, index_tuple):
  bound = 0
  coeff = 0.5
  
  i = index_tuple[0]
  j = index_tuple[1]
  
  normal = jnp.array([0., 1.])
  
  inv_m = 1 / (grid_m_in + 1e-10)
  v_out = jnp.expand_dims(inv_m, -1) * grid_v_in
  v_out -= dt * gravity * jnp.array([0., 1.])
  
  v_out = jax.lax.cond(jnp.logical_and(i < bound, v_out[0] < 0), lambda _: jnp.zeros_like(v_out), lambda _: v_out, operand=None)
  
  
  v_out = jax.lax.cond(jnp.logical_and(i > n_grid - bound, v_out[0] > 0), lambda _: jnp.zeros_like(v_out), lambda _: v_out, operand=None)
  
  lin = (v_out.transpose() @ normal)
  
  vit = v_out - lin * normal
  lit = jnp.linalg.norm(vit + 1e-10)  + 1e-10
  
  
  v_out_gate_2 = jax.lax.cond(lit + coeff * lin <= 0, lambda _: jnp.zeros_like(v_out), lambda _: (1 + coeff * lin / lit) * vit, operand=None)
  v_out_gate_1 = jax.lax.cond(lin < 0, lambda _: v_out_gate_2, lambda _: jnp.zeros_like(v_out), operand=None)
  v_out = jax.lax.cond(jnp.logical_and(j < bound, v_out[1] < 0), lambda _: v_out_gate_1, lambda _: v_out, operand=None)          
  v_out = jax.lax.cond(jnp.logical_and(j > n_grid - bound, v_out[1] > 0), lambda _: jnp.zeros_like(v_out), lambda _: v_out, operand=None)
  
  return v_out

go_j = jit(vmap(vmap(grid_op)))


def advance2(t, args):
  grid_v_in = args[0]
  grid_m_in = args[1]
  index_array = args[2]
  grid_v_in = go_j(grid_v_in, grid_m_in, index_array)
  
  return grid_v_in, grid_m_in, index_array
  
  
def advance(t, args):
  x = args[0]
  v = args[1]
  C = args[2]
  F = args[3]
  x, v, C, F = p1_j(x, v, C, F, actuator_id)
  
  return x, v, C, F
  
a = jit(advance)

def forward2(grid_v_in, grid_m_in, index_array):
  grid_v_in, grid_m_in, index_array = jlax.fori_loop(0, 4000, advance2, (grid_v_in, grid_m_in, index_array))
  return jnp.mean(grid_v_in)


def main():
# initialization
  allocate_arrays()
  
  f2 = jit(forward2)
  forward_grad2 = jit(grad(forward2))

  number = 10
  

  print(timeit.timeit(lambda : f2(grid_v_in, grid_m_in, index_array).block_until_ready(), number=number) / number)
  print(timeit.timeit(lambda : f2(grid_v_in, grid_m_in, index_array).block_until_ready(), number=number) / number)
  print(timeit.timeit(lambda : forward_grad2(grid_v_in, grid_m_in, index_array).block_until_ready(), number=number) / number)
  print(timeit.timeit(lambda : forward_grad2(grid_v_in, grid_m_in, index_array).block_until_ready(), number=number) / number)
  
if __name__ == "__main__":
  main()

@aespielberg aespielberg added the bug Something isn't working label Oct 30, 2021
@aespielberg
Copy link
Author

In fact, spamming the idempotent v_out = jax.lax.cond(jnp.logical_and(j > n_grid - bound, v_out[1] > 0), lambda _: jnp.zeros_like(v_out), lambda _: v_out, operand=None) a few times to the end of grid_op adds GBs more memory usage 😞

@froystig
Copy link
Member

froystig commented Nov 1, 2021

XLA may choose to evaluate several or all branches of cond or select, but indeed one intent of exposing cond separately from select is to allow you to indicate when that would be expensive. Using cond suggests that only one branch should be evaluated (should other XLA optimizations allow for it).

But XLA has no "batched" conditional construct. Today, when batching the predicate to a cond using vmap, jax will transform the cond to a select instead. From a quick glance through your example, it seems like this might be happening.

Could jax.checkpoint serve you here, by working around having to store material arrays in the course of autodiff?

@aespielberg
Copy link
Author

aespielberg commented Nov 2, 2021

I see. Is there any reason why vmap'd conds are not supported? What is the logic there? I think this would be very important for a lot of people (and it would be great if these nuances were documented). I don't know the internals of the jax compiler, but I do know that other similar systems support some form of branching. Is there any way to add this capability to XLA or somehow jax?

Actually - since I could, if I were carefully about my code, manually batch everything, including conditionals, I see no reason why this couldn't be supported with vmap?

I did try jax.checkpoint, but even having one checkpoint and splitting the loop into half balloons the runtime (I think by 4x IIRC, I'll have to double check, but that's surprising, since I would think as long as checkpoints weren't encompassing each other, that the maximum extra runtime this would incur is 2x). I'm actually On that note, adding cond's vs. where's also balloons the runtime by 2x...eventually these things start to really add up.

@shailesh1729
Copy link

Often, one branch of lax.cond is pretty much like a no-op, while the other branch involves a significant amount of computation. If lax.cond indication is being ignored by vmap, then I guess it will lead to both memory and computation overhead. Using vmap might turn out to be more expensive than writing a for loop.

@aespielberg
Copy link
Author

A loop may be possible in some cases, but not all, and I think that makes this problematic. And, if I can, I would like to make an argument for why I think somehow supporting cond's in vmap'd scenarios is so important.

The wonderful thing about vmap, as advertised, is that anything should be easily vmappable. It's essentially the beauty of abstraction and modularity, as described. Someone writes a module, foo, and if you want to parallelize it, it should be as easy as just calling vmap(foo). You don't need to understand foo, it should just work. In this case, it seems like cond, a very critical dataflow construct, has different behaviors whether vmap'd or not. The problem, of course, is that people who use a module might not even know if cond is being used inside. This leads to functionally very different performance profiles when vmap'd or not that might require deep investigation of foo, which could be a quite complex piece of code. I don't think this is good for the mission of vmap.

In some cases (as I would argue here), vmap appears to be the correct way to process such code, rather than a for loop, since as far as I understand, indexing into arrays via at is runtime expensive in jax (correct me if I'm mistaken), and creating large, sparse one-hot masks would be prohibitively memory-expensive without support for sparse matrices.

@froystig
Copy link
Member

froystig commented Nov 2, 2021

We're in agreement here by and large. This is something that we've thought about improving before, whether at the JAX or XLA level. I can't find an open issue for it, so let's use this one for it.

@froystig froystig changed the title Jax cond evaluating too many branches. vmap of cond's predicate results in select, leading to unexpected compute/memory use Nov 2, 2021
@froystig froystig added enhancement New feature or request performance make things lean and fast and removed bug Something isn't working labels Nov 2, 2021
@minqi
Copy link

minqi commented Jan 6, 2023

Hi there! I was wondering what the JAX team's latest thinking is regarding the behavior of lax.cond when batched via vmap.

I find myself often running into the design pattern of conditionally branching into two subroutines, one expensive, and the other a "placeholder," for example, returning a dummy zero tensor.

@froystig
Copy link
Member

@minqi – the thinking hasn't changed much since this issue was last active. Although there's a fundamental puzzle regarding whether/how to do better, for now we're still producing select when we batch cond's predicate.

@HHalva
Copy link

HHalva commented Mar 15, 2023

I find this one of the biggest practical issues with jax -- Vmap+jit are great but in a lot of code this also necessitates use of cond with them, which results in compute/memory issues, as described above,...and the time spent trying to work around those.

@HHalva
Copy link

HHalva commented Mar 16, 2023

Btw does switch suffer from this same problem when used with vmap?

@jakevdp
Copy link
Collaborator

jakevdp commented Mar 16, 2023

Switch is implemented in terms of cond, so yes it has the same characteristics.

@HHalva
Copy link

HHalva commented Mar 17, 2023

That's what I thought -- might be helpful to document that for switch similarly to what's in cond doc.: e.g" However, when transformed with vmap to operate over a batch of predicates/indices, switch is converted to select"

@pablo2909
Copy link

Hi,

I was wondering what the status was on that ?

I face the following situation:

jax.vmap(lambda x,y: lax.cond(y<0, heavy_computation_1, heavy_computation_2,x))(X, Y)

IIUC, i'll execute both branches in this case and I would rather not :)

@jakevdp
Copy link
Collaborator

jakevdp commented Nov 15, 2023

This is still the case, as mentioned in the docstring of jax.lax.cond.

@jakevdp
Copy link
Collaborator

jakevdp commented Nov 15, 2023

I don't know of any active work to change this.

@evanatyourservice
Copy link

Correct me if I'm wrong but I think you can get around this by splitting the vmap dimension into a list of single pieces, jnp.split(to_vmap, to_vmap.shape[0]), tree_map the list, then concat things back together how you'd like. It could get fancier with the various tree utils, but overall it avoids the vmap and essentially unrolls the batch apply into a sort of list comprehension.

@pablo2909
Copy link

Hey, thanks for the suggestion, I create a MWE that I think illustrates your idea. Let me know if that is not the case. This example batches the condition as well as the argument of lax.cond function. It is 1000 matrix vector products.

import jax
from jax import lax
import jax.numpy as jnp
import jax.tree_util as jtu
from jax import random

X = jnp.arange(1000*4).reshape(1000, 4)
sigma = random.normal(random.PRNGKey(0), (1000,)) + 3.
A = random.normal(random.PRNGKey(1), (1000, 4))

def true_f(x):
    return A@x

def false_f(x):
    return -A@x

def f(sigma, x):
    return lax.cond(sigma < 3, true_f, false_f, x)

@jax.jit
def F(SIGMA, X):
    return jax.vmap(lambda sigma, x: f(sigma, x))(SIGMA,X)


@jax.jit
def splitF(SIGMA, X):
    split_sigma = jnp.split(SIGMA, SIGMA.shape[0])
    split_x = jnp.split(X, X.shape[0])
    return jnp.stack(jtu.tree_map(lambda sigma, x: f(sigma[0], x[0]), split_sigma, split_x))

print(F(sigma, X))
print("-----------")
print(splitF(sigma, X))

The compilation of splitF is significantly longer than F. Intuitively I would have said that splitF would have been longer to run, but I get the following performance test:

%timeit F(sigma, X).block_until_ready()
>>> 3.01 ms ± 177 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit splitF(sigma, X).block_until_ready()
>>> 2.81 ms ± 93.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

I'm not sure if we can draw meaningful conclusions from this small example but it's a starting point.

@evanatyourservice
Copy link

@pablo2909 Yeah the compilation time would take longer but it should allow taking advantage of cond only computing one branch during run, so it depends what you’d like to take advantage of.

@evanatyourservice
Copy link

Here's an example that shows the differences for true and false functions with highly skewed compute times, 20 qr decomps vs a single elementwise multiply:

from time import time
import jax
from jax import numpy as jnp, jit


def true_fn(rng_key, x):
    for _ in range(20):
        x = jnp.linalg.qr(x)[0]
        rng_key, subkey = jax.random.split(rng_key)
        y = jax.random.uniform(subkey, (200, 200))
        x = x + y
    return x


def false_fn(_, x):
    return x * 1.01


def main_fn(rng_key, x):
    rng_key, subkey = jax.random.split(rng_key)
    c = jax.random.choice(subkey, 2).astype(bool)
    return jax.lax.cond(
        c,
        true_fn,
        false_fn,
        rng_key,
        x,
    )


@jit
def regular_vmap(rng_key, xs):
    rng_keys = jax.random.split(rng_key, xs.shape[0])
    return jax.vmap(main_fn)(rng_keys, xs)


@jit
def unrolled(rng_key, xs):
    rng_keys = jax.random.split(rng_key, xs.shape[0])
    x_out = [main_fn(k, x) for k, x in zip(rng_keys, xs)]
    x_out = jnp.concatenate([jnp.expand_dims(x, 0) for x in x_out])
    return x_out


rng_key = jax.random.PRNGKey(0)
x_in = jax.random.normal(jax.random.PRNGKey(1), (15, 200, 200))

# compile
t = time()
regular_vmap(rng_key, x_in).block_until_ready()
print("regular_vmap compile", time() - t)

t = time()
unrolled(rng_key, x_in).block_until_ready()
print("list_vmap compile", time() - t)

n = 5

x = x_in
t = time()
for _ in range(n):
    rng_key, subkey = jax.random.split(rng_key)
    x = regular_vmap(subkey, x)
x = jax.block_until_ready(x)
print("regular_vmap", (time() - t) / n)

x = x_in
t = time()
for _ in range(n):
    rng_key, subkey = jax.random.split(rng_key)
    x = unrolled(subkey, x)
x = jax.block_until_ready(x)
print("list_vmap", (time() - t) / n)

The times on my machine are:

regular_vmap compile 3.4660181999206543
list_vmap compile 18.33980703353882
regular_vmap 0.79718918800354
list_vmap 0.22351880073547364

The unroll has several times longer compile time, but only about a quarter the run time, so if it's going to be run for many iterations the compile time would be worth it. In the case where true and false functions are both light, or can be computed simultaneously on a gpu, it might be worth it to just use vmap, you'd have to experiment.

@pablo2909
Copy link

Thanks a lot, I'm a bit surprised by your results. Your false_fn is very cheap compared to true_fn so I would have expected that vmapping would have been faster than essentially a compiled forloop. I might be missing something.

To adapt it a bit more to my original question, I make true_fn and false_fn both equally costly and replace the list comprehension with a lax.scan. The code looks like this and below it are the performances.

from time import time
import jax
from jax import numpy as jnp, jit


def true_fn(rng_key, x):
    for _ in range(20):
        x = jnp.linalg.qr(x)[0]
        rng_key, subkey = jax.random.split(rng_key)
        y = jax.random.uniform(subkey, (200, 200))
        x = x + y
    return x


def false_fn(rng_key, x):
    for _ in range(20):
        x = jnp.linalg.qr(x)[0]
        rng_key, subkey = jax.random.split(rng_key)
        y = jax.random.uniform(subkey, (200, 200))
        x = x - y
    return x


def main_fn(carry, rng_key_x):
    rng_key, x = rng_key_x
    rng_key, subkey = jax.random.split(rng_key)
    c = jax.random.choice(subkey, 2).astype(bool)
    return carry, jax.lax.cond(
        c,
        true_fn,
        false_fn,
        rng_key,
        x,
    )

@jit
def regular_vmap(rng_key, xs):
    rng_keys = jax.random.split(rng_key, xs.shape[0])
    return jax.vmap(main_fn, in_axes=(None, 0))(None, (rng_keys, xs))[1]


@jit
def unrolled(rng_key, xs):
    rng_keys = jax.random.split(rng_key, xs.shape[0])
    _, x_out = jax.lax.scan( main_fn, None ,(rng_keys, xs))
    return x_out


rng_key = jax.random.PRNGKey(0)
x_in = jax.random.normal(jax.random.PRNGKey(1), (15, 200, 200))

# compile
t = time()
regular_vmap(rng_key, x_in).block_until_ready()
print("regular_vmap compile", time() - t)

t = time()
unrolled(rng_key, x_in).block_until_ready()
print("list_vmap compile", time() - t)

n = 5

x = x_in
t = time()
for _ in range(n):
    rng_key, subkey = jax.random.split(rng_key)
    x = regular_vmap(subkey, x).block_until_ready()
print("regular_vmap", (time() - t) / n)

x = x_in
t = time()
for _ in range(n):
    rng_key, subkey = jax.random.split(rng_key)
    x = unrolled(subkey, x).block_until_ready()
print("list_vmap", (time() - t) / n)
>>> regular_vmap compile 7.062215089797974
>>> list_vmap compile 4.893687963485718
>>> regular_vmap 2.9101763725280763
>>> list_vmap 1.1428837776184082

@evanatyourservice
Copy link

evanatyourservice commented Nov 20, 2023

Scanning over the inputs is a good idea, I think that was mentioned before in a similar issue, it sacrifices a bit of runtime for faster compile (you could of course also do this for the qr for loop)

Edit: Actually after testing this setup the scan-over-batch version runs faster than the unroll!

@evanatyourservice
Copy link

Unfortunately unrolling using scan or list comprehension only seems to consistently improve performance on cpu, not gpu or tpu, unless the batch size is very small and the branches are wildly unequal in compute. So I don't think there's a good solution to this problem without batched cond through XLA :(

@pablo2909
Copy link

I've been observing the same on gpu :/

@evanatyourservice
Copy link

I'm not familiar with triton at all but maybe there's a way to batch cond through pallas. I don't have enough of a need to look that deep into it though 😁

@kerupp
Copy link

kerupp commented Feb 12, 2024

Hello,
I just want to voice my support for having a "batchable cond". It would be very usefull to have for our models!

@nmonette
Copy link

nmonette commented Jul 3, 2024

I have had the same issue as @minqi here. Just wanted to express that this is still an issue that people deal with

@inversecrime
Copy link

Because this hasn't been mentioned yet:

As far as I know, using vmap and while_loop can cause similar problems, since a batched while loop executes the loop body for every batch item until all batch items satisfy the termination condition (see #15954).
For computationally expensive loop body functions, it might be faster to use a for loop over all batch items instead.

Regarding the cond discussion:

It would be nice to have the option to disable vmap completely, i.e. raise an Exception when jax tries to convert a vmap cond to a select.
This could make debugging a lot easier (i.e. "where does this huge memory usage come from?" etc).
It seems to me that most users here agree that the conversion of vmap cond to select is not wanted or even expected.

@inversecrime
Copy link

It might be worth mentioning that it is possible to use jax.custom_batching.sequential_vmap.

@evanatyourservice
Copy link

@inversecrime would sequential_vmap be similar to unrolling the batch using a for loop?

@froystig
Copy link
Member

@inversecrime would sequential_vmap be similar to unrolling the batch using a for loop?

It generates a jax.lax.map, which bottoms out in a (rolled) XLA loop.

@guyuntian
Copy link

I tried using jax.lax.map + jax.lax.cond in my program, but it appears to be significantly slower than jax.vmap + jax.numpy.where on an NVIDIA GeForce RTX 3090. I want to express my support for the "batch cond" feature mentioned by others, as it would be highly valuable for my current work!

@froystig
Copy link
Member

froystig commented Sep 8, 2024

Indeed, it all depends on the program (ignoring possible compiler rewrites), i.e. what's being computed within each branch. The two approaches trade off total compute vs. parallelism.

For similar reasons, there are many possible implementations of "batched cond," all along this tradeoff curve.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request performance make things lean and fast
Projects
None yet
Development

No branches or pull requests

14 participants