Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions diff_diff/bootstrap_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,53 @@
"compute_bootstrap_pvalue",
"compute_effect_bootstrap_stats",
"compute_effect_bootstrap_stats_batch",
"warn_bootstrap_failure_rate",
]


def warn_bootstrap_failure_rate(
n_success: int,
n_attempted: int,
context: str,
threshold: float = 0.05,
stacklevel: int = 3,
) -> None:
"""Emit one proportional failure-rate warning after a replicate loop.

Replaces the hard-coded ``< N successes`` pattern that lets high-failure
runs (e.g. 11 of 200) pass silently. Does not emit when
``n_attempted == 0`` (callers handle that degenerate path explicitly);
when ``n_success == 0`` and ``n_attempted > 0``, the warning fires
describing the all-failed run.

Parameters
----------
n_success : int
Number of replicates that produced a finite estimate.
n_attempted : int
Total replicates attempted (``self.n_bootstrap``).
context : str
Short label for the caller (e.g. ``"TROP global bootstrap"``).
threshold : float, default=0.05
Failure-rate threshold above which a warning is emitted. ``0.05``
matches the existing SyntheticDiD bootstrap and placebo guards.
stacklevel : int, default=3
Passed to :func:`warnings.warn`.
"""
if n_attempted <= 0 or n_success >= n_attempted:
return
failure_rate = 1.0 - (n_success / n_attempted)
if failure_rate <= threshold:
return
warnings.warn(
f"Only {n_success}/{n_attempted} bootstrap iterations succeeded in "
f"{context} ({failure_rate:.1%} failure rate). "
"Standard errors may be unreliable.",
UserWarning,
stacklevel=stacklevel,
)


