<a href="https://colab.research.google.com/github/icml2022anon/fast_finite_width_ntk/blob/main/example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Example of computing NTK of a ResNet18 on ImageNet inputs

Tested on NVIDIA V100

# Imports and setup

In [None]:
!nvidia-smi -L

GPU 0: Tesla V100-SXM2-16GB (UUID: GPU-77f94e99-2605-bb10-cd34-6c366ced61f9)


In [None]:
# We need at least jaxlib-0.1.73 to avoid certain CUDA bugs when using `implementation=auto`
!pip install --upgrade pip
!pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html

Collecting pip
  Downloading pip-21.3.1-py3-none-any.whl (1.7 MB)
[K     |████████████████████████████████| 1.7 MB 11.9 MB/s 
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 21.1.3
    Uninstalling pip-21.1.3:
      Successfully uninstalled pip-21.1.3
Successfully installed pip-21.3.1
Looking in links: https://storage.googleapis.com/jax-releases/jax_releases.html
Collecting jaxlib==0.1.73+cuda11.cudnn805
  Downloading https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.1.73%2Bcuda11.cudnn805-cp37-none-manylinux2010_x86_64.whl (138.5 MB)
     |████████████████████████████████| 138.5 MB 53 kB/s             
Installing collected packages: jaxlib
  Attempting uninstall: jaxlib
    Found existing installation: jaxlib 0.1.71+cuda111
    Uninstalling jaxlib-0.1.71+cuda111:
      Successfully uninstalled jaxlib-0.1.71+cuda111
Successfully installed jaxlib-0.1.73+cuda11.cudnn805


In [None]:
!pip install git+https://github.com/icml2022anon/fast_finite_width_ntk.git

Collecting git+https://github.com/icml2022anon/fast_finite_width_ntk.git
  Cloning https://github.com/icml2022anon/fast_finite_width_ntk.git to /tmp/pip-req-build-zxe9pbew
  Running command git clone --filter=blob:none -q https://github.com/icml2022anon/fast_finite_width_ntk.git /tmp/pip-req-build-zxe9pbew
  Resolved https://github.com/icml2022anon/fast_finite_width_ntk.git to commit ee2db7af6795e6f083fa65e8bce67afa3dd0f0ad
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: fast-finite-width-ntk
  Building wheel for fast-finite-width-ntk (setup.py) ... [?25l[?25hdone
  Created wheel for fast-finite-width-ntk: filename=fast_finite_width_ntk-0.0.1-py3-none-any.whl size=28541 sha256=952d159295ff8897bdba47de428ec04db7603b1df85dbeabb42772b938d81fc1
  Stored in directory: /tmp/pip-ephem-wheel-cache-8g6yuilw/wheels/f4/cc/bf/251881ca4cc5881e20ddbe06962d314ee2b8aa7c456927027f
Successfully built fast-finite-width-ntk
Installing collected packages: fast

# FLAX Setup

In [None]:
# Install ml-collections & latest Flax version from Github.
!pip install -q clu ml-collections git+https://github.com/google/flax

example_directory = 'examples/imagenet'
editor_relpaths = ('configs/default.py', 'input_pipeline.py', 'models.py', 'train.py')

repo, branch = 'https://github.com/google/flax', 'main'

  Preparing metadata (setup.py) ... [?25l[?25hdone
     |████████████████████████████████| 77 kB 3.4 MB/s             
     |████████████████████████████████| 77 kB 6.4 MB/s             
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
     |████████████████████████████████| 126 kB 15.4 MB/s            
     |████████████████████████████████| 65 kB 3.2 MB/s             
[?25h  Building wheel for ml-collections (setup.py) ... [?25l[?25hdone
  Building wheel for flax (setup.py) ... [?25l[?25hdone


## Setup

In [None]:
# Install ml-collections & latest Flax version from Github.
!pip install -q clu ml-collections git+https://github.com/google/flax

  Preparing metadata (setup.py) ... [?25l[?25hdone


In [None]:
example_directory = 'examples/imagenet'
editor_relpaths = ('configs/default.py', 'input_pipeline.py', 'models.py', 'train.py')

repo, branch = 'https://github.com/google/flax', 'main'

In [None]:
# (If you run this code in Jupyter[lab], then you're already in the
#  example directory and nothing needs to be done.)

#@markdown **Fetch newest Flax, copy example code**
#@markdown
#@markdown **If you select no** below, then the files will be stored on the
#@markdown *ephemeral* Colab VM. **After some time of inactivity, this VM will
#@markdown be restarted an any changes are lost**.
#@markdown
#@markdown **If you select yes** below, then you will be asked for your
#@markdown credentials to mount your personal Google Drive. In this case, all
#@markdown changes you make will be *persisted*, and even if you re-run the
#@markdown Colab later on, the files will still be the same (you can of course
#@markdown remove directories inside your Drive's `flax/` root if you want to
#@markdown manually revert these files).

if 'google.colab' in str(get_ipython()):
  import os
  os.chdir('/content')
  # Download Flax repo from Github.
  if not os.path.isdir('flaxrepo'):
    !git clone --depth=1 -b $branch $repo flaxrepo
  # Copy example files & change directory.
  mount_gdrive = 'no' #@param ['yes', 'no']
  if mount_gdrive == 'yes':
    DISCLAIMER = 'Note : Editing in your Google Drive, changes will persist.'
    from google.colab import drive
    drive.mount('/content/gdrive')
    example_root_path = f'/content/gdrive/My Drive/flax/{example_directory}'
  else:
    DISCLAIMER = 'WARNING : Editing in VM - changes lost after reboot!!'
    example_root_path = f'/content/{example_directory}'
    from IPython import display
    display.display(display.HTML(
        f'<h1 style="color:red;" class="blink">{DISCLAIMER}</h1>'))
  if not os.path.isdir(example_root_path):
    os.makedirs(example_root_path)
    !cp -r flaxrepo/$example_directory/* "$example_root_path"
  os.chdir(example_root_path)
  from google.colab import files
  for relpath in editor_relpaths:
    s = open(f'{example_root_path}/{relpath}').read()
    open(f'{example_root_path}/{relpath}', 'w').write(
        f'## {DISCLAIMER}\n' + '#' * (len(DISCLAIMER) + 3) + '\n\n' + s)
    files.view(f'{example_root_path}/{relpath}')

Cloning into 'flaxrepo'...
remote: Enumerating objects: 343, done.[K
remote: Counting objects: 100% (343/343), done.[K
remote: Compressing objects: 100% (313/313), done.[K
remote: Total 343 (delta 53), reused 118 (delta 17), pack-reused 0[K
Receiving objects: 100% (343/343), 2.10 MiB | 12.82 MiB/s, done.
Resolving deltas: 100% (53/53), done.


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
# Note : In Colab, above cell changed the working direcoty.
!pwd

/content/examples/imagenet


## Imports / Helpers

In [None]:
# TPU setup : Boilerplate for connecting JAX to TPU.

import os
if 'google.colab' in str(get_ipython()) and 'COLAB_TPU_ADDR' in os.environ:
  # Make sure the Colab Runtime is set to Accelerator: TPU.
  import requests
  if 'TPU_DRIVER_MODE' not in globals():
    url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'
    resp = requests.post(url)
    TPU_DRIVER_MODE = 1

  # The following is required to use TPU Driver as JAX's backend.
  from jax.config import config
  config.FLAGS.jax_xla_backend = "tpu_driver"
  config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
  print('Registered TPU:', config.FLAGS.jax_backend_target)
else:
  print('No TPU detected. Can be changed under "Runtime/Change runtime type".')

No TPU detected. Can be changed under "Runtime/Change runtime type".


In [None]:
import json
from absl import logging
import flax
import jax
import jax.numpy as jnp
from matplotlib import pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

logging.set_verbosity(logging.INFO)

# assert len(jax.devices()) == 8, f'Expected 8 TPU cores : {jax.devices()}'

In [None]:
# Helper functions for images.

def show_img(img, ax=None, title=None):
  """Shows a single image."""
  if ax is None:
    ax = plt.gca()
  img *= tf.constant(input_pipeline.STDDEV_RGB, shape=[1, 1, 3], dtype=img.dtype)
  img += tf.constant(input_pipeline.MEAN_RGB, shape=[1, 1, 3], dtype=img.dtype)
  img = np.clip(img.numpy().astype(int), 0, 255)
  ax.imshow(img)
  ax.set_xticks([])
  ax.set_yticks([])
  if title:
    ax.set_title(title)

def show_img_grid(imgs, titles):
  """Shows a grid of images."""
  n = int(np.ceil(len(imgs)**.5))
  _, axs = plt.subplots(n, n, figsize=(3 * n, 3 * n))
  for i, (img, title) in enumerate(zip(imgs, titles)):
    show_img(img, axs[i // n][i % n], title)

In [None]:
# Local imports from current directory - auto reload.
# Any changes you make to train.py will appear automatically.
%load_ext autoreload
%autoreload 2
import input_pipeline
import models
import train
from configs import default as config_lib

In [None]:
from jax import jit
from jax import numpy as np
from jax import random

from fast_finite_width_ntk import empirical

In [None]:
def get_ntk_fns(O: int):
  # Define a ResNet18.
  model = models.ResNet18(num_classes=O)


  # f(x, \theta)
  def apply_fn(params, x):
    return model.apply(params, x, train=False, mutable=['batch_stats'])[0]

  kwargs = dict(
      f=apply_fn,
      trace_axes=(),
      vmap_axes=0
  )

  # Different NTK implementations
  jacobian_contraction = jit(empirical.empirical_ntk_fn(**kwargs, implementation=1))
  ntvp = jit(empirical.empirical_ntk_fn(**kwargs, implementation=2))
  str_derivatives = jit(empirical.empirical_ntk_fn(**kwargs, implementation=3))
  auto = jit(empirical.empirical_ntk_fn(**kwargs, implementation=0))
  
  # Parameters \theta
  params = model.init(random.PRNGKey(0), x1)
  return params, (jacobian_contraction, ntvp, str_derivatives, auto)

# $\color{blue}O = 8$ logit, batch size $\color{red}N = 8$

Structured derivatives compute NTK fastest. NTK-vector products are actually slower in this setting, due to costly forward pass relative to parameters size, and therefore scales poorly with batch size $\color{red}N$. While it scales better with $\color{blue}O$ than other methods, it's not enough to overcome the $\color{red}N^2$ forward passes.

In [None]:
O = 8
N = 8

# Input images x
input_shape = (224, 224, 3)
k1, k2 = random.split(random.PRNGKey(1), 2)
x1 = random.normal(k1, (N, *input_shape))
x2 = random.normal(k2, (N, *input_shape))

params, (ntk_fn_jacobian_contraction, ntk_fn_ntvp, ntk_fn_str_derivatives, ntk_fn_auto) = get_ntk_fns(O=O)

INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.


In [None]:
# Jacobian contraction
k_1 = ntk_fn_jacobian_contraction(x1, x2, params)
print(k_1.shape)

(8, 8, 8, 8)


In [None]:
# NTK-vector products
k_2 = ntk_fn_ntvp(x1, x2, params)
print(k_2.shape)

(8, 8, 8, 8)


In [None]:
# Structured derivatives
k_3 = ntk_fn_str_derivatives(x1, x2, params)
print(k_3.shape)

(8, 8, 8, 8)


In [None]:
# Make sure kernels agree.
print(
    np.max(np.abs(k_1 - k_2)) / np.mean(np.abs(k_1)), 
    np.max(np.abs(k_1 - k_3)) / np.mean(np.abs(k_1)),
    np.max(np.abs(k_2 - k_3)) / np.mean(np.abs(k_2))
)

2.79299e-06 3.351588e-06 3.9101856e-06


In [None]:
# Selects best method based on FLOPs at first call / compilation.
# Takes about 3x more time to compile.
# WARNING: due to an XLA issue, currently only works correctly on TPUs!
# Wrong FLOPs for CPU/GPU of JITted functions.
k_0 = ntk_fn_auto(x1, x2, params)
print(k_0.shape)

impl=1, flops=3964546560.0
impl=2, flops=14645684224.0
impl=3, flops=4283323904.0
(8, 8, 8, 8)


In [None]:
%%timeit
ntk_fn_jacobian_contraction(x1, x2, params).block_until_ready()

1 loop, best of 5: 222 ms per loop


In [None]:
%%timeit
# Slower - forward pass (FP) is expensive relative to parameters.
# Time cost scales poorly with batch size N.
ntk_fn_ntvp(x1, x2, params).block_until_ready()  

1 loop, best of 5: 312 ms per loop


In [None]:
%%timeit
# 3X faster.
ntk_fn_str_derivatives(x1, x2, params).block_until_ready()

10 loops, best of 5: 90.1 ms per loop


In [None]:
%%timeit 
# On TPU should match the fastest method.
# On GPU/CPU, currently is broken, and may not be the fastest.
ntk_fn_auto(x1, x2, params).block_until_ready()

1 loop, best of 5: 222 ms per loop


# $\color{blue}O = 128$ logits, batch size $\color{red}N = 1$

Both NTK-vector products and Structured derivatives compute NTK faster than Jacobian contraction. NTK-vector products incur no penalty when batch size $\color{red}N = 1$, and leverage their beneficial scaling with large $\color{blue}O = 128$.

In [None]:
O = 128
N = 1

# Input images x
input_shape = (224, 224, 3)
k1, k2 = random.split(random.PRNGKey(1), 2)
x1 = random.normal(k1, (N, *input_shape))
x2 = random.normal(k2, (N, *input_shape))

params, (ntk_fn_jacobian_contraction, ntk_fn_ntvp, ntk_fn_str_derivatives, ntk_fn_auto) = get_ntk_fns(O=O)

In [None]:
# Jacobian contraction
k_1 = ntk_fn_jacobian_contraction(x1, x2, params)
print(k_1.shape)

(1, 1, 128, 128)


In [None]:
# NTK-vector products
k_2 = ntk_fn_ntvp(x1, x2, params)
print(k_2.shape)

(1, 1, 128, 128)


In [None]:
# Structured derivatives
k_3 = ntk_fn_str_derivatives(x1, x2, params)
print(k_3.shape)

(1, 1, 128, 128)


In [None]:
# Make sure kernels agree.
print(
    np.max(np.abs(k_1 - k_2)) / np.mean(np.abs(k_1)), 
    np.max(np.abs(k_1 - k_3)) / np.mean(np.abs(k_1)),
    np.max(np.abs(k_2 - k_3)) / np.mean(np.abs(k_2))
)

1.161448e-05 4.645792e-06 1.3937378e-05


In [None]:
# Selects best method based on FLOPs at first call / compilation.
# Takes about 3x more time to compile.
# WARNING: due to an XLA issue, currently only works correctly on TPUs!
# Wrong FLOPs for CPU/GPU of JITted functions.
k_0 = ntk_fn_auto(x1, x2, params)
print(k_0.shape)

impl=1, flops=6864798208.0
impl=2, flops=6120569344.0
impl=3, flops=6878602240.0
(1, 1, 128, 128)


In [None]:
%%timeit
ntk_fn_jacobian_contraction(x1, x2, params).block_until_ready()

1 loop, best of 5: 453 ms per loop


In [None]:
%%timeit
# 3X faster!
ntk_fn_ntvp(x1, x2, params).block_until_ready()  

10 loops, best of 5: 151 ms per loop


In [None]:
%%timeit
# 4X faster!
ntk_fn_str_derivatives(x1, x2, params).block_until_ready()

10 loops, best of 5: 113 ms per loop


In [None]:
%%timeit 
# On TPU should match the fastest method.
# On GPU/CPU, currently is broken, and may not be the fastest.
ntk_fn_auto(x1, x2, params).block_until_ready()

10 loops, best of 5: 151 ms per loop


# $\color{blue}O = 1000$ logits, batch size $\color{red}N = 1$, full NTK

Structured derivatives allows to compute full $1000\times 1000$ ImageNet NTK. Other methods run out of memory.

In [None]:
O = 1000
N = 1

# Input images x
input_shape = (224, 224, 3)
k1, k2 = random.split(random.PRNGKey(1), 2)
x1 = random.normal(k1, (N, *input_shape))
x2 = random.normal(k2, (N, *input_shape))

params, (ntk_fn_jacobian_contraction, ntk_fn_ntvp, ntk_fn_str_derivatives, ntk_fn_auto) = get_ntk_fns(O=O)

In [None]:
# Structured derivatives - fits in memory!
k_3 = ntk_fn_str_derivatives(x1, x2, params)
print(k_3.shape)

(1, 1, 1000, 1000)


In [None]:
%%timeit
ntk_fn_str_derivatives(x1, x2, params).block_until_ready()

1 loop, best of 5: 983 ms per loop


In [None]:
# NTK-vector products - OOM!
k_3 = ntk_fn_ntvp(x1, x2, params)
print(k_3.shape)

RuntimeError: ignored

In [None]:
# Jacobian contraction - OOM!
k_1 = ntk_fn_jacobian_contraction(x1, x2, params)
print(k_1.shape)

RuntimeError: ignored