# Lalonde DR Analysis Workflow with Veldra

This notebook demonstrates scenario-driven causal analysis using the Lalonde training program data.
Goal: estimate treatment effect on 1978 earnings (`re78`) with **ATT** as the default estimand.


## Prerequisites

- This notebook fetches Lalonde data from a public URL on first run.
- The normalized data is cached locally and reused on subsequent runs.
- DR estimation is run through `veldra.api.estimate_dr` (no direct core call).


In [None]:
import json
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from IPython.display import display

from veldra.api import estimate_dr


def _resolve_repo_root(start: Path) -> Path:
    current = start.resolve()
    candidates = [current, *current.parents]
    for base in candidates:
        if (base / "pyproject.toml").exists() and (base / "examples").exists():
            return base
    return start.resolve()


ROOT = _resolve_repo_root(Path.cwd())
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

OUT_DIR = ROOT / "examples" / "out" / "notebook_lalonde_dr"
OUT_DIR.mkdir(parents=True, exist_ok=True)

CACHE_PATH = OUT_DIR / "lalonde_raw.parquet"
SUMMARY_PATH = OUT_DIR / "lalonde_analysis_summary.json"

LALONDE_URL = "https://raw.githubusercontent.com/vincentarelbundock/Rdatasets/master/csv/MatchIt/lalonde.csv"
TARGET_COL = "re78"
TREATMENT_COL = "treat"

REQUIRED_COLUMNS = [
    "treat",
    "re78",
    "age",
    "educ",
    "black",
    "hispan",
    "married",
    "nodegree",
    "re74",
    "re75",
]


In [None]:
def _normalize_lalonde(raw: pd.DataFrame) -> pd.DataFrame:
    frame = raw.copy()
    if "race" in frame.columns:
        race = frame["race"].astype(str).str.lower()
        frame["black"] = (race == "black").astype(int)
        frame["hispan"] = (race == "hispan").astype(int)
    for col in ["treat", "married", "nodegree", "age", "educ", "re74", "re75", "re78"]:
        frame[col] = pd.to_numeric(frame[col], errors="coerce")
    frame = frame.dropna(subset=["treat", "re78", "age", "educ", "married", "nodegree", "re74", "re75"])
    frame["treat"] = frame["treat"].astype(int)
    frame["married"] = frame["married"].astype(int)
    frame["nodegree"] = frame["nodegree"].astype(int)
    frame["black"] = frame["black"].fillna(0).astype(int)
    frame["hispan"] = frame["hispan"].fillna(0).astype(int)

    normalized = frame[REQUIRED_COLUMNS].copy()
    return normalized.reset_index(drop=True)


if CACHE_PATH.exists():
    lalonde_df = pd.read_parquet(CACHE_PATH)
    cache_mode = "cache"
else:
    raw_df = pd.read_csv(LALONDE_URL)
    lalonde_df = _normalize_lalonde(raw_df)
    lalonde_df.to_parquet(CACHE_PATH, index=False)
    cache_mode = "url"

print(f"cache_mode={cache_mode}")
print(f"cache_path={CACHE_PATH}")
display(lalonde_df.head())


In [None]:
missing = [col for col in REQUIRED_COLUMNS if col not in lalonde_df.columns]
if missing:
    raise ValueError(f"Missing required columns: {missing}")

print("n_rows:", len(lalonde_df))
print("treated:", int(lalonde_df[TREATMENT_COL].sum()))
print("control:", int((lalonde_df[TREATMENT_COL] == 0).sum()))

group_summary = lalonde_df.groupby(TREATMENT_COL)[["age", "educ", "re74", "re75", "re78"]].agg(["mean", "std"])
display(group_summary)


In [None]:
config = {
    "config_version": 1,
    "task": {"type": "regression"},
    "data": {
        "path": str(CACHE_PATH),
        "target": TARGET_COL,
        "id_cols": [],
        "drop_cols": [],
    },
    "split": {"type": "kfold", "n_splits": 5, "seed": 42},
    "train": {"seed": 42},
    "causal": {
        "method": "dr",
        "treatment_col": TREATMENT_COL,
        "estimand": "att",  # explicit default
        "propensity_calibration": "platt",  # explicit default
        "propensity_clip": 0.01,
        "cross_fit": True,
    },
    "export": {"artifact_dir": "artifacts"},
}

result = estimate_dr(config)
print("run_id:", result.run_id)
print("estimate:", result.estimate)
print("95% CI:", result.ci_lower, result.ci_upper)
display(pd.DataFrame([result.metrics]))


In [None]:
summary_path = Path(result.metadata["summary_path"])
obs_path = Path(result.metadata["observation_path"])

summary = json.loads(summary_path.read_text(encoding="utf-8"))
obs = pd.read_parquet(obs_path)

estimate_table = pd.DataFrame(
    [
        {"metric": "naive", "value": result.metrics.get("naive")},
        {"metric": "ipw", "value": result.metrics.get("ipw")},
        {"metric": "dr", "value": result.metrics.get("dr")},
        {"metric": "ci_lower", "value": result.ci_lower},
        {"metric": "ci_upper", "value": result.ci_upper},
    ]
)

display(estimate_table)
display(obs.head())


