Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Buffer donation to a jit function on GPU #1273

Closed
romanngg opened this issue Aug 30, 2019 · 7 comments
Closed

Buffer donation to a jit function on GPU #1273

romanngg opened this issue Aug 30, 2019 · 7 comments
Labels
enhancement New feature or request

Comments

@romanngg
Copy link
Contributor

Below is a CNN iteratedly applied to a 2Gb input. It produces a 4x2Gb = 8 Gb peak memory consumption.

import jax.numpy as np
import jax.random as random
from jax import lax
from jax import jit

@jit
def f(x):
  for _ in range(10):
    x = lax.conv_general_dilated(x, np.ones((3, 3, 1, 1)), (1, 1), 'SAME', 
                                 dimension_numbers=('NHWC', 'HWIO', 'NHWC'))
  return x

x = random.normal(random.PRNGKey(1), (2**19, 2**5, 2**5, 1))  
# (2**20, 2**5, 2**5, 1)) OOMs!
x = f(x)

Without JIT, the peak memory consumption is 2x2Gb = 4 Gb, as is expected.

Would be great to achieve a comparable memory usage with JIT by input buffer donation to the jit function (not sure on the exact terminology).

Thanks a lot!

@hawkinsp
Copy link
Member

Buffer donation has been checked in!

@romanngg
Copy link
Contributor Author

Thanks Peter, do you know how can I leverage it to reduce the memory consumption in the example above?

So far, even if I do

f = jit(f, donate_argnums=0)

I still get peak memory of 4x2 = 8Gb, and a message

jax/interpreters/xla.py:660: UserWarning: Some donated buffers were not usable: f32[524288,32,32,1]{3,2,1,0}

@jekbradbury
Copy link
Contributor

I believe that means that there wasn't an output with the same shape that could have reused that buffer (or there weren't an equal number of such outputs as inputs).

@romanngg
Copy link
Contributor Author

Interesting - how come it doesn't work in this example then? From my understanding here there's 1 input, 1 output, both of shape and type f32[524288,32,32,1].

@tomhennigan
Copy link
Member

FYI buffeer donation is only supported on TPU at the moment, XLA team are working to support this on CPU/GPU but that may be why we cannot use the donation.

@romanngg romanngg changed the title Buffer donation to a jit function Buffer donation to a jit function on GPU Jun 30, 2020
@romanngg
Copy link
Contributor Author

I see, thanks! Could you please reopen this issue then?

@tomhennigan tomhennigan reopened this Jun 30, 2020
@hawkinsp
Copy link
Member

Fixed by #3800

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

5 participants