Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

High GPU memory during empirical NTK calculation #100

Closed
chenwydj opened this issue Feb 13, 2021 · 5 comments
Closed

High GPU memory during empirical NTK calculation #100

chenwydj opened this issue Feb 13, 2021 · 5 comments
Labels
question Further information is requested

Comments

@chenwydj
Copy link

chenwydj commented Feb 13, 2021

I am working on empirical NTK calculations, and I feel the GPU memory usage is very high.

Two cases:

  1. Input (28*28) => 1 Linear layer of 1024 channels => output 10 classes; Use 64 samples to calculate NTK.
  2. Input (28x28x1) => 1 3x3 Conv layer of 64 channels => GAP => output 10 classes; Use 64 samples to calculate NTK.

Case1: over 10G memory
Case2: OOM on 1080Ti

I wonder why the GPU memory consumption is so high? Parameters in both two cases are not very high. Especially, case 2 (~1.2K params) should have much lower #params than case1 (~0.8M params), but case2 consumes more GPU memory than case1.

Below are my implementations. Could you help point out any possible mis-uage, and any way to reduce the GPU memory consumption?

Thank you very much!!!

# Case1
init_fn, apply_fn, kernel_fn_inf = stax.serial(
    stax.Conv(64, (3, 3), padding='SAME', W_std=1.0, b_std=0.05),
    stax.Relu(),
    stax.GlobalAvgPool(),
    stax.Dense(10, 1.0, 0.05),
    stax.Relu())
# Case2
init_fn, apply_fn, kernel_fn_inf = stax.serial(
    stax.Dense(1024, 1.0, 0.05),
    stax.Relu(),
    stax.Dense(10, 1.0, 0.05),
    stax.Relu())
# calculate empirical NTK, X contains 64 samples
kernel_fn = nt.empirical_kernel_fn(apply_fn)
kernel = kernel_fn(X, None, 'ntk', params)
@romanngg
Copy link
Contributor

Could you try something like kernel_fn = jax.jit(nt.empirical_kernel_fn(apply_fn, vmap_axes=0))? I believe Jitting and specifying the vmap_axes (see description at https://neural-tangents.readthedocs.io/en/latest/neural_tangents.empirical.html) should both help with speed and memory, especially for CNNs. Lmk if this works!

@chenwydj
Copy link
Author

Thank you very much! It works!

I rewrite as below:

kernel_fn_emp = nt.empirical_kernel_fn(apply_fn, vmap_axes=0)
kernel_fn = jit(lambda x1, x2, _type, params: kernel_fn_emp(x1, x2, _type, params), static_argnums=(2,))

@chenwydj
Copy link
Author

Sorry for reopening this issue.

With the method mentioned above, on one 1080Ti (11G memory), I can calculate up to 32x32 empirical NTK for a 2-layer 1024-channels FCN on MNIST.

However, I can see the work here can study up to a 3-layers FCN of width 2048.

I really hope I could study empirical NTK for larger networks.

May I ask if there is any more method that can further reduce the GPU memory usage? Could more example code be demonstrated?

Thank you very much!!!

@chenwydj chenwydj reopened this Feb 19, 2021
@romanngg
Copy link
Contributor

Could you try implemetation=2? Here's an example that seems to work:
https://colab.research.google.com/gist/romanngg/78bba7a22cf9d63d2f3711f4f3d5c062/empirical_nntk_sample.ipynb
This might be running on a v100 though, so more powerful than 1080ti. Lmk if this helps!

@chenwydj
Copy link
Author

Hi @romanngg!

Thank you so much for the kind help! It did further halve the GPU memory! Greatly appreciate!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants