# Fully Customizable RIS Probe-Based ML Notebook

This notebook exposes a **single config dictionary** to control:
- Continuous vs. discrete phases (and quantization bits)
- Which models run
- Which comparisons and plots are generated
- Training/data sizes for fast iteration

Keep changes **short** by editing only the `CONFIG` dict.


In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt

from config import get_config
from data_generation import generate_probe_bank, create_dataloaders
from training import train
from evaluation import evaluate_model
from utils import (
    plot_training_history,
    plot_eta_distribution,
    plot_top_m_comparison,
    plot_baseline_comparison,
)
from model import LimitedProbingMLP


In [None]:
# ---- CONFIGURATION (edit only this dict for quick customization) ----
CONFIG = {
    # System + data
    'system': {
        'N': 32,
        'K': 64,
        'M': 8,
        'phase_mode': 'continuous',  # 'continuous' or 'discrete'
        'phase_bits': 3,            # used only when phase_mode='discrete'
    },
    'data': {
        'n_train': 3000,
        'n_val': 600,
        'n_test': 600,
        'seed': 42,
        'normalize_input': True,
    },
    # Training
    'training': {
        'num_epochs': 10,
        'batch_size': 128,
        'learning_rate': 1e-3,
    },
    # Default MLP config
    'model': {
        'hidden_sizes': [256, 128],
        'dropout_prob': 0.1,
        'use_batch_norm': True,
    },
    # Comparison controls
    'models_to_run': [
        'mlp_default',
        'mlp_shallow',
        'linear',
    ],
    'plots': [
        'training_history',
        'eta_distribution',
        'top_m_comparison',
        'baseline_comparison',
        'model_metric_bars',
        'eta_boxplot',
    ],
    # Metrics shown in bar comparison
    'compare_metrics': ['accuracy_top1', 'eta_top1', 'eta_top2'],
}


In [None]:
# ---- MODEL REGISTRY (add your own builders here) ----
def build_mlp_default(config):
    return LimitedProbingMLP(
        K=config.system.K,
        hidden_sizes=config.model.hidden_sizes,
        dropout_prob=config.model.dropout_prob,
        use_batch_norm=config.model.use_batch_norm,
    )

def build_mlp_shallow(config):
    return LimitedProbingMLP(
        K=config.system.K,
        hidden_sizes=[128],
        dropout_prob=0.0,
        use_batch_norm=False,
    )

def build_linear(config):
    # Simple linear classifier as a fast baseline
    input_size = 2 * config.system.K
    output_size = config.system.K
    return torch.nn.Sequential(torch.nn.Linear(input_size, output_size))

MODEL_REGISTRY = {
    'mlp_default': build_mlp_default,
    'mlp_shallow': build_mlp_shallow,
    'linear': build_linear,
}


In [None]:
# ---- EXPERIMENT HELPERS ----
def build_config(overrides):
    return get_config(**overrides)

def run_single_model(model_name, overrides):
    config = build_config(overrides)

    # Set seeds for reproducibility
    torch.manual_seed(config.data.seed)
    np.random.seed(config.data.seed)

    probe_bank = generate_probe_bank(
        N=config.system.N,
        K=config.system.K,
        seed=config.data.seed,
        phase_mode=config.system.phase_mode,
        phase_bits=config.system.phase_bits,
    )

    train_loader, val_loader, test_loader, metadata = create_dataloaders(
        config, probe_bank
    )

    model = MODEL_REGISTRY[model_name](config)
    model, history = train(model, train_loader, val_loader, config, metadata)

    results = evaluate_model(
        model,
        test_loader,
        config,
        metadata['test_powers_full'],
        metadata['test_labels'],
        metadata['test_observed_indices'],
        metadata['test_optimal_powers'],
    )

    return {
        'name': model_name,
        'model': model,
        'history': history,
        'results': results,
        'config': config,
    }

def run_comparison(config_dict):
    overrides = {
        'system': config_dict['system'],
        'data': config_dict['data'],
        'training': config_dict['training'],
        'model': config_dict['model'],
    }
    results = []
    for model_name in config_dict['models_to_run']:
        results.append(run_single_model(model_name, overrides))
    return results


In [None]:
# ---- RUN EXPERIMENTS ----
all_runs = run_comparison(CONFIG)
print('Finished models:', [run['name'] for run in all_runs])


In [None]:
# ---- PLOT REGISTRY ----
def plot_model_metric_bars(runs, metrics):
    names = [r['name'] for r in runs]
    x = np.arange(len(names))
    width = 0.2
    fig, ax = plt.subplots(figsize=(10, 5))
    for i, metric in enumerate(metrics):
        values = [getattr(r['results'], metric) for r in runs]
        ax.bar(x + i * width, values, width, label=metric)
    ax.set_xticks(x + width * (len(metrics) - 1) / 2)
    ax.set_xticklabels(names, rotation=15)
    ax.set_ylabel('Metric value')
    ax.set_title('Model Comparison')
    ax.legend()
    plt.tight_layout()
    plt.show()

def plot_eta_boxplot(runs):
    labels = [r['name'] for r in runs]
    data = [r['results'].eta_top1_distribution for r in runs]
    fig, ax = plt.subplots(figsize=(10, 5))
    ax.boxplot(data, labels=labels, showmeans=True)
    ax.set_ylabel('η (Top-1)')
    ax.set_title('η Distribution by Model')
    plt.tight_layout()
    plt.show()

PLOT_REGISTRY = {
    'training_history': lambda run: plot_training_history(run['history']),
    'eta_distribution': lambda run: plot_eta_distribution(run['results']),
    'top_m_comparison': lambda run: plot_top_m_comparison(run['results']),
    'baseline_comparison': lambda run: plot_baseline_comparison(run['results']),
    'model_metric_bars': lambda runs: plot_model_metric_bars(runs, CONFIG['compare_metrics']),
    'eta_boxplot': plot_eta_boxplot,
}


In [None]:
# ---- RENDER PLOTS ----
for plot_name in CONFIG['plots']:
    if plot_name in {'model_metric_bars', 'eta_boxplot'}:
        PLOT_REGISTRY[plot_name](all_runs)
    else:
        # per-model plots
        for run in all_runs:
            PLOT_REGISTRY[plot_name](run)
