# Model Selection and Parameter Tuning

This notebook demonstrates two key capabilities of the measurement framework:

1. **Model swappability** — Given the same data, switch between cross-sectional models by overriding a single config entry.
2. **Parameter sensitivity** — For a given model, investigate how tuning parameters affect the treatment effect estimate.

All three cross-sectional models share the same data-generation process: a single-day simulation with `product_detail_boost` enrichment applied to 50% of products.

## Workflow Overview

1. Generate a shared product catalog and define model overrides
2. Loop over models and override `MEASUREMENT` in the base config, write a temp YAML, call `evaluate_impact()`
3. Compare treatment effect estimates against ground truth
4. Sweep tuning parameters for subclassification and nearest neighbour matching

## Initial Setup

In [None]:
import copy
from pathlib import Path

import pandas as pd
import yaml
from impact_engine_measure import evaluate_impact, load_results, parse_config_file
from impact_engine_measure.metrics import create_metrics_manager
from online_retail_simulator import simulate

## Step 1 — Shared Data

All models will use the same product catalog and enriched metrics.
Differences in estimates come only from the model, not the data.

In [None]:
output_path = Path("output/demo_model_selection")
output_path.mkdir(parents=True, exist_ok=True)

job_info = simulate("configs/demo_model_selection_catalog.yaml", job_id="catalog")
products = job_info.load_df("products")

print(f"Generated {len(products)} products")
products.head()

### Ground Truth

Compute the true treatment effect by comparing enriched vs baseline (counterfactual) metrics.

In [None]:
config_path = "configs/demo_model_selection.yaml"
baseline_config_path = "configs/demo_model_selection_baseline.yaml"

parsed_enriched = parse_config_file(config_path)
enriched_manager = create_metrics_manager(parsed_enriched)
enriched_metrics = enriched_manager.retrieve_metrics(products)

parsed_baseline = parse_config_file(baseline_config_path)
baseline_manager = create_metrics_manager(parsed_baseline)
baseline_metrics = baseline_manager.retrieve_metrics(products)

print(f"Enriched records: {len(enriched_metrics)}")
print(f"Baseline records: {len(baseline_metrics)}")

In [None]:
def calculate_true_effect(
    baseline_metrics: pd.DataFrame,
    enriched_metrics: pd.DataFrame,
) -> dict:
    """Calculate TRUE ATT by comparing per-product revenue for treated products."""
    treated_ids = enriched_metrics[enriched_metrics["enriched"]]["product_id"].unique()

    enriched_treated = enriched_metrics[enriched_metrics["product_id"].isin(treated_ids)]
    baseline_treated = baseline_metrics[baseline_metrics["product_id"].isin(treated_ids)]

    enriched_mean = enriched_treated.groupby("product_id")["revenue"].mean().mean()
    baseline_mean = baseline_treated.groupby("product_id")["revenue"].mean().mean()
    treatment_effect = enriched_mean - baseline_mean

    return {
        "enriched_mean": float(enriched_mean),
        "baseline_mean": float(baseline_mean),
        "treatment_effect": float(treatment_effect),
    }


true_effect = calculate_true_effect(baseline_metrics, enriched_metrics)
true_te = true_effect["treatment_effect"]

print("=" * 60)
print("TRUE TREATMENT EFFECT (GROUND TRUTH)")
print("=" * 60)
print(f"Enriched mean revenue:  {true_effect['enriched_mean']:.4f}")
print(f"Baseline mean revenue:  {true_effect['baseline_mean']:.4f}")
print(f"True treatment effect:  {true_te:.4f}")
print("=" * 60)

## Part 1: Model Swappability

We load one base config and override `MEASUREMENT` for each model.
Each iteration writes a temporary YAML and calls `evaluate_impact()`.

In [None]:
def run_with_override(base_config, measurement_override, storage_url, job_id):
    """Override MEASUREMENT in base config, write temp YAML, run evaluate_impact()."""
    config = copy.deepcopy(base_config)
    config["MEASUREMENT"] = measurement_override

    tmp_config_path = Path(storage_url) / f"config_{job_id}.yaml"
    tmp_config_path.parent.mkdir(parents=True, exist_ok=True)
    with open(tmp_config_path, "w") as f:
        yaml.dump(config, f, default_flow_style=False)

    job_info = evaluate_impact(str(tmp_config_path), storage_url, job_id=job_id)
    result = load_results(job_info)
    return result.impact_results

In [None]:
base_config = parse_config_file(config_path)

model_overrides = {
    "Experiment (OLS)": {
        "MODEL": "experiment",
        "PARAMS": {"formula": "revenue ~ enriched + price"},
    },
    "Subclassification": {
        "MODEL": "subclassification",
        "PARAMS": {
            "treatment_column": "enriched",
            "covariate_columns": ["price"],
            "n_strata": 5,
            "estimand": "att",
            "dependent_variable": "revenue",
        },
    },
    "Nearest Neighbour Matching": {
        "MODEL": "nearest_neighbour_matching",
        "PARAMS": {
            "treatment_column": "enriched",
            "covariate_columns": ["price"],
            "dependent_variable": "revenue",
            "caliper": 0.2,
            "replace": True,
            "ratio": 1,
        },
    },
}


def extract_te(results):
    """Extract the treatment effect from model results regardless of model type."""
    estimates = results["data"]["impact_estimates"]
    model_type = results["model_type"]
    if model_type == "experiment":
        return estimates["params"].get("enriched[T.True]", estimates["params"].get("enriched", 0))
    elif model_type == "nearest_neighbour_matching":
        return estimates["att"]
    else:
        return estimates["treatment_effect"]

