In [None]:
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from xgboost import XGBRegressor
from sklearn.metrics import mean_absolute_error

In [None]:
def build_cnn(input_shape: tuple, num_classes: int) -> nn.Module:
    c, h, w = input_shape
    return nn.Sequential(
        nn.Conv2d(c, 32, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
        nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
        nn.Flatten(),
        nn.Linear(64 * (h // 4) * (w // 4), 256), nn.ReLU(),
        nn.Linear(256, num_classes)
    )

In [None]:
def extract_probe_features(model, X, y, criterion):
    model.to(DEVICE).train()
    logP = np.log(sum(p.numel() for p in model.parameters()))
    logB = np.log(min(32, X.size(0)))
    Xp, yp = X[:32].to(DEVICE), y[:32].to(DEVICE)
    params = [p for p in model.parameters() if p.requires_grad]
    g2_list, tau_list = [], []
    for xi, yi in zip(Xp, yp):
        xi, yi = xi.unsqueeze(0), yi.unsqueeze(0)
        model.zero_grad()
        logits = model(xi)
        loss   = criterion(logits, yi)
        grads  = torch.autograd.grad(loss, params, retain_graph=True)
        gv     = torch.cat([g.contiguous().view(-1) for g in grads])
        g2_list.append((gv**2).sum().item())
        model.zero_grad()
        true_logit = logits.view(-1)[yi.item()]
        grads_f    = torch.autograd.grad(true_logit, params, retain_graph=True)
        fv         = torch.cat([g.contiguous().view(-1) for g in grads_f])
        tau_list.append((fv**2).sum().item())
    logG2  = np.log(np.mean(g2_list))
    logTau = np.log(np.sum(tau_list))
    return np.array([logP, logB, logG2, logTau])

In [None]:
def measure_convergence(model, X, y, eps, lr, criterion):
    model.to(DEVICE).train()
    X, y = X.to(DEVICE), y.to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    init_loss = None
    for t in range(1, 501):
        optimizer.zero_grad()
        logits = model(X)
        loss   = criterion(logits, y)
        if t == 1:
            init_loss = loss.item()
        if loss.item() <= eps * init_loss:
            return t
        loss.backward()
        optimizer.step()
    return 500

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

DATASETS = {
    'MNIST':        (datasets.MNIST,      {'train': True}),
    'FashionMNIST': (datasets.FashionMNIST,{'train': True}),
    'CIFAR10':      (datasets.CIFAR10,    {'train': True}),
    'CIFAR100':     (datasets.CIFAR100,   {'train': True})
}
TRANSFORMS = {name: transforms.Compose([transforms.ToTensor()]) for name in DATASETS}

LR_VALUES     = [0.0005, 0.001, 0.005]
BATCH_SIZES   = [50, 100]
EPS_VALUES    = [0.1, 0.15, 0.2]
N_EVAL_TRIALS = 10

In [None]:
df_meta = pd.read_csv('../meta_datasets/meta_dataset_cnn.csv')
FEATURES = ['logP','logB','logG2','logTau','logLR','logN']
X_meta  = df_meta[FEATURES].values
y_meta  = df_meta['T_star'].values
meta_reg = XGBRegressor(n_estimators=200, max_depth=4, random_state=42)
meta_reg.fit(X_meta, y_meta)

In [None]:
print("CAPE CNN Evaluation:")
records = []

for ds_name, (ds_cls, ds_args) in DATASETS.items():
    ds = ds_cls(root='./data', download=True,
                transform=TRANSFORMS[ds_name], **ds_args)
    num_classes = len(ds.classes)
    input_shape = ds[0][0].shape
    total_N     = len(ds)
    criterion   = nn.CrossEntropyLoss()

    for lr in LR_VALUES:
        logLR = np.log(lr)
        for B in BATCH_SIZES:
            loader = DataLoader(ds, batch_size=B, shuffle=True)
            for eps in EPS_VALUES:
                y_preds, y_trues = [], []
                for _ in range(N_EVAL_TRIALS):
                    model = build_cnn(input_shape, num_classes)
                    Xp, yp = next(iter(loader))
                    z0 = extract_probe_features(model, Xp, yp, criterion)
                    z  = np.concatenate([z0, [logLR, np.log(total_N)]])
                    T_pred = meta_reg.predict(z.reshape(1, -1))[0]
                    T_act  = measure_convergence(model, Xp, yp, eps, lr, criterion)
                    y_preds.append(T_pred)
                    y_trues.append(T_act)

                mae  = mean_absolute_error(y_trues, y_preds)
                corr = np.corrcoef(y_trues, y_preds)[0, 1]
                print(f"{ds_name} | lr={lr} | B={B} | eps={eps} | TT error={mae:.0f} | Corr={corr:.2f}")
                records.append({
                    'dataset':    ds_name,
                    'TT_error':   mae,
                    'Corr':       corr
                })

In [None]:
df_results = pd.DataFrame(records)
df_summary = df_results.groupby("dataset").agg({
    "TT_error": "mean",
    "Corr": "mean"
}).reset_index()

df_summary["TT_error"] = df_summary["TT_error"].round(2)
df_summary["Corr"] = df_summary["Corr"].round(3)

# --- Save results ---
df_summary.to_csv("CNN_evaluation_dataset_avg.csv", index=False)
print("Per-dataset average results saved to CNN_evaluation_dataset_avg.csv")