In [1]:
import sys

if "google.colab" in sys.modules:
    from google.colab import drive
    drive.mount("/content/drive")

In [None]:
import os

if not os.path.isdir("cld_optimization_experiments"):
    !git clone https://github.com/oopir/cld_optimization_experiments

%cd cld_optimization_experiments

In [None]:
import os
import glob
from datetime import datetime

import numpy as np
import random
import torch

from src.data import load_digits_data
from src.training import train_multiseed
from src.plots import plot_ex1_multiseed, plot_ex1_multiseed_short
from src.metric_checkpoints import Exp1Config, save_exp1_checkpoint, load_exp1_checkpoint

torch.cuda.empty_cache()
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
USE_CHECKPOINT = False
CKPT_PATH = None
# CKPT_PATH = "/content/drive/MyDrive/cld_checkpoints/?????.pt"

os.makedirs("checkpoints", exist_ok=True)

def latest_exp1_checkpoint():
    paths = glob.glob("checkpoints/exp1_digits_*.pt")
    if not paths:
        return None
    return max(paths, key=os.path.getmtime)

if USE_CHECKPOINT:
    ckpt_path = CKPT_PATH or latest_exp1_checkpoint()
    if ckpt_path is None:
        raise FileNotFoundError(
            "No exp1 checkpoint found; set USE_CHECKPOINT=False to train."
        )
    results, config = load_exp1_checkpoint(ckpt_path)
    print(f"Loaded checkpoint: {ckpt_path}")

else:
    epochs = 200
    n      = 10
    betas_to_plot = [10*n, 50*n, 100*n]

    m = max([n * np.log(n) * beta * np.log(beta) for beta in betas_to_plot])
    m = int(max(4096, m))
    print(f"m={m}")

    config = Exp1Config(
        seeds=list(range(2)),
        n=n,
        random_labels=False,
        betas=betas_to_plot,
        epochs=epochs,
        eta=1e-4,
        m=m,
        device=device,
        track_every=max(1,epochs//100),
        print_every=epochs//5,
    )

    common = config.train_kwargs()
    results = {
        f"Î²={beta // config.n}n": train_multiseed(beta=beta, **common)
        for beta in config.betas
    }

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    ckpt_path = os.path.join("/content/drive/MyDrive/cld_checkpoints", f"exp1_digits_{timestamp}.pt")
    save_exp1_checkpoint(ckpt_path, results, config)
    print(f"Saved checkpoint: {ckpt_path}")


In [None]:
plot_ex1_multiseed_short(results, config.epochs, config.track_every)

In [None]:
# plot_ex1_multiseed(results, epochs, track_every)