In [1]:
import jax
import jax.numpy as jnp
from jax import random, grad, vmap, jit
import neural_tangents as nt
from neural_tangents import stax
import optax

In [2]:
key = random.PRNGKey(42)
n_train, n_test = 20, 50
x_train = jnp.linspace(-jnp.pi, jnp.pi, n_train)[:, None]
y_train = jnp.sin(x_train) + 0.1 * random.normal(key, shape=(n_train, 1))
x_test = jnp.linspace(-jnp.pi, jnp.pi + 1, n_test)[:, None]
y_test = jnp.sin(x_test)

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
I0000 00:00:1758506173.043028 1968454 service.cc:145] XLA service 0x17fe08eb0 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1758506173.043099 1968454 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1758506173.048027 1968454 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1758506173.048040 1968454 mps_client.cc:384] XLA backend will use up to 5726240768 bytes on device 0 for SimpleAllocator.


Metal device set to: Apple M2

systemMemory: 8.00 GB
maxCacheSize: 2.67 GB



In [3]:
init_fn, apply_fn, kernel_fn = nt.stax.serial(
    stax.Dense(512, W_std=1.0, b_std=0.05),
    stax.Relu(),
    stax.Dense(1, W_std=1.0, b_std=0.05),
)
output_shape, params_init = init_fn(key, input_shape=x_train.shape)
predictions_init = apply_fn(params_init, x_train)

In [4]:
def compute_ntk(x1, x2, params):
    ntk_fn = nt.empirical_ntk_fn(apply_fn)
    return ntk_fn(x1, x2, params)

def mse_loss(params, x, y):
    pred = apply_fn(params, x)
    return 0.5 * jnp.mean((pred - y) ** 2)

In [5]:
grad_loss = jit(grad(mse_loss))
optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(params_init)
params_dnn = params_init

In [14]:
for epoch in range(5000):
    grads = grad_loss(params_dnn, x_train, y_train)
    updates, opt_state = optimizer.update(grads, opt_state)
    params_dnn = optax.apply_updates(params_dnn, updates)

In [7]:
initial_ntk = compute_ntk(x_train, x_train, params_init)
after_ntk = compute_ntk(x_train, x_train, params_dnn)

In [8]:
def kare(y, K, z, n):
    K_norm = K / n
    mat = K_norm + z * jnp.eye(n)
    inv = jnp.linalg.inv(mat)
    inv2 = inv @ inv
    term1 = (1 / n) * y.T @ inv2 @ y
    trace = jnp.trace(inv) / n
    term2 = trace**2
    return term1[0, 0] - term2

def kare_objective(params):
    K = compute_ntk(x_train, x_train, params)
    return kare(y_train, K, z=1e-3, n=n_train)
    

In [9]:
grad_kare = jit(grad(kare_objective))
optimizer_kare = optax.adam(learning_rate=1e-6)
opt_state_kare = optimizer_kare.init(params_init)
params_kare = params_dnn

In [10]:
for epoch in range(1000):
    grads = grad_kare(params_kare)
    updates, opt_state_kare = optimizer_kare.update(grads, opt_state_kare)
    params_kare = optax.apply_updates(params_kare, updates)

XlaRuntimeError: UNKNOWN: /var/folders/nw/8ph9ksx515nc537nnyyjj34w0000gn/T/ipykernel_24425/1887060143.py:4:0: error: failed to legalize operation 'mhlo.triangular_solve'
/var/folders/nw/8ph9ksx515nc537nnyyjj34w0000gn/T/ipykernel_24425/1887060143.py:4:0: note: called from
/var/folders/nw/8ph9ksx515nc537nnyyjj34w0000gn/T/ipykernel_24425/1887060143.py:13:0: note: called from
/var/folders/nw/8ph9ksx515nc537nnyyjj34w0000gn/T/ipykernel_24425/833026880.py:2:0: note: called from
/Users/nahum/Programming/research/rubicon/.venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3579:0: note: called from
/Users/nahum/Programming/research/rubicon/.venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3519:0: note: called from
/Users/nahum/Programming/research/rubicon/.venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3336:0: note: called from
/Users/nahum/Programming/research/rubicon/.venv/lib/python3.10/site-packages/IPython/core/async_helpers.py:128:0: note: called from
/Users/nahum/Programming/research/rubicon/.venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3132:0: note: called from
/Users/nahum/Programming/research/rubicon/.venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3077:0: note: called from
/Users/nahum/Programming/research/rubicon/.venv/lib/python3.10/site-packages/ipykernel/zmqshell.py:577:0: note: called from
/var/folders/nw/8ph9ksx515nc537nnyyjj34w0000gn/T/ipykernel_24425/1887060143.py:4:0: note: see current operation: %553 = "mhlo.triangular_solve"(%475#4, %552) {left_side = true, lower = true, transpose_a = #mhlo<transpose NO_TRANSPOSE>, unit_diagonal = true} : (tensor<20x20xf32>, tensor<20x20xf32>) -> tensor<20x20xf32>


In [11]:
def kernel_predict(kernel_matrix_train, x_test, y_train, params, lambd=1e-6, n=n_train):
    K_test_train = compute_ntk(x_test, x_train, params)
    K_norm = kernel_matrix_train / n
    inv = jnp.linalg.inv(K_norm + lambd * jnp.eye(n))
    preds = (1 / n) * K_test_train @ inv @ y_train
    return preds

def mse(pred, true):
    return jnp.mean((pred - true) ** 2)

In [12]:
kare_ntk = compute_ntk(x_train, x_train, params_kare)

In [13]:
dnn_preds = apply_fn(params_dnn, x_test)
initial_ntk_preds = kernel_predict(initial_ntk, x_test, y_train, params_init)
after_ntk_preds = kernel_predict(after_ntk, x_test, y_train, params_dnn)
kare_preds = kernel_predict(kare_ntk, x_test, y_train, params_kare)

print(
    f"Neural network = {mse(dnn_preds, y_test)}\n"
    f"Initial NTK    = {mse(initial_ntk_preds, y_test)}\n"
    f"After NTK      = {mse(after_ntk_preds, y_test)}\n"
    f"NTK KARE       = {mse(kare_preds, y_test)}"
)

XlaRuntimeError: UNKNOWN: /var/folders/nw/8ph9ksx515nc537nnyyjj34w0000gn/T/ipykernel_24425/1887060143.py:4:0: error: failed to legalize operation 'mhlo.triangular_solve'
/var/folders/nw/8ph9ksx515nc537nnyyjj34w0000gn/T/ipykernel_24425/1887060143.py:4:0: note: called from
/var/folders/nw/8ph9ksx515nc537nnyyjj34w0000gn/T/ipykernel_24425/1887060143.py:13:0: note: called from
/var/folders/nw/8ph9ksx515nc537nnyyjj34w0000gn/T/ipykernel_24425/833026880.py:2:0: note: called from
/Users/nahum/Programming/research/rubicon/.venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3579:0: note: called from
/Users/nahum/Programming/research/rubicon/.venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3519:0: note: called from
/Users/nahum/Programming/research/rubicon/.venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3336:0: note: called from
/Users/nahum/Programming/research/rubicon/.venv/lib/python3.10/site-packages/IPython/core/async_helpers.py:128:0: note: called from
/Users/nahum/Programming/research/rubicon/.venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3132:0: note: called from
/Users/nahum/Programming/research/rubicon/.venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3077:0: note: called from
/Users/nahum/Programming/research/rubicon/.venv/lib/python3.10/site-packages/ipykernel/zmqshell.py:577:0: note: called from
/var/folders/nw/8ph9ksx515nc537nnyyjj34w0000gn/T/ipykernel_24425/1887060143.py:4:0: note: see current operation: %126 = "mhlo.triangular_solve"(%48#4, %125) {left_side = true, lower = true, transpose_a = #mhlo<transpose NO_TRANSPOSE>, unit_diagonal = true} : (tensor<20x20xf32>, tensor<20x20xf32>) -> tensor<20x20xf32>
