Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions diff_diff/trop.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
_PrecomputedStructures,
TROPResults,
)
from diff_diff.utils import safe_inference
from diff_diff.utils import safe_inference, warn_if_not_converged


class TROP(TROPLocalMixin, TROPGlobalMixin):
Expand Down Expand Up @@ -748,6 +748,8 @@ def fit(

# Use pre-computed treated observations
treated_observations = self._precomputed["treated_observations"]
nonconverg_tracker: list = []
n_fits_attempted = 0

for t, i in treated_observations:
unit_id = idx_to_unit[i]
Expand All @@ -765,8 +767,10 @@ def fit(
)

# Fit model with these weights
n_fits_attempted += 1
alpha_hat, beta_hat, L_hat = self._estimate_model(
Y, control_mask, weight_matrix, lambda_nn, n_units, n_periods
Y, control_mask, weight_matrix, lambda_nn, n_units, n_periods,
_nonconvergence_tracker=nonconverg_tracker,
)

# Compute treatment effect: tau_{it} = Y_{it} - alpha_i - beta_t - L_{it}
Expand All @@ -782,6 +786,16 @@ def fit(
beta_estimates.append(beta_hat)
L_estimates.append(L_hat)

if nonconverg_tracker:
warn_if_not_converged(
False,
f"TROP local per-treated-observation fit: "
f"{len(nonconverg_tracker)} of {n_fits_attempted} "
f"fits did not converge",
self.max_iter,
self.tol,
)

# Count valid treated observations
n_valid_treated = len(tau_values)
if n_valid_treated == 0:
Expand Down
76 changes: 70 additions & 6 deletions diff_diff/trop_global.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)
from diff_diff.trop_local import _soft_threshold_svd, _validate_and_pivot_treatment
from diff_diff.trop_results import TROPResults
from diff_diff.utils import safe_inference
from diff_diff.utils import safe_inference, warn_if_not_converged


class TROPGlobalMixin:
Expand Down Expand Up @@ -156,6 +156,7 @@ def _solve_global_model(
Y: np.ndarray,
delta: np.ndarray,
lambda_nn: float,
_nonconvergence_tracker: Optional[List[int]] = None,
) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray]:
"""
Dispatch to no-lowrank or with-lowrank solver based on lambda_nn.
Expand All @@ -168,7 +169,8 @@ def _solve_global_model(
L = np.zeros((n_periods, n_units))
else:
mu, alpha, beta, L = self._solve_global_with_lowrank(
Y, delta, lambda_nn, self.max_iter, self.tol
Y, delta, lambda_nn, self.max_iter, self.tol,
_nonconvergence_tracker=_nonconvergence_tracker,
)
return mu, alpha, beta, L

Expand Down Expand Up @@ -273,14 +275,18 @@ def _loocv_score_global(

tau_sq_sum = 0.0
n_valid = 0
nonconverg_tracker: List[int] = []

for t_ex, i_ex in control_obs:
# Create modified delta with excluded observation zeroed out
delta_ex = delta.copy()
delta_ex[t_ex, i_ex] = 0.0

try:
mu, alpha, beta, L = self._solve_global_model(Y, delta_ex, lambda_nn)
mu, alpha, beta, L = self._solve_global_model(
Y, delta_ex, lambda_nn,
_nonconvergence_tracker=nonconverg_tracker,
)

# Pseudo treatment effect: tau = Y - mu - alpha - beta - L
if np.isfinite(Y[t_ex, i_ex]):
Expand All @@ -292,6 +298,16 @@ def _loocv_score_global(
# Any failure means this lambda combination is invalid per Equation 5
return np.inf

if nonconverg_tracker:
warn_if_not_converged(
False,
f"TROP global LOOCV: {len(nonconverg_tracker)} of {len(control_obs)} "
f"per-observation fits did not converge "
f"(\u03bb=({lambda_time}, {lambda_unit}, {lambda_nn}))",
self.max_iter,
self.tol,
)

if n_valid == 0:
return np.inf

Expand Down Expand Up @@ -395,6 +411,7 @@ def _solve_global_with_lowrank(
lambda_nn: float,
max_iter: int = 100,
tol: float = 1e-6,
_nonconvergence_tracker: Optional[List[int]] = None,
) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray]:
"""
Solve TWFE + low-rank on control data via alternating minimization.
Expand Down Expand Up @@ -445,6 +462,9 @@ def _solve_global_with_lowrank(
# Initialize L = 0
L = np.zeros((n_periods, n_units))

_FISTA_MAX_ITER = 20
inner_nonconverged_count = 0
outer_converged = False
for iteration in range(max_iter):
L_old = L.copy()

Expand All @@ -463,7 +483,8 @@ def _solve_global_with_lowrank(
L_inner_prev = L_inner # share reference initially (no copy needed)
t_fista = 1.0

for _ in range(20):
inner_converged = False
for _ in range(_FISTA_MAX_ITER):
# FISTA momentum
t_fista_new = (1.0 + np.sqrt(1.0 + 4.0 * t_fista**2)) / 2.0
momentum = (t_fista - 1.0) / t_fista_new
Expand All @@ -479,14 +500,29 @@ def _solve_global_with_lowrank(

# Convergence check (L_inner_prev holds the pre-SVD value)
if np.max(np.abs(L_inner - L_inner_prev)) < tol:
inner_converged = True
break
if not inner_converged:
inner_nonconverged_count += 1

L = L_inner

# Outer convergence check
if np.max(np.abs(L - L_old)) < tol:
outer_converged = True
break

if not outer_converged:
if _nonconvergence_tracker is not None:
_nonconvergence_tracker.append(inner_nonconverged_count)
else:
detail = (
f"TROP global alternating minimization "
f"(inner FISTA non-converged in {inner_nonconverged_count}/{max_iter} "
f"outer iterations, FISTA max_iter={_FISTA_MAX_ITER})"
)
warn_if_not_converged(False, detail, max_iter, tol)

# Final re-solve with converged L (match Rust behavior)
Y_adj = Y_safe - L
mu, alpha, beta = self._solve_global_no_lowrank(Y_adj, delta_masked)
Expand Down Expand Up @@ -984,6 +1020,7 @@ def _bootstrap_variance_global(
n_control_units = len(control_units)

bootstrap_estimates_list: List[float] = []
nonconverg_tracker: List[int] = []

for _ in range(self.n_bootstrap):
# Stratified sampling
Expand Down Expand Up @@ -1018,6 +1055,7 @@ def _bootstrap_variance_global(
optimal_lambda,
treated_periods,
survey_design=survey_design,
_nonconvergence_tracker=nonconverg_tracker,
)
if np.isfinite(tau):
bootstrap_estimates_list.append(tau)
Expand All @@ -1026,6 +1064,15 @@ def _bootstrap_variance_global(

bootstrap_estimates = np.array(bootstrap_estimates_list)

if nonconverg_tracker:
warn_if_not_converged(
False,
f"TROP global bootstrap: {len(nonconverg_tracker)} of "
f"{self.n_bootstrap} replicate fits did not converge",
self.max_iter,
self.tol,
)

if len(bootstrap_estimates) < 10:
warnings.warn(
f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded.", UserWarning
Expand Down Expand Up @@ -1169,6 +1216,7 @@ def _bootstrap_rao_wu_global(
)

bootstrap_estimates_list: List[float] = []
nonconverg_tracker: List[int] = []

for _ in range(self.n_bootstrap):
try:
Expand All @@ -1187,7 +1235,10 @@ def _bootstrap_rao_wu_global(
delta = self._compute_global_weights(
Y, D, lambda_time, lambda_unit, treated_periods, n_units, n_periods
)
mu, alpha, beta, L = self._solve_global_model(Y, delta, lambda_nn)
mu, alpha, beta, L = self._solve_global_model(
Y, delta, lambda_nn,
_nonconvergence_tracker=nonconverg_tracker,
)

# Extract weighted ATT using Rao-Wu rescaled weights
att, _, _ = self._extract_posthoc_tau(
Expand All @@ -1201,6 +1252,15 @@ def _bootstrap_rao_wu_global(

bootstrap_estimates = np.array(bootstrap_estimates_list)

if nonconverg_tracker:
warn_if_not_converged(
False,
f"TROP global Rao-Wu bootstrap: {len(nonconverg_tracker)} of "
f"{self.n_bootstrap} replicate fits did not converge",
self.max_iter,
self.tol,
)

if len(bootstrap_estimates) < 10:
warnings.warn(
f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded.",
Expand All @@ -1222,6 +1282,7 @@ def _fit_global_with_fixed_lambda(
fixed_lambda: Tuple[float, float, float],
treated_periods: int,
survey_design=None,
_nonconvergence_tracker: Optional[List[int]] = None,
) -> float:
"""
Fit global model with fixed tuning parameters.
Expand Down Expand Up @@ -1263,7 +1324,10 @@ def _fit_global_with_fixed_lambda(
)

# Fit model on control data and extract post-hoc tau
mu, alpha, beta, L = self._solve_global_model(Y, delta, lambda_nn)
mu, alpha, beta, L = self._solve_global_model(
Y, delta, lambda_nn,
_nonconvergence_tracker=_nonconvergence_tracker,
)
att, _, _ = self._extract_posthoc_tau(
Y, D, mu, alpha, beta, L, unit_weights=local_weight_arr
)
Expand Down
Loading
Loading