# Hyperparameter Optimization for 2-Arm ANCOVA NPE Model

This notebook performs **Bayesian Optimization** (Optuna) to find optimal neural network architecture for the 2-arm ANCOVA amortized posterior estimation model.

## Objectives
1. **Minimize calibration error** (mean absolute coverage error)
2. **Minimize parameter count** (model size)

This is a multi-objective optimization with Pareto-optimal solutions.

In [None]:
import os

if not os.environ.get("KERAS_BACKEND"):
    os.environ["KERAS_BACKEND"] = "torch"
    
from pathlib import Path
import importlib

import numpy as np
np.set_printoptions(suppress=True)
RNG = np.random.default_rng(2025)

from itertools import product

import keras
import bayesflow as bf

import matplotlib.pyplot as plt

# Import generic infrastructure from the package
from rctbp_bf_training.core.infrastructure import (
    SummaryNetworkConfig,
    InferenceNetworkConfig,
    WorkflowConfig,
    params_dict_to_workflow_config,
    build_summary_network,
    build_inference_network,
)

# Import ANCOVA-specific functions
from rctbp_bf_training.models.ancova.model import (
    ANCOVAConfig,
    create_adapter,
    create_simulator,
    create_ancova_workflow_components,
    get_ancova_adapter_spec,
    create_validation_grid,
    make_simulate_fn,
    get_model_metadata,
    save_model_with_metadata,
    # Legacy imports for backwards compatibility
    build_networks_from_params,
    NetworkConfig,
)
from rctbp_bf_training.core.utils import MovingAverageEarlyStopping

# Create default configuration using new structure
config = ANCOVAConfig()
print(f"Config loaded: {config.to_dict()}")
print(f"\nNew decoupled network configs:")
print(f"  Summary network: {config.workflow.summary_network}")
print(f"  Inference network: {config.workflow.inference_network}")

## Model Definition

Define the simulator components: prior, likelihood, meta function, and adapter.

In [None]:
# Create simulator and adapter using factory functions
simulator = create_simulator(config, RNG)
adapter = create_adapter()

# Test
sim_draws = simulator.sample(100)
print("Simulator + Adapter created via factory functions")
print(f"  sim keys: {list(sim_draws.keys())}")
print(f"  N={sim_draws['N']}, p_alloc={sim_draws['p_alloc']:.2f}")

processed = adapter(sim_draws)
print(f"  inference_variables: {processed['inference_variables'].shape}")
print(f"  summary_variables: {processed['summary_variables'].shape}")

Simulator + Adapter created via factory functions
  sim keys: ['N', 'p_alloc', 'prior_df', 'prior_scale', 'b_covariate', 'b_group', 'outcome', 'covariate', 'group']
  N=979, p_alloc=0.65
  inference_variables: (100, 1)
  summary_variables: (100, 979, 3)


In [None]:
# Demonstrate the new decoupled API

# 1. Create independent network configs
from rctbp_bf_training.core.infrastructure import SummaryNetworkConfig, InferenceNetworkConfig

summary_config = SummaryNetworkConfig(
    summary_dim=12,  # Independently tuned
    depth=4,
    width=96,
    dropout=0.1,
)

inference_config = InferenceNetworkConfig(
    depth=8,  # Independently tuned
    hidden_sizes=(256, 256),
    dropout=0.15,
)

print("✓ Decoupled network configs created:")
print(f"  Summary: {summary_config}")
print(f"  Inference: {inference_config}")

# 2. Use ANCOVA factory to build everything
from rctbp_bf_training.core.infrastructure import WorkflowConfig

workflow_config = WorkflowConfig(
    summary_network=summary_config,
    inference_network=inference_config,
)

demo_config = ANCOVAConfig(workflow=workflow_config)

# Build all components
summary_net, inference_net, demo_adapter = create_ancova_workflow_components(demo_config)

