Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions diff_diff/continuous_did_bspline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
the dose-response curve estimation in ContinuousDiD.
"""

import warnings

import numpy as np
from scipy.interpolate import BSpline

Expand Down Expand Up @@ -140,9 +142,12 @@ def bspline_derivative_design_matrix(x, knots, degree, include_intercept=True):

# Check if knot vector is degenerate (all identical, e.g. single dose)
if knots[0] == knots[-1]:
# All knots identical: derivatives are all zero
# All knots identical: derivatives are all zero — this is a
# mathematically well-defined degenerate case (single dose value
# means no dose variation to differentiate), handled silently.
pass
else:
failed_basis_indices = []
for j in range(n_basis):
c = np.zeros(n_basis)
c[j] = 1.0
Expand All @@ -151,8 +156,29 @@ def bspline_derivative_design_matrix(x, knots, degree, include_intercept=True):
deriv_j = spline_j.derivative()
dB[:, j] = deriv_j(x_clamped)
except ValueError:
# Degenerate knot vector: derivative is zero
pass
# Finding #12 (axis C, silent-failures audit): silent pass
# on ValueError meant a malformed knot vector (too few
# knots for the degree, non-monotonic, etc.) quietly set
# whole columns of the derivative design matrix to zero.
# Downstream ContinuousDiD inference then used a silently
# biased dPsi matrix. Track affected basis indices so we
# can surface ONE aggregate warning.
failed_basis_indices.append(j)

if failed_basis_indices:
warnings.warn(
f"B-spline derivative construction failed for "
f"{len(failed_basis_indices)} of {n_basis} basis function(s) "
f"(indices {failed_basis_indices}); their derivative columns "
f"are zero. This typically indicates a malformed knot vector "
f"(too few knots for the chosen degree, non-monotonic, or "
f"repeated interior knots). Both ACRT point estimates and "
f"analytical/bootstrap inference depend on this derivative "
f"matrix, so both may be biased. Consider increasing the "
f"number of distinct doses or reducing the B-spline degree.",
UserWarning,
stacklevel=2,
)

if include_intercept:
# Drop first column (intercept derivative = 0), prepend zeros
Expand Down
27 changes: 17 additions & 10 deletions diff_diff/power.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,17 +167,24 @@ def __post_init__(self) -> None:
)

def _build_survey_design(self) -> Any:
"""Return cached SurveyDesign (built once, reused across simulations)."""
if not hasattr(self, "_cached_survey_design"):
if self.survey_design is not None:
self._cached_survey_design = self.survey_design
else:
from diff_diff.survey import SurveyDesign
"""Return a SurveyDesign for this config.

