In [1]:
import numpy as np

import jax
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)

from neural_tangents import stax

In [2]:
# Get rid of Jax warnings up front
jnp.arange(42).sum()



DeviceArray(861, dtype=int64)

# Reference implementation
Using `neural-tangents`

### Diagram of residual block

In [3]:
def get_res_block(
    W_std: float,
    b_std: float,
):
    return stax.serial(
        stax.FanOut(2), stax.parallel(
            stax.Identity(),
            stax.serial(
                stax.Dense(1, W_std=W_std, b_std=b_std),
                stax.Relu(),
                stax.Dense(1, W_std=W_std, b_std=b_std),
            )
        ), stax.FanInSum(),
    )

def _get_kernel_fn(
    num_res_blocks: int,
    W_std: float,
    b_std: float,
):
    layers = [
        stax.Dense(1, W_std=W_std, b_std=b_std)
    ] + [
        get_res_block(W_std=W_std, b_std=b_std)
        for _ in range(num_res_blocks)
    ] + [stax.Relu(), stax.Dense(1, W_std=W_std, b_std=0)]
    
    _, _, kernel_fn = stax.serial(*layers)
    return kernel_fn


def rmlp_nngp_ref(
    xs: jnp.ndarray,
    num_res_blocks: int,
    W_std: float,
    b_std: float,
):
    kernel_fn = _get_kernel_fn(
        num_res_blocks=num_res_blocks,
        W_std=W_std,
        b_std=b_std,
    )
    
    return kernel_fn(xs, xs).nngp


def rmlp_ntk_ref(
    xs: jnp.ndarray,
    num_res_blocks: int,
    W_std: float,
    b_std: float,
):
    kernel_fn = _get_kernel_fn(
        num_res_blocks=num_res_blocks,
        W_std=W_std,
        b_std=b_std,
    )
    
    return kernel_fn(xs, xs).ntk

We ignore the independent input warning
and cite Tensor Programs II for correctness.

In [4]:
key, _ = jax.random.split(jax.random.PRNGKey(1))
xs = jax.random.normal(key=key, shape=(3, 10))

print("NNGP:\n", rmlp_nngp_ref(xs, num_res_blocks=3, W_std=1, b_std=1))
print("NTK:\n", rmlp_ntk_ref(xs, num_res_blocks=3, W_std=1, b_std=1))



NNGP:
 [[6.40371803 5.71130867 5.42762025]
 [5.71130867 8.54887713 5.56358456]
 [5.42762025 5.56358456 6.05370648]]
NTK:
 [[22.67737213 13.64461843 14.63093884]
 [13.64461843 31.25800854 13.28012561]
 [14.63093884 13.28012561 21.27732594]]


# Our implementations

### NNGP

In [5]:
def v_relu(K: jnp.ndarray):
    """K should be square"""
    sqrt = jnp.sqrt(
        jnp.diag(K)[:, jnp.newaxis]
      * jnp.diag(K)[jnp.newaxis, :]
    )
    
    c = K / sqrt
    
    return 1 / (2 * jnp.pi) * (
        jnp.sqrt(1 - c * c)
        + (jnp.pi - jnp.arccos(c)) * c
    ) * sqrt

def rmlp_nngp_ours(
    xs: jnp.ndarray,
    num_res_blocks: int,
    W_std: float,
    b_std: float,
):
    assert num_res_blocks > 0
    
    data_dim = xs.shape[1]
    W_var = W_std ** 2
    b_var = b_std ** 2
    
    # Initial covariance in data space
    K_init = xs @ xs.T
    
    # Kernel after first weight layer (no ReLU)
    K = W_var * K_init / data_dim + b_var
    
    # Residual blocks
    for i in range(num_res_blocks):
        K += W_var * v_relu(W_var * K + b_var) + b_var
    
    # Final linear output layer
    K = W_var * v_relu(K)
    
    return K

In [6]:
key, _ = jax.random.split(jax.random.PRNGKey(1))
xs = jax.random.normal(key=key, shape=(3, 5))

print("NNGP (ref):\n", rmlp_nngp_ref(xs, num_res_blocks=3, W_std=1.1, b_std=0.9))
print("NNGP (ours):\n", rmlp_nngp_ours(xs, num_res_blocks=3, W_std=1.1, b_std=0.9))

