In [None]:
import warnings
warnings.filterwarnings('ignore')

import os
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from xgboost import XGBRegressor
import matplotlib.pyplot as plt
import seaborn as sns
import timm

In [None]:
reuse = True
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
RESULTS_CSV = "Transformer_eval_records.csv"
ARCHITECTURE = "deit_tiny_patch16_224"

DATASETS = {
    'MNIST':        (datasets.MNIST, {'train': True}),
    'FashionMNIST': (datasets.FashionMNIST, {'train': True}),
    'CIFAR10':      (datasets.CIFAR10, {'train': True}),
    'QMNIST':       (datasets.QMNIST, {'train': True, 'what': 'train'})
}

# Resize + RGB conversion for all datasets
TRANSFORMS = {}
for name in DATASETS:
    if name == 'CIFAR10' or name == 'CIFAR100':
        mean = [0.4914, 0.4822, 0.4465]
        std = [0.2470, 0.2435, 0.2616]
    else:
        mean = [0.1307] * 3
        std = [0.3081] * 3
    TRANSFORMS[name] = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0) == 1 else x),
        transforms.Normalize(mean, std)
    ])


LR_VALUES     = [0.0005, 0.001, 0.002]
BATCH_SIZES   = [50, 100]
EPS_VALUES    = [0.6]  # For Transformers
N_EVAL_TRIALS = 3
FONT_SCALE = 1.5

In [None]:
def build_transformer(num_classes: int):
    model = timm.create_model(ARCHITECTURE, pretrained=True, num_classes=num_classes)
    for name, param in model.named_parameters():
        if "head" not in name and "classifier" not in name:
            param.requires_grad = False
    return model.to(DEVICE)

In [None]:
def extract_probe_features(model, X, y, criterion):
    model.train()
    logP = np.log(sum(p.numel() for p in model.parameters()))
    logB = np.log(min(32, X.size(0)))
    Xp, yp = X[:32].to(DEVICE), y[:32].to(DEVICE)
    params = [p for p in model.parameters() if p.requires_grad]
    g2_list, tau_list = [], []
    for xi, yi in zip(Xp, yp):
        xi, yi = xi.unsqueeze(0), yi.unsqueeze(0)
        model.zero_grad()
        logits = model(xi)
        loss = criterion(logits, yi)
        grads = torch.autograd.grad(loss, params, retain_graph=True)
        gv = torch.cat([g.contiguous().view(-1) for g in grads])
        g2_list.append((gv**2).sum().item())
        model.zero_grad()
        true_logit = logits.view(-1)[yi.item()]
        grads_f = torch.autograd.grad(true_logit, params, retain_graph=True)
        fv = torch.cat([g.contiguous().view(-1) for g in grads_f])
        tau_list.append((fv**2).sum().item())
    logG2 = np.log(np.mean(g2_list))
    logTau = np.log(np.sum(tau_list))
    return np.array([logP, logB, logG2, logTau])

In [None]:
def measure_convergence(model, X, y, eps, lr, criterion):
    model.to(DEVICE).train()
    X, y = X.to(DEVICE), y.to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    init_loss = None
    for t in range(1, 501):
        optimizer.zero_grad()
        logits = model(X)
        loss = criterion(logits, y)
        if t == 1:
            init_loss = loss.item()
        if loss.item() <= eps * init_loss:
            return t
        loss.backward()
        optimizer.step()
    return 500

In [None]:
def run_transformer_evaluation():
    print("Running Transformer Evaluation...")
    df_meta = pd.read_csv('../meta_datasets/meta_dataset_transformer.csv')
    FEATURES = ['logP','logB','logG2','logTau','logLR','logN']
    X_meta = df_meta[FEATURES].values
    y_meta = df_meta['T_star'].values
    meta_reg = XGBRegressor(n_estimators=200, max_depth=4, random_state=42)
    meta_reg.fit(X_meta, y_meta)

    records = []
    criterion = nn.CrossEntropyLoss()

    for ds_name, (ds_cls, ds_args) in DATASETS.items():
        ds = ds_cls(root='./data', download=True, transform=TRANSFORMS[ds_name], **ds_args)
        num_classes = len(ds.classes) if hasattr(ds, 'classes') else 10
        total_N = len(ds)

        for lr in LR_VALUES:
            logLR = np.log(lr)
            for B in BATCH_SIZES:
                loader = DataLoader(ds, batch_size=B, shuffle=True)
                for _ in range(N_EVAL_TRIALS):
                    model = build_transformer(num_classes)
                    Xp, yp = next(iter(loader))
                    z0 = extract_probe_features(model, Xp, yp, criterion)
                    z  = np.concatenate([z0, [logLR, np.log(total_N)]])
                    T_pred = meta_reg.predict(z.reshape(1, -1))[0]
                    T_act = measure_convergence(model, Xp, yp, eps=EPS_VALUES[0], lr=lr, criterion=criterion)
                    records.append({
                        'dataset': ds_name,
                        'learning_rate': lr,
                        'batch_size': B,
                        'T_star': T_act,
                        'T_pred': T_pred,
                        'TT_error': abs(T_act - T_pred)
                    })

    df = pd.DataFrame(records)
    df.to_csv(RESULTS_CSV, index=False)
    print(f"Results saved to {RESULTS_CSV}")
    return df

In [None]:
def plot_tt_error(df):
    df_lr = df.groupby(['learning_rate', 'dataset'])['TT_error'].agg(['mean', 'std']).reset_index()
    df_bs = df.groupby(['batch_size', 'dataset'])['TT_error'].agg(['mean', 'std']).reset_index()

    sns.set_context("talk", font_scale=FONT_SCALE)
    sns.set(style='ticks')

    # Plot 1: Learning Rate
    fig_lr, ax_lr = plt.subplots(figsize=(8, 6))
    sns.barplot(data=df_lr, x='learning_rate', y='mean', hue='dataset',
                palette='Set2', edgecolor='black', capsize=0.15,
                errwidth=1.5, ax=ax_lr)
    ax_lr.set_xlabel('Learning Rate')
    ax_lr.set_ylabel('Mean Absolute TT Error (MAE)')
    ax_lr.set_ylim(top=df_lr['mean'].max() + df_lr['std'].max() + 2)
    ax_lr.legend(title='Dataset', loc='upper left')
    plt.tight_layout()
    plt.show()

    # Plot 2: Batch Size
    fig_bs, ax_bs = plt.subplots(figsize=(8, 6))
    sns.barplot(data=df_bs, x='batch_size', y='mean', hue='dataset',
                palette='Set2', edgecolor='black', capsize=0.15,
                errwidth=1.5, ax=ax_bs)
    ax_bs.set_xlabel('Batch Size')
    ax_bs.set_ylabel('Mean Absolute TT Error (MAE)')
    ax_bs.set_ylim(top=df_bs['mean'].max() + df_bs['std'].max() + 2)
    ax_bs.legend(title='Dataset', loc='upper left')
    plt.tight_layout()
    plt.show()

In [None]:
if reuse and os.path.exists(RESULTS_CSV):
    print(f"Using cached results from {RESULTS_CSV}")
    df_results = pd.read_csv(RESULTS_CSV)
else:
    df_results = run_transformer_evaluation()

plot_tt_error(df_results)