Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Colab Reformer Prediction Could not allocate bytes in memory #431

Open
Elkia-Federation opened this issue Apr 5, 2020 · 1 comment
Open

Comments

@Elkia-Federation
Copy link

Elkia-Federation commented Apr 5, 2020

Description

After using colab for training/loading model into prediction mode, runs out of memory on second prediction run on TPU runtime
https://colab.research.google.com/drive/1v2q5Qp2-68hLG-uTZ3gZZHvkm9Ovbpkc

Reformer model details:

def reformer(mode):
  return trax.models.reformer.ReformerLM(
    d_model=32,
    d_ff=128,
    n_layers=8,
    vocab_size=1024,
    mode=mode)

Sequence Length = 100
batch size = 128
...

Environment information

OS: Google Colab

$ pip freeze | grep tensor
mesh-tensorflow==0.1.13
tensor2tensor==1.15.4
tensorboard==2.2.0
tensorboard-plugin-wit==1.6.0.post2
tensorboardcolab==0.0.22
tensorflow==2.2.0rc2
tensorflow-addons==0.8.3
tensorflow-datasets==2.1.0
tensorflow-estimator==2.2.0rc0
tensorflow-gan==2.0.0
tensorflow-gcs-config==2.1.8
tensorflow-hub==0.7.0
tensorflow-metadata==0.21.1
tensorflow-privacy==0.2.2
tensorflow-probability==0.7.0

$ pip freeze | grep jax
jax==0.1.59
jaxlib==0.1.39

$ python -V
Python 3.6.9

For bugs: reproduction and error logs

# Steps to reproduce:
Run all cells upto the "Speed" markdown cell
# Error logs:
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng)
    443       else:
--> 444         outputs, s = self._do_custom_gradients(x, weights, state, rng=rng)
    445       self._state = s

16 frames
RuntimeError: Failed precondition: Dependency failed: Could not allocate 419430400 bytes in memory 0x0x0_HBM0; 370884608 bytes allocatable, 376881152 bytes available

During handling of the above exception, another exception occurred:

LayerError                                Traceback (most recent call last)
LayerError: Exception passing through layer ReversibleSerial (in pure_fn):
  layer created in file [...]/models/reformer/reformer.py, line 802
  layer input shapes: (ShapeDtype{shape:(100, 1, 32), dtype:float32}, ShapeDtype{shape:(100, 1, 32), dtype:float32})

  File [...]/trax/layers/base.py, line 562, in _do_custom_gradients
    output, state = _do_forward(x, weights)

  File [...]/dist-packages/jax/api.py, line 1460, in __call__
    num_consts=len(consts))

  File [...]/dist-packages/jax/core.py, line 179, in bind
    return self.impl(*args, **kwargs)

  File [...]/dist-packages/jax/api.py, line 1511, in fun_impl
    return core.eval_jaxpr(params['jaxpr'], consts, *args)

  File [...]/dist-packages/jax/core.py, line 249, in eval_jaxpr
    ans = eqn.primitive.bind(*(subfuns + in_vals), **params)

  File [...]/dist-packages/jax/core.py, line 179, in bind
    return self.impl(*args, **kwargs)

  File [...]/dist-packages/jax/api.py, line 1511, in fun_impl
    return core.eval_jaxpr(params['jaxpr'], consts, *args)

  File [...]/dist-packages/jax/core.py, line 249, in eval_jaxpr
    ans = eqn.primitive.bind(*(subfuns + in_vals), **params)

  File [...]/dist-packages/jax/core.py, line 179, in bind
    return self.impl(*args, **kwargs)

  File [...]/jax/interpreters/xla.py, line 159, in apply_primitive
    return compiled_fun(*args)

  File [...]/jax/interpreters/xla.py, line 246, in _execute_compiled_primitive
    out_buf = compiled.Execute(input_bufs)

RuntimeError: Failed precondition: Dependency failed: Could not allocate 419430400 bytes in memory 0x0x0_HBM0; 370884608 bytes allocatable, 376881152 bytes available

During handling of the above exception, another exception occurred:

LayerError                                Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng)
    449       name, trace = self.__class__.__name__, _short_traceback()
    450       raise LayerError(name, 'pure_fn',
--> 451                        self._caller, signature(x), trace)
    452 
    453   def output_signature(self, input_signature):

