<a href="https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/empirical_ntk_fcn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Example of computing finite width NTK of an FCN on CIFAR-10 inputs

Tested on NVIDIA T4.

# Imports and setup

In [None]:
!nvidia-smi -L

GPU 0: Tesla T4 (UUID: GPU-70c5645d-1ae8-d42c-8b35-ca8a58b81237)


In [None]:
# We need at least jaxlib-0.1.73 to avoid certain CUDA bugs when using `implementation=auto`
!pip install -q --upgrade pip
!pip install -q jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install -q git+https://www.github.com/google/neural-tangents

In [None]:
from jax import jit
from jax import numpy as np
from jax import random

import neural_tangents as nt
from neural_tangents import stax

# Defining a simple FCN model

In [None]:
def get_ntk_fns(O: int):
  # Define an FCN.
  init_fn, apply_fn, _ = stax.serial(
      stax.Dense(2048),
      stax.Relu(),
      stax.Dense(2048),
      stax.Relu(),
      stax.Dense(2048),
      stax.Relu(),
      stax.Dense(O)
  )

  kwargs = dict(
      f=apply_fn,
      trace_axes=(),
      vmap_axes=0
  )

  # Different NTK implementations
  jacobian_contraction = jit(nt.empirical_ntk_fn(
      **kwargs, implementation=nt.NtkImplementation.JACOBIAN_CONTRACTION))
  ntvp = jit(nt.empirical_ntk_fn(
      **kwargs, implementation=nt.NtkImplementation.NTK_VECTOR_PRODUCTS))
  str_derivatives = jit(nt.empirical_ntk_fn(
      **kwargs, implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES))
  auto = jit(nt.empirical_ntk_fn(
      **kwargs, implementation=nt.NtkImplementation.AUTO))

  # Parameters \theta
  _, params = init_fn(random.PRNGKey(0), x1.shape)
  return params, (jacobian_contraction, ntvp, str_derivatives, auto)

# Benchmark

Structured derivatives compute NTK fastest. NTK-vector products also provide a speedup, due to a cheap forward pass relative to parameters size.

In [None]:
O = 8
N = 16

# Input images x
input_shape = (3072,)
k1, k2 = random.split(random.PRNGKey(1), 2)
x1 = random.normal(k1, (N, *input_shape))
x2 = random.normal(k2, (N, *input_shape))

params, (ntk_fn_jacobian_contraction, ntk_fn_ntvp, ntk_fn_str_derivatives, ntk_fn_auto) = get_ntk_fns(O=O)

In [None]:
# Jacobian contraction
k_1 = ntk_fn_jacobian_contraction(x1, x2, params)
print(k_1.shape)

(16, 16, 8, 8)


In [None]:
# NTK-vector products
k_2 = ntk_fn_ntvp(x1, x2, params)
print(k_2.shape)

(16, 16, 8, 8)


In [None]:
# Structured derivatives
k_3 = ntk_fn_str_derivatives(x1, x2, params)
print(k_3.shape)

(16, 16, 8, 8)


In [None]:
# Make sure kernels agree.
print(
    np.max(np.abs(k_1 - k_2)) / np.mean(np.abs(k_1)), 
    np.max(np.abs(k_1 - k_3)) / np.mean(np.abs(k_1)),
    np.max(np.abs(k_2 - k_3)) / np.mean(np.abs(k_2))
)

8.350792e-06 8.350792e-06 3.7114617e-06


In [None]:
# test {"skip": true}
# Selects best method based on FLOPs at first call / compilation.
# Takes about 3x more time to compile.
# WARNING: due to an XLA issue, currently only works correctly on TPUs!
# Wrong FLOPs for CPU/GPU of JITted functions.
k_0 = ntk_fn_auto(x1, x2, params)
print(k_0.shape)

impl=1, flops=3765027328.0
impl=2, flops=1916764544.0
impl=3, flops=2670960.0
(16, 16, 8, 8)


In [None]:
# test {"skip": true}
%%timeit
ntk_fn_jacobian_contraction(x1, x2, params).block_until_ready()

1 loop, best of 5: 243 ms per loop


In [None]:
# test {"skip": true}
%%timeit
# 3X faster.
ntk_fn_ntvp(x1, x2, params).block_until_ready()  

10 loops, best of 5: 81.4 ms per loop


In [None]:
# test {"skip": true}
%%timeit
# 70X faster.
ntk_fn_str_derivatives(x1, x2, params).block_until_ready()

100 loops, best of 5: 3.46 ms per loop


In [None]:
# test {"skip": true}
%%timeit 
# On TPU should match the fastest method.
# On GPU/CPU, currently is broken, and may not be the fastest.
ntk_fn_auto(x1, x2, params).block_until_ready()

100 loops, best of 5: 3.44 ms per loop
