# Synthetic Control

This notebook demonstrates **synthetic control** impact estimation via [pysyncon](https://github.com/sdfordham/pysyncon) [`Synth`](https://sdfordham.github.io/pysyncon/synth.html).

The synthetic control method constructs a weighted combination of control units as a counterfactual for the treated unit, then estimates the causal effect as the difference between observed and synthetic outcomes in the post-treatment period.

## Workflow Overview

1. User provides `products.csv` (each product = a "unit" in the panel)
2. User configures `DATA.ENRICHMENT` for treatment assignment
3. User calls `evaluate_impact(config.yaml)`
4. Engine handles everything internally (adapter, enrichment, transform, model)

## Initial Setup

In [None]:
from pathlib import Path

import pandas as pd
from impact_engine_measure import evaluate_impact, load_results, parse_config_file
from impact_engine_measure.core import apply_transform
from impact_engine_measure.metrics import create_metrics_manager
from impact_engine_measure.models.factory import get_model_adapter
from online_retail_simulator import simulate

## Step 1 — Product Catalog

We use a small catalog (20 products) because synthetic control treats each product as a separate unit in the donor pool.

In [None]:
output_path = Path("output/demo_synthetic_control")
output_path.mkdir(parents=True, exist_ok=True)

job_info = simulate("configs/demo_synthetic_control_catalog.yaml", job_id="catalog")
products = job_info.load_df("products")

print(f"Generated {len(products)} products")
print(f"Products catalog: {job_info.get_store().full_path('products.csv')}")
products.head()

## Step 2 — Engine Configuration

Configure the engine with the following sections.
- `ENRICHMENT` — Quality boost applied to ~10% of products starting Nov 15
- `TRANSFORM` — `prepare_for_synthetic_control` adds a time-aware `treatment` column
- `MODEL` — `synthetic_control` with one designated treated unit

The panel structure is 20 products x 30 days (Nov 1–30). Enrichment starts Nov 15, giving 14 pre-treatment and 16 post-treatment periods.

With `seed=42` and `enrichment_fraction=0.1`, product `BU9XOOP3LG` is the treated unit.

In [None]:
config_path = "configs/demo_synthetic_control.yaml"
baseline_config_path = "configs/demo_synthetic_control_baseline.yaml"

## Step 3 — Impact Evaluation

A single call to `evaluate_impact()` handles everything.
- Engine creates `CatalogSimulatorAdapter`
- Adapter simulates daily metrics (30-day panel)
- Adapter applies enrichment (quality boost to treated product after Nov 15)
- `prepare_for_synthetic_control` transform adds the `treatment` column
- `SyntheticControlAdapter` builds a `Dataprep` object and fits via pysyncon's `Synth`

In [None]:
job_info = evaluate_impact(config_path, str(output_path), job_id="results")
print(f"Job ID: {job_info.job_id}")

## Step 4 — Review Results

In [None]:
result = load_results(job_info)

data = result.impact_results["data"]
model_params = data["model_params"]
estimates = data["impact_estimates"]
summary = data["model_summary"]

print("=" * 60)
print("SYNTHETIC CONTROL RESULTS")
print("=" * 60)

print(f"\nModel Type: {result.model_type}")
print(f"Treated Unit: {model_params['treated_unit']}")
print(f"Treatment Time: {model_params['treatment_time']}")

print("\n--- Impact Estimates ---")
print(f"ATT:               {estimates['att']:.4f}")
print(f"Standard Error:    {estimates['se']:.4f}")
print(f"CI Lower:          {estimates['ci_lower']:.4f}")
print(f"CI Upper:          {estimates['ci_upper']:.4f}")
print(f"Cumulative Effect: {estimates['cumulative_effect']:.4f}")

print("\n--- Model Summary ---")
print(f"Pre-treatment periods:  {summary['n_pre_periods']}")
print(f"Post-treatment periods: {summary['n_post_periods']}")
print(f"Control units:          {summary['n_control_units']}")
print(f"MSPE:                   {summary['mspe']:.4f}")
print(f"MAE:                    {summary['mae']:.4f}")

print("\n--- Control Unit Weights ---")
for unit, weight in summary["weights"].items():
    if weight > 0.001:
        print(f"  {unit}: {weight:.4f}")

print("\n" + "=" * 60)
print("Demo Complete!")
print("=" * 60)

## Step 5 — Model Validation

Compare the model's ATT estimate against the **true per-period revenue difference** for the treated unit (enriched vs counterfactual).

In [None]:
def calculate_true_effect(
    baseline_metrics: pd.DataFrame,
    enriched_metrics: pd.DataFrame,
    treated_unit: str,
    treatment_time: str,
) -> dict:
    """Calculate TRUE per-period effect for the treated unit."""
    treatment_date = pd.Timestamp(treatment_time)

    baseline_unit = baseline_metrics[
        (baseline_metrics["product_id"] == treated_unit) & (pd.to_datetime(baseline_metrics["date"]) >= treatment_date)
    ]
    enriched_unit = enriched_metrics[
        (enriched_metrics["product_id"] == treated_unit) & (pd.to_datetime(enriched_metrics["date"]) >= treatment_date)
    ]

    baseline_mean = baseline_unit["revenue"].mean()
    enriched_mean = enriched_unit["revenue"].mean()
    mean_effect = enriched_mean - baseline_mean

    return {
        "baseline_mean": float(baseline_mean),
        "enriched_mean": float(enriched_mean),
        "mean_effect": float(mean_effect),
    }

In [None]:
parsed_baseline = parse_config_file(baseline_config_path)
baseline_manager = create_metrics_manager(parsed_baseline)
baseline_metrics = baseline_manager.retrieve_metrics(products)

parsed_enriched = parse_config_file(config_path)
enriched_manager = create_metrics_manager(parsed_enriched)
enriched_metrics = enriched_manager.retrieve_metrics(products)

print(f"Baseline records: {len(baseline_metrics)}")
print(f"Enriched records: {len(enriched_metrics)}")

In [None]:
treated_unit = model_params["treated_unit"]
true_effect = calculate_true_effect(baseline_metrics, enriched_metrics, treated_unit, "2024-11-15")

true_me = true_effect["mean_effect"]
model_me = estimates["att"]

if true_me != 0:
    recovery_accuracy = (1 - abs(1 - model_me / true_me)) * 100
else:
    recovery_accuracy = 100 if model_me == 0 else 0

print("=" * 60)
print("TRUTH RECOVERY VALIDATION")
print("=" * 60)
print(f"True mean effect:  {true_me:.4f}")
print(f"Model estimate:    {model_me:.4f}")
print(f"Recovery accuracy: {max(0, recovery_accuracy):.1f}%")
print("=" * 60)

### Convergence Analysis

How does the estimate improve as the number of control units increases? We vary the catalog size (keeping 1 treated unit) and observe convergence.

In [None]:
control_sizes = [3, 5, 8, 12, 15, 18]
estimates_list = []
truth_list = []

parsed = parse_config_file(config_path)
measurement_params = parsed["MEASUREMENT"]["PARAMS"]

# enrichment_start is auto-injected during evaluate_impact() but not
# available when calling apply_transform directly — supply it explicitly.
transform_config = {
    "FUNCTION": "prepare_for_synthetic_control",
    "PARAMS": {"enrichment_start": "2024-11-15"},
}

# Determine which product is enriched in the retrieved metrics
# (enrichment assignment may differ from the evaluate_impact run).
enriched_ids = enriched_metrics[enriched_metrics["enriched"]]["product_id"].unique()
convergence_treated = enriched_ids[0]
all_ids = enriched_metrics["product_id"].unique()
control_pool = [pid for pid in all_ids if pid != convergence_treated]

for n_controls in control_sizes:
    subset_ids = [convergence_treated] + control_pool[:n_controls]
    enriched_sub = enriched_metrics[enriched_metrics["product_id"].isin(subset_ids)]
    baseline_sub = baseline_metrics[baseline_metrics["product_id"].isin(subset_ids)]

    true = calculate_true_effect(baseline_sub, enriched_sub, convergence_treated, "2024-11-15")
    truth_list.append(true["mean_effect"])

    transformed = apply_transform(enriched_sub, transform_config)
    model = get_model_adapter("synthetic_control")
    model.connect(measurement_params)
    result = model.fit(
        data=transformed,
        treatment_time="2024-11-15",
        treated_unit=convergence_treated,
        outcome_column="revenue",
        unit_column="product_id",
        time_column="date",
    )
    estimates_list.append(result.data["impact_estimates"]["att"])

print("Convergence analysis complete.")

In [None]:
from notebook_support import plot_convergence

plot_convergence(
    control_sizes,
    estimates_list,
    truth_list,
    xlabel="Number of Control Units",
    ylabel="ATT",
    title="Synthetic Control: Convergence of Estimate to True Effect",
)