Reflects the live ``self.survey_design`` value every call (no
caching). Finding #28 (axis J, silent-failures audit): the
previous ``_cached_survey_design`` was populated on first call
and never invalidated on mutation, so ``config.survey_design =
other_design`` silently kept returning the original. Since the
default ``SurveyDesign(...)`` construction is microseconds and
user-provided designs are just reference copies, there's no cache
cost worth keeping.
"""
if self.survey_design is not None:
return self.survey_design
from diff_diff.survey import SurveyDesign

self._cached_survey_design = SurveyDesign(
weights="weight", strata="stratum", psu="psu", fpc="fpc"
)
return self._cached_survey_design
return SurveyDesign(
weights="weight", strata="stratum", psu="psu", fpc="fpc"
)

@property
def min_viable_n(self) -> int:
Expand Down
2 changes: 2 additions & 0 deletions docs/methodology/REGISTRY.md
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,7 @@ See `docs/methodology/continuous-did.md` Section 4 for full details.
not-yet-treated controls. When `anticipation=0` (default), behavior is
unchanged.
- **Boundary knots**: Knots are built once from all treated doses (global, not per-cell) to ensure a common basis across (g,t) cells for aggregation. Evaluation grid is clamped to training-dose boundary knots (`range(dose)`). R's `contdid` v0.1.0 has an inconsistency where `splines2::bSpline(dvals)` uses `range(dvals)` instead of `range(dose)`, which can produce extrapolation artifacts at dose grid extremes. Our approach avoids extrapolation and is methodologically sound.
- **Note:** `bspline_derivative_design_matrix` previously swallowed `ValueError` from `scipy.interpolate.BSpline` in the per-basis derivative loop, leaving affected columns of the derivative design matrix as zero with no user-facing signal. It now aggregates the failed basis indices and emits ONE `UserWarning` naming them. Both ACRT point estimates and analytical/bootstrap inference read the same `dPsi` matrix (see `continuous_did.py:1026-1046` and the bootstrap ACRT path at `continuous_did.py:1524-1561`), so both are biased on a partial derivative-construction failure — the warning wording makes that explicit. The all-identical-knot degenerate case (single dose value) remains silently handled — derivatives there are mathematically zero. Axis-C finding #12 in the Phase 2 silent-failures audit.

### Implementation Checklist

Expand Down Expand Up @@ -2582,6 +2583,7 @@ n = 2(t_{α/2} + t_{1-κ})² σ² / MDE²
- **Note:** The `TripleDifference` registry adapter uses `generate_ddd_data`, a fixed 2×2×2 factorial DGP (group × partition × time). The `n_periods`, `treatment_period`, and `treatment_fraction` parameters are ignored — DDD always simulates 2 periods with balanced groups. `n_units` is mapped to `n_per_cell = max(2, n_units // 8)` (effective total N = `n_per_cell × 8`), so non-multiples of 8 are rounded down and values below 16 are clamped to 16. A `UserWarning` is emitted when simulation inputs differ from the effective DDD design. When rounding occurs, all result objects (`SimulationPowerResults`, `SimulationMDEResults`, `SimulationSampleSizeResults`) set `effective_n_units` to the actual sample size used; it is `None` when no rounding occurred. `simulate_sample_size()` snaps bisection candidates to multiples of 8 so that `required_n` is always a realizable DDD sample size. Passing `n_per_cell` in `data_generator_kwargs` suppresses the effective-N rounding warning but not warnings for ignored parameters (`n_periods`, `treatment_period`, `treatment_fraction`).
- **Note:** The analytical power methods (`PowerAnalysis.power/mde/sample_size` and the `compute_power/compute_mde/compute_sample_size` convenience functions) accept a `deff` parameter (survey design effect, default 1.0). This inflates variance multiplicatively: `Var(ATT) *= deff`, and inflates required sample size: `n_total *= deff`. The `deff` parameter is **not redundant** with `rho` (intra-cluster correlation): `rho` models within-unit serial correlation in panel data via the Moulton factor `1 + (T-1)*rho`, while `deff` models the survey design effect from stratified multi-stage sampling (clustering + unequal weighting). A survey panel study may need both. Values `deff > 0` are accepted; `deff < 1.0` (net variance reduction, e.g., from stratification gain) emits a warning.
- **Note:** `simulate_power()` catches a narrow set of exception types — `ValueError`, `numpy.linalg.LinAlgError`, `KeyError`, `RuntimeError`, `ZeroDivisionError` — raised inside the per-simulation fit and result-extraction block, increments a per-effect failure counter, and skips the replicate. Programming errors (`TypeError`, `AttributeError`, `NameError`, `IndexError`, etc.) are allowed to propagate so that bugs in the estimator or custom result extractor surface loudly instead of being absorbed as simulation failures. The primary-effect failure count is surfaced on the result object as `SimulationPowerResults.n_simulation_failures`; a `UserWarning` still fires when the failure rate exceeds 10% for any effect size, and all-failed runs raise `RuntimeError`. This replaces the prior bare `except Exception` that swallowed root causes and kept the counter internal to the function (axis C — silent fallback — under the Phase 2 audit).
- **Note:** `SurveyPowerConfig._build_survey_design()` no longer caches its return value in `self._cached_survey_design`. Reassigning `config.survey_design` (either replacing a user-supplied `SurveyDesign` with another, or toggling between `None` and a user-supplied design) after the first call used to silently return the stale cached design; the method now returns the live `self.survey_design` (or the default construction when `None`) every call. Other config fields (`n_strata`, `icc`, `weight_variation`, etc.) never influenced the returned design, so the staleness surface was specifically `survey_design` reassignment. Construction is microseconds — the cache never earned its complexity. Axis-J finding #28 in the Phase 2 silent-failures audit.
- **Note:** The simulation-based power functions (`simulate_power/simulate_mde/simulate_sample_size`) accept a `survey_config` parameter (`SurveyPowerConfig` dataclass). When set, the simulation loop uses `generate_survey_did_data` instead of the default registry DGP, and automatically injects `SurveyDesign(weights="weight", strata="stratum", psu="psu", fpc="fpc")` into the estimator's `fit()` call. Supported estimators: DifferenceInDifferences, TwoWayFixedEffects, MultiPeriodDiD, CallawaySantAnna, SunAbraham, ImputationDiD, TwoStageDiD, StackedDiD, EfficientDiD. Unsupported (raises `ValueError`): TROP, SyntheticDiD, TripleDifference (generate_survey_did_data produces staggered cohort data incompatible with factor-model/DDD DGPs). `survey_config` and `data_generator` are mutually exclusive. `data_generator_kwargs` may not contain keys managed by `SurveyPowerConfig` (n_strata, psu_per_stratum, etc.) but may contain passthrough DGP params (unit_fe_sd, add_covariates, strata_sizes). Repeated cross-section survey power (`panel=False`) is only supported for `CallawaySantAnna(panel=False)` with a matching `data_generator_kwargs={"panel": False}`; both mismatch directions are rejected. `estimator_kwargs` may not contain `survey_design` when `survey_config` is set (use `SurveyPowerConfig(survey_design=...)` instead). Estimator settings that require a multi-cohort DGP (`control_group="not_yet_treated"`, `control_group="last_cohort"`, `clean_control="strict"`) are rejected because the survey DGP uses a single cohort; use the custom `data_generator` path for these configurations. `simulate_sample_size` raises the bisection floor to `n_strata * psu_per_stratum * 2` to ensure viable survey structure and rejects `strata_sizes` in `data_generator_kwargs` (it depends on `n_units` which varies during bisection).

**Reference implementation(s):**
Expand Down
126 changes: 126 additions & 0 deletions tests/test_continuous_did.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,132 @@ def test_linear_basis(self):
assert B.shape[1] == 2 # intercept + 1 basis fn


# ---------------------------------------------------------------------------
# Finding #12 (axis C, silent-failures audit). Previously
# `bspline_derivative_design_matrix` silently swallowed ValueError in the
# per-basis derivative loop, leaving affected columns of the derivative
# design matrix as zero with no user-visible signal. ContinuousDiD's
# analytical inference then fed a biased dPsi into downstream SE
# computation. The fix aggregates failed-basis indices and emits ONE
# UserWarning naming them.
# ---------------------------------------------------------------------------


class TestBSplineDerivativeDegenerateBasis:
def test_single_dose_is_silent(self):
"""All-identical knots (single dose value) is a well-defined
degenerate case — derivatives are mathematically zero and the
function returns silently. Regression-guard the existing contract."""
x = np.array([3.0, 3.0, 3.0, 3.0])
knots = np.array([3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0]) # all identical
import warnings as _w

with _w.catch_warnings(record=True) as caught:
_w.simplefilter("always")
dB = bspline_derivative_design_matrix(x, knots, degree=3, include_intercept=True)
deriv_warnings = [
w for w in caught if "B-spline derivative construction failed" in str(w.message)
]
assert deriv_warnings == [], (
"All-identical knots should be handled silently (mathematically "
"well-defined zero-derivative case); warning fired unexpectedly: "
f"{[str(w.message) for w in deriv_warnings]}"
)
np.testing.assert_array_equal(dB, np.zeros_like(dB))

def test_valueerror_from_bspline_emits_aggregate_warning(self):
"""When BSpline construction raises ValueError for some basis
functions (malformed knot vector, etc.), the new aggregate
UserWarning must fire naming the affected indices."""
from unittest.mock import patch

import diff_diff.continuous_did_bspline as bspline_mod

dose = np.linspace(1, 5, 30)
knots, deg = build_bspline_basis(dose, degree=3, num_knots=1)
x = np.linspace(1.5, 4.5, 20)

# Force ValueError on basis indices 1 and 3 only; the rest run
# through normally. This is the partial-failure mode the audit
# called out.
real_bspline = bspline_mod.BSpline
call_counter = {"n": 0}

def flaky_bspline(knots, c, degree):
# c is a one-hot vector; the index set to 1 is the basis j
j = int(np.argmax(c))
call_counter["n"] += 1
if j in (1, 3):
raise ValueError(f"forced test failure for basis j={j}")
return real_bspline(knots, c, degree)

import warnings as _w

with patch.object(bspline_mod, "BSpline", side_effect=flaky_bspline):
with _w.catch_warnings(record=True) as caught:
_w.simplefilter("always")
dB = bspline_derivative_design_matrix(
x, knots, degree=deg, include_intercept=True
)

deriv_warnings = [
w for w in caught if "B-spline derivative construction failed" in str(w.message)
]
assert len(deriv_warnings) == 1, (
f"Expected exactly one aggregate warning, got {len(deriv_warnings)}: "
f"{[str(w.message) for w in deriv_warnings]}"
)
msg = str(deriv_warnings[0].message)
# Message must name the failed basis indices so the user can debug.
assert "[1, 3]" in msg, f"Expected indices [1, 3] in warning; got: {msg}"
assert "2 of" in msg, f"Expected failure count '2 of ...' in warning; got: {msg}"
# Affected columns should be zero.
# With include_intercept=True, column 0 is always zero (intercept
# derivative) and basis index j is at dB column j (the drop-first
# then prepend-zeros logic keeps the same per-j mapping for j>=1).
np.testing.assert_array_equal(dB[:, 1], np.zeros(len(x))) # failed basis j=1
np.testing.assert_array_equal(dB[:, 3], np.zeros(len(x))) # failed basis j=3

# Unaffected columns must match the un-patched baseline exactly
# (except columns 1 and 3 which were forced to zero). This guards
# a regression that would zero or corrupt the entire derivative
# matrix on any ValueError.
dB_baseline = bspline_derivative_design_matrix(
x, knots, degree=deg, include_intercept=True
)
for col in range(dB.shape[1]):
if col in (1, 3):
continue
np.testing.assert_array_equal(
dB[:, col],
dB_baseline[:, col],
err_msg=f"Unaffected column {col} diverges from baseline",
)
# At least one non-intercept, non-failed column must be non-zero,
# confirming the function still produces meaningful derivatives.
non_failed_cols = [c for c in range(1, dB.shape[1]) if c not in (1, 3)]
assert any(np.any(dB[:, c] != 0) for c in non_failed_cols), (
"Expected at least one unaffected non-intercept column to have "
"non-zero derivatives; got all-zero dB outside failed cols."
)

def test_clean_knots_emit_no_warning(self):
"""Well-formed knot vector → no ValueError path taken → no
warning. Regression-guard the happy path."""
dose = np.linspace(1, 5, 50)
knots, deg = build_bspline_basis(dose, degree=3, num_knots=2)
x = np.linspace(1.5, 4.5, 30)
import warnings as _w

with _w.catch_warnings(record=True) as caught:
_w.simplefilter("always")
bspline_derivative_design_matrix(x, knots, deg, include_intercept=True)
deriv_warnings = [
w for w in caught if "B-spline derivative construction failed" in str(w.message)
]
assert deriv_warnings == []


class TestDoseGrid:
"""Test dose grid computation."""

Expand Down
Loading
Loading