In [None]:
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
from xgboost import XGBRegressor
from sklearn.metrics import mean_absolute_error
import timm

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ARCHITECTURE = 'deit_tiny_patch16_224'
LR_VALUES = [0.0005, 0.001, 0.002]
BATCH_SIZES = [50, 100]
EPS_VALUES = [0.6, 0.5, 0.4]
N_EVAL_TRIALS = 10
MAX_STEPS = 500

# --- Datasets for Meta-Evaluation ---
DATASETS = {
    'MNIST':         (datasets.MNIST, {'train': True}),
    'FashionMNIST':  (datasets.FashionMNIST, {'train': True}),
    'CIFAR10':       (datasets.CIFAR10, {'train': True}),
    'QMNIST':        (datasets.QMNIST, {'train': True, 'what': 'train'}),
}

# Resize + RGB conversion for all datasets
TRANSFORMS = {}
for name in DATASETS:
    if name == 'CIFAR10' or name == 'CIFAR100':
        mean = [0.4914, 0.4822, 0.4465]
        std = [0.2470, 0.2435, 0.2616]
    else:
        mean = [0.1307] * 3
        std = [0.3081] * 3
    TRANSFORMS[name] = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0) == 1 else x),
        transforms.Normalize(mean, std)
    ])

In [None]:
def build_timm_model(name, num_classes):
    model = timm.create_model(name, pretrained=True, num_classes=num_classes)
    for pname, p in model.named_parameters():
        if "head" not in pname and "fc" not in pname and "classifier" not in pname:
            p.requires_grad = False
    return model.to(DEVICE)

In [None]:
def extract_probe_features(model, X, y, criterion):
    model.train()
    logP = np.log(sum(p.numel() for p in model.parameters()))
    logB = np.log(min(32, X.size(0)))
    Xp, yp = X[:32].to(DEVICE), y[:32].to(DEVICE)
    params = [p for p in model.parameters() if p.requires_grad]
    g2_list, tau_list = [], []
    for xi, yi in zip(Xp, yp):
        xi, yi = xi.unsqueeze(0), yi.unsqueeze(0)
        model.zero_grad()
        logits = model(xi)
        loss = criterion(logits, yi)
        grads = torch.autograd.grad(loss, params, retain_graph=True)
        gv = torch.cat([g.contiguous().view(-1) for g in grads])
        g2_list.append((gv**2).sum().item())
        model.zero_grad()
        true_logit = logits.view(-1)[yi.item()]
        grads_f = torch.autograd.grad(true_logit, params, retain_graph=True)
        fv = torch.cat([g.contiguous().view(-1) for g in grads_f])
        tau_list.append((fv**2).sum().item())
    logG2 = np.log(np.mean(g2_list))
    logTau = np.log(np.sum(tau_list))
    return np.array([logP, logB, logG2, logTau])

In [None]:
def measure_convergence(model, loader, eps, lr, criterion):
    model.train()
    X0, y0 = next(iter(loader))
    X0, y0 = X0.to(DEVICE), y0.to(DEVICE)
    model.eval()
    with torch.no_grad():
        init_loss = criterion(model(X0), y0).item()
    threshold = eps * init_loss
    model.train()
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
    it = iter(loader)
    for step in range(1, MAX_STEPS + 1):
        try:
            Xb, yb = next(it)
        except StopIteration:
            it = iter(loader)
            Xb, yb = next(it)
        Xb, yb = Xb.to(DEVICE), yb.to(DEVICE)
        optimizer.zero_grad()
        logits = model(Xb)
        loss = criterion(logits, yb)
        if loss.item() <= threshold:
            return step
        loss.backward()
        optimizer.step()
    return MAX_STEPS

In [None]:
df_meta = pd.read_csv('../meta_datasets/meta_dataset_transformer.csv')
FEATURES = ['logP','logB','logG2','logTau','logLR','logN']
X_meta = df_meta[FEATURES].values
y_meta = df_meta['T_star'].values
meta_reg = XGBRegressor(n_estimators=200, max_depth=4, random_state=42)
meta_reg.fit(X_meta, y_meta)

# --- Evaluation ---
print("CAPE Transformer Evaluation:")
records = []
criterion = nn.CrossEntropyLoss()

In [None]:
for ds_name, (ds_cls, ds_args) in DATASETS.items():
    print(f"\nEvaluating on: {ds_name}")
    full_ds = ds_cls(root='./data', download=True, transform=TRANSFORMS[ds_name], **ds_args)
    subset = Subset(full_ds, range(1000))  # Keep evaluation fast
    num_classes = len(full_ds.classes) if hasattr(full_ds, 'classes') else 10
    total_N = len(subset)

    for lr in LR_VALUES:
        logLR = np.log(lr)
        for B in BATCH_SIZES:
            loader = DataLoader(subset, batch_size=B, shuffle=True)
            for eps in EPS_VALUES:
                y_preds, y_trues = [], []
                for _ in range(N_EVAL_TRIALS):
                    model = build_timm_model(ARCHITECTURE, num_classes)
                    Xp, yp = next(iter(loader))
                    z0 = extract_probe_features(model, Xp, yp, criterion)
                    z = np.concatenate([z0, [logLR, np.log(total_N)]])
                    T_pred = meta_reg.predict(z.reshape(1, -1))[0]
                    T_act = measure_convergence(model, loader, eps, lr, criterion)
                    y_preds.append(T_pred)
                    y_trues.append(T_act)

                mae = mean_absolute_error(y_trues, y_preds)
                corr = np.corrcoef(y_trues, y_preds)[0, 1]
                print(f"{ds_name} | lr={lr} | B={B} | eps={eps} | MAE={mae:.1f} | Corr={corr:.2f}")
                records.append({
                    'dataset': ds_name,
                    'TT_error': mae,
                    'Corr': corr
                })

In [None]:
df_results = pd.DataFrame(records)
df_summary = df_results.groupby("dataset").agg({
    "TT_error": "mean",
    "Corr": "mean"
}).reset_index()

df_summary["TT_error"] = df_summary["TT_error"].round(2)
df_summary["Corr"] = df_summary["Corr"].round(3)

# --- Save to CSV ---
df_summary.to_csv("Transformer_evaluation_dataset_avg.csv", index=False)
print("\nSaved to Transformer_evaluation_dataset_avg.csv")