print(f"\n✓ Workflow components created:")
print(f"  Summary network output: {summary_net.summary_dim}")
print(f"  Adapter spec: {get_ancova_adapter_spec().set_keys}")

# Clean up
del summary_net, inference_net, demo_adapter, demo_config

## New Decoupled Architecture Demo

The refactored codebase decouples summary and inference networks for independent configuration and optimization.

## Validation Functions

Functions for simulating data and running validation during optimization.

In [None]:
# Validation functions - all imported from the package
from rctbp_bf_training.core import validation as functions_validation
importlib.reload(functions_validation)

from rctbp_bf_training.core.validation import (
    run_validation_pipeline,
    extract_calibration_metrics,
    make_bayesflow_infer_fn,
)

# MovingAverageEarlyStopping imported from rctbp_bf_training.core.utils (in cell-1)
# make_simulate_fn imported from rctbp_bf_training.models.ancova.model (in cell-1)

print("Validation functions loaded:")
print("  - MovingAverageEarlyStopping from rctbp_bf_training.core.utils")
print("  - make_simulate_fn from rctbp_bf_training.models.ancova.model")
print("  - run_validation_pipeline from rctbp_bf_training.core.validation")

# Bayesian Optimization

Import BO infrastructure and define the optimization loop.

In [None]:
# Import BO infrastructure from the package
from rctbp_bf_training.core import optimization as bo
importlib.reload(bo)

from rctbp_bf_training.core.optimization import (
    create_study,
    sample_hyperparameters,
    HyperparameterSpace,
    get_param_count,
    extract_objective_values,
    cleanup_trial,
    plot_optimization_results,
    summarize_best_trials,
)

print("Bayesian optimization infrastructure loaded from package")
print(f"Optuna available: {bo.OPTUNA_AVAILABLE}")

## Search Space and Optimization Grid

Define the hyperparameter search space and a reduced validation grid for faster optimization.

In [None]:
# Define search space (customize ranges as needed)
search_space = HyperparameterSpace(
    # DeepSet
    summary_dim=(4, 16),
    deepset_width=(32, 128),
    deepset_depth=(1, 4),
    deepset_dropout=(0.05, 0.5),
    
    # CouplingFlow  
    flow_depth=(2, 8),
    flow_hidden=(32, 128),
    flow_dropout=(0.05, 0.5),
    
    # Training
    initial_lr=(1e-5, 5e-3),
    batch_size=(128, 384),
    
    # Fixed (not optimized)
    decay_rate=0.85,
    patience=15,
    window=15,
)

# Use factory function for validation grid
opt_conditions = create_validation_grid(extended=False)

print(f"Search space defined with {len(search_space.__dataclass_fields__)} parameters")
print(f"Optimization validation grid: {len(opt_conditions)} conditions")

Search space defined with 12 parameters
Optimization validation grid: 16 conditions


## Objective Function

The objective function builds a model with trial hyperparameters, trains it, validates on the reduced grid, and returns (calibration_error, param_count).

