<a href="https://colab.research.google.com/github/icml2022anon/fast_finite_width_ntk/blob/main/example_tf.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Example of computing NTK of a **Tensorflow (Keras)** ResNet50 on ImageNet inputs

Tested on NVIDIA A100

More examples: 


*   [JAX (Flax)](https://colab.research.google.com/github/icml2022anon/fast_finite_width_ntk/blob/main/example.ipynb)
*   [PyTorch](https://colab.research.google.com/github/icml2022anon/fast_finite_width_ntk/blob/main/example_pytorch.ipynb)



# Imports and setup

In [1]:
!nvidia-smi -L

GPU 0: A100-SXM4-40GB (UUID: GPU-716bb71d-bc31-5489-3eca-51341eec18f9)


In [2]:
# We need at least jaxlib-0.1.73 to avoid certain CUDA bugs when using `implementation=auto`
!pip install --upgrade pip
!pip install --upgrade jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html

Collecting pip
  Downloading pip-22.0.4-py3-none-any.whl (2.1 MB)
[K     |████████████████████████████████| 2.1 MB 8.1 MB/s 
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 21.1.3
    Uninstalling pip-21.1.3:
      Successfully uninstalled pip-21.1.3
Successfully installed pip-22.0.4
Looking in links: https://storage.googleapis.com/jax-releases/jax_releases.html
Collecting jax[cuda11_cudnn805]
  Downloading jax-0.3.5.tar.gz (946 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m946.8/946.8 KB[0m [31m18.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting jaxlib==0.3.5+cuda11.cudnn805
  Downloading https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.5%2Bcuda11.cudnn805-cp37-none-manylinux2010_x86_64.whl (208.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m208.0/208.0 MB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
Building wheels for 

In [7]:
!pip install git+https://github.com/deepmind/tf2jax.git --no-deps
!pip install git+https://github.com/icml2022anon/fast_finite_width_ntk.git --upgrade

Collecting git+https://github.com/deepmind/tf2jax.git
  Cloning https://github.com/deepmind/tf2jax.git to /tmp/pip-req-build-s8tm47gp
  Running command git clone --filter=blob:none --quiet https://github.com/deepmind/tf2jax.git /tmp/pip-req-build-s8tm47gp
  Resolved https://github.com/deepmind/tf2jax.git to commit b3e80f7e9ac18d7d495b995b2e63c182aee6b236
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: tf2jax
  Building wheel for tf2jax (setup.py) ... [?25l[?25hdone
  Created wheel for tf2jax: filename=tf2jax-0.1.1-py3-none-any.whl size=56508 sha256=d8941a4c1b481ceb0f62f73ed96e0a7fc22c628caf060598d45c02d7824168b2
  Stored in directory: /tmp/pip-ephem-wheel-cache-ko61pwrf/wheels/a9/d4/2a/09130b0825ec10a6157d88661459480b93f58bb0e62278ab79
Successfully built tf2jax
Installing collected packages: tf2jax
Successfully installed tf2jax-0.1.1
[0mCollecting git+https://github.com/icml2022anon/fast_finite_width_ntk.git
  Cloning https://github.com/i

In [8]:
import tensorflow as tf
from fast_finite_width_ntk import empirical_ntk_fn_tf

In [9]:
input_shape = (224, 224, 3)

# Tensorflow model definition

In [10]:
def get_model(O: int) -> tf.Module:
  return tf.keras.applications.resnet.ResNet50(classes=O, weights=None)

# NTK functions declaration

In [11]:
def get_ntk_fns(O: int):
  # Define a TF-Keras ResNet50 with `O` output logits.
  f = get_model(O)
  f.build((None, *input_shape))
  params = f.weights

  kwargs = dict(
      f=f,
      trace_axes=(),
      vmap_axes=0
  )

  # Different NTK implementations
  jacobian_contraction = empirical_ntk_fn_tf(**kwargs, implementation=1)
  ntvp = empirical_ntk_fn_tf(**kwargs, implementation=2)
  str_derivatives = empirical_ntk_fn_tf(**kwargs, implementation=3)
  auto = empirical_ntk_fn_tf(**kwargs, implementation=0)
  
  return params, (jacobian_contraction, ntvp, str_derivatives, auto)

# $\color{blue}O = 8$ logit, batch size $\color{red}N = 8$

Structured derivatives compute NTK fastest. NTK-vector products are actually slower in this setting, due to costly forward pass relative to parameters size, and therefore scales poorly with batch size $\color{red}N$. While it scales better with $\color{blue}O$ than other methods, it's not enough to overcome the $\color{red}N^2$ forward passes.

In [12]:
O = 8
N = 8

# Input images x
x1 = tf.random.normal((N, *input_shape))
x2 = tf.random.normal((N, *input_shape))

params, (ntk_fn_jacobian_contraction, ntk_fn_ntvp, ntk_fn_str_derivatives, ntk_fn_auto) = get_ntk_fns(O=O)



In [13]:
# Jacobian contraction
k_1 = ntk_fn_jacobian_contraction(x1, x2, params)
print(k_1.shape)





(8, 8, 8, 8)


In [14]:
# NTK-vector products
k_2 = ntk_fn_ntvp(x1, x2, params)
print(k_2.shape)





(8, 8, 8, 8)


In [15]:
# Structured derivatives
k_3 = ntk_fn_str_derivatives(x1, x2, params)
print(k_3.shape)





(8, 8, 8, 8)


In [16]:
# Make sure kernels agree.
print(
    tf.reduce_max(tf.abs(k_1 - k_2)) / tf.reduce_mean(tf.abs(k_1)), 
    tf.reduce_max(tf.abs(k_1 - k_3)) / tf.reduce_mean(tf.abs(k_1)),
    tf.reduce_max(tf.abs(k_2 - k_3)) / tf.reduce_mean(tf.abs(k_2))
)

tf.Tensor(0.00027873155, shape=(), dtype=float32) tf.Tensor(0.0006046945, shape=(), dtype=float32) tf.Tensor(0.00086610956, shape=(), dtype=float32)


In [17]:
# Selects best method based on FLOPs at first call / compilation.
# Takes about 3x more time to compile.
# WARNING: due to an XLA issue, currently only works correctly on TPUs!
# Wrong FLOPs for CPU/GPU of JITted functions.
k_0 = ntk_fn_auto(x1, x2, params)
print(k_0.shape)

impl=1, flops=18100342784.0
impl=2, flops=61626617856.0
impl=3, flops=18217189376.0




(8, 8, 8, 8)


In [18]:
%%timeit
ntk_fn_jacobian_contraction(x1, x2, params)

1 loop, best of 5: 264 ms per loop


In [19]:
%%timeit
# Slower - forward pass (FP) is expensive relative to parameters.
# Time cost scales poorly with batch size N.
ntk_fn_ntvp(x1, x2, params)

10 loops, best of 5: 427 ms per loop


In [20]:
%%timeit
# 2X faster!
ntk_fn_str_derivatives(x1, x2, params)

The slowest run took 8.69 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 5: 127 ms per loop


In [21]:
%%timeit 
# On TPU should match the fastest method.
# On GPU/CPU, currently is broken, and may not be the fastest.
ntk_fn_auto(x1, x2, params)

10 loops, best of 5: 264 ms per loop


# $\color{blue}O = 128$ logits, batch size $\color{red}N = 1$

Both NTK-vector products and Structured derivatives compute NTK faster than Jacobian contraction. NTK-vector products incur no penalty when batch size $\color{red}N = 1$, and leverage their beneficial scaling with large $\color{blue}O = 128$.

In [23]:
O = 128
N = 1

# Input images x
x1 = tf.random.normal((N, *input_shape))
x2 = tf.random.normal((N, *input_shape))

params, (ntk_fn_jacobian_contraction, ntk_fn_ntvp, ntk_fn_str_derivatives, ntk_fn_auto) = get_ntk_fns(O=O)



In [24]:
# Jacobian contraction
k_1 = ntk_fn_jacobian_contraction(x1, x2, params)
print(k_1.shape)





(1, 1, 128, 128)


In [25]:
# NTK-vector products
k_2 = ntk_fn_ntvp(x1, x2, params)
print(k_2.shape)





(1, 1, 128, 128)


In [26]:
# Structured derivatives
k_3 = ntk_fn_str_derivatives(x1, x2, params)
print(k_3.shape)





(1, 1, 128, 128)


In [27]:
# Make sure kernels agree.
print(
    tf.reduce_max(tf.abs(k_1 - k_2)) / tf.reduce_mean(tf.abs(k_1)), 
    tf.reduce_max(tf.abs(k_1 - k_3)) / tf.reduce_mean(tf.abs(k_1)),
    tf.reduce_max(tf.abs(k_2 - k_3)) / tf.reduce_mean(tf.abs(k_2))
)

tf.Tensor(0.014684605, shape=(), dtype=float32) tf.Tensor(0.0035293968, shape=(), dtype=float32) tf.Tensor(0.016298601, shape=(), dtype=float32)


In [28]:
# Selects best method based on FLOPs at first call / compilation.
# Takes about 3x more time to compile.
# WARNING: due to an XLA issue, currently only works correctly on TPUs!
# Wrong FLOPs for CPU/GPU of JITted functions.
k_0 = ntk_fn_auto(x1, x2, params)
print(k_0.shape)

impl=1, flops=30358691840.0
impl=2, flops=25772449792.0
impl=3, flops=31345016832.0








(1, 1, 128, 128)


In [29]:
%%timeit
ntk_fn_jacobian_contraction(x1, x2, params)

1 loop, best of 5: 421 ms per loop


In [30]:
%%timeit
# 2X faster!
ntk_fn_ntvp(x1, x2, params)

10 loops, best of 5: 205 ms per loop


In [31]:
%%timeit
# 2.5X faster!
ntk_fn_str_derivatives(x1, x2, params)

1 loop, best of 5: 162 ms per loop


In [32]:
%%timeit 
# On TPU should match the fastest method.
# On GPU/CPU, currently is broken, and may not be the fastest.
ntk_fn_auto(x1, x2, params)

10 loops, best of 5: 205 ms per loop


# $\color{blue}O = 1000$ logits, batch size $\color{red}N = 1$, full NTK

Structured derivatives allows to compute full $1000\times 1000$ ImageNet NTK. Other methods run out of memory.

In [33]:
O = 1000
N = 1

# Input images x
x1 = tf.random.normal((N, *input_shape))
x2 = tf.random.normal((N, *input_shape))

params, (ntk_fn_jacobian_contraction, ntk_fn_ntvp, ntk_fn_str_derivatives, ntk_fn_auto) = get_ntk_fns(O=O)



In [34]:
# Structured derivatives - fits in memory!
k_3 = ntk_fn_str_derivatives(x1, x2, params)
print(k_3.shape)





(1, 1, 1000, 1000)


In [35]:
%%timeit
ntk_fn_str_derivatives(x1, x2, params)

1 loop, best of 5: 1.26 s per loop


In [36]:
# NTK-vector products - OOM!
k_3 = ntk_fn_ntvp(x1, x2, params)
print(k_3.shape)





UnknownError: ignored

In [37]:
# Jacobian contraction - OOM!
k_1 = ntk_fn_jacobian_contraction(x1, x2, params)
print(k_1.shape)





UnknownError: ignored