In [2]:
pip install thop

Collecting thop
  Downloading thop-0.1.1.post2209072238-py3-none-any.whl.metadata (2.7 kB)
Downloading thop-0.1.1.post2209072238-py3-none-any.whl (15 kB)
Installing collected packages: thop
Successfully installed thop-0.1.1.post2209072238
[0mNote: you may need to restart the kernel to use updated packages.


In [3]:
import os
import random
import numpy as np
import pandas as pd
import torch
from torch.cuda.amp import autocast, GradScaler
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from scipy.optimize import curve_fit
from scipy.stats import linregress, ks_2samp
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.feature_selection import mutual_info_regression
from sklearn.decomposition import PCA
from scipy import stats
import warnings
from scipy.optimize import OptimizeWarning
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.model_selection import cross_val_score
from thop import profile

torch.set_num_threads(1)

scaler = GradScaler()

warnings.filterwarnings("ignore", category=OptimizeWarning)
warnings.filterwarnings(
    "ignore",
    category=RuntimeWarning,
    module="scipy.optimize._minpack_py"
)

# Choose GPU if available, else CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# === Utils ===

_ds_cache = {}

def get_data(ds_name: str, resize: int):
    """
    Return a torchvision Dataset for `ds_name` ('cifar' or 'mnist'),
    resized to (resize × resize).  Downloads only on first call.
    """
    key = (ds_name, resize)
    if key not in _ds_cache:
        # pick the right class
        DataClass = datasets.CIFAR10 if ds_name == 'cifar' else datasets.MNIST

        # build & stash it
        tf = transforms.Compose([
            transforms.Resize((resize, resize)),
            transforms.ToTensor(),
        ])
        _ds_cache[key] = DataClass(
            './data',
            train=False,
            download=True,
            transform=tf
        )
    return _ds_cache[key]

def set_random_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def summarize_sd(arr):
    arr = np.asarray(arr, float)
    arr = arr[~np.isnan(arr)]
    if arr.size == 0:
        return "n/a"
    return f"{arr.mean():.3f} ± {arr.std():.3f} (SD, n={len(arr)})"

def summarize_ci(arr, alpha=0.05):
    """
    mean ± 95 % CI (Student-t, two-sided).
    Returns 'n/a' if nothing to summarise.
    """
    arr = np.asarray(arr, dtype=float)
    arr = arr[~np.isnan(arr)]
    n   = len(arr)
    if n == 0:
        return "n/a"
    mean = arr.mean()
    sem  = stats.sem(arr)               
    half = stats.t.ppf(1-alpha/2, n-1) * sem
    return f"{mean:.3f} ± {half:.3f} (95 % CI, n={n})"

def summarize_p(arr):
    """median [IQR] for p-values – unchanged."""
    arr = np.asarray(arr, dtype=float)
    arr = arr[~np.isnan(arr)]
    if arr.size == 0:
        return "n/a"
    q25, q75 = np.percentile(arr, [25, 75])
    return f"{np.median(arr):.3f} [IQR {q25:.3f}–{q75:.3f}] (n={len(arr)})"

# === Models ===
class ModularCNN(nn.Module):
    def __init__(
        self,
        conv_channels,
        kernel_sizes,
        strides,
        use_leaky_relu: bool = False,
        use_batchnorm: bool = False,
        in_channels: int = 3
    ):
        super().__init__()
        self.convs = nn.ModuleList()
        self.bns   = nn.ModuleList() if use_batchnorm else None
        self.use_bn = use_batchnorm
        self.strides = strides
        self.kernel_sizes = kernel_sizes
        self.use_leaky = use_leaky_relu
        for out_ch, k, s in zip(conv_channels, kernel_sizes, strides):
            conv = nn.Conv2d(in_channels, out_ch, kernel_size=k, padding=k//2, bias=True)
            self.convs.append(conv)
            if use_batchnorm:
                self.bns.append(nn.BatchNorm2d(out_ch))
            in_channels = out_ch

    def forward(self, x):
        activations = []
        for idx, conv in enumerate(self.convs):
            x = conv(x)
            if self.use_bn:
                x = self.bns[idx](x)
            x = F.leaky_relu(x, 0.01) if self.use_leaky else F.relu(x)
            activations.append(x.clone())
            s = self.strides[idx]
            x = F.avg_pool2d(x, kernel_size=s)
        return activations

class ResNetBlock(nn.Module):
    def __init__(
        self,
        in_ch: int,
        out_ch: int,
        stride: int = 1,
        kernel_size: int = 3,
        bottleneck_ratio: float = 1.0,
        projection_type: str = 'identity',
        activation: str = 'relu'
    ):
        super().__init__()
        mid_ch = int(out_ch * bottleneck_ratio)
        # 1st conv
        self.conv1 = nn.Conv2d(in_ch, mid_ch, kernel_size, stride, kernel_size//2, bias=False)
        self.bn1   = nn.BatchNorm2d(mid_ch)
        # 2nd conv
        self.conv2 = nn.Conv2d(mid_ch, out_ch, kernel_size, 1, kernel_size//2, bias=False)
        self.bn2   = nn.BatchNorm2d(out_ch)
        # shortcut
        if projection_type == 'conv1x1' or stride != 1 or in_ch != out_ch:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 1, stride, bias=False),
                nn.BatchNorm2d(out_ch)
            )
        else:
            self.shortcut = nn.Identity()
        # activation
        if activation == 'leaky_relu':
            self.act = lambda x: F.leaky_relu(x, 0.01)
        else:
            self.act = F.relu

    def forward(self, x):
        y = self.act(self.bn1(self.conv1(x)))
        y = self.bn2(self.conv2(y))
        y = y + self.shortcut(x)
        return self.act(y)

class ModularResNet(nn.Module):
    def __init__(
        self,
        channels,
        block_sizes,
        kernel_sizes,
        strides,
        bottleneck_ratios,
        projection_types,
        activation_functions,
        in_channels: int = 3
    ):
        super().__init__()
        # stem
        self.conv1 = nn.Conv2d(in_channels, channels[0], kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(channels[0])
        self.act1  = F.relu
        self.in_ch = channels[0]
        # store params
        self.channels = channels
        self.block_sizes = block_sizes
        self.kernel_sizes = kernel_sizes
        self.strides_list = strides
        self.bottleneck_ratios = bottleneck_ratios
        self.projection_types = projection_types
        self.activation_functions = activation_functions
        # build layers
        self.layers = nn.ModuleList()
        for idx in range(len(channels)):
            self.layers.append(self._make_layer(
                out_ch = channels[idx],
                num_blocks = block_sizes[idx],
                stride = strides[idx],
                kernel_size = kernel_sizes[idx],
                bottleneck_ratio = bottleneck_ratios[idx],
                projection_type = projection_types[idx],
                activation = activation_functions[idx]
            ))

    def _make_layer(
        self,
        out_ch,
        num_blocks,
        stride,
        kernel_size,
        bottleneck_ratio,
        projection_type,
        activation
    ):
        blocks = []
        # first block
        blocks.append(ResNetBlock(
            self.in_ch, out_ch, stride,
            kernel_size, bottleneck_ratio,
            projection_type, activation
        ))
        self.in_ch = out_ch
        # rest
        for _ in range(1, num_blocks):
            blocks.append(ResNetBlock(
                self.in_ch, out_ch, 1,
                kernel_size, bottleneck_ratio,
                'identity', activation
            ))
        return nn.Sequential(*blocks)

    def forward(self, x):
        acts = []
        x = self.act1(self.bn1(self.conv1(x)))
        acts.append(x.clone())
        for layer in self.layers:
            x = layer(x)
            acts.append(x.clone())
        return acts

# === Architecture Generators ===
def generate_random_cnn(
    n_layers: int,
    stride_choices,
    kernel_choices,
    channel_choices,
    input_size: int,
    in_channels: int,
    use_leaky_relu: bool,
    use_batchnorm: bool
):
    size = input_size
    channels, kernels, strides = [], [], []
    for _ in range(n_layers):
        s = random.choice([st for st in stride_choices if size >= 2*st] or [1])
        k = random.choice(kernel_choices)
        c = random.choice(channel_choices)
        strides.append(s)
        kernels.append(k)
        channels.append(c)
        size = max(1, size // s)
    return ModularCNN(channels, kernels, strides, use_leaky_relu, use_batchnorm, in_channels)

def generate_random_resnet(
    n_layers: int,
    channel_choices,
    input_size: int,
    in_channels: int,
    block_sizes,
    kernel_sizes,
    stride_choices,
    bottleneck_ratios,
    projection_types,
    activation_functions
):
    """
    Build a random ResNet with n_layers blocks. For each block we sample:
      - out_channels from channel_choices
      - num_blocks from block_sizes
      - kernel_size from kernel_sizes
      - stride from stride_choices (only if spatial size allows)
      - bottleneck_ratio from bottleneck_ratios
      - projection_type from projection_types
      - activation from activation_functions
    """
    size = input_size
    channels, blocks, ks, strides, brs, pts, afs = [], [], [], [], [], [], []

    # Sample up to n_layers (stop early if spatial collapses)
    for _ in range(n_layers):
        c  = random.choice(channel_choices)
        b  = random.choice(block_sizes)
        k  = random.choice(kernel_sizes)
        br = random.choice(bottleneck_ratios)
        pt = random.choice(projection_types)
        af = random.choice(activation_functions)

        # only allow stride>1 if size>=2*stride
        possible_strides = [s for s in stride_choices if size >= 2*s]
        s = random.choice(possible_strides) if possible_strides else 1

        channels.append(c)
        blocks.append(b)
        ks.append(k)
        strides.append(s)
        brs.append(br)
        pts.append(pt)
        afs.append(af)

        size = max(1, size // s)
        if size <= 1:
            break

    return ModularResNet(
        channels       = channels,
        block_sizes    = blocks,
        kernel_sizes   = ks,
        strides        = strides,
        bottleneck_ratios   = brs,
        projection_types    = pts,
        activation_functions= afs,
        in_channels    = in_channels
    )

# === Property Extraction ===
def extract_cnn_properties(model: ModularCNN):
    """
    Extracts layer-wise properties for a ModularCNN.
    Uses naming consistent with ResNet properties for shared concepts.
    """
    strides, kernels, channels, depths, overlaps = [], [], [], [], []
    for i, conv in enumerate(model.convs):
        k = model.kernel_sizes[i]
        s = model.strides[i]
        c = conv.out_channels
        strides.append(s)
        kernels.append(k)
        channels.append(c)
        depths.append(i+1)
        overlaps.append((k - s) / k)

    # Use 'kernel_size' and 'stride' to be consistent with ResNet extractor
    return {
        'stride': np.array(strides),
        'kernel_size': np.array(kernels),
        'c': np.array(channels),
        'inv_c': 1/np.array(channels),
        'd': np.array(depths),
        'o': np.array(overlaps),
        's_r': np.array([s*r for s,r in zip(strides, kernels)]),
        'arch_type': np.full_like(np.array(depths), 'cnn', dtype=object)
    }

def extract_resnet_properties(model: ModularResNet):
    ks = model.kernel_sizes
    ss = model.strides_list
    br = model.bottleneck_ratios
    pt = model.projection_types
    af = model.activation_functions
    skip_width, block_depth, layer_index, channels, res_type = [], [], [], [], []
    kernel_size, stride, bottleneck, projection, activation, depths = [], [], [], [], [], []

    # initial conv
    layer = 1
    channels.append(model.conv1.out_channels)
    skip_width.append(0)
    block_depth.append(0)
    layer_index.append(layer)
    res_type.append(0)
    kernel_size.append(3)
    stride.append(1)
    bottleneck.append(1.0)
    projection.append(0)
    activation.append(0)
    depths.append(layer)
    layer += 1

    for li, seq in enumerate(model.layers):
        for bj, blk in enumerate(seq):
            in_ch = blk.conv1.in_channels
            out_ch = blk.conv2.out_channels
            has_proj = not isinstance(blk.shortcut, nn.Identity)

            skip_width.append(out_ch - in_ch if has_proj else 0)
            block_depth.append(bj)
            layer_index.append(layer)
            res_type.append(1 if has_proj else 0)

            if bj == 0:
                kernel_size.append(ks[li])
                stride.append(ss[li])
                bottleneck.append(br[li])
                projection.append(1 if pt[li] != 'identity' else 0)
                activation.append(1 if af[li]=='leaky_relu' else 0)
            else:
                kernel_size.append(ks[li])
                stride.append(1)
                bottleneck.append(br[li])
                projection.append(0)
                activation.append(1 if af[li]=='leaky_relu' else 0)

            channels.append(out_ch)
            depths.append(layer)
            layer += 1

    return {
        'skip_width': np.array(skip_width),
        'block_depth': np.array(block_depth),
        'layer_index': np.array(layer_index),
        'c': np.array(channels),
        'inv_c': 1/np.array(channels),
        'res_type': np.array(res_type),
        'kernel_size': np.array(kernel_size),
        'stride': np.array(stride),
        'bottleneck': np.array(bottleneck),
        'projection': np.array(projection),
        'activation': np.array(activation),
        'd': np.array(depths),
        'arch_type': np.full_like(np.array(depths), 'resnet', dtype=object)
    }

def extract_model_properties(model):
    if isinstance(model, ModularCNN):
        return extract_cnn_properties(model)
    elif isinstance(model, ModularResNet):
        return extract_resnet_properties(model)
    else:
        raise ValueError("Unknown model type")
        
def extract_natural_features(model, raw_props_df: pd.DataFrame, input_shape: tuple) -> dict:
    """
    Calculates a rich, "natural" set of features for a given model,
    combining global budget, shape, and architecture-specific properties.
    Now uses MACs for computational budget.

    Args:
        model: The instantiated PyTorch model object.
        raw_props_df: The DataFrame containing the layer-wise properties for this single model.
        input_shape: A tuple representing the input shape (e.g., (1, 3, 44, 44)) for MACs calculation.

    Returns:
        A dictionary containing the engineered features for this one model.
    """
    # --- Basic Properties ---
    total_depth = len(raw_props_df)
    arch_type = raw_props_df['arch_type'].iloc[0]
    
    # --- 1. Computational Budget Features (Using MACs) ---
    # Create a dummy input tensor with the correct shape to profile the model
    dummy_input = torch.randn(input_shape)
    
    # Use thop to calculate MACs and params. model must be on CPU for thop.
    macs, params = profile(model.to('cpu'), inputs=(dummy_input,), verbose=False)
    
    features = {
        # --- THIS IS THE KEY CHANGE ---
        'total_macs_log': np.log1p(macs),
        'param_density': params / total_depth # Keep param_density as it's still interesting
    }
    
    # --- 2. Global Shape ("Pyramid") Features ---
    channels = raw_props_df['c']
    strides = raw_props_df['stride']
    
    features['initial_channels'] = channels.iloc[0]
    features['max_channels'] = channels.max()
    features['width_expansion_factor'] = features['max_channels'] / (features['initial_channels'] + 1e-8)
    
    downsampling_stages = (strides > 1).sum()
    features['downsampling_stages'] = downsampling_stages
    if downsampling_stages > 0:
        features['layers_per_stage'] = total_depth / downsampling_stages
    else:
        features['layers_per_stage'] = total_depth

    # --- 3. Flow Consistency Features ---
    channel_ratio = channels / channels.shift(1).fillna(channels.iloc[0])
    features['mean_channel_ratio'] = channel_ratio.mean()
    features['std_channel_ratio'] = channel_ratio.std()
    features['mean_stride'] = strides.mean()
    
    # --- 4. Architecture-Specific Features ---
    if arch_type == 'resnet':
        if 'projection' in raw_props_df and 'res_type' in raw_props_df:
            num_projections = raw_props_df['projection'].sum()
            num_blocks = (raw_props_df['res_type'] > 0).sum()
            num_identities = num_blocks - num_projections
            
            features['num_projection_shortcuts'] = num_projections
            features['num_identity_shortcuts'] = num_identities
            if num_blocks > 0:
                features['identity_shortcut_ratio'] = num_identities / num_blocks
            else:
                features['identity_shortcut_ratio'] = 0
            features['mean_bottleneck_ratio'] = raw_props_df['bottleneck'].mean()
        
    elif arch_type == 'cnn':
        if 'o' in raw_props_df:
            features['mean_overlap'] = raw_props_df['o'].mean()

    return features

# === Metrics ===

def compute_robust_lambda(model, x, device, num_noise_steps=10, max_noise_std=0.3, eps=1e-8):
    """
    Computes a robust, intrinsic lambda by fitting a sensitivity curve for each layer.
    It measures the amplification of perturbations across a spectrum of noise levels.

    Args:
        model: The PyTorch model.
        x: The input tensor.
        device: The device to run on.
        num_noise_steps: The number of noise levels to test.
        max_noise_std: The maximum noise standard deviation to test.

    Returns:
        np.ndarray: An array of robust lambda values for the model.
    """
    noise_levels_to_test = np.linspace(0.01, max_noise_std, num_noise_steps)
    all_deltas = []

    model.eval()
    with torch.no_grad():
        clean_acts = model(x)
        if isinstance(clean_acts, tuple):
            clean_acts = clean_acts[0]

        for noise_std in noise_levels_to_test:
            noisy_input = x + noise_std * torch.randn_like(x)
            noisy_acts = model(noisy_input)
            if isinstance(noisy_acts, tuple):
                noisy_acts = noisy_acts[0]

            deltas_for_level = torch.stack(
                [(n_act - c_act).flatten(1).norm(dim=1) for n_act, c_act in zip(noisy_acts, clean_acts)]
            ).mean(dim=1).cpu().numpy()
            all_deltas.append(deltas_for_level)
    
    all_deltas = np.array(all_deltas)
    num_layers = all_deltas.shape[1]
    robust_lambdas = []
    
    for i in range(num_layers - 1):
        x_data = all_deltas[:, i]
        y_data = all_deltas[:, i+1]
        
        try:
            # Fit a line: y = m*x. The slope 'm' is the amplification factor.
            # We force the intercept through zero, which is physically motivated.
            x_data_reshaped = x_data[:, np.newaxis]
            slope, _, _, _ = np.linalg.lstsq(x_data_reshaped, y_data, rcond=None)
            slope = slope[0]
        except np.linalg.LinAlgError:
            slope = 1.0

        amplification_factor = max(slope, eps)
        robust_lambdas.append(-np.log(amplification_factor))
        
    return np.array(robust_lambdas)


def compute_input_sensitivity(model, x):
    model.eval()
    xv = x.clone().requires_grad_(True)
    acts = model(xv)

    input_sens = []
    for act in acts:
        # exactly the same mathematical expression as before
        loss = act.square().sum()
        (grad,) = torch.autograd.grad(loss, xv, retain_graph=True)
        input_sens.append(grad.norm().item())

    return input_sens

def mi_permutation_test(X, y, n_perm=200):
    n_neighbors = max(1, min(2, len(y) - 1))  
    actual_mi = mutual_info_regression(X, y, random_state=0, n_neighbors=n_neighbors)[0]
    perm_mis = []
    for _ in range(n_perm):
        y_perm = np.random.permutation(y)
        perm_mi = mutual_info_regression(X, y_perm, random_state=0, n_neighbors=n_neighbors)[0]
        perm_mis.append(perm_mi)
    pval = (np.sum(np.array(perm_mis) >= actual_mi) + 1) / (n_perm + 1)
    return actual_mi, pval

# === Regression & Analysis ===

def plot_key_results(results, title_suffix=""):
    # 1) Input Sensitivity vs Depth
    grouped = results['grouped']
    plt.figure(figsize=(8,6))
    plt.errorbar(grouped['depth'], grouped['mean'], yerr=grouped['std'], fmt='o-', capsize=3)
    plt.xlabel("Depth")
    plt.ylabel("Mean Input Sensitivity")
    plt.title(f"Input Sensitivity vs Depth{title_suffix}")
    plt.grid(True)
    plt.show()

    # 2) σ² vs 1/(L–1) Regression
    inv   = results['inv_Lminus1']      # array of 1/(L_i - 1)
    sigma2 = results['sigma2']          # array of σ_i²
    slope, intercept, r_value, p_value, stderr = linregress(inv, sigma2)

    plt.figure(figsize=(8,6))
    plt.scatter(inv, sigma2, alpha=0.5, label="data")
    xs = np.linspace(inv.min(), inv.max(), 200)
    plt.plot(xs, intercept + slope*xs, 'r--',
             label=f"fit: y={slope:.2e}x+{intercept:.2e}\n$R^2$={r_value**2:.2f}")
    plt.xlabel("1 / (L - 1)")
    plt.ylabel("σ²")
    plt.title(f"σ² vs 1/(L–1){title_suffix}")
    plt.legend()
    plt.grid(True)
    plt.show()
    
def compute_global_complexity(raw_props, n_components=1):
    """
    Compute a 1D "ACS" per model as the projection onto the first principal component
    of the global (minmax‐scaled) parameter space.
    """
    # --- THIS IS THE FIX ---
    # 1) Get keys for numeric columns only, filtering out strings like 'arch_type'.
    numeric_keys = [k for k, v in raw_props.items() if np.issubdtype(v.dtype, np.number)]
    if not numeric_keys:
        # If no numeric columns are found, return an array of zeros.
        num_models = len(next(iter(raw_props.values())))
        return np.zeros(num_models)
        
    # Stack parameters into (n_models, n_features) using only numeric keys
    X = np.vstack([ raw_props[k] for k in numeric_keys ]).T

    # 2) Min–max scale each column to [0,1]
    # Add a small epsilon to the denominator to prevent division by zero
    # if a feature is constant across all models.
    mins = X.min(axis=0, keepdims=True)
    maxs = X.max(axis=0, keepdims=True)
    X_norm = (X - mins) / (maxs - mins + 1e-8)

    # 3) PCA to 1 component
    pca = PCA(n_components=n_components)
    pcs = pca.fit_transform(X_norm)  # shape (n_models, 1)

    # 4) Return the first principal component (flattened)
    return pcs[:, 0]

def classify_models(complexity_scores, percentile=10):
    sorted_idx = np.argsort(complexity_scores)
    n = len(complexity_scores)
    labels = np.array(['unlabeled'] * n)
    k = max(1, int(n * percentile / 100))
    labels[sorted_idx[:k]] = 'simple'
    labels[sorted_idx[-k:]] = 'complex'
    return labels

def bin_and_validate(input_sens_means, complexity_labels, n_bins=10):
    bins = np.linspace(min(input_sens_means), max(input_sens_means), n_bins + 1)
    bin_indices = np.digitize(input_sens_means, bins) - 1
    valid_bins = []
    for bin_id in range(n_bins):
        idx = np.where(bin_indices == bin_id)[0]
        labels_in_bin = complexity_labels[idx]
        simple_count = np.sum(labels_in_bin == 'simple')
        complex_count = np.sum(labels_in_bin == 'complex')
        if len(idx) >= 5 and simple_count >= 3 and complex_count >= 3:
            valid_bins.append({
                'indices': idx,
                'simple_count': simple_count,
                'complex_count': complex_count,
                'bin_id': bin_id
            })
    return valid_bins

def compute_bin_metrics(valid_bins, lambda_vals, complexity_scores, raw_props, complexity_labels):
    """
    For each valid FI‐bin, compute:
      - CV of σ
      - Mutual Information on ACS and on all_params (with p‐values)
      - σ(simple)/σ(complex) ratio and its KS‐test p‐value
    Returns a pandas.DataFrame with columns:
      ['IS_bin', '#simple', '#complex', 'CV',
       'ratio', 'pval_ratio',
       'MI_ACS', 'MI_pval_ACS',
       'MI_all_params', 'MI_pval_all_params']
    """
    metrics = []
    # --- THIS IS THE FIX ---
    # Filter for numeric keys once, outside the loop.
    numeric_keys = [k for k, v in raw_props.items() if np.issubdtype(v.dtype, np.number)]
    
    for b in valid_bins:
        idx = b['indices']

        # compute σ for all models in this bin
        sigma_l = np.array([np.std(lambda_vals[i]) for i in idx])

        # coefficient of variation
        CV = sigma_l.std() / (sigma_l.mean() + 1e-12)

        # --- simple vs complex σ ---
        lam_s = [np.std(lambda_vals[i]) for i in idx if complexity_labels[i] == 'simple']
        lam_c = [np.std(lambda_vals[i]) for i in idx if complexity_labels[i] == 'complex']
        # KS‐test on the two distributions
        _, p_val = ks_2samp(lam_s, lam_c)
        # ratio of means
        ratio = np.nanmean(lam_s) / np.nanmean(lam_c)

        # mutual information tests
        acs         = complexity_scores[idx].reshape(-1, 1)
        
        # --- THIS IS THE FIX ---
        # Build all_params using only the pre-filtered numeric keys.
        all_params  = np.array([[raw_props[k][i] for k in numeric_keys] for i in idx])
        
        mi_acs, p_acs       = mi_permutation_test(acs, sigma_l)
        mi_params, p_params = mi_permutation_test(all_params, sigma_l)

        metrics.append({
            'IS_bin':            b['bin_id'],
            '#simple':           b['simple_count'],
            '#complex':          b['complex_count'],
            'CV':                CV,
            'ratio':             ratio,
            'pval_ratio':        p_val,
            'MI_ACS':            mi_acs,
            'MI_pval_ACS':       p_acs,
            'MI_all_params':     mi_params,
            'MI_pval_all_params': p_params
        })

    return pd.DataFrame(metrics)

def train_and_analyze_stability_model(X: pd.DataFrame, y: np.ndarray):
    """
    Trains a Gradient Boosting model to predict sigma (RG flow stability)
    from engineered architectural features and returns feature importances.

    Args:
        X: DataFrame of engineered features (one row per model).
        y: NumPy array of sigma values (std(lambda_i)) for each model.

    Returns:
        A tuple containing:
        - A DataFrame of sorted feature importances.
        - The mean cross-validated R^2 score of the model.
    """
    if len(X) != len(y):
        raise ValueError("Mismatch between number of models in X and y.")
    if len(X) < 10: # Not enough data to train a meaningful model
        return pd.DataFrame(), np.nan

    # Define the model
    gbr = GradientBoostingRegressor(n_estimators=50, max_depth=3, random_state=0)

    # Calculate cross-validated R^2 score to see how well we can predict sigma
    try:
        cv_scores = cross_val_score(gbr, X, y, cv=min(5, len(X)), scoring='r2')
        mean_sigma_r2 = np.mean(cv_scores)
    except Exception:
        mean_sigma_r2 = np.nan

    # Train the model on all data to get final feature importances
    gbr.fit(X, y)

    # Create a DataFrame of feature importances
    importances = pd.DataFrame({
        'feature': X.columns,
        'importance': gbr.feature_importances_
    }).sort_values(by='importance', ascending=False).reset_index(drop=True)

    return importances, mean_sigma_r2

# === Experiment Runner ===
def run_experiment(cfg):
    # ───────────────────────────── SET-UP ──────────────────────────────
    set_random_seed(cfg['seed'])
    ds = cfg['dataset'].lower()
    base, ch = (32, 3) if ds == 'cifar' else (28, 1)
    scale = random.uniform(*cfg['resize_range'])
    tgt = int(base * scale)

    print(f"Running run_id       : {cfg['run_id']}")
    print(f"\nResize scale factor  : {scale:.2f}  →  target size {tgt}")

    data = get_data(ds, tgt)
    n = max(1, int(len(data) * cfg['data_fraction']))
    idx = np.random.choice(len(data), n, replace=False)
    samples = [data[i][0] for i in idx]
    x = torch.stack(samples, dim=0).to(device)
    # noise_std is now IGNORED by the lambda calculation but kept for compatibility
    input_size, in_channels = tgt, ch

    print(f"Input shape          : {tuple(x.shape)}")

    # ───────────────────── GENERATE & MEASURE MODELS ────────────────────
    data_pts, input_sens_vals, lambda_vals = [], [], []
    lambda_stds, model_depths          = [], []

    all_natural_features = []

    def _measure_model(model):
        
        lam = compute_robust_lambda(model, x, device)

        if lam.size == 0 or len(lam) < 3:
            return None
        with autocast(enabled=False):
            input_sens = compute_input_sensitivity(model, x)
        props = extract_model_properties(model)
        ml = min(len(lam)+1, len(input_sens))
        for k in props:
            props[k] = props[k][:ml]
        rows = []
        for j in range(ml-1):
            rows.append({**{k: props[k][j] for k in props}, 'empirical': float(lam[j])})
        return lam, input_sens, rows, props

    model_id_counter = 0
    for _ in tqdm(range(cfg['n_models']), desc="Measuring models"):
        depth = random.choice(cfg['depth_choices'])
        if cfg['architecture'] == 'cnn':
            model = generate_random_cnn(
                depth,
                cfg['param_ranges']['strides'],
                cfg['param_ranges']['kernels'],
                cfg['param_ranges']['channels'],
                input_size, in_channels,
                cfg.get('use_leaky_relu', True),
                cfg.get('use_batchnorm', False)
            ).to(device)
        else:
            pr = cfg['param_ranges']
            model = generate_random_resnet(
                depth, pr['channels'], input_size, in_channels,
                pr['block_sizes'], pr['kernel_sizes'], pr['strides'],
                pr['bottleneck_ratios'], pr['projection_types'],
                pr['activation_functions']
            ).to(device)
        
        res = _measure_model(model)
        
        if res is None:
            del model
            continue
            
        lam, input_sens, rows, raw_props_dict = res
        
        for row in rows:
            row['model_id'] = model_id_counter

        lambda_stds.append(float(np.std(lam)))
        model_depths.append(len(lam) + 1)
        input_sens_vals.append(input_sens[:-1])
        lambda_vals.append(lam)
        data_pts.extend(rows)

        raw_props_df = pd.DataFrame(raw_props_dict)
        # Define the input shape for profiling (batch size of 1 is sufficient)
        profile_input_shape = (1, in_channels, input_size, input_size)
        natural_features = extract_natural_features(model, raw_props_df, profile_input_shape)
        all_natural_features.append(natural_features)
        
        model_id_counter += 1
        del model

    torch.cuda.empty_cache()

    if not data_pts:
        print("No valid data collected for this run.")
        return {
            'run_id': cfg['run_id'],
            'architecture': cfg['architecture'],
            'dataset': cfg['dataset']
        }

    n_models_valid = len(lambda_stds)
    print(f"Generated models accepted       : {n_models_valid}")
    print(f"Unique datapoints               : {len(set(tuple(d.items()) for d in data_pts))}")

    df          = pd.DataFrame(data_pts)
    raw_props   = {c: df[c].values for c in df.columns if c != 'empirical' and c != 'model_id'}
    empirical   = df['empirical'].values
    input_sens_flat = np.concatenate([f for f in input_sens_vals if f])
    lambda_flat = np.concatenate([l for l in lambda_vals if l.size > 0])

    complexity_scores = compute_global_complexity(raw_props)
    percentile = cfg.get("complexity_percentile", 10)
    complexity_labels = classify_models(complexity_scores, percentile)
    input_sens_means = np.array([np.mean(f) for f in input_sens_vals if f])
    valid_bins = bin_and_validate(input_sens_means, complexity_labels)
    per_bin_metrics = compute_bin_metrics(valid_bins, lambda_vals, complexity_scores, raw_props, complexity_labels)
    
    ratios = []
    pvals  = []
    for b in valid_bins:
       lam_s = [ np.std(lambda_vals[i]) for i in b['indices'] if complexity_labels[i]=='simple' ]
       lam_c = [ np.std(lambda_vals[i]) for i in b['indices'] if complexity_labels[i]=='complex' ]
       if lam_s and lam_c:
           _, p = ks_2samp(lam_s, lam_c)
           ratios.append(np.nanmean(lam_s) / np.nanmean(lam_c))
           pvals.append(p)
    mean_ratio = np.nanmean(ratios) if ratios else np.nan
    median_p   = np.nanmedian(pvals) if pvals else np.nan

    slope_input_sens = r_value_input_sens = np.nan
    grp = None
    try:
        if 'd' in raw_props and len(raw_props['d']) == len(input_sens_flat):
            grp_df = pd.DataFrame({'depth': raw_props['d'], 'input_sens': input_sens_flat})
            grp = grp_df.groupby('depth')['input_sens'].agg(['mean','std']).reset_index()
            log_d2 = np.log(grp['depth'] + 1e-8)
            log_f2 = np.log(grp['mean']  + 1e-8)
            slope_input_sens, _, r_value_input_sens, _, _ = linregress(log_d2, log_f2)
    except Exception:
        grp = None
    
    print("\n--- Supervised RG-Flow Stability Analysis ---")
    feature_importances, sigma_r2 = pd.DataFrame(), np.nan
    if n_models_valid > 10:
        X_eng = pd.DataFrame(all_natural_features).fillna(0)
        X_eng['mean_IS'] = input_sens_means
        y_sigma = np.array(lambda_stds)
        
        feature_importances, sigma_r2 = train_and_analyze_stability_model(X_eng, y_sigma)
        
        if not feature_importances.empty:
            print(f"Stability vs Architecture Model R^2 (sigma): {sigma_r2:.3f}")
            print("Most Important Architectural Features for Predicting RG-Flow Stability (sigma):")
            print(feature_importances.to_string(index=False))
        else:
            print("Not enough models or data to run supervised analysis.")
    else:
        print("Skipping supervised analysis due to insufficient models.")

    print("\nIS Scaling with Depth")
    if not np.isnan(slope_input_sens):
        print(f"  IS   ~ depth^{slope_input_sens:.3f}   (R² = {r_value_input_sens**2:.3f})")
        
    if valid_bins:
        bin_centers = [np.mean([input_sens_means[i] for i in b['indices']]) for b in valid_bins]
        all_complexities = [complexity_scores[i] for i in b['indices']]
        all_sigmas      = [np.std(lambda_vals[i]) for i in b['indices']]
        print("\nBin Analysis Summary")
        print(f"  Number of valid IS-bins   : {len(valid_bins)}")
        if bin_centers:
            print(f"  IS-bins range             : [{min(bin_centers):.3f}, {max(bin_centers):.3f}]")
        if all_complexities:
            print(f"  Complexity score range    : [{min(all_complexities):.3f}, {max(all_complexities):.3f}]")
        if all_sigmas:
            print(f"  σ range                   : [{min(all_sigmas):.3f}, {max(all_sigmas):.3f}]")
    else:
        print("\nBin Analysis Summary")
        print("  No valid IS bins — cannot compute summary ranges.")

    print("\nPer-Bin Table (one row per valid bin)")
    print("|   IS   | simple | complex |   CV    | Ratio σ | pval_ratio |  MI_ACS  | pval_ACS | MI_params | pval_params |")
    print("|--------|--------|---------|---------|---------|------------|----------|----------|-----------|-------------|")
    if not per_bin_metrics.empty:
        bin_centers_lookup = {b['bin_id']: np.mean([input_sens_means[i] for i in b['indices']]) for b in valid_bins}
        for _, row in per_bin_metrics.iterrows():
            center = bin_centers_lookup.get(row['IS_bin'], 0)
            print(f"| {center:6.2f} | {int(row['#simple']):6d} | {int(row['#complex']):7d} | {row['CV']:7.3f} | {row['ratio']:7.3f} | {row['pval_ratio']:10.3f} | {row['MI_ACS']:8.3f} | {row['MI_pval_ACS']:8.3f} | {row['MI_all_params']:9.3f} | {row['MI_pval_all_params']:11.3f} |")

    print("\nRun-Level Metrics Table")
    if per_bin_metrics.empty:
        mean_cv = max_cv = mean_mi_acs = median_mi_acs = max_mi_acs = np.nan
        mean_mi_params = median_mi_params = max_mi_params = np.nan
    else:
        mean_cv = per_bin_metrics['CV'].mean()
        max_cv = per_bin_metrics['CV'].max()
        mean_mi_acs = per_bin_metrics['MI_ACS'].mean()
        median_mi_acs = per_bin_metrics['MI_ACS'].median()
        max_mi_acs = per_bin_metrics['MI_ACS'].max()
        mean_mi_params = per_bin_metrics['MI_all_params'].mean()
        median_mi_params = per_bin_metrics['MI_all_params'].median()
        max_mi_params = per_bin_metrics['MI_all_params'].max()

    std_ratio = np.nanstd(ratios) if ratios else np.nan
    header_fmt = "| {0:35s} | {1:43s} |"
    row_fmt    = "| {0:35s} | {1:43s} |"
    print(header_fmt.format("Metric", "Value"))
    print("|" + "-"*37 + "|" + "-"*45 + "|")
    print(row_fmt.format("Mean CV(bin) ± SD", f"{mean_cv:.3f} ± {per_bin_metrics['CV'].std():.3f}" if not np.isnan(mean_cv) else "n/a"))
    print(row_fmt.format("Max CV(bin)", f"{max_cv:.3f}" if not np.isnan(max_cv) else "n/a"))
    print(row_fmt.format("Ratio σ(simple/complex) ± SD", f"{mean_ratio:.3f} ± {std_ratio:.3f}" if not np.isnan(mean_ratio) else "n/a"))
    print(row_fmt.format("KS p-value σ(simple/complex)", f"{median_p:.3f}" if not np.isnan(median_p) else "n/a"))
    print(row_fmt.format("MI ACS mean ± SD", f"{mean_mi_acs:.3f} ± {per_bin_metrics['MI_ACS'].std():.3f}" if not np.isnan(mean_mi_acs) else "n/a"))
    print(row_fmt.format("MI all_params mean ± SD", f"{mean_mi_params:.3f} ± {per_bin_metrics['MI_all_params'].std():.3f}" if not np.isnan(mean_mi_params) else "n/a"))
        
    ratios_stability = (np.array(lambda_stds)**2) * (np.array(model_depths) - 1)
    mean_sigma_stability, std_sigma_stability = np.mean(ratios_stability), np.std(ratios_stability)
    cv_sigma_stability = std_sigma_stability / (mean_sigma_stability + 1e-12)
    
    print("\nRun-level Geometric Stability")
    print(f"  Mean σ²×(L-1): {mean_sigma_stability:.3e}")
    print(f"  Std  σ²×(L-1): {std_sigma_stability:.3e}")
    print(f"  CV   σ²×(L-1): {cv_sigma_stability:.3f}")
    
    if cfg.get('make_plots', False) and (grp is not None):
        inv_Lminus1 = 1.0 / (np.array(model_depths) - 1)
        sigma2      = np.array(lambda_stds)**2
        plot_key_results({'grouped': grp, 'inv_Lminus1': inv_Lminus1, 'sigma2': sigma2}, title_suffix = f" ({cfg['run_id']})")

    final_results = {
        'run_id': cfg['run_id'],
        'architecture': cfg['architecture'],
        'dataset': cfg['dataset'],
        'n_models_accepted': n_models_valid,
        'n_models_used_in_bins': sum(len(b['indices']) for b in valid_bins),
        'n_lambda_vals': len(lambda_flat),
        'n_input_sens_vals': len(input_sens_flat),
        'n_IS_bins_attempted': cfg.get('n_variance_levels', 20),
        'n_IS_bins_valid': len(valid_bins),
        'input_sens_depth_slope': slope_input_sens,
        'input_sens_depth_r2': r_value_input_sens**2 if not np.isnan(r_value_input_sens) else np.nan,
        'mean_sigma_ratio_simple_to_complex': mean_ratio,
        'ks_pvalue_sigma_ratio': median_p,
        'cv_bins_mean': mean_cv,
        'cv_bins_std':  np.nan if per_bin_metrics.empty else per_bin_metrics['CV'].std(),
        'cv_bins_max':  max_cv,
        'mean_mi_acs': mean_mi_acs,
        'median_mi_acs': median_mi_acs,
        'max_mi_acs': max_mi_acs,
        'mean_mi_all_params': mean_mi_params,
        'median_mi_all_params': median_mi_params,
        'max_mi_all_params': max_mi_params,
        'mean_sigma2_ratio': mean_sigma_stability,
        'std_sigma2_ratio': std_sigma_stability,
        'cv_sigma2_ratio': cv_sigma_stability,
        'rg_flow_model_r2': sigma_r2,
        'rg_flow_feature_importances': feature_importances.to_dict('records') if not feature_importances.empty else [],
    }
    
    return final_results

def analyze_experiment_results(results_list):
    import numpy as np
    from scipy import stats
    from collections import defaultdict

    # Helper to compute mean ± 95% CI
    def mean_ci(arr):
        a = np.array(arr, float)
        a = a[~np.isnan(a)]
        if len(a) < 2:
            return None
        m = a.mean()
        se = stats.sem(a)
        h = stats.t.ppf(0.975, len(a)-1) * se
        return m, m-h, m+h

    # Format an array as mean ± half-width (95% CI, n=…)
    def fmt_ci(arr):
        ci = mean_ci(arr)
        cnt = np.count_nonzero(~np.isnan(arr))
        if ci is None:
            return "n/a"
        m, lo, hi = ci
        return f"{m:.3f} ± {(m-lo):.3f} (95% CI, n={cnt})"

    # Format p-values as median [IQR] (n=…)
    def fmt_p(arr):
        a = np.array(arr, float)
        a = a[~np.isnan(a)]
        if len(a) == 0:
            return "n/a"
        q25, q75 = np.percentile(a, [25, 75])
        return f"{np.median(a):.3f} [IQR {q25:.3f}, {q75:.3f}] (n={len(a)})"

    # Simple mean±SD
    def fmt_sd(arr):
        a = np.array(arr, float)
        a = a[~np.isnan(a)]
        if len(a) == 0:
            return "n/a"
        return f"{a.mean():.3f} ± {a.std():.3f}"

    n_runs = len(results_list)

    # collect arrays across all runs, using .get to default to nan
    ratio_arr   = [r.get('mean_sigma_ratio_simple_to_complex',   np.nan) for r in results_list]
    pval_arr    = [r.get('ks_pvalue_sigma_ratio',               np.nan) for r in results_list]
    cv_arr      = [r.get('cv_bins_mean',                        np.nan) for r in results_list]
    cvmax_arr   = [r.get('cv_bins_max',                         np.nan) for r in results_list]
    mi_acs_arr  = [r.get('mean_mi_acs',                         np.nan) for r in results_list]
    mi_ap_arr   = [r.get('mean_mi_all_params',                  np.nan) for r in results_list]
    s2_m_arr    = [r.get('mean_sigma2_ratio',                   np.nan) for r in results_list]
    s2_s_arr    = [r.get('std_sigma2_ratio',                    np.nan) for r in results_list]
    s2_cv_arr   = [r.get('cv_sigma2_ratio',                     np.nan) for r in results_list]

    # --- CROSS-RUN SUMMARY ---
    print("\n===== CROSS-RUN SUMMARY =====\n")
    # Metadata
    tm = sum(r.get('n_models_accepted',0) for r in results_list)
    tl = sum(r.get('n_lambda_vals',0)       for r in results_list)
    tf = sum(r.get('n_input_sens_vals',0)       for r in results_list)
    ba = sum(r.get('n_IS_bins_attempted',0) for r in results_list)
    bv = sum(r.get('n_IS_bins_valid',0)     for r in results_list)
    ub = np.nanmean([r.get('n_models_used_in_bins', np.nan) for r in results_list])
    print("Metadata:")
    print(f"  Total models accepted          : {tm}")
    print(f"  Total λᵢ values                : {tl}")
    print(f"  Total IS values            : {tf}")
    print(f"  Valid IS bins                  : {bv}")
    print(f"  Total IS bins attempted        : {ba}")
    print(f"  Avg models per run             : {tm/n_runs:.2f}")
    print(f"  Avg λᵢ per model               : {tl/tm:.2f}")
    print(f"  Avg models used-in-bins per run: {ub:.2f}\n")

    # 1. Per-IS-bin σ Variation
    print("1. Per-IS-bin σ Variation (Old Method):")
    print(f"   Ratio σ(simple/complex)  : {fmt_ci(ratio_arr)}")
    print(f"   KS-p-value summary       : {fmt_p(pval_arr)}")
    print(f"   Runs with p<0.05         : {int(np.sum(np.array(pval_arr)<0.05))} / {n_runs}\n")

    # 2. Per-IS-bin CV(σ)
    print("2. Per-IS-bin CV(σ):")
    print(f"   Mean CV(bin) : {fmt_ci(cv_arr)}")
    print(f"   Max  CV(bin) : {fmt_ci(cvmax_arr)}\n")

    # 3. Mutual Information ACS
    print("3. Mutual Information ACS (Old Method):")
    if all(np.isnan(mi_acs_arr)):
        print("   MI_ACS: n/a\n")
    else:
        print(f"   Mean ± SD     : {fmt_sd(mi_acs_arr)}")
        med = np.nanmedian(mi_acs_arr)
        iqr = (np.nanpercentile(mi_acs_arr,25), np.nanpercentile(mi_acs_arr,75))
        print(f"   Median [IQR]  : {med:.3f} [{iqr[0]:.3f}–{iqr[1]:.3f}]")
        print(f"   Max           : {np.nanmax(mi_acs_arr):.3f}\n")

    # 4. Mutual Information all_params
    print("4. Mutual Information all_params (Old Method):")
    if all(np.isnan(mi_ap_arr)):
        print("   MI_all_params: n/a\n")
    else:
        print(f"   Mean ± SD     : {fmt_sd(mi_ap_arr)}")
        med = np.nanmedian(mi_ap_arr)
        iqr = (np.nanpercentile(mi_ap_arr,25), np.nanpercentile(mi_ap_arr,75))
        print(f"   Median [IQR]  : {med:.3f} [{iqr[0]:.3f}–{iqr[1]:.3f}]")
        print(f"   Max           : {np.nanmax(mi_ap_arr):.3f}\n")

    # 5. Stability σ²×(L-1)
    print("5. Stability σ²×(L-1):")
    print(f"   Mean : {fmt_ci(s2_m_arr)}")
    print(f"   Std  : {fmt_ci(s2_s_arr)}")
    print(f"   CV   : {fmt_ci(s2_cv_arr)}\n")

    # --- NEW SECTION 6: RG-FLOW ANALYSIS ---
    print("6. RG-Flow Stability Prediction (New Supervised Method):")
    r2_arr = [r.get('rg_flow_model_r2', np.nan) for r in results_list]
    print(f"   Mean Model R^2 (across runs) : {fmt_ci(r2_arr)}")

    agg_importances = defaultdict(list)
    for r in results_list:
        if 'rg_flow_feature_importances' in r and r['rg_flow_feature_importances']:
            for item in r['rg_flow_feature_importances']:
                agg_importances[item['feature']].append(item['importance'])
    
    if agg_importances:
        mean_imps = {name: np.mean(vals) for name, vals in agg_importances.items()}
        sorted_imps = sorted(mean_imps.items(), key=lambda item: item[1], reverse=True)
        
        print("\n   Aggregated Feature Importances (Top 5):")
        print("   " + "-"*45)
        print(f"   {'Feature':<30} | {'Mean Importance':<15}")
        print("   " + "-"*45)
        for feature, importance in sorted_imps[:5]:
            print(f"   {feature:<30} | {importance:<15.4f}")
        print("   " + "-"*45 + "\n")
    else:
        print("   No feature importance data to report.\n")


    # --- PER-ARCH/DATASET BREAKDOWN ---
    # (This section is unchanged, but will now implicitly benefit from the new data)
    print("\n===== PER-ARCH/DATASET BREAKDOWN =====\n")
    by_group = defaultdict(list)
    for r in results_list:
        by_group[(r.get('architecture','?'), r.get('dataset','?'))].append(r)

    for (arch, ds), grp in by_group.items():
        print(f"--- {arch.upper()} / {ds.upper()} (n={len(grp)}) ---")
        # You can add a mini version of the RG-Flow analysis here if desired,
        # but for simplicity, we'll leave this section as is. The overall
        # summary in section 6 is the most important part.
        g_ratio = [g.get('mean_sigma_ratio_simple_to_complex', np.nan) for g in grp]
        g_pval  = [g.get('ks_pvalue_sigma_ratio',             np.nan) for g in grp]
        g_cv    = [g.get('cv_bins_mean',                       np.nan) for g in grp]
        g_cvmax = [g.get('cv_bins_max',                        np.nan) for g in grp]
        g_s2_m  = [g.get('mean_sigma2_ratio',                  np.nan) for g in grp]
        g_s2_s  = [g.get('std_sigma2_ratio',                   np.nan) for g in grp]
        g_s2_cv = [g.get('cv_sigma2_ratio',                    np.nan) for g in grp]
        g_r2    = [g.get('rg_flow_model_r2',                   np.nan) for g in grp]

        print("1. Per-IS-bin σ Variation:")
        print(f"   Ratio σ(simple/complex)  : {fmt_ci(g_ratio)}")
        print(f"   KS-p-value summary       : {fmt_p(g_pval)}\n")
        
        print("2. Per-IS-bin CV(σ):")
        print(f"   Mean CV(bin) : {fmt_ci(g_cv)}")
        print(f"   Max  CV(bin) : {fmt_ci(g_cvmax)}\n")

        print("3. Stability σ²×(L-1):")
        print(f"   Mean : {fmt_ci(g_s2_m)}")
        print(f"   Std  : {fmt_ci(g_s2_s)}")
        print(f"   CV   : {fmt_ci(g_s2_cv)}\n")
        
        print("4. RG-Flow Stability Prediction (R^2):")
        print(f"   Mean Model R^2 : {fmt_ci(g_r2)}\n")


    return {
        "overall": {
            "ratio_CI": fmt_ci(ratio_arr),
            "ks_p":     fmt_p(pval_arr),
            "cv_mean":  fmt_ci(cv_arr),
            "cv_max":   fmt_ci(cvmax_arr),
            "mi_acs":   fmt_sd(mi_acs_arr) if not all(np.isnan(mi_acs_arr)) else "n/a",
            "mi_ap":    fmt_sd(mi_ap_arr)  if not all(np.isnan(mi_ap_arr))  else "n/a",
            "s2_mean":  fmt_ci(s2_m_arr),
            "s2_std":   fmt_ci(s2_s_arr),
            "s2_cv":    fmt_ci(s2_cv_arr),
            "rg_flow_r2": fmt_ci(r2_arr),
        }
    }

def run_statistical_consistency_experiments(
    seeds,
    architectures,
    datasets,
    data_fraction,
    # The 'noise_levels' parameter is no longer used and can be removed
    n_models=100,
    complexity_percentile=10,
    n_variance_levels=20,
    n_bootstrap=50,
    make_plots=False,
    architecture_configs=None
):
    """
    Runs experiments over all combinations of seed, architecture, and dataset.
    The notion of a single noise_level is removed in favor of a robust lambda calculation.
    """
    # Remove `noise_levels` from the total calculation and the inner loop
    total = len(seeds) * len(architectures) * len(datasets)
    count = 0
    all_results = []

    for seed in seeds:
        for arch in architectures:
            for ds in datasets:
                count += 1
                print("\n" + "=" * 60)
                print(f"Experiment {count}/{total}: "
                      f"arch={arch}, dataset={ds}, seed={seed}")

                # The config no longer needs a specific `noise_std`
                cfg = {
                    'seed': seed,
                    'architecture': arch,
                    'dataset': ds,
                    'data_fraction': data_fraction,
                    'n_models': n_models,
                    'complexity_percentile': complexity_percentile,
                    'n_variance_levels': n_variance_levels,
                    'n_bootstrap': n_bootstrap,
                    'make_plots': make_plots,
                    'run_id': f"{arch}_{ds}_s{seed}",
                    **architecture_configs[arch]
                }

                res = run_experiment(cfg)
                all_results.append(res)

    analysis = analyze_experiment_results(all_results)
    return all_results, analysis

Using device: cuda


In [4]:
# architecture_configs
architecture_configs = {
    'cnn': {
        'depth_choices': [4, 6, 8, 10, 20, 40, 80, 100],
        'resize_range': (1.0, 4.0),
        'param_ranges': {
            'strides':  [1, 2, 3],
            'kernels':  [3, 5, 7, 9, 11, 13, 15, 17],
            'channels': [8, 16, 32, 64, 128, 256, 512, 1024]
        }
    },
    'resnet': {
        'depth_choices': [4, 6, 8, 10, 20, 40, 80, 100],
        'resize_range': (1.0, 4.0),
        'param_ranges': {
            'channels':          [16, 32, 64, 128, 256, 512, 1024],
            'block_sizes':       [1, 2, 3, 4, 5],
            'kernel_sizes':      [1, 3, 5, 7, 9],
            'strides':           [1, 2, 3],
            'bottleneck_ratios': [1, 0.75, 0.5, 0.25, 0.125],
            'projection_types':  ['identity', 'conv1x1'],
            'activation_functions': ['relu', 'leaky_relu']
        }
    }
}

# call to main script:
results, analysis = run_statistical_consistency_experiments(
    seeds=[1, 12],
    architectures=['cnn', 'resnet'],
    datasets=['cifar', 'mnist'],
    data_fraction=0.0004,          # proportion of test set to sample
    n_models=30,                  # how many random nets per run
    complexity_percentile=25,      # top/bottom 25% are “complex”/“simple”
    n_variance_levels=10,          # how many IS‐bins to attempt
    n_bootstrap=20,
    make_plots=False,               # toggle your two matplotlib plots
    architecture_configs=architecture_configs
)

print(analysis)


Experiment 1/8: arch=cnn, dataset=cifar, seed=1
Running run_id       : cnn_cifar_s1

Resize scale factor  : 1.40  →  target size 44
Files already downloaded and verified
Input shape          : (4, 3, 44, 44)


Measuring models: 100%|██████████| 30/30 [01:56<00:00,  3.89s/it]


Generated models accepted       : 30
Unique datapoints               : 1074

--- Supervised RG-Flow Stability Analysis ---
Stability vs Architecture Model R^2 (sigma): 0.901
Most Important Architectural Features for Predicting RG-Flow Stability (sigma):
               feature  importance
      layers_per_stage    0.553929
           mean_stride    0.416045
          mean_overlap    0.008549
    mean_channel_ratio    0.006241
      initial_channels    0.004428
               mean_IS    0.004309
         param_density    0.002003
          max_channels    0.001375
        total_macs_log    0.001192
width_expansion_factor    0.001025
     std_channel_ratio    0.000888
   downsampling_stages    0.000016

IS Scaling with Depth
  IS   ~ depth^-3.884   (R² = 0.580)

Bin Analysis Summary
  Number of valid IS-bins   : 1
  IS-bins range             : [36.272, 36.272]
  Complexity score range    : [-0.833, 0.795]
  σ range                   : [1.207, 8.494]

Per-Bin Table (one row per valid bin)


Measuring models: 100%|██████████| 30/30 [01:45<00:00,  3.51s/it]


Generated models accepted       : 30
Unique datapoints               : 1074

--- Supervised RG-Flow Stability Analysis ---
Stability vs Architecture Model R^2 (sigma): 0.670
Most Important Architectural Features for Predicting RG-Flow Stability (sigma):
               feature  importance
      layers_per_stage    0.767296
      initial_channels    0.073599
           mean_stride    0.059844
width_expansion_factor    0.026406
    mean_channel_ratio    0.021729
        total_macs_log    0.020855
          mean_overlap    0.020091
               mean_IS    0.004049
     std_channel_ratio    0.002830
         param_density    0.002509
          max_channels    0.000793
   downsampling_stages    0.000000

IS Scaling with Depth
  IS   ~ depth^-3.868   (R² = 0.571)

Bin Analysis Summary
  Number of valid IS-bins   : 1
  IS-bins range             : [33.684, 33.684]
  Complexity score range    : [-0.833, 0.795]
  σ range                   : [1.373, 8.286]

Per-Bin Table (one row per valid bin)


Measuring models:  43%|████▎     | 13/30 [02:23<03:08, 11.07s/it]


KeyboardInterrupt: 

In [7]:
# ===================================================================
# Part 1: Imports and Necessary Building Blocks (Unchanged)
# ===================================================================
import time
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from thop import profile
from tqdm import tqdm
import matplotlib.pyplot as plt

# Choose GPU if available, else CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def set_random_seed(seed: int) -> None:
    """Sets the random seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

class ModularCNN(nn.Module):
    def __init__(self, conv_channels, kernel_sizes, strides, num_classes=10, in_channels=3):
        super().__init__()
        self.convs = nn.ModuleList()
        self.strides_list = strides
        self.kernel_sizes = kernel_sizes
        
        # Convolutional layers
        for out_ch, k, s in zip(conv_channels, kernel_sizes, strides):
            conv = nn.Conv2d(in_channels, out_ch, kernel_size=k, padding=k//2, bias=True)
            self.convs.append(conv)
            in_channels = out_ch

        # Classifier
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(conv_channels[-1], num_classes)

    def forward(self, x):
        for idx, conv in enumerate(self.convs):
            x = conv(x)
            x = F.relu(x)
            s = self.strides_list[idx]
            if s > 1:
                x = F.avg_pool2d(x, kernel_size=s)
        
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# ===================================================================
# Part 2: The New, More Powerful Model Engineering Function
# ===================================================================

### --- NEW/MODIFIED --- ###
def generate_controlled_cnn_v2(
    design_type: str,
    target_value: float,
    control_variable: str, # 'macs' or 'params'
    input_shape: tuple,
    tolerance: float = 0.05, # Tighter tolerance for better matching
    max_iter: int = 50
) -> ModularCNN:
    """
    Engineers a CNN to have a target budget for either MACs or Parameters.

    Args:
        design_type (str): 'stable' or 'erratic'.
        target_value (float): The desired value for the control variable.
        control_variable (str): What to target: 'macs' or 'params'.
        input_shape (tuple): The input shape for MACs calculation.
        tolerance (float): The allowed relative error.
        max_iter (int): Maximum attempts to find a matching model.

    Returns:
        A ModularCNN instance.
    """
    print(f"\n--- Engineering '{design_type.upper()}' model (Controlling for {control_variable.upper()}) ---")
    if control_variable == 'macs':
        print(f"Target MACs: {target_value / 1e6:.1f}M")
    else:
        print(f"Target Params: {target_value / 1e6:.2f}M")

    # Define the core design principle based on our findings
    if design_type == 'stable':
        strides = [1, 2, 1, 2, 1, 2] # "wide and fast-shrinking"
        base_channels = 32
    elif design_type == 'erratic':
        strides = [1, 1, 1, 1, 2, 1, 1, 1, 1, 2] # "narrow and slow-shrinking"
        base_channels = 16
    else:
        raise ValueError("design_type must be 'stable' or 'erratic'")
    
    depth = len(strides)
    kernels = [3] * depth

    # Iteratively adjust channel width to meet the budget
    for i in range(max_iter):
        channels = [int(base_channels * (1.5**s)) for s in range(len(strides))]
        model = ModularCNN(conv_channels=channels, kernel_sizes=kernels, strides=strides)
        
        dummy_input = torch.randn(input_shape)
        macs, params = profile(model, inputs=(dummy_input,), verbose=False)

        current_value = macs if control_variable == 'macs' else params
        
        if abs(current_value - target_value) / target_value < tolerance:
            print(f"Success! Model has {macs/1e6:.1f}M MACs and {params/1e6:.2f}M params.")
            print(f"Design -> Strides: {strides}, Channels: {channels[:4]}...")
            return model.to(device)

        # Adjust for next iteration based on a simple scaling rule
        if control_variable == 'macs':
             # MACs scale roughly with channels^2
            ratio = np.sqrt(target_value / (current_value + 1e-9))
        else: # params
            # Params also scale roughly with channels^2 in conv layers
            ratio = np.sqrt(target_value / (current_value + 1e-9))
        base_channels = int(base_channels * ratio)
        if base_channels < 1: base_channels = 1
        
    raise RuntimeError(f"Failed to generate a model with target {control_variable} after {max_iter} iterations.")


# ===================================================================
# Part 3: Training and Evaluation Infrastructure (Unchanged)
# ===================================================================

def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    pbar = tqdm(loader, desc="Training", leave=False)
    for inputs, labels in pbar:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        pbar.set_postfix(loss=running_loss/total, acc=100.*correct/total)
    return running_loss / len(loader), 100. * correct / total

def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        pbar = tqdm(loader, desc="Evaluating", leave=False)
        for inputs, labels in pbar:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            pbar.set_postfix(loss=running_loss/total, acc=100.*correct/total)
    return running_loss / len(loader), 100. * correct / total

# ===================================================================
# Part 4: The New "Three-Point Comparison" Experiment Runner
# ===================================================================

### --- NEW/MODIFIED --- ###
def run_three_point_comparison(config):
    set_random_seed(config['seed'])

    # --- 1. Data Loading ---
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    val_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    
    if config['data_fraction'] < 1.0:
        # Code to use a subset of data (unchanged)
        pass

    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=2)
    
    # --- 2. The Three-Point Model Generation ---
    input_shape = (1, 3, 32, 32)
    
    # Model A: The 'Stable' baseline, controlling for a target parameter count.
    model_a = generate_controlled_cnn_v2('stable', config['target_params'], 'params', input_shape)
    _, model_a_params = profile(model_a, inputs=(torch.randn(input_shape).to(device),), verbose=False)
    model_a_macs, _ = profile(model_a, inputs=(torch.randn(input_shape).to(device),), verbose=False)
    
    # Model B: 'Erratic' model matched to Model A's MACs.
    # This replicates our original flawed "Erratic" model.
    model_b = generate_controlled_cnn_v2('erratic', model_a_macs, 'macs', input_shape)

    # Model C: 'Erratic' model matched to Model A's Parameters.
    # This is the crucial new control model.
    model_c = generate_controlled_cnn_v2('erratic', model_a_params, 'params', input_shape)
    
    # --- 3. Training Setup ---
    models = {
        f"A: Stable (Low σ)": model_a,
        f"B: Erratic (MACs-Matched)": model_b,
        f"C: Erratic (Params-Matched)": model_c,
    }
    optimizers = {name: optim.Adam(model.parameters(), lr=config['lr']) for name, model in models.items()}
    criterion = nn.CrossEntropyLoss()
    results = {name: {'val_acc': []} for name in models.keys()}

    # --- 4. The Race ---
    for epoch in range(config['epochs']):
        print(f"\nEpoch {epoch+1}/{config['epochs']}")
        for name, model in models.items():
            print(f"--- Training {name} Model ---")
            train_one_epoch(model, train_loader, criterion, optimizers[name], device)
            _, val_acc = evaluate(model, val_loader, criterion, device)
            results[name]['val_acc'].append(val_acc)
            print(f"Epoch {epoch+1} [{name}] - Val Acc: {val_acc:.2f}%")
    return results

# ===================================================================
# Part 5: Modified Visualization
# ===================================================================

### --- NEW/MODIFIED --- ###
def plot_three_point_results(results, config):
    plt.style.use('seaborn-v0_8-whitegrid')
    fig, ax = plt.subplots(1, 1, figsize=(12, 7))

    colors = ['#1f77b4', '#ff7f0e', '#d62728'] # Blue, Orange, Red
    for i, (name, metrics) in enumerate(results.items()):
        ax.plot(metrics['val_acc'], label=f'{name}', color=colors[i], marker='o', markersize=4, linestyle='--')

    ax.set_title(f"3-Point Comparison: The Impact of Architectural Shape vs. Capacity\n(Target Params: {config['target_params']/1e6:.2f}M)", fontsize=16)
    ax.set_xlabel("Epoch", fontsize=12)
    ax.set_ylabel("Validation Accuracy (%)", fontsize=12)
    ax.legend(fontsize=12, title="Models")
    ax.tick_params(axis='both', which='major', labelsize=10)
    plt.tight_layout()
    plt.show()



Using device: cuda


In [None]:

# ===================================================================
# Part 6: Main Execution Block
# ===================================================================

experiment_config = {
    'seed': 42,
    'target_params': 600000, # Target 0.6 Million Parameters
    'epochs': 10,
    'lr': 1e-3,
    'batch_size': 128,
    'data_fraction': 1.0,
}

training_results = run_three_point_comparison(experiment_config)
plot_three_point_results(training_results, experiment_config)

Files already downloaded and verified
Files already downloaded and verified

--- Engineering 'STABLE' model (Controlling for PARAMS) ---
Target Params: 0.60M
Success! Model has 68.3M MACs and 0.59M params.
Design -> Strides: [1, 2, 1, 2, 1, 2], Channels: [31, 46, 69, 104]...

--- Engineering 'ERRATIC' model (Controlling for MACS) ---
Target MACs: 68.3M
Success! Model has 67.8M MACs and 0.25M params.
Design -> Strides: [1, 1, 1, 1, 2, 1, 1, 1, 1, 2], Channels: [4, 6, 9, 13]...

--- Engineering 'ERRATIC' model (Controlling for PARAMS) ---
Target Params: 0.59M
Success! Model has 152.9M MACs and 0.57M params.
Design -> Strides: [1, 1, 1, 1, 2, 1, 1, 1, 1, 2], Channels: [6, 9, 13, 20]...

Epoch 1/10
--- Training A: Stable (Low σ) Model ---


                                                                                   

Epoch 1 [A: Stable (Low σ)] - Val Acc: 45.07%
--- Training B: Erratic (MACs-Matched) Model ---


                                                                                   

Epoch 1 [B: Erratic (MACs-Matched)] - Val Acc: 34.15%
--- Training C: Erratic (Params-Matched) Model ---


                                                                                  

Epoch 1 [C: Erratic (Params-Matched)] - Val Acc: 34.49%

Epoch 2/10
--- Training A: Stable (Low σ) Model ---


                                                                                   

Epoch 2 [A: Stable (Low σ)] - Val Acc: 54.19%
--- Training B: Erratic (MACs-Matched) Model ---


                                                                                   

Epoch 2 [B: Erratic (MACs-Matched)] - Val Acc: 42.67%
--- Training C: Erratic (Params-Matched) Model ---


                                                                                  

Epoch 2 [C: Erratic (Params-Matched)] - Val Acc: 45.15%

Epoch 3/10
--- Training A: Stable (Low σ) Model ---


                                                                                    

Epoch 3 [A: Stable (Low σ)] - Val Acc: 60.34%
--- Training B: Erratic (MACs-Matched) Model ---


                                                                                  

Epoch 3 [B: Erratic (MACs-Matched)] - Val Acc: 47.08%
--- Training C: Erratic (Params-Matched) Model ---


                                                                                  

Epoch 3 [C: Erratic (Params-Matched)] - Val Acc: 50.98%

Epoch 4/10
--- Training A: Stable (Low σ) Model ---


                                                                                    

Epoch 4 [A: Stable (Low σ)] - Val Acc: 64.33%
--- Training B: Erratic (MACs-Matched) Model ---


                                                                                  

Epoch 4 [B: Erratic (MACs-Matched)] - Val Acc: 49.94%
--- Training C: Erratic (Params-Matched) Model ---


                                                                                   

Epoch 4 [C: Erratic (Params-Matched)] - Val Acc: 55.31%

Epoch 5/10
--- Training A: Stable (Low σ) Model ---


                                                                                    

Epoch 5 [A: Stable (Low σ)] - Val Acc: 66.61%
--- Training B: Erratic (MACs-Matched) Model ---


                                                                                   

Epoch 5 [B: Erratic (MACs-Matched)] - Val Acc: 51.90%
--- Training C: Erratic (Params-Matched) Model ---


                                                                                   

Epoch 5 [C: Erratic (Params-Matched)] - Val Acc: 56.54%

Epoch 6/10
--- Training A: Stable (Low σ) Model ---


                                                                                    

Epoch 6 [A: Stable (Low σ)] - Val Acc: 70.57%
--- Training B: Erratic (MACs-Matched) Model ---


                                                                                   

Epoch 6 [B: Erratic (MACs-Matched)] - Val Acc: 57.74%
--- Training C: Erratic (Params-Matched) Model ---


                                                                                   

Epoch 6 [C: Erratic (Params-Matched)] - Val Acc: 59.39%

Epoch 7/10
--- Training A: Stable (Low σ) Model ---


                                                                                    

Epoch 7 [A: Stable (Low σ)] - Val Acc: 70.91%
--- Training B: Erratic (MACs-Matched) Model ---


                                                                                   

Epoch 7 [B: Erratic (MACs-Matched)] - Val Acc: 59.33%
--- Training C: Erratic (Params-Matched) Model ---


                                                                                   

Epoch 7 [C: Erratic (Params-Matched)] - Val Acc: 62.06%

Epoch 8/10
--- Training A: Stable (Low σ) Model ---


                                                                                    

Epoch 8 [A: Stable (Low σ)] - Val Acc: 72.05%
--- Training B: Erratic (MACs-Matched) Model ---


                                                                                   

Epoch 8 [B: Erratic (MACs-Matched)] - Val Acc: 58.39%
--- Training C: Erratic (Params-Matched) Model ---


                                                                                   

Epoch 8 [C: Erratic (Params-Matched)] - Val Acc: 63.79%

Epoch 9/10
--- Training A: Stable (Low σ) Model ---


                                                                                    

Epoch 9 [A: Stable (Low σ)] - Val Acc: 74.60%
--- Training B: Erratic (MACs-Matched) Model ---


                                                                                    

Epoch 9 [B: Erratic (MACs-Matched)] - Val Acc: 63.92%
--- Training C: Erratic (Params-Matched) Model ---


                                                                                   

Epoch 9 [C: Erratic (Params-Matched)] - Val Acc: 65.04%

Epoch 10/10
--- Training A: Stable (Low σ) Model ---


Training:  23%|██▎       | 90/391 [00:01<00:03, 92.78it/s, acc=78.3, loss=0.00481]

In [3]:
# ===================================================================
# Part 1: Imports and Core Building Blocks
# ===================================================================
import time
import random
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from thop import profile
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

# Choose GPU if available, else CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def set_random_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

### --- RESNET & CNN DEFINITIONS --- ###
class ResNetBlock(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_ch)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_ch != out_ch:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_ch)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ModularResNet(nn.Module):
    def __init__(self, block_config, channel_config, num_classes=10, in_channels=3):
        super().__init__()
        self.in_planes = channel_config[0]
        self.conv1 = nn.Conv2d(in_channels, channel_config[0], kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channel_config[0])
        self.layers = nn.ModuleList()
        
        for i, num_blocks in enumerate(block_config):
            s = 2 if i > 0 else 1
            self.layers.append(self._make_layer(channel_config[i], num_blocks, stride=s))

        self.linear = nn.Linear(channel_config[-1], num_classes)

    def _make_layer(self, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for s in strides:
            layers.append(ResNetBlock(self.in_planes, planes, s))
            self.in_planes = planes
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        for layer in self.layers:
            out = layer(out)
        out = F.adaptive_avg_pool2d(out, 1)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

### --- MODIFIED: Corrected ModularCNN Constructor --- ###
class ModularCNN(nn.Module):
    def __init__(self, conv_channels, stride_config, num_classes=10, in_channels=3):
        super().__init__()
        self.layers = nn.ModuleList()
        in_ch_local = in_channels
        
        # Create a dictionary for easy lookup: {layer_index: stride_value}
        stride_map = {idx: stride for idx, stride in stride_config}

        for i in range(len(conv_channels)):
            # Get the stride for the current layer, default to 1
            s = stride_map.get(i, 1)
            
            # The stride is applied in the Conv2d layer itself for plain CNNs
            conv_layer = nn.Conv2d(in_ch_local, conv_channels[i], kernel_size=3, stride=s, padding=1)
            self.layers.append(nn.Sequential(conv_layer, nn.ReLU()))
            in_ch_local = conv_channels[i]
        
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(conv_channels[-1], num_classes)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# ===================================================================
# Part 2: Universal Model Engineering Function with NEW Blueprints
# ===================================================================
def generate_controlled_model(arch, design_type, target_value, control_variable, input_shape, num_classes=10, tolerance=0.1, max_iter=50):
    print(f"\n--- Engineering '{arch.upper()}/{design_type.upper()}' (Controlling for {control_variable.upper()}) ---")
    if control_variable == 'macs': print(f"Target MACs: {target_value / 1e6:.1f}M")
    else: print(f"Target Params: {target_value / 1e6:.2f}M")

    base_channels = 24
    
    if arch == 'cnn':
        num_layers = 6
        strides_stable = [(1, 2), (3, 2), (5, 2)]
        strides_erratic = [(2, 2), (5, 2)]
        stride_config = strides_stable if design_type == 'stable' else strides_erratic
    elif arch == 'resnet':
        block_config_stable = [2, 2, 2, 2]  # ~ResNet-18
        block_config_erratic = [4, 4, 4, 4] # ~ResNet-34
        block_config = block_config_stable if design_type == 'stable' else block_config_erratic
    
    for _ in range(max_iter):
        if arch == 'cnn':
            channels = [int(base_channels * (1.5**i)) for i in range(num_layers)]
            model = ModularCNN(conv_channels=channels, stride_config=stride_config, num_classes=num_classes, in_channels=input_shape[1])
        elif arch == 'resnet':
            channel_config = [int(base_channels), int(base_channels*2), int(base_channels*4), int(base_channels*8)]
            model = ModularResNet(block_config, channel_config, num_classes=num_classes, in_channels=input_shape[1])

        macs, params = profile(model.to('cpu'), inputs=(torch.randn(input_shape),), verbose=False)
        current_value = macs if control_variable == 'macs' else params
        
        if abs(current_value - target_value) / target_value < tolerance:
            print(f"Success! Model has {macs/1e6:.1f}M MACs and {params/1e6:.2f}M params.")
            return model.to(device)
        
        ratio = np.sqrt(target_value / (current_value + 1e-9))
        base_channels = max(4, int(base_channels * ratio))
        
    raise RuntimeError(f"Failed to generate model for {arch}/{design_type} with target {target_value:.2e}")

# ===================================================================
# Part 3: Training and Evaluation Infrastructure
# ===================================================================
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    pbar = tqdm(loader, desc="Training", leave=False)
    for inputs, labels in pbar:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    return

def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    return running_loss / total, 100. * correct / total

# ===================================================================
# Part 4: Main Experimental Suite Runner
# ===================================================================
def run_experimental_suite(config):
    set_random_seed(config['seed'])
    
    if config['dataset'] == 'cifar10':
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        train_ds = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
        val_ds = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
        input_shape, num_classes = (1, 3, 32, 32), 10
    elif config['dataset'] == 'fashion_mnist':
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
        train_ds = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
        val_ds = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
        input_shape, num_classes = (1, 1, 28, 28), 10

    train_loader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True, num_workers=2)
    val_loader = DataLoader(val_ds, batch_size=config['batch_size'], shuffle=False, num_workers=2)
    
    model_a = generate_controlled_model(config['arch'], 'stable', config['target_params'], 'params', input_shape, num_classes)
    model_c = generate_controlled_model(config['arch'], 'erratic', config['target_params'], 'params', input_shape, num_classes)
    
    models = {"A: Stable (Low σ)": model_a, "C: Erratic (Params-Matched)": model_c}
    optimizers = {name: optim.Adam(model.parameters(), lr=config['lr']) for name, model in models.items()}
    criterion = nn.CrossEntropyLoss()
    results = {name: {'val_acc': []} for name in models.keys()}

    for epoch in range(config['epochs']):
        print(f"\nEpoch {epoch+1}/{config['epochs']} for {config['arch'].upper()}/{config['dataset'].upper()}")
        for name, model in models.items():
            train_one_epoch(model, train_loader, criterion, optimizers[name], device)
            _, val_acc = evaluate(model, val_loader, criterion, device)
            results[name]['val_acc'].append(val_acc)
            print(f"  {name} - Val Acc: {val_acc:.2f}%")
            
    return results, models

# ===================================================================
# Part 5: Loss Landscape Visualization
# ===================================================================
def get_weights(model): return torch.cat([p.data.view(-1) for p in model.parameters()])
def set_weights(model, weights):
    offset = 0
    for p in model.parameters():
        size = torch.prod(torch.tensor(p.shape)).item()
        p.data.copy_(weights[offset:offset+size].view(p.shape))
        offset += size

def get_random_directions(model, device):
    weights = [p.data for p in model.parameters()]
    direction1 = [torch.randn_like(w, device=device) for w in weights]
    norm1 = torch.sqrt(sum(torch.sum(d**2) for d in direction1))
    for d in direction1: d.div_(norm1)
    direction2 = [torch.randn_like(w, device=device) for w in weights]
    dot_product = sum(torch.sum(d1 * d2) for d1, d2 in zip(direction1, direction2))
    for d1, d2 in zip(direction1, direction2): d2.sub_(dot_product * d1)
    norm2 = torch.sqrt(sum(torch.sum(d**2) for d in direction2))
    for d in direction2: d.div_(norm2)
    return direction1, direction2

def calculate_loss_grid(model, directions, loader, criterion, device, grid_range=(-1, 1), grid_points=21):
    w_star_flat = get_weights(model)
    d1_flat = torch.cat([d.view(-1) for d in directions[0]])
    d2_flat = torch.cat([d.view(-1) for d in directions[1]])
    alphas, betas = np.linspace(*grid_range, grid_points), np.linspace(*grid_range, grid_points)
    loss_grid = np.zeros((grid_points, grid_points))
    temp_model = copy.deepcopy(model).to(device)
    pbar = tqdm(total=grid_points*grid_points, desc="Scanning Landscape", leave=False)
    for i, alpha in enumerate(alphas):
        for j, beta in enumerate(betas):
            set_weights(temp_model, w_star_flat + alpha * d1_flat + beta * d2_flat)
            loss, _ = evaluate(temp_model, loader, criterion, device)
            loss_grid[j, i] = loss
            pbar.update(1)
    pbar.close()
    return loss_grid, alphas, betas

def plot_landscapes(models_to_visualize, val_loader, criterion, device, title):
    num_models = len(models_to_visualize)
    fig, axes = plt.subplots(1, num_models, figsize=(8 * num_models, 7), squeeze=False)
    fig.suptitle(f'Loss Landscape Near Final Weights\n({title})', fontsize=24, y=1.02)
    for i, (name, model) in enumerate(models_to_visualize.items()):
        ax = axes[0, i]
        directions = get_random_directions(model, device)
        loss_grid, alphas, betas = calculate_loss_grid(model, directions, val_loader, criterion, device)
        contour = ax.contourf(alphas, betas, loss_grid, levels=20, cmap='viridis', norm=LogNorm())
        fig.colorbar(contour, ax=ax, label='Validation Loss (Log Scale)')
        ax.contour(alphas, betas, loss_grid, levels=20, colors='white', linewidths=0.5, alpha=0.5)
        ax.plot(0, 0, 'r*', markersize=15, label='Final Weights (w*)')
        ax.set_title(f"Landscape for: {name}", fontsize=16)
        ax.set_xlabel("Direction 1 (α)"); ax.set_ylabel("Direction 2 (β)")
        ax.legend(); ax.set_aspect('equal', adjustable='box')
    plt.tight_layout(rect=[0, 0, 1, 0.95]); plt.show()



Using device: cuda


In [None]:
# ===================================================================
# Part 6: Main Execution and Visualization
# ===================================================================
all_results = {}

### --- MODIFIED: Adjusted parameter budgets for feasibility --- ###
experiment_configs = [
    {'arch': 'resnet', 'dataset': 'cifar10', 'target_params': 2000000},       # 2.0M params
    {'arch': 'cnn', 'dataset': 'cifar10', 'target_params': 900000},          # 0.9M params (Increased)
    {'arch': 'resnet', 'dataset': 'fashion_mnist', 'target_params': 1500000}, # 1.5M params
    {'arch': 'cnn', 'dataset': 'fashion_mnist', 'target_params': 400000},      # 0.4M params (Increased)
]

base_config = {'seed': 42, 'epochs': 10, 'lr': 1e-3, 'batch_size': 128}

# --- Run all experiments ---
for exp_config in experiment_configs:
    config = {**base_config, **exp_config}
    key = (config['arch'], config['dataset'])
    try:
        results, trained_models = run_experimental_suite(config)
        all_results[key] = (results, trained_models)
    except RuntimeError as e:
        print(f"\n!!!!!! SKIPPING EXPERIMENT {key} due to error: {e} !!!!!!\n")
        all_results[key] = (None, None)


# --- Plot training curves ---
fig, axes = plt.subplots(2, 2, figsize=(18, 14), sharey=True)
fig.suptitle('Training Efficiency: The Impact of Architectural Shape (RG Flow)', fontsize=24, y=0.95)

plotted_count = 0
for i, config in enumerate(experiment_configs):
    key = (config['arch'], config['dataset'])
    if key in all_results and all_results[key][0] is not None:
        results, _ = all_results[key]
        arch, dataset = key
        ax = axes[plotted_count//2, plotted_count%2]
        colors = ['#1f77b4', '#d62728'] # Blue for Stable, Red for Erratic
        for j, (name, metrics) in enumerate(results.items()):
            label_name = name.split(":")[1].strip().replace(" (Params-Matched)","")
            ax.plot(metrics['val_acc'], label=label_name, color=colors[j], marker='o', markersize=4, linestyle='--')
        ax.set_title(f"{arch.upper()} on {dataset.replace('_', ' ').title()}", fontsize=16)
        ax.set_xlabel("Epoch"); ax.set_ylabel("Validation Accuracy (%)")
        ax.legend(fontsize=12); ax.grid(True)
        plotted_count += 1

plt.tight_layout(rect=[0, 0, 1, 0.93]); plt.show()

# --- Plot the most illustrative loss landscape ---
print("\n--- Visualizing key loss landscape for ResNet on CIFAR-10 ---")
key_to_plot = ('resnet', 'cifar10')
if key_to_plot in all_results and all_results[key_to_plot][1] is not None:
    _, trained_models = all_results[key_to_plot]
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    val_ds = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    val_loader = DataLoader(val_ds, batch_size=512, shuffle=False)
    criterion = nn.CrossEntropyLoss()
    plot_landscapes(trained_models, val_loader, criterion, device, title="ResNet on CIFAR-10")

Files already downloaded and verified
Files already downloaded and verified

--- Engineering 'RESNET/STABLE' (Controlling for PARAMS) ---
Target Params: 2.00M
Success! Model has 100.3M MACs and 1.99M params.

--- Engineering 'RESNET/ERRATIC' (Controlling for PARAMS) ---
Target Params: 2.00M
Success! Model has 93.3M MACs and 1.88M params.

Epoch 1/10 for RESNET/CIFAR10


                                                           

  A: Stable (Low σ) - Val Acc: 66.09%


                                                           

  C: Erratic (Params-Matched) - Val Acc: 58.39%

Epoch 2/10 for RESNET/CIFAR10


                                                           

  A: Stable (Low σ) - Val Acc: 72.57%


                                                           

  C: Erratic (Params-Matched) - Val Acc: 70.45%

Epoch 3/10 for RESNET/CIFAR10


                                                           

  A: Stable (Low σ) - Val Acc: 76.48%


                                                           

  C: Erratic (Params-Matched) - Val Acc: 71.15%

Epoch 4/10 for RESNET/CIFAR10


                                                           

  A: Stable (Low σ) - Val Acc: 80.34%


                                                           

  C: Erratic (Params-Matched) - Val Acc: 75.50%

Epoch 5/10 for RESNET/CIFAR10


                                                           

  A: Stable (Low σ) - Val Acc: 80.66%


                                                           

  C: Erratic (Params-Matched) - Val Acc: 74.55%

Epoch 6/10 for RESNET/CIFAR10


                                                           

  A: Stable (Low σ) - Val Acc: 80.30%


                                                           

  C: Erratic (Params-Matched) - Val Acc: 79.20%

Epoch 7/10 for RESNET/CIFAR10


                                                           

  A: Stable (Low σ) - Val Acc: 78.37%


                                                           

  C: Erratic (Params-Matched) - Val Acc: 74.99%

Epoch 8/10 for RESNET/CIFAR10


                                                           

  A: Stable (Low σ) - Val Acc: 81.44%


                                                           

  C: Erratic (Params-Matched) - Val Acc: 78.68%

Epoch 9/10 for RESNET/CIFAR10


                                                           

  A: Stable (Low σ) - Val Acc: 80.01%


                                                           

  C: Erratic (Params-Matched) - Val Acc: 80.21%

Epoch 10/10 for RESNET/CIFAR10


                                                           

  A: Stable (Low σ) - Val Acc: 79.01%


                                                           

  C: Erratic (Params-Matched) - Val Acc: 80.87%
Files already downloaded and verified
Files already downloaded and verified

--- Engineering 'CNN/STABLE' (Controlling for PARAMS) ---
Target Params: 0.90M
Success! Model has 45.6M MACs and 0.88M params.

--- Engineering 'CNN/ERRATIC' (Controlling for PARAMS) ---
Target Params: 0.90M
Success! Model has 145.7M MACs and 0.88M params.

Epoch 1/10 for CNN/CIFAR10


                                                            

  A: Stable (Low σ) - Val Acc: 44.64%


                                                           

  C: Erratic (Params-Matched) - Val Acc: 42.01%

Epoch 2/10 for CNN/CIFAR10


                                                            

  A: Stable (Low σ) - Val Acc: 53.10%


                                                           

  C: Erratic (Params-Matched) - Val Acc: 53.14%

Epoch 3/10 for CNN/CIFAR10


                                                            

  A: Stable (Low σ) - Val Acc: 61.18%


                                                           

  C: Erratic (Params-Matched) - Val Acc: 58.93%

Epoch 4/10 for CNN/CIFAR10


                                                            

  A: Stable (Low σ) - Val Acc: 63.37%


                                                           

  C: Erratic (Params-Matched) - Val Acc: 62.63%

Epoch 5/10 for CNN/CIFAR10


                                                            

  A: Stable (Low σ) - Val Acc: 68.14%


                                                           

  C: Erratic (Params-Matched) - Val Acc: 66.96%

Epoch 6/10 for CNN/CIFAR10


                                                            

  A: Stable (Low σ) - Val Acc: 70.19%


                                                           

  C: Erratic (Params-Matched) - Val Acc: 69.90%

Epoch 7/10 for CNN/CIFAR10


                                                            

  A: Stable (Low σ) - Val Acc: 72.95%


                                                           

  C: Erratic (Params-Matched) - Val Acc: 71.65%

Epoch 8/10 for CNN/CIFAR10


                                                            

  A: Stable (Low σ) - Val Acc: 74.91%


                                                           

  C: Erratic (Params-Matched) - Val Acc: 73.94%

Epoch 9/10 for CNN/CIFAR10


                                                            

  A: Stable (Low σ) - Val Acc: 76.97%


                                                           

  C: Erratic (Params-Matched) - Val Acc: 72.30%

Epoch 10/10 for CNN/CIFAR10


                                                            

  A: Stable (Low σ) - Val Acc: 77.70%


Training:  64%|██████▍   | 250/391 [00:03<00:01, 74.29it/s]