<a href="https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/experimental/empirical_ntk_resnet_tf.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Example of computing NTK of a **Tensorflow (Keras)** ResNet50 on ImageNet inputs

Tested on NVIDIA A100

More examples: 


*   JAX (Flax):
  * [FCN](https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/empirical_ntk_fcn.ipynb)
  * [ResNet18](https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/empirical_ntk_resnet.ipynb)



# Imports and setup

In [10]:
!nvidia-smi -L

GPU 0: A100-SXM4-40GB (UUID: GPU-097b69f8-fbcd-6775-82c1-4cfd355907e9)


In [None]:
# We need at least jaxlib-0.1.73 to avoid certain CUDA bugs when using `implementation=auto`
!pip install -q --upgrade pip
!pip install --upgrade jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install -q git+https://www.github.com/google/neural-tangents.git

In [None]:
import tensorflow as tf
import neural_tangents as nt

In [13]:
input_shape = (224, 224, 3)

# Tensorflow model definition

In [14]:
def get_model(O: int) -> tf.Module:
  return tf.keras.applications.resnet.ResNet50(classes=O, weights=None)

# NTK functions declaration

In [15]:
def get_ntk_fns(O: int):
  # Define a TF-Keras ResNet50 with `O` output logits.
  f = get_model(O)
  f.build((None, *input_shape))
  _, params = nt.experimental.get_apply_fn_and_params(f)

  kwargs = dict(
      f=f,
      trace_axes=(),
      vmap_axes=0
  )

  # Different NTK implementations
  jacobian_contraction = nt.experimental.empirical_ntk_fn_tf(
      **kwargs, implementation=nt.NtkImplementation.JACOBIAN_CONTRACTION)
  ntvp = nt.experimental.empirical_ntk_fn_tf(
      **kwargs, implementation=nt.NtkImplementation.NTK_VECTOR_PRODUCTS)
  str_derivatives = nt.experimental.empirical_ntk_fn_tf(
      **kwargs, implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES)
  auto = nt.experimental.empirical_ntk_fn_tf(
      **kwargs, implementation=nt.NtkImplementation.AUTO)
  
  return params, (jacobian_contraction, ntvp, str_derivatives, auto)

# $\color{blue}O = 8$ logit, batch size $\color{red}N = 8$

Structured derivatives compute NTK fastest. NTK-vector products are actually slower in this setting, due to costly forward pass relative to parameters size, and therefore scales poorly with batch size $\color{red}N$. While it scales better with $\color{blue}O$ than other methods, it's not enough to overcome the $\color{red}N^2$ forward passes.

In [16]:
O = 8
N = 8

# Input images x
x1 = tf.random.normal((N, *input_shape))
x2 = tf.random.normal((N, *input_shape))

params, (ntk_fn_jacobian_contraction, ntk_fn_ntvp, ntk_fn_str_derivatives, ntk_fn_auto) = get_ntk_fns(O=O)



In [17]:
# Jacobian contraction
k_1 = ntk_fn_jacobian_contraction(x1, x2, params)
print(k_1.shape)





(8, 8, 8, 8)


In [18]:
# NTK-vector products
k_2 = ntk_fn_ntvp(x1, x2, params)
print(k_2.shape)





(8, 8, 8, 8)


In [19]:
# Structured derivatives
k_3 = ntk_fn_str_derivatives(x1, x2, params)
print(k_3.shape)





(8, 8, 8, 8)


In [20]:
# Make sure kernels agree.
print(
    tf.reduce_max(tf.abs(k_1 - k_2)) / tf.reduce_mean(tf.abs(k_1)), 
    tf.reduce_max(tf.abs(k_1 - k_3)) / tf.reduce_mean(tf.abs(k_1)),
    tf.reduce_max(tf.abs(k_2 - k_3)) / tf.reduce_mean(tf.abs(k_2))
)

tf.Tensor(0.00024576858, shape=(), dtype=float32) tf.Tensor(0.0006717233, shape=(), dtype=float32) tf.Tensor(0.00086082943, shape=(), dtype=float32)


In [21]:
# Selects best method based on FLOPs at first call / compilation.
# Takes about 3x more time to compile.
# WARNING: due to an XLA issue, currently only works correctly on TPUs!
# Wrong FLOPs for CPU/GPU of JITted functions.
k_0 = ntk_fn_auto(x1, x2, params)
print(k_0.shape)

impl=1, flops=17815195648.0
impl=2, flops=61376139264.0
impl=3, flops=17957609472.0




(8, 8, 8, 8)


In [22]:
%%timeit
ntk_fn_jacobian_contraction(x1, x2, params)

1 loop, best of 5: 317 ms per loop


