-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
memory leak with grad and lax.scan #3348
Comments
I can reproduce the leak on CPU, too. It goes away though if you use |
We looked into this. Yes, it's a memory leak of sorts. Note that @clemisch 's suggestion is the best one: you should use The leak occurs because:
forms a new We should fix this, but as a workaround, I suggest either using We probably need to make our reference to the function in the cache a weak reference, if nothing else, although this would still lead to redundant tracing work on each step. |
Thanks. The suggestion by @clemisch to jit the grad function seems to work just fine. |
@hawkinsp thanks for responding here! I was going through this thread and had something a bit unclear in my head, I hope you don't mind me raising the question directly here in context. How would one rewrite the two lines of code below to avoid creating a new function on each step call? fx = partial(apply_fun_scan, p1)
_, ht_new = lax.scan(fx, p2, inputs) Would it be like the following block? def apply_fun(params, inputs, **kwargs):
p1 = params[0] # general parameters
p2 = params[1] # the hidden state that evolves over time
def fx(p2, inputs):
output = inputs * p1 * p2
return output, output
_, ht_new = lax.scan(fx, p2, inputs)
return ht_new I think this avoids re-creating a new function object, but I'm not 100% sure; it's definitely the cleanest way I can think of to keep I'm asking because I think I am observing a similar issue using |
I coded a GRU for time series analysis that uses lax.scan to scan through each time step. It appears that lax.scan when called by grad results in a GPU memory leak. This issue might be related to 282, but I am not sure how to fix it. As I understand it, a recurrent network built with Jax implies use of lax.scan (to avoid a for loop). Any ideas would be appreciated. I tried clearing the xla cache, to no avail, and tried using remat, which reduced the memory leak only slightly (both efforts are commented out in the code below).
Simplified code to reproduce the problem is below. Each iteration of a call to grad produces a 370 MB memory leak. I use cuda 10.1 with driver 435.21, Ununtu 19.10, Python 3.7, and Jax 0.1.62.
The text was updated successfully, but these errors were encountered: