<a href="https://colab.research.google.com/github/neurips2022sub/ntk_activations/blob/main/example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Examples of using new nonlinearities from [`ntk_activations`](https://github.com/neurips2022sub/ntk_activations)

In [1]:
!pip install git+https://github.com/neurips2022sub/ntk_activations.git

Collecting git+https://github.com/neurips2022sub/ntk_activations.git
  Cloning https://github.com/neurips2022sub/ntk_activations.git to /tmp/pip-req-build-hz0_m17q
  Running command git clone -q https://github.com/neurips2022sub/ntk_activations.git /tmp/pip-req-build-hz0_m17q
Collecting neural-tangents>=0.5.0
  Downloading neural_tangents-0.5.0-py2.py3-none-any.whl (193 kB)
[K     |████████████████████████████████| 193 kB 4.6 MB/s 
Collecting frozendict>=2.3
  Downloading frozendict-2.3.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (99 kB)
[K     |████████████████████████████████| 99 kB 1.5 MB/s 
Building wheels for collected packages: ntk-activations
  Building wheel for ntk-activations (setup.py) ... [?25l[?25hdone
  Created wheel for ntk-activations: filename=ntk_activations-0.0.1-py3-none-any.whl size=12912 sha256=8d3424a6173b5b2f9d7a30dff518300fca97c0bbfaeb1203e45f28e762b0f3c0
  Stored in directory: /tmp/pip-ephem-wheel-cache-8g93oen8/wheels/b5/6c/15/4329dce81d43

In [2]:
import jax
from jax import numpy as np, random
from neural_tangents import stax
from ntk_activations import stax_extensions

## 1. Using `ntk_activations.stax_extensions` with `neural_tangents.stax`

You can seamlessly combine layers from `neural_tangents.stax` and `ntk_activations.stax_extensions`.

In [3]:
key1, key2, key_init = random.split(random.PRNGKey(1), 3)
x1 = random.normal(key1, (3, 2))
x2 = random.normal(key2, (4, 2))

init_fn, apply_fn, kernel_fn = stax.serial(
    stax.Dense(128),
    stax_extensions.Gaussian(),
    stax.Dense(10)
)

_, params = init_fn(key_init, x1.shape)
outputs = apply_fn(params, x1)
kernel = kernel_fn(x1, x2)
print(kernel)



Kernel(nngp=DeviceArray([[0.4366224 , 0.45867887, 0.38753727, 0.52652365],
             [0.4365696 , 0.4073899 , 0.27690652, 0.38512793],
             [0.35802934, 0.44013247, 0.29739842, 0.44431883]],            dtype=float32), ntk=DeviceArray([[0.47899756, 0.54227614, 0.6494441 , 0.7565323 ],
             [0.6249889 , 0.49593154, 0.3092413 , 0.38644305],
             [0.3599234 , 0.6529554 , 0.39999655, 0.61957943]],            dtype=float32), cov1=DeviceArray([0.48287258, 0.41158044, 0.4013612 ], dtype=float32), cov2=DeviceArray([0.54617095, 0.5546928 , 0.3666358 , 0.5852136 ], dtype=float32), x1_is_x2=False, is_gaussian=True, is_reversed=False, is_input=False, diagonal_batch=True, diagonal_spatial=False, shape1=(3, 10), shape2=(4, 10), batch_axis=0, channel_axis=1, mask1=None, mask2=None)


## 2. Using `Elementwise` to automatically derive the NTK in closed form from the NNGP.

`ntk_activations.stax_extensions.Elementwise` derives under the hood the NTK function from the NNGP function using autodiff.

In [4]:
# Hand-derived NTK expression for the sine nonlinearity.
_, _, kernel_fn_manual = stax.serial(stax.Dense(1),
                                     stax_extensions.Sin())

# NNGP function for the sine nonlinearity:
def nngp_fn(cov12, var1, var2):
  sum_ = (var1 + var2)
  s1 = np.exp((-0.5 * sum_ + cov12))
  s2 = np.exp((-0.5 * sum_ - cov12))
  return (s1 - s2) / 2

# Let the `Elementwise` derive the NTK function in closed form automatically.
_, _, kernel_fn = stax.serial(stax.Dense(1),
                              stax_extensions.Elementwise(nngp_fn=nngp_fn))


k_auto = kernel_fn(x1, x2, 'ntk')
k_manual = kernel_fn_manual(x1, x2, 'ntk')

# The two kernels match!
print(np.max(np.abs(k_manual - k_auto)))

  'Using JAX autodiff to compute the `fn` derivative for NTK. Beware of '


0.0


## 3. Using `ElementwiseNumerical` to approximate kernels given only the nonlinearity.

`ntk_activations.stax_extensions.ElementwiseNumerical` approximates the NNGP and NTK using Gaussian quadrature and autodiff.

In [5]:
# A nonlinearity with a known closed-form expression (GeLU).
_, _, kernel_fn_closed_form = stax.serial(
  stax.Dense(1),
  stax_extensions.Gelu(),  # Contains the closed-form GeLU NNGP/NTK expression.
  stax.Dense(1)
)
kernel_closed_form = kernel_fn_closed_form(x1, x2)

# Construct the layer from only the elementwise forward-pass GeLU.
_, _, kernel_fn_numerical = stax.serial(
  stax.Dense(1),
  stax.ElementwiseNumerical(jax.nn.gelu, deg=25),  # quadrature and autodiff.
  stax.Dense(1)
)
kernel_numerical = kernel_fn_numerical(x1, x2)

# The two kernels are close!
print(np.max(np.abs(kernel_closed_form.nngp - kernel_numerical.nngp)))
print(np.max(np.abs(kernel_closed_form.ntk - kernel_numerical.ntk)))

  f'Numerical Activation Layer with fn={fn}, deg={deg} used!'
  'Using JAX autodiff to compute the `fn` derivative for NTK. Beware of '


3.823638e-05
8.529425e-05