In [None]:
def objective(trial):
    """
    Optuna objective: returns (calibration_error, param_count).
    
    Uses new decoupled API for building networks.
    """
    import gc
    
    # Sample hyperparameters
    params = sample_hyperparameters(trial, search_space)
    
    # NEW API: Convert params to WorkflowConfig with decoupled networks
    workflow_config = params_dict_to_workflow_config(params)
    
    # NEW API: Build networks using decoupled configs
    # This gives us independent summary and inference networks
    summary_net = build_summary_network(workflow_config.summary_network)
    inference_net = build_inference_network(workflow_config.inference_network)
    
    # Or equivalently, use the ANCOVA factory (returns summary_net, inference_net, adapter):
    # ancova_config = ANCOVAConfig(workflow=workflow_config)
    # summary_net, inference_net, adapter = create_ancova_workflow_components(ancova_config)
    
    # Setup learning rate schedule
    steps_per_epoch = params["batch_size"] * 100
    lr_schedule = keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate=params["initial_lr"],
        decay_steps=steps_per_epoch,
        decay_rate=params["decay_rate"],
        staircase=True,
    )
    opt = keras.optimizers.Adam(learning_rate=lr_schedule)
    
    # Create workflow
    wf = bf.BasicWorkflow(
        simulator=simulator,
        adapter=adapter,
        inference_network=inference_net,
        summary_network=summary_net,
        optimizer=opt,
        inference_conditions=["N", "p_alloc", "prior_df", "prior_scale"],
        checkpoint_name=f"optuna_trial_{trial.number}",
    )
    
    try:
        wf.approximator.compile(optimizer=opt)
    except Exception:
        pass
    
    early_stop = MovingAverageEarlyStopping(
        window=params["window"],
        patience=params["patience"],
        restore_best_weights=True,
    )
    
    # Train
    try:
        history = wf.fit_online(
            epochs=config.workflow.training.epochs,
            batch_size=params["batch_size"],
            num_batches_per_epoch=config.workflow.training.batches_per_epoch,
            validation_data=config.workflow.training.validation_sims,
            callbacks=[early_stop],
        )
    except Exception as e:
        print(f"Trial {trial.number} FAILED: {e}")
        cleanup_trial()
        return 1.0, 1e9
    
    param_count = get_param_count(wf.approximator)
    
    # Validate
    simulate_fn_opt = make_simulate_fn(rng=RNG)
    infer_fn_opt = make_bayesflow_infer_fn(
        wf.approximator,
        param_key="b_group",
        data_keys=["outcome", "covariate", "group"],
        context_keys={"N": int, "p_alloc": float, "prior_df": float, "prior_scale": float},
    )
    
    try:
        results = run_validation_pipeline(
            conditions_list=opt_conditions,
            n_sims=500,
            n_post_draws=500,
            simulate_fn=simulate_fn_opt,
            infer_fn=infer_fn_opt,
            true_param_key="b_arm_treat",
            verbose=False,
        )
        cal_error, _ = extract_objective_values(results["metrics"], param_count)
    except Exception as e:
        print(f"Trial {trial.number} validation FAILED: {e}")
        cal_error = 1.0
    
    print(f"Trial {trial.number}: cal_error={cal_error:.4f}, params={param_count:,}")
    
    cleanup_trial()
    del wf, summary_net, inference_net
    gc.collect()
    
    return cal_error, param_count

# Functions already imported in cell-1
print("Objective function defined using NEW DECOUPLED API")
print("  - params_dict_to_workflow_config() converts hyperparameters")
print("  - build_summary_network() builds summary net independently")
print("  - build_inference_network() builds inference net independently")

## Run Optimization

Create and run the multi-objective Optuna study. Results are saved to SQLite for resumption.

In [None]:
# Reload optimization module to pick up any changes
importlib.reload(bo)
from rctbp_bf_training.core.optimization import create_study

# Create multi-objective study
study = create_study(
    study_name="ancova_npe_optimization",
    directions=["minimize", "minimize"],  # [calibration_error, param_count]
    storage="sqlite:///optuna_ancova.db",  # Persistent storage for resumption
    load_if_exists=True,
)

print(f"Study created: {study.study_name}")
print(f"Existing trials: {len(study.trials)}")
print(f"Directions: {study.directions}")

In [8]:
# Run optimization (adjust n_trials based on compute budget)
# Each trial takes ~5-10 minutes depending on architecture
N_TRIALS = 30  # Adjust as needed

study.optimize(
    objective,
    n_trials=N_TRIALS,
    show_progress_bar=True,
    gc_after_trial=True,
)

print(f"\nOptimization complete!")
print(f"Total trials: {len(study.trials)}")
print(f"Pareto-optimal trials: {len(study.best_trials)}")

  0%|          | 0/30 [00:00<?, ?it/s]INFO:bayesflow:Fitting on dataset instance of OnlineDataset.
INFO:bayesflow:Building on a test batch.


