In [3]:
import os
import random
import numpy as np
import pandas as pd
import torch
from torch.cuda.amp import autocast, GradScaler
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from scipy.optimize import curve_fit
from scipy.stats import linregress, ks_2samp
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.feature_selection import mutual_info_regression
from sklearn.decomposition import PCA
from scipy import stats
import warnings
from scipy.optimize import OptimizeWarning

torch.set_num_threads(1)

scaler = GradScaler()

warnings.filterwarnings("ignore", category=OptimizeWarning)
warnings.filterwarnings(
    "ignore",
    category=RuntimeWarning,
    module="scipy.optimize._minpack_py"
)

# Choose GPU if available, else CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# === Utils ===

_ds_cache = {}

def get_data(ds_name: str, resize: int):
    """
    Return a torchvision Dataset for `ds_name` ('cifar' or 'mnist'),
    resized to (resize × resize).  Downloads only on first call.
    """
    key = (ds_name, resize)
    if key not in _ds_cache:
        # pick the right class
        DataClass = datasets.CIFAR10 if ds_name == 'cifar' else datasets.MNIST

        # build & stash it
        tf = transforms.Compose([
            transforms.Resize((resize, resize)),
            transforms.ToTensor(),
        ])
        _ds_cache[key] = DataClass(
            './data',
            train=False,
            download=True,
            transform=tf
        )
    return _ds_cache[key]

def set_random_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def summarize_sd(arr):
    arr = np.asarray(arr, float)
    arr = arr[~np.isnan(arr)]
    if arr.size == 0:
        return "n/a"
    return f"{arr.mean():.3f} ± {arr.std():.3f} (SD, n={len(arr)})"

def choose_summary(key, arr):
    if key in {"global_pvalue", "mean_mi_pval"}:
        return summarize_p(arr)
    if key == "mean_mi_std_lambda":
        return summarize_sd(arr)
    # default → 95 % CI
    return summarize_ci(arr)

def summarize_ci(arr, alpha=0.05):
    """
    mean ± 95 % CI (Student-t, two-sided).
    Returns 'n/a' if nothing to summarise.
    """
    arr = np.asarray(arr, dtype=float)
    arr = arr[~np.isnan(arr)]
    n   = len(arr)
    if n == 0:
        return "n/a"
    mean = arr.mean()
    sem  = stats.sem(arr)               
    half = stats.t.ppf(1-alpha/2, n-1) * sem
    return f"{mean:.3f} ± {half:.3f} (95 % CI, n={n})"

def summarize_p(arr):
    """median [IQR] for p-values – unchanged."""
    arr = np.asarray(arr, dtype=float)
    arr = arr[~np.isnan(arr)]
    if arr.size == 0:
        return "n/a"
    q25, q75 = np.percentile(arr, [25, 75])
    return f"{np.median(arr):.3f} [IQR {q25:.3f}–{q75:.3f}] (n={len(arr)})"

