# Visual Experiment Studio (Fully Customizable)

This notebook provides a **visual, widget-based interface** to configure every parameter
in the project, select probe-bank methods (random, Hadamard, Sobol, Halton), choose
models to compare, and render multiple plot types.

Use the controls below, then click **Run Experiments**.


In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output

from config import get_config
from data_generation import generate_probe_bank, create_dataloaders
from training import train
from evaluation import evaluate_model
from utils import (
    plot_training_history,
    plot_eta_distribution,
    plot_top_m_comparison,
    plot_baseline_comparison,
)
from model import LimitedProbingMLP


In [None]:
# ---- MODEL REGISTRY (add your own builders here) ----
def build_mlp_default(config):
    return LimitedProbingMLP(
        K=config.system.K,
        hidden_sizes=config.model.hidden_sizes,
        dropout_prob=config.model.dropout_prob,
        use_batch_norm=config.model.use_batch_norm,
    )

def build_mlp_shallow(config):
    return LimitedProbingMLP(
        K=config.system.K,
        hidden_sizes=[128],
        dropout_prob=0.0,
        use_batch_norm=False,
    )

def build_linear(config):
    input_size = 2 * config.system.K
    output_size = config.system.K
    return torch.nn.Sequential(torch.nn.Linear(input_size, output_size))

MODEL_REGISTRY = {
    'mlp_default': build_mlp_default,
    'mlp_shallow': build_mlp_shallow,
    'linear': build_linear,
}


In [None]:
# ---- VISUAL CONFIG PANEL ----
# System
system_N = widgets.IntText(value=32, description='N')
system_K = widgets.IntText(value=64, description='K')
system_M = widgets.IntText(value=8, description='M')
system_P_tx = widgets.FloatText(value=1.0, description='P_tx')
system_sigma_h = widgets.FloatText(value=1.0, description='sigma_h_sq')
system_sigma_g = widgets.FloatText(value=1.0, description='sigma_g_sq')
system_phase_mode = widgets.Dropdown(options=['continuous', 'discrete'], value='continuous', description='phase_mode')
system_phase_bits = widgets.IntText(value=3, description='phase_bits')
system_probe_bank_mode = widgets.Dropdown(options=['random', 'hadamard', 'sobol', 'halton'], value='random', description='probe_bank')
system_qmc_scramble = widgets.Checkbox(value=True, description='qmc_scramble')

system_box = widgets.VBox([
    system_N, system_K, system_M, system_P_tx,
    system_sigma_h, system_sigma_g,
    system_phase_mode, system_phase_bits,
    system_probe_bank_mode, system_qmc_scramble
])

# Data
data_n_train = widgets.IntText(value=3000, description='n_train')
data_n_val = widgets.IntText(value=600, description='n_val')
data_n_test = widgets.IntText(value=600, description='n_test')
data_seed = widgets.IntText(value=42, description='seed')
data_normalize = widgets.Checkbox(value=True, description='normalize_input')
data_norm_type = widgets.Dropdown(options=['mean', 'std', 'log'], value='mean', description='normalization_type')
data_box = widgets.VBox([
    data_n_train, data_n_val, data_n_test, data_seed, data_normalize, data_norm_type
])

# Model
model_hidden = widgets.Text(value='256,128', description='hidden_sizes')
model_dropout = widgets.FloatText(value=0.1, description='dropout_prob')
model_batch_norm = widgets.Checkbox(value=True, description='use_batch_norm')
model_box = widgets.VBox([model_hidden, model_dropout, model_batch_norm])

# Training
train_batch = widgets.IntText(value=128, description='batch_size')
train_lr = widgets.FloatText(value=1e-3, description='learning_rate')
train_weight_decay = widgets.FloatText(value=1e-5, description='weight_decay')
train_epochs = widgets.IntText(value=10, description='num_epochs')
train_early = widgets.IntText(value=15, description='early_stopping_patience')
train_eval_interval = widgets.IntText(value=1, description='eval_interval')
train_device = widgets.Dropdown(options=['cpu', 'cuda'], value='cpu', description='device')
train_box = widgets.VBox([
    train_batch, train_lr, train_weight_decay, train_epochs,
    train_early, train_eval_interval, train_device
])

# Eval
eval_top_m = widgets.Text(value='1,2,4,8', description='top_m_values')
eval_box = widgets.VBox([eval_top_m])

# Experiment selection
model_select = widgets.SelectMultiple(
    options=list(MODEL_REGISTRY.keys()),
    value=('mlp_default',),
    description='models',
)
plot_select = widgets.SelectMultiple(
    options=[
        'training_history', 'eta_distribution', 'top_m_comparison',
        'baseline_comparison', 'model_metric_bars', 'eta_boxplot'
    ],
    value=('model_metric_bars', 'eta_boxplot'),
    description='plots',
)
compare_metrics = widgets.Text(value='accuracy_top1,eta_top1,eta_top2', description='compare_metrics')
experiment_box = widgets.VBox([model_select, plot_select, compare_metrics])

# Layout
tabs = widgets.Tab(children=[system_box, data_box, model_box, train_box, eval_box, experiment_box])
titles = ['System', 'Data', 'Model', 'Training', 'Eval', 'Experiment']
for i, title in enumerate(titles):
    tabs.set_title(i, title)

display(tabs)


