<a href="https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/empirical_ntk_resnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Example of computing NTK of a ResNet18 on ImageNet inputs

Tested on NVIDIA V100

# Imports and setup

In [None]:
!nvidia-smi -L

/bin/sh: line 1: nvidia-smi: command not found


In [None]:
# We need at least jaxlib-0.1.73 to avoid certain CUDA bugs when using `implementation=auto`
!pip install --upgrade pip
!pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
!pip install -q git+https://github.com/google/flax
!pip install -q git+https://www.github.com/google/neural-tangents

/bin/sh: line 1: pip: command not found
/bin/sh: line 1: pip: command not found
/bin/sh: line 1: pip: command not found


In [None]:
from functools import partial
from typing import Any, Callable, Sequence, Tuple, Optional
from flax import linen as nn
import jax.numpy as np
import numpy as onp
from jax import jit
from jax import numpy as np
from jax import random

import neural_tangents as nt

# ResNet18 definition, copied from [FLAX examples](https://github.com/google/flax/blob/main/examples/imagenet/models.py)

In [None]:
_ModuleDef = Any


class _ResNetBlock(nn.Module):
  """ResNet block."""
  filters: int
  conv: _ModuleDef
  norm: _ModuleDef
  act: Callable
  strides: Tuple[int, int] = (1, 1)

  @nn.compact
  def __call__(self, x,):
    residual = x
    y = self.conv(self.filters, (3, 3), self.strides)(x)
    y = self.norm()(y)
    y = self.act(y)
    y = self.conv(self.filters, (3, 3))(y)
    y = self.norm(scale_init=nn.initializers.zeros)(y)

    if residual.shape != y.shape:
      residual = self.conv(self.filters, (1, 1),
                           self.strides, name='conv_proj')(residual)
      residual = self.norm(name='norm_proj')(residual)

    return self.act(residual + y)


class _BottleneckResNetBlock(nn.Module):
  """Bottleneck ResNet block."""
  filters: int
  conv: _ModuleDef
  norm: _ModuleDef
  act: Callable
  strides: Tuple[int, int] = (1, 1)

  @nn.compact
  def __call__(self, x):
    residual = x
    y = self.conv(self.filters, (1, 1))(x)
    y = self.norm()(y)
    y = self.act(y)
    y = self.conv(self.filters, (3, 3), self.strides)(y)
    y = self.norm()(y)
    y = self.act(y)
    y = self.conv(self.filters * 4, (1, 1))(y)
    y = self.norm(scale_init=nn.initializers.zeros)(y)

    if residual.shape != y.shape:
      residual = self.conv(self.filters * 4, (1, 1),
                           self.strides, name='conv_proj')(residual)
      residual = self.norm(name='norm_proj')(residual)

    return self.act(residual + y)


class _ResNet(nn.Module):
  """ResNetV1."""
  stage_sizes: Sequence[int]
  block_cls: _ModuleDef
  num_classes: int
  num_filters: int = 64
  dtype: Any = np.float32
  act: Callable = nn.relu
  conv: _ModuleDef = nn.Conv

  @nn.compact
  def __call__(self, x, train: bool = True):
    conv = partial(self.conv, use_bias=False, dtype=self.dtype)
    norm = partial(nn.BatchNorm,
                   use_running_average=not train,
                   momentum=0.9,
                   epsilon=1e-5,
                   dtype=self.dtype)

    x = conv(self.num_filters, (7, 7), (2, 2),
             padding=[(3, 3), (3, 3)],
             name='conv_init')(x)
    x = norm(name='bn_init')(x)
    x = nn.relu(x)
    x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
    for i, block_size in enumerate(self.stage_sizes):
      for j in range(block_size):
        strides = (2, 2) if i > 0 and j == 0 else (1, 1)
        x = self.block_cls(self.num_filters * 2 ** i,
                           strides=strides,
                           conv=conv,
                           norm=norm,
                           act=self.act)(x)
    x = np.mean(x, axis=(1, 2))
    x = nn.Dense(self.num_classes, dtype=self.dtype)(x)
    x = np.asarray(x, self.dtype)
    return x


_ResNet18 = partial(_ResNet, stage_sizes=[2, 2, 2, 2],
                    block_cls=_ResNetBlock)

