In [None]:
import sys
from pathlib import Path

sys.path.append(str(Path.cwd().parent))

In [None]:
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
from typing import Iterator

import jax.numpy as jnp

from datasets.sine import create_data_factory, load_sine, SineLoaderConfig
from rubicon.nns.mlp import LayerConfig, MLPConfig, MultiLayerPerceptron
from rubicon.nns._base import TrainingConfig, NTKTrainingConfig
from rubicon.nns.metrics.mae import MeanAbsoluteError

In [None]:
def get_input_shape(
    iterator: Iterator[tuple[jnp.ndarray, jnp.ndarray]],
) -> tuple[int, ...]:
    """Extract the input shape from the first batch of an iterator"""
    try:
        x_batch, _ = next(iter(iterator))
        return x_batch.shape[:]
    except StopIteration:
        raise ValueError("Empty iterator; cannot determine input shape")

In [None]:
# get the input shape, necessary for the model initialization.
batch_size = 32
n_train = 1600
n_test = 320
dataset_config = SineLoaderConfig(
    batch_size=batch_size,
    n_train=n_train,
    n_test=n_test,
)
temp_train_iter, _ = load_sine(dataset_config)
input_shape = get_input_shape(temp_train_iter)
input_shape

In [None]:
config = MLPConfig(
    output_layer=LayerConfig(size=1), hidden_layers=[LayerConfig(size=256)]
)
model = MultiLayerPerceptron(config)
model(input_shape=input_shape)
model

In [None]:
# standard training
training_config = TrainingConfig(
    data_factory=create_data_factory(dataset_config),
    num_epochs=2,
    batch_size=batch_size,
    verbose=True,
    accuracy_fn=MeanAbsoluteError(),
)
history = model.fit(training_config)

In [None]:
# training with kare
kare_training_config = NTKTrainingConfig(
    data_factory=create_data_factory(dataset_config),
    num_epochs=2,
    batch_size=batch_size,
    verbose=True,
    accuracy_fn=MeanAbsoluteError(),
    z=1e-3,
    lambd=1e-6,
    update_params=False,
    with_kare=True,
)
history = model.fit(kare_training_config)

In [None]:
import jax
import optax
from rubicon.nns.losses import KARELoss, MSELoss


In [None]:
def compute_gradient(p, x):
    def _compute_gradient(p, x):
        grad_fn = jax.grad(lambda p, x: model.apply_fn(p, x).sum())
        flat_grads = []
        for item in grad_fn(p, x):
            for element in item:
                flat_grads.append(jnp.ravel(element))
        return jnp.concatenate(flat_grads)

    per_sample_grads = jax.vmap(lambda p, x: _compute_gradient(p, x))
    return per_sample_grads(p, x)
    

def compute_ntk(params, x1, x2):
    """Compute the empirical NTK between batches of pairs of points."""
    G1 = compute_gradient(params, x1)
    G2 = compute_gradient(params, x2)

    def _compute_ntk(g1, g2):
        """Calculate as the scalar product of the two gradients."""
        return g1[:, None].T @ g2[:, None]
    
    per_sample_ntk = jax.vmap(lambda g1, g2: _compute_ntk(g1, g2))
    return per_sample_ntk(G1, G2)

def kare_loss(p, x, y, z):
    def _kare_loss(y, K, z):
        """Compute KARE pointwise."""
        n = K.shape[0]
        K_norm = K / n
        mat = K_norm + z * jnp.eye(n)
        inv = jax.jit(jnp.linalg.inv, backend="cpu")(mat)
        inv2 = inv @ inv
        return ((1/n) * y.T @ inv2 @ y) / ((1/n) * jnp.trace(inv)) ** 2
    
    K = compute_ntk(p, x, x)
    # Vectorization of the pointwise calculation
    per_sample_kare = jax.vmap(lambda y, K, z: _kare_loss(y, K, z))
    return jnp.sum(per_sample_kare(y, K, z))

grad_kare = jax.grad(kare_loss)


In [None]:
for layer in model.params:
    print(layer)

In [None]:
datafactory = create_data_factory(dataset_config)
train_iter, _ = datafactory()
params = model.params
z = 1e-3
optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(model.params)

for x, y in train_iter:
    grads = grad_kare(params, x, y, z)
    updates, opt_state = optax.update(grads, opt_state) 

In [None]:
_, test_iter = create_data_factory(dataset_config)()
z = 1e-3