In [None]:
model_results = {}
model_estimates = {}

for name, measurement in model_overrides.items():
    job_id = measurement["MODEL"]
    results = run_with_override(base_config, measurement, str(output_path), job_id)
    model_results[name] = results
    model_estimates[name] = extract_te(results)
    print(f"{name}: treatment effect = {model_estimates[name]:.4f}")

print(f"\nTrue effect: {true_te:.4f}")

In [None]:
comparison = pd.DataFrame(
    [
        {
            "Model": name,
            "Estimate": est,
            "True Effect": true_te,
            "Absolute Error": abs(est - true_te),
            "Recovery (%)": max(0, (1 - abs(1 - est / true_te)) * 100) if true_te != 0 else 0,
        }
        for name, est in model_estimates.items()
    ]
)

print("=" * 80)
print("CROSS-SECTIONAL MODEL COMPARISON")
print("=" * 80)
print(comparison.to_string(index=False, float_format=lambda x: f"{x:.4f}"))
print("=" * 80)

In [None]:
from notebook_support import plot_model_comparison

plot_model_comparison(
    model_names=list(model_estimates.keys()),
    estimates=list(model_estimates.values()),
    true_effect=true_te,
    ylabel="Treatment Effect",
    title="Cross-Sectional Models: Estimates vs Truth",
)

## Part 2: Parameter Sensitivity

For a given model and data, how sensitive is the treatment effect estimate to tuning parameters?
We use the same override pattern, varying one parameter at a time.

### 2a. Subclassification: `n_strata`

More strata means finer partitioning of the covariate space.
This can improve precision but may leave strata without common support.

In [None]:
n_strata_values = [2, 3, 5, 10, 20, 50, 100]
subclass_estimates = []
strata_used = []
strata_dropped = []

for n in n_strata_values:
    measurement = {
        "MODEL": "subclassification",
        "PARAMS": {
            "treatment_column": "enriched",
            "covariate_columns": ["price"],
            "n_strata": n,
            "estimand": "att",
            "dependent_variable": "revenue",
        },
    }
    results = run_with_override(base_config, measurement, str(output_path), f"subclass_strata_{n}")
    estimates = results["data"]["impact_estimates"]

    subclass_estimates.append(estimates["treatment_effect"])
    strata_used.append(estimates["n_strata"])
    strata_dropped.append(estimates["n_strata_dropped"])

subclass_sensitivity = pd.DataFrame(
    {
        "n_strata (requested)": n_strata_values,
        "Strata Used": strata_used,
        "Strata Dropped": strata_dropped,
        "Treatment Effect": subclass_estimates,
        "Absolute Error": [abs(est - true_te) for est in subclass_estimates],
    }
)

print("Subclassification: n_strata Sensitivity")
print("-" * 70)
print(subclass_sensitivity.to_string(index=False, float_format=lambda x: f"{x:.4f}"))

In [None]:
from notebook_support import plot_parameter_sensitivity

plot_parameter_sensitivity(
    param_values=n_strata_values,
    estimates=subclass_estimates,
    true_effect=true_te,
    xlabel="Number of Strata (n_strata)",
    ylabel="Treatment Effect",
    title="Subclassification: Sensitivity to n_strata",
)

### 2b. Nearest Neighbour Matching: `caliper`

The caliper controls the maximum allowed distance between a treated unit and its matched control.
Smaller values enforce tighter matches but may discard units, while larger values allow more matches with worse balance.

In [None]:
caliper_values = [0.01, 0.05, 0.1, 0.2, 0.5, 1.0, 2.0]
matching_estimates = []
n_matched_att_list = []

for cal in caliper_values:
    measurement = {
        "MODEL": "nearest_neighbour_matching",
        "PARAMS": {
            "treatment_column": "enriched",
            "covariate_columns": ["price"],
            "dependent_variable": "revenue",
            "caliper": cal,
            "replace": True,
            "ratio": 1,
        },
    }
    results = run_with_override(base_config, measurement, str(output_path), f"matching_caliper_{cal}")
    estimates = results["data"]["impact_estimates"]
    summary = results["data"]["model_summary"]

    matching_estimates.append(estimates["att"])
    n_matched_att_list.append(summary["n_matched_att"])

matching_sensitivity = pd.DataFrame(
    {
        "Caliper": caliper_values,
        "N Matched (ATT)": n_matched_att_list,
        "Treatment Effect (ATT)": matching_estimates,
        "Absolute Error": [abs(est - true_te) for est in matching_estimates],
    }
)

print("Nearest Neighbour Matching: Caliper Sensitivity")
print("-" * 70)
print(matching_sensitivity.to_string(index=False, float_format=lambda x: f"{x:.4f}"))

In [None]:
plot_parameter_sensitivity(
    param_values=caliper_values,
    estimates=matching_estimates,
    true_effect=true_te,
    xlabel="Caliper",
    ylabel="Treatment Effect (ATT)",
    title="Nearest Neighbour Matching: Sensitivity to Caliper",
)

## Key Takeaways

**Model swappability.**
- All three models recover the true treatment effect from the same simulated data.
- Switching models requires only changing the `MEASUREMENT` entry in the config.
- The `evaluate_impact()` interface stays the same regardless of the model.

**Parameter sensitivity.**
- **Subclassification** is relatively stable across `n_strata` values. Very low values may under-partition, while very high values may drop strata with insufficient common support.
- **Nearest neighbour matching** is more sensitive to `caliper`. Very small calipers may discard too many units, while very large calipers degrade match quality.