In [None]:
def get_ntk_fns(O: int):
  # Define a ResNet18.
  model = _ResNet18(num_classes=O)

  # f(x, \theta)
  def apply_fn(params, x):
    return model.apply(params, x, train=False, mutable=['batch_stats'])[0]

  kwargs = dict(
      f=apply_fn,
      trace_axes=(),
      vmap_axes=0
  )

  # Different NTK implementations
  jacobian_contraction = jit(nt.empirical_ntk_fn(**kwargs, implementation=1))
  ntvp = jit(nt.empirical_ntk_fn(**kwargs, implementation=2))
  str_derivatives = jit(nt.empirical_ntk_fn(**kwargs, implementation=3))
  auto = jit(nt.empirical_ntk_fn(**kwargs, implementation=0))

  # Parameters \theta
  params = model.init(random.PRNGKey(0), x1)
  return params, (jacobian_contraction, ntvp, str_derivatives, auto)

# $\color{blue}O = 8$ logit, batch size $\color{red}N = 8$

Structured derivatives compute NTK fastest. NTK-vector products are actually slower in this setting, due to costly forward pass relative to parameters size, and therefore scales poorly with batch size $\color{red}N$. While it scales better with $\color{blue}O$ than other methods, it's not enough to overcome the $\color{red}N^2$ forward passes.

In [None]:
O = 8
N = 8

# Input images x
input_shape = (224, 224, 3)
k1, k2 = random.split(random.PRNGKey(1), 2)
x1 = random.normal(k1, (N, *input_shape))
x2 = random.normal(k2, (N, *input_shape))

params, (ntk_fn_jacobian_contraction, ntk_fn_ntvp, ntk_fn_str_derivatives, ntk_fn_auto) = get_ntk_fns(O=O)

In [None]:
#@test {"skip": true}
# Jacobian contraction
k_1 = ntk_fn_jacobian_contraction(x1, x2, params)
print(k_1.shape)

(8, 8, 8, 8)


In [None]:
#@test {"skip": true}
# NTK-vector products
k_2 = ntk_fn_ntvp(x1, x2, params)
print(k_2.shape)

(8, 8, 8, 8)


In [None]:
# Structured derivatives
k_3 = ntk_fn_str_derivatives(x1, x2, params)
print(k_3.shape)

(8, 8, 8, 8)


In [None]:
#@test {"skip": true}
# Make sure kernels agree.
print(
    np.max(np.abs(k_1 - k_2)) / np.mean(np.abs(k_1)),
    np.max(np.abs(k_1 - k_3)) / np.mean(np.abs(k_1)),
    np.max(np.abs(k_2 - k_3)) / np.mean(np.abs(k_2))
)

3.4810557e-06 3.4810557e-06 3.916188e-06


In [None]:
#@test {"skip": true}
# Selects best method based on FLOPs at first call / compilation.
# Takes about 3x more time to compile.
# WARNING: due to an XLA issue, currently only works correctly on TPUs!
# Wrong FLOPs for CPU/GPU of JITted functions.
k_0 = ntk_fn_auto(x1, x2, params)
print(k_0.shape)

impl=1, flops=3964546560.0
impl=2, flops=20214009856.0
impl=3, flops=4047975424.0
(8, 8, 8, 8)


In [None]:
#@test {"skip": true}
%%timeit
ntk_fn_jacobian_contraction(x1, x2, params).block_until_ready()

1 loops, best of 5: 224 ms per loop


In [None]:
#@test {"skip": true}
%%timeit
# Slower - forward pass (FP) is expensive relative to parameters.
# Time cost scales poorly with batch size N.
ntk_fn_ntvp(x1, x2, params).block_until_ready()

1 loops, best of 5: 329 ms per loop


In [None]:
#@test {"skip": true}
%%timeit
# 2X faster.
ntk_fn_str_derivatives(x1, x2, params).block_until_ready()

10 loops, best of 5: 103 ms per loop


In [None]:
#@test {"skip": true}
%%timeit
# On TPU should match the fastest method.
# On GPU/CPU, currently is broken, and may not be the fastest.
ntk_fn_auto(x1, x2, params).block_until_ready()

1 loops, best of 5: 224 ms per loop


# $\color{blue}O = 128$ logits, batch size $\color{red}N = 1$

Both NTK-vector products and Structured derivatives compute NTK faster than Jacobian contraction. NTK-vector products incur no penalty when batch size $\color{red}N = 1$, and leverage their beneficial scaling with large $\color{blue}O = 128$.

