## Imports

In [1]:
from config.config import (
    Config,
    LinearModelConfig,
    ModelConfig,
    RandomRegressionConfig,
    TrainingConfig,
    TrainingConfig,
    get_config,
)
from main import create_dataset_and_model, train_and_evaluate
from data.data import create_dataset

import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp
from jax import random

from config.config import (
    Config,
    LinearModelConfig,
    ModelConfig,
    RandomRegressionConfig,
    TrainingConfig,
    TrainingConfig,
    get_config,
)
from main import train_and_evaluate
import jax
import jax.numpy as jnp

## KFAC api


In [None]:
# config = get_config("random_regression_single_feature")
config = Config(
    dataset=RandomRegressionConfig(
        n_samples=100,
        n_features=1,
        n_targets=1,
        noise=30,
        random_state=42,
        train_test_split=1,
    ),
    model=LinearModelConfig(
        name="linear", loss="mse", hidden_dim=[10, 15, 30, 2, 4, 19, 10]
    ),
    training=TrainingConfig(
        epochs=0,
        lr=0.01,
        optimizer="sgd",
        loss="mse",
    ),
)

key = jax.random.PRNGKey(42)

dataset = create_dataset(config.dataset)



In [3]:
from kfac import LinearModel

model_kwargs = vars(config.model).copy()
model_kwargs.pop("name")
model_kwargs.pop("loss")
model_kwargs.update(
    {"input_dim": dataset.input_dim(), "output_dim": dataset.output_dim()}
)

model = LinearModel(**model_kwargs)

In [4]:
key = jax.random.PRNGKey(42)
params = model.init(key, jnp.ones((1, 1)))["params"]

dataset.split_dataset()
x_train, y_train = dataset.get_train_data()
x_train = jnp.array(x_train)
y_train = jnp.array(y_train)

In [5]:
def loss_fn_with_capture(params, x, y, model, collector):
    """
    Calculates the loss by calling the model's kfac_apply method.
    """
    pred = model.apply(
        {"params": params},
        x,
        collector,
        method=model.kfac_apply,
    )

    return jnp.mean((pred - y) ** 2)

In [6]:
def normal_loss_fn(params, x, y, model):
    """Standard loss function without KFAC wrapper."""
    pred = model.apply({"params": params}, x)
    return jnp.mean((pred.reshape(y.shape) - y) ** 2)


In [7]:
from kfac import KFACOptimizer, KFACCollector


# This function does not need any changes.
def train_step(params, opt_state, x, y):
    collector = KFACCollector()
    loss_fn_for_grad = lambda p: loss_fn_with_capture(p, x, y, model, collector)  # type: ignore

    loss, grads = jax.value_and_grad(loss_fn_for_grad)(params)

    captured_data_for_step = collector.captured_data

    updated_params, new_opt_state = kfac_optimizer.step(
        params, grads, opt_state, captured_data_for_step
    )

    return updated_params, new_opt_state, loss, collector


# --- Let's run it again ---
kfac_optimizer = KFACOptimizer()
opt_state = kfac_optimizer.init(params)
losses = []

epochs = 3
batch_size = 9
for epoch in range(epochs):
    print(f"\n--- Epoch {epoch} ---")
    batch_x = x_train[epoch * batch_size : (epoch + 1) * batch_size]
    batch_y = y_train[epoch * batch_size : (epoch + 1) * batch_size]

    params, opt_state, loss, collector = train_step(params, opt_state, batch_x, batch_y)
    ground_truth_grads = jax.grad(normal_loss_fn)(params, batch_x, batch_y, model)

    losses.append(float(loss))
    print(f"Epoch {epoch}: loss={loss:.4f}")
    break

print("\nFinal Optimizer State Covariances:")
# Use jax.tree_map to print shapes for clarity



--- Epoch 0 ---


Epoch 0: loss=1447.8220

Final Optimizer State Covariances:


In [8]:
ground_truth_grads["linear_1"]["kernel"]


Array([[  6.169838  ,   9.5792    ,  -1.229347  ,   4.0899725 ,
         -5.266316  ,   1.7055271 ,  -2.9594076 ,  -1.6442432 ,
         -3.0903447 ,   7.8578806 ,   2.159025  , -11.400885  ,
         -0.52694327,  -2.5515263 ,  -0.150404  ],
       [  7.6029997 ,  11.804305  ,  -1.5149062 ,   5.0400124 ,
         -6.489604  ,   2.101696  ,  -3.6468341 ,  -2.026177  ,
         -3.8081863 ,   9.68315   ,   2.6605344 , -14.049141  ,
         -0.6493444 ,  -3.1442084 ,  -0.18534063],
       [  3.7146704 ,   5.767342  ,  -0.7401523 ,   2.4624472 ,
         -3.170688  ,   1.0268458 ,  -1.7817687 ,  -0.9899487 ,
         -1.860602  ,   4.7309895 ,   1.2998829 ,  -6.8641233 ,
         -0.31725642,  -1.536196  ,  -0.09055365],
       [ -4.5299573 ,  -7.0331445 ,   0.902599  ,  -3.0028987 ,
          3.8665826 ,  -1.2522151 ,   2.1728265 ,   1.20722   ,
          2.2689621 ,  -5.7693343 ,  -1.5851778 ,   8.370645  ,
          0.38688704,   1.8733563 ,   0.11042814],
       [  4.714502  ,   7.31

In [9]:
keys = list(ground_truth_grads.keys())

In [10]:
# outer product of a and g
for key in keys:
    a = collector.captured_data[key][0]
    g = collector.captured_data[key][1]

    ag = jnp.einsum("bi,bj->ij", a, g)

    diff = ag - ground_truth_grads[key]["kernel"]
    print(jnp.abs(diff).max(), jnp.abs(diff).sum())


0.0 0.0
2.3841858e-07 8.791685e-07
9.536743e-07 8.29529e-06
1.9073486e-06 1.8246472e-05
7.450581e-09 9.313226e-09
9.536743e-07 5.3886324e-06
4.7683716e-07 2.8163195e-06
0.0 0.0


In [11]:
# absolute value
jnp.abs(diff).max(), jnp.abs(diff).sum()

(Array(0., dtype=float32), Array(0., dtype=float32))

In [12]:
diff

Array([[0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.]], dtype=float32)