-
Notifications
You must be signed in to change notification settings - Fork 227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
High GPU memory during empirical NTK calculation #100
Comments
Could you try something like |
Thank you very much! It works! I rewrite as below: kernel_fn_emp = nt.empirical_kernel_fn(apply_fn, vmap_axes=0)
kernel_fn = jit(lambda x1, x2, _type, params: kernel_fn_emp(x1, x2, _type, params), static_argnums=(2,)) |
Sorry for reopening this issue. With the method mentioned above, on one 1080Ti (11G memory), I can calculate up to 32x32 empirical NTK for a 2-layer 1024-channels FCN on MNIST. However, I can see the work here can study up to a 3-layers FCN of width 2048. I really hope I could study empirical NTK for larger networks. May I ask if there is any more method that can further reduce the GPU memory usage? Could more example code be demonstrated? Thank you very much!!! |
Could you try |
Hi @romanngg! Thank you so much for the kind help! It did further halve the GPU memory! Greatly appreciate! |
I am working on empirical NTK calculations, and I feel the GPU memory usage is very high.
Two cases:
Case1: over 10G memory
Case2: OOM on 1080Ti
I wonder why the GPU memory consumption is so high? Parameters in both two cases are not very high. Especially, case 2 (~1.2K params) should have much lower #params than case1 (~0.8M params), but case2 consumes more GPU memory than case1.
Below are my implementations. Could you help point out any possible mis-uage, and any way to reduce the GPU memory consumption?
Thank you very much!!!
The text was updated successfully, but these errors were encountered: