### prep

In [None]:
import sys
import os
from pathlib import Path

if "google.colab" in sys.modules:
    from google.colab import drive
    drive.mount("/content/drive")

    if not os.path.isdir("cld_optimization_experiments"):
        !git clone -b mnist --single-branch https://github.com/oopir/cld_optimization_experiments
    %cd cld_optimization_experiments
else:
    %cd ..
    ROOT = Path.cwd()

    if str(ROOT) not in sys.path:
        sys.path.insert(0, str(ROOT))

    os.environ["PYTHONPATH"] = str(ROOT) + os.pathsep + os.environ.get("PYTHONPATH", "")

In [None]:
import os
import math
import numpy as np
import torch
import matplotlib.pyplot as plt

from src.training import train_multiseed

print(torch.__version__)
torch.cuda.is_available(), torch.cuda.device_count()

### config

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
n = 60000
m = 10000
epochs = ?????
eta = 1e-2
beta_fixed = 1e07
seed = 0

track_every = max(1, epochs // 100)
print_every = max(1, epochs // 100)

alpha_grid = np.logspace(-2, 2, num=13)  # 0.01 ... 100
alpha_grid

### run and plot

In [None]:
def run_single_alpha(alpha):
    print(f"\n=== alpha = {alpha:.3g}, beta = {beta_fixed:.3g} ===")
    results = train_multiseed(
        dataset="mnist",
        seeds=[seed],
        n=n,
        random_labels=False,
        eta=eta,
        epochs=epochs,
        beta=beta_fixed,
        m=m,
        init_type="alpha",
        lam_fc1=None,
        lam_fc2=None,
        regularization_scale=1.0,
        use_linearized=False,
        track_jacobian=False,
        jac_probe_size=1,
        device=device,
        track_every=track_every,
        print_every=print_every,
        gpu_ids=[0],          # or None / adapt if you use multiple GPUs
        add_noise=True,       # CLD noise on
        alpha=alpha,          # <-- important
    )
    return results

In [None]:
results_by_alpha = {}

for alpha in alpha_grid:
    res = run_single_alpha(alpha)
    results_by_alpha[alpha] = res

# optionally: save to disk
save_path = "mnist_alpha_sweep_results.pt"
torch.save(results_by_alpha, save_path)
print(f"saved to {save_path}")


In [None]:
def plot_alpha_sweep_test_error(results_by_alpha, label=None):
    """
    results_by_alpha: dict[alpha -> dict[seed -> metrics]]
    Uses best test accuracy over training for each alpha.
    """
    alphas = np.array(sorted(results_by_alpha.keys()))
    test_errors = []

    for alpha in alphas:
        run_results_by_seed = results_by_alpha[alpha]
        # single seed
        run_seed = next(iter(run_results_by_seed.keys()))
        metrics = run_results_by_seed[run_seed]

        test_acc_hist = np.asarray(metrics["test_acc_hist"])
        best_test_acc = float(test_acc_hist.max())
        test_errors.append(100.0 * (1.0 - best_test_acc))

    test_errors = np.array(test_errors)

    plt.figure(figsize=(4, 3))
    plt.semilogx(alphas, test_errors, "--o", label=label or "NN")
    plt.xlabel(r"$\alpha$")
    plt.ylabel("Test error (%)")
    plt.ylim(bottom=0)
    if label is not None:
        plt.legend()
    plt.tight_layout()
    plt.show()

plot_alpha_sweep_test_error(results_by_alpha)


### investigate single runs

In [None]:
alpha_example = 1.0
run_results_by_seed = results_by_alpha[alpha_example]
run_seed = next(iter(run_results_by_seed.keys()))
metrics = run_results_by_seed[run_seed]

x = np.arange(1, epochs + 1, track_every)

plt.figure(figsize=(10,4))
plt.subplot(1,2,1)
plt.plot(x, metrics["train_loss_hist"])
plt.xlabel("epoch"); plt.title("train loss")

plt.subplot(1,2,2)
plt.plot(x, metrics["train_acc_hist"], label="train acc")
plt.plot(x, metrics["test_acc_hist"], label="test acc")
plt.xlabel("epoch"); plt.title("accuracy"); plt.legend()
plt.tight_layout()
plt.show()


### loading saved content (and plotting it)

In [None]:
import torch

load_path = "mnist_alpha_sweep_results.pt"  # adjust if needed
results_by_alpha_loaded = torch.load(load_path)
print(f"loaded from {load_path}")
print("alphas:", sorted(results_by_alpha_loaded.keys()))


In [None]:
ot_alpha_sweep_test_error(results_by_alpha_loaded)