# Visual Experiment IDE (Fully Customizable)

Use the UI below to configure **all parameters** in the project, choose probe-bank methods (random / Hadamard / Sobol / Halton), select models, and render plots.

**Tip:** If widgets are missing, install them with `pip install ipywidgets` and restart the kernel.


In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output

from config import get_config
from data_generation import generate_probe_bank, create_dataloaders
from training import train
from evaluation import evaluate_model
from utils import (
    plot_training_history,
    plot_eta_distribution,
    plot_top_m_comparison,
    plot_baseline_comparison,
)
from model import LimitedProbingMLP


In [None]:
# ---- DEFAULT CONFIG (all parameters exposed) ----
DEFAULT_CONFIG = {
    'system': {
        'N': 32,
        'K': 64,
        'M': 8,
        'P_tx': 1.0,
        'sigma_h_sq': 1.0,
        'sigma_g_sq': 1.0,
        'phase_mode': 'continuous',  # continuous / discrete
        'phase_bits': 3,
        'probe_bank_method': 'random',  # random / hadamard / sobol / halton
    },
    'data': {
        'n_train': 3000,
        'n_val': 600,
        'n_test': 600,
        'seed': 42,
        'normalize_input': True,
        'normalization_type': 'mean',
    },
    'model': {
        'hidden_sizes': [256, 128],
        'dropout_prob': 0.1,
        'use_batch_norm': True,
    },
    'training': {
        'batch_size': 128,
        'learning_rate': 1e-3,
        'weight_decay': 1e-5,
        'num_epochs': 10,
        'early_stopping_patience': 15,
        'eval_interval': 1,
        'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    },
    'eval': {
        'top_m_values': [1, 2, 4, 8],
    },
    'probe_bank_methods': ['random'],
    'models_to_run': ['mlp_default', 'mlp_shallow', 'linear'],
    'plots': [
        'training_history',
        'eta_distribution',
        'top_m_comparison',
        'baseline_comparison',
        'model_metric_bars',
        'eta_boxplot',
    ],
    'compare_metrics': ['accuracy_top1', 'eta_top1', 'eta_top2'],
}


In [None]:
# ---- MODEL REGISTRY ----
def build_mlp_default(config):
    return LimitedProbingMLP(
        K=config.system.K,
        hidden_sizes=config.model.hidden_sizes,
        dropout_prob=config.model.dropout_prob,
        use_batch_norm=config.model.use_batch_norm,
    )

def build_mlp_shallow(config):
    return LimitedProbingMLP(
        K=config.system.K,
        hidden_sizes=[128],
        dropout_prob=0.0,
        use_batch_norm=False,
    )

def build_linear(config):
    input_size = 2 * config.system.K
    output_size = config.system.K
    return torch.nn.Sequential(torch.nn.Linear(input_size, output_size))

MODEL_REGISTRY = {
    'mlp_default': build_mlp_default,
    'mlp_shallow': build_mlp_shallow,
    'linear': build_linear,
}


In [None]:
# ---- UI BUILDERS ----
def _list_to_str(values):
    return ','.join(str(v) for v in values)

def _str_to_int_list(text):
    return [int(v.strip()) for v in text.split(',') if v.strip()]

system_widgets = {
    'N': widgets.IntText(value=DEFAULT_CONFIG['system']['N'], description='N'),
    'K': widgets.IntText(value=DEFAULT_CONFIG['system']['K'], description='K'),
    'M': widgets.IntText(value=DEFAULT_CONFIG['system']['M'], description='M'),
    'P_tx': widgets.FloatText(value=DEFAULT_CONFIG['system']['P_tx'], description='P_tx'),
    'sigma_h_sq': widgets.FloatText(value=DEFAULT_CONFIG['system']['sigma_h_sq'], description='sigma_h_sq'),
    'sigma_g_sq': widgets.FloatText(value=DEFAULT_CONFIG['system']['sigma_g_sq'], description='sigma_g_sq'),
    'phase_mode': widgets.Dropdown(options=['continuous', 'discrete'], value=DEFAULT_CONFIG['system']['phase_mode'], description='phase_mode'),
    'phase_bits': widgets.IntText(value=DEFAULT_CONFIG['system']['phase_bits'], description='phase_bits'),
    'probe_bank_method': widgets.Dropdown(options=['random', 'hadamard', 'sobol', 'halton'], value=DEFAULT_CONFIG['system']['probe_bank_method'], description='bank_method'),
}

data_widgets = {
    'n_train': widgets.IntText(value=DEFAULT_CONFIG['data']['n_train'], description='n_train'),
    'n_val': widgets.IntText(value=DEFAULT_CONFIG['data']['n_val'], description='n_val'),
    'n_test': widgets.IntText(value=DEFAULT_CONFIG['data']['n_test'], description='n_test'),
    'seed': widgets.IntText(value=DEFAULT_CONFIG['data']['seed'], description='seed'),
    'normalize_input': widgets.Checkbox(value=DEFAULT_CONFIG['data']['normalize_input'], description='normalize_input'),
    'normalization_type': widgets.Dropdown(options=['mean', 'std', 'log'], value=DEFAULT_CONFIG['data']['normalization_type'], description='normalization_type'),
}

model_widgets = {
    'hidden_sizes': widgets.Text(value=_list_to_str(DEFAULT_CONFIG['model']['hidden_sizes']), description='hidden_sizes'),
    'dropout_prob': widgets.FloatText(value=DEFAULT_CONFIG['model']['dropout_prob'], description='dropout_prob'),
    'use_batch_norm': widgets.Checkbox(value=DEFAULT_CONFIG['model']['use_batch_norm'], description='use_batch_norm'),
}

training_widgets = {
    'batch_size': widgets.IntText(value=DEFAULT_CONFIG['training']['batch_size'], description='batch_size'),
    'learning_rate': widgets.FloatText(value=DEFAULT_CONFIG['training']['learning_rate'], description='learning_rate'),
    'weight_decay': widgets.FloatText(value=DEFAULT_CONFIG['training']['weight_decay'], description='weight_decay'),
    'num_epochs': widgets.IntText(value=DEFAULT_CONFIG['training']['num_epochs'], description='num_epochs'),
    'early_stopping_patience': widgets.IntText(value=DEFAULT_CONFIG['training']['early_stopping_patience'], description='early_stopping_patience'),
    'eval_interval': widgets.IntText(value=DEFAULT_CONFIG['training']['eval_interval'], description='eval_interval'),
    'device': widgets.Dropdown(options=['cpu', 'cuda'], value=DEFAULT_CONFIG['training']['device'], description='device'),
}

eval_widgets = {
    'top_m_values': widgets.Text(value=_list_to_str(DEFAULT_CONFIG['eval']['top_m_values']), description='top_m_values'),
}

comparison_widgets = {
    'probe_bank_methods': widgets.SelectMultiple(
        options=['random', 'hadamard', 'sobol', 'halton'],
        value=tuple(DEFAULT_CONFIG['probe_bank_methods']),
        description='probe_methods'
    ),
    'models_to_run': widgets.SelectMultiple(
        options=sorted(MODEL_REGISTRY.keys()),
        value=tuple(DEFAULT_CONFIG['models_to_run']),
        description='models'
    ),
    'plots': widgets.SelectMultiple(
        options=[
            'training_history', 'eta_distribution', 'top_m_comparison',
            'baseline_comparison', 'model_metric_bars', 'eta_boxplot',
        ],
        value=tuple(DEFAULT_CONFIG['plots']),
        description='plots'
    ),
    'compare_metrics': widgets.Text(value=_list_to_str(DEFAULT_CONFIG['compare_metrics']), description='compare_metrics'),
}

def _as_box(widget_dict):
    return widgets.VBox(list(widget_dict.values()))

tabs = widgets.Tab()
tab_children = [
    _as_box(system_widgets),
    _as_box(data_widgets),
    _as_box(model_widgets),
    _as_box(training_widgets),
    _as_box(eval_widgets),
    _as_box(comparison_widgets),
]
tabs.children = tab_children
tab_titles = ['System', 'Data', 'Model', 'Training', 'Eval', 'Compare/Plots']
for i, title in enumerate(tab_titles):
    tabs.set_title(i, title)

display(tabs)


In [None]:
# ---- CONFIG BUILDER ----
def build_config_from_widgets():
    return {
        'system': {
            'N': system_widgets['N'].value,
            'K': system_widgets['K'].value,
            'M': system_widgets['M'].value,
            'P_tx': system_widgets['P_tx'].value,
            'sigma_h_sq': system_widgets['sigma_h_sq'].value,
            'sigma_g_sq': system_widgets['sigma_g_sq'].value,
            'phase_mode': system_widgets['phase_mode'].value,
            'phase_bits': system_widgets['phase_bits'].value,
            'probe_bank_method': system_widgets['probe_bank_method'].value,
        },
        'data': {
            'n_train': data_widgets['n_train'].value,
            'n_val': data_widgets['n_val'].value,
            'n_test': data_widgets['n_test'].value,
            'seed': data_widgets['seed'].value,
            'normalize_input': data_widgets['normalize_input'].value,
            'normalization_type': data_widgets['normalization_type'].value,
        },
        'model': {
            'hidden_sizes': _str_to_int_list(model_widgets['hidden_sizes'].value),
            'dropout_prob': model_widgets['dropout_prob'].value,
            'use_batch_norm': model_widgets['use_batch_norm'].value,
        },
        'training': {
            'batch_size': training_widgets['batch_size'].value,
            'learning_rate': training_widgets['learning_rate'].value,
            'weight_decay': training_widgets['weight_decay'].value,
            'num_epochs': training_widgets['num_epochs'].value,
            'early_stopping_patience': training_widgets['early_stopping_patience'].value,
            'eval_interval': training_widgets['eval_interval'].value,
            'device': training_widgets['device'].value,
        },
        'eval': {
            'top_m_values': _str_to_int_list(eval_widgets['top_m_values'].value),
        },
        'probe_bank_methods': list(comparison_widgets['probe_bank_methods'].value),
        'models_to_run': list(comparison_widgets['models_to_run'].value),
        'plots': list(comparison_widgets['plots'].value),
        'compare_metrics': _str_to_int_list(comparison_widgets['compare_metrics'].value),
    }

config_preview = widgets.Output()
def show_config_preview(_=None):
    config_preview.clear_output()
    with config_preview:
        print(build_config_from_widgets())

preview_button = widgets.Button(description='Preview Config')
preview_button.on_click(show_config_preview)
display(preview_button, config_preview)


In [None]:
# ---- EXPERIMENT HELPERS ----
def run_single_model(model_name, overrides):
    config = get_config(**overrides)

    torch.manual_seed(config.data.seed)
    np.random.seed(config.data.seed)

    probe_bank = generate_probe_bank(
        N=config.system.N,
        K=config.system.K,
        seed=config.data.seed,
        phase_mode=config.system.phase_mode,
        phase_bits=config.system.phase_bits,
        probe_bank_method=config.system.probe_bank_method,
    )

    train_loader, val_loader, test_loader, metadata = create_dataloaders(config, probe_bank)
    model = MODEL_REGISTRY[model_name](config)
    model, history = train(model, train_loader, val_loader, config, metadata)

    results = evaluate_model(
        model,
        test_loader,
        config,
        metadata['test_powers_full'],
        metadata['test_labels'],
        metadata['test_observed_indices'],
        metadata['test_optimal_powers'],
    )

    return {
        'name': f"{model_name} | {config.system.probe_bank_method}",
        'model': model,
        'history': history,
        'results': results,
        'config': config,
    }

def run_comparison(config_dict):
    overrides = {
        'system': config_dict['system'],
        'data': config_dict['data'],
        'training': config_dict['training'],
        'model': config_dict['model'],
        'eval': config_dict['eval'],
    }
    results = []
    for method in config_dict['probe_bank_methods']:
        overrides['system']['probe_bank_method'] = method
        for model_name in config_dict['models_to_run']:
            results.append(run_single_model(model_name, overrides))
    return results


In [None]:
# ---- PLOT REGISTRY ----
def plot_model_metric_bars(runs, metrics):
    names = [r['name'] for r in runs]
    x = np.arange(len(names))
    width = 0.2
    fig, ax = plt.subplots(figsize=(10, 5))
    for i, metric in enumerate(metrics):
        values = [getattr(r['results'], metric) for r in runs]
        ax.bar(x + i * width, values, width, label=metric)
    ax.set_xticks(x + width * (len(metrics) - 1) / 2)
    ax.set_xticklabels(names, rotation=15)
    ax.set_ylabel('Metric value')
    ax.set_title('Model Comparison')
    ax.legend()
    plt.tight_layout()
    plt.show()

def plot_eta_boxplot(runs):
    labels = [r['name'] for r in runs]
    data = [r['results'].eta_top1_distribution for r in runs]
    fig, ax = plt.subplots(figsize=(10, 5))
    ax.boxplot(data, labels=labels, showmeans=True)
    ax.set_ylabel('η (Top-1)')
    ax.set_title('η Distribution by Model')
    plt.tight_layout()
    plt.show()

PLOT_REGISTRY = {
    'training_history': lambda run: plot_training_history(run['history']),
    'eta_distribution': lambda run: plot_eta_distribution(run['results']),
    'top_m_comparison': lambda run: plot_top_m_comparison(run['results']),
    'baseline_comparison': lambda run: plot_baseline_comparison(run['results']),
    'model_metric_bars': lambda runs, metrics: plot_model_metric_bars(runs, metrics),
    'eta_boxplot': plot_eta_boxplot,
}


In [None]:
# ---- RUN + PLOT (click) ----
run_button = widgets.Button(description='Run Experiments', button_style='success')
run_output = widgets.Output()

def on_run_clicked(_):
    run_output.clear_output()
    with run_output:
        config_dict = build_config_from_widgets()
        runs = run_comparison(config_dict)
        print('Finished runs:', [r['name'] for r in runs])

        for plot_name in config_dict['plots']:
            if plot_name in {'model_metric_bars', 'eta_boxplot'}:
                if plot_name == 'model_metric_bars':
                    PLOT_REGISTRY[plot_name](runs, config_dict['compare_metrics'])
                else:
                    PLOT_REGISTRY[plot_name](runs)
            else:
                for run in runs:
                    PLOT_REGISTRY[plot_name](run)

run_button.on_click(on_run_clicked)
display(run_button, run_output)
