In [1]:
import jax
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)

from neural_tangents import stax

In [2]:
# Get rid of Jax warnings up front
jnp.arange(42).sum()



DeviceArray(861, dtype=int64)

# Reference implementation
Using `neural-tangents`

```python
stax.serial(
    stax.Dense(1, W_std=W_std, b_std=b_std), stax.Relu(),
    stax.Dense(1, W_std=W_std, b_std=b_std), stax.Relu(),
    stax.Dense(1, W_std=W_std, b_std=0)
)
```

In [3]:
def _get_kernel_fn(
    num_hidden_layers: int,
    W_std: float,
    b_std: float,
):
    layers = []
    for _ in range(num_hidden_layers):
        layers.append(stax.Dense(1, W_std=W_std, b_std=b_std))
        layers.append(stax.Relu())
    layers.append(stax.Dense(1, W_std=W_std, b_std=0))
    
    _, _, kernel_fn = stax.serial(*layers)
    return kernel_fn


def mlp_nngp_ref(
    xs: jnp.ndarray,
    num_hidden_layers: int,
    W_std: float,
    b_std: float,
):
    kernel_fn = _get_kernel_fn(
        num_hidden_layers=num_hidden_layers,
        W_std=W_std,
        b_std=b_std,
    )
    
    return kernel_fn(xs, xs).nngp


def mlp_ntk_ref(
    xs: jnp.ndarray,
    num_hidden_layers: int,
    W_std: float,
    b_std: float,
):
    kernel_fn = _get_kernel_fn(
        num_hidden_layers=num_hidden_layers,
        W_std=W_std,
        b_std=b_std,
    )
    
    return kernel_fn(xs, xs).ntk

In [4]:
key, _ = jax.random.split(jax.random.PRNGKey(1))
xs = jax.random.normal(key=key, shape=(3, 5))

print("xs:\n", xs)
print("NNGP:\n", mlp_nngp_ref(xs, num_hidden_layers=3, W_std=1, b_std=1))
print("NTK:\n", mlp_ntk_ref(xs, num_hidden_layers=3, W_std=1, b_std=1))

xs:
 [[ 0.51232451 -0.76248157  0.47484656 -1.44580725 -0.01293077]
 [-0.01876019  0.96316201 -1.30182157  0.92914455  0.27818229]
 [-0.41780672  0.07287297  1.92944858 -1.29686923 -0.40593064]]
NNGP:
 [[0.95399649 0.87331003 0.95016414]
 [0.87331003 0.96408668 0.88058314]
 [0.95016414 0.88058314 1.01873236]]
NTK:
 [[2.56598596 1.86727954 2.28452466]
 [1.86727954 2.60634674 1.83368718]
 [2.28452466 1.83368718 2.82492942]]


# Our implementations

### NNGP

In [7]:
def v_relu(K: jnp.ndarray):
    """K should be square"""
    sqrt = jnp.sqrt(
        jnp.diag(K)[:, jnp.newaxis]
      * jnp.diag(K)[jnp.newaxis, :]
    )
    
    c = K / sqrt
    
    return 1 / (2 * jnp.pi) * (
        jnp.sqrt(1 - c * c)
        + (jnp.pi - jnp.arccos(c)) * c
    ) * sqrt

def mlp_nngp_ours(
    xs: jnp.ndarray,
    num_hidden_layers: int,
    W_std: float,
    b_std: float,
):
    assert num_hidden_layers > 0
    
    data_dim = xs.shape[1]
    W_var = W_std ** 2
    b_var = b_std ** 2
    
    # Initial covariance in data space
    K_init = xs @ xs.T
    
    # Kernel after first hidden layer (before ReLU)
    K = W_var * K_init / data_dim + b_var
    
    # Remaining hidden layers
    for i in range(num_hidden_layers - 1):
        K = W_var * v_relu(K) + b_var
    
    # Final linear output layer
    K = W_var * v_relu(K)
    
    return K

In [8]:
key, _ = jax.random.split(jax.random.PRNGKey(1))
xs = jax.random.normal(key=key, shape=(3, 5))

print("NNGP:\n", mlp_nngp_ours(xs, num_hidden_layers=3, W_std=1, b_std=1))

NNGP:
 [[0.95399649 0.87331003 0.95016414]
 [0.87331003 0.96408668 0.88058314]
 [0.95016414 0.88058314 1.01873236]]


In [9]:
# Stronger tests
for num_hidden_layers in [1, 4, 20]:
    key, _ = jax.random.split(jax.random.PRNGKey(num_hidden_layers))
    xs = jax.random.normal(key=key, shape=(10, 5))
    
    assert jnp.allclose(
        mlp_nngp_ref(xs, num_hidden_layers=num_hidden_layers,
                     W_std=jnp.pi, b_std=jnp.e),
        mlp_nngp_ours(xs, num_hidden_layers=num_hidden_layers,
                      W_std=jnp.pi, b_std=jnp.e),
    )

### NTK

In [8]:
# TODO