In [None]:
import warnings
warnings.filterwarnings('ignore')

import os
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import random
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from xgboost import XGBRegressor
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import mean_absolute_error
from sklearn.linear_model import LinearRegression
from scipy.stats import pearsonr

In [None]:
def build_mlp(input_dim: int, num_classes: int) -> nn.Module:
    layers = [nn.Flatten()]
    depth = random.randint(1, 3)
    sizes = [random.choice([64, 128, 256]) for _ in range(depth)]
    dims = [input_dim] + sizes + [num_classes]
    for in_d, out_d in zip(dims[:-1], dims[1:]):
        layers.append(nn.Linear(in_d, out_d))
        if out_d != num_classes:
            layers.append(nn.ReLU())
    return nn.Sequential(*layers)

In [None]:
def extract_probe_features(model, X, y, criterion):
    model.to(DEVICE).train()
    logP = np.log(sum(p.numel() for p in model.parameters()))
    logB = np.log(min(32, X.size(0)))
    Xp, yp = X[:32].to(DEVICE), y[:32].to(DEVICE)
    params = [p for p in model.parameters() if p.requires_grad]
    g2_list, tau_list = [], []
    for xi, yi in zip(Xp, yp):
        xi, yi = xi.unsqueeze(0), yi.unsqueeze(0)
        model.zero_grad()
        logits = model(xi)
        loss = criterion(logits, yi)
        grads = torch.autograd.grad(loss, params, retain_graph=True)
        gv = torch.cat([g.contiguous().view(-1) for g in grads])
        g2_list.append((gv**2).sum().item())
        model.zero_grad()
        true_logit = logits.view(-1)[yi.item()]
        grads_f = torch.autograd.grad(true_logit, params, retain_graph=True)
        fv = torch.cat([g.contiguous().view(-1) for g in grads_f])
        tau_list.append((fv**2).sum().item())
    logG2 = np.log(np.mean(g2_list))
    logTau = np.log(np.sum(tau_list))
    return np.array([logP, logB, logG2, logTau])

In [None]:
def measure_convergence(model, X, y, eps, lr, criterion):
    model.to(DEVICE).train()
    X, y = X.to(DEVICE), y.to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    init_loss = None
    for t in range(1, 501):
        optimizer.zero_grad()
        logits = model(X)
        loss = criterion(logits, y)
        if t == 1:
            init_loss = loss.item()
        if loss.item() <= eps * init_loss:
            return t
        loss.backward()
        optimizer.step()
    return 500

In [None]:
def run_unseen_eval():
    print("Running unseen dataset evaluation...")
    df_meta = pd.read_csv('../meta_datasets/meta_dataset_mlp.csv')
    FEATURES = ['logP','logB','logG2','logTau','logLR','logN']
    X_meta = df_meta[FEATURES].values
    y_meta = df_meta['T_star'].values
    meta_reg = XGBRegressor(n_estimators=200, max_depth=4, random_state=42)
    meta_reg.fit(X_meta, y_meta)

    records = []
    for ds_name, (ds_cls, ds_args) in UNSEEN_DATASETS.items():
        ds = ds_cls(root='./data', download=True, transform=TRANSFORMS[ds_name], **ds_args)
        num_classes = len(ds.classes)
        input_dim   = int(np.prod(ds[0][0].shape))
        total_N     = len(ds)
        criterion   = nn.CrossEntropyLoss()

        for lr in LR_VALUES:
            logLR = np.log(lr)
            for B in BATCH_SIZES:
                loader = DataLoader(ds, batch_size=B, shuffle=True)
                for _ in range(N_EVAL_TRIALS):
                    model = build_mlp(input_dim, num_classes)
                    Xp, yp = next(iter(loader))
                    Xp = Xp.view(Xp.size(0), -1)
                    z0 = extract_probe_features(model, Xp, yp, criterion)
                    z  = np.concatenate([z0, [logLR, np.log(total_N)]])
                    T_pred = meta_reg.predict(z.reshape(1, -1))[0]
                    T_act  = measure_convergence(model, Xp, yp, eps=EPSILON, lr=lr, criterion=criterion)
                    records.append({
                        'dataset': ds_name,
                        'lr': lr,
                        'batch_size': B,
                        'T_star': T_act,
                        'T_pred': T_pred,
                        'abs_error': abs(T_pred - T_act)
                    })

    df = pd.DataFrame(records)
    df.to_csv(RESULTS_CSV, index=False)
    print(f"Saved to {RESULTS_CSV}")
    return df

In [None]:
def plot_scatter(df):
    sns.set(style='white', font_scale=FONT_SCALE)
    fig, ax = plt.subplots(figsize=(10, 8))

    max_val = max(df['T_star'].max(), df['T_pred'].max())
    lims = [0, max_val + 20]

    # Compute average relative error for dynamic bounds
    rel_errors = np.abs((df['T_pred'] - df['T_star']) / df['T_star'])
    avg_rel_error = rel_errors.mean()
    upper = [x * (1 + avg_rel_error) for x in lims]
    lower = [x * (1 - avg_rel_error) for x in lims]

    # Plot perfect prediction and bounds
    ax.plot(lims, lims, 'b-', label='Perfect prediction')
    ax.plot(lims, upper, 'c--', label=f'+{int(avg_rel_error * 100)}% error')
    ax.plot(lims, lower, 'c--', label=f'-{int(avg_rel_error * 100)}% error')

    # Plot points with per-dataset MAE and Pearson r
    colors = sns.color_palette("Set2", len(df['dataset'].unique()))
    for color, ds in zip(colors, df['dataset'].unique()):
        sub = df[df['dataset'] == ds]
        ax.scatter(sub['T_star'], sub['T_pred'], color=color, s=60, edgecolor='black', alpha=0.8)

        mae = mean_absolute_error(sub['T_star'], sub['T_pred'])
        corr, _ = pearsonr(sub['T_star'], sub['T_pred'])

        ax.plot([], [], ' ', label=f'{ds}   MAE={mae:.1f}, r={corr:.2f}')

    ax.set_xlim(lims)
    ax.set_ylim(lims)
    ax.set_xlabel("Real Training Time", fontsize=14 * FONT_SCALE)
    ax.set_ylabel("Predicted Training Time", fontsize=14 * FONT_SCALE)
    ax.tick_params(axis='both', labelsize=12 * FONT_SCALE)
    ax.legend(title=None, loc='upper left', frameon=True, fontsize=11 * FONT_SCALE)
    plt.tight_layout()
    plt.show()

In [None]:
# --- Configuration ---
reuse = True
FONT_SCALE = 1.5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
RESULTS_CSV = "MLP_unseen_scatter.csv"

UNSEEN_DATASETS = {
    'KMNIST': (datasets.KMNIST, {'train': True}),
    'EMNIST': (datasets.EMNIST, {'split': 'balanced', 'train': True})
}
TRANSFORMS = {name: transforms.Compose([transforms.ToTensor()]) for name in UNSEEN_DATASETS}

LR_VALUES     = [0.0005, 0.001, 0.005]
BATCH_SIZES   = [50, 100]
EPSILON       = 0.1
N_EVAL_TRIALS = 100

In [None]:
if reuse and os.path.exists(RESULTS_CSV):
    df_results = pd.read_csv(RESULTS_CSV)
else:
    df_results = run_unseen_eval()
plot_scatter(df_results)