In [None]:
import warnings
warnings.filterwarnings('ignore')

import os
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import random
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from xgboost import XGBRegressor
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
def build_cnn(input_shape: tuple, num_classes: int) -> nn.Module:
    c, h, w = input_shape
    return nn.Sequential(
        nn.Conv2d(c, 32, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
        nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
        nn.Flatten(),
        nn.Linear(64 * (h//4) * (w//4), 256), nn.ReLU(),
        nn.Linear(256, num_classes)
    )

In [None]:
def extract_probe_features(model, X, y, criterion):
    model.to(DEVICE).train()
    logP = np.log(sum(p.numel() for p in model.parameters()))
    logB = np.log(min(32, X.size(0)))
    Xp, yp = X[:32].to(DEVICE), y[:32].to(DEVICE)
    params = [p for p in model.parameters() if p.requires_grad]
    g2_list, tau_list = [], []
    for xi, yi in zip(Xp, yp):
        xi, yi = xi.unsqueeze(0), yi.unsqueeze(0)
        model.zero_grad()
        logits = model(xi)
        loss = criterion(logits, yi)
        grads = torch.autograd.grad(loss, params, retain_graph=True)
        gv = torch.cat([g.contiguous().view(-1) for g in grads])
        g2_list.append((gv**2).sum().item())
        model.zero_grad()
        true_logit = logits.view(-1)[yi.item()]
        grads_f = torch.autograd.grad(true_logit, params, retain_graph=True)
        fv = torch.cat([g.contiguous().view(-1) for g in grads_f])
        tau_list.append((fv**2).sum().item())
    logG2 = np.log(np.mean(g2_list))
    logTau = np.log(np.sum(tau_list))
    return np.array([logP, logB, logG2, logTau])

In [None]:
def measure_convergence(model, X, y, eps, lr, criterion):
    model.to(DEVICE).train()
    X, y = X.to(DEVICE), y.to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    init_loss = None
    for t in range(1, 501):
        optimizer.zero_grad()
        logits = model(X)
        loss = criterion(logits, y)
        if t == 1:
            init_loss = loss.item()
        if loss.item() <= eps * init_loss:
            return t
        loss.backward()
        optimizer.step()
    return 500

In [None]:
def run_cnn_evaluation():
    print("Running CNN Evaluation...")
    df_meta = pd.read_csv('../meta_datasets/meta_dataset_cnn.csv')
    FEATURES = ['logP','logB','logG2','logTau','logLR','logN']
    X_meta = df_meta[FEATURES].values
    y_meta = df_meta['T_star'].values
    meta_reg = XGBRegressor(n_estimators=200, max_depth=4, random_state=42)
    meta_reg.fit(X_meta, y_meta)

    records = []
    for ds_name, (ds_cls, ds_args) in DATASETS.items():
        ds = ds_cls(root='./data', download=True, transform=TRANSFORMS[ds_name], **ds_args)
        num_classes = len(ds.classes)
        input_shape = ds[0][0].shape
        total_N = len(ds)
        criterion = nn.CrossEntropyLoss()

        for lr in LR_VALUES:
            logLR = np.log(lr)
            for B in BATCH_SIZES:
                loader = DataLoader(ds, batch_size=B, shuffle=True)
                for _ in range(N_EVAL_TRIALS):
                    model = build_cnn(input_shape, num_classes)
                    Xp, yp = next(iter(loader))
                    z0 = extract_probe_features(model, Xp, yp, criterion)
                    z  = np.concatenate([z0, [logLR, np.log(total_N)]])
                    T_pred = meta_reg.predict(z.reshape(1, -1))[0]
                    T_act = measure_convergence(model, Xp, yp, eps=0.1, lr=lr, criterion=criterion)
                    records.append({
                        'dataset': ds_name,
                        'learning_rate': lr,
                        'batch_size': B,
                        'T_star': T_act,
                        'T_pred': T_pred,
                        'TT_error': abs(T_act - T_pred)
                    })

    df = pd.DataFrame(records)
    df.to_csv(RESULTS_CSV, index=False)
    print(f"Results saved to {RESULTS_CSV}")
    return df

In [None]:
def plot_tt_error(df):
    df_lr = df.groupby(['learning_rate', 'dataset'])['TT_error'].agg(['mean', 'std']).reset_index()
    df_bs = df.groupby(['batch_size', 'dataset'])['TT_error'].agg(['mean', 'std']).reset_index()

    sns.set_context("talk", font_scale=FONT_SCALE)
    sns.set(style='ticks')

    # --- Plot 1: TT Error vs Learning Rate ---
    fig_lr, ax_lr = plt.subplots(figsize=(8, 6))
    sns.barplot(data=df_lr, x='learning_rate', y='mean', hue='dataset',
                palette='Set2', edgecolor='black', capsize=0.15, errwidth=1.5, ax=ax_lr)

    ax_lr.set_xlabel('Learning Rate', fontsize=14 * FONT_SCALE)
    ax_lr.set_ylabel('Mean Absolute TT Error (MAE)', fontsize=14 * FONT_SCALE)
    ax_lr.tick_params(axis='both', labelsize=12 * FONT_SCALE)
    ax_lr.grid(False)

    ylim_max = df_lr['mean'].max() + df_lr['std'].max() + 2
    ax_lr.set_ylim(top=ylim_max)

    ax_lr.legend(title='Dataset', title_fontsize=12 * FONT_SCALE,
                 fontsize=11 * FONT_SCALE, loc='upper left',
                 bbox_to_anchor=(0.02, 0.92),
                 frameon=True, fancybox=True, framealpha=1.0)

    plt.tight_layout()
    plt.show()

    # --- Plot 2: TT Error vs Batch Size ---
    fig_bs, ax_bs = plt.subplots(figsize=(8, 6))
    sns.barplot(data=df_bs, x='batch_size', y='mean', hue='dataset',
                palette='Set2', edgecolor='black', capsize=0.15, errwidth=1.5, ax=ax_bs)

    ax_bs.set_xlabel('Batch Size', fontsize=14 * FONT_SCALE)
    ax_bs.set_ylabel('Mean Absolute TT Error (MAE)', fontsize=14 * FONT_SCALE)
    ax_bs.tick_params(axis='both', labelsize=12 * FONT_SCALE)
    ax_bs.grid(False)

    ylim_max_bs = df_bs['mean'].max() + df_bs['std'].max() + 2
    ax_bs.set_ylim(top=ylim_max_bs)

    ax_bs.legend(title='Dataset', title_fontsize=12 * FONT_SCALE,
                 fontsize=11 * FONT_SCALE, loc='upper left',
                 bbox_to_anchor=(0.02, 0.92),
                 frameon=True, fancybox=True, framealpha=1.0)

    plt.tight_layout()
    plt.show()

In [None]:
reuse = True  # Set to False to rerun evaluation
FONT_SCALE = 1.5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
RESULTS_CSV = "CNN_eval_records.csv"

DATASETS = {
    'MNIST':        (datasets.MNIST, {'train': True}),
    'FashionMNIST': (datasets.FashionMNIST, {'train': True}),
    'CIFAR10':      (datasets.CIFAR10, {'train': True}),
    'CIFAR100':     (datasets.CIFAR100, {'train': True})
}
TRANSFORMS = {name: transforms.Compose([transforms.ToTensor()]) for name in DATASETS}

LR_VALUES     = [0.0005, 0.001, 0.005]
BATCH_SIZES   = [50, 100]
EPS_VALUES    = [0.1]
N_EVAL_TRIALS = 3


In [None]:
if reuse and os.path.exists(RESULTS_CSV):
    df_results = pd.read_csv(RESULTS_CSV)
else:
    df_results = run_cnn_evaluation()

plot_tt_error(df_results)