def generate_bootstrap_weights(
n_units: int,
weight_type: str,
Expand Down
60 changes: 36 additions & 24 deletions diff_diff/trop_global.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
_rust_bootstrap_trop_variance_global,
_rust_loocv_grid_search_global,
)
from diff_diff.bootstrap_utils import warn_bootstrap_failure_rate
from diff_diff.trop_local import _soft_threshold_svd, _validate_and_pivot_treatment
from diff_diff.trop_results import TROPResults
from diff_diff.utils import safe_inference, warn_if_not_converged
Expand Down Expand Up @@ -169,7 +170,11 @@ def _solve_global_model(
L = np.zeros((n_periods, n_units))
else:
mu, alpha, beta, L = self._solve_global_with_lowrank(
Y, delta, lambda_nn, self.max_iter, self.tol,
Y,
delta,
lambda_nn,
self.max_iter,
self.tol,
_nonconvergence_tracker=_nonconvergence_tracker,
)
return mu, alpha, beta, L
Expand Down Expand Up @@ -284,7 +289,9 @@ def _loocv_score_global(

try:
mu, alpha, beta, L = self._solve_global_model(
Y, delta_ex, lambda_nn,
Y,
delta_ex,
lambda_nn,
_nonconvergence_tracker=nonconverg_tracker,
)

Expand Down Expand Up @@ -988,13 +995,13 @@ def _bootstrap_variance_global(
unit_weight_arr,
)

if len(bootstrap_estimates) < 10:
warnings.warn(
f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded.",
UserWarning,
)
if len(bootstrap_estimates) == 0:
return np.nan, np.array([])
warn_bootstrap_failure_rate(
n_success=len(bootstrap_estimates),
n_attempted=self.n_bootstrap,
context="TROP global bootstrap (Rust)",
)
if len(bootstrap_estimates) == 0:
return np.nan, np.array([])

return float(se), np.array(bootstrap_estimates)

Expand Down Expand Up @@ -1073,12 +1080,13 @@ def _bootstrap_variance_global(
self.tol,
)

if len(bootstrap_estimates) < 10:
warnings.warn(
f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded.", UserWarning
)
if len(bootstrap_estimates) == 0:
return np.nan, np.array([])
warn_bootstrap_failure_rate(
n_success=len(bootstrap_estimates),
n_attempted=self.n_bootstrap,
context="TROP global bootstrap",
)
if len(bootstrap_estimates) == 0:
return np.nan, np.array([])

se = np.std(bootstrap_estimates, ddof=1)
return float(se), bootstrap_estimates
Expand Down Expand Up @@ -1236,7 +1244,9 @@ def _bootstrap_rao_wu_global(
Y, D, lambda_time, lambda_unit, treated_periods, n_units, n_periods
)
mu, alpha, beta, L = self._solve_global_model(
Y, delta, lambda_nn,
Y,
delta,
lambda_nn,
_nonconvergence_tracker=nonconverg_tracker,
)

Expand All @@ -1261,13 +1271,13 @@ def _bootstrap_rao_wu_global(
self.tol,
)

if len(bootstrap_estimates) < 10:
warnings.warn(
f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded.",
UserWarning,
)
if len(bootstrap_estimates) == 0:
return np.nan, np.array([])
warn_bootstrap_failure_rate(
n_success=len(bootstrap_estimates),
n_attempted=self.n_bootstrap,
context="TROP global Rao-Wu bootstrap",
)
if len(bootstrap_estimates) == 0:
return np.nan, np.array([])

se = np.std(bootstrap_estimates, ddof=1)
return float(se), bootstrap_estimates
Expand Down Expand Up @@ -1325,7 +1335,9 @@ def _fit_global_with_fixed_lambda(

# Fit model on control data and extract post-hoc tau
mu, alpha, beta, L = self._solve_global_model(
Y, delta, lambda_nn,
Y,
delta,
lambda_nn,
_nonconvergence_tracker=_nonconvergence_tracker,
)
att, _, _ = self._extract_posthoc_tau(
Expand Down
57 changes: 32 additions & 25 deletions diff_diff/trop_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
_rust_bootstrap_trop_variance,
_rust_unit_distance_matrix,
)
from diff_diff.bootstrap_utils import warn_bootstrap_failure_rate
from diff_diff.trop_results import _PrecomputedStructures
from diff_diff.utils import warn_if_not_converged

Expand Down Expand Up @@ -727,8 +728,10 @@ def _estimate_model(
_nonconvergence_tracker.append(1)
else:
warn_if_not_converged(
converged, "TROP local alternating minimization",
self.max_iter, self.tol,
converged,
"TROP local alternating minimization",
self.max_iter,
self.tol,
)

return alpha, beta, L
Expand Down Expand Up @@ -982,13 +985,14 @@ def _bootstrap_variance(
unit_weight_arr,
)

if len(bootstrap_estimates) >= 10:
if len(bootstrap_estimates) > 0:
warn_bootstrap_failure_rate(
n_success=len(bootstrap_estimates),
n_attempted=self.n_bootstrap,
context="TROP local bootstrap (Rust)",
)
return float(se), bootstrap_estimates
# Fall through to Python if too few bootstrap samples
logger.debug(
"Rust bootstrap returned only %d samples, falling back to Python",
len(bootstrap_estimates),
)
logger.debug("Rust bootstrap returned 0 samples, falling back to Python")
except Exception as e:
logger.debug("Rust bootstrap variance failed, falling back to Python: %s", e)
warnings.warn(
Expand Down Expand Up @@ -1068,14 +1072,13 @@ def _bootstrap_variance(
self.tol,
)

if len(bootstrap_estimates) < 10:
warnings.warn(
f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded. "
"Standard errors may be unreliable.",
UserWarning,
)
if len(bootstrap_estimates) == 0:
return np.nan, np.array([])
warn_bootstrap_failure_rate(
n_success=len(bootstrap_estimates),
n_attempted=self.n_bootstrap,
context="TROP local bootstrap",
)
if len(bootstrap_estimates) == 0:
return np.nan, np.array([])

se = np.std(bootstrap_estimates, ddof=1)
return float(se), bootstrap_estimates
Expand Down Expand Up @@ -1236,14 +1239,13 @@ def _bootstrap_rao_wu_local(
self.tol,
)

if len(bootstrap_estimates) < 10:
warnings.warn(
f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded. "
"Standard errors may be unreliable.",
UserWarning,
)
if len(bootstrap_estimates) == 0:
return np.nan, np.array([])
warn_bootstrap_failure_rate(
n_success=len(bootstrap_estimates),
n_attempted=self.n_bootstrap,
context="TROP local Rao-Wu bootstrap",
)
if len(bootstrap_estimates) == 0:
return np.nan, np.array([])

se = np.std(bootstrap_estimates, ddof=1)
return float(se), bootstrap_estimates
Expand Down Expand Up @@ -1338,7 +1340,12 @@ def _fit_with_fixed_lambda(

# Fit model with these weights
alpha, beta, L = self._estimate_model(
Y, control_mask, weight_matrix, lambda_nn, n_units, n_periods,
Y,
control_mask,
weight_matrix,
lambda_nn,
n_units,
n_periods,
_nonconvergence_tracker=_nonconvergence_tracker,
)

Expand Down
1 change: 1 addition & 0 deletions docs/methodology/REGISTRY.md
Original file line number Diff line number Diff line change
Expand Up @@ -1971,6 +1971,7 @@ Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]²
- Wrong D specification: if user provides event-style D (only first treatment period),
the absorbing-state validation will raise ValueError with helpful guidance
- **Bootstrap minimum**: `n_bootstrap` must be >= 2 (enforced via `ValueError`). TROP uses bootstrap for all variance estimation — there is no analytical SE formula.
- **Note:** TROP bootstrap loops (`_bootstrap_variance`, `_bootstrap_rao_wu`, and their global counterparts, including both Rust happy paths — local and global) emit a proportional `UserWarning` via `diff_diff.bootstrap_utils.warn_bootstrap_failure_rate` when the replicate failure rate exceeds 5%. The previous hard-coded `< 10 successes` threshold let high-failure runs (e.g. 11 of 200) pass silently; this was classified as a silent failure under the Phase 2 audit (axis D — degenerate-replicate handling). The 5% threshold matches the existing SyntheticDiD bootstrap and placebo guards. When zero replicates succeed, SE is set to `NaN` (unchanged). The local Rust path previously also used `len >= 10` as a Python-fallback trigger; it now accepts any non-zero Rust result and emits the proportional warning instead of path-switching silently.
- **LOOCV failure metadata**: When LOOCV fits fail in the Rust backend, the first failed observation coordinates (t, i) are returned to Python for informative warning messages
- **Inference CI distribution**: After `safe_inference()` migration, CI uses t-distribution (df = max(1, n_treated_obs - 1)), consistent with p_value. Previously CI used normal-distribution while p_value used t-distribution (inconsistent). This is a minor behavioral change; CIs may be slightly wider for small n_treated_obs.
- **Note:** Both the `local` alternating-minimization solver (`_estimate_model`) and the `global` alternating-minimization solver (`_solve_global_with_lowrank`, including its hard-coded inner FISTA loop of 20 iterations) emit `UserWarning` via `diff_diff.utils.warn_if_not_converged` when the outer loop exhausts `max_iter` without reaching `tol`. The global-method warning surfaces the inner-FISTA non-convergence count as diagnostic context. Silent return of the current iterate was classified as a silent failure under the Phase 2 audit and replaced with an explicit signal to match the convention used across other iterative solvers in the library.
Expand Down
77 changes: 66 additions & 11 deletions tests/test_bootstrap_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from diff_diff.bootstrap_utils import (
compute_effect_bootstrap_stats,
compute_effect_bootstrap_stats_batch,
warn_bootstrap_failure_rate,
)


Expand Down Expand Up @@ -76,9 +77,7 @@ def test_nonfinite_original_effect_with_finite_boot_dist(self, bad_value):
def test_bootstrap_stats_normal_case(self):
"""Normal case with varied values: all fields finite."""
boot_dist = np.arange(100.0)
se, ci, p_value = compute_effect_bootstrap_stats(
original_effect=50.0, boot_dist=boot_dist
)
se, ci, p_value = compute_effect_bootstrap_stats(original_effect=50.0, boot_dist=boot_dist)
assert np.isfinite(se)
assert se > 0
assert np.isfinite(ci[0])
Expand All @@ -102,9 +101,7 @@ def test_batch_warns_insufficient_valid_samples(self):

effects = np.array([1.0, 2.0, 3.0])
with pytest.warns(RuntimeWarning, match="too few valid"):
ses, ci_lo, ci_hi, pvals = compute_effect_bootstrap_stats_batch(
effects, matrix
)
ses, ci_lo, ci_hi, pvals = compute_effect_bootstrap_stats_batch(effects, matrix)
# Effect 1 (index 1) should be NaN
assert np.isnan(ses[1])
# Other effects should be finite
Expand All @@ -119,9 +116,7 @@ def test_batch_warns_zero_se(self):

effects = np.array([5.0, 5.0])
with pytest.warns(RuntimeWarning, match="non-finite or zero"):
ses, ci_lo, ci_hi, pvals = compute_effect_bootstrap_stats_batch(
effects, matrix
)
ses, ci_lo, ci_hi, pvals = compute_effect_bootstrap_stats_batch(effects, matrix)
assert np.isnan(ses[0])
assert np.isnan(ses[1])

Expand All @@ -135,6 +130,66 @@ def test_batch_no_warning_for_normal_case(self):

with warnings.catch_warnings():
warnings.simplefilter("error", RuntimeWarning)
ses, ci_lo, ci_hi, pvals = compute_effect_bootstrap_stats_batch(
effects, matrix
ses, ci_lo, ci_hi, pvals = compute_effect_bootstrap_stats_batch(effects, matrix)


class TestWarnBootstrapFailureRate:
"""Proportional failure-rate guard for replicate loops (axis-D)."""

def test_warns_above_threshold(self):
"""11/200 successes = 94.5% failure rate — must warn."""
with pytest.warns(UserWarning, match=r"11/200 bootstrap iterations"):
warn_bootstrap_failure_rate(n_success=11, n_attempted=200, context="test case")

def test_warning_message_includes_context(self):
"""Context label must appear verbatim in the warning."""
with pytest.warns(UserWarning, match="TROP global bootstrap") as rec:
warn_bootstrap_failure_rate(
n_success=50,
n_attempted=200,
context="TROP global bootstrap",
)
assert len(rec) == 1
assert "75.0% failure rate" in str(rec[0].message)

def test_silent_below_threshold(self):
"""Default threshold=0.05 — 4% failure is below and must not warn."""
with warnings.catch_warnings():
warnings.simplefilter("error", UserWarning)
warn_bootstrap_failure_rate(n_success=960, n_attempted=1000, context="test case")

def test_silent_on_full_success(self):
"""No warning when every replicate succeeded."""
with warnings.catch_warnings():
warnings.simplefilter("error", UserWarning)
warn_bootstrap_failure_rate(n_success=200, n_attempted=200, context="test case")

def test_silent_when_n_attempted_zero(self):
"""Degenerate empty call must not divide by zero."""
with warnings.catch_warnings():
warnings.simplefilter("error", UserWarning)
warn_bootstrap_failure_rate(n_success=0, n_attempted=0, context="test case")

def test_custom_threshold(self):
"""Higher threshold suppresses the 50% case."""
with warnings.catch_warnings():
warnings.simplefilter("error", UserWarning)
warn_bootstrap_failure_rate(
n_success=100,
n_attempted=200,
context="test case",
threshold=0.75,
)

with pytest.warns(UserWarning, match="50.0% failure rate"):
warn_bootstrap_failure_rate(
n_success=100,
n_attempted=200,
context="test case",
threshold=0.25,
)

def test_all_failed_warns(self):
"""0/N replicates succeeded — caller handles NaN return, but the warning fires."""
with pytest.warns(UserWarning, match=r"0/50 bootstrap iterations"):
warn_bootstrap_failure_rate(n_success=0, n_attempted=50, context="test case")
Loading
Loading