# computing the gradient, then flatten, in batches
def flat_grad(p, x):
    grad_fn = jax.grad(lambda p, x: model.apply_fn(p, x).sum())
    flat_grads = []
    for item in grad_fn(p, x):
        for element in item:
            flat_grads.append(jnp.ravel(element))
    return jnp.concatenate(flat_grads)

xs, ys = next(test_iter)
per_sample_grads = jax.vmap(lambda p, x: flat_grad(model.params, x))
grads = per_sample_grads(model.params, xs.squeeze())
grads.shape  # (32, 769)

# compute the ntk in batches
def compute_ntk(x1, x2, params):
    G1 = per_sample_grads(x1.squeeze(), params)  # (32, 769)
    G2 = per_sample_grads(x2.squeeze(), params)  # (32, 769)

    def _compute_ntk(g1, g2):
        return g1[:, None].T @ g2[:, None]
    
    per_sample_ntk = jax.vmap(lambda g1, g2: _compute_ntk(g1, g2))
    return per_sample_ntk(G1, G2)

x1, _ = next(test_iter)
x2, _ = next(test_iter)
ntks = compute_ntk(x1, x2)
ntks.shape  # (32, 1, 1)

# compute an implementation of kare that supports batches
def kare(y, K, z):
    def _kare(y, K, z):
        n = K.shape[0]
        K_norm = K / n
        mat = K_norm + z * jnp.eye(n)
        inv = jax.jit(jnp.linalg.inv, backend="cpu")(mat)
        inv2 = inv @ inv
        return ((1/n) * y.T @ inv2 @ y) / ((1/n) * jnp.trace(inv)) ** 2
    
    per_sample_kare = jax.vmap(lambda y, K: _kare(y, K, z))
    return jnp.sum(per_sample_kare(y, K))

@jax.jit
def compute_kare(x, y, z):
    K = compute_ntk(x, x)
    return kare(y, K, z).squeeze()

grad_kare = jax.grad(compute_kare)
optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(model.params)

params = model.params

x, y = next(test_iter)
grads = grad_kare(x, y, z)
print(grads)

# for x, y in test_iter:
#     grads = grad_kare(x, y, z)
#     grads = list(grads.squeeze())
#     print(grads)
#     # updates, opt_state = optimizer.update(grads, opt_state)

In [None]:
datafactory = create_data_factory(dataset_config)
train_iter, _ = datafactory()
params = model.params
z = 1e-3
optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(model.params)


# for x, y in train_iter:
#     grads = grad_kare(params, x, y, z)
#     updates, opt_state = optax.update(grads, opt_state) 

In [None]:
datafactory = create_data_factory(dataset_config)
train_iter, _ = datafactory()

def compute_gradient(f, p, xs) -> jnp.ndarray:
    grad_fn = jax.grad(lambda p, x: f(p, x).squeeze())

    def _pointwise(x):
        flat_grads = []
        for layer in grad_fn(p, x):
            for part in layer:
                flat_grads.append(part.flatten())
        return jnp.concatenate(flat_grads)

    per_sample = jax.vmap(_pointwise, in_axes=0)
    return per_sample(xs)

def compute_ntk(p, xs1, xs2) -> jnp.ndarray:
    G1 = compute_gradient(model.apply_fn, p, xs1)
    G2 = compute_gradient(model.apply_fn, p, xs2)
    return G1.dot(G2.T)

def kare(p, x, y, z):  # y.shape=(32,1), K.shape=(32,32), z.shape=()
    K = compute_ntk(p, x, x)
    n = K.shape[0]
    K_norm = K / n
    mat = K_norm + z * jnp.eye(n)
    inv = jax.jit(jnp.linalg.inv, backend="cpu")(mat)
    inv2 = inv @ inv
    return (((1/n) * y.T @ inv2 @ y) / ((1/n) * jnp.trace(inv)) ** 2)[0, 0]

xs, ys = next(train_iter)
kare(model.params, xs, ys, 1e-3)


In [None]:
datafactory = create_data_factory(dataset_config)
train_iter, _ = datafactory()
params = model.params
z = 1e-3
optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(model.params)
grad_kare = jax.grad(kare)

xs, ys = next(train_iter)
grads = grad_kare(model.params, x, y, z)
print(len(grads))
print(grads)

# for x, y in train_iter:
#     grads = grad_kare(model.params, x, y, z)
#     updates, opt_state = optimizer.update(grads, opt_state)
#     params = optax.apply_updates(params, updates)
