In [None]:
import jax
import jax.numpy as jnp
from jax import random, grad, vmap, jit
import neural_tangents as nt
from neural_tangents import stax
import optax

In [None]:
key = random.PRNGKey(42)
n_train, n_test = 20, 50
x_train = jnp.linspace(-jnp.pi, jnp.pi, n_train)[:, None]
y_train = jnp.sin(x_train) + 0.1 * random.normal(key, shape=(n_train, 1))
x_test = jnp.linspace(-jnp.pi, jnp.pi + 1, n_test)[:, None]
y_test = jnp.sin(x_test)

In [None]:
init_fn, apply_fn, kernel_fn = nt.stax.serial(
    stax.Dense(512, W_std=1.0, b_std=0.05),
    stax.Relu(),
    stax.Dense(1, W_std=1.0, b_std=0.05),
)
output_shape, params_init = init_fn(key, input_shape=x_train.shape)
predictions_init = apply_fn(params_init, x_train)

In [None]:
def compute_ntk(x1, x2, params):
    ntk_fn = nt.empirical_ntk_fn(apply_fn)
    return ntk_fn(x1, x2, params)


def mse_loss(params, x, y):
    pred = apply_fn(params, x)
    return 0.5 * jnp.mean((pred - y) ** 2)

In [None]:
grad_loss = jit(grad(mse_loss))
optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(params_init)
params_dnn = params_init

In [None]:
for epoch in range(5000):
    grads = grad_loss(params_dnn, x_train, y_train)
    updates, opt_state = optimizer.update(grads, opt_state)
    params_dnn = optax.apply_updates(params_dnn, updates)

In [None]:
initial_ntk = compute_ntk(x_train, x_train, params_init)
after_ntk = compute_ntk(x_train, x_train, params_dnn)

In [None]:
def kare(y, K, z, n):
    K_norm = K / n
    mat = K_norm + z * jnp.eye(n)
    inv = jnp.linalg.inv(mat)
    inv2 = inv @ inv
    term1 = (1 / n) * y.T @ inv2 @ y
    trace = jnp.trace(inv) / n
    term2 = trace**2
    return term1[0, 0] - term2


def kare_objective(params):
    K = compute_ntk(x_train, x_train, params)
    return kare(y_train, K, z=1e-3, n=n_train)

In [None]:
grad_kare = jit(grad(kare_objective))
optimizer_kare = optax.adam(learning_rate=1e-6)
opt_state_kare = optimizer_kare.init(params_init)
params_kare = params_dnn

In [None]:
for epoch in range(1000):
    grads = grad_kare(params_kare)
    updates, opt_state_kare = optimizer_kare.update(grads, opt_state_kare)
    params_kare = optax.apply_updates(params_kare, updates)

In [None]:
def kernel_predict(
    kernel_matrix_train, x_test, y_train, params, lambd=1e-6, n=n_train
):
    K_test_train = compute_ntk(x_test, x_train, params)
    print(K_test_train.shape)
    K_norm = kernel_matrix_train / n
    inv = jnp.linalg.inv(K_norm + lambd * jnp.eye(n))
    preds = (1 / n) * K_test_train @ inv @ y_train
    return preds


def mse(pred, true):
    return jnp.mean((pred - true) ** 2)

In [None]:
kare_ntk = compute_ntk(x_train, x_train, params_kare)

In [None]:
dnn_preds = apply_fn(params_dnn, x_test)
initial_ntk_preds = kernel_predict(initial_ntk, x_test, y_train, params_init)
after_ntk_preds = kernel_predict(after_ntk, x_test, y_train, params_dnn)
kare_preds = kernel_predict(kare_ntk, x_test, y_train, params_kare)

print(
    f"Neural network = {mse(dnn_preds, y_test)}\n"
    f"Initial NTK    = {mse(initial_ntk_preds, y_test)}\n"
    f"After NTK      = {mse(after_ntk_preds, y_test)}\n"
    f"NTK KARE       = {mse(kare_preds, y_test)}"
)