# === Models ===
class ModularCNN(nn.Module):
    def __init__(
        self,
        conv_channels,
        kernel_sizes,
        strides,
        use_leaky_relu: bool = False,
        use_batchnorm: bool = False,
        in_channels: int = 3
    ):
        super().__init__()
        self.convs = nn.ModuleList()
        self.bns   = nn.ModuleList() if use_batchnorm else None
        self.use_bn = use_batchnorm
        self.strides = strides
        self.kernel_sizes = kernel_sizes
        self.use_leaky = use_leaky_relu
        for out_ch, k, s in zip(conv_channels, kernel_sizes, strides):
            conv = nn.Conv2d(in_channels, out_ch, kernel_size=k, padding=k//2, bias=True)
            self.convs.append(conv)
            if use_batchnorm:
                self.bns.append(nn.BatchNorm2d(out_ch))
            in_channels = out_ch

    def forward(self, x):
        activations = []
        for idx, conv in enumerate(self.convs):
            x = conv(x)
            if self.use_bn:
                x = self.bns[idx](x)
            x = F.leaky_relu(x, 0.01) if self.use_leaky else F.relu(x)
            activations.append(x.clone())
            s = self.strides[idx]
            x = F.avg_pool2d(x, kernel_size=s)
        return activations

class ResNetBlock(nn.Module):
    def __init__(
        self,
        in_ch: int,
        out_ch: int,
        stride: int = 1,
        kernel_size: int = 3,
        bottleneck_ratio: float = 1.0,
        projection_type: str = 'identity',
        activation: str = 'relu'
    ):
        super().__init__()
        mid_ch = int(out_ch * bottleneck_ratio)
        # 1st conv
        self.conv1 = nn.Conv2d(in_ch, mid_ch, kernel_size, stride, kernel_size//2, bias=False)
        self.bn1   = nn.BatchNorm2d(mid_ch)
        # 2nd conv
        self.conv2 = nn.Conv2d(mid_ch, out_ch, kernel_size, 1, kernel_size//2, bias=False)
        self.bn2   = nn.BatchNorm2d(out_ch)
        # shortcut
        if projection_type == 'conv1x1' or stride != 1 or in_ch != out_ch:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 1, stride, bias=False),
                nn.BatchNorm2d(out_ch)
            )
        else:
            self.shortcut = nn.Identity()
        # activation
        if activation == 'leaky_relu':
            self.act = lambda x: F.leaky_relu(x, 0.01)
        else:
            self.act = F.relu

    def forward(self, x):
        y = self.act(self.bn1(self.conv1(x)))
        y = self.bn2(self.conv2(y))
        y = y + self.shortcut(x)
        return self.act(y)

class ModularResNet(nn.Module):
    def __init__(
        self,
        channels,
        block_sizes,
        kernel_sizes,
        strides,
        bottleneck_ratios,
        projection_types,
        activation_functions,
        in_channels: int = 3
    ):
        super().__init__()
        # stem
        self.conv1 = nn.Conv2d(in_channels, channels[0], kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(channels[0])
        self.act1  = F.relu
        self.in_ch = channels[0]
        # store params
        self.channels = channels
        self.block_sizes = block_sizes
        self.kernel_sizes = kernel_sizes
        self.strides_list = strides
        self.bottleneck_ratios = bottleneck_ratios
        self.projection_types = projection_types
        self.activation_functions = activation_functions
        # build layers
        self.layers = nn.ModuleList()
        for idx in range(len(channels)):
            self.layers.append(self._make_layer(
                out_ch = channels[idx],
                num_blocks = block_sizes[idx],
                stride = strides[idx],
                kernel_size = kernel_sizes[idx],
                bottleneck_ratio = bottleneck_ratios[idx],
                projection_type = projection_types[idx],
                activation = activation_functions[idx]
            ))

    def _make_layer(
        self,
        out_ch,
        num_blocks,
        stride,
        kernel_size,
        bottleneck_ratio,
        projection_type,
        activation
    ):
        blocks = []
        # first block
        blocks.append(ResNetBlock(
            self.in_ch, out_ch, stride,
            kernel_size, bottleneck_ratio,
            projection_type, activation
        ))
        self.in_ch = out_ch
        # rest
        for _ in range(1, num_blocks):
            blocks.append(ResNetBlock(
                self.in_ch, out_ch, 1,
                kernel_size, bottleneck_ratio,
                'identity', activation
            ))
        return nn.Sequential(*blocks)

    def forward(self, x):
        acts = []
        x = self.act1(self.bn1(self.conv1(x)))
        acts.append(x.clone())
        for layer in self.layers:
            x = layer(x)
            acts.append(x.clone())
        return acts

# === Architecture Generators ===
def generate_random_cnn(
    n_layers: int,
    stride_choices,
    kernel_choices,
    channel_choices,
    input_size: int,
    in_channels: int,
    use_leaky_relu: bool,
    use_batchnorm: bool
):
    size = input_size
    channels, kernels, strides = [], [], []
    for _ in range(n_layers):
        s = random.choice([st for st in stride_choices if size >= 2*st] or [1])
        k = random.choice(kernel_choices)
        c = random.choice(channel_choices)
        strides.append(s)
        kernels.append(k)
        channels.append(c)
        size = max(1, size // s)
    return ModularCNN(channels, kernels, strides, use_leaky_relu, use_batchnorm, in_channels)

def generate_random_resnet(
    n_layers: int,
    channel_choices,
    input_size: int,
    in_channels: int,
    block_sizes,
    kernel_sizes,
    stride_choices,
    bottleneck_ratios,
    projection_types,
    activation_functions
):
    """
    Build a random ResNet with n_layers blocks. For each block we sample:
      - out_channels from channel_choices
      - num_blocks from block_sizes
      - kernel_size from kernel_sizes
      - stride from stride_choices (only if spatial size allows)
      - bottleneck_ratio from bottleneck_ratios
      - projection_type from projection_types
      - activation from activation_functions
    """
    size = input_size
    channels, blocks, ks, strides, brs, pts, afs = [], [], [], [], [], [], []

    # Sample up to n_layers (stop early if spatial collapses)
    for _ in range(n_layers):
        c  = random.choice(channel_choices)
        b  = random.choice(block_sizes)
        k  = random.choice(kernel_sizes)
        br = random.choice(bottleneck_ratios)
        pt = random.choice(projection_types)
        af = random.choice(activation_functions)

        # only allow stride>1 if size>=2*stride
        possible_strides = [s for s in stride_choices if size >= 2*s]
        s = random.choice(possible_strides) if possible_strides else 1

        channels.append(c)
        blocks.append(b)
        ks.append(k)
        strides.append(s)
        brs.append(br)
        pts.append(pt)
        afs.append(af)

        size = max(1, size // s)
        if size <= 1:
            break

    return ModularResNet(
        channels       = channels,
        block_sizes    = blocks,
        kernel_sizes   = ks,
        strides        = strides,
        bottleneck_ratios   = brs,
        projection_types    = pts,
        activation_functions= afs,
        in_channels    = in_channels
    )

# === Property Extraction ===
def extract_cnn_properties(model: ModularCNN):
    strides, kernels, channels, depths, overlaps = [], [], [], [], []
    for i, conv in enumerate(model.convs):
        k = model.kernel_sizes[i]
        s = model.strides[i]
        c = conv.out_channels
        strides.append(s)
        kernels.append(k)
        channels.append(c)
        depths.append(i+1)
        overlaps.append((k - s) / k)
    return {
        's': np.array(strides),
        'r': np.array(kernels),
        'c': np.array(channels),
        'inv_c': 1/np.array(channels),
        'd': np.array(depths),
        'o': np.array(overlaps),
        's_r': np.array([s*r for s,r in zip(strides, kernels)])
    }

def extract_resnet_properties(model: ModularResNet):
    ks = model.kernel_sizes
    ss = model.strides_list
    br = model.bottleneck_ratios
    pt = model.projection_types
    af = model.activation_functions
    skip_width, block_depth, layer_index, channels, res_type = [], [], [], [], []
    kernel_size, stride, bottleneck, projection, activation, depths = [], [], [], [], [], []

    # initial conv
    layer = 1
    channels.append(model.conv1.out_channels)
    skip_width.append(0)
    block_depth.append(0)
    layer_index.append(layer)
    res_type.append(0)
    kernel_size.append(3)
    stride.append(1)
    bottleneck.append(1.0)
    projection.append(0)
    activation.append(0)
    depths.append(layer)
    layer += 1

    for li, seq in enumerate(model.layers):
        for bj, blk in enumerate(seq):
            in_ch = blk.conv1.in_channels
            out_ch = blk.conv2.out_channels
            has_proj = not isinstance(blk.shortcut, nn.Identity)

            skip_width.append(out_ch - in_ch if has_proj else 0)
            block_depth.append(bj)
            layer_index.append(layer)
            res_type.append(1 if has_proj else 0)

            if bj == 0:
                kernel_size.append(ks[li])
                stride.append(ss[li])
                bottleneck.append(br[li])
                projection.append(1 if pt[li] != 'identity' else 0)
                activation.append(1 if af[li]=='leaky_relu' else 0)
            else:
                kernel_size.append(ks[li])
                stride.append(1)
                bottleneck.append(br[li])
                projection.append(0)
                activation.append(1 if af[li]=='leaky_relu' else 0)

            channels.append(out_ch)
            depths.append(layer)
            layer += 1

    return {
        'skip_width': np.array(skip_width),
        'block_depth': np.array(block_depth),
        'layer_index': np.array(layer_index),
        'inv_c': 1/np.array(channels),
        'res_type': np.array(res_type),
        'kernel_size': np.array(kernel_size),
        'stride': np.array(stride),
        'bottleneck': np.array(bottleneck),
        'projection': np.array(projection),
        'activation': np.array(activation),
        'd': np.array(depths)
    }

def extract_model_properties(model):
    if isinstance(model, ModularCNN):
        return extract_cnn_properties(model)
    elif isinstance(model, ModularResNet):
        return extract_resnet_properties(model)
    else:
        raise ValueError("Unknown model type")

# === Metrics ===

def compute_lambda_i_empirical(model, x, noise_std=0.2, eps=1e-8):
    model.eval()
    with torch.no_grad():
        noisy = x + noise_std * torch.randn_like(x)
        clean_acts = model(x)
        noisy_acts = model(noisy)

    # stack → (L, …), flatten each activation, compute norms in one go
    deltas = torch.stack(
        [(n - c).flatten(1).norm(dim=1) for n, c in zip(noisy_acts, clean_acts)]
    ).squeeze()                       # shape (L,)

    lam = -torch.log((deltas[1:] + eps) / (deltas[:-1] + eps))
    return lam.cpu().numpy()


def compute_input_sensitivity(model, x):
    model.eval()
    xv = x.clone().requires_grad_(True)
    acts = model(xv)

    input_sens = []
    for act in acts:
        # exactly the same mathematical expression as before
        loss = act.square().sum()
        (grad,) = torch.autograd.grad(loss, xv, retain_graph=True)
        input_sens.append(grad.norm().item())

    return input_sens

def mi_permutation_test(X, y, n_perm=200):
    n_neighbors = max(1, min(2, len(y) - 1))  
    actual_mi = mutual_info_regression(X, y, random_state=0, n_neighbors=n_neighbors)[0]
    perm_mis = []
    for _ in range(n_perm):
        y_perm = np.random.permutation(y)
        perm_mi = mutual_info_regression(X, y_perm, random_state=0, n_neighbors=n_neighbors)[0]
        perm_mis.append(perm_mi)
    pval = (np.sum(np.array(perm_mis) >= actual_mi) + 1) / (n_perm + 1)
    return actual_mi, pval

# === Regression & Analysis ===
def create_feature_matrix(raw, config):
    X, names, powers = [], [], []
    used = set()
    for c in config:
        name = c['name']
        if name not in raw or name in used:
            continue
        f = raw[name].astype(float).copy()
        used.add(name)
        if c.get('log', False):
            f = np.log1p(np.clip(f, 0, None))
        if c.get('power_law', False) or c.get('normalize', False):
            f = (f - f.min())/(f.max()-f.min()+1e-8)
        X.append(f)
        names.append(name)
        powers.append(c.get('power_law', False))
    return np.vstack(X), names, powers

def fit_model(X, y, powers, n_boot=100):
    import warnings
    from scipy.optimize import OptimizeWarning

    def model_fn(x, *p):
        total = np.zeros_like(x[0], dtype=float)
        idx = 0
        for j, pl in enumerate(powers):
            a = p[idx]
            if pl:
                b = p[idx+1]
                xj = x[j] + 1e-12
                # Safely calculate power term
                with np.errstate(over='ignore', invalid='ignore'):
                    log_xj = np.log(xj)
                    # Clip to avoid overflow in exp()
                    power = np.clip(b * log_xj, -500, 500)
                    term = np.exp(power)
                    # Clean up any invalid values
                    term = np.nan_to_num(term, nan=0.0, posinf=0.0, neginf=0.0)
                total += a * term
                idx += 2
            else:
                total += a * x[j]
                idx += 1
        total += p[idx]
        return total

    p0 = []
    for pl in powers:
        p0 += [0.1, 1.0] if pl else [0.1]
    p0 += [1.0]

    with warnings.catch_warnings():
        warnings.simplefilter("ignore", OptimizeWarning)
        popt, _ = curve_fit(model_fn, X, y, p0=p0, maxfev=10000)

    pred = model_fn(X, *popt)
    rmse = np.sqrt(((y-pred)**2).mean())
    r2 = linregress(y, pred).rvalue**2

    samples = []
    for _ in tqdm(range(n_boot), desc="Bootstrapping", leave=False):
        idxs = np.random.choice(len(y), len(y), replace=True)
        try:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore", OptimizeWarning)
                pp, _ = curve_fit(model_fn, X[:,idxs], y[idxs], p0=p0, maxfev=5000)
            samples.append(pp)
        except:
            pass

    param_stds = np.std(np.stack(samples), axis=0) if samples else np.zeros_like(popt)
    return {'popt': popt, 'param_stds': param_stds, 'pred': pred, 'rmse': rmse, 'r2': r2, 'model_fn': model_fn}



def plot_key_results(results, title_suffix=""):
    # 1) Input Sensitivity vs Depth
    grouped = results['grouped']
    plt.figure(figsize=(8,6))
    plt.errorbar(grouped['depth'], grouped['mean'], yerr=grouped['std'], fmt='o-', capsize=3)
    plt.xlabel("Depth")
    plt.ylabel("Mean Input Sensitivity")
    plt.title(f"Input Sensitivity vs Depth{title_suffix}")
    plt.grid(True)
    plt.show()

    # 2) σ² vs 1/(L–1) Regression
    inv   = results['inv_Lminus1']      # array of 1/(L_i - 1)
    sigma2 = results['sigma2']          # array of σ_i²
    slope, intercept, r_value, p_value, stderr = linregress(inv, sigma2)

    plt.figure(figsize=(8,6))
    plt.scatter(inv, sigma2, alpha=0.5, label="data")
    xs = np.linspace(inv.min(), inv.max(), 200)
    plt.plot(xs, intercept + slope*xs, 'r--',
             label=f"fit: y={slope:.2e}x+{intercept:.2e}\n$R^2$={r_value**2:.2f}")
    plt.xlabel("1 / (L - 1)")
    plt.ylabel("σ²")
    plt.title(f"σ² vs 1/(L–1){title_suffix}")
    plt.legend()
    plt.grid(True)
    plt.show()
    
def compute_global_complexity(raw_props, n_components=1):
    """
    Compute a 1D "ACS" per model as the projection onto the first principal component
    of the global (minmax‐scaled) parameter space.
    """
    # 1) Stack parameters into (n_models, n_features)
    keys = list(raw_props.keys())
    X = np.vstack([ raw_props[k] for k in keys ]).T  # shape (n_models, n_keys)

    # 2) Min–max scale each column to [0,1]
    mins = X.min(axis=0, keepdims=True)
    maxs = X.max(axis=0, keepdims=True)
    X_norm = (X - mins) / (maxs - mins)

    # 3) PCA to 1 component
    pca = PCA(n_components=n_components)
    pcs = pca.fit_transform(X_norm)  # shape (n_models, 1)

    # 4) Return the first principal component (flattened)
    return pcs[:, 0]

def classify_models(complexity_scores, percentile=10):
    sorted_idx = np.argsort(complexity_scores)
    n = len(complexity_scores)
    labels = np.array(['unlabeled'] * n)
    k = max(1, int(n * percentile / 100))
    labels[sorted_idx[:k]] = 'simple'
    labels[sorted_idx[-k:]] = 'complex'
    return labels

def bin_and_validate(input_sens_means, complexity_labels, n_bins=10):
    bins = np.linspace(min(input_sens_means), max(input_sens_means), n_bins + 1)
    bin_indices = np.digitize(input_sens_means, bins) - 1
    valid_bins = []
    for bin_id in range(n_bins):
        idx = np.where(bin_indices == bin_id)[0]
        labels_in_bin = complexity_labels[idx]
        simple_count = np.sum(labels_in_bin == 'simple')
        complex_count = np.sum(labels_in_bin == 'complex')
        if len(idx) >= 5 and simple_count >= 3 and complex_count >= 3:
            valid_bins.append({
                'indices': idx,
                'simple_count': simple_count,
                'complex_count': complex_count,
                'bin_id': bin_id
            })
    return valid_bins

def compute_bin_metrics(valid_bins, lambda_vals, complexity_scores, raw_props, complexity_labels):
    """
    For each valid FI‐bin, compute:
      - CV of σ
      - Mutual Information on ACS and on all_params (with p‐values)
      - σ(simple)/σ(complex) ratio and its KS‐test p‐value
    Returns a pandas.DataFrame with columns:
      ['IS_bin', '#simple', '#complex', 'CV',
       'ratio', 'pval_ratio',
       'MI_ACS', 'MI_pval_ACS',
       'MI_all_params', 'MI_pval_all_params']
    """
    metrics = []
    for b in valid_bins:
        idx = b['indices']

        # compute σ for all models in this bin
        sigma_l = np.array([np.std(lambda_vals[i]) for i in idx])

        # coefficient of variation
        CV = sigma_l.std() / (sigma_l.mean() + 1e-12)

        # --- simple vs complex σ ---
        lam_s = [np.std(lambda_vals[i]) for i in idx if complexity_labels[i] == 'simple']
        lam_c = [np.std(lambda_vals[i]) for i in idx if complexity_labels[i] == 'complex']
        # KS‐test on the two distributions
        _, p_val = ks_2samp(lam_s, lam_c)
        # ratio of means
        ratio = np.nanmean(lam_s) / np.nanmean(lam_c)

        # mutual information tests
        acs         = complexity_scores[idx].reshape(-1, 1)
        all_params  = np.array([[raw_props[k][i] for k in raw_props] for i in idx])
        mi_acs, p_acs       = mi_permutation_test(acs, sigma_l)
        mi_params, p_params = mi_permutation_test(all_params, sigma_l)

        metrics.append({
            'IS_bin':            b['bin_id'],
            '#simple':           b['simple_count'],
            '#complex':          b['complex_count'],
            'CV':                CV,
            'ratio':             ratio,
            'pval_ratio':        p_val,
            'MI_ACS':            mi_acs,
            'MI_pval_ACS':       p_acs,
            'MI_all_params':     mi_params,
            'MI_pval_all_params': p_params
        })

    return pd.DataFrame(metrics)

# === Experiment Runner ===
def run_experiment(cfg):
    # ───────────────────────────── SET-UP ──────────────────────────────
    set_random_seed(cfg['seed'])
    ds = cfg['dataset'].lower()
    base, ch = (32, 3) if ds == 'cifar' else (28, 1)
    scale = random.uniform(*cfg['resize_range'])
    tgt = int(base * scale)

    print(f"Running run_id       : {cfg['run_id']}")
    print(f"\nResize scale factor  : {scale:.2f}  →  target size {tgt}")

    data = get_data(ds, tgt)
    n = max(1, int(len(data) * cfg['data_fraction']))
    idx = np.random.choice(len(data), n, replace=False)
    samples = [data[i][0] for i in idx]
    x = torch.stack(samples, dim=0).to(device)
    noise_std = cfg['noise_std']
    input_size, in_channels = tgt, ch

    print(f"Input shape          : {tuple(x.shape)}")

    # ───────────────────── GENERATE & MEASURE MODELS ────────────────────
    data_pts, input_sens_vals, lambda_vals = [], [], []
    lambda_stds, model_depths          = [], []

    def _measure_model(model):
        with autocast():
            lam_raw = compute_lambda_i_empirical(model, x, noise_std)
        lam = lam_raw.mean(axis=1) if lam_raw.ndim > 1 else lam_raw
        if len(lam) < 3:
            return None
        with autocast(enabled=False):
            input_sens = compute_input_sensitivity(model, x)
        props = extract_model_properties(model)
        ml = min(len(lam)+1, len(input_sens))
        for k in props:
            props[k] = props[k][:ml]
        rows = []
        for j in range(ml-1):
            rows.append({**{k: props[k][j] for k in props}, 'empirical': float(lam[j])})
        return lam, input_sens, rows

    for _ in tqdm(range(cfg['n_models']), desc="Measuring models"):
        depth = random.choice(cfg['depth_choices'])
        if cfg['architecture'] == 'cnn':
            model = generate_random_cnn(
                depth,
                cfg['param_ranges']['strides'],
                cfg['param_ranges']['kernels'],
                cfg['param_ranges']['channels'],
                input_size, in_channels,
                cfg.get('use_leaky_relu', True),
                cfg.get('use_batchnorm', False)
            ).to(device)
        else:
            pr = cfg['param_ranges']
            model = generate_random_resnet(
                depth, pr['channels'], input_size, in_channels,
                pr['block_sizes'], pr['kernel_sizes'], pr['strides'],
                pr['bottleneck_ratios'], pr['projection_types'],
                pr['activation_functions']
            ).to(device)
        res = _measure_model(model)
        del model
        if res is None:
            continue
        lam, input_sens, rows = res
        lambda_stds.append(float(np.std(lam)))
        model_depths.append(len(lam) + 1)
        input_sens_vals.append(input_sens[:-1])
        lambda_vals.append(lam)
        data_pts.extend(rows)

    torch.cuda.empty_cache()

    if not data_pts:
        print("No valid data collected for this run.")
        return {
            'run_id': cfg['run_id'],
            'architecture': cfg['architecture'],
            'dataset': cfg['dataset']
        }

    n_models_valid = len(lambda_stds)
    print(f"Generated models accepted       : {n_models_valid}")
    print(f"Unique datapoints               : {len(set(tuple(d.items()) for d in data_pts))}")

    df          = pd.DataFrame(data_pts)
    raw_props   = {c: df[c].values for c in df.columns if c != 'empirical'}
    empirical   = df['empirical'].values
    input_sens_flat = np.concatenate(input_sens_vals)
    lambda_flat = np.concatenate(lambda_vals)

    complexity_scores = compute_global_complexity(raw_props)
    percentile = cfg.get("complexity_percentile", 10)
    complexity_labels = classify_models(complexity_scores, percentile)

    input_sens_means = np.array([np.mean(f) for f in input_sens_vals])
    valid_bins = bin_and_validate(input_sens_means, complexity_labels)

    per_bin_metrics = compute_bin_metrics(valid_bins, lambda_vals, complexity_scores, raw_props, complexity_labels)
    
    # ──────────────────────────── Compute global σ(simple)/σ(complex) ────────────────────────────
    ratios = []
    pvals  = []
    for b in valid_bins:
       # collect σ for simple vs complex in this bin
       lam_s = [ np.std(lambda_vals[i]) 
                 for i in b['indices'] 
                 if complexity_labels[i]=='simple' ]
       lam_c = [ np.std(lambda_vals[i]) 
                 for i in b['indices'] 
                 if complexity_labels[i]=='complex' ]
       # ks_2samp returns (statistic, p)
       _, p = ks_2samp(lam_s, lam_c)
       ratios.append(np.nanmean(lam_s) / np.nanmean(lam_c))
       pvals .append(p)

    # overall mean ratio and median p‐value across all bins
    mean_ratio = np.nanmean(ratios)
    median_p   = np.nanmedian(pvals)

    # ───────── Group by depth (mean & std) and fit slope ────────────
    slope_input_sens = r_value_input_sens = np.nan
    grp = None
    try:
        grp = (
            pd.DataFrame({
                'depth': raw_props['d'],
                'input_sens': input_sens_flat
            })
            .groupby('depth')['input_sens']
            .agg(['mean','std'])
            .reset_index()
        )
        # now do the log‐log fit on mean
        log_d2 = np.log(grp['depth'] + 1e-8)
        log_f2 = np.log(grp['mean']  + 1e-8)
        slope_input_sens, _, r_value_input_sens, _, _ = linregress(log_d2, log_f2)
    except Exception:
        grp = None

    print("\nIS Scaling with Depth")
    if not np.isnan(slope_input_sens):
        print(f"  IS   ~ depth^{slope_input_sens:.3f}   (R² = {r_value_input_sens**2:.3f})")
        
    if valid_bins:
        # compute per‐bin means
        bin_centers = [np.mean([input_sens_means[i] for i in b['indices']]) for b in valid_bins]
        all_complexities = [complexity_scores[i] for b in valid_bins for i in b['indices']]
        all_sigmas      = [np.std(lambda_vals[i])        for b in valid_bins for i in b['indices']]

        print("\nBin Analysis Summary")
        print(f"  Number of valid IS-bins   : {len(valid_bins)}")
        print(f"  IS-bins range             : [{min(bin_centers):.3f}, {max(bin_centers):.3f}]")
        print(f"  Complexity score range    : [{min(all_complexities):.3f}, {max(all_complexities):.3f}]")
        print(f"  σ range                   : [{min(all_sigmas):.3f}, {max(all_sigmas):.3f}]")
    else:
        print("\nBin Analysis Summary")
        print("  No valid IS bins — cannot compute summary ranges.")


    # ───────────────────────────── Per-Bin Table ─────────────────────────────
    # Build a lookup from bin_id → mean input_sens-info
    bin_centers = {
        b['bin_id']: np.mean([input_sens_means[i] for i in b['indices']])
        for b in valid_bins
    }

    print("\nPer-Bin Table (one row per valid bin)")
    print("|   IS   | simple | complex |   CV    | Ratio σ | pval_ratio |  MI_ACS  | pval_ACS | MI_params | pval_params |")
    print("|--------|--------|---------|---------|---------|------------|----------|----------|-----------|-------------|")

    for _, row in per_bin_metrics.iterrows():
        center = bin_centers[row['IS_bin']]
        print(
            f"| {center:6.2f} | "
            f"{int(row['#simple']):6d} | "
            f"{int(row['#complex']):7d} | "
            f"{row['CV']:7.3f} | "
            f"{row['ratio']:7.3f} | "
            f"{row['pval_ratio']:10.3f} | "
            f"{row['MI_ACS']:8.3f} | "
            f"{row['MI_pval_ACS']:8.3f} | "
            f"{row['MI_all_params']:9.3f} | "
            f"{row['MI_pval_all_params']:11.3f} |"
        )

        # then downstream you can still compute run-level statistics:
        mean_cv      = per_bin_metrics['CV'].mean()
        max_cv       = per_bin_metrics['CV'].max()
        mean_mi_acs  = per_bin_metrics['MI_ACS'].mean()
        median_mi_acs= per_bin_metrics['MI_ACS'].median()
        max_mi_acs   = per_bin_metrics['MI_ACS'].max()
        
    print("\nRun-Level Metrics Table")
    if per_bin_metrics.empty:
        print("No valid run-level metrics due to empty bins.")
    else:
        mean_mi_params   = per_bin_metrics['MI_all_params'].mean()
        median_mi_params = per_bin_metrics['MI_all_params'].median()
        max_mi_params    = per_bin_metrics['MI_all_params'].max()
        
    # --- prepare your statistics ---
    std_ratio = np.nanstd(ratios)

    # --- table formatters ---
    header_fmt = "| {0:35s} | {1:43s} |"
    row_fmt    = "| {0:35s} | {1:43s} |"

    # print the table
    print(header_fmt.format("Metric", "Value"))
    # separator line: 37 dashes for the first column, 45 for the second
    print("|" + "-"*37 + "|" + "-"*45 + "|")

    # each row uses the same row_fmt
    print(row_fmt.format(
        "Mean CV(bin) ± SD",
        f"{mean_cv:.3f} ± {per_bin_metrics['CV'].std():.3f}"
    ))
    print(row_fmt.format(
        "Max CV(bin)",
        f"{max_cv:.3f}"
    ))
    print(row_fmt.format(
        "Ratio σ(simple/complex) ± SD",
        f"{mean_ratio:.3f} ± {std_ratio:.3f}"
    ))
    print(row_fmt.format(
        "KS p-value σ(simple/complex)",
        f"{median_p:.3f}"
    ))
    print(row_fmt.format(
        "MI ACS mean ± SD",
        f"{mean_mi_acs:.3f} ± {per_bin_metrics['MI_ACS'].std():.3f}"
    ))
    print(row_fmt.format(
        "MI ACS median (IQR)",
        f"{median_mi_acs:.3f} [{per_bin_metrics['MI_ACS'].quantile(0.25):.3f},{per_bin_metrics['MI_ACS'].quantile(0.75):.3f}]"
    ))
    print(row_fmt.format(
        "MI ACS max",
        f"{max_mi_acs:.3f}"
    ))
    print(row_fmt.format(
        "MI all_params mean ± SD",
        f"{mean_mi_params:.3f} ± {per_bin_metrics['MI_all_params'].std():.3f}"
    ))
    print(row_fmt.format(
        "MI all_params median (IQR)",
        f"{median_mi_params:.3f} [{per_bin_metrics['MI_all_params'].quantile(0.25):.3f},{per_bin_metrics['MI_all_params'].quantile(0.75):.3f}]"
    ))
    print(row_fmt.format(
        "MI all_params max",
        f"{max_mi_params:.3f}"
    ))
        

    # After printing Run-Level Metrics:
    if per_bin_metrics.empty:
        mean_cv = max_cv = mean_mi_acs = median_mi_acs = max_mi_acs = np.nan
    else:
        mean_cv = per_bin_metrics['CV'].mean()
        max_cv  = per_bin_metrics['CV'].max()
        mean_mi_acs   = per_bin_metrics['MI_ACS'].mean()
        median_mi_acs = per_bin_metrics['MI_ACS'].median()
        max_mi_acs    = per_bin_metrics['MI_ACS'].max()
    
    # Stability metrics unchanged
    ratios = (np.array(lambda_stds) ** 2) * (np.array(model_depths) - 1)
    mean_sigma, std_sigma = np.mean(ratios), np.std(ratios)
    cv_sigma = std_sigma / (mean_sigma + 1e-12)
    
    # ───────── RUN-LEVEL STABILITY PRINT ─────────
    print("\nRun-level Geometric Stability")
    print(f"  Mean σ²×(L-1): {mean_sigma:.3e}")
    print(f"  Std  σ²×(L-1): {std_sigma:.3e}")
    print(f"  CV   σ²×(L-1): {cv_sigma:.3f}")
    
    # ────────── PLOTS ──────────────────────
    if cfg.get('make_plots', False) and (grp is not None):
        inv_Lminus1 = 1.0 / (np.array(model_depths) - 1)
        sigma2      = np.array(lambda_stds)**2
        plot_key_results(
            {
                'grouped':     grp, 
                'inv_Lminus1': inv_Lminus1,
                'sigma2':      sigma2
            },
            title_suffix = f" ({cfg['run_id']})"
        )
    # ────────────────────────────────────────

    return {
        'run_id': cfg['run_id'],
        'architecture': cfg['architecture'],
        'dataset': cfg['dataset'],
        'n_models_accepted': len(lambda_stds),
        'n_models_used_in_bins': sum(len(b['indices']) for b in valid_bins),
        'n_lambda_vals': len(lambda_flat),
        'n_input_sens_vals': len(input_sens_flat),
        'n_IS_bins_attempted': cfg['n_variance_levels'],
        'n_IS_bins_valid': len(valid_bins),
        'input_sens_depth_slope': slope_input_sens,
        'input_sens_depth_r2': r_value_input_sens**2 if not np.isnan(r_value_input_sens) else np.nan,
        'mean_sigma_ratio_simple_to_complex': mean_ratio,
        'ks_pvalue_sigma_ratio':             median_p,
        'cv_bins_mean': mean_cv,
        'cv_bins_std':  np.nan if per_bin_metrics.empty else per_bin_metrics['CV'].std(),
        'cv_bins_max':  max_cv,
        'mean_mi_acs': mean_mi_acs,
        'median_mi_acs': median_mi_acs,
        'max_mi_acs': max_mi_acs,
        'mean_mi_all_params':    mean_mi_params,
        'median_mi_all_params':  median_mi_params,
        'max_mi_all_params':     max_mi_params,
        'mean_sigma2_ratio': mean_sigma,
        'std_sigma2_ratio': std_sigma,
        'cv_sigma2_ratio': cv_sigma
    }

def analyze_experiment_results(results_list):
    import numpy as np
    from scipy import stats
    from collections import defaultdict

    # Helper to compute mean ± 95% CI
    def mean_ci(arr):
        a = np.array(arr, float)
        a = a[~np.isnan(a)]
        if len(a) < 2:
            return None
        m = a.mean()
        se = stats.sem(a)
        h = stats.t.ppf(0.975, len(a)-1) * se
        return m, m-h, m+h

    # Format an array as mean ± half-width (95% CI, n=…)
    def fmt_ci(arr):
        ci = mean_ci(arr)
        cnt = np.count_nonzero(~np.isnan(arr))
        if ci is None:
            return "n/a"
        m, lo, hi = ci
        return f"{m:.3f} ± {(m-lo):.3f} (95% CI, n={cnt})"

    # Format p-values as median [IQR] (n=…)
    def fmt_p(arr):
        a = np.array(arr, float)
        a = a[~np.isnan(a)]
        if len(a) == 0:
            return "n/a"
        q25, q75 = np.percentile(a, [25, 75])
        return f"{np.median(a):.3f} [IQR {q25:.3f}, {q75:.3f}] (n={len(a)})"

    # Simple mean±SD
    def fmt_sd(arr):
        a = np.array(arr, float)
        a = a[~np.isnan(a)]
        if len(a) == 0:
            return "n/a"
        return f"{a.mean():.3f} ± {a.std():.3f}"

    n_runs = len(results_list)

    # collect arrays across all runs, using .get to default to nan
    ratio_arr   = [r.get('mean_sigma_ratio_simple_to_complex',   np.nan) for r in results_list]
    pval_arr    = [r.get('ks_pvalue_sigma_ratio',               np.nan) for r in results_list]
    cv_arr      = [r.get('cv_bins_mean',                        np.nan) for r in results_list]
    cvmax_arr   = [r.get('cv_bins_max',                         np.nan) for r in results_list]
    mi_acs_arr  = [r.get('mean_mi_acs',                         np.nan) for r in results_list]
    mi_ap_arr   = [r.get('mean_mi_all_params',                  np.nan) for r in results_list]
    s2_m_arr    = [r.get('mean_sigma2_ratio',                   np.nan) for r in results_list]
    s2_s_arr    = [r.get('std_sigma2_ratio',                    np.nan) for r in results_list]
    s2_cv_arr   = [r.get('cv_sigma2_ratio',                     np.nan) for r in results_list]

    # --- CROSS-RUN SUMMARY ---
    print("\n===== CROSS-RUN SUMMARY =====\n")
    # Metadata
    tm = sum(r.get('n_models_accepted',0) for r in results_list)
    tl = sum(r.get('n_lambda_vals',0)       for r in results_list)
    tf = sum(r.get('n_input_sens_vals',0)       for r in results_list)
    ba = sum(r.get('n_IS_bins_attempted',0) for r in results_list)
    bv = sum(r.get('n_IS_bins_valid',0)     for r in results_list)
    ub = np.nanmean([r.get('n_models_used_in_bins', np.nan) for r in results_list])
    print("Metadata:")
    print(f"  Total models accepted          : {tm}")
    print(f"  Total λᵢ values                : {tl}")
    print(f"  Total IS values            : {tf}")
    print(f"  Valid IS bins                  : {bv}")
    print(f"  Total IS bins attempted        : {ba}")
    print(f"  Avg models per run             : {tm/n_runs:.2f}")
    print(f"  Avg λᵢ per model               : {tl/tm:.2f}")
    print(f"  Avg models used-in-bins per run: {ub:.2f}\n")

    # 1. Per-IS-bin σ Variation
    print("1. Per-IS-bin σ Variation:")
    print(f"   Ratio σ(simple/complex)  : {fmt_ci(ratio_arr)}")
    print(f"   KS-p-value summary       : {fmt_p(pval_arr)}")
    print(f"   Runs with p<0.05         : {int(np.sum(np.array(pval_arr)<0.05))} / {n_runs}\n")

    # 2. Per-IS-bin CV(σ)
    print("2. Per-IS-bin CV(σ):")
    print(f"   Mean CV(bin) : {fmt_ci(cv_arr)}")
    print(f"   Max  CV(bin) : {fmt_ci(cvmax_arr)}\n")

    # 3. Mutual Information ACS
    print("3. Mutual Information ACS:")
    if all(np.isnan(mi_acs_arr)):
        print("   MI_ACS: n/a\n")
    else:
        print(f"   Mean ± SD     : {fmt_sd(mi_acs_arr)}")
        med = np.nanmedian(mi_acs_arr)
        iqr = (np.nanpercentile(mi_acs_arr,25), np.nanpercentile(mi_acs_arr,75))
        print(f"   Median [IQR]  : {med:.3f} [{iqr[0]:.3f}–{iqr[1]:.3f}]")
        print(f"   Max           : {np.nanmax(mi_acs_arr):.3f}\n")

    # 4. Mutual Information all_params
    print("4. Mutual Information all_params:")
    if all(np.isnan(mi_ap_arr)):
        print("   MI_all_params: n/a\n")
    else:
        print(f"   Mean ± SD     : {fmt_sd(mi_ap_arr)}")
        med = np.nanmedian(mi_ap_arr)
        iqr = (np.nanpercentile(mi_ap_arr,25), np.nanpercentile(mi_ap_arr,75))
        print(f"   Median [IQR]  : {med:.3f} [{iqr[0]:.3f}–{iqr[1]:.3f}]")
        print(f"   Max           : {np.nanmax(mi_ap_arr):.3f}\n")

    # 5. Stability σ²×(L-1)
    print("5. Stability σ²×(L-1):")
    print(f"   Mean : {fmt_ci(s2_m_arr)}")
    print(f"   Std  : {fmt_ci(s2_s_arr)}")
    print(f"   CV   : {fmt_ci(s2_cv_arr)}\n")

    # --- PER-ARCH/DATASET BREAKDOWN ---
    print("\n===== PER-ARCH/DATASET BREAKDOWN =====\n")
    by_group = defaultdict(list)
    for r in results_list:
        by_group[(r.get('architecture','?'), r.get('dataset','?'))].append(r)

    for (arch, ds), grp in by_group.items():
        print(f"--- {arch.upper()} / {ds.upper()} (n={len(grp)}) ---")
        # collect per-group arrays
        g_ratio = [g.get('mean_sigma_ratio_simple_to_complex', np.nan) for g in grp]
        g_pval  = [g.get('ks_pvalue_sigma_ratio',             np.nan) for g in grp]
        g_cv    = [g.get('cv_bins_mean',                       np.nan) for g in grp]
        g_cvmax = [g.get('cv_bins_max',                        np.nan) for g in grp]
        g_mi_acs= [g.get('mean_mi_acs',                        np.nan) for g in grp]
        g_mi_ap = [g.get('mean_mi_all_params',                 np.nan) for g in grp]
        g_s2_m  = [g.get('mean_sigma2_ratio',                  np.nan) for g in grp]
        g_s2_s  = [g.get('std_sigma2_ratio',                   np.nan) for g in grp]
        g_s2_cv = [g.get('cv_sigma2_ratio',                    np.nan) for g in grp]

        # 1.
        print("1. Per-IS-bin σ Variation:")
        if len(grp) == 1:
            print(f"   Ratio σ(simple/complex)  : {g_ratio[0]:.3f}")
            print(f"   KS-p-value summary       : {g_pval[0]:.3f}")
            print(f"   Runs with p<0.05         : {int(g_pval[0]<0.05)} / 1\n")
        else:
            print(f"   Ratio σ(simple/complex)  : {fmt_ci(g_ratio)}")
            print(f"   KS-p-value summary       : {fmt_p(g_pval)}")
            print(f"   Runs with p<0.05         : {int(np.sum(np.array(g_pval)<0.05))} / {len(grp)}\n")

        # 2.
        print("2. Per-IS-bin CV(σ):")
        print(f"   Mean CV(bin) : {fmt_ci(g_cv)}")
        print(f"   Max  CV(bin) : {fmt_ci(g_cvmax)}\n")

        # 3.
        print("3. Mutual Information ACS:")
        if all(np.isnan(g_mi_acs)):
            print("   MI_ACS: n/a\n")
        else:
            print(f"   Mean ± SD     : {fmt_sd(g_mi_acs)}")
            med = np.nanmedian(g_mi_acs)
            iqr = (np.nanpercentile(g_mi_acs,25), np.nanpercentile(g_mi_acs,75))
            print(f"   Median [IQR]  : {med:.3f} [{iqr[0]:.3f}–{iqr[1]:.3f}]")
            print(f"   Max           : {np.nanmax(g_mi_acs):.3f}\n")

        # 4.
        print("4. Mutual Information all_params:")
        if all(np.isnan(g_mi_ap)):
            print("   MI_all_params: n/a\n")
        else:
            print(f"   Mean ± SD     : {fmt_sd(g_mi_ap)}")
            med = np.nanmedian(g_mi_ap)
            iqr = (np.nanpercentile(g_mi_ap,25), np.nanpercentile(g_mi_ap,75))
            print(f"   Median [IQR]  : {med:.3f} [{iqr[0]:.3f}–{iqr[1]:.3f}]")
            print(f"   Max           : {np.nanmax(g_mi_ap):.3f}\n")

        # 5.
        print("5. Stability σ²×(L-1):")
        print(f"   Mean : {fmt_ci(g_s2_m)}")
        print(f"   Std  : {fmt_ci(g_s2_s)}")
        print(f"   CV   : {fmt_ci(g_s2_cv)}\n")

    return {
        "overall": {
            "ratio_CI": fmt_ci(ratio_arr),
            "ks_p":     fmt_p(pval_arr),
            "cv_mean":  fmt_ci(cv_arr),
            "cv_max":   fmt_ci(cvmax_arr),
            "mi_acs":   fmt_sd(mi_acs_arr) if not all(np.isnan(mi_acs_arr)) else "n/a",
            "mi_ap":    fmt_sd(mi_ap_arr)  if not all(np.isnan(mi_ap_arr))  else "n/a",
            "s2_mean":  fmt_ci(s2_m_arr),
            "s2_std":   fmt_ci(s2_s_arr),
            "s2_cv":    fmt_ci(s2_cv_arr),
        }
    }

def run_statistical_consistency_experiments(
    seeds,
    architectures,
    datasets,
    data_fraction,
    noise_levels=(0.1, 0.2),
    n_models=100,
    complexity_percentile=10,
    n_variance_levels=20,
    n_bootstrap=50,
    make_plots=False,
    architecture_configs=None
):
    """
    Runs experiments over all combinations of seed, architecture, dataset and noise level.
    Each call to run_experiment gets a single scalar noise_std.
    """
    total = len(seeds) * len(architectures) * len(datasets) * len(noise_levels)
    count = 0
    all_results = []

    for seed in seeds:
        for arch in architectures:
            for ds in datasets:
                for noise in noise_levels:
                    count += 1
                    print("\n" + "=" * 60)
                    print(f"Experiment {count}/{total}: "
                          f"arch={arch}, dataset={ds}, seed={seed}, noise={noise}")

                    # build config for this run
                    cfg = {
                        'seed': seed,
                        'architecture': arch,
                        'dataset': ds,
                        'data_fraction': data_fraction,
                        'noise_std': noise,
                        'n_models': n_models,
                        'complexity_percentile': complexity_percentile,
                        'n_variance_levels': n_variance_levels,
                        'n_bootstrap': n_bootstrap,
                        'make_plots': make_plots,
                        # unique run_id
                        'run_id': f"{arch}_{ds}_s{seed}_n{noise}",
                        # bring in architecture-specific settings
                        **architecture_configs[arch]
                    }

                    # run it!
                    res = run_experiment(cfg)
                    all_results.append(res)

    # once all runs are done, aggregate
    analysis = analyze_experiment_results(all_results)
    return all_results, analysis

Using device: cuda


In [None]:
# architecture_configs
architecture_configs = {
    'cnn': {
        'depth_choices': [4, 6, 8, 10, 20, 40, 80, 100],
        'resize_range': (1.0, 4.0),
        'param_ranges': {
            'strides':  [1, 2, 3],
            'kernels':  [3, 5, 7, 9, 11, 13, 15, 17],
            'channels': [8, 16, 32, 64, 128, 256, 512, 1024]
        }
    },
    'resnet': {
        'depth_choices': [4, 6, 8, 10, 20, 40, 80, 100],
        'resize_range': (1.0, 4.0),
        'param_ranges': {
            'channels':          [16, 32, 64, 128, 256, 512, 1024],
            'block_sizes':       [1, 2, 3, 4, 5],
            'kernel_sizes':      [1, 3, 5, 7, 9],
            'strides':           [1, 2, 3],
            'bottleneck_ratios': [1, 0.75, 0.5, 0.25, 0.125],
            'projection_types':  ['identity', 'conv1x1'],
            'activation_functions': ['relu', 'leaky_relu']
        }
    }
}

# call to main script:
results, analysis = run_statistical_consistency_experiments(
    seeds=[1, 12, 123],
    architectures=['cnn', 'resnet'],
    datasets=['cifar', 'mnist'],
    data_fraction=0.0008,          # proportion of test set to sample
    noise_levels=[0.1, 0.25, 0.5],  # noise std devs to sweep over
    n_models=100,                  # how many random nets per run
    complexity_percentile=25,      # top/bottom 25% are “complex”/“simple”
    n_variance_levels=20,          # how many IS‐bins to attempt
    n_bootstrap=20,
    make_plots=False,               # toggle your two matplotlib plots
    architecture_configs=architecture_configs
)

print(analysis)


Experiment 1/36: arch=cnn, dataset=cifar, seed=1, noise=0.1
Running run_id       : cnn_cifar_s1_n0.1

Resize scale factor  : 1.40  →  target size 44
Files already downloaded and verified
Input shape          : (8, 3, 44, 44)


Measuring models: 100%|██████████| 100/100 [03:52<00:00,  2.33s/it]


Generated models accepted       : 100
Unique datapoints               : 2834

IS Scaling with Depth
  IS   ~ depth^-4.295   (R² = 0.628)

Bin Analysis Summary
  Number of valid IS-bins   : 2
  IS-bins range             : [68.870, 351.859]
  Complexity score range    : [-0.715, 0.742]
  σ range                   : [0.160, 2.231]

Per-Bin Table (one row per valid bin)
|   IS   | simple | complex |   CV    | Ratio σ | pval_ratio |  MI_ACS  | pval_ACS | MI_params | pval_params |
|--------|--------|---------|---------|---------|------------|----------|----------|-----------|-------------|
|  68.87 |     21 |      18 |   0.335 |   1.125 |      0.604 |    0.000 |    1.000 |     0.000 |       1.000 |
| 351.86 |      3 |       3 |   0.507 |   2.085 |      0.600 |    0.000 |    1.000 |     0.000 |       1.000 |

Run-Level Metrics Table
| Metric                              | Value                                       |
|-------------------------------------|-------------------------------------

Measuring models: 100%|██████████| 100/100 [03:38<00:00,  2.19s/it]


Generated models accepted       : 100
Unique datapoints               : 2831

IS Scaling with Depth
  IS   ~ depth^-4.295   (R² = 0.628)

Bin Analysis Summary
  Number of valid IS-bins   : 2
  IS-bins range             : [68.870, 351.859]
  Complexity score range    : [-0.715, 0.742]
  σ range                   : [0.200, 2.168]

Per-Bin Table (one row per valid bin)
|   IS   | simple | complex |   CV    | Ratio σ | pval_ratio |  MI_ACS  | pval_ACS | MI_params | pval_params |
|--------|--------|---------|---------|---------|------------|----------|----------|-----------|-------------|
|  68.87 |     21 |      18 |   0.321 |   1.119 |      0.604 |    0.000 |    1.000 |     0.000 |       1.000 |
| 351.86 |      3 |       3 |   0.543 |   1.935 |      0.600 |    0.017 |    0.408 |     0.078 |       0.214 |

Run-Level Metrics Table
| Metric                              | Value                                       |
|-------------------------------------|-------------------------------------

Measuring models: 100%|██████████| 100/100 [03:41<00:00,  2.21s/it]


Generated models accepted       : 100
Unique datapoints               : 2836

IS Scaling with Depth
  IS   ~ depth^-4.295   (R² = 0.628)

Bin Analysis Summary
  Number of valid IS-bins   : 2
  IS-bins range             : [68.870, 351.859]
  Complexity score range    : [-0.715, 0.742]
  σ range                   : [0.242, 2.162]

Per-Bin Table (one row per valid bin)
|   IS   | simple | complex |   CV    | Ratio σ | pval_ratio |  MI_ACS  | pval_ACS | MI_params | pval_params |
|--------|--------|---------|---------|---------|------------|----------|----------|-----------|-------------|
|  68.87 |     21 |      18 |   0.309 |   1.096 |      0.473 |    0.000 |    1.000 |     0.000 |       1.000 |
| 351.86 |      3 |       3 |   0.573 |   1.818 |      0.600 |    0.079 |    0.299 |     0.089 |       0.184 |

Run-Level Metrics Table
| Metric                              | Value                                       |
|-------------------------------------|-------------------------------------

Measuring models: 100%|██████████| 100/100 [03:47<00:00,  2.28s/it]


Generated models accepted       : 100
Unique datapoints               : 2821

IS Scaling with Depth
  IS   ~ depth^-4.341   (R² = 0.631)

Bin Analysis Summary
  Number of valid IS-bins   : 1
  IS-bins range             : [70.838, 70.838]
  Complexity score range    : [-0.715, 0.742]
  σ range                   : [0.184, 2.285]

Per-Bin Table (one row per valid bin)
|   IS   | simple | complex |   CV    | Ratio σ | pval_ratio |  MI_ACS  | pval_ACS | MI_params | pval_params |
|--------|--------|---------|---------|---------|------------|----------|----------|-----------|-------------|
|  70.84 |     21 |      20 |   0.346 |   1.163 |      0.410 |    0.200 |    0.040 |     0.000 |       1.000 |

Run-Level Metrics Table
| Metric                              | Value                                       |
|-------------------------------------|---------------------------------------------|
| Mean CV(bin) ± SD                   | 0.346 ± nan                                 |
| Max CV(bin)   

Measuring models: 100%|██████████| 100/100 [03:45<00:00,  2.26s/it]


Generated models accepted       : 100
Unique datapoints               : 2829

IS Scaling with Depth
  IS   ~ depth^-4.341   (R² = 0.631)

Bin Analysis Summary
  Number of valid IS-bins   : 1
  IS-bins range             : [70.838, 70.838]
  Complexity score range    : [-0.715, 0.742]
  σ range                   : [0.219, 2.107]

Per-Bin Table (one row per valid bin)
|   IS   | simple | complex |   CV    | Ratio σ | pval_ratio |  MI_ACS  | pval_ACS | MI_params | pval_params |
|--------|--------|---------|---------|---------|------------|----------|----------|-----------|-------------|
|  70.84 |     21 |      20 |   0.332 |   1.130 |      0.450 |    0.041 |    0.328 |     0.000 |       1.000 |

Run-Level Metrics Table
| Metric                              | Value                                       |
|-------------------------------------|---------------------------------------------|
| Mean CV(bin) ± SD                   | 0.332 ± nan                                 |
| Max CV(bin)   

Measuring models: 100%|██████████| 100/100 [03:41<00:00,  2.21s/it]


Generated models accepted       : 100
Unique datapoints               : 2829

IS Scaling with Depth
  IS   ~ depth^-4.341   (R² = 0.631)

Bin Analysis Summary
  Number of valid IS-bins   : 1
  IS-bins range             : [70.838, 70.838]
  Complexity score range    : [-0.715, 0.742]
  σ range                   : [0.290, 1.983]

Per-Bin Table (one row per valid bin)
|   IS   | simple | complex |   CV    | Ratio σ | pval_ratio |  MI_ACS  | pval_ACS | MI_params | pval_params |
|--------|--------|---------|---------|---------|------------|----------|----------|-----------|-------------|
|  70.84 |     21 |      20 |   0.320 |   1.101 |      0.466 |    0.035 |    0.338 |     0.000 |       0.383 |

Run-Level Metrics Table
| Metric                              | Value                                       |
|-------------------------------------|---------------------------------------------|
| Mean CV(bin) ± SD                   | 0.320 ± nan                                 |
| Max CV(bin)   

Measuring models: 100%|██████████| 100/100 [10:36<00:00,  6.37s/it]


Generated models accepted       : 100
Unique datapoints               : 3283

IS Scaling with Depth
  IS   ~ depth^-5.338   (R² = 0.726)

Bin Analysis Summary
  Number of valid IS-bins   : 1
  IS-bins range             : [107.637, 107.637]
  Complexity score range    : [-0.615, 1.213]
  σ range                   : [0.510, 1.489]

Per-Bin Table (one row per valid bin)
|   IS   | simple | complex |   CV    | Ratio σ | pval_ratio |  MI_ACS  | pval_ACS | MI_params | pval_params |
|--------|--------|---------|---------|---------|------------|----------|----------|-----------|-------------|
| 107.64 |     16 |      16 |   0.236 |   1.039 |      0.952 |    0.052 |    0.264 |     0.087 |       0.040 |

Run-Level Metrics Table
| Metric                              | Value                                       |
|-------------------------------------|---------------------------------------------|
| Mean CV(bin) ± SD                   | 0.236 ± nan                                 |
| Max CV(bin) 

Measuring models: 100%|██████████| 100/100 [11:16<00:00,  6.77s/it]


Generated models accepted       : 100
Unique datapoints               : 3283

IS Scaling with Depth
  IS   ~ depth^-5.338   (R² = 0.726)

Bin Analysis Summary
  Number of valid IS-bins   : 1
  IS-bins range             : [107.637, 107.637]
  Complexity score range    : [-0.615, 1.213]
  σ range                   : [0.515, 1.483]

Per-Bin Table (one row per valid bin)
|   IS   | simple | complex |   CV    | Ratio σ | pval_ratio |  MI_ACS  | pval_ACS | MI_params | pval_params |
|--------|--------|---------|---------|---------|------------|----------|----------|-----------|-------------|
| 107.64 |     16 |      16 |   0.232 |   1.052 |      0.952 |    0.000 |    1.000 |     0.058 |       0.075 |

Run-Level Metrics Table
| Metric                              | Value                                       |
|-------------------------------------|---------------------------------------------|
| Mean CV(bin) ± SD                   | 0.232 ± nan                                 |
| Max CV(bin) 

Measuring models: 100%|██████████| 100/100 [11:15<00:00,  6.76s/it]


Generated models accepted       : 100
Unique datapoints               : 3282

IS Scaling with Depth
  IS   ~ depth^-5.338   (R² = 0.726)

Bin Analysis Summary
  Number of valid IS-bins   : 1
  IS-bins range             : [107.637, 107.637]
  Complexity score range    : [-0.615, 1.213]
  σ range                   : [0.531, 1.476]

Per-Bin Table (one row per valid bin)
|   IS   | simple | complex |   CV    | Ratio σ | pval_ratio |  MI_ACS  | pval_ACS | MI_params | pval_params |
|--------|--------|---------|---------|---------|------------|----------|----------|-----------|-------------|
| 107.64 |     16 |      16 |   0.226 |   1.065 |      0.716 |    0.002 |    0.458 |     0.036 |       0.169 |

Run-Level Metrics Table
| Metric                              | Value                                       |
|-------------------------------------|---------------------------------------------|
| Mean CV(bin) ± SD                   | 0.226 ± nan                                 |
| Max CV(bin) 

Measuring models:  47%|████▋     | 47/100 [06:11<01:50,  2.08s/it]