In [None]:
import sys
import os
from pathlib import Path

if "google.colab" in sys.modules:
    from google.colab import drive
    drive.mount("/content/drive")

    if not os.path.isdir("cld_optimization_experiments"):
        !git clone https://github.com/oopir/cld_optimization_experiments
    %cd cld_optimization_experiments
else:
    %cd ..
    ROOT = Path.cwd()

    if str(ROOT) not in sys.path:
        sys.path.insert(0, str(ROOT))

    os.environ["PYTHONPATH"] = str(ROOT) + os.pathsep + os.environ.get("PYTHONPATH", "")

In [None]:
import os
import glob
from datetime import datetime

import numpy as np
import random
import torch

from src.training import train_multiseed
from src.utils import select_idle_gpus_for_experiment
from src.plots import plot_ex1_multiseed, plot_ex1_multiseed_short
from src.metric_checkpoints import Exp1Config, save_exp1_checkpoint, load_exp1_checkpoint

torch.cuda.empty_cache()
device = "cuda" if torch.cuda.is_available() else "cpu"
gpu_ids = select_idle_gpus_for_experiment(device=device, util_threshold=5)
print(f"Using GPUs: {gpu_ids}")

# VERY IMPORTANT THINGS TO SET

In [None]:
SAVE_CHECKPOINT = True
USE_CHECKPOINT = False
CKPT_DIR = "~/cld_checkpoints/expr1"
CKPT_PATH = "/home/ofirg/cld_optimization_experiments/~/cld_checkpoints/expr1/exp1_digits_20260126_161327.pt"

In [None]:
if not USE_CHECKPOINT:
    epochs = int(2e02)
    eta    = 1e-5
    n      = 10
    betas_to_plot = [10*n, 50*n, 100*n]
    seeds = list(range(1))

In [None]:
os.makedirs(CKPT_DIR, exist_ok=True)

def latest_exp1_checkpoint():
    paths = glob.glob(os.path.join(CKPT_DIR, "exp1_digits_*.pt"))
    if not paths:
        return None
    return max(paths, key=os.path.getmtime)

if USE_CHECKPOINT:
    ckpt_path = CKPT_PATH or latest_exp1_checkpoint()
    if ckpt_path is None:
        raise FileNotFoundError(
            "No exp1 checkpoint found; set USE_CHECKPOINT=False to train."
        )
    results, config = load_exp1_checkpoint(ckpt_path)
    print(f"Loaded checkpoint: {ckpt_path}")

else:
    m = max([n * np.log(n) * beta * np.log(beta) for beta in betas_to_plot])
    m = int(max(4096, m))
    print(f"m={m}")

    config = Exp1Config(
        seeds=seeds,
        n=n,
        random_labels=False,
        betas=betas_to_plot,
        epochs=epochs,
        eta=eta,
        m=m,
        device=device,
        track_every=max(1,epochs//100),
        print_every=epochs//5,
    )

    common = config.train_kwargs()
    common["gpu_ids"] = gpu_ids
    
    results = {
        f"Î²={beta // config.n}n": train_multiseed(dataset="digits", beta=beta, **common)
        for beta in config.betas
    }

    if SAVE_CHECKPOINT:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        ckpt_path = os.path.join(CKPT_DIR, f"exp1_digits_{timestamp}.pt")
        save_exp1_checkpoint(ckpt_path, results, config)
        print(f"Saved checkpoint: {ckpt_path}")


In [None]:
# plot_ex1_multiseed_short(results, config.epochs, config.track_every)

In [None]:
plot_ex1_multiseed(results, config.epochs, config.track_every)