In [None]:
plot_df = estimate_table[estimate_table["metric"].isin(["naive", "ipw", "dr"])].copy()
fig, ax = plt.subplots(figsize=(7, 4))
ax.bar(plot_df["metric"], plot_df["value"], color=["#9ca3af", "#60a5fa", "#10b981"])
ax.errorbar(
    x=[2],
    y=[result.estimate],
    yerr=[[result.estimate - result.ci_lower], [result.ci_upper - result.estimate]],
    fmt="o",
    color="black",
    capsize=5,
)
ax.set_title("Naive vs IPW vs DR (ATT)")
ax.set_ylabel("Estimated effect on re78")
ax.grid(axis="y", alpha=0.2)
plt.show()


In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 4), sharey=True)
for idx, col in enumerate(["e_raw", "e_hat"]):
    axes[idx].hist(obs.loc[obs["treatment"] == 1, col], bins=30, alpha=0.6, label="treated")
    axes[idx].hist(obs.loc[obs["treatment"] == 0, col], bins=30, alpha=0.6, label="control")
    axes[idx].set_title(f"{col} distribution")
    axes[idx].set_xlabel(col)
    axes[idx].grid(alpha=0.2)
axes[0].set_ylabel("count")
axes[1].legend()
plt.tight_layout()
plt.show()


In [None]:
def _weighted_mean(x: np.ndarray, w: np.ndarray) -> float:
    denom = float(np.sum(w))
    if denom <= 0:
        return float(np.nan)
    return float(np.sum(x * w) / denom)


def _weighted_var(x: np.ndarray, w: np.ndarray) -> float:
    mu = _weighted_mean(x, w)
    if np.isnan(mu):
        return float(np.nan)
    denom = float(np.sum(w))
    if denom <= 0:
        return float(np.nan)
    return float(np.sum(w * (x - mu) ** 2) / denom)


def _smd(x_t: np.ndarray, x_c: np.ndarray) -> float:
    var_t = np.var(x_t)
    var_c = np.var(x_c)
    denom = np.sqrt((var_t + var_c) / 2.0)
    return float((np.mean(x_t) - np.mean(x_c)) / denom) if denom > 0 else 0.0


def _smd_weighted(x_t: np.ndarray, w_t: np.ndarray, x_c: np.ndarray, w_c: np.ndarray) -> float:
    mu_t = _weighted_mean(x_t, w_t)
    mu_c = _weighted_mean(x_c, w_c)
    var_t = _weighted_var(x_t, w_t)
    var_c = _weighted_var(x_c, w_c)
    denom = np.sqrt((var_t + var_c) / 2.0)
    if denom <= 0 or np.isnan(denom):
        return 0.0
    return float((mu_t - mu_c) / denom)


balance_cols = ["age", "educ", "black", "hispan", "married", "nodegree", "re74", "re75"]
analysis_df = lalonde_df.copy()
analysis_df["w_att"] = obs["weight"].to_numpy(dtype=float)

records = []
for col in balance_cols:
    t_mask = analysis_df[TREATMENT_COL] == 1
    c_mask = analysis_df[TREATMENT_COL] == 0
    x_t = analysis_df.loc[t_mask, col].to_numpy(dtype=float)
    x_c = analysis_df.loc[c_mask, col].to_numpy(dtype=float)
    w_t = analysis_df.loc[t_mask, "w_att"].to_numpy(dtype=float)
    w_c = analysis_df.loc[c_mask, "w_att"].to_numpy(dtype=float)

    records.append(
        {
            "feature": col,
            "smd_unweighted": _smd(x_t, x_c),
            "smd_weighted": _smd_weighted(x_t, w_t, x_c, w_c),
        }
    )

balance_df = pd.DataFrame(records).sort_values("smd_unweighted", key=lambda s: np.abs(s), ascending=False)
display(balance_df)

fig, ax = plt.subplots(figsize=(8, 5))
y = np.arange(len(balance_df))
ax.scatter(balance_df["smd_unweighted"], y, label="unweighted", color="#ef4444")
ax.scatter(balance_df["smd_weighted"], y, label="weighted (ATT)", color="#2563eb")
ax.axvline(0.0, color="black", linewidth=1)
ax.axvline(0.1, color="gray", linestyle="--", linewidth=1)
ax.axvline(-0.1, color="gray", linestyle="--", linewidth=1)
ax.set_yticks(y)
ax.set_yticklabels(balance_df["feature"])
ax.set_xlabel("Standardized Mean Difference")
ax.set_title("Covariate balance: before vs after ATT weighting")
ax.legend()
ax.grid(alpha=0.2)
plt.tight_layout()
plt.show()


In [None]:
analysis_summary = {
    "run_id": result.run_id,
    "estimand": result.estimand,
    "estimate": result.estimate,
    "ci_lower": result.ci_lower,
    "ci_upper": result.ci_upper,
    "metrics": result.metrics,
    "summary_path": result.metadata.get("summary_path"),
    "observation_path": result.metadata.get("observation_path"),
    "cache_path": str(CACHE_PATH),
}

SUMMARY_PATH.write_text(json.dumps(analysis_summary, indent=2, sort_keys=True), encoding="utf-8")
print("analysis_summary_path=", SUMMARY_PATH)


## Interpretation Notes

- `naive` vs `ipw` vs `dr` helps inspect confounding adjustment impact.
- If weighted SMD is still large for multiple covariates, overlap/model misspecification risk remains.
- For production use, combine this with sensitivity checks and cohort diagnostics.