In [None]:
#@test {"skip": true}
O = 128
N = 1

# Input images x
input_shape = (224, 224, 3)
k1, k2 = random.split(random.PRNGKey(1), 2)
x1 = random.normal(k1, (N, *input_shape))
x2 = random.normal(k2, (N, *input_shape))

params, (ntk_fn_jacobian_contraction, ntk_fn_ntvp, ntk_fn_str_derivatives, ntk_fn_auto) = get_ntk_fns(O=O)

In [None]:
#@test {"skip": true}
# Jacobian contraction
k_1 = ntk_fn_jacobian_contraction(x1, x2, params)
print(k_1.shape)

(1, 1, 128, 128)


In [None]:
#@test {"skip": true}
# NTK-vector products
k_2 = ntk_fn_ntvp(x1, x2, params)
print(k_2.shape)

(1, 1, 128, 128)


In [None]:
#@test {"skip": true}
# Structured derivatives
k_3 = ntk_fn_str_derivatives(x1, x2, params)
print(k_3.shape)

(1, 1, 128, 128)


In [None]:
#@test {"skip": true}
# Make sure kernels agree.
print(
    np.max(np.abs(k_1 - k_2)) / np.mean(np.abs(k_1)),
    np.max(np.abs(k_1 - k_3)) / np.mean(np.abs(k_1)),
    np.max(np.abs(k_2 - k_3)) / np.mean(np.abs(k_2))
)

1.637553e-05 1.0234707e-05 1.2281647e-05


In [None]:
#@test {"skip": true}
# Selects best method based on FLOPs at first call / compilation.
# Takes about 3x more time to compile.
# WARNING: due to an XLA issue, currently only works correctly on TPUs!
# Wrong FLOPs for CPU/GPU of JITted functions.
k_0 = ntk_fn_auto(x1, x2, params)
print(k_0.shape)

impl=1, flops=6864192000.0
impl=2, flops=7510226432.0
impl=3, flops=6847879168.0
(1, 1, 128, 128)


In [None]:
#@test {"skip": true}
%%timeit
ntk_fn_jacobian_contraction(x1, x2, params).block_until_ready()

1 loops, best of 5: 454 ms per loop


In [None]:
#@test {"skip": true}
%%timeit
# 3X faster!
ntk_fn_ntvp(x1, x2, params).block_until_ready()

10 loops, best of 5: 146 ms per loop


In [None]:
#@test {"skip": true}
%%timeit
# 4X faster!
ntk_fn_str_derivatives(x1, x2, params).block_until_ready()

10 loops, best of 5: 112 ms per loop


In [None]:
#@test {"skip": true}
%%timeit
# On TPU should match the fastest method.
# On GPU/CPU, currently is broken, and may not be the fastest.
ntk_fn_auto(x1, x2, params).block_until_ready()

10 loops, best of 5: 112 ms per loop


# $\color{blue}O = 1000$ logits, batch size $\color{red}N = 1$, full NTK

Structured derivatives allows to compute full $1000\times 1000$ ImageNet NTK. Other methods run out of memory.

In [None]:
#@test {"skip": true}
O = 1000
N = 1

# Input images x
input_shape = (224, 224, 3)
k1, k2 = random.split(random.PRNGKey(1), 2)
x1 = random.normal(k1, (N, *input_shape))
x2 = random.normal(k2, (N, *input_shape))

params, (ntk_fn_jacobian_contraction, ntk_fn_ntvp, ntk_fn_str_derivatives, ntk_fn_auto) = get_ntk_fns(O=O)

In [None]:
#@test {"skip": true}
# Structured derivatives - fits in memory!
k_3 = ntk_fn_str_derivatives(x1, x2, params)
print(k_3.shape)

(1, 1, 1000, 1000)


In [None]:
#@test {"skip": true}
%%timeit
ntk_fn_str_derivatives(x1, x2, params).block_until_ready()

1 loops, best of 5: 975 ms per loop


In [None]:
#@test {"skip": true}
# NTK-vector products - OOM!
k_3 = ntk_fn_ntvp(x1, x2, params)
print(k_3.shape)

XlaRuntimeError: ignored

In [None]:
#@test {"skip": true}
# Jacobian contraction - OOM!
k_1 = ntk_fn_jacobian_contraction(x1, x2, params)
print(k_1.shape)


XlaRuntimeError: ignored