Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

very large memory footprint for a simple UNet #18

Open
kayhan-batmanghelich opened this issue Jan 22, 2020 · 7 comments
Open

very large memory footprint for a simple UNet #18

kayhan-batmanghelich opened this issue Jan 22, 2020 · 7 comments
Labels
enhancement New feature or request

Comments

@kayhan-batmanghelich
Copy link
Contributor

Hi,

I hit a roadblock! I tried to compute kernel for a typical UNet for 10 images. The image size is not big (64,64) and the number of images is just 10 (for testing purposes). However, it crashes complaining about memory (see below). I think intermediate layers are probably using so much memory but that limits the usability. Perhaps, I am missing something?

gist collab: https://gist.github.com/kayhan-batmanghelich/f444e6cec65139070f1b3e5ade230de5

Side notes:

  • If you train the model using gradient descent, the performance is not always good. You should try different seed numbers. I have a different JAX implementation that uses upsample but that need developing a new layer in neural-tangent and I am not sure how to do that.

Error message:

/usr/local/lib/python3.6/dist-packages/jax/lax/lax.py:4571: UserWarning: Explicitly requested dtype <class 'jax.numpy.lax_numpy.float64'> requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  warnings.warn(msg.format(dtype, fun_name , truncated_dtype))
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-27-053c34ab30f7> in <module>()
----> 1 kernel = mykernel(random_image[:10],random_image[:10])

6 frames
/usr/local/lib/python3.6/dist-packages/jax/api.py in f_jitted(*args, **kwargs)
    147     _check_args(args_flat)
    148     flat_fun, out_tree = flatten_fun(f, in_tree)
--> 149     out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend)
    150     return tree_unflatten(out_tree(), out)
    151 

/usr/local/lib/python3.6/dist-packages/jax/core.py in call_bind(primitive, f, *args, **params)
    600   if top_trace is None:
    601     with new_sublevel():
--> 602       outs = primitive.impl(f, *args, **params)
    603   else:
    604     tracers = map(top_trace.full_raise, args)

/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py in _xla_call_impl(fun, *args, **params)
    440   device = params['device']
    441   backend = params['backend']
--> 442   compiled_fun = _xla_callable(fun, device, backend, *map(arg_spec, args))
    443   try:
    444     return compiled_fun(*args)

/usr/local/lib/python3.6/dist-packages/jax/linear_util.py in memoized_fun(fun, *args)
    221       fun.populate_stores(stores)
    222     else:
--> 223       ans = call(fun, *args)
    224       cache[key] = (ans, fun.stores)
    225     return ans

/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, *arg_specs)
    497   options = xb.get_compile_options(
    498       num_replicas=nreps, device_assignment=(device.id,) if device else None)
--> 499   compiled = built.Compile(compile_options=options, backend=xb.get_backend(backend))
    500 
    501   if nreps == 1:

/usr/local/lib/python3.6/dist-packages/jaxlib/xla_client.py in Compile(self, argument_shapes, compile_options, backend)
    607     if argument_shapes:
    608       compile_options.argument_layouts = argument_shapes
--> 609     return backend.compile(self.computation, compile_options)
    610 
    611   def GetProgramShape(self):

/usr/local/lib/python3.6/dist-packages/jaxlib/tpu_client.py in compile(self, c_computation, compile_options)
    103                                              compile_options.argument_layouts,
    104                                              options, self.client,
--> 105                                              compile_options.device_assignment)
    106 
    107   def get_default_device_assignment(self, num_replicas):

RuntimeError: Resource exhausted: Ran out of memory in memory space hbm. Used 25.99G of 7.48G hbm. Exceeded hbm capacity by 18.50G.

Total hbm usage >= 26.50G:
    reserved        529.00M 
    program          25.99G 
    arguments       unknown size 

Output size unknown.

