In [1]:
import jax
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)

from neural_tangents import stax

In [2]:
# Get rid of Jax warnings up front
jnp.arange(42).sum()



DeviceArray(861, dtype=int64)

# Reference implementation
Using `neural-tangents`

In [3]:
def _get_kernel_fn(
    num_hidden_layers: int,
    W_std: float,
    b_std: float,
):
    layers = []
    for _ in range(num_hidden_layers):
        layers.append(stax.Dense(1, W_std=W_std, b_std=b_std))
        layers.append(stax.Relu())
    layers.append(stax.Dense(1, W_std=W_std, b_std=0))
    
    _, _, kernel_fn = stax.serial(*layers)
    return kernel_fn


def mlp_nngp_ref(
    xs: jnp.ndarray,
    num_hidden_layers: int,
    W_std: float,
    b_std: float,
):
    kernel_fn = _get_kernel_fn(
        num_hidden_layers=num_hidden_layers,
        W_std=W_std,
        b_std=b_std,
    )
    
    return kernel_fn(xs, xs).nngp


def mlp_ntk_ref(
    xs: jnp.ndarray,
    num_hidden_layers: int,
    W_std: float,
    b_std: float,
):
    kernel_fn = _get_kernel_fn(
        num_hidden_layers=num_hidden_layers,
        W_std=W_std,
        b_std=b_std,
    )
    
    return kernel_fn(xs, xs).ntk

In [4]:
key, _ = jax.random.split(jax.random.PRNGKey(1))
xs = jax.random.normal(key=key, shape=(3, 10))

print("NNGP:\n", mlp_nngp_ref(xs, num_hidden_layers=1, W_std=1, b_std=1))
print("NTK:\n", mlp_ntk_ref(xs, num_hidden_layers=1, W_std=1, b_std=1))

NNGP:
 [[0.84184238 0.59124872 0.54153011]
 [0.59124872 1.47744508 0.54834254]
 [0.54153011 0.54834254 0.73813525]]
NTK:
 [[1.68368476 0.85379928 0.88487602]
 [0.85379928 2.95489015 0.78780738]
 [0.88487602 0.78780738 1.47627051]]


# Our implementations

### NNGP

In [5]:
def v_relu(K: jnp.ndarray):
    """K should be square"""
    sqrt = jnp.sqrt(
        jnp.diag(K)[:, jnp.newaxis]
      * jnp.diag(K)[jnp.newaxis, :]
    )
    
    c = K / sqrt
    
    return 1 / (2 * jnp.pi) * (
        jnp.sqrt(1 - c * c)
        + (jnp.pi - jnp.arccos(c)) * c
    ) * sqrt

def mlp_nngp_ours(
    xs: jnp.ndarray,
    num_hidden_layers: int,
    W_std: float,
    b_std: float,
):
    data_dim = xs.shape[1]
    W_var = W_std ** 2
    b_var = b_std ** 2
    
    # Kernel after first hidden layer (no ReLU)
    K = W_var * (xs @ xs.T) / data_dim + b_var
    
    # Remaining hidden layers
    for i in range(num_hidden_layers - 1):
        K = W_var * v_relu(K) + b_var
    
    # Final linear output layer
    if num_hidden_layers > 0:
        K = W_var * v_relu(K)
    
    return K

In [6]:
key, _ = jax.random.split(jax.random.PRNGKey(1))
xs = jax.random.normal(key=key, shape=(3, 10))

print("NNGP:\n", mlp_nngp_ours(xs, num_hidden_layers=1, W_std=1, b_std=1))

NNGP:
 [[0.84184238 0.59124872 0.54153011]
 [0.59124872 1.47744508 0.54834254]
 [0.54153011 0.54834254 0.73813525]]


In [7]:
# Stronger tests
for data_dim in [1, 2, 4, 32, 1024]:
    key, _ = jax.random.split(jax.random.PRNGKey(1))
    xs = jax.random.normal(key=key, shape=(10, data_dim))
    
    assert jnp.allclose(
        mlp_nngp_ref(xs, num_hidden_layers=1, W_std=jnp.pi, b_std=jnp.e),
        mlp_nngp_ours(xs, num_hidden_layers=1, W_std=jnp.pi, b_std=jnp.e),
    )

### NTK

In [None]:
# TODO