In [None]:
# ---- RUNNER + PLOTS ----
def _parse_int_list(text):
    return [int(x.strip()) for x in text.split(',') if x.strip()]

def _parse_float_list(text):
    return [float(x.strip()) for x in text.split(',') if x.strip()]

def build_config_from_widgets():
    return {
        'system': {
            'N': system_N.value,
            'K': system_K.value,
            'M': system_M.value,
            'P_tx': system_P_tx.value,
            'sigma_h_sq': system_sigma_h.value,
            'sigma_g_sq': system_sigma_g.value,
            'phase_mode': system_phase_mode.value,
            'phase_bits': system_phase_bits.value,
            'probe_bank_mode': system_probe_bank_mode.value,
            'qmc_scramble': system_qmc_scramble.value,
        },
        'data': {
            'n_train': data_n_train.value,
            'n_val': data_n_val.value,
            'n_test': data_n_test.value,
            'seed': data_seed.value,
            'normalize_input': data_normalize.value,
            'normalization_type': data_norm_type.value,
        },
        'model': {
            'hidden_sizes': _parse_int_list(model_hidden.value),
            'dropout_prob': model_dropout.value,
            'use_batch_norm': model_batch_norm.value,
        },
        'training': {
            'batch_size': train_batch.value,
            'learning_rate': train_lr.value,
            'weight_decay': train_weight_decay.value,
            'num_epochs': train_epochs.value,
            'early_stopping_patience': train_early.value,
            'eval_interval': train_eval_interval.value,
            'device': train_device.value,
        },
        'eval': {
            'top_m_values': _parse_int_list(eval_top_m.value),
        },
        'models_to_run': list(model_select.value),
        'plots': list(plot_select.value),
        'compare_metrics': [m.strip() for m in compare_metrics.value.split(',') if m.strip()],
    }

def run_single_model(model_name, overrides):
    config = get_config(**overrides)
    torch.manual_seed(config.data.seed)
    np.random.seed(config.data.seed)

    probe_bank = generate_probe_bank(
        N=config.system.N,
        K=config.system.K,
        seed=config.data.seed,
        phase_mode=config.system.phase_mode,
        phase_bits=config.system.phase_bits,
        probe_bank_mode=config.system.probe_bank_mode,
        qmc_scramble=config.system.qmc_scramble,
    )

    train_loader, val_loader, test_loader, metadata = create_dataloaders(
        config, probe_bank
    )

    model = MODEL_REGISTRY[model_name](config)
    model, history = train(model, train_loader, val_loader, config, metadata)

    results = evaluate_model(
        model,
        test_loader,
        config,
        metadata['test_powers_full'],
        metadata['test_labels'],
        metadata['test_observed_indices'],
        metadata['test_optimal_powers'],
    )

    return {
        'name': model_name,
        'model': model,
        'history': history,
        'results': results,
        'config': config,
    }

def run_comparison(config_dict):
    overrides = {
        'system': config_dict['system'],
        'data': config_dict['data'],
        'training': config_dict['training'],
        'model': config_dict['model'],
        'eval': config_dict['eval'],
    }
    runs = []
    for model_name in config_dict['models_to_run']:
        runs.append(run_single_model(model_name, overrides))
    return runs

def plot_model_metric_bars(runs, metrics):
    names = [r['name'] for r in runs]
    x = np.arange(len(names))
    width = 0.2
    fig, ax = plt.subplots(figsize=(10, 5))
    for i, metric in enumerate(metrics):
        values = [getattr(r['results'], metric) for r in runs]
        ax.bar(x + i * width, values, width, label=metric)
    ax.set_xticks(x + width * (len(metrics) - 1) / 2)
    ax.set_xticklabels(names, rotation=15)
    ax.set_ylabel('Metric value')
    ax.set_title('Model Comparison')
    ax.legend()
    plt.tight_layout()
    plt.show()

def plot_eta_boxplot(runs):
    labels = [r['name'] for r in runs]
    data = [r['results'].eta_top1_distribution for r in runs]
    fig, ax = plt.subplots(figsize=(10, 5))
    ax.boxplot(data, labels=labels, showmeans=True)
    ax.set_ylabel('η (Top-1)')
    ax.set_title('η Distribution by Model')
    plt.tight_layout()
    plt.show()

PLOT_REGISTRY = {
    'training_history': lambda run: plot_training_history(run['history']),
    'eta_distribution': lambda run: plot_eta_distribution(run['results']),
    'top_m_comparison': lambda run: plot_top_m_comparison(run['results']),
    'baseline_comparison': lambda run: plot_baseline_comparison(run['results']),
    'model_metric_bars': lambda runs, metrics: plot_model_metric_bars(runs, metrics),
    'eta_boxplot': plot_eta_boxplot,
}

run_button = widgets.Button(description='Run Experiments', button_style='success')
output = widgets.Output()

def on_run_clicked(_):
    with output:
        clear_output()
        config_dict = build_config_from_widgets()
        runs = run_comparison(config_dict)

        for plot_name in config_dict['plots']:
            if plot_name in {'model_metric_bars', 'eta_boxplot'}:
                if plot_name == 'model_metric_bars':
                    PLOT_REGISTRY[plot_name](runs, config_dict['compare_metrics'])
                else:
                    PLOT_REGISTRY[plot_name](runs)
            else:
                for run in runs:
                    PLOT_REGISTRY[plot_name](run)

run_button.on_click(on_run_clicked)
display(run_button, output)