Program hbm requirement 25.99G:
    reserved           4.0K
    global            36.0K
    HLO temp         25.99G (74.4% utilization: Unpadded (19.34G) Padded (25.98G), 0.0% fragmentation (10.31M))

  Largest program allocations in hbm:

  1. Size: 12.50G
     Operator: op_type="conv_general_dilated"
     Shape: f32[409600,1,64,64]{0,1,3,2:T(2,128)}
     Unpadded size: 6.25G
     Extra memory due to padding: 6.25G (2.0x expansion)
     XLA label: %convolution.5785 = f32[409600,1,64,64]{0,1,3,2:T(2,128)} convolution(bf16[409600,1,64,64]{0,1,3,2:T(4,128)(2,1)} %reshape.1452, bf16[3,3,1,1]{3,2,1,0:T(4,128)(2,1)} %constant.2723), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, metadata={op_t...
     Allocation type: HLO temp
     ==========================

  2. Size: 6.25G
     Shape: f32[409600,1,64,64]{0,3,2,1}
     Unpadded size: 6.25G
     XLA label: %copy.1516 = f32[409600,1,64,64]{0,3,2,1} copy(f32[409600,1,64,64]{0,1,3,2:T(2,128)} %convolution.4620)
     Allocation type: HLO temp
     ==========================

  3. Size: 6.25G
     Shape: f32[409600,1,64,64]{0,3,2,1}
     Unpadded size: 6.25G
     XLA label: %copy.1540 = f32[409600,1,64,64]{0,3,2,1} copy(f32[409600,1,64,64]{0,1,3,2:T(2,128)} %convolution.5785)
     Allocation type: HLO temp
     ==========================

  4. Size: 640.00M
     Operator: op_type="reshape"
     Shape: bf16[10,64,64,64,64]{2,1,0,4,3:T(8,128)(2,1)}
     Unpadded size: 320.00M
     Extra memory due to padding: 320.00M (2.0x expansion)
     XLA label: %reshape.753 = bf16[10,64,64,64,64]{2,1,0,4,3:T(8,128)(2,1)} reshape(bf16[40960,1,64,64]{0,1,3,2:T(4,128)(2,1)} %fusion.420), metadata={op_type="reshape"}
     Allocation type: HLO temp
     ==========================

  5. Size: 160.00M
     Operator: op_type="transpose"
     Shape: bf16[10,64,64,32,32]{2,1,0,4,3:T(8,128)(2,1)}
     Unpadded size: 80.00M
     Extra memory due to padding: 80.00M (2.0x expansion)
     XLA label: %copy.1153 = bf16[10,64,64,32,32]{2,1,0,4,3:T(8,128)(2,1)} copy(bf16[10,64,64,32,32]{2,1,4,3,0:T(8,128)(2,1)} %bitcast.127), metadata={op_type="transpose"}
     Allocation type: HLO temp
     ==========================

  6. Size: 100.00M
     Shape: f32[409600,64]{0,1:T(8,128)}
     Unpadded size: 100.00M
     XLA label: %reshape.1326 = f32[409600,64]{0,1:T(8,128)} reshape(f32[10,10,64,64,64]{3,2,1,0,4:T(8,128)} %broadcast.1682.remat)
     Allocation type: HLO temp
     ==========================

  7. Size: 100.00M
     Shape: f32[409600,64]{0,1:T(8,128)}
     Unpadded size: 100.00M
     XLA label: %reshape.1332 = f32[409600,64]{0,1:T(8,128)} reshape(f32[10,10,64,64,64]{3,2,1,0,4:T(8,128)} %broadcast.2053)
     Allocation type: HLO temp
     ==========================

  8. Size: 256.0K
     Operator: op_type="slice"
     Shape: f32[10,4096]{1,0:T(8,128)}
     Unpadded size: 160.0K
     Extra memory due to padding: 96.0K (1.6x expansion)
     XLA label: %fusion.671 = f32[10,4096]{1,0:T(8,128)} fusion(f32[10,4096,4096]{2,1,0:T(8,128)} %reshape.4392, pred[4096,4096]{1,0:T(8,128)E(32)} %fusion.1076.remat), kind=kLoop, calls=%fused_computation.591, metadata={op_type="slice"}
     Allocation type: HLO temp
     ==========================

  9. Size: 9.0K
     Shape: bf16[3,3,1,1]{3,2,1,0:T(4,128)(2,1)}
     Unpadded size: 18B
     Extra memory due to padding: 9.0K (512.0x expansion)
     XLA label: constant literal
     Allocation type: global
     ==========================

  10. Size: 4.0K
     XLA label: profiler
     Allocation type: reserved
     ==========================

  11. Size: 4.0K
     Shape: bf16[2,2,1,1]{3,2,1,0:T(4,128)(2,1)}
     Unpadded size: 8B
     Extra memory due to padding: 4.0K (512.0x expansion)
     XLA label: constant literal
     Allocation type: global
     ==========================

  12. Size: 4.0K
     Shape: u32[8,128]{1,0}
     Unpadded size: 4.0K
     XLA label: constant literal
     Allocation type: global
     ==========================

@sschoenholz
Copy link
Contributor

Hey! Thanks for pushing on this. We'd love to iterate on this to get it working for you (though looking at the UNet architecture I am a bit concerned that the vanilla version violates some independence assumptions wrt FanInConcat).

A few things off the top of my head:

  1. Convolutions generically require quite a bit of storage (a dataset x pixels x dataset x pixels covariance matrix in the general case). It can be helpful to use the batching functionality to run on even modestly large datasets.

  2. Having said that, I think we ought to be able to do 10 images at a time! Looking at this stack trace you might notice lines like the following:

Size: 12.50G
     Operator: op_type="conv_general_dilated"
     Shape: f32[409600,1,64,64]{0,1,3,2:T(2,128)}
     Unpadded size: 6.25G
     Extra memory due to padding: 6.25G (2.0x expansion)

TPUs must store data in blocks of size 8 x 128. To fit arbitrary data into blocks of this size, XLA will often pad data. Here you can see that the raw size of the data is 6.25Gb, but it is getting padded by a factor of 2. I might recommend trying to run this on GPU rather than TPU and seeing whether the calculation will fit into memory since GPUs don't need to pad. Generally, we have not figured out a way of phrasing our convolutions in a way that doesn't get padded by the TPU (since our channel count is 1). This is an ongoing area of work, but I have to say we have limited tools at our disposal to make progress here (though maybe @romanngg can comment if he's more hopeful than myself).

Let us know how the GPU works. Glancing at the sizes I would expect it to easily fit on a V100 (since it has 32 Gb of RAM whereas this calculation is consuming around 19Gb unpadded).

@romanngg
Copy link
Contributor

+1 to Sam re padding, and also note that even unpadded, the intermediary NNGP covariance of shape 10x10x(64x64)x(64x64) is 6.25 Gb. To propagate this tensor through the NNGP computation from one layer to the next, you need 2X of that. Unfortunately, due to JAX internals in practice it requires 3X (see & upvote google/jax#1733, google/jax#1273), which results in peak memory consumption of 19 Gb, which would require a 32Gb GPU (note that V100s come in 16 and 32 GB varieties, so even it may not be enough). For this reason you'd probably need to work on even smaller batches in this case (see nt.batch), or reduce the image sizes.

@kayhan-batmanghelich
Copy link
Contributor Author

Hi @sschoenholz ,

My understanding from #16 was that FanInConcat is theoritically OK, and also my superficial understanding of the Greg Yang paper was that these kinds of linear operations do not break the theory, but I might be totally wrong.

Since there was not FanInConcat, I implemented the UNet using ConvTraspose which resulted in an increase in parameters and less stable SGD training. However, that is different than getting kernel nngp kernel for ten samples. I will re-run in GPU and report back.

Thanks

@romanngg
Copy link
Contributor

FYI, we've just added FanInConcat support in c485052!

Two caveats:

  1. When concatenating along the channel/feature axis, a Dense or Convolutional layer is required afterwards (so you can't have [assuming NHWC, channel axis -1] stax.serial(..., stax.FanInConcat(axis=-1), stax.Relu(),...) for now - this might be implemented later).
  2. This will not reduce the memory footprint (which should be identical to FanInSum for channel axis concatenation, and larger for spatial or batch axis concatenation).

@romanngg romanngg added the enhancement New feature or request label Apr 15, 2020
@romanngg
Copy link
Contributor

FYI, 100afac altered tensor layout and this might reduce the TPU memory footprint; in general, there's still work to do to fully eliminate padding, and GPUs are much recommended (see 100afac).

(Otherwise, speed was also improved by ~3-5X, which should allow to use smaller batches)

@romanngg
Copy link
Contributor

romanngg commented Sep 9, 2020

FYI, we have finally added ConvTranspose in 780ad0c!

@n17dccn151
Copy link

Hello @kayhan-batmanghelich, I am currently learning about ntk as well as unet network, do you mind if you share the notebook of colab ? Thank you very much

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants