<a href="https://colab.research.google.com/github/icml2022anon/fast_finite_width_ntk/blob/main/example_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Example of computing NTK of a **PyTorch** FCN

Tested on NVIDIA A100

More examples: 


*   [JAX (Flax)](https://colab.research.google.com/github/icml2022anon/fast_finite_width_ntk/blob/main/example.ipynb)
*   [TensorFlow (Keras)](https://colab.research.google.com/github/icml2022anon/fast_finite_width_ntk/blob/main/example_tf.ipynb)

# Imports and setup

In [1]:
!nvidia-smi -L

GPU 0: A100-SXM4-40GB (UUID: GPU-f3b01e18-1bac-8bc5-09d7-b3075ae5c01d)


In [2]:
# We need at least jaxlib-0.1.73 to avoid certain CUDA bugs when using `implementation=auto`
!pip install --upgrade pip
!pip install --upgrade jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html

[0mLooking in links: https://storage.googleapis.com/jax-releases/jax_releases.html
[0m

In [3]:
!pip install git+https://github.com/deepmind/tf2jax.git --no-deps
!pip install git+https://github.com/icml2022anon/fast_finite_width_ntk.git --upgrade

Collecting git+https://github.com/deepmind/tf2jax.git
  Cloning https://github.com/deepmind/tf2jax.git to /tmp/pip-req-build-vn3jkvr9
  Running command git clone --filter=blob:none --quiet https://github.com/deepmind/tf2jax.git /tmp/pip-req-build-vn3jkvr9
  Resolved https://github.com/deepmind/tf2jax.git to commit b3e80f7e9ac18d7d495b995b2e63c182aee6b236
  Preparing metadata (setup.py) ... [?25l[?25hdone
[0mCollecting git+https://github.com/icml2022anon/fast_finite_width_ntk.git
  Cloning https://github.com/icml2022anon/fast_finite_width_ntk.git to /tmp/pip-req-build-1pq9dhf3
  Running command git clone --filter=blob:none --quiet https://github.com/icml2022anon/fast_finite_width_ntk.git /tmp/pip-req-build-1pq9dhf3
  Resolved https://github.com/icml2022anon/fast_finite_width_ntk.git to commit 23b5988ed95f58c55c9a9b7cb7e707b9e0f018c1
  Preparing metadata (setup.py) ... [?25l[?25hdone
[0m

In [4]:
import torch
from fast_finite_width_ntk import empirical_ntk_fn_pytorch

In [5]:
input_size = 1024

# PyTorch model defintion

In [6]:
def get_model(O: int):
  # TODO: match ONNX and Pytorch tree structures and tensor layouts for CNNs.
  return torch.nn.Sequential(
      torch.nn.Linear(input_size, 2048),
      torch.nn.ReLU(),
      torch.nn.Linear(2048, 2048),
      torch.nn.ReLU(),
      torch.nn.Linear(2048, 2048),
      torch.nn.ReLU(),
      torch.nn.Linear(2048, O),
  )

# NTK functions declaration

In [7]:
def get_ntk_fns(O: int):
  # Define a PyTorch with `O` output logits.
  f = get_model(O)
  f.forward(torch.rand((1, input_size)))
  params = [p.T if p.ndim == 2 else p for p in f.parameters()]

  kwargs = dict(
      f=f,
      input_shape=(input_size,),
      trace_axes=(),
      vmap_axes=0
  )

  # Different NTK implementations
  jacobian_contraction = empirical_ntk_fn_pytorch(**kwargs, implementation=1)
  ntvp = empirical_ntk_fn_pytorch(**kwargs, implementation=2)
  str_derivatives = empirical_ntk_fn_pytorch(**kwargs, implementation=3)
  auto = empirical_ntk_fn_pytorch(**kwargs, implementation=0)
  
  return params, (jacobian_contraction, ntvp, str_derivatives, auto)

# $\color{blue}O = 10$ logits, batch size $\color{red}N = 10$, NTK

Structured derivatives allows to compute the NTK much faster than other methods.

In [8]:
O = 10
N = 10

# Input images x
x1 = torch.rand((N, input_size))
x2 = torch.rand((N, input_size))

params, (ntk_fn_jacobian_contraction, ntk_fn_ntvp, ntk_fn_str_derivatives, ntk_fn_auto) = get_ntk_fns(O=O)

INFO:onnx2keras:Converter is called.
DEBUG:onnx2keras:List input shapes:
DEBUG:onnx2keras:None
DEBUG:onnx2keras:List inputs:
DEBUG:onnx2keras:Input 0 -> input.
DEBUG:onnx2keras:List outputs:
DEBUG:onnx2keras:Output 0 -> output.
DEBUG:onnx2keras:Gathering weights to dictionary.
DEBUG:onnx2keras:Found weight 0.weight with shape (2048, 1024).
DEBUG:onnx2keras:Found weight 0.bias with shape (2048,).
DEBUG:onnx2keras:Found weight 2.weight with shape (2048, 2048).
DEBUG:onnx2keras:Found weight 2.bias with shape (2048,).
DEBUG:onnx2keras:Found weight 4.weight with shape (2048, 2048).
DEBUG:onnx2keras:Found weight 4.bias with shape (2048,).
DEBUG:onnx2keras:Found weight 6.weight with shape (10, 2048).
DEBUG:onnx2keras:Found weight 6.bias with shape (10,).
DEBUG:onnx2keras:Found input input with shape [1024]
DEBUG:onnx2keras:######
DEBUG:onnx2keras:...
DEBUG:onnx2keras:Converting ONNX operation
DEBUG:onnx2keras:type: Gemm
DEBUG:onnx2keras:node_name: 9
DEBUG:onnx2keras:node_params: {'alpha': 1.0

In [9]:
# Jacobian contraction
k_1 = ntk_fn_jacobian_contraction(x1, x2, params)
print(k_1.shape)

DEBUG:absl:Initializing backend 'interpreter'
DEBUG:absl:Backend 'interpreter' initialized
DEBUG:absl:Initializing backend 'cpu'
DEBUG:absl:Backend 'cpu' initialized
DEBUG:absl:Initializing backend 'tpu_driver'
INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
DEBUG:absl:Initializing backend 'gpu'
DEBUG:absl:Backend 'gpu' initialized
DEBUG:absl:Initializing backend 'tpu'
INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
DEBUG:absl:Finished tracing + transforming prim_fun for jit in 0.0005383491516113281 sec
DEBUG:absl:Compiling prim_fun (140285108449760 for args (ShapedArray(float32[10,1024]),).
DEBUG:absl:Finished XLA compilation of broadcast_in_dim in 0.19342517852783203 sec
DEBUG:absl:Finished tracing + transforming matmul for jit in 0.0016980171203613281 sec
DEBUG:absl:Compiling matmul (140287053802864 for args (ShapedArray(float32[10,1,1024]), ShapedArray(float32[1024,2048]))

torch.Size([10, 10, 10, 10])


In [10]:
# NTK-vector products
k_2 = ntk_fn_ntvp(x1, x2, params)
print(k_2.shape)

DEBUG:absl:Finished tracing + transforming matmul for jit in 0.0014281272888183594 sec
DEBUG:absl:Compiling matmul (140284338827904 for args (ShapedArray(float32[10,1,1024]), ShapedArray(float32[1024,2048])).
DEBUG:absl:Finished XLA compilation of vmap(jvp(matmul)) in 0.010274410247802734 sec
DEBUG:absl:Finished tracing + transforming fn for jit in 0.0020987987518310547 sec
DEBUG:absl:Compiling fn (140284338321760 for args (ShapedArray(float32[10,1,2048]), ShapedArray(float32[2048])).
DEBUG:absl:Finished XLA compilation of vmap(jvp(fn)) in 0.015392303466796875 sec
DEBUG:absl:Finished tracing + transforming matmul for jit in 0.0018317699432373047 sec
DEBUG:absl:Compiling matmul (140284338807424 for args (ShapedArray(float32[10,1,2048]), ShapedArray(float32[2048,2048])).
DEBUG:absl:Finished XLA compilation of vmap(jvp(matmul)) in 0.010088682174682617 sec
DEBUG:absl:Finished tracing + transforming matmul for jit in 0.0031430721282958984 sec
DEBUG:absl:Compiling matmul (140284338808624 for

torch.Size([10, 10, 10, 10])


In [11]:
# Structured derivatives
k_3 = ntk_fn_str_derivatives(x1, x2, params)
print(k_3.shape)

DEBUG:absl:Finished tracing + transforming prim_fun for jit in 0.00038623809814453125 sec
DEBUG:absl:Compiling prim_fun (140284338062192 for args (ShapedArray(float32[10,1,1024]), ShapedArray(float32[1024,2048])).
DEBUG:absl:Finished XLA compilation of dot_general in 0.016028881072998047 sec
DEBUG:absl:Finished tracing + transforming prim_fun for jit in 0.0003523826599121094 sec
DEBUG:absl:Compiling prim_fun (140284338779152 for args (ShapedArray(float32[2048]),).
DEBUG:absl:Finished XLA compilation of broadcast_in_dim in 0.007985830307006836 sec
DEBUG:absl:Finished tracing + transforming prim_fun for jit in 0.0005848407745361328 sec
DEBUG:absl:Compiling prim_fun (140284338373488 for args (ShapedArray(float32[1,2048]),).
DEBUG:absl:Finished XLA compilation of broadcast_in_dim in 0.00751805305480957 sec
DEBUG:absl:Finished tracing + transforming prim_fun for jit in 0.0005481243133544922 sec
DEBUG:absl:Compiling prim_fun (140284338486320 for args (ShapedArray(float32[10,1,2048]), ShapedA

torch.Size([10, 10, 10, 10])


In [14]:
# Make sure kernels agree.
print(
    torch.max(torch.abs(k_1 - k_2)) / torch.mean(torch.abs(k_1)), 
    torch.max(torch.abs(k_1 - k_3)) / torch.mean(torch.abs(k_1)),
    torch.max(torch.abs(k_2 - k_3)) / torch.mean(torch.abs(k_2))
)

tensor(2.0115e-06) tensor(3.0172e-06) tensor(2.0115e-06)


In [15]:
# Selects best method based on FLOPs at first call / compilation.
# Takes about 3x more time to compile.
# WARNING: due to an XLA issue, currently only works correctly on TPUs!
# Wrong FLOPs for CPU/GPU of JITted functions.
k_0 = ntk_fn_auto(x1, x2, params)
print(k_0.shape)

impl=1, flops=2103388288.0
impl=2, flops=1068667008.0
impl=3, flops=2117817.0
torch.Size([10, 10, 10, 10])


In [19]:
%%timeit
# Slow
ntk_fn_jacobian_contraction(x1, x2, params)

1 loop, best of 5: 14.1 s per loop


In [20]:
%%timeit
# 2X faster
ntk_fn_ntvp(x1, x2, params)

1 loop, best of 5: 6.84 s per loop


In [21]:
%%timeit
# 50X faster.
ntk_fn_str_derivatives(x1, x2, params)

1 loop, best of 5: 266 ms per loop


In [22]:
%%timeit 
# On TPU should match the fastest method.
# On GPU/CPU, currently is broken, and may not be the fastest.
ntk_fn_auto(x1, x2, params)

1 loop, best of 5: 265 ms per loop