In [23]:
%%timeit
# Slower - forward pass (FP) is expensive relative to parameters.
# Time cost scales poorly with batch size N.
ntk_fn_ntvp(x1, x2, params)

1 loop, best of 5: 484 ms per loop


In [24]:
%%timeit
# 2X faster!
ntk_fn_str_derivatives(x1, x2, params)

1 loop, best of 5: 184 ms per loop


In [25]:
%%timeit 
# On TPU should match the fastest method.
# On GPU/CPU, currently is broken, and may not be the fastest.
ntk_fn_auto(x1, x2, params)

10 loops, best of 5: 324 ms per loop


# $\color{blue}O = 128$ logits, batch size $\color{red}N = 1$

Both NTK-vector products and Structured derivatives compute NTK faster than Jacobian contraction. NTK-vector products incur no penalty when batch size $\color{red}N = 1$, and leverage their beneficial scaling with large $\color{blue}O = 128$.

In [26]:
O = 128
N = 1

# Input images x
x1 = tf.random.normal((N, *input_shape))
x2 = tf.random.normal((N, *input_shape))

params, (ntk_fn_jacobian_contraction, ntk_fn_ntvp, ntk_fn_str_derivatives, ntk_fn_auto) = get_ntk_fns(O=O)



In [27]:
# Jacobian contraction
k_1 = ntk_fn_jacobian_contraction(x1, x2, params)
print(k_1.shape)





(1, 1, 128, 128)


In [28]:
# NTK-vector products
k_2 = ntk_fn_ntvp(x1, x2, params)
print(k_2.shape)





(1, 1, 128, 128)


In [29]:
# Structured derivatives
k_3 = ntk_fn_str_derivatives(x1, x2, params)
print(k_3.shape)





(1, 1, 128, 128)


In [30]:
# Make sure kernels agree.
print(
    tf.reduce_max(tf.abs(k_1 - k_2)) / tf.reduce_mean(tf.abs(k_1)), 
    tf.reduce_max(tf.abs(k_1 - k_3)) / tf.reduce_mean(tf.abs(k_1)),
    tf.reduce_max(tf.abs(k_2 - k_3)) / tf.reduce_mean(tf.abs(k_2))
)

tf.Tensor(0.014025839, shape=(), dtype=float32) tf.Tensor(0.0051410757, shape=(), dtype=float32) tf.Tensor(0.0148538705, shape=(), dtype=float32)


In [31]:
# Selects best method based on FLOPs at first call / compilation.
# Takes about 3x more time to compile.
# WARNING: due to an XLA issue, currently only works correctly on TPUs!
# Wrong FLOPs for CPU/GPU of JITted functions.
k_0 = ntk_fn_auto(x1, x2, params)
print(k_0.shape)

impl=1, flops=30326192128.0
impl=2, flops=25741133824.0
impl=3, flops=30259060736.0








(1, 1, 128, 128)


In [32]:
%%timeit
ntk_fn_jacobian_contraction(x1, x2, params)

1 loop, best of 5: 495 ms per loop


In [33]:
%%timeit
# 2X faster!
ntk_fn_ntvp(x1, x2, params)

1 loop, best of 5: 276 ms per loop


In [34]:
%%timeit
# 2.5X faster!
ntk_fn_str_derivatives(x1, x2, params)

1 loop, best of 5: 232 ms per loop


In [35]:
%%timeit 
# On TPU should match the fastest method.
# On GPU/CPU, currently is broken, and may not be the fastest.
ntk_fn_auto(x1, x2, params)

1 loop, best of 5: 275 ms per loop


# $\color{blue}O = 1000$ logits, batch size $\color{red}N = 1$, full NTK

Structured derivatives allows to compute full $1000\times 1000$ ImageNet NTK. Other methods run out of memory.

In [36]:
O = 1000
N = 1

# Input images x
x1 = tf.random.normal((N, *input_shape))
x2 = tf.random.normal((N, *input_shape))

params, (ntk_fn_jacobian_contraction, ntk_fn_ntvp, ntk_fn_str_derivatives, ntk_fn_auto) = get_ntk_fns(O=O)



In [37]:
# Structured derivatives - fits in memory!
k_3 = ntk_fn_str_derivatives(x1, x2, params)
print(k_3.shape)





(1, 1, 1000, 1000)


In [38]:
%%timeit
ntk_fn_str_derivatives(x1, x2, params)

1 loop, best of 5: 1.29 s per loop


In [39]:
# NTK-vector products - OOM!
k_3 = ntk_fn_ntvp(x1, x2, params)
print(k_3.shape)





UnknownError: ignored

In [40]:
# Jacobian contraction - OOM!
k_1 = ntk_fn_jacobian_contraction(x1, x2, params)
print(k_1.shape)





UnknownError: ignored