# Visual Experiment Studio

This notebook provides a **visual, organized UI** for controlling **all experiment parameters**.

**What you can customize:**
- Phase mode (continuous/discrete) and discrete bits
- Probe bank method: random, Hadamard, Sobol, or Halton
- All system/data/model/training/eval parameters
- Which models to compare
- Which plots to render

> Edit values in the widgets, then press **Run Experiments**.



## How to see the Visual IDE

1. **Run the first two code cells** (imports + widget UI). The tabs and controls will appear right below the "Run Experiments" button.
2. If you are in **JupyterLab**, ensure the `ipywidgets` extension is enabled (installed via `requirements.txt`).
3. The UI is organized in tabs: **System**, **Data**, **Model**, **Training**, **Eval**, **Compare/Plots**.

> If you only see code, it means the widget cell hasn’t been executed yet.


In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display

from config import get_config
from data_generation import generate_probe_bank, create_dataloaders
from training import train
from evaluation import evaluate_model
from utils import (
    plot_training_history,
    plot_eta_distribution,
    plot_top_m_comparison,
    plot_baseline_comparison,
)
from model import LimitedProbingMLP


In [None]:
# ---- MODEL REGISTRY ----
def build_mlp_default(config):
    return LimitedProbingMLP(
        K=config.system.K,
        hidden_sizes=config.model.hidden_sizes,
        dropout_prob=config.model.dropout_prob,
        use_batch_norm=config.model.use_batch_norm,
    )

def build_mlp_shallow(config):
    return LimitedProbingMLP(
        K=config.system.K,
        hidden_sizes=[128],
        dropout_prob=0.0,
        use_batch_norm=False,
    )

def build_linear(config):
    input_size = 2 * config.system.K
    output_size = config.system.K
    return torch.nn.Sequential(torch.nn.Linear(input_size, output_size))

MODEL_REGISTRY = {
    'mlp_default': build_mlp_default,
    'mlp_shallow': build_mlp_shallow,
    'linear': build_linear,
}


In [None]:
# ---- PLOT REGISTRY ----
def plot_model_metric_bars(runs, metrics):
    names = [r['name'] for r in runs]
    x = np.arange(len(names))
    width = 0.2
    fig, ax = plt.subplots(figsize=(10, 5))
    for i, metric in enumerate(metrics):
        values = [getattr(r['results'], metric) for r in runs]
        ax.bar(x + i * width, values, width, label=metric)
    ax.set_xticks(x + width * (len(metrics) - 1) / 2)
    ax.set_xticklabels(names, rotation=15)
    ax.set_ylabel('Metric value')
    ax.set_title('Model Comparison')
    ax.legend()
    plt.tight_layout()
    plt.show()

def plot_eta_boxplot(runs):
    labels = [r['name'] for r in runs]
    data = [r['results'].eta_top1_distribution for r in runs]
    fig, ax = plt.subplots(figsize=(10, 5))
    ax.boxplot(data, labels=labels, showmeans=True)
    ax.set_ylabel('η (Top-1)')
    ax.set_title('η Distribution by Model')
    plt.tight_layout()
    plt.show()

PLOT_REGISTRY = {
    'training_history': lambda run: plot_training_history(run['history']),
    'eta_distribution': lambda run: plot_eta_distribution(run['results']),
    'top_m_comparison': lambda run: plot_top_m_comparison(run['results']),
    'baseline_comparison': lambda run: plot_baseline_comparison(run['results']),
    'model_metric_bars': lambda runs, metrics: plot_model_metric_bars(runs, metrics),
    'eta_boxplot': lambda runs, _metrics: plot_eta_boxplot(runs),
}


In [None]:
# ---- WIDGET UI ----
default_config = get_config()

system_widgets = {
    'N': widgets.IntText(value=default_config.system.N, description='N'),
    'K': widgets.IntText(value=default_config.system.K, description='K'),
    'M': widgets.IntText(value=default_config.system.M, description='M'),
    'P_tx': widgets.FloatText(value=default_config.system.P_tx, description='P_tx'),
    'sigma_h_sq': widgets.FloatText(value=default_config.system.sigma_h_sq, description='sigma_h_sq'),
    'sigma_g_sq': widgets.FloatText(value=default_config.system.sigma_g_sq, description='sigma_g_sq'),
    'phase_mode': widgets.Dropdown(
        options=['continuous', 'discrete'],
        value=default_config.system.phase_mode,
        description='phase_mode'
    ),
    'phase_bits': widgets.IntText(value=default_config.system.phase_bits, description='phase_bits'),
    'probe_bank_method': widgets.Dropdown(
        options=['random', 'hadamard', 'sobol', 'halton'],
        value=default_config.system.probe_bank_method,
        description='probe_method'
    ),
}

data_widgets = {
    'n_train': widgets.IntText(value=default_config.data.n_train, description='n_train'),
    'n_val': widgets.IntText(value=default_config.data.n_val, description='n_val'),
    'n_test': widgets.IntText(value=default_config.data.n_test, description='n_test'),
    'seed': widgets.IntText(value=default_config.data.seed, description='seed'),
    'normalize_input': widgets.Checkbox(value=default_config.data.normalize_input, description='normalize_input'),
}

model_widgets = {
    'hidden_sizes': widgets.Text(
        value=','.join(map(str, default_config.model.hidden_sizes)),
        description='hidden_sizes'
    ),
    'dropout_prob': widgets.FloatText(value=default_config.model.dropout_prob, description='dropout'),
    'use_batch_norm': widgets.Checkbox(value=default_config.model.use_batch_norm, description='batch_norm'),
}

