<a href="https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/elementwise.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Examples of automatic nonlinearity NNGP/NTK computation using `stax.Elementwise` and `stax.ElementwiseNumerical`.

For details, please see "[Fast Neural Kernel Embeddings for General Activations](https://arxiv.org/abs/2209.04121)".

# Imports and setup

In [None]:
!pip install -q --upgrade pip
!pip install -q jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install -q git+https://www.github.com/google/neural-tangents

/bin/sh: line 1: pip: command not found
/bin/sh: line 1: pip: command not found
/bin/sh: line 1: pip: command not found


In [None]:
import jax
from jax import numpy as np, random
from neural_tangents import stax

In [None]:
key1, key2, key_init = random.split(random.PRNGKey(1), 3)
x1 = random.normal(key1, (3, 2))
x2 = random.normal(key2, (4, 2))

# 1. Using `Elementwise` to automatically derive the NTK in closed form from the NNGP.

`stax.Elementwise` derives under the hood the NTK function from the NNGP function using autodiff.

In [None]:
# Hand-derived NTK expression for the sine nonlinearity.
_, _, kernel_fn_manual = stax.serial(stax.Dense(1),
                                     stax.Sin())

# NNGP function for the sine nonlinearity:
def nngp_fn(cov12, var1, var2):
  sum_ = (var1 + var2)
  s1 = np.exp((-0.5 * sum_ + cov12))
  s2 = np.exp((-0.5 * sum_ - cov12))
  return (s1 - s2) / 2

# Let the `Elementwise` derive the NTK function in closed form automatically.
_, _, kernel_fn = stax.serial(stax.Dense(1),
                              stax.Elementwise(nngp_fn=nngp_fn))


k_auto = kernel_fn(x1, x2, 'ntk')
k_manual = kernel_fn_manual(x1, x2, 'ntk')

# The two kernels match!
print(np.max(np.abs(k_manual - k_auto)))

0.0




# 2. Using `ElementwiseNumerical` to approximate kernels given only the nonlinearity.

`stax.ElementwiseNumerical` approximates the NNGP and NTK using Gaussian quadrature and autodiff.

In [None]:
# A nonlinearity with a known closed-form expression (GeLU).
_, _, kernel_fn_closed_form = stax.serial(
  stax.Dense(1),
  stax.Gelu(),  # Contains the closed-form GeLU NNGP/NTK expression.
  stax.Dense(1)
)
kernel_closed_form = kernel_fn_closed_form(x1, x2)

# Construct the layer from only the elementwise forward-pass GeLU.
_, _, kernel_fn_numerical = stax.serial(
  stax.Dense(1),
  stax.ElementwiseNumerical(jax.nn.gelu, deg=25),  # quadrature and autodiff.
  stax.Dense(1)
)
kernel_numerical = kernel_fn_numerical(x1, x2)

# The two kernels are close!
print(np.max(np.abs(kernel_closed_form.nngp - kernel_numerical.nngp)))
print(np.max(np.abs(kernel_closed_form.ntk - kernel_numerical.ntk)))

3.825128e-05
8.523464e-05