LayerError: Exception passing through layer Serial (in pure_fn):
  layer created in file [...]/models/reformer/reformer.py, line 811
  layer input shapes: ShapeDtype{shape:(100, 1), dtype:int32}

  File [...]/trax/layers/combinators.py, line 77, in forward_with_state
    outputs, s = layer.pure_fn(inputs, w, s, rng)

LayerError: Exception passing through layer ReversibleSerial (in pure_fn):
  layer created in file [...]/models/reformer/reformer.py, line 802
  layer input shapes: (ShapeDtype{shape:(100, 1, 32), dtype:float32}, ShapeDtype{shape:(100, 1, 32), dtype:float32})

  File [...]/trax/layers/base.py, line 562, in _do_custom_gradients
    output, state = _do_forward(x, weights)

  File [...]/dist-packages/jax/api.py, line 1460, in __call__
    num_consts=len(consts))

  File [...]/dist-packages/jax/core.py, line 179, in bind
    return self.impl(*args, **kwargs)

  File [...]/dist-packages/jax/api.py, line 1511, in fun_impl
    return core.eval_jaxpr(params['jaxpr'], consts, *args)

  File [...]/dist-packages/jax/core.py, line 249, in eval_jaxpr
    ans = eqn.primitive.bind(*(subfuns + in_vals), **params)

  File [...]/dist-packages/jax/core.py, line 179, in bind
    return self.impl(*args, **kwargs)

  File [...]/dist-packages/jax/api.py, line 1511, in fun_impl
    return core.eval_jaxpr(params['jaxpr'], consts, *args)

  File [...]/dist-packages/jax/core.py, line 249, in eval_jaxpr
    ans = eqn.primitive.bind(*(subfuns + in_vals), **params)

  File [...]/dist-packages/jax/core.py, line 179, in bind
    return self.impl(*args, **kwargs)

  File [...]/jax/interpreters/xla.py, line 159, in apply_primitive
    return compiled_fun(*args)

  File [...]/jax/interpreters/xla.py, line 246, in _execute_compiled_primitive
    out_buf = compiled.Execute(input_bufs)

RuntimeError: Failed precondition: Dependency failed: Could not allocate 419430400 bytes in memory 0x0x0_HBM0; 370884608 bytes allocatable, 376881152 bytes available
@NightMachinery
Copy link

I am also getting these OOM errors; any way to monitor the TPU ram usage? Any docs on garbage collection on the TPU?

---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
<ipython-input-72-63dca48c8c17> in <module>()
     20   params, state, opt_state, model_output, loss = (
---> 21     train_step(params, state, opt_state, input_batch, target_batch, k1))
     22

9 frames
UnfilteredStackTrace: RuntimeError: RESOURCE_EXHAUSTED: Attempting to allocate 31.06M. That was not possible. There are 58.64M free. Due to fragmentation, the largest contiguous region of free memory is 30.56M.; (0x0x0_HBM0)

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py in _execute_compiled(name, compiled, output_buffer_counts, handlers, kept_var_idx, *args)
   1098           for i, x in enumerate(args)
   1099           if x is not token and i in kept_var_idx))
-> 1100   out_bufs = compiled.execute(input_bufs)
   1101   check_special(name, out_bufs)
   1102   if output_buffer_counts is None:

RuntimeError: RESOURCE_EXHAUSTED: Attempting to allocate 31.06M. That was not possible. There are 58.64M free. Due to fragmentation, the largest contiguous region of free memory is 30.56M.; (0x0x0_HBM0)
---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
<ipython-input-69-63dca48c8c17> in <module>()
     20   params, state, opt_state, model_output, loss = (
---> 21     train_step(params, state, opt_state, input_batch, target_batch, k1))
     22

9 frames
UnfilteredStackTrace: RuntimeError: FAILED_PRECONDITION: Dependency failed: Could not allocate 32571392 bytes in memory 0x0x0_HBM0; 32047104 bytes allocatable, 59981824 bytes available

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py in _execute_compiled(name, compiled, output_buffer_counts, handlers, kept_var_idx, *args)
   1098           for i, x in enumerate(args)
   1099           if x is not token and i in kept_var_idx))
-> 1100   out_bufs = compiled.execute(input_bufs)
   1101   check_special(name, out_bufs)
   1102   if output_buffer_counts is None:

RuntimeError: FAILED_PRECONDITION: Dependency failed: Could not allocate 32571392 bytes in memory 0x0x0_HBM0; 32047104 bytes allocatable, 59981824 bytes available

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants