In [8]:
import numpy as np

import jax
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)

from neural_tangents import stax

In [9]:
# Get rid of Jax warnings up front
jnp.arange(42).sum()

DeviceArray(861, dtype=int64)

# Reference implementation
Using `neural-tangents`

### Diagram of residual block

In [10]:
def get_res_block(
    W_std: float,
    b_std: float,
):
    return stax.serial(
        stax.FanOut(2), stax.parallel(
            stax.Identity(),
            stax.serial(
                stax.Dense(1, W_std=W_std, b_std=b_std),
                stax.Relu(),
                stax.Dense(1, W_std=W_std, b_std=b_std),
            )
        ), stax.FanInSum(),
    )

def _get_kernel_fn(
    num_res_blocks: int,
    W_std: float,
    b_std: float,
):
    layers = [
        stax.Dense(1, W_std=W_std, b_std=1)
    ] + [
        get_res_block(W_std=W_std, b_std=b_std)
        for _ in range(num_res_blocks)
    ] + [stax.Relu(), stax.Dense(1, W_std=W_std, b_std=0)]
    
    _, _, kernel_fn = stax.serial(*layers)
    return kernel_fn


def rmlp_nngp_ref(
    xs: jnp.ndarray,
    num_res_blocks: int,
    W_std: float,
    b_std: float,
):
    kernel_fn = _get_kernel_fn(
        num_res_blocks=num_res_blocks,
        W_std=W_std,
        b_std=b_std,
    )
    
    return kernel_fn(xs, xs).nngp


def rmlp_ntk_ref(
    xs: jnp.ndarray,
    num_res_blocks: int,
    W_std: float,
    b_std: float,
):
    kernel_fn = _get_kernel_fn(
        num_res_blocks=num_res_blocks,
        W_std=W_std,
        b_std=b_std,
    )
    
    return kernel_fn(xs, xs).ntk

We ignore the independent input warning
and cite Tensor Programs II for correctness.

In [11]:
key, _ = jax.random.split(jax.random.PRNGKey(1))
xs = jax.random.normal(key=key, shape=(3, 10))

print("NNGP:\n", rmlp_nngp_ref(xs, num_res_blocks=3, W_std=1, b_std=1))
print("NTK:\n", rmlp_ntk_ref(xs, num_res_blocks=3, W_std=1, b_std=1))



NNGP:
 [[6.40371803 5.71130867 5.42762025]
 [5.71130867 8.54887713 5.56358456]
 [5.42762025 5.56358456 6.05370648]]
NTK:
 [[22.67737213 13.64461843 14.63093884]
 [13.64461843 31.25800854 13.28012561]
 [14.63093884 13.28012561 21.27732594]]


# Our implementations

### NNGP

In [14]:
def v_relu(K: jnp.ndarray):
    """K should be square"""
    sqrt = jnp.sqrt(
        jnp.diag(K)[:, jnp.newaxis]
      * jnp.diag(K)[jnp.newaxis, :]
    )
    
    c = K / sqrt
    
    return 1 / (2 * jnp.pi) * (
        jnp.sqrt(1 - c * c)
        + (jnp.pi - jnp.arccos(c)) * c
    ) * sqrt

def rmlp_nngp_ours(
    xs: jnp.ndarray,
    num_res_blocks: int,
    W_std: float,
    b_std: float,
):
    assert num_res_blocks > 0
    
    data_dim = xs.shape[1]
    W_var = W_std ** 2
    b_var = b_std ** 2
    
    # Initial covariance in data space
    K_init = xs @ xs.T
    
    # Kernel after first weight layer (no ReLU)
    K = W_var * K_init / data_dim + b_var
    
    # Residual blocks
    for i in range(num_res_blocks):
        K += W_var * v_relu(W_var * K + b_var) + b_var
    
    # Final linear output layer
    K = W_var * v_relu(K)
    
    return K

In [15]:
key, _ = jax.random.split(jax.random.PRNGKey(1))
xs = jax.random.normal(key=key, shape=(3, 10))

print("NNGP:\n", rmlp_nngp_ours(xs, num_res_blocks=3, W_std=1, b_std=1))

NNGP:
 [[6.40371803 5.71130867 5.42762025]
 [5.71130867 8.54887713 5.56358456]
 [5.42762025 5.56358456 6.05370648]]


In [7]:
# Stronger tests
for num_res_blocks in [1, 3, 20]:
    key, _ = jax.random.split(jax.random.PRNGKey(num_res_blocks))
    xs = jax.random.normal(key=key, shape=(10, 5))
    
    assert jnp.allclose(
         rmlp_nngp_ref(xs, num_res_blocks=num_res_blocks,
                       W_std=1, b_std=1),
         rmlp_nngp_ours(xs, num_res_blocks=num_res_blocks,
                        W_std=1, b_std=1),
    )

