In [None]:
import sys
from pathlib import Path

sys.path.append(str(Path.cwd().parent))

In [None]:
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
from typing import Iterator

import jax.numpy as jnp

from datasets.sine import create_data_factory, load_sine, SineLoaderConfig
from rubicon.nns.mlp import LayerConfig, MLPConfig, MultiLayerPerceptron
from rubicon.nns._base import TrainingConfig, NTKTrainingConfig
from rubicon.nns.metrics.mae import MeanAbsoluteError

In [None]:
def get_input_shape(
    iterator: Iterator[tuple[jnp.ndarray, jnp.ndarray]],
) -> tuple[int, ...]:
    """Extract the input shape from the first batch of an iterator"""
    try:
        x_batch, _ = next(iter(iterator))
        return x_batch.shape[:]
    except StopIteration:
        raise ValueError("Empty iterator; cannot determine input shape")

In [None]:
# get the input shape, necessary for the model initialization.
batch_size = 32
n_train = 1600
n_test = 320
dataset_config = SineLoaderConfig(
    batch_size=batch_size,
    n_train=n_train,
    n_test=n_test,
)
temp_train_iter, _ = load_sine(dataset_config)
input_shape = get_input_shape(temp_train_iter)
input_shape

In [None]:
config = MLPConfig(
    output_layer=LayerConfig(size=1), hidden_layers=[LayerConfig(size=256)]
)
model = MultiLayerPerceptron(config)
model(input_shape=input_shape)
model

In [None]:
# standard training
training_config = TrainingConfig(
    data_factory=create_data_factory(dataset_config),
    num_epochs=2,
    batch_size=batch_size,
    verbose=True,
    accuracy_fn=MeanAbsoluteError(),
)
history = model.fit(training_config)

In [None]:
# training with kare
kare_training_config = NTKTrainingConfig(
    data_factory=create_data_factory(dataset_config),
    num_epochs=2,
    batch_size=batch_size,
    verbose=True,
    accuracy_fn=MeanAbsoluteError(),
    z=1e-3,
    lambd=1e-6,
    update_params=False,
    with_kare=True,
)
history = model.fit(kare_training_config)