In [None]:
import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent))

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

from typing import Callable, Iterator

import jax
from jax import numpy as jnp
import optax

from datasets.mnist import load_mnist, MnistLoaderConfig
from rubicon.nns.convnet import ConvNetConfig, ConvNet
from rubicon.ntk.ntk import NTKConfig, NeuralTangentKernel

jax.devices()

In [None]:
def get_input_shape(iterator: Iterator[tuple[jnp.ndarray, jnp.ndarray]]) -> tuple[int, ...]:
    """Extract the input shape from the first batch of an iterator"""
    try:
        x_batch, _ = next(iter(iterator))
        return x_batch.shape[:]
    except StopIteration:
        raise ValueError("Empty iterator; cannot determine input shape")

def create_data_factory(config: MnistLoaderConfig) -> Callable[[], Iterator[tuple[jnp.ndarray, jnp.ndarray]]]:
    """Create a factory function for training and test iterators"""
    def data_factory():
        train_iter, test_iter = load_mnist(config=config)
        return train_iter, test_iter
    return data_factory


In [None]:
mnist_loader_config = MnistLoaderConfig(batch_size=128)
temp_train_iter, _ = load_mnist(mnist_loader_config)
input_shape = get_input_shape(temp_train_iter)

data_factory = create_data_factory(mnist_loader_config)

In [None]:
config = ConvNetConfig(
    conv_filters=[32],
    kernel_sizes=[(3, 3)],
    dense_sizes=[10],
    batch_size=128,
    input_shape=input_shape,
    num_epochs=10,
    seed=42,
    # activation_function=stax.Gelu,
)
net = ConvNet(config)
net.initialize(); net.initialized

In [None]:
ntk_config = NTKConfig(
    z=1e-4,
    optimizer=optax.adam,
    learning_rate=1e-3,
    lambd=1e-6,
)
ntk = NeuralTangentKernel.from_convnet(net, config=ntk_config)

In [None]:
# temporarily before training so that we don't have to wait a whole day
ntk.train_with_kare(
    data_factory=data_factory,
    num_epochs=10,
    start_from_init=True,
    return_metrics=True,
    verbose=True
)

In [None]:
training_history = net.train(
    data_factory=data_factory,
    return_metrics=True,
    verbose=True
)


In [None]:
ntk.train_with_kare(
    data_factory=data_factory,
    num_epochs=10,
    start_from_init=True,
    return_metrics=True,
    verbose=True
)