Epoch 1/200
[1m 1/50[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m27s[0m 566ms/step - loss: 1.4189

Consider using tensor.detach() first. (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\pytorch\torch\csrc\autograd\generated\python_variable_methods.cpp:836.)
  value = float(value)


[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 197ms/step - loss: -0.2195 - val_loss: -1.4871 - moving_avg_val_loss: -1.4871
Epoch 2/200
[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 324ms/step - loss: -0.6678 - val_loss: -1.8209 - moving_avg_val_loss: -1.6540
Epoch 3/200
[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 293ms/step - loss: -1.9652 - val_loss: -2.9856 - moving_avg_val_loss: -2.0978
Epoch 4/200
[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 301ms/step - loss: -2.9115 - val_loss: -3.2123 - moving_avg_val_loss: -2.3765
Epoch 5/200
[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 213ms/step - loss: -2.4414 - val_loss: -2.8836 - moving_avg_val_loss: -2.4779
Epoch 6/200
[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 347ms/step - loss: -2.5338 - val_loss: -2.8841 - moving_avg_val_loss: -2.5456
Epoch 7/200
[1m50/50[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 217ms

  0%|          | 0/30 [04:08<?, ?it/s]

[W 2025-12-21 14:54:37,269] Trial 101 failed with parameters: {'summary_dim': 6, 'deepset_width': 96, 'deepset_depth': 1, 'deepset_dropout': 0.07661209538663485, 'flow_depth': 8, 'flow_hidden': 64, 'flow_dropout': 0.07201038490553535, 'initial_lr': 0.0003746261155838961, 'batch_size': 256} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "c:\Users\Matze\anaconda3\envs\py311_sbi\Lib\site-packages\optuna\study\_optimize.py", line 205, in _run_trial
    value_or_values = func(trial)
                      ^^^^^^^^^^^
  File "C:\Users\Matze\AppData\Local\Temp\ipykernel_59168\3415015791.py", line 41, in objective
    history = wf.fit_online(
              ^^^^^^^^^^^^^^
  File "c:\Users\Matze\anaconda3\envs\py311_sbi\Lib\site-packages\bayesflow\workflows\basic_workflow.py", line 789, in fit_online
    return self._fit(
           ^^^^^^^^^^
  File "c:\Users\Matze\anaconda3\envs\py311_sbi\Lib\site-packages\bayesflow\workflows\basic_workflow.py", l

  0%|          | 0/30 [04:08<?, ?it/s]


KeyboardInterrupt: 

## Analyze Results

Visualize the Pareto front and extract the best configurations.

In [None]:
# Plot optimization results
importlib.reload(bo)
from rctbp_bf_training.core.optimization import plot_optimization_results

fig = plot_optimization_results(study)
plt.show()

In [None]:
# Get best configurations from Pareto front
best_configs = summarize_best_trials(study)
display(best_configs)

# Print the best configuration for each objective
print("\n" + "="*60)
print("RECOMMENDED CONFIGURATIONS")
print("="*60)

if len(best_configs) > 0:
    # Best for calibration (lowest cal_error)
    best_cal = best_configs.iloc[0]
    print(f"\n📊 Best Calibration (trial {int(best_cal['trial'])}):")
    print(f"   Cal error: {best_cal['cal_error']:.4f}")
    print(f"   Params: {int(best_cal['param_count']):,}")
    
    # Best for size (if different)
    if "param_count" in best_configs.columns:
        best_size = best_configs.sort_values("param_count").iloc[0]
        if best_size["trial"] != best_cal["trial"]:
            print(f"\n📦 Smallest Model (trial {int(best_size['trial'])}):")
            print(f"   Cal error: {best_size['cal_error']:.4f}")
            print(f"   Params: {int(best_size['param_count']):,}")

## Apply Best Configuration

Copy the best hyperparameters to the main configuration cells above to retrain with the optimal architecture.

In [None]:
# Select which Pareto-optimal trial to use
# Options: choose by index (0 = best calibration) or by trial number
SELECTED_TRIAL_IDX = 0  # Index in best_configs DataFrame

if len(best_configs) > 0:
    selected = best_configs.iloc[SELECTED_TRIAL_IDX]
    trial_num = int(selected["trial"])
    
    # Get full trial parameters
    trial = [t for t in study.trials if t.number == trial_num][0]
    params = trial.params
    
    print(f"Selected Trial {trial_num}")
    print(f"Calibration Error: {selected['cal_error']:.4f}")
    print(f"Parameter Count: {int(selected['param_count']):,}")
    print("\nHyperparameters to use:")
    print("-" * 40)
    
    # Print as copy-pasteable configuration
    print(f"""
# Copy these values to the HYPERPARAMETERS cell above:
SUMMARY_DIM = {params['summary_dim']}
DEEPSET_DEPTH = {params['deepset_depth']}
DEEPSET_WIDTH = {params['deepset_width']}
DEEPSET_DROPOUT = {params['deepset_dropout']:.4f}

FLOW_DEPTH = {params['flow_depth']}
FLOW_HIDDEN = {params['flow_hidden']}
FLOW_DROPOUT = {params['flow_dropout']:.4f}

INITIAL_LR = {params['initial_lr']:.6f}
BATCH_SIZE = {params['batch_size']}
""")

## Train Until Threshold

Train the best configuration repeatedly until it meets the calibration error threshold.

In [None]:
# Configuration for threshold-based training
CAL_ERROR_THRESHOLD = 0.05
MAX_ATTEMPTS = 40
FULL_EPOCHS = 50
FULL_BATCHES = 100

# Use factory function for extended validation grid
final_conditions = create_validation_grid(extended=True)

print(f"Threshold: {CAL_ERROR_THRESHOLD}")
print(f"Max attempts: {MAX_ATTEMPTS}")
print(f"Final validation grid: {len(final_conditions)} conditions")

In [None]:
def train_until_threshold(params, threshold, max_attempts, epochs=50, batches_per_epoch=100):
    """
    Train until calibration threshold is met.
    
    Uses new decoupled API for network building.
    """
    import gc
    
    for attempt in range(1, max_attempts + 1):
        print(f"\n{'='*60}\nATTEMPT {attempt}/{max_attempts}\n{'='*60}")
        
        # NEW API: Convert params to WorkflowConfig
        workflow_config = params_dict_to_workflow_config(params)
        
        # NEW API: Build decoupled networks
        summary_net = build_summary_network(workflow_config.summary_network)
        inference_net = build_inference_network(workflow_config.inference_network)
        
        print(f"Networks built:")
        print(f"  Summary: dim={workflow_config.summary_network.summary_dim}, "
              f"depth={workflow_config.summary_network.depth}, "
              f"width={workflow_config.summary_network.width}")
        print(f"  Inference: depth={workflow_config.inference_network.depth}, "
              f"hidden={workflow_config.inference_network.hidden_sizes}")
        
        # Setup optimizer
        steps_per_epoch = params["batch_size"] * batches_per_epoch
        lr_schedule = keras.optimizers.schedules.ExponentialDecay(
            initial_learning_rate=params["initial_lr"],
            decay_steps=steps_per_epoch,
            decay_rate=params.get("decay_rate", 0.85),
            staircase=True,
        )
        opt = keras.optimizers.Adam(learning_rate=lr_schedule)
        
        # Create workflow
        wf = bf.BasicWorkflow(
            simulator=simulator,
            adapter=adapter,
            inference_network=inference_net,
            summary_network=summary_net,
            optimizer=opt,
            inference_conditions=["N", "p_alloc", "prior_df", "prior_scale"],
            checkpoint_name=f"best_model_attempt_{attempt}",
        )
        
        early_stop = MovingAverageEarlyStopping(
            window=params.get("window", 10),
            patience=params.get("patience", 10),
            restore_best_weights=True,
        )
        
        # Train
        try:
            wf.fit_online(
                epochs=epochs,
                batch_size=params["batch_size"],
                num_batches_per_epoch=batches_per_epoch,
                validation_data=config.workflow.training.validation_sims,
                callbacks=[early_stop],
            )
        except Exception as e:
            print(f"Training failed: {e}")
            del wf, summary_net, inference_net
            gc.collect()
            continue
        
        # Validate using imported factory functions
        simulate_fn = make_simulate_fn(rng=RNG)
        infer_fn = make_bayesflow_infer_fn(
            wf.approximator,
            param_key="b_group",
            data_keys=["outcome", "covariate", "group"],
            context_keys={"N": int, "p_alloc": float, "prior_df": float, "prior_scale": float},
        )
        
        try:
            results = run_validation_pipeline(
                conditions_list=final_conditions,
                n_sims=1000,
                n_post_draws=1000,
                simulate_fn=simulate_fn,
                infer_fn=infer_fn,
                true_param_key="b_arm_treat",
                verbose=True,
            )
            param_count = get_param_count(wf.approximator)
            cal_error, _ = extract_objective_values(results["metrics"], param_count)
        except Exception as e:
            print(f"Validation failed: {e}")
            del wf, summary_net, inference_net
            gc.collect()
            continue
        
        print(f"\nAttempt {attempt}: cal_error={cal_error:.4f}, threshold={threshold:.4f}")
        
        if cal_error <= threshold:
            print("✓ SUCCESS! Threshold met.")
            return wf, cal_error, attempt, workflow_config
        
        del wf, summary_net, inference_net
        gc.collect()
    
    return None, None, None, None

print("train_until_threshold defined (using new decoupled API)")

In [None]:
# Train the best model until threshold is met
best_workflow, final_cal_error, successful_attempt, best_workflow_config = train_until_threshold(
    params=params,  # From the "Apply Best Configuration" cell
    threshold=CAL_ERROR_THRESHOLD,
    max_attempts=MAX_ATTEMPTS,
    epochs=FULL_EPOCHS,
    batches_per_epoch=FULL_BATCHES,
)

if best_workflow is not None:
    print(f"\n{'='*60}")
    print(f"FINAL MODEL READY")
    print(f"{'='*60}")
    print(f"Achieved calibration error: {final_cal_error:.4f}")
    print(f"Successful on attempt: {successful_attempt}")
    print(f"Model parameters: {get_param_count(best_workflow.approximator):,}")
    print(f"\nFinal network configurations:")
    print(f"  Summary: {best_workflow_config.summary_network}")
    print(f"  Inference: {best_workflow_config.inference_network}")

In [None]:
# Save with metadata using new API
if best_workflow is not None:
    # Create ANCOVAConfig with the optimized workflow config
    config_with_optimized = ANCOVAConfig(
        prior=config.prior,
        meta=config.meta,
        workflow=best_workflow_config,  # Use the optimized workflow config
    )
    
    # Get metadata using new infrastructure
    metadata = get_model_metadata(
        config=config_with_optimized,
        validation_results={
            "calibration_error": final_cal_error,
            "successful_attempt": successful_attempt,
            "param_count": get_param_count(best_workflow.approximator),
        },
    )
    
    # Save using infrastructure's save function
    save_path = Path("checkpoints") / "ancova_cont_2arms_optimized"
    saved_path = save_model_with_metadata(best_workflow.approximator, save_path, metadata)
    
    print(f"✓ Model saved to: {saved_path}")
    print(f"✓ Metadata saved to: {saved_path.with_suffix('.json')}")
    print(f"\nSaved configuration:")
    print(f"  Summary network: dim={best_workflow_config.summary_network.summary_dim}, "
          f"depth={best_workflow_config.summary_network.depth}")
    print(f"  Inference network: depth={best_workflow_config.inference_network.depth}, "
          f"hidden={best_workflow_config.inference_network.hidden_sizes}")
else:
    print("No model to save")