training_widgets = {
    'batch_size': widgets.IntText(value=default_config.training.batch_size, description='batch_size'),
    'learning_rate': widgets.FloatText(value=default_config.training.learning_rate, description='lr'),
    'weight_decay': widgets.FloatText(value=default_config.training.weight_decay, description='weight_decay'),
    'num_epochs': widgets.IntText(value=default_config.training.num_epochs, description='epochs'),
    'early_stopping_patience': widgets.IntText(
        value=default_config.training.early_stopping_patience,
        description='early_stop'
    ),
    'eval_interval': widgets.IntText(value=default_config.training.eval_interval, description='eval_interval'),
}

eval_widgets = {
    'top_m_values': widgets.Text(
        value=','.join(map(str, default_config.eval.top_m_values)),
        description='top_m_values'
    )
}

comparison_widgets = {
    'models_to_run': widgets.SelectMultiple(
        options=sorted(MODEL_REGISTRY.keys()),
        value=tuple(sorted(MODEL_REGISTRY.keys())),
        description='models'
    ),
    'plots': widgets.SelectMultiple(
        options=[
            'training_history',
            'eta_distribution',
            'top_m_comparison',
            'baseline_comparison',
            'model_metric_bars',
            'eta_boxplot',
        ],
        value=('model_metric_bars', 'eta_boxplot'),
        description='plots'
    ),
    'compare_metrics': widgets.Text(
        value='accuracy_top1,eta_top1,eta_top2',
        description='metrics'
    ),
}

tabs = widgets.Tab()
tabs.children = [
    widgets.VBox(list(system_widgets.values())),
    widgets.VBox(list(data_widgets.values())),
    widgets.VBox(list(model_widgets.values())),
    widgets.VBox(list(training_widgets.values())),
    widgets.VBox(list(eval_widgets.values())),
    widgets.VBox(list(comparison_widgets.values())),
]
tabs.set_title(0, 'System')
tabs.set_title(1, 'Data')
tabs.set_title(2, 'Model')
tabs.set_title(3, 'Training')
tabs.set_title(4, 'Eval')
tabs.set_title(5, 'Compare/Plots')

run_button = widgets.Button(description='Run Experiments', button_style='success')
output = widgets.Output()

display(tabs, run_button, output)


In [None]:
# ---- EXPERIMENT LOGIC ----
def parse_int_list(text):
    return [int(x.strip()) for x in text.split(',') if x.strip()]

def parse_float_list(text):
    return [float(x.strip()) for x in text.split(',') if x.strip()]

def collect_config():
    system = {k: w.value for k, w in system_widgets.items()}
    data = {k: w.value for k, w in data_widgets.items()}
    model = {k: w.value for k, w in model_widgets.items()}
    training = {k: w.value for k, w in training_widgets.items()}
    eval_cfg = {k: w.value for k, w in eval_widgets.items()}

    model['hidden_sizes'] = parse_int_list(model['hidden_sizes'])
    eval_cfg['top_m_values'] = parse_int_list(eval_cfg['top_m_values'])

    comparison = {
        'models_to_run': list(comparison_widgets['models_to_run'].value),
        'plots': list(comparison_widgets['plots'].value),
        'compare_metrics': [m.strip() for m in comparison_widgets['compare_metrics'].value.split(',') if m.strip()],
    }

    return {
        'system': system,
        'data': data,
        'model': model,
        'training': training,
        'eval': eval_cfg,
        **comparison,
    }

def run_single_model(model_name, overrides):
    config = get_config(
        system=overrides['system'],
        data=overrides['data'],
        model=overrides['model'],
        training=overrides['training'],
        eval=overrides['eval'],
    )

    torch.manual_seed(config.data.seed)
    np.random.seed(config.data.seed)

    probe_bank = generate_probe_bank(
        N=config.system.N,
        K=config.system.K,
        seed=config.data.seed,
        phase_mode=config.system.phase_mode,
        phase_bits=config.system.phase_bits,
        probe_bank_method=config.system.probe_bank_method,
    )

    train_loader, val_loader, test_loader, metadata = create_dataloaders(config, probe_bank)
    model = MODEL_REGISTRY[model_name](config)
    model, history = train(model, train_loader, val_loader, config, metadata)

    results = evaluate_model(
        model,
        test_loader,
        config,
        metadata['test_powers_full'],
        metadata['test_labels'],
        metadata['test_observed_indices'],
        metadata['test_optimal_powers'],
    )

    return {
        'name': model_name,
        'model': model,
        'history': history,
        'results': results,
        'config': config,
    }

def run_comparison(config_dict):
    runs = []
    for model_name in config_dict['models_to_run']:
        runs.append(run_single_model(model_name, config_dict))
    return runs

def render_plots(runs, plots, compare_metrics):
    for plot_name in plots:
        if plot_name in {'model_metric_bars', 'eta_boxplot'}:
            PLOT_REGISTRY[plot_name](runs, compare_metrics)
        else:
            for run in runs:
                PLOT_REGISTRY[plot_name](run)

def on_run_clicked(_):
    output.clear_output()
    with output:
        config_dict = collect_config()
        runs = run_comparison(config_dict)
        render_plots(runs, config_dict['plots'], config_dict['compare_metrics'])

run_button.on_click(on_run_clicked)