NNGP (ref):
 [[ 9.45871832  6.78979695  9.30120483]
 [ 6.78979695  9.76577003  6.89616609]
 [ 9.30120483  6.89616609 11.42867631]]
NNGP (ours):
 [[ 9.45871832  6.78979695  9.30120483]
 [ 6.78979695  9.76577003  6.89616609]
 [ 9.30120483  6.89616609 11.42867631]]


In [7]:
# Stronger tests
for num_res_blocks in [2, 3, 20]:
    key, _ = jax.random.split(jax.random.PRNGKey(num_res_blocks))
    xs = jax.random.normal(key=key, shape=(10, 5))
    
    assert jnp.allclose(
         rmlp_nngp_ref(xs, num_res_blocks=num_res_blocks,
                       W_std=1.1, b_std=0.9),
         rmlp_nngp_ours(xs, num_res_blocks=num_res_blocks,
                        W_std=1.1, b_std=0.9),
    )



### NTK

In [53]:
def v_relu_prime(K: jnp.ndarray):
    """K should be square"""
    sqrt = jnp.sqrt(
        jnp.diag(K)[:, jnp.newaxis]
      * jnp.diag(K)[jnp.newaxis, :]
    )
    
    c = K / sqrt
    
    return 1 / (2 * jnp.pi) * (jnp.pi - jnp.arccos(c))

def rmlp_ntk_ours(
    xs: jnp.ndarray,
    num_res_blocks: int,
    W_std: float,
    b_std: float,
):
    assert num_res_blocks > 0
    
    batch_size = xs.shape[0]
    data_dim = xs.shape[1]
    W_var = W_std ** 2
    b_var = b_std ** 2
    
    # Initial covariance in data space
    K_init = xs @ xs.T
    
    ############################## Begin forward pass
    
    embedding_K = W_var * K_init / data_dim + b_var
    
    # forward_Ks[i] is the coordinate covariance
    # of the output after the ith residual block
    # The 0th residual block is the embedding layer.
    forward_Ks = [embedding_K]
    for i in range(num_res_blocks):
        prv_K = forward_Ks[-1]
        forward_Ks.append(
            prv_K + W_var * v_relu(W_var * prv_K + b_var) + b_var
        )
    
    # out_K is the covariance of the output
    # it is the NNGP kernel
    out_K = W_var * v_relu(forward_Ks[-1])
                        
    ############################## End forward pass
    
    ############################## Begin backward pass
    
    ntk = jnp.zeros((batch_size, batch_size))
    
    # output layer weights
    ntk += out_K
    
    # Work backwards through residual blocks
    # cur_grad_K[i][j] gives the limit of the inner product
    # of the gradients of the current weight layer's output
    # for inputs xs[i] and xs[j]
    cur_grad_K = W_var * v_relu_prime(forward_Ks[-1])
    for i in range(num_res_blocks, 0, -1):
        # Save a copy of original gradient
        # to use later for the skip connection
        skip_grad_K = cur_grad_K.copy()
        
        # second weight layer in block
        ntk += cur_grad_K * v_relu(W_var * forward_Ks[i - 1] + b_var)
        
        # back up gradients to first weight layer of block
        cur_grad_K *= W_var
        cur_grad_K *= v_relu_prime(W_var * forward_Ks[i - 1] + b_var)
        
        # first weight layer in block
        ntk += cur_grad_K * forward_Ks[i - 1]
        
        # back up gradients to beginning of block
        cur_grad_K *= W_var
        cur_grad_K += skip_grad_K
        
    # First embedding weight layer
    ntk += cur_grad_K * embedding_K
    
    ############################## End backward pass
    
    return ntk

In [54]:
key, _ = jax.random.split(jax.random.PRNGKey(1))
xs = jax.random.normal(key=key, shape=(3, 5))

print("NTK (ref):\n", rmlp_ntk_ref(xs, num_res_blocks=1, W_std=1, b_std=0))
print("NTK (ours):\n", rmlp_ntk_ours(xs, num_res_blocks=1, W_std=1, b_std=0))

NTK (ref):
 [[ 1.26394385 -0.05351018  0.85058322]
 [-0.05351018  1.42538696 -0.07560456]
 [ 0.85058322 -0.07560456  2.2997177 ]]
NTK (ours):
 [[ 1.26394385 -0.05351018  0.85058322]
 [-0.05351018  1.42538696 -0.07560456]
 [ 0.85058322 -0.07560456  2.2997177 ]]
