Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

memory leak with grad and lax.scan #3348

Closed
John-Boik opened this issue Jun 6, 2020 · 4 comments
Closed

memory leak with grad and lax.scan #3348

John-Boik opened this issue Jun 6, 2020 · 4 comments
Labels
bug Something isn't working

Comments

@John-Boik
Copy link

John-Boik commented Jun 6, 2020

I coded a GRU for time series analysis that uses lax.scan to scan through each time step. It appears that lax.scan when called by grad results in a GPU memory leak. This issue might be related to 282, but I am not sure how to fix it. As I understand it, a recurrent network built with Jax implies use of lax.scan (to avoid a for loop). Any ideas would be appreciated. I tried clearing the xla cache, to no avail, and tried using remat, which reduced the memory leak only slightly (both efforts are commented out in the code below).

Simplified code to reproduce the problem is below. Each iteration of a call to grad produces a 370 MB memory leak. I use cuda 10.1 with driver 435.21, Ununtu 19.10, Python 3.7, and Jax 0.1.62.


import time
import os
import jax.numpy as np

from functools import partial
from jax.experimental.stax import serial
from jax import remat
from jax import lax
from jax import random
from jax.nn.initializers import glorot_normal, normal
from jax.experimental import optimizers
from jax import grad
from jax.interpreters import xla


os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='false'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'

# --------------------------------------------------------------------------------------------------
# Gru
# --------------------------------------------------------------------------------------------------
def GRU(num_latents, W_init=glorot_normal(), b_init=normal()):
    # ----------------------------------------------------------------------------------------------
    def init_fun(rng, input_shape, **kwargs):
        k1, k2 = random.split(rng, num=2)
        p1 = b_init(k1, (input_shape[0], num_latents))  
        p2 = b_init(k2, (input_shape[0], num_latents))  
        return ((p1,p2),(None,))

        
    # ----------------------------------------------------------------------------------------------
    def apply_fun(params, inputs, **kwargs):
        p1 = params[0]  # general parameters
        p2 = params[1]  # the hidden state that evolves over time
        
        # -------------------------------
        def apply_fun_scan(p1_, p2_, inputs_):
            output = inputs_ * p1_ * p2_
            return output, output
        
        #fx = remat (partial(apply_fun_scan, p1), concrete=True)
        fx = partial(apply_fun_scan, p1)
        _, ht_new = lax.scan(fx, p2, inputs)  
        return ht_new
        
    return init_fun, apply_fun


rng = random.PRNGKey(1)
net_init, apply = serial(
    GRU(400)
    )

params = net_init(rng, (300,400))   
init, update, get_params = optimizers.adam(step_size=1e-3)
state = init(params)

# --------------------------------------------------------------------------------------------------
def grad_update(ii, X, state):
    params = get_params(state)
    gradients = grad(forward,  argnums=(0,), has_aux=False)(params, X)
    state = update(ii, gradients[0], state)
    return state

# --------------------------------------------------------------------------------------------------
def forward(params, X):
    Z = apply(params, X) 
    return Z.sum() + 1.


for ii in range(100):
    print(ii)
    X = np.ones((200, 300, 400))
    state = grad_update(ii, X, state)
    #xla._xla_callable.cache_clear()  # does not help memory leak
    time.sleep(3)

@clemisch
Copy link
Contributor

clemisch commented Jun 7, 2020

I can reproduce the leak on CPU, too. It goes away though if you use @jax.jit on the grad_update function.

@hawkinsp
Copy link
Collaborator

hawkinsp commented Jun 8, 2020

We looked into this. Yes, it's a memory leak of sorts.

Note that @clemisch 's suggestion is the best one: you should use jax.jit around your function, otherwise I wouldn't expect great performance at the moment no matter what. It also avoids the memory leak.

The leak occurs because:

fx = partial(apply_fun_scan, p1)
_, ht_new = lax.scan(fx, p2, inputs)

forms a new fx function on each step (i.e., with a unique object identity). This means we never get cache hits in a cache in the implementation of lax.scan; and each entry costs at least 100MiB or more.

We should fix this, but as a workaround, I suggest either using jit or rewriting your code so as to avoid forming a new function object on each step (as determined by id(fx)).

We probably need to make our reference to the function in the cache a weak reference, if nothing else, although this would still lead to redundant tracing work on each step.

@hawkinsp hawkinsp added the bug Something isn't working label Jun 8, 2020
@John-Boik
Copy link
Author

Thanks. The suggestion by @clemisch to jit the grad function seems to work just fine.

@ericmjl
Copy link

ericmjl commented Sep 16, 2020

@hawkinsp thanks for responding here! I was going through this thread and had something a bit unclear in my head, I hope you don't mind me raising the question directly here in context.

How would one rewrite the two lines of code below to avoid creating a new function on each step call?

fx = partial(apply_fun_scan, p1)
_, ht_new = lax.scan(fx, p2, inputs)

Would it be like the following block?

    def apply_fun(params, inputs, **kwargs):
        p1 = params[0]  # general parameters
        p2 = params[1]  # the hidden state that evolves over time
        
        def fx(p2, inputs):
            output = inputs * p1 * p2
            return output, output
        
        _, ht_new = lax.scan(fx, p2, inputs)  
        return ht_new

I think this avoids re-creating a new function object, but I'm not 100% sure; it's definitely the cleanest way I can think of to keep p1 in scope without partial-ing though.

I'm asking because I think I am observing a similar issue using lax.scan in jax-unirep. To keep things easily testable though, we have the inner fx-like function refactored outside of the scope of apply_fun, and use partial to help us bind certain inputs in scope, but if need be, we can always rearchitect the code base a bit.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants