<a href="https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/experimental/empirical_ntk_resnet_tf.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Example of computing NTK of a **Tensorflow (Keras)** ResNet50 on ImageNet inputs
Warning: computing the NTK in Tensorflow currently appears to have very long compile times (but OK runtime), can be prone to triggering XLA errors, and does not distinguish between trainable and non-trainable parameters of the model.

Tested on NVIDIA A100

More examples: 


*   JAX (Flax):
  * [FCN](https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/empirical_ntk_fcn.ipynb)
  * [ResNet18](https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/empirical_ntk_resnet.ipynb)



# Imports and setup

In [1]:
!nvidia-smi -L

GPU 0: A100-SXM4-40GB (UUID: GPU-07b846cf-7f39-fff7-8224-d367cef00104)


In [2]:
!pip install --upgrade pip
!pip install --upgrade jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install git+https://www.github.com/google/neural-tangents.git

Collecting pip
  Downloading pip-22.1.2-py3-none-any.whl (2.1 MB)
[K     |████████████████████████████████| 2.1 MB 6.7 MB/s 
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 21.1.3
    Uninstalling pip-21.1.3:
      Successfully uninstalled pip-21.1.3
Successfully installed pip-22.1.2
Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Collecting jax[cuda11_cudnn805]
  Downloading jax-0.3.13.tar.gz (951 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m951.0/951.0 kB[0m [31m26.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting jaxlib==0.3.10+cuda11.cudnn805
  Downloading https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.10%2Bcuda11.cudnn805-cp37-none-manylinux2014_x86_64.whl (175.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m175.7/175.7 MB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
Building whe

In [3]:
import tensorflow as tf
import neural_tangents as nt

In [4]:
input_shape = (224, 224, 3)

# Tensorflow model definition

In [5]:
def get_model(O: int) -> tf.Module:
  return tf.keras.applications.resnet.ResNet50(classes=O, weights=None)

# NTK functions declaration

In [6]:
def get_ntk_fns(O: int):
  # Define a TF-Keras ResNet50 with `O` output logits.
  f = get_model(O)
  f.build((None, *input_shape))
  _, params = nt.experimental.get_apply_fn_and_params(f)

  kwargs = dict(
      f=f,
      trace_axes=(),
      vmap_axes=0
  )

  # Different NTK implementations
  jacobian_contraction = nt.experimental.empirical_ntk_fn_tf(
      **kwargs, implementation=nt.NtkImplementation.JACOBIAN_CONTRACTION)
  ntvp = nt.experimental.empirical_ntk_fn_tf(
      **kwargs, implementation=nt.NtkImplementation.NTK_VECTOR_PRODUCTS)
  str_derivatives = nt.experimental.empirical_ntk_fn_tf(
      **kwargs, implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES)
  auto = nt.experimental.empirical_ntk_fn_tf(
      **kwargs, implementation=nt.NtkImplementation.AUTO)
  
  return params, (jacobian_contraction, ntvp, str_derivatives, auto)

# $\color{blue}O = 8$ logit, batch size $\color{red}N = 8$

Structured derivatives compute NTK fastest. NTK-vector products are actually slower in this setting, due to costly forward pass relative to parameters size, and therefore scales poorly with batch size $\color{red}N$. While it scales better with $\color{blue}O$ than other methods, it's not enough to overcome the $\color{red}N^2$ forward passes.

In [7]:
O = 8
N = 8

# Input images x
x1 = tf.random.normal((N, *input_shape))
x2 = tf.random.normal((N, *input_shape))

params, (ntk_fn_jacobian_contraction, ntk_fn_ntvp, ntk_fn_str_derivatives, ntk_fn_auto) = get_ntk_fns(O=O)



In [8]:
# test {"skip": true}
# Jacobian contraction
k_1 = ntk_fn_jacobian_contraction(x1, x2, params)
print(k_1.shape)





(8, 8, 8, 8)


In [9]:
# test {"skip": true}
# NTK-vector products
k_2 = ntk_fn_ntvp(x1, x2, params)
print(k_2.shape)





(8, 8, 8, 8)


In [10]:
# test {"skip": true}
# Structured derivatives
k_3 = ntk_fn_str_derivatives(x1, x2, params)
print(k_3.shape)





(8, 8, 8, 8)


In [11]:
# test {"skip": true}
# Make sure kernels agree.
print(
    tf.reduce_max(tf.abs(k_1 - k_2)) / tf.reduce_mean(tf.abs(k_1)), 
    tf.reduce_max(tf.abs(k_1 - k_3)) / tf.reduce_mean(tf.abs(k_1)),
    tf.reduce_max(tf.abs(k_2 - k_3)) / tf.reduce_mean(tf.abs(k_2))
)

tf.Tensor(0.00030575716, shape=(), dtype=float32) tf.Tensor(0.00075477577, shape=(), dtype=float32) tf.Tensor(0.0010306665, shape=(), dtype=float32)


In [12]:
# test {"skip": true}
# Selects best method based on FLOPs at first call / compilation.
# Takes about 3x more time to compile.
# WARNING: due to an XLA issue, currently only works correctly on TPUs!
# Wrong FLOPs for CPU/GPU of JITted functions.
k_0 = ntk_fn_auto(x1, x2, params)
print(k_0.shape)

impl=1, flops=17815195648.0
impl=2, flops=61376139264.0
impl=3, flops=17957609472.0




(8, 8, 8, 8)


In [13]:
# test {"skip": true}
%%timeit
ntk_fn_jacobian_contraction(x1, x2, params)

1 loop, best of 5: 319 ms per loop


In [14]:
# test {"skip": true}
%%timeit
# Slower - forward pass (FP) is expensive relative to parameters.
# Time cost scales poorly with batch size N.
ntk_fn_ntvp(x1, x2, params)

1 loop, best of 5: 486 ms per loop


In [15]:
# test {"skip": true}
%%timeit
# 2X faster!
ntk_fn_str_derivatives(x1, x2, params)

1 loop, best of 5: 185 ms per loop


In [16]:
# test {"skip": true}
%%timeit 
# On TPU should match the fastest method.
# On GPU/CPU, currently is broken, and may not be the fastest.
ntk_fn_auto(x1, x2, params)

10 loops, best of 5: 325 ms per loop


# $\color{blue}O = 128$ logits, batch size $\color{red}N = 1$

Both NTK-vector products and Structured derivatives compute NTK faster than Jacobian contraction. NTK-vector products incur no penalty when batch size $\color{red}N = 1$, and leverage their beneficial scaling with large $\color{blue}O = 128$.

In [17]:
O = 128
N = 1

# Input images x
x1 = tf.random.normal((N, *input_shape))
x2 = tf.random.normal((N, *input_shape))

params, (ntk_fn_jacobian_contraction, ntk_fn_ntvp, ntk_fn_str_derivatives, ntk_fn_auto) = get_ntk_fns(O=O)



In [18]:
# test {"skip": true}
# Jacobian contraction
k_1 = ntk_fn_jacobian_contraction(x1, x2, params)
print(k_1.shape)





(1, 1, 128, 128)


In [19]:
# test {"skip": true}
# NTK-vector products
k_2 = ntk_fn_ntvp(x1, x2, params)
print(k_2.shape)





(1, 1, 128, 128)


In [20]:
# test {"skip": true}
# Structured derivatives
k_3 = ntk_fn_str_derivatives(x1, x2, params)
print(k_3.shape)





(1, 1, 128, 128)


In [21]:
# test {"skip": true}
# Make sure kernels agree.
print(
    tf.reduce_max(tf.abs(k_1 - k_2)) / tf.reduce_mean(tf.abs(k_1)), 
    tf.reduce_max(tf.abs(k_1 - k_3)) / tf.reduce_mean(tf.abs(k_1)),
    tf.reduce_max(tf.abs(k_2 - k_3)) / tf.reduce_mean(tf.abs(k_2))
)

tf.Tensor(0.016565105, shape=(), dtype=float32) tf.Tensor(0.0032940311, shape=(), dtype=float32) tf.Tensor(0.01386257, shape=(), dtype=float32)


In [22]:
# test {"skip": true}
# Selects best method based on FLOPs at first call / compilation.
# Takes about 3x more time to compile.
# WARNING: due to an XLA issue, currently only works correctly on TPUs!
# Wrong FLOPs for CPU/GPU of JITted functions.
k_0 = ntk_fn_auto(x1, x2, params)
print(k_0.shape)

impl=1, flops=30326192128.0
impl=2, flops=25741133824.0
impl=3, flops=30259060736.0








(1, 1, 128, 128)


In [23]:
# test {"skip": true}
%%timeit
ntk_fn_jacobian_contraction(x1, x2, params)

1 loop, best of 5: 481 ms per loop


In [24]:
# test {"skip": true}
%%timeit
# 2X faster!
ntk_fn_ntvp(x1, x2, params)

1 loop, best of 5: 265 ms per loop


In [25]:
# test {"skip": true}
%%timeit
# 2.5X faster!
ntk_fn_str_derivatives(x1, x2, params)

1 loop, best of 5: 212 ms per loop


In [26]:
# test {"skip": true}
%%timeit 
# On TPU should match the fastest method.
# On GPU/CPU, currently is broken, and may not be the fastest.
ntk_fn_auto(x1, x2, params)

10 loops, best of 5: 264 ms per loop


# $\color{blue}O = 1000$ logits, batch size $\color{red}N = 1$, full NTK

Structured derivatives allows to compute full $1000\times 1000$ ImageNet NTK. Other methods run out of memory.

In [27]:
O = 1000
N = 1

# Input images x
x1 = tf.random.normal((N, *input_shape))
x2 = tf.random.normal((N, *input_shape))

params, (ntk_fn_jacobian_contraction, ntk_fn_ntvp, ntk_fn_str_derivatives, ntk_fn_auto) = get_ntk_fns(O=O)



In [28]:
# test {"skip": true}
# Structured derivatives - fits in memory!
k_3 = ntk_fn_str_derivatives(x1, x2, params)
print(k_3.shape)





(1, 1, 1000, 1000)


In [29]:
# test {"skip": true}
%%timeit
ntk_fn_str_derivatives(x1, x2, params)

1 loop, best of 5: 1.28 s per loop


In [30]:
# test {"skip": true}
# NTK-vector products - OOM!
k_2 = ntk_fn_ntvp(x1, x2, params)
print(k_3.shape)





UnknownError: ignored

In [31]:
# test {"skip": true}
# Jacobian contraction - OOM!
k_1 = ntk_fn_jacobian_contraction(x1, x2, params)
print(k_1.shape)





UnknownError: ignored