-
Notifications
You must be signed in to change notification settings - Fork 225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
very large memory footprint for a simple UNet #18
Comments
Hey! Thanks for pushing on this. We'd love to iterate on this to get it working for you (though looking at the UNet architecture I am a bit concerned that the vanilla version violates some independence assumptions wrt FanInConcat). A few things off the top of my head:
TPUs must store data in blocks of size 8 x 128. To fit arbitrary data into blocks of this size, XLA will often pad data. Here you can see that the raw size of the data is 6.25Gb, but it is getting padded by a factor of 2. I might recommend trying to run this on GPU rather than TPU and seeing whether the calculation will fit into memory since GPUs don't need to pad. Generally, we have not figured out a way of phrasing our convolutions in a way that doesn't get padded by the TPU (since our channel count is 1). This is an ongoing area of work, but I have to say we have limited tools at our disposal to make progress here (though maybe @romanngg can comment if he's more hopeful than myself). Let us know how the GPU works. Glancing at the sizes I would expect it to easily fit on a V100 (since it has 32 Gb of RAM whereas this calculation is consuming around 19Gb unpadded). |
+1 to Sam re padding, and also note that even unpadded, the intermediary NNGP covariance of shape 10x10x(64x64)x(64x64) is 6.25 Gb. To propagate this tensor through the NNGP computation from one layer to the next, you need 2X of that. Unfortunately, due to JAX internals in practice it requires 3X (see & upvote google/jax#1733, google/jax#1273), which results in peak memory consumption of 19 Gb, which would require a 32Gb GPU (note that V100s come in 16 and 32 GB varieties, so even it may not be enough). For this reason you'd probably need to work on even smaller batches in this case (see |
Hi @sschoenholz , My understanding from #16 was that Since there was not Thanks |
FYI, we've just added Two caveats:
|
FYI, we have finally added |
Hello @kayhan-batmanghelich, I am currently learning about ntk as well as unet network, do you mind if you share the notebook of colab ? Thank you very much |
Hi,
I hit a roadblock! I tried to compute kernel for a typical UNet for 10 images. The image size is not big (64,64) and the number of images is just 10 (for testing purposes). However, it crashes complaining about memory (see below). I think intermediate layers are probably using so much memory but that limits the usability. Perhaps, I am missing something?
gist collab: https://gist.github.com/kayhan-batmanghelich/f444e6cec65139070f1b3e5ade230de5
Side notes:
upsample
but that need developing a new layer inneural-tangent
and I am not sure how to do that.Error message:
The text was updated successfully, but these errors were encountered: