diff --git a/TODO.md b/TODO.md index 006aa0e..96e176a 100644 --- a/TODO.md +++ b/TODO.md @@ -61,6 +61,7 @@ Deferred items from PR reviews that were not addressed before merge. |-------|----------|----|----------| | Tutorial notebooks not executed in CI | `docs/tutorials/*.ipynb` | #159 | Low | | R comparison tests spawn separate `Rscript` per test (slow CI) | `tests/test_methodology_twfe.py:294` | #139 | Low | +| CS R helpers hard-code `xformla = ~ 1`; no covariate-adjusted R benchmark for IRLS path | `tests/test_methodology_callaway.py` | #202 | Low | --- diff --git a/diff_diff/linalg.py b/diff_diff/linalg.py index a74de9f..56a2052 100644 --- a/diff_diff/linalg.py +++ b/diff_diff/linalg.py @@ -116,7 +116,7 @@ def _detect_rank_deficiency( # Compute pivoted QR decomposition: X @ P = Q @ R # P is a permutation matrix, represented as pivot indices - Q, R, pivot = qr(X, mode='economic', pivoting=True) + Q, R, pivot = qr(X, mode="economic", pivoting=True) # Determine rank tolerance # R's qr() uses tol = 1e-07 by default, which is sqrt(eps) ≈ 1.49e-08 @@ -169,8 +169,7 @@ def _format_dropped_columns( return "" if column_names is not None: - names = [column_names[i] if i < len(column_names) else f"column {i}" - for i in dropped_cols] + names = [column_names[i] if i < len(column_names) else f"column {i}" for i in dropped_cols] if len(names) == 1: return f"'{names[0]}'" elif len(names) <= 5: @@ -251,10 +250,12 @@ def _solve_ols_rust( cluster_ids: Optional[np.ndarray] = None, return_vcov: bool = True, return_fitted: bool = False, -) -> Optional[Union[ - Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]], - Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]], -]]: +) -> Optional[ + Union[ + Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]], + Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]], + ] +]: """ Rust backend implementation of solve_ols for full-rank matrices. @@ -447,8 +448,7 @@ def solve_ols( raise ValueError(f"y must be 1-dimensional, got shape {y.shape}") if X.shape[0] != y.shape[0]: raise ValueError( - f"X and y must have same number of observations: " - f"{X.shape[0]} vs {y.shape[0]}" + f"X and y must have same number of observations: " f"{X.shape[0]} vs {y.shape[0]}" ) n, k = X.shape @@ -484,7 +484,8 @@ def solve_ols( if skip_rank_check: if HAS_RUST_BACKEND and _rust_solve_ols is not None: result = _solve_ols_rust( - X, y, + X, + y, cluster_ids=cluster_ids, return_vcov=return_vcov, return_fitted=return_fitted, @@ -494,7 +495,8 @@ def solve_ols( # Fall through to NumPy on numerical instability # Fall through to Python without rank check (user guarantees full rank) return _solve_ols_numpy( - X, y, + X, + y, cluster_ids=cluster_ids, return_vcov=return_vcov, return_fitted=return_fitted, @@ -521,7 +523,8 @@ def solve_ols( # - No Rust → Python backend (works for all cases) if HAS_RUST_BACKEND and _rust_solve_ols is not None and not is_rank_deficient: result = _solve_ols_rust( - X, y, + X, + y, cluster_ids=cluster_ids, return_vcov=return_vcov, return_fitted=return_fitted, @@ -531,7 +534,8 @@ def solve_ols( # signaled us to fall back to Python backend if result is None: return _solve_ols_numpy( - X, y, + X, + y, cluster_ids=cluster_ids, return_vcov=return_vcov, return_fitted=return_fitted, @@ -555,7 +559,8 @@ def solve_ols( # and SVD disagreed about rank. Python's QR will re-detect and # apply R-style NaN handling for dropped columns. return _solve_ols_numpy( - X, y, + X, + y, cluster_ids=cluster_ids, return_vcov=return_vcov, return_fitted=return_fitted, @@ -569,7 +574,8 @@ def solve_ols( # Use NumPy implementation for rank-deficient cases (R-style NA handling) # or when Rust backend is not available return _solve_ols_numpy( - X, y, + X, + y, cluster_ids=cluster_ids, return_vcov=return_vcov, return_fitted=return_fitted, @@ -834,9 +840,7 @@ def _compute_robust_vcov_numpy( n_clusters = len(unique_clusters) if n_clusters < 2: - raise ValueError( - f"Need at least 2 clusters for cluster-robust SEs, got {n_clusters}" - ) + raise ValueError(f"Need at least 2 clusters for cluster-robust SEs, got {n_clusters}") # Small-sample adjustment adjustment = (n_clusters / (n_clusters - 1)) * ((n - 1) / (n - k)) @@ -871,6 +875,193 @@ def _compute_robust_vcov_numpy( return vcov +# Empirical threshold: coefficients above this magnitude suggest near-separation +# in the logistic model (predicted probabilities collapse to 0/1). +_LOGIT_SEPARATION_COEF_THRESHOLD = 10 +_LOGIT_SEPARATION_PROB_THRESHOLD = 1e-5 + + +def solve_logit( + X: np.ndarray, + y: np.ndarray, + max_iter: int = 25, + tol: float = 1e-8, + check_separation: bool = True, + rank_deficient_action: str = "warn", +) -> Tuple[np.ndarray, np.ndarray]: + """ + Fit logistic regression via IRLS (Fisher scoring). + + Matches R's ``glm(family=binomial)`` algorithm: iteratively reweighted + least squares with working weights ``mu*(1-mu)`` and working response + ``eta + (y-mu)/(mu*(1-mu))``. + + Parameters + ---------- + X : np.ndarray + Feature matrix (n_samples, n_features). Intercept added automatically. + y : np.ndarray + Binary outcome (0/1). + max_iter : int, default 25 + Maximum IRLS iterations (R's ``glm`` default). + tol : float, default 1e-8 + Convergence tolerance on coefficient change (R's ``glm`` default). + check_separation : bool, default True + Whether to check for near-separation and emit warnings. + rank_deficient_action : str, default "warn" + How to handle rank-deficient design matrices: + - "warn": Emit warning and drop columns (default) + - "error": Raise ValueError + - "silent": Drop columns silently + + Returns + ------- + beta : np.ndarray + Fitted coefficients (including intercept as element 0). + probs : np.ndarray + Predicted probabilities. + """ + n, p = X.shape + X_with_intercept = np.column_stack([np.ones(n), X]) + k = p + 1 # number of parameters including intercept + + # Validate rank_deficient_action + valid_actions = {"warn", "error", "silent"} + if rank_deficient_action not in valid_actions: + raise ValueError( + f"rank_deficient_action must be one of {valid_actions}, " + f"got '{rank_deficient_action}'" + ) + + # Check rank deficiency once before iterating + rank_info = _detect_rank_deficiency(X_with_intercept) + rank, dropped_cols, _ = rank_info + if len(dropped_cols) > 0: + col_desc = _format_dropped_columns(dropped_cols) + if rank_deficient_action == "error": + raise ValueError( + f"Rank-deficient design matrix in logistic regression: " + f"dropping {col_desc}. Propensity score estimates may be unreliable." + ) + elif rank_deficient_action == "warn": + warnings.warn( + f"Rank-deficient design matrix in logistic regression: " + f"dropping {col_desc}. Propensity score estimates may be unreliable.", + UserWarning, + stacklevel=2, + ) + kept_cols = np.array([i for i in range(k) if i not in dropped_cols]) + X_solve = X_with_intercept[:, kept_cols] + else: + kept_cols = np.arange(k) + X_solve = X_with_intercept + + # IRLS (Fisher scoring) + beta_solve = np.zeros(X_solve.shape[1]) + converged = False + + for iteration in range(max_iter): + eta = X_solve @ beta_solve + # Clip to prevent overflow in exp + eta = np.clip(eta, -500, 500) + mu = 1.0 / (1.0 + np.exp(-eta)) + # Clip mu to prevent zero working weights + mu = np.clip(mu, 1e-10, 1 - 1e-10) + + # Working weights and working response + w = mu * (1.0 - mu) + z = eta + (y - mu) / w + + # Weighted least squares: solve (X'WX) beta = X'Wz + sqrt_w = np.sqrt(w) + Xw = X_solve * sqrt_w[:, None] + zw = z * sqrt_w + beta_new, _, _, _ = np.linalg.lstsq(Xw, zw, rcond=None) + + # Check convergence + if np.max(np.abs(beta_new - beta_solve)) < tol: + beta_solve = beta_new + converged = True + break + beta_solve = beta_new + + # Final predicted probabilities + eta_final = X_solve @ beta_solve + eta_final = np.clip(eta_final, -500, 500) + probs = 1.0 / (1.0 + np.exp(-eta_final)) + + # Warnings + if not converged: + warnings.warn( + f"Logistic regression did not converge in {max_iter} iterations. " + f"Propensity score estimates may be unreliable.", + UserWarning, + stacklevel=2, + ) + + if check_separation: + if np.max(np.abs(beta_solve)) > _LOGIT_SEPARATION_COEF_THRESHOLD: + warnings.warn( + "Large coefficients detected in propensity score model " + f"(max|beta| > {_LOGIT_SEPARATION_COEF_THRESHOLD}), " + "suggesting potential separation.", + UserWarning, + stacklevel=2, + ) + n_extreme = int( + np.sum( + (probs < _LOGIT_SEPARATION_PROB_THRESHOLD) + | (probs > 1 - _LOGIT_SEPARATION_PROB_THRESHOLD) + ) + ) + if n_extreme > 0: + warnings.warn( + f"Near-separation detected in propensity score model: " + f"{n_extreme} of {n} observations have predicted probabilities " + f"within {_LOGIT_SEPARATION_PROB_THRESHOLD} of 0 or 1. ATT estimates may be sensitive to " + f"model specification.", + UserWarning, + stacklevel=2, + ) + + # Expand beta back to full size if columns were dropped + if len(dropped_cols) > 0: + beta_full = np.zeros(k) + beta_full[kept_cols] = beta_solve + else: + beta_full = beta_solve + + return beta_full, probs + + +def _check_propensity_diagnostics( + pscore: np.ndarray, + trim_bound: float = 0.01, +) -> None: + """ + Warn if propensity scores are extreme. + + Parameters + ---------- + pscore : np.ndarray + Predicted probabilities. + trim_bound : float, default 0.01 + Trimming threshold. + """ + n_extreme = int(np.sum((pscore < trim_bound) | (pscore > 1 - trim_bound))) + if n_extreme > 0: + n_total = len(pscore) + pct = 100.0 * n_extreme / n_total + warnings.warn( + f"Propensity scores for {n_extreme} of {n_total} observations " + f"({pct:.1f}%) were outside [{trim_bound}, {1 - trim_bound}] " + f"and will be trimmed. This may indicate near-separation in " + f"the propensity score model.", + UserWarning, + stacklevel=2, + ) + + def compute_r_squared( y: np.ndarray, residuals: np.ndarray, @@ -1149,7 +1340,8 @@ def fit( if self.robust or effective_cluster_ids is not None: # Use solve_ols with robust/cluster SEs coefficients, residuals, fitted, vcov = solve_ols( - X, y, + X, + y, cluster_ids=effective_cluster_ids, return_fitted=True, return_vcov=compute_vcov, @@ -1158,7 +1350,8 @@ def fit( else: # Classical OLS - compute vcov separately coefficients, residuals, fitted, _ = solve_ols( - X, y, + X, + y, return_fitted=True, return_vcov=False, rank_deficient_action=self.rank_deficient_action, @@ -1294,6 +1487,7 @@ def get_inference( # Handle zero or negative SE (indicates perfect fit or numerical issues) if se <= 0: import warnings + warnings.warn( f"Standard error is zero or negative (se={se}) for coefficient at index {index}. " "This may indicate perfect multicollinearity or numerical issues.", @@ -1319,6 +1513,7 @@ def get_inference( # Warn if df is non-positive and fall back to normal distribution if effective_df is not None and effective_df <= 0: import warnings + warnings.warn( f"Degrees of freedom is non-positive (df={effective_df}). " "Using normal distribution instead of t-distribution for inference.", @@ -1396,10 +1591,7 @@ def get_all_inference( Inference results for each coefficient in order. """ self._check_fitted() - return [ - self.get_inference(i, alpha=alpha, df=df) - for i in range(len(self.coefficients_)) - ] + return [self.get_inference(i, alpha=alpha, df=df) for i in range(len(self.coefficients_))] def r_squared(self, adjusted: bool = False) -> float: """ @@ -1424,9 +1616,7 @@ def r_squared(self, adjusted: bool = False) -> float: self._check_fitted() # Use effective params for adjusted R² to match df correction n_params = self.n_params_effective_ if adjusted else self.n_params_ - return compute_r_squared( - self._y, self.residuals_, adjusted=adjusted, n_params=n_params - ) + return compute_r_squared(self._y, self.residuals_, adjusted=adjusted, n_params=n_params) def predict(self, X: np.ndarray) -> np.ndarray: """ diff --git a/diff_diff/staggered.py b/diff_diff/staggered.py index 65d4865..6c4560e 100644 --- a/diff_diff/staggered.py +++ b/diff_diff/staggered.py @@ -11,9 +11,13 @@ import numpy as np import pandas as pd from scipy import linalg as scipy_linalg -from scipy import optimize - -from diff_diff.linalg import solve_ols, _detect_rank_deficiency, _format_dropped_columns +from diff_diff.linalg import ( + solve_ols, + solve_logit, + _check_propensity_diagnostics, + _detect_rank_deficiency, + _format_dropped_columns, +) from diff_diff.utils import safe_inference, safe_inference_batch # Import from split modules @@ -41,69 +45,6 @@ PrecomputedData = Dict[str, Any] -def _logistic_regression( - X: np.ndarray, - y: np.ndarray, - max_iter: int = 100, - tol: float = 1e-6, -) -> Tuple[np.ndarray, np.ndarray]: - """ - Fit logistic regression using scipy optimize. - - Parameters - ---------- - X : np.ndarray - Feature matrix (n_samples, n_features). Intercept added automatically. - y : np.ndarray - Binary outcome (0/1). - max_iter : int - Maximum iterations. - tol : float - Convergence tolerance. - - Returns - ------- - beta : np.ndarray - Fitted coefficients (including intercept). - probs : np.ndarray - Predicted probabilities. - """ - n, p = X.shape - # Add intercept - X_with_intercept = np.column_stack([np.ones(n), X]) - - def neg_log_likelihood(beta: np.ndarray) -> float: - z = np.dot(X_with_intercept, beta) - # Clip to prevent overflow - z = np.clip(z, -500, 500) - log_lik = np.sum(y * z - np.log(1 + np.exp(z))) - return -log_lik - - def gradient(beta: np.ndarray) -> np.ndarray: - z = np.dot(X_with_intercept, beta) - z = np.clip(z, -500, 500) - probs = 1 / (1 + np.exp(-z)) - return -np.dot(X_with_intercept.T, y - probs) - - # Initialize with zeros - beta_init = np.zeros(p + 1) - - result = optimize.minimize( - neg_log_likelihood, - beta_init, - method='BFGS', - jac=gradient, - options={'maxiter': max_iter, 'gtol': tol} - ) - - beta = result.x - z = np.dot(X_with_intercept, beta) - z = np.clip(z, -500, 500) - probs = 1 / (1 + np.exp(-z)) - - return beta, probs - - def _linear_regression( X: np.ndarray, y: np.ndarray, @@ -137,7 +78,9 @@ def _linear_regression( # Use unified OLS backend (no vcov needed) beta, residuals, _ = solve_ols( - X_with_intercept, y, return_vcov=False, + X_with_intercept, + y, + return_vcov=False, rank_deficient_action=rank_deficient_action, ) @@ -221,6 +164,10 @@ class CallawaySantAnna( event study aggregation. Requires ``n_bootstrap > 0``. When True, results include ``cband_crit_value`` and per-event-time ``cband_conf_int`` entries controlling family-wise error rate. + pscore_trim : float, default=0.01 + Trimming bound for propensity scores. Scores are clipped to + ``[pscore_trim, 1 - pscore_trim]`` before weight computation + in IPW and DR estimation. Must be in ``(0, 0.5)``. Attributes ---------- @@ -309,6 +256,7 @@ def __init__( rank_deficient_action: str = "warn", base_period: str = "varying", cband: bool = True, + pscore_trim: float = 0.01, ): import warnings @@ -319,9 +267,10 @@ def __init__( ) if estimation_method not in ["dr", "ipw", "reg"]: raise ValueError( - f"estimation_method must be 'dr', 'ipw', or 'reg', " - f"got '{estimation_method}'" + f"estimation_method must be 'dr', 'ipw', or 'reg', " f"got '{estimation_method}'" ) + if not (0 < pscore_trim < 0.5): + raise ValueError(f"pscore_trim must be in (0, 0.5), got {pscore_trim}") # Handle bootstrap_weight_type deprecation if bootstrap_weight_type is not None: @@ -329,7 +278,7 @@ def __init__( "bootstrap_weight_type is deprecated and will be removed in v3.0. " "Use bootstrap_weights instead.", DeprecationWarning, - stacklevel=2 + stacklevel=2, ) if bootstrap_weights is None: bootstrap_weights = bootstrap_weight_type @@ -352,8 +301,7 @@ def __init__( if base_period not in ["varying", "universal"]: raise ValueError( - f"base_period must be 'varying' or 'universal', " - f"got '{base_period}'" + f"base_period must be 'varying' or 'universal', " f"got '{base_period}'" ) self.control_group = control_group @@ -370,6 +318,7 @@ def __init__( self.base_period = base_period self.cband = cband + self.pscore_trim = pscore_trim self.is_fitted_ = False self.results_: Optional[CallawaySantAnnaResults] = None @@ -418,7 +367,7 @@ def _precompute_structures( # Pre-compute cohort masks (boolean arrays) cohort_masks = {} for g in treatment_groups: - cohort_masks[g] = (unit_cohorts == g) + cohort_masks[g] = unit_cohorts == g # Never-treated mask # np.inf was normalized to 0 in fit(), so the np.inf check is defensive only @@ -437,16 +386,16 @@ def _precompute_structures( is_balanced = not np.any(np.isnan(outcome_matrix)) return { - 'all_units': all_units, - 'unit_to_idx': unit_to_idx, - 'unit_cohorts': unit_cohorts, - 'outcome_matrix': outcome_matrix, - 'period_to_col': period_to_col, - 'cohort_masks': cohort_masks, - 'never_treated_mask': never_treated_mask, - 'covariate_by_period': covariate_by_period, - 'time_periods': time_periods, - 'is_balanced': is_balanced, + "all_units": all_units, + "unit_to_idx": unit_to_idx, + "unit_cohorts": unit_cohorts, + "outcome_matrix": outcome_matrix, + "period_to_col": period_to_col, + "cohort_masks": cohort_masks, + "never_treated_mask": never_treated_mask, + "covariate_by_period": covariate_by_period, + "time_periods": time_periods, + "is_balanced": is_balanced, } def _compute_att_gt_fast( @@ -464,12 +413,12 @@ def _compute_att_gt_fast( Uses vectorized numpy operations on pre-pivoted outcome matrix instead of repeated pandas filtering. """ - period_to_col = precomputed['period_to_col'] - outcome_matrix = precomputed['outcome_matrix'] - cohort_masks = precomputed['cohort_masks'] - never_treated_mask = precomputed['never_treated_mask'] - unit_cohorts = precomputed['unit_cohorts'] - covariate_by_period = precomputed['covariate_by_period'] + period_to_col = precomputed["period_to_col"] + outcome_matrix = precomputed["outcome_matrix"] + cohort_masks = precomputed["cohort_masks"] + never_treated_mask = precomputed["never_treated_mask"] + unit_cohorts = precomputed["unit_cohorts"] + covariate_by_period = precomputed["covariate_by_period"] # Base period selection based on mode if self.base_period == "universal": @@ -553,7 +502,7 @@ def _compute_att_gt_fast( # Compute cache key for propensity score reuse pscore_key = None if pscore_cache is not None and X_treated is not None: - is_balanced = precomputed.get('is_balanced', False) + is_balanced = precomputed.get("is_balanced", False) if is_balanced and self.control_group == "never_treated": pscore_key = (g, base_period_val) else: @@ -562,7 +511,7 @@ def _compute_att_gt_fast( # Compute cache key for Cholesky reuse (DR outcome regression) cho_key = None if cho_cache is not None and X_control is not None: - is_balanced = precomputed.get('is_balanced', False) + is_balanced = precomputed.get("is_balanced", False) if is_balanced and self.control_group == "never_treated": cho_key = base_period_val else: @@ -575,15 +524,21 @@ def _compute_att_gt_fast( ) elif self.estimation_method == "ipw": att_gt, se_gt, inf_func = self._ipw_estimation( - treated_change, control_change, - int(n_treated), int(n_control), - X_treated, X_control, + treated_change, + control_change, + int(n_treated), + int(n_control), + X_treated, + X_control, pscore_cache=pscore_cache, pscore_key=pscore_key, ) else: # doubly robust att_gt, se_gt, inf_func = self._doubly_robust( - treated_change, control_change, X_treated, X_control, + treated_change, + control_change, + X_treated, + X_control, pscore_cache=pscore_cache, pscore_key=pscore_key, cho_cache=cho_cache, @@ -594,16 +549,16 @@ def _compute_att_gt_fast( # precomputed['all_units']) for O(1) downstream lookups instead of # O(n) Python dict lookups. n_t = int(n_treated) - all_units = precomputed['all_units'] + all_units = precomputed["all_units"] treated_positions = np.where(treated_valid)[0] control_positions = np.where(control_valid)[0] inf_func_info = { - 'treated_idx': treated_positions, - 'control_idx': control_positions, - 'treated_units': all_units[treated_positions], - 'control_units': all_units[control_positions], - 'treated_inf': inf_func[:n_t], - 'control_inf': inf_func[n_t:], + "treated_idx": treated_positions, + "control_idx": control_positions, + "treated_units": all_units[treated_positions], + "control_units": all_units[control_positions], + "treated_inf": inf_func[:n_t], + "control_inf": inf_func[n_t:], } return att_gt, se_gt, int(n_treated), int(n_control), inf_func_info @@ -628,11 +583,11 @@ def _compute_all_att_gt_vectorized( influence_func_info : dict Mapping (g, t) -> influence function info dict. """ - period_to_col = precomputed['period_to_col'] - outcome_matrix = precomputed['outcome_matrix'] - cohort_masks = precomputed['cohort_masks'] - never_treated_mask = precomputed['never_treated_mask'] - unit_cohorts = precomputed['unit_cohorts'] + period_to_col = precomputed["period_to_col"] + outcome_matrix = precomputed["outcome_matrix"] + cohort_masks = precomputed["cohort_masks"] + never_treated_mask = precomputed["never_treated_mask"] + unit_cohorts = precomputed["unit_cohorts"] group_time_effects = {} influence_func_info = {} @@ -645,8 +600,7 @@ def _compute_all_att_gt_vectorized( valid_periods = [t for t in time_periods if t != universal_base] else: valid_periods = [ - t for t in time_periods - if t >= g - self.anticipation or t > min_period + t for t in time_periods if t >= g - self.anticipation or t > min_period ] for t in valid_periods: @@ -711,26 +665,26 @@ def _compute_all_att_gt_vectorized( inf_control = -(control_change - np.mean(control_change)) / n_c group_time_effects[(g, t)] = { - 'effect': att, - 'se': se, + "effect": att, + "se": se, # t_stat, p_value, conf_int filled by batch inference below - 't_stat': np.nan, - 'p_value': np.nan, - 'conf_int': (np.nan, np.nan), - 'n_treated': n_t, - 'n_control': n_c, + "t_stat": np.nan, + "p_value": np.nan, + "conf_int": (np.nan, np.nan), + "n_treated": n_t, + "n_control": n_c, } - all_units = precomputed['all_units'] + all_units = precomputed["all_units"] treated_positions = np.where(treated_valid)[0] control_positions = np.where(control_valid)[0] influence_func_info[(g, t)] = { - 'treated_idx': treated_positions, - 'control_idx': control_positions, - 'treated_units': all_units[treated_positions], - 'control_units': all_units[control_positions], - 'treated_inf': inf_treated, - 'control_inf': inf_control, + "treated_idx": treated_positions, + "control_idx": control_positions, + "treated_units": all_units[treated_positions], + "control_units": all_units[control_positions], + "treated_inf": inf_treated, + "control_inf": inf_control, } atts.append(att) @@ -743,11 +697,9 @@ def _compute_all_att_gt_vectorized( np.array(atts), np.array(ses), alpha=self.alpha ) for idx, key in enumerate(task_keys): - group_time_effects[key]['t_stat'] = float(t_stats[idx]) - group_time_effects[key]['p_value'] = float(p_values[idx]) - group_time_effects[key]['conf_int'] = ( - float(ci_lowers[idx]), float(ci_uppers[idx]) - ) + group_time_effects[key]["t_stat"] = float(t_stats[idx]) + group_time_effects[key]["p_value"] = float(p_values[idx]) + group_time_effects[key]["conf_int"] = (float(ci_lowers[idx]), float(ci_uppers[idx])) return group_time_effects, influence_func_info @@ -772,13 +724,13 @@ def _compute_all_att_gt_covariate_reg( influence_func_info : dict Mapping (g, t) -> influence function info dict. """ - period_to_col = precomputed['period_to_col'] - outcome_matrix = precomputed['outcome_matrix'] - cohort_masks = precomputed['cohort_masks'] - never_treated_mask = precomputed['never_treated_mask'] - unit_cohorts = precomputed['unit_cohorts'] - covariate_by_period = precomputed['covariate_by_period'] - is_balanced = precomputed['is_balanced'] + period_to_col = precomputed["period_to_col"] + outcome_matrix = precomputed["outcome_matrix"] + cohort_masks = precomputed["cohort_masks"] + never_treated_mask = precomputed["never_treated_mask"] + unit_cohorts = precomputed["unit_cohorts"] + covariate_by_period = precomputed["covariate_by_period"] + is_balanced = precomputed["is_balanced"] group_time_effects = {} influence_func_info = {} @@ -795,8 +747,7 @@ def _compute_all_att_gt_covariate_reg( valid_periods = [t for t in time_periods if t != universal_base] else: valid_periods = [ - t for t in time_periods - if t >= g - self.anticipation or t > min_period + t for t in time_periods if t >= g - self.anticipation or t > min_period ] for t in valid_periods: @@ -878,14 +829,16 @@ def _compute_all_att_gt_covariate_reg( f"Rank-deficient covariate design (control_key={control_key}): " f"dropped columns {col_info}. Rank {rank} < {X_ctrl.shape[1]}. " "Using minimum-norm least-squares solution.", - UserWarning, stacklevel=2, + UserWarning, + stacklevel=2, ) cho = None # Force lstsq path for ALL rank-deficient cases - kept_cols = np.array([i for i in range(X_ctrl.shape[1]) - if i not in dropped_cols]) + kept_cols = np.array( + [i for i in range(X_ctrl.shape[1]) if i not in dropped_cols] + ) else: kept_cols = None # Full rank — use all columns - with np.errstate(all='ignore'): + with np.errstate(all="ignore"): XtX = X_ctrl.T @ X_ctrl try: cho = scipy_linalg.cho_factor(XtX) @@ -947,8 +900,7 @@ def _compute_all_att_gt_covariate_reg( inf_control = -(control_change - np.mean(control_change)) / n_c else: # Build per-pair X_ctrl if control_valid differs from base - if (is_balanced and self.control_group == "never_treated" - and X_ctrl is not None): + if is_balanced and self.control_group == "never_treated" and X_ctrl is not None: pair_X_ctrl = X_ctrl pair_n_c = n_c_base else: @@ -957,9 +909,12 @@ def _compute_all_att_gt_covariate_reg( # Solve for beta beta = None - with np.errstate(all='ignore'): - if (cho is not None and is_balanced - and self.control_group == "never_treated"): + with np.errstate(all="ignore"): + if ( + cho is not None + and is_balanced + and self.control_group == "never_treated" + ): # Use cached Cholesky Xty = pair_X_ctrl.T @ control_change beta = scipy_linalg.cho_solve(cho, Xty) @@ -981,7 +936,8 @@ def _compute_all_att_gt_covariate_reg( if kept_cols is not None: # Reduced solve for rank-deficient design result = scipy_linalg.lstsq( - pair_X_ctrl[:, kept_cols], control_change, + pair_X_ctrl[:, kept_cols], + control_change, cond=1e-07, ) beta = np.zeros(pair_X_ctrl.shape[1]) @@ -989,7 +945,9 @@ def _compute_all_att_gt_covariate_reg( else: # Full-rank lstsq fallback (Cholesky numerical failure) result = scipy_linalg.lstsq( - pair_X_ctrl, control_change, cond=1e-07, + pair_X_ctrl, + control_change, + cond=1e-07, ) beta = result[0] @@ -1001,7 +959,7 @@ def _compute_all_att_gt_covariate_reg( if not nan_cell: X_treated_w_intercept = np.column_stack([np.ones(n_t), X_treated_pair]) - with np.errstate(all='ignore'): + with np.errstate(all="ignore"): predicted_control = X_treated_w_intercept @ beta treated_residuals = treated_change - predicted_control if np.any(~np.isfinite(predicted_control)): @@ -1010,7 +968,7 @@ def _compute_all_att_gt_covariate_reg( if not nan_cell: att = float(np.mean(treated_residuals)) - with np.errstate(all='ignore'): + with np.errstate(all="ignore"): residuals = control_change - pair_X_ctrl @ beta if np.any(~np.isfinite(residuals)): nan_cell = True @@ -1029,25 +987,25 @@ def _compute_all_att_gt_covariate_reg( inf_control = -residuals / pair_n_c group_time_effects[(g, t)] = { - 'effect': att, - 'se': se, - 't_stat': np.nan, - 'p_value': np.nan, - 'conf_int': (np.nan, np.nan), - 'n_treated': n_t, - 'n_control': n_c, + "effect": att, + "se": se, + "t_stat": np.nan, + "p_value": np.nan, + "conf_int": (np.nan, np.nan), + "n_treated": n_t, + "n_control": n_c, } - all_units = precomputed['all_units'] + all_units = precomputed["all_units"] treated_positions = np.where(treated_valid)[0] control_positions = np.where(control_valid)[0] influence_func_info[(g, t)] = { - 'treated_idx': treated_positions, - 'control_idx': control_positions, - 'treated_units': all_units[treated_positions], - 'control_units': all_units[control_positions], - 'treated_inf': inf_treated, - 'control_inf': inf_control, + "treated_idx": treated_positions, + "control_idx": control_positions, + "treated_units": all_units[treated_positions], + "control_units": all_units[control_positions], + "treated_inf": inf_treated, + "control_inf": inf_control, } atts.append(att) @@ -1068,11 +1026,9 @@ def _compute_all_att_gt_covariate_reg( np.array(atts), np.array(ses), alpha=self.alpha ) for idx, key in enumerate(task_keys): - group_time_effects[key]['t_stat'] = float(t_stats[idx]) - group_time_effects[key]['p_value'] = float(p_values[idx]) - group_time_effects[key]['conf_int'] = ( - float(ci_lowers[idx]), float(ci_uppers[idx]) - ) + group_time_effects[key]["t_stat"] = float(t_stats[idx]) + group_time_effects[key]["p_value"] = float(p_values[idx]) + group_time_effects[key]["conf_int"] = (float(ci_lowers[idx]), float(ci_uppers[idx])) return group_time_effects, influence_func_info @@ -1126,6 +1082,10 @@ def fit( ValueError If required columns are missing or data validation fails. """ + # Validate pscore_trim (may have been changed via set_params) + if not (0 < self.pscore_trim < 0.5): + raise ValueError(f"pscore_trim must be in (0, 0.5), got {self.pscore_trim}") + # Normalize empty covariates list to None if covariates is not None and len(covariates) == 0: covariates = None @@ -1148,10 +1108,10 @@ def fit( # Standardize the first_treat column name for internal use # This avoids hardcoding column names in internal methods - df['first_treat'] = df[first_treat] + df["first_treat"] = df[first_treat] # Never-treated indicator (must precede treatment_groups to exclude np.inf) - df['_never_treated'] = (df[first_treat] == 0) | (df[first_treat] == np.inf) + df["_never_treated"] = (df[first_treat] == 0) | (df[first_treat] == np.inf) # Normalize np.inf → 0 so all downstream `> 0` checks exclude never-treated df.loc[df[first_treat] == np.inf, first_treat] = 0 @@ -1160,21 +1120,19 @@ def fit( treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0]) # Get unique units - unit_info = df.groupby(unit).agg({ - first_treat: 'first', - '_never_treated': 'first' - }).reset_index() + unit_info = ( + df.groupby(unit).agg({first_treat: "first", "_never_treated": "first"}).reset_index() + ) n_treated_units = (unit_info[first_treat] > 0).sum() - n_control_units = (unit_info['_never_treated']).sum() + n_control_units = (unit_info["_never_treated"]).sum() if n_control_units == 0: raise ValueError("No never-treated units found. Check 'first_treat' column.") # Pre-compute data structures for efficient ATT(g,t) computation precomputed = self._precompute_structures( - df, outcome, unit, time, first_treat, - covariates, time_periods, treatment_groups + df, outcome, unit, time, first_treat, covariates, time_periods, treatment_groups ) # Compute ATT(g,t) for each group-time combination @@ -1182,18 +1140,17 @@ def fit( if covariates is None and self.estimation_method == "reg": # Fast vectorized path for the common no-covariates regression case - group_time_effects, influence_func_info = ( - self._compute_all_att_gt_vectorized( - precomputed, treatment_groups, time_periods, min_period - ) + group_time_effects, influence_func_info = self._compute_all_att_gt_vectorized( + precomputed, treatment_groups, time_periods, min_period ) - elif (covariates is not None and self.estimation_method == "reg" - and self.rank_deficient_action != "error"): + elif ( + covariates is not None + and self.estimation_method == "reg" + and self.rank_deficient_action != "error" + ): # Optimized covariate regression path with Cholesky caching - group_time_effects, influence_func_info = ( - self._compute_all_att_gt_covariate_reg( - precomputed, treatment_groups, time_periods, min_period - ) + group_time_effects, influence_func_info = self._compute_all_att_gt_covariate_reg( + precomputed, treatment_groups, time_periods, min_period ) else: # General path: IPW, DR, rank_deficient_action="error", or edge cases @@ -1201,14 +1158,17 @@ def fit( influence_func_info = {} # Propensity score cache for IPW/DR with covariates - pscore_cache = {} if ( - covariates and self.estimation_method in ("ipw", "dr") - ) else None + pscore_cache = {} if (covariates and self.estimation_method in ("ipw", "dr")) else None # Cholesky cache for DR outcome regression component - cho_cache = {} if ( - covariates and self.estimation_method == "dr" - and self.rank_deficient_action != "error" - ) else None + cho_cache = ( + {} + if ( + covariates + and self.estimation_method == "dr" + and self.rank_deficient_action != "error" + ) + else None + ) for g in treatment_groups: if self.base_period == "universal": @@ -1216,13 +1176,15 @@ def fit( valid_periods = [t for t in time_periods if t != universal_base] else: valid_periods = [ - t for t in time_periods - if t >= g - self.anticipation or t > min_period + t for t in time_periods if t >= g - self.anticipation or t > min_period ] for t in valid_periods: att_gt, se_gt, n_treat, n_ctrl, inf_info = self._compute_att_gt_fast( - precomputed, g, t, covariates, + precomputed, + g, + t, + covariates, pscore_cache=pscore_cache, cho_cache=cho_cache, ) @@ -1231,13 +1193,13 @@ def fit( t_stat, p_val, ci = safe_inference(att_gt, se_gt, alpha=self.alpha) group_time_effects[(g, t)] = { - 'effect': att_gt, - 'se': se_gt, - 't_stat': t_stat, - 'p_value': p_val, - 'conf_int': ci, - 'n_treated': n_treat, - 'n_control': n_ctrl, + "effect": att_gt, + "se": se_gt, + "t_stat": t_stat, + "p_value": p_val, + "conf_int": ci, + "n_treated": n_treat, + "n_control": n_ctrl, } if inf_info is not None: @@ -1253,9 +1215,7 @@ def fit( overall_att, overall_se = self._aggregate_simple( group_time_effects, influence_func_info, df, unit, precomputed ) - overall_t, overall_p, overall_ci = safe_inference( - overall_att, overall_se, alpha=self.alpha - ) + overall_t, overall_p, overall_ci = safe_inference(overall_att, overall_se, alpha=self.alpha) # Compute additional aggregations if requested event_study_effects = None @@ -1263,14 +1223,21 @@ def fit( if aggregate in ["event_study", "all"]: event_study_effects = self._aggregate_event_study( - group_time_effects, influence_func_info, - treatment_groups, time_periods, balance_e, - df, unit, precomputed, + group_time_effects, + influence_func_info, + treatment_groups, + time_periods, + balance_e, + df, + unit, + precomputed, ) if aggregate in ["group", "all"]: group_effects = self._aggregate_by_group( - group_time_effects, influence_func_info, treatment_groups, + group_time_effects, + influence_func_info, + treatment_groups, precomputed=precomputed, ) @@ -1299,46 +1266,70 @@ def fit( # Update group-time effects with bootstrap SEs (batched) gt_keys = [gt for gt in group_time_effects if gt in bootstrap_results.group_time_ses] if gt_keys: - gt_effects_arr = np.array([float(group_time_effects[gt]['effect']) for gt in gt_keys]) - gt_ses_arr = np.array([float(bootstrap_results.group_time_ses[gt]) for gt in gt_keys]) - gt_t_stats, _, _, _ = safe_inference_batch(gt_effects_arr, gt_ses_arr, alpha=self.alpha) + gt_effects_arr = np.array( + [float(group_time_effects[gt]["effect"]) for gt in gt_keys] + ) + gt_ses_arr = np.array( + [float(bootstrap_results.group_time_ses[gt]) for gt in gt_keys] + ) + gt_t_stats, _, _, _ = safe_inference_batch( + gt_effects_arr, gt_ses_arr, alpha=self.alpha + ) for idx, gt in enumerate(gt_keys): - group_time_effects[gt]['se'] = bootstrap_results.group_time_ses[gt] - group_time_effects[gt]['conf_int'] = bootstrap_results.group_time_cis[gt] - group_time_effects[gt]['p_value'] = bootstrap_results.group_time_p_values[gt] - group_time_effects[gt]['t_stat'] = float(gt_t_stats[idx]) + group_time_effects[gt]["se"] = bootstrap_results.group_time_ses[gt] + group_time_effects[gt]["conf_int"] = bootstrap_results.group_time_cis[gt] + group_time_effects[gt]["p_value"] = bootstrap_results.group_time_p_values[gt] + group_time_effects[gt]["t_stat"] = float(gt_t_stats[idx]) # Update event study effects with bootstrap SEs (batched) - if (event_study_effects is not None + if ( + event_study_effects is not None and bootstrap_results.event_study_ses is not None and bootstrap_results.event_study_cis is not None - and bootstrap_results.event_study_p_values is not None): + and bootstrap_results.event_study_p_values is not None + ): es_keys = [e for e in event_study_effects if e in bootstrap_results.event_study_ses] if es_keys: - es_effects_arr = np.array([float(event_study_effects[e]['effect']) for e in es_keys]) - es_ses_arr = np.array([float(bootstrap_results.event_study_ses[e]) for e in es_keys]) - es_t_stats, _, _, _ = safe_inference_batch(es_effects_arr, es_ses_arr, alpha=self.alpha) + es_effects_arr = np.array( + [float(event_study_effects[e]["effect"]) for e in es_keys] + ) + es_ses_arr = np.array( + [float(bootstrap_results.event_study_ses[e]) for e in es_keys] + ) + es_t_stats, _, _, _ = safe_inference_batch( + es_effects_arr, es_ses_arr, alpha=self.alpha + ) for idx, e in enumerate(es_keys): - event_study_effects[e]['se'] = bootstrap_results.event_study_ses[e] - event_study_effects[e]['conf_int'] = bootstrap_results.event_study_cis[e] - event_study_effects[e]['p_value'] = bootstrap_results.event_study_p_values[e] - event_study_effects[e]['t_stat'] = float(es_t_stats[idx]) + event_study_effects[e]["se"] = bootstrap_results.event_study_ses[e] + event_study_effects[e]["conf_int"] = bootstrap_results.event_study_cis[e] + event_study_effects[e]["p_value"] = bootstrap_results.event_study_p_values[ + e + ] + event_study_effects[e]["t_stat"] = float(es_t_stats[idx]) # Update group effects with bootstrap SEs (batched) - if (group_effects is not None + if ( + group_effects is not None and bootstrap_results.group_effect_ses is not None and bootstrap_results.group_effect_cis is not None - and bootstrap_results.group_effect_p_values is not None): + and bootstrap_results.group_effect_p_values is not None + ): grp_keys = [g for g in group_effects if g in bootstrap_results.group_effect_ses] if grp_keys: - grp_effects_arr = np.array([float(group_effects[g]['effect']) for g in grp_keys]) - grp_ses_arr = np.array([float(bootstrap_results.group_effect_ses[g]) for g in grp_keys]) - grp_t_stats, _, _, _ = safe_inference_batch(grp_effects_arr, grp_ses_arr, alpha=self.alpha) + grp_effects_arr = np.array( + [float(group_effects[g]["effect"]) for g in grp_keys] + ) + grp_ses_arr = np.array( + [float(bootstrap_results.group_effect_ses[g]) for g in grp_keys] + ) + grp_t_stats, _, _, _ = safe_inference_batch( + grp_effects_arr, grp_ses_arr, alpha=self.alpha + ) for idx, g in enumerate(grp_keys): - group_effects[g]['se'] = bootstrap_results.group_effect_ses[g] - group_effects[g]['conf_int'] = bootstrap_results.group_effect_cis[g] - group_effects[g]['p_value'] = bootstrap_results.group_effect_p_values[g] - group_effects[g]['t_stat'] = float(grp_t_stats[idx]) + group_effects[g]["se"] = bootstrap_results.group_effect_ses[g] + group_effects[g]["conf_int"] = bootstrap_results.group_effect_cis[g] + group_effects[g]["p_value"] = bootstrap_results.group_effect_p_values[g] + group_effects[g]["t_stat"] = float(grp_t_stats[idx]) # Compute simultaneous confidence band CIs if cband is available cband_crit_value = None @@ -1347,11 +1338,11 @@ def fit( if cband_crit_value is not None and event_study_effects is not None: for e, eff_data in event_study_effects.items(): - se_val = eff_data['se'] + se_val = eff_data["se"] if np.isfinite(se_val) and se_val > 0: - eff_data['cband_conf_int'] = ( - eff_data['effect'] - cband_crit_value * se_val, - eff_data['effect'] + cband_crit_value * se_val, + eff_data["cband_conf_int"] = ( + eff_data["effect"] - cband_crit_value * se_val, + eff_data["effect"] + cband_crit_value * se_val, ) # Store results @@ -1374,6 +1365,7 @@ def fit( group_effects=group_effects, bootstrap_results=bootstrap_results, cband_crit_value=cband_crit_value, + pscore_trim=self.pscore_trim, ) self.is_fitted_ = True @@ -1404,7 +1396,8 @@ def _outcome_regression( # Covariate-adjusted outcome regression # Fit regression on control units: E[Delta Y | X, D=0] beta, residuals = _linear_regression( - X_control, control_change, + X_control, + control_change, rank_deficient_action=self.rank_deficient_action, ) @@ -1492,13 +1485,20 @@ def _ipw_estimation( X_all = np.vstack([X_treated, X_control]) D = np.concatenate([np.ones(n_t), np.zeros(n_c)]) - # Estimate propensity scores using logistic regression + # Estimate propensity scores using IRLS logistic regression try: - beta_logistic, pscore = _logistic_regression(X_all, D) + beta_logistic, pscore = solve_logit( + X_all, + D, + rank_deficient_action=self.rank_deficient_action, + ) + _check_propensity_diagnostics(pscore, self.pscore_trim) # Cache the fitted coefficients if pscore_cache is not None and pscore_key is not None: pscore_cache[pscore_key] = beta_logistic except (np.linalg.LinAlgError, ValueError): + if self.rank_deficient_action == "error": + raise # Fallback to unconditional if logistic regression fails warnings.warn( "Propensity score estimation failed. " @@ -1513,8 +1513,8 @@ def _ipw_estimation( pscore_control = pscore[n_t:] # Clip propensity scores to avoid extreme weights - pscore_control = np.clip(pscore_control, 0.01, 0.99) - pscore_treated = np.clip(pscore_treated, 0.01, 0.99) + pscore_control = np.clip(pscore_control, self.pscore_trim, 1 - self.pscore_trim) + pscore_treated = np.clip(pscore_treated, self.pscore_trim, 1 - self.pscore_trim) # IPW weights for control units: p(X) / (1 - p(X)) # This reweights controls to have same covariate distribution as treated @@ -1529,13 +1529,17 @@ def _ipw_estimation( var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0 # Variance of weighted control mean - weighted_var_c = np.sum(weights_control * (control_change - np.sum(weights_control * control_change)) ** 2) + weighted_var_c = np.sum( + weights_control * (control_change - np.sum(weights_control * control_change)) ** 2 + ) se = np.sqrt(var_t / n_t + weighted_var_c) if (n_t > 0 and n_c > 0) else 0.0 # Influence function inf_treated = (treated_change - np.mean(treated_change)) / n_t - inf_control = -weights_control * (control_change - np.sum(weights_control * control_change)) + inf_control = -weights_control * ( + control_change - np.sum(weights_control * control_change) + ) inf_func = np.concatenate([inf_treated, inf_control]) else: # Unconditional IPW (reduces to difference in means) @@ -1547,7 +1551,11 @@ def _ipw_estimation( var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0 # Adjusted variance for IPW - se = np.sqrt(var_t / n_t + var_c * (1 - p_treat) / (n_c * p_treat)) if (n_t > 0 and n_c > 0 and p_treat > 0) else 0.0 + se = ( + np.sqrt(var_t / n_t + var_c * (1 - p_treat) / (n_c * p_treat)) + if (n_t > 0 and n_c > 0 and p_treat > 0) + else 0.0 + ) # Influence function (for aggregation) inf_treated = (treated_change - np.mean(treated_change)) / n_t @@ -1622,7 +1630,8 @@ def _doubly_robust( if beta is None: beta, _ = _linear_regression( - X_control, control_change, + X_control, + control_change, rank_deficient_action=self.rank_deficient_action, ) # Zero NaN coefficients for prediction only — dropped columns @@ -1654,17 +1663,30 @@ def _doubly_robust( D = np.concatenate([np.ones(n_t), np.zeros(n_c)]) try: - beta_logistic, pscore = _logistic_regression(X_all, D) + beta_logistic, pscore = solve_logit( + X_all, + D, + rank_deficient_action=self.rank_deficient_action, + ) + _check_propensity_diagnostics(pscore, self.pscore_trim) if pscore_cache is not None and pscore_key is not None: pscore_cache[pscore_key] = beta_logistic except (np.linalg.LinAlgError, ValueError): + if self.rank_deficient_action == "error": + raise # Fallback to unconditional if logistic regression fails + warnings.warn( + "Propensity score estimation failed. " + "Falling back to unconditional estimation.", + UserWarning, + stacklevel=4, + ) pscore = np.full(len(D), n_t / (n_t + n_c)) pscore_control = pscore[n_t:] # Clip propensity scores - pscore_control = np.clip(pscore_control, 0.01, 0.99) + pscore_control = np.clip(pscore_control, self.pscore_trim, 1 - self.pscore_trim) # IPW weights for control: p(X) / (1 - p(X)) weights_control = pscore_control / (1 - pscore_control) @@ -1685,7 +1707,7 @@ def _doubly_robust( psi_control = (weights_control * (m_control - control_change)) / n_t # Variance is sum of squared influence functions - var_psi = np.sum(psi_treated ** 2) + np.sum(psi_control ** 2) + var_psi = np.sum(psi_treated**2) + np.sum(psi_control**2) se = np.sqrt(var_psi) if var_psi > 0 else 0.0 # Full influence function @@ -1722,6 +1744,7 @@ def get_params(self) -> Dict[str, Any]: "rank_deficient_action": self.rank_deficient_action, "base_period": self.base_period, "cband": self.cband, + "pscore_trim": self.pscore_trim, } def set_params(self, **params) -> "CallawaySantAnna": diff --git a/diff_diff/staggered_results.py b/diff_diff/staggered_results.py index 757b566..8eb46d0 100644 --- a/diff_diff/staggered_results.py +++ b/diff_diff/staggered_results.py @@ -37,6 +37,7 @@ class GroupTimeEffect: n_control : int Number of control observations. """ + group: Any time: Any effect: float @@ -92,7 +93,10 @@ class CallawaySantAnnaResults: Effects aggregated by relative time (event study). group_effects : dict, optional Effects aggregated by treatment cohort. + pscore_trim : float + Propensity score trimming bound used during estimation. """ + group_time_effects: Dict[Tuple[Any, Any], Dict[str, Any]] overall_att: float overall_se: float @@ -112,6 +116,7 @@ class CallawaySantAnnaResults: influence_functions: Optional["np.ndarray"] = field(default=None, repr=False) bootstrap_results: Optional["CSBootstrapResults"] = field(default=None, repr=False) cband_crit_value: Optional[float] = None + pscore_trim: float = 0.01 def __repr__(self) -> str: """Concise string representation.""" @@ -156,35 +161,39 @@ def summary(self, alpha: Optional[float] = None) -> str: ] # Overall ATT - lines.extend([ - "-" * 85, - "Overall Average Treatment Effect on the Treated".center(85), - "-" * 85, - f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", - "-" * 85, - f"{'ATT':<15} {self.overall_att:>12.4f} {self.overall_se:>12.4f} " - f"{self.overall_t_stat:>10.3f} {self.overall_p_value:>10.4f} " - f"{_get_significance_stars(self.overall_p_value):>6}", - "-" * 85, - "", - f"{conf_level}% Confidence Interval: [{self.overall_conf_int[0]:.4f}, {self.overall_conf_int[1]:.4f}]", - "", - ]) + lines.extend( + [ + "-" * 85, + "Overall Average Treatment Effect on the Treated".center(85), + "-" * 85, + f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", + "-" * 85, + f"{'ATT':<15} {self.overall_att:>12.4f} {self.overall_se:>12.4f} " + f"{self.overall_t_stat:>10.3f} {self.overall_p_value:>10.4f} " + f"{_get_significance_stars(self.overall_p_value):>6}", + "-" * 85, + "", + f"{conf_level}% Confidence Interval: [{self.overall_conf_int[0]:.4f}, {self.overall_conf_int[1]:.4f}]", + "", + ] + ) # Event study effects if available if self.event_study_effects: ci_label = "Simult. CI" if self.cband_crit_value is not None else "Pointwise CI" - lines.extend([ - "-" * 85, - "Event Study (Dynamic) Effects".center(85), - "-" * 85, - f"{'Rel. Period':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", - "-" * 85, - ]) + lines.extend( + [ + "-" * 85, + "Event Study (Dynamic) Effects".center(85), + "-" * 85, + f"{'Rel. Period':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", + "-" * 85, + ] + ) for rel_t in sorted(self.event_study_effects.keys()): eff = self.event_study_effects[rel_t] - sig = _get_significance_stars(eff['p_value']) + sig = _get_significance_stars(eff["p_value"]) lines.append( f"{rel_t:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} " f"{eff['t_stat']:>10.3f} {eff['p_value']:>10.4f} {sig:>6}" @@ -200,17 +209,19 @@ def summary(self, alpha: Optional[float] = None) -> str: # Group effects if available if self.group_effects: - lines.extend([ - "-" * 85, - "Effects by Treatment Cohort".center(85), - "-" * 85, - f"{'Cohort':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", - "-" * 85, - ]) + lines.extend( + [ + "-" * 85, + "Effects by Treatment Cohort".center(85), + "-" * 85, + f"{'Cohort':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", + "-" * 85, + ] + ) for group in sorted(self.group_effects.keys()): eff = self.group_effects[group] - sig = _get_significance_stars(eff['p_value']) + sig = _get_significance_stars(eff["p_value"]) lines.append( f"{group:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} " f"{eff['t_stat']:>10.3f} {eff['p_value']:>10.4f} {sig:>6}" @@ -218,10 +229,12 @@ def summary(self, alpha: Optional[float] = None) -> str: lines.extend(["-" * 85, ""]) - lines.extend([ - "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1", - "=" * 85, - ]) + lines.extend( + [ + "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1", + "=" * 85, + ] + ) return "\n".join(lines) @@ -246,16 +259,18 @@ def to_dataframe(self, level: str = "group_time") -> pd.DataFrame: if level == "group_time": rows = [] for (g, t), data in self.group_time_effects.items(): - rows.append({ - 'group': g, - 'time': t, - 'effect': data['effect'], - 'se': data['se'], - 't_stat': data['t_stat'], - 'p_value': data['p_value'], - 'conf_int_lower': data['conf_int'][0], - 'conf_int_upper': data['conf_int'][1], - }) + rows.append( + { + "group": g, + "time": t, + "effect": data["effect"], + "se": data["se"], + "t_stat": data["t_stat"], + "p_value": data["p_value"], + "conf_int_lower": data["conf_int"][0], + "conf_int_upper": data["conf_int"][1], + } + ) return pd.DataFrame(rows) elif level == "event_study": @@ -263,18 +278,20 @@ def to_dataframe(self, level: str = "group_time") -> pd.DataFrame: raise ValueError("Event study effects not computed. Use aggregate='event_study'.") rows = [] for rel_t, data in sorted(self.event_study_effects.items()): - cband_ci = data.get('cband_conf_int', (np.nan, np.nan)) - rows.append({ - 'relative_period': rel_t, - 'effect': data['effect'], - 'se': data['se'], - 't_stat': data['t_stat'], - 'p_value': data['p_value'], - 'conf_int_lower': data['conf_int'][0], - 'conf_int_upper': data['conf_int'][1], - 'cband_lower': cband_ci[0], - 'cband_upper': cband_ci[1], - }) + cband_ci = data.get("cband_conf_int", (np.nan, np.nan)) + rows.append( + { + "relative_period": rel_t, + "effect": data["effect"], + "se": data["se"], + "t_stat": data["t_stat"], + "p_value": data["p_value"], + "conf_int_lower": data["conf_int"][0], + "conf_int_upper": data["conf_int"][1], + "cband_lower": cband_ci[0], + "cband_upper": cband_ci[1], + } + ) return pd.DataFrame(rows) elif level == "group": @@ -282,19 +299,23 @@ def to_dataframe(self, level: str = "group_time") -> pd.DataFrame: raise ValueError("Group effects not computed. Use aggregate='group'.") rows = [] for group, data in sorted(self.group_effects.items()): - rows.append({ - 'group': group, - 'effect': data['effect'], - 'se': data['se'], - 't_stat': data['t_stat'], - 'p_value': data['p_value'], - 'conf_int_lower': data['conf_int'][0], - 'conf_int_upper': data['conf_int'][1], - }) + rows.append( + { + "group": group, + "effect": data["effect"], + "se": data["se"], + "t_stat": data["t_stat"], + "p_value": data["p_value"], + "conf_int_lower": data["conf_int"][0], + "conf_int_upper": data["conf_int"][1], + } + ) return pd.DataFrame(rows) else: - raise ValueError(f"Unknown level: {level}. Use 'group_time', 'event_study', or 'group'.") + raise ValueError( + f"Unknown level: {level}. Use 'group_time', 'event_study', or 'group'." + ) @property def is_significant(self) -> bool: diff --git a/diff_diff/triple_diff.py b/diff_diff/triple_diff.py index 55f6480..8b4ca32 100644 --- a/diff_diff/triple_diff.py +++ b/diff_diff/triple_diff.py @@ -28,9 +28,7 @@ import numpy as np import pandas as pd -from scipy import optimize - -from diff_diff.linalg import solve_ols +from diff_diff.linalg import solve_ols, solve_logit from diff_diff.results import _get_significance_stars from diff_diff.utils import safe_inference @@ -155,44 +153,52 @@ def summary(self, alpha: Optional[float] = None) -> str: if self.n_clusters is not None: lines.append(f"{'Number of clusters:':<30} {self.n_clusters:>15}") - lines.extend([ - "", - "-" * 75, - f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'':>5}", - "-" * 75, - f"{'ATT':<15} {self.att:>12.4f} {self.se:>12.4f} {self.t_stat:>10.3f} {self.p_value:>10.4f} {self.significance_stars:>5}", - "-" * 75, - "", - f"{conf_level}% Confidence Interval: [{self.conf_int[0]:.4f}, {self.conf_int[1]:.4f}]", - ]) - - # Show group means if available - if self.group_means: - lines.extend([ + lines.extend( + [ "", "-" * 75, - "Cell Means (Y):", + f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'':>5}", "-" * 75, - ]) + f"{'ATT':<15} {self.att:>12.4f} {self.se:>12.4f} {self.t_stat:>10.3f} {self.p_value:>10.4f} {self.significance_stars:>5}", + "-" * 75, + "", + f"{conf_level}% Confidence Interval: [{self.conf_int[0]:.4f}, {self.conf_int[1]:.4f}]", + ] + ) + + # Show group means if available + if self.group_means: + lines.extend( + [ + "", + "-" * 75, + "Cell Means (Y):", + "-" * 75, + ] + ) for cell, mean in self.group_means.items(): lines.append(f" {cell:<35} {mean:>12.4f}") # Show propensity score diagnostics if available if self.pscore_stats: - lines.extend([ - "", - "-" * 75, - "Propensity Score Diagnostics:", - "-" * 75, - ]) + lines.extend( + [ + "", + "-" * 75, + "Propensity Score Diagnostics:", + "-" * 75, + ] + ) for stat, value in self.pscore_stats.items(): lines.append(f" {stat:<35} {value:>12.4f}") - lines.extend([ - "", - "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1", - "=" * 75, - ]) + lines.extend( + [ + "", + "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1", + "=" * 75, + ] + ) return "\n".join(lines) @@ -259,66 +265,6 @@ def significance_stars(self) -> str: # ============================================================================= -def _logistic_regression( - X: np.ndarray, - y: np.ndarray, - max_iter: int = 100, - tol: float = 1e-6, -) -> Tuple[np.ndarray, np.ndarray]: - """ - Fit logistic regression using scipy optimize. - - Parameters - ---------- - X : np.ndarray - Feature matrix (n_samples, n_features). Intercept added automatically. - y : np.ndarray - Binary outcome (0/1). - max_iter : int - Maximum iterations. - tol : float - Convergence tolerance. - - Returns - ------- - beta : np.ndarray - Fitted coefficients (including intercept). - probs : np.ndarray - Predicted probabilities. - """ - n, p = X.shape - X_with_intercept = np.column_stack([np.ones(n), X]) - - def neg_log_likelihood(beta: np.ndarray) -> float: - z = np.dot(X_with_intercept, beta) - z = np.clip(z, -500, 500) - log_lik = np.sum(y * z - np.log(1 + np.exp(z))) - return -log_lik - - def gradient(beta: np.ndarray) -> np.ndarray: - z = np.dot(X_with_intercept, beta) - z = np.clip(z, -500, 500) - probs = 1 / (1 + np.exp(-z)) - return -np.dot(X_with_intercept.T, y - probs) - - beta_init = np.zeros(p + 1) - - result = optimize.minimize( - neg_log_likelihood, - beta_init, - method='BFGS', - jac=gradient, - options={'maxiter': max_iter, 'gtol': tol} - ) - - beta = result.x - z = np.dot(X_with_intercept, beta) - z = np.clip(z, -500, 500) - probs = 1 / (1 + np.exp(-z)) - - return beta, probs - - # ============================================================================= # Main Estimator Class # ============================================================================= @@ -445,8 +391,7 @@ def __init__( ): if estimation_method not in ("dr", "reg", "ipw"): raise ValueError( - f"estimation_method must be 'dr', 'reg', or 'ipw', " - f"got '{estimation_method}'" + f"estimation_method must be 'dr', 'reg', or 'ipw', " f"got '{estimation_method}'" ) if rank_deficient_action not in ["warn", "error", "silent"]: raise ValueError( @@ -520,9 +465,7 @@ def fit( # Store cluster IDs for SE computation self._cluster_ids = data[self.cluster].values if self.cluster is not None else None if self._cluster_ids is not None and np.any(pd.isna(data[self.cluster])): - raise ValueError( - f"Cluster column '{self.cluster}' contains missing values" - ) + raise ValueError(f"Cluster column '{self.cluster}' contains missing values") # Get covariates if specified X = None @@ -543,17 +486,11 @@ def fit( # Estimate ATT based on method if self.estimation_method == "reg": - att, se, r_squared, pscore_stats = self._regression_adjustment( - y, G, P, T, X - ) + att, se, r_squared, pscore_stats = self._regression_adjustment(y, G, P, T, X) elif self.estimation_method == "ipw": - att, se, r_squared, pscore_stats = self._ipw_estimation( - y, G, P, T, X - ) + att, se, r_squared, pscore_stats = self._ipw_estimation(y, G, P, T, X) else: # doubly robust - att, se, r_squared, pscore_stats = self._doubly_robust( - y, G, P, T, X - ) + att, se, r_squared, pscore_stats = self._doubly_robust(y, G, P, T, X) # Compute inference df = n_obs - 8 # Approximate df (8 cell means) @@ -626,13 +563,10 @@ def _validate_data( unique_vals = set(data[col].unique()) if not unique_vals.issubset({0, 1, 0.0, 1.0}): raise ValueError( - f"'{name}' column must be binary (0/1), " - f"got values: {sorted(unique_vals)}" + f"'{name}' column must be binary (0/1), " f"got values: {sorted(unique_vals)}" ) if len(unique_vals) < 2: - raise ValueError( - f"'{name}' column must have both 0 and 1 values" - ) + raise ValueError(f"'{name}' column must have both 0 and 1 values") # Check we have observations in all cells G = data[group].values @@ -832,10 +766,14 @@ def _estimate_ddd_decomposition( # Logistic regression: P(subgroup=4 | X) within {j, 4} ps_estimated = True try: - _, pscore_sub = _logistic_regression( - covX_sub[:, 1:], PA4 + _, pscore_sub = solve_logit( + covX_sub[:, 1:], + PA4, + rank_deficient_action=self.rank_deficient_action, ) except Exception: + if self.rank_deficient_action == "error": + raise pscore_sub = np.full(n_sub, np.mean(PA4)) ps_estimated = False warnings.warn( @@ -846,16 +784,17 @@ def _estimate_ddd_decomposition( stacklevel=3, ) - pscore_sub = np.clip(pscore_sub, self.pscore_trim, - 1 - self.pscore_trim) + pscore_sub = np.clip(pscore_sub, self.pscore_trim, 1 - self.pscore_trim) all_pscores[j] = pscore_sub # Check overlap: count obs at trim bounds # (1e-10 tolerance for floating-point after np.clip) - n_trimmed = int(np.sum( - (pscore_sub <= self.pscore_trim + 1e-10) - | (pscore_sub >= 1 - self.pscore_trim - 1e-10) - )) + n_trimmed = int( + np.sum( + (pscore_sub <= self.pscore_trim + 1e-10) + | (pscore_sub >= 1 - self.pscore_trim - 1e-10) + ) + ) frac_trimmed = n_trimmed / len(pscore_sub) if frac_trimmed > 0.05: overlap_issues.append((j, frac_trimmed)) @@ -873,13 +812,14 @@ def _estimate_ddd_decomposition( else: # No covariates: unconditional probability pscore_sub = np.full(n_sub, np.mean(PA4)) - pscore_sub = np.clip(pscore_sub, self.pscore_trim, - 1 - self.pscore_trim) + pscore_sub = np.clip(pscore_sub, self.pscore_trim, 1 - self.pscore_trim) # Check overlap (same logic as covariate branch) - n_trimmed = int(np.sum( - (pscore_sub <= self.pscore_trim + 1e-10) - | (pscore_sub >= 1 - self.pscore_trim - 1e-10) - )) + n_trimmed = int( + np.sum( + (pscore_sub <= self.pscore_trim + 1e-10) + | (pscore_sub >= 1 - self.pscore_trim - 1e-10) + ) + ) frac_trimmed = n_trimmed / len(pscore_sub) if frac_trimmed > 0.05: overlap_issues.append((j, frac_trimmed)) @@ -895,19 +835,33 @@ def _estimate_ddd_decomposition( else: # Fit separate OLS per subgroup-time cell, predict for all or_ctrl_pre = self._fit_predict_mu( - y_sub, covX_sub, sg_sub == j, post_sub == 0, n_sub) + y_sub, covX_sub, sg_sub == j, post_sub == 0, n_sub + ) or_ctrl_post = self._fit_predict_mu( - y_sub, covX_sub, sg_sub == j, post_sub == 1, n_sub) + y_sub, covX_sub, sg_sub == j, post_sub == 1, n_sub + ) or_trt_pre = self._fit_predict_mu( - y_sub, covX_sub, sg_sub == 4, post_sub == 0, n_sub) + y_sub, covX_sub, sg_sub == 4, post_sub == 0, n_sub + ) or_trt_post = self._fit_predict_mu( - y_sub, covX_sub, sg_sub == 4, post_sub == 1, n_sub) + y_sub, covX_sub, sg_sub == 4, post_sub == 1, n_sub + ) # --- Compute DiD ATT and influence function --- att_j, inf_j = self._compute_did_rc( - y_sub, post_sub, PA4, PAa, pscore_sub, covX_sub, - or_ctrl_pre, or_ctrl_post, or_trt_pre, or_trt_post, - hessian, est_method, n_sub, + y_sub, + post_sub, + PA4, + PAa, + pscore_sub, + covX_sub, + or_ctrl_pre, + or_ctrl_post, + or_trt_pre, + or_trt_post, + hessian, + est_method, + n_sub, ) # Track non-finite IF values (flag for NaN SE later) @@ -923,9 +877,7 @@ def _estimate_ddd_decomposition( # Emit overlap warning if >5% of observations trimmed in any comparison if overlap_issues: - details = ", ".join( - f"subgroup {j} vs 4: {frac:.0%}" for j, frac in overlap_issues - ) + details = ", ".join(f"subgroup {j} vs 4: {frac:.0%}" for j, frac in overlap_issues) warnings.warn( f"Poor propensity score overlap ({details} of observations " f"trimmed at bounds). IPW/DR estimates may be unreliable.", @@ -944,9 +896,9 @@ def _estimate_ddd_decomposition( w2 = n / n2 w1 = n / n1 - inf_func = (w3 * did_results[3]["inf"] - + w2 * did_results[2]["inf"] - - w1 * did_results[1]["inf"]) + inf_func = ( + w3 * did_results[3]["inf"] + w2 * did_results[2]["inf"] - w1 * did_results[1]["inf"] + ) if self._cluster_ids is not None: # Cluster-robust SE: sum IF within clusters, then Liang-Zeger variance @@ -954,17 +906,15 @@ def _estimate_ddd_decomposition( n_clusters_val = len(unique_clusters) if n_clusters_val < 2: raise ValueError( - f"Need at least 2 clusters for cluster-robust SEs, " - f"got {n_clusters_val}" + f"Need at least 2 clusters for cluster-robust SEs, " f"got {n_clusters_val}" ) - cluster_sums = np.array([ - np.sum(inf_func[self._cluster_ids == c]) for c in unique_clusters - ]) + cluster_sums = np.array( + [np.sum(inf_func[self._cluster_ids == c]) for c in unique_clusters] + ) # V = (G/(G-1)) * (1/n^2) * sum(psi_c^2) - se = float(np.sqrt( - (n_clusters_val / (n_clusters_val - 1)) - * np.sum(cluster_sums**2) / n**2 - )) + se = float( + np.sqrt((n_clusters_val / (n_clusters_val - 1)) * np.sum(cluster_sums**2) / n**2) + ) else: se = float(np.std(inf_func, ddof=1) / np.sqrt(n)) @@ -1002,7 +952,8 @@ def _estimate_ddd_decomposition( y_fit = y[cell_mask] try: beta_rs, _, _ = solve_ols( - X_fit, y_fit, + X_fit, + y_fit, rank_deficient_action=self.rank_deficient_action, ) beta_rs = np.where(np.isnan(beta_rs), 0.0, beta_rs) @@ -1035,7 +986,8 @@ def _fit_predict_mu( try: beta, _, _ = solve_ols( - X_fit, y_fit, + X_fit, + y_fit, rank_deficient_action=self.rank_deficient_action, ) # Replace NaN coefficients (dropped columns) with 0 for prediction @@ -1072,17 +1024,26 @@ def _compute_did_rc( Matches R's triplediff::compute_did_rc(). """ if est_method == "ipw": - return self._compute_did_rc_ipw( - y, post, PA4, PAa, pscore, covX, hessian, n) + return self._compute_did_rc_ipw(y, post, PA4, PAa, pscore, covX, hessian, n) elif est_method == "reg": return self._compute_did_rc_reg( - y, post, PA4, PAa, covX, - or_ctrl_pre, or_ctrl_post, or_trt_pre, or_trt_post, n) + y, post, PA4, PAa, covX, or_ctrl_pre, or_ctrl_post, or_trt_pre, or_trt_post, n + ) else: return self._compute_did_rc_dr( - y, post, PA4, PAa, pscore, covX, - or_ctrl_pre, or_ctrl_post, or_trt_pre, or_trt_post, - hessian, n) + y, + post, + PA4, + PAa, + pscore, + covX, + or_ctrl_pre, + or_ctrl_post, + or_trt_pre, + or_trt_post, + hessian, + n, + ) def _compute_did_rc_ipw( self, @@ -1115,24 +1076,21 @@ def _hajek(riesz, y_vals): eta_control_pre, att_control_pre = _hajek(riesz_control_pre, y) eta_control_post, att_control_post = _hajek(riesz_control_post, y) - att = ((att_treat_post - att_treat_pre) - - (att_control_post - att_control_pre)) + att = (att_treat_post - att_treat_pre) - (att_control_post - att_control_pre) # Influence function - inf_treat_pre = (eta_treat_pre - - riesz_treat_pre * att_treat_pre - / np.mean(riesz_treat_pre)) - inf_treat_post = (eta_treat_post - - riesz_treat_post * att_treat_post - / np.mean(riesz_treat_post)) + inf_treat_pre = eta_treat_pre - riesz_treat_pre * att_treat_pre / np.mean(riesz_treat_pre) + inf_treat_post = eta_treat_post - riesz_treat_post * att_treat_post / np.mean( + riesz_treat_post + ) inf_treat = inf_treat_post - inf_treat_pre - inf_control_pre = (eta_control_pre - - riesz_control_pre * att_control_pre - / np.mean(riesz_control_pre)) - inf_control_post = (eta_control_post - - riesz_control_post * att_control_post - / np.mean(riesz_control_post)) + inf_control_pre = eta_control_pre - riesz_control_pre * att_control_pre / np.mean( + riesz_control_pre + ) + inf_control_post = eta_control_post - riesz_control_post * att_control_post / np.mean( + riesz_control_post + ) inf_control = inf_control_post - inf_control_pre # Propensity score correction for influence function @@ -1178,10 +1136,8 @@ def _compute_did_rc_reg( reg_att_treat_post = riesz_treat_post * y reg_att_control = riesz_control * (or_ctrl_post - or_ctrl_pre) - eta_treat_pre = (np.mean(reg_att_treat_pre) - / np.mean(riesz_treat_pre)) - eta_treat_post = (np.mean(reg_att_treat_post) - / np.mean(riesz_treat_post)) + eta_treat_pre = np.mean(reg_att_treat_pre) / np.mean(riesz_treat_pre) + eta_treat_post = np.mean(reg_att_treat_post) / np.mean(riesz_treat_post) eta_control = np.mean(reg_att_control) / np.mean(riesz_control) att = (eta_treat_post - eta_treat_pre) - eta_control @@ -1208,20 +1164,21 @@ def _compute_did_rc_reg( XpX_inv_post = np.linalg.pinv(XpX_post) asy_lin_rep_ols_post = wols_eX_post @ XpX_inv_post - inf_treat_pre = ((reg_att_treat_pre - - riesz_treat_pre * eta_treat_pre) - / np.mean(riesz_treat_pre)) - inf_treat_post = ((reg_att_treat_post - - riesz_treat_post * eta_treat_post) - / np.mean(riesz_treat_post)) + inf_treat_pre = (reg_att_treat_pre - riesz_treat_pre * eta_treat_pre) / np.mean( + riesz_treat_pre + ) + inf_treat_post = (reg_att_treat_post - riesz_treat_post * eta_treat_post) / np.mean( + riesz_treat_post + ) inf_treat = inf_treat_post - inf_treat_pre inf_control_1 = reg_att_control - riesz_control * eta_control M1 = np.mean(riesz_control[:, None] * covX, axis=0) inf_control_2_post = asy_lin_rep_ols_post @ M1 inf_control_2_pre = asy_lin_rep_ols_pre @ M1 - inf_control = ((inf_control_1 + inf_control_2_post - inf_control_2_pre) - / np.mean(riesz_control)) + inf_control = (inf_control_1 + inf_control_2_post - inf_control_2_pre) / np.mean( + riesz_control + ) inf_func = inf_treat - inf_control return att, inf_func @@ -1257,24 +1214,22 @@ def _compute_did_rc_dr( def _safe_ratio(num, denom): return num / denom if denom > 0 else 0.0 - eta_treat_pre = (riesz_treat_pre * (y - or_ctrl) - * _safe_ratio(1, np.mean(riesz_treat_pre))) - eta_treat_post = (riesz_treat_post * (y - or_ctrl) - * _safe_ratio(1, np.mean(riesz_treat_post))) - eta_control_pre = (riesz_control_pre * (y - or_ctrl) - * _safe_ratio(1, np.mean(riesz_control_pre))) - eta_control_post = (riesz_control_post * (y - or_ctrl) - * _safe_ratio(1, np.mean(riesz_control_post))) + eta_treat_pre = riesz_treat_pre * (y - or_ctrl) * _safe_ratio(1, np.mean(riesz_treat_pre)) + eta_treat_post = ( + riesz_treat_post * (y - or_ctrl) * _safe_ratio(1, np.mean(riesz_treat_post)) + ) + eta_control_pre = ( + riesz_control_pre * (y - or_ctrl) * _safe_ratio(1, np.mean(riesz_control_pre)) + ) + eta_control_post = ( + riesz_control_post * (y - or_ctrl) * _safe_ratio(1, np.mean(riesz_control_post)) + ) # Efficiency correction (OR bias correction) - eta_d_post = (riesz_d * (or_trt_post - or_ctrl_post) - * _safe_ratio(1, np.mean(riesz_d))) - eta_dt1_post = (riesz_dt1 * (or_trt_post - or_ctrl_post) - * _safe_ratio(1, np.mean(riesz_dt1))) - eta_d_pre = (riesz_d * (or_trt_pre - or_ctrl_pre) - * _safe_ratio(1, np.mean(riesz_d))) - eta_dt0_pre = (riesz_dt0 * (or_trt_pre - or_ctrl_pre) - * _safe_ratio(1, np.mean(riesz_dt0))) + eta_d_post = riesz_d * (or_trt_post - or_ctrl_post) * _safe_ratio(1, np.mean(riesz_d)) + eta_dt1_post = riesz_dt1 * (or_trt_post - or_ctrl_post) * _safe_ratio(1, np.mean(riesz_dt1)) + eta_d_pre = riesz_d * (or_trt_pre - or_ctrl_pre) * _safe_ratio(1, np.mean(riesz_d)) + eta_dt0_pre = riesz_dt0 * (or_trt_pre - or_ctrl_pre) * _safe_ratio(1, np.mean(riesz_dt0)) att_treat_pre = float(np.mean(eta_treat_pre)) att_treat_post = float(np.mean(eta_treat_post)) @@ -1285,10 +1240,12 @@ def _safe_ratio(num, denom): att_d_pre = float(np.mean(eta_d_pre)) att_dt0_pre = float(np.mean(eta_dt0_pre)) - att = ((att_treat_post - att_treat_pre) - - (att_control_post - att_control_pre) - + (att_d_post - att_dt1_post) - - (att_d_pre - att_dt0_pre)) + att = ( + (att_treat_post - att_treat_pre) + - (att_control_post - att_control_pre) + + (att_d_post - att_dt1_post) + - (att_d_pre - att_dt0_pre) + ) # --- Influence function --- # OLS asymptotic linear representations (control subgroup) @@ -1315,8 +1272,7 @@ def _safe_ratio(num, denom): # OLS representations (treated subgroup) weights_ols_pre_treat = PA4 * (1 - post) wols_x_pre_treat = weights_ols_pre_treat[:, None] * covX - wols_eX_pre_treat = (weights_ols_pre_treat - * (y - or_trt_pre))[:, None] * covX + wols_eX_pre_treat = (weights_ols_pre_treat * (y - or_trt_pre))[:, None] * covX XpX_pre_treat = wols_x_pre_treat.T @ covX / n try: XpX_inv_pre_treat = np.linalg.inv(XpX_pre_treat) @@ -1326,8 +1282,7 @@ def _safe_ratio(num, denom): weights_ols_post_treat = PA4 * post wols_x_post_treat = weights_ols_post_treat[:, None] * covX - wols_eX_post_treat = (weights_ols_post_treat - * (y - or_trt_post))[:, None] * covX + wols_eX_post_treat = (weights_ols_post_treat * (y - or_trt_post))[:, None] * covX XpX_post_treat = wols_x_post_treat.T @ covX / n try: XpX_inv_post_treat = np.linalg.inv(XpX_post_treat) @@ -1346,22 +1301,28 @@ def _safe_ratio(num, denom): m_riesz_treat_pre = np.mean(riesz_treat_pre) m_riesz_treat_post = np.mean(riesz_treat_post) - inf_treat_pre = (eta_treat_pre - riesz_treat_pre * att_treat_pre - / m_riesz_treat_pre) if m_riesz_treat_pre > 0 \ + inf_treat_pre = ( + (eta_treat_pre - riesz_treat_pre * att_treat_pre / m_riesz_treat_pre) + if m_riesz_treat_pre > 0 else np.zeros(n) - inf_treat_post = (eta_treat_post - riesz_treat_post * att_treat_post - / m_riesz_treat_post) if m_riesz_treat_post > 0 \ + ) + inf_treat_post = ( + (eta_treat_post - riesz_treat_post * att_treat_post / m_riesz_treat_post) + if m_riesz_treat_post > 0 else np.zeros(n) + ) # OR correction for treated - M1_post = (-np.mean( - (riesz_treat_post * post)[:, None] * covX, axis=0) - / m_riesz_treat_post) if m_riesz_treat_post > 0 \ + M1_post = ( + (-np.mean((riesz_treat_post * post)[:, None] * covX, axis=0) / m_riesz_treat_post) + if m_riesz_treat_post > 0 else np.zeros(covX.shape[1]) - M1_pre = (-np.mean( - (riesz_treat_pre * (1 - post))[:, None] * covX, axis=0) - / m_riesz_treat_pre) if m_riesz_treat_pre > 0 \ + ) + M1_pre = ( + (-np.mean((riesz_treat_pre * (1 - post))[:, None] * covX, axis=0) / m_riesz_treat_pre) + if m_riesz_treat_pre > 0 else np.zeros(covX.shape[1]) + ) inf_treat_or_post = asy_lin_rep_ols_post @ M1_post inf_treat_or_pre = asy_lin_rep_ols_pre @ M1_pre @@ -1369,37 +1330,54 @@ def _safe_ratio(num, denom): m_riesz_control_pre = np.mean(riesz_control_pre) m_riesz_control_post = np.mean(riesz_control_post) - inf_control_pre = (eta_control_pre - - riesz_control_pre * att_control_pre - / m_riesz_control_pre) if m_riesz_control_pre > 0 \ + inf_control_pre = ( + (eta_control_pre - riesz_control_pre * att_control_pre / m_riesz_control_pre) + if m_riesz_control_pre > 0 else np.zeros(n) - inf_control_post = (eta_control_post - - riesz_control_post * att_control_post - / m_riesz_control_post) if m_riesz_control_post > 0 \ + ) + inf_control_post = ( + (eta_control_post - riesz_control_post * att_control_post / m_riesz_control_post) + if m_riesz_control_post > 0 else np.zeros(n) + ) # PS correction for control - M2_pre = (np.mean( - (riesz_control_pre * (y - or_ctrl - att_control_pre))[:, None] - * covX, axis=0) - / m_riesz_control_pre) if m_riesz_control_pre > 0 \ + M2_pre = ( + ( + np.mean( + (riesz_control_pre * (y - or_ctrl - att_control_pre))[:, None] * covX, axis=0 + ) + / m_riesz_control_pre + ) + if m_riesz_control_pre > 0 else np.zeros(covX.shape[1]) - M2_post = (np.mean( - (riesz_control_post * (y - or_ctrl - att_control_post))[:, None] - * covX, axis=0) - / m_riesz_control_post) if m_riesz_control_post > 0 \ + ) + M2_post = ( + ( + np.mean( + (riesz_control_post * (y - or_ctrl - att_control_post))[:, None] * covX, axis=0 + ) + / m_riesz_control_post + ) + if m_riesz_control_post > 0 else np.zeros(covX.shape[1]) + ) inf_control_ps = asy_lin_rep_ps @ (M2_post - M2_pre) # OR correction for control - M3_post = (-np.mean( - (riesz_control_post * post)[:, None] * covX, axis=0) - / m_riesz_control_post) if m_riesz_control_post > 0 \ + M3_post = ( + (-np.mean((riesz_control_post * post)[:, None] * covX, axis=0) / m_riesz_control_post) + if m_riesz_control_post > 0 else np.zeros(covX.shape[1]) - M3_pre = (-np.mean( - (riesz_control_pre * (1 - post))[:, None] * covX, axis=0) - / m_riesz_control_pre) if m_riesz_control_pre > 0 \ + ) + M3_pre = ( + ( + -np.mean((riesz_control_pre * (1 - post))[:, None] * covX, axis=0) + / m_riesz_control_pre + ) + if m_riesz_control_pre > 0 else np.zeros(covX.shape[1]) + ) inf_control_or_post = asy_lin_rep_ols_post @ M3_post inf_control_or_pre = asy_lin_rep_ols_pre @ M3_pre @@ -1408,41 +1386,46 @@ def _safe_ratio(num, denom): m_riesz_dt1 = np.mean(riesz_dt1) m_riesz_dt0 = np.mean(riesz_dt0) - inf_eff1 = ((eta_d_post - riesz_d * att_d_post / m_riesz_d) - if m_riesz_d > 0 else np.zeros(n)) - inf_eff2 = ((eta_dt1_post - riesz_dt1 * att_dt1_post / m_riesz_dt1) - if m_riesz_dt1 > 0 else np.zeros(n)) - inf_eff3 = ((eta_d_pre - riesz_d * att_d_pre / m_riesz_d) - if m_riesz_d > 0 else np.zeros(n)) - inf_eff4 = ((eta_dt0_pre - riesz_dt0 * att_dt0_pre / m_riesz_dt0) - if m_riesz_dt0 > 0 else np.zeros(n)) + inf_eff1 = (eta_d_post - riesz_d * att_d_post / m_riesz_d) if m_riesz_d > 0 else np.zeros(n) + inf_eff2 = ( + (eta_dt1_post - riesz_dt1 * att_dt1_post / m_riesz_dt1) + if m_riesz_dt1 > 0 + else np.zeros(n) + ) + inf_eff3 = (eta_d_pre - riesz_d * att_d_pre / m_riesz_d) if m_riesz_d > 0 else np.zeros(n) + inf_eff4 = ( + (eta_dt0_pre - riesz_dt0 * att_dt0_pre / m_riesz_dt0) + if m_riesz_dt0 > 0 + else np.zeros(n) + ) inf_eff = (inf_eff1 - inf_eff2) - (inf_eff3 - inf_eff4) # OR combination - mom_post = np.mean( - (riesz_d[:, None] / m_riesz_d - - riesz_dt1[:, None] / m_riesz_dt1) * covX, - axis=0, - ) if (m_riesz_d > 0 and m_riesz_dt1 > 0) \ + mom_post = ( + np.mean( + (riesz_d[:, None] / m_riesz_d - riesz_dt1[:, None] / m_riesz_dt1) * covX, + axis=0, + ) + if (m_riesz_d > 0 and m_riesz_dt1 > 0) else np.zeros(covX.shape[1]) - mom_pre = np.mean( - (riesz_d[:, None] / m_riesz_d - - riesz_dt0[:, None] / m_riesz_dt0) * covX, - axis=0, - ) if (m_riesz_d > 0 and m_riesz_dt0 > 0) \ + ) + mom_pre = ( + np.mean( + (riesz_d[:, None] / m_riesz_d - riesz_dt0[:, None] / m_riesz_dt0) * covX, + axis=0, + ) + if (m_riesz_d > 0 and m_riesz_dt0 > 0) else np.zeros(covX.shape[1]) - inf_or_post = ((asy_lin_rep_ols_post_treat - asy_lin_rep_ols_post) - @ mom_post) - inf_or_pre = ((asy_lin_rep_ols_pre_treat - asy_lin_rep_ols_pre) - @ mom_pre) + ) + inf_or_post = (asy_lin_rep_ols_post_treat - asy_lin_rep_ols_post) @ mom_post + inf_or_pre = (asy_lin_rep_ols_pre_treat - asy_lin_rep_ols_pre) @ mom_pre inf_treat_or = inf_treat_or_post + inf_treat_or_pre inf_control_or = inf_control_or_post + inf_control_or_pre inf_or = inf_or_post - inf_or_pre inf_treat = inf_treat_post - inf_treat_pre + inf_treat_or - inf_control = (inf_control_post - inf_control_pre - + inf_control_ps + inf_control_or) + inf_control = inf_control_post - inf_control_pre + inf_control_ps + inf_control_or inf_func = inf_treat - inf_control + inf_eff + inf_or return att, inf_func diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 42548c8..500fc37 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -393,6 +393,17 @@ The multiplier bootstrap uses random weights w_i with E[w]=0 and Var(w)=1: - This is expected: CS uses consecutive comparisons, SA uses fixed reference (e=-1-anticipation) - Use `base_period="universal"` for methodologically comparable pre-treatment effects - Post-treatment effects match regardless of base_period setting +- Propensity score estimation: + - Algorithm: IRLS (Fisher scoring), matching R's `glm(family=binomial)` default + - **Note:** Uses IRLS (Fisher scoring) for propensity score estimation, consistent + with R's `did::att_gt()` which uses `glm(family=binomial)` internally + - Near-separation detection: Warns when predicted probabilities are within 1e-5 + of 0 or 1, or when IRLS fails to converge + - Trimming: Propensity scores clipped to `[pscore_trim, 1-pscore_trim]` (default + 0.01) before weight computation. Warning emitted when scores are trimmed. + - Fallback: If IRLS fails entirely (LinAlgError/ValueError), falls back to + unconditional propensity score with warning. Exception: when + `rank_deficient_action="error"`, the error is re-raised instead of falling back. - Control group with `control_group="not_yet_treated"`: - Always excludes cohort g from controls when computing ATT(g,t) - This applies to both pre-treatment (t < g) and post-treatment (t >= g) periods @@ -1180,7 +1191,8 @@ has no additional effect. - Cluster IDs: must not contain NaN (raises `ValueError`) - Overlap warning: emitted when >5% of observations are trimmed at pscore bounds (IPW/DR only) - Propensity score estimation failure: falls back to unconditional probability P(subgroup=4), - sets hessian=None (skipping PS correction in influence function), emits UserWarning + sets hessian=None (skipping PS correction in influence function), emits UserWarning. + Exception: when `rank_deficient_action="error"`, the error is re-raised instead of falling back. - Collinear covariates: detected via pivoted QR in `solve_ols()`, action controlled by `rank_deficient_action` ("warn", "error", "silent") - Non-finite influence function values (e.g., from extreme propensity scores in IPW/DR diff --git a/tests/test_linalg.py b/tests/test_linalg.py index 07efd46..db98005 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -236,14 +236,20 @@ def test_rank_deficient_produces_nan_for_dropped_columns(self): # Non-NaN coefficients should be finite and reasonable finite_coef = coef[~nan_mask] - assert np.all(np.isfinite(finite_coef)), f"Finite coefficients contain non-finite values: {finite_coef}" - assert np.all(np.abs(finite_coef) < 1e6), f"Finite coefficients are unreasonably large: {finite_coef}" + assert np.all( + np.isfinite(finite_coef) + ), f"Finite coefficients contain non-finite values: {finite_coef}" + assert np.all( + np.abs(finite_coef) < 1e6 + ), f"Finite coefficients are unreasonably large: {finite_coef}" # VCoV should have NaN for dropped column's row and column assert vcov is not None dropped_idx = np.where(nan_mask)[0][0] assert np.all(np.isnan(vcov[dropped_idx, :])), "VCoV row for dropped column should be NaN" - assert np.all(np.isnan(vcov[:, dropped_idx])), "VCoV column for dropped column should be NaN" + assert np.all( + np.isnan(vcov[:, dropped_idx]) + ), "VCoV column for dropped column should be NaN" # VCoV for identified coefficients should be finite kept_idx = np.where(~nan_mask)[0] @@ -292,13 +298,14 @@ def test_rank_deficient_column_names_in_warning(self): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - coef, resid, vcov = solve_ols( - X, y, - column_names=["intercept", "x1", "x2_collinear"] - ) + coef, resid, vcov = solve_ols(X, y, column_names=["intercept", "x1", "x2_collinear"]) assert len(w) == 1 # Column name should appear in warning (not just index) - assert "x2_collinear" in str(w[0].message) or "intercept" in str(w[0].message) or "x1" in str(w[0].message) + assert ( + "x2_collinear" in str(w[0].message) + or "intercept" in str(w[0].message) + or "x1" in str(w[0].message) + ) def test_skip_rank_check_bypasses_qr_decomposition(self): """Test that skip_rank_check=True skips QR rank detection. @@ -398,18 +405,18 @@ def test_multiperiod_like_design_full_rank(self): assert np.all(np.isfinite(coef)), f"Full-rank matrix: coefficients should be finite" assert np.all(np.abs(coef) < 1e6), f"Coefficients are unreasonably large: {coef}" # The treatment effect coefficient (last one) should be close to true effect - assert abs(coef[-1] - true_effect) < 2.0, ( - f"Treatment effect {coef[-1]} is too far from true {true_effect}" - ) + assert ( + abs(coef[-1] - true_effect) < 2.0 + ), f"Treatment effect {coef[-1]} is too far from true {true_effect}" else: # If rank-deficient, check that identified coefficients are valid finite_coef = coef[~np.isnan(coef)] assert np.all(np.isfinite(finite_coef)), f"Identified coefficients should be finite" # If treatment effect is identified, check it if not np.isnan(coef[-1]): - assert abs(coef[-1] - true_effect) < 2.0, ( - f"Treatment effect {coef[-1]} is too far from true {true_effect}" - ) + assert ( + abs(coef[-1] - true_effect) < 2.0 + ), f"Treatment effect {coef[-1]} is too far from true {true_effect}" def test_single_cluster_error(self): """Test that single cluster raises error.""" @@ -436,10 +443,12 @@ def test_singleton_clusters_included_in_variance(self): # Create clusters: one large cluster (50 obs), 50 singleton clusters # Total: 51 clusters, 50 of which are singletons - cluster_ids = np.concatenate([ - np.zeros(50), # Large cluster (id=0) - np.arange(1, 51) # 50 singleton clusters (ids 1-50) - ]) + cluster_ids = np.concatenate( + [ + np.zeros(50), # Large cluster (id=0) + np.arange(1, 51), # 50 singleton clusters (ids 1-50) + ] + ) coef, resid, vcov = solve_ols(X, y, cluster_ids=cluster_ids) @@ -454,15 +463,15 @@ def test_singleton_clusters_included_in_variance(self): # Compare to case without singletons (only large clusters) # With fewer clusters, variance should be DIFFERENT (not necessarily larger) - cluster_ids_no_singletons = np.concatenate([ - np.zeros(50), # Cluster 0 - np.ones(50) # Cluster 1 - ]) + cluster_ids_no_singletons = np.concatenate( + [np.zeros(50), np.ones(50)] # Cluster 0 # Cluster 1 + ) _, _, vcov_no_singletons = solve_ols(X, y, cluster_ids=cluster_ids_no_singletons) # The two variance estimates should differ (singletons change the calculation) - assert not np.allclose(vcov, vcov_no_singletons), \ - "Singleton clusters should affect variance estimation" + assert not np.allclose( + vcov, vcov_no_singletons + ), "Singleton clusters should affect variance estimation" class TestComputeRobustVcov: @@ -543,12 +552,11 @@ def mock_rust_vcov(*args, **kwargs): # Verify warning was emitted instability_warnings = [ - w for w in caught_warnings - if "numerical instability" in str(w.message).lower() + w for w in caught_warnings if "numerical instability" in str(w.message).lower() ] - assert len(instability_warnings) == 1, ( - f"Expected 1 numerical instability warning, got {len(instability_warnings)}" - ) + assert ( + len(instability_warnings) == 1 + ), f"Expected 1 numerical instability warning, got {len(instability_warnings)}" # Verify fallback produced valid vcov matrix assert vcov.shape == (X.shape[1], X.shape[1]) @@ -730,23 +738,20 @@ def test_is_significant_default_alpha(self): """Test is_significant with default alpha.""" # Significant at 0.05 result = InferenceResult( - coefficient=2.0, se=0.5, t_stat=4.0, p_value=0.001, - conf_int=(1.0, 3.0), alpha=0.05 + coefficient=2.0, se=0.5, t_stat=4.0, p_value=0.001, conf_int=(1.0, 3.0), alpha=0.05 ) assert result.is_significant() is True # Not significant at 0.05 result2 = InferenceResult( - coefficient=0.5, se=0.5, t_stat=1.0, p_value=0.3, - conf_int=(-0.5, 1.5), alpha=0.05 + coefficient=0.5, se=0.5, t_stat=1.0, p_value=0.3, conf_int=(-0.5, 1.5), alpha=0.05 ) assert result2.is_significant() is False def test_is_significant_custom_alpha(self): """Test is_significant with custom alpha override.""" result = InferenceResult( - coefficient=2.0, se=0.5, t_stat=4.0, p_value=0.02, - conf_int=(1.0, 3.0), alpha=0.05 + coefficient=2.0, se=0.5, t_stat=4.0, p_value=0.02, conf_int=(1.0, 3.0), alpha=0.05 ) # Significant at 0.05 (default) @@ -759,44 +764,44 @@ def test_significance_stars(self): """Test significance_stars returns correct stars.""" # p < 0.001 -> *** result = InferenceResult( - coefficient=1.0, se=0.1, t_stat=10.0, p_value=0.0001, - conf_int=(0.8, 1.2) + coefficient=1.0, se=0.1, t_stat=10.0, p_value=0.0001, conf_int=(0.8, 1.2) ) assert result.significance_stars() == "***" # p < 0.01 -> ** result2 = InferenceResult( - coefficient=1.0, se=0.2, t_stat=5.0, p_value=0.005, - conf_int=(0.6, 1.4) + coefficient=1.0, se=0.2, t_stat=5.0, p_value=0.005, conf_int=(0.6, 1.4) ) assert result2.significance_stars() == "**" # p < 0.05 -> * result3 = InferenceResult( - coefficient=1.0, se=0.3, t_stat=3.0, p_value=0.03, - conf_int=(0.4, 1.6) + coefficient=1.0, se=0.3, t_stat=3.0, p_value=0.03, conf_int=(0.4, 1.6) ) assert result3.significance_stars() == "*" # p < 0.1 -> . result4 = InferenceResult( - coefficient=1.0, se=0.4, t_stat=2.5, p_value=0.08, - conf_int=(0.2, 1.8) + coefficient=1.0, se=0.4, t_stat=2.5, p_value=0.08, conf_int=(0.2, 1.8) ) assert result4.significance_stars() == "." # p >= 0.1 -> "" result5 = InferenceResult( - coefficient=1.0, se=0.5, t_stat=2.0, p_value=0.15, - conf_int=(0.0, 2.0) + coefficient=1.0, se=0.5, t_stat=2.0, p_value=0.15, conf_int=(0.0, 2.0) ) assert result5.significance_stars() == "" def test_to_dict(self): """Test to_dict returns all fields.""" result = InferenceResult( - coefficient=2.5, se=0.5, t_stat=5.0, p_value=0.001, - conf_int=(1.52, 3.48), df=100, alpha=0.05 + coefficient=2.5, + se=0.5, + t_stat=5.0, + p_value=0.001, + conf_int=(1.52, 3.48), + df=100, + alpha=0.05, ) d = result.to_dict() @@ -1329,8 +1334,10 @@ def test_rank_deficient_action_warn_default(self): reg.fit(X, y) # Should have a warning about rank deficiency assert len(w) > 0, "Expected warning about rank deficiency" - assert any("Rank-deficient" in str(x.message) or "rank-deficient" in str(x.message).lower() - for x in w), f"Expected rank-deficient warning, got: {[str(x.message) for x in w]}" + assert any( + "Rank-deficient" in str(x.message) or "rank-deficient" in str(x.message).lower() + for x in w + ), f"Expected rank-deficient warning, got: {[str(x.message) for x in w]}" class TestNumericalStability: @@ -1411,17 +1418,16 @@ def test_did_estimator_produces_valid_results(self): # Create reproducible test data np.random.seed(42) n = 200 - data = pd.DataFrame({ - "unit": np.repeat(range(20), 10), - "time": np.tile(range(10), 20), - "treated": np.repeat([0] * 10 + [1] * 10, 10), - "post": np.tile([0] * 5 + [1] * 5, 20), - }) - # True ATT = 2.0 - data["outcome"] = ( - np.random.randn(n) - + 2.0 * data["treated"] * data["post"] + data = pd.DataFrame( + { + "unit": np.repeat(range(20), 10), + "time": np.tile(range(10), 20), + "treated": np.repeat([0] * 10 + [1] * 10, 10), + "post": np.tile([0] * 5 + [1] * 5, 20), + } ) + # True ATT = 2.0 + data["outcome"] = np.random.randn(n) + 2.0 * data["treated"] * data["post"] # Fit estimator did = DifferenceInDifferences(robust=True) @@ -1444,11 +1450,13 @@ def test_twfe_estimator_produces_valid_results(self): n_times = 6 n = n_units * n_times - data = pd.DataFrame({ - "unit": np.repeat(np.arange(n_units), n_times), - "time": np.tile(np.arange(n_times), n_units), - "treated": np.repeat(np.random.binomial(1, 0.5, n_units), n_times), - }) + data = pd.DataFrame( + { + "unit": np.repeat(np.arange(n_units), n_times), + "time": np.tile(np.arange(n_times), n_units), + "treated": np.repeat(np.random.binomial(1, 0.5, n_units), n_times), + } + ) data["post"] = (data["time"] >= 3).astype(int) # Add unit and time effects with true ATT = 1.5 @@ -1462,9 +1470,7 @@ def test_twfe_estimator_produces_valid_results(self): ) twfe = TwoWayFixedEffects() - result = twfe.fit( - data, outcome="y", treatment="treated", time="post", unit="unit" - ) + result = twfe.fit(data, outcome="y", treatment="treated", time="post", unit="unit") # Should produce valid results assert result.se > 0 @@ -1480,10 +1486,12 @@ def test_sun_abraham_estimator_produces_valid_results(self): n_times = 10 n = n_units * n_times - data = pd.DataFrame({ - "unit": np.repeat(np.arange(n_units), n_times), - "time": np.tile(np.arange(n_times), n_units), - }) + data = pd.DataFrame( + { + "unit": np.repeat(np.arange(n_units), n_times), + "time": np.tile(np.arange(n_times), n_units), + } + ) # Staggered treatment timing first_treat_map = {} @@ -1500,9 +1508,7 @@ def test_sun_abraham_estimator_produces_valid_results(self): data["y"] = np.random.randn(n) + data["treated"] * 2.0 sa = SunAbraham(n_bootstrap=0) - result = sa.fit( - data, outcome="y", unit="unit", time="time", first_treat="first_treat" - ) + result = sa.fit(data, outcome="y", unit="unit", time="time", first_treat="first_treat") # Should produce valid results assert result.overall_se > 0 @@ -1510,6 +1516,165 @@ def test_sun_abraham_estimator_produces_valid_results(self): assert len(result.event_study_effects) > 0 +class TestSolveLogit: + """Tests for IRLS logistic regression (solve_logit).""" + + def test_irls_coefficients_well_conditioned(self): + """IRLS produces correct coefficients on well-conditioned data.""" + from diff_diff.linalg import solve_logit + + rng = np.random.default_rng(42) + n = 500 + X = rng.standard_normal((n, 3)) + beta_true = np.array([0.5, -1.0, 0.8]) + z = X @ beta_true + y = (rng.random(n) < 1 / (1 + np.exp(-z))).astype(float) + + beta, probs = solve_logit(X, y) + # beta[0] is intercept, beta[1:] are coefficients + assert beta.shape == (4,) + assert probs.shape == (n,) + # Coefficients should be close to true values (intercept ~0) + assert np.abs(beta[0]) < 1.0, "Intercept should be near zero" + assert np.allclose( + beta[1:], beta_true, atol=0.3 + ), f"Coefficients {beta[1:]} not close to {beta_true}" + + def test_irls_convergence(self): + """IRLS converges without warnings on standard data.""" + from diff_diff.linalg import solve_logit + + rng = np.random.default_rng(123) + n = 200 + X = rng.standard_normal((n, 2)) + y = (rng.random(n) < 0.5).astype(float) + + import warnings + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + beta, probs = solve_logit(X, y) + + convergence_warns = [x for x in w if "did not converge" in str(x.message)] + assert len(convergence_warns) == 0 + + def test_irls_non_convergence_warning(self): + """IRLS warns when max_iter=1 prevents convergence.""" + from diff_diff.linalg import solve_logit + + rng = np.random.default_rng(42) + n = 100 + X = rng.standard_normal((n, 2)) + y = (rng.random(n) < 0.5).astype(float) + + with pytest.warns(UserWarning, match="did not converge"): + solve_logit(X, y, max_iter=1) + + def test_near_separation_warning(self): + """Warns about near-separation when covariate perfectly predicts outcome.""" + from diff_diff.linalg import solve_logit + + rng = np.random.default_rng(42) + n = 200 + # Create near-perfect separation: large coefficient -> probs near 0/1 + X = rng.standard_normal((n, 1)) + y = (X[:, 0] > 0).astype(float) + # Add a tiny bit of noise to avoid exact separation + flip_idx = rng.choice(n, size=3, replace=False) + y[flip_idx] = 1 - y[flip_idx] + + with pytest.warns(UserWarning, match="Near-separation detected"): + solve_logit(X, y) + + def test_predicted_probabilities_valid(self): + """Predicted probabilities are in (0, 1).""" + from diff_diff.linalg import solve_logit + + rng = np.random.default_rng(42) + n = 100 + X = rng.standard_normal((n, 2)) + y = (rng.random(n) < 0.5).astype(float) + + _, probs = solve_logit(X, y) + assert np.all(probs > 0) and np.all(probs < 1) + + def test_rank_deficient_design_matrix(self): + """Handles rank-deficient X in logistic regression.""" + from diff_diff.linalg import solve_logit + + rng = np.random.default_rng(42) + n = 100 + x1 = rng.standard_normal(n) + # x2 is a duplicate of x1 -> rank deficient + X = np.column_stack([x1, x1]) + y = (rng.random(n) < 0.5).astype(float) + + with pytest.warns(UserWarning, match="Rank-deficient"): + beta, probs = solve_logit(X, y) + + assert beta.shape == (3,) # intercept + 2 features + assert probs.shape == (n,) + + def test_rank_deficient_action_silent(self): + """rank_deficient_action='silent' suppresses warning on rank-deficient X.""" + from diff_diff.linalg import solve_logit + + rng = np.random.default_rng(42) + n = 100 + x1 = rng.standard_normal(n) + X = np.column_stack([x1, x1]) # rank deficient + y = (rng.random(n) < 0.5).astype(float) + + import warnings + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + beta, probs = solve_logit(X, y, rank_deficient_action="silent") + + rank_warns = [x for x in w if "Rank-deficient" in str(x.message)] + assert len(rank_warns) == 0 + assert beta.shape == (3,) + assert probs.shape == (n,) + + def test_rank_deficient_action_error(self): + """rank_deficient_action='error' raises ValueError on rank-deficient X.""" + from diff_diff.linalg import solve_logit + + rng = np.random.default_rng(42) + n = 100 + x1 = rng.standard_normal(n) + X = np.column_stack([x1, x1]) # rank deficient + y = (rng.random(n) < 0.5).astype(float) + + with pytest.raises(ValueError, match="Rank-deficient"): + solve_logit(X, y, rank_deficient_action="error") + + +class TestCheckPropensityDiagnostics: + """Tests for propensity score diagnostic warnings.""" + + def test_no_warning_normal_scores(self): + """No warning when all scores are within bounds.""" + from diff_diff.linalg import _check_propensity_diagnostics + + import warnings + + pscore = np.array([0.3, 0.5, 0.7, 0.4, 0.6]) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + _check_propensity_diagnostics(pscore, trim_bound=0.01) + user_warns = [x for x in w if issubclass(x.category, UserWarning)] + assert len(user_warns) == 0 + + def test_warning_extreme_scores(self): + """Warns when propensity scores are near 0 or 1.""" + from diff_diff.linalg import _check_propensity_diagnostics + + pscore = np.array([0.001, 0.5, 0.999, 0.3, 0.7]) + with pytest.warns(UserWarning, match="outside"): + _check_propensity_diagnostics(pscore, trim_bound=0.01) + + class TestNoDotRuntimeWarnings: """Verify np.dot replacement avoids Apple M4 BLAS ufunc FPE bug.""" diff --git a/tests/test_methodology_triple_diff.py b/tests/test_methodology_triple_diff.py index 29e9dc0..861c751 100644 --- a/tests/test_methodology_triple_diff.py +++ b/tests/test_methodology_triple_diff.py @@ -51,16 +51,15 @@ def _check_triplediff_available() -> bool: try: result = subprocess.run( [ - "Rscript", "-e", + "Rscript", + "-e", "library(triplediff); library(jsonlite); cat('OK')", ], capture_output=True, text=True, timeout=30, ) - _triplediff_available_cache = ( - result.returncode == 0 and "OK" in result.stdout - ) + _triplediff_available_cache = result.returncode == 0 and "OK" in result.stdout except (subprocess.TimeoutExpired, FileNotFoundError, OSError): _triplediff_available_cache = False return _triplediff_available_cache @@ -83,9 +82,7 @@ def require_triplediff(triplediff_available): # Data Helpers # ============================================================================= -_R_RESULTS_PATH = ( - _REPO_ROOT / "benchmarks" / "data" / "synthetic" / "ddd_r_results.json" -) +_R_RESULTS_PATH = _REPO_ROOT / "benchmarks" / "data" / "synthetic" / "ddd_r_results.json" def _load_r_results(): @@ -112,32 +109,34 @@ def _generate_hand_calculable_ddd() -> pd.DataFrame: DiD(control): (3 - 2) = 1 DDD = 4 - 1 = 3.0 """ - data = pd.DataFrame({ - "outcome": [10, 10, 18, 18, 6, 6, 10, 10, 5, 5, 8, 8, 3, 3, 5, 5], - "group": [ 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], - "partition":[1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0], - "time": [ 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1], - "unit_id": list(range(16)), - }) + data = pd.DataFrame( + { + "outcome": [10, 10, 18, 18, 6, 6, 10, 10, 5, 5, 8, 8, 3, 3, 5, 5], + "group": [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + "partition": [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0], + "time": [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1], + "unit_id": list(range(16)), + } + ) return data def _load_r_dgp_data(dgp_num: int) -> pd.DataFrame: """Load R-generated DGP data, mapping columns to Python convention.""" - csv_path = ( - _REPO_ROOT / "benchmarks" / "data" / "synthetic" / f"ddd_r_dgp{dgp_num}.csv" - ) + csv_path = _REPO_ROOT / "benchmarks" / "data" / "synthetic" / f"ddd_r_dgp{dgp_num}.csv" if not csv_path.exists(): pytest.skip(f"R DGP{dgp_num} data CSV not available") df = pd.read_csv(csv_path) # Map R columns to Python convention - df = df.rename(columns={ - "y": "outcome", - "state": "group", - "partition": "partition", - "time": "time", - "id": "unit_id", - }) + df = df.rename( + columns={ + "y": "outcome", + "state": "group", + "partition": "partition", + "time": "time", + "id": "unit_id", + } + ) # R uses time in {1, 2}, map to {0, 1} df["time"] = (df["time"] - 1).astype(int) return df @@ -152,7 +151,7 @@ def _run_r_triplediff( escaped_path = data_path.replace("\\", "/") xformla = "~cov1+cov2+cov3+cov4" if covariates else "~1" - r_script = f''' + r_script = f""" suppressMessages(library(triplediff)) suppressMessages(library(jsonlite)) @@ -179,7 +178,7 @@ def _run_r_triplediff( ) cat(toJSON(output, pretty = TRUE, digits = 15)) - ''' + """ result = subprocess.run( ["Rscript", "-e", r_script], @@ -217,7 +216,9 @@ def test_att_hand_calculation_no_covariates(self): time="time", ) np.testing.assert_allclose( - results.att, 3.0, atol=1e-10, + results.att, + 3.0, + atol=1e-10, err_msg=f"ATT ({method}) should be 3.0 by hand calculation", ) @@ -231,12 +232,18 @@ def test_att_reg_matches_ols_interaction(self): T = data["time"].values.astype(float) y = data["outcome"].values.astype(float) - X = np.column_stack([ - np.ones(len(y)), - G, P, T, - G * P, G * T, P * T, - G * P * T, - ]) + X = np.column_stack( + [ + np.ones(len(y)), + G, + P, + T, + G * P, + G * T, + P * T, + G * P * T, + ] + ) beta_ols = np.linalg.lstsq(X, y, rcond=None)[0] ols_att = beta_ols[7] # coefficient on G*P*T @@ -251,7 +258,9 @@ def test_att_reg_matches_ols_interaction(self): ) np.testing.assert_allclose( - results.att, ols_att, rtol=1e-6, + results.att, + ols_att, + rtol=1e-6, err_msg="RA ATT should match G*P*T OLS coefficient (no covariates)", ) @@ -272,11 +281,15 @@ def test_all_methods_agree_no_covariates(self): atts[method] = results.att np.testing.assert_allclose( - atts["ipw"], atts["reg"], rtol=1e-6, + atts["ipw"], + atts["reg"], + rtol=1e-6, err_msg="IPW and REG ATT should agree without covariates", ) np.testing.assert_allclose( - atts["dr"], atts["reg"], rtol=1e-6, + atts["dr"], + atts["reg"], + rtol=1e-6, err_msg="DR and REG ATT should agree without covariates", ) @@ -297,11 +310,15 @@ def test_all_methods_se_agree_no_covariates(self): ses[method] = results.se np.testing.assert_allclose( - ses["ipw"], ses["reg"], rtol=1e-4, + ses["ipw"], + ses["reg"], + rtol=1e-4, err_msg="IPW and REG SE should agree without covariates", ) np.testing.assert_allclose( - ses["dr"], ses["reg"], rtol=1e-4, + ses["dr"], + ses["reg"], + rtol=1e-4, err_msg="DR and REG SE should agree without covariates", ) @@ -349,7 +366,10 @@ def test_safe_inference_used(self): df = max(df, 1) t_stat, p_value, conf_int = safe_inference( - results.att, results.se, alpha=0.05, df=df, + results.att, + results.se, + alpha=0.05, + df=df, ) np.testing.assert_allclose(results.t_stat, t_stat, rtol=1e-10) @@ -384,7 +404,9 @@ def test_cell_means_match_direct_computation(self): for cell, expected in expected_means.items(): actual = results.group_means[cell] np.testing.assert_allclose( - actual, expected, atol=1e-10, + actual, + expected, + atol=1e-10, err_msg=f"Cell mean mismatch for {cell}", ) @@ -430,12 +452,16 @@ def test_att_no_covariates_matches_r_dgp1(self, r_results, method): # Use atol for near-zero ATTs if abs(r_att) < 0.1: np.testing.assert_allclose( - results.att, r_att, atol=0.05, + results.att, + r_att, + atol=0.05, err_msg=f"ATT ({method} nocov DGP1): Py={results.att:.6f}, R={r_att:.6f}", ) else: np.testing.assert_allclose( - results.att, r_att, rtol=0.01, + results.att, + r_att, + rtol=0.01, err_msg=f"ATT ({method} nocov DGP1): Py={results.att:.6f}, R={r_att:.6f}", ) @@ -456,7 +482,9 @@ def test_se_no_covariates_matches_r_dgp1(self, r_results, method): ) np.testing.assert_allclose( - results.se, r_se, rtol=0.01, + results.se, + r_se, + rtol=0.01, err_msg=f"SE ({method} nocov DGP1): Py={results.se:.6f}, R={r_se:.6f}", ) @@ -480,12 +508,16 @@ def test_att_with_covariates_matches_r_dgp1(self, r_results, method): if abs(r_att) < 0.1: np.testing.assert_allclose( - results.att, r_att, atol=0.05, + results.att, + r_att, + atol=0.05, err_msg=f"ATT ({method} cov DGP1): Py={results.att:.6f}, R={r_att:.6f}", ) else: np.testing.assert_allclose( - results.att, r_att, rtol=0.01, + results.att, + r_att, + rtol=0.01, err_msg=f"ATT ({method} cov DGP1): Py={results.att:.6f}, R={r_att:.6f}", ) @@ -508,7 +540,9 @@ def test_se_with_covariates_matches_r_dgp1(self, r_results, method): ) np.testing.assert_allclose( - results.se, r_se, rtol=0.01, + results.se, + r_se, + rtol=0.01, err_msg=f"SE ({method} cov DGP1): Py={results.se:.6f}, R={r_se:.6f}", ) @@ -536,18 +570,24 @@ def test_dr_robust_across_dgp_types(self, r_results, dgp): # ATT check if abs(r_att) < 0.1: np.testing.assert_allclose( - results.att, r_att, atol=0.05, + results.att, + r_att, + atol=0.05, err_msg=f"DR ATT (DGP{dgp} {cov_suffix}): Py={results.att:.6f}, R={r_att:.6f}", ) else: np.testing.assert_allclose( - results.att, r_att, rtol=0.01, + results.att, + r_att, + rtol=0.01, err_msg=f"DR ATT (DGP{dgp} {cov_suffix}): Py={results.att:.6f}, R={r_att:.6f}", ) # SE check np.testing.assert_allclose( - results.se, r_se, rtol=0.01, + results.se, + r_se, + rtol=0.01, err_msg=f"DR SE (DGP{dgp} {cov_suffix}): Py={results.se:.6f}, R={r_se:.6f}", ) @@ -572,13 +612,15 @@ def shared_data_csv(self, tmp_path_factory): csv_path = tmp_dir / "ddd_data.csv" # Map to R column convention - r_data = data.rename(columns={ - "outcome": "y", - "group": "state", - "partition": "partition", - "time": "time", - "unit_id": "id", - }) + r_data = data.rename( + columns={ + "outcome": "y", + "group": "state", + "partition": "partition", + "time": "time", + "unit_id": "id", + } + ) # R expects time in {1, 2} r_data["time"] = r_data["time"] + 1 # Add covariate columns named cov1-cov4 if they exist @@ -610,12 +652,16 @@ def test_live_att_no_cov(self, require_triplediff, shared_data_csv, method): if abs(r_att) < 0.1: np.testing.assert_allclose( - py_result.att, r_att, atol=0.05, + py_result.att, + r_att, + atol=0.05, err_msg=f"Live ATT ({method} nocov): Py={py_result.att:.6f}, R={r_att:.6f}", ) else: np.testing.assert_allclose( - py_result.att, r_att, rtol=0.01, + py_result.att, + r_att, + rtol=0.01, err_msg=f"Live ATT ({method} nocov): Py={py_result.att:.6f}, R={r_att:.6f}", ) @@ -637,7 +683,9 @@ def test_live_se_no_cov(self, require_triplediff, shared_data_csv, method): ) np.testing.assert_allclose( - py_result.se, r_se, rtol=0.01, + py_result.se, + r_se, + rtol=0.01, err_msg=f"Live SE ({method} nocov): Py={py_result.se:.6f}, R={r_se:.6f}", ) @@ -665,15 +713,17 @@ def test_small_sample_sizes(self): ) assert np.isfinite(results.att), f"ATT should be finite ({method})" - assert np.isfinite(results.se) and results.se > 0, ( - f"SE should be positive and finite ({method})" - ) + assert ( + np.isfinite(results.se) and results.se > 0 + ), f"SE should be positive and finite ({method})" assert results.n_obs == 24 # 8 cells × 3 def test_zero_treatment_effect(self): """ATT near zero when true effect is zero; inference still valid.""" data = generate_ddd_data( - n_per_cell=200, treatment_effect=0.0, seed=42, + n_per_cell=200, + treatment_effect=0.0, + seed=42, ) ddd = TripleDifference(estimation_method="dr") @@ -686,9 +736,9 @@ def test_zero_treatment_effect(self): ) # ATT should be near zero (within ~2 SE) - assert abs(results.att) < 2 * results.se, ( - f"ATT={results.att:.4f} too far from zero (SE={results.se:.4f})" - ) + assert ( + abs(results.att) < 2 * results.se + ), f"ATT={results.att:.4f} too far from zero (SE={results.se:.4f})" # Inference should still be valid assert np.isfinite(results.t_stat) assert 0 <= results.p_value <= 1 @@ -702,7 +752,7 @@ def test_pscore_trimming_active(self): unit_id = 0 # Heavily imbalanced: 5 in treated eligible, 200 in control ineligible sizes = { - (1, 1): 5, # G=1, P=1 (very small) + (1, 1): 5, # G=1, P=1 (very small) (1, 0): 200, # G=1, P=0 (large) (0, 1): 200, # G=0, P=1 (large) (0, 0): 200, # G=0, P=0 (large) @@ -713,13 +763,15 @@ def test_pscore_trimming_active(self): y = 10 + 2 * g + 1 * p + 0.5 * t + rng.normal(0, 1) if g == 1 and p == 1 and t == 1: y += 3.0 - records.append({ - "outcome": y, - "group": g, - "partition": p, - "time": t, - "unit_id": unit_id, - }) + records.append( + { + "outcome": y, + "group": g, + "partition": p, + "time": t, + "unit_id": unit_id, + } + ) unit_id += 1 data = pd.DataFrame(records) @@ -738,16 +790,32 @@ def test_pscore_trimming_active(self): def test_nan_inference_when_se_zero(self): """All inference fields are NaN when SE is zero or invalid.""" # Create perfectly deterministic data (zero variance in all cells) - data = pd.DataFrame({ - "outcome": [10.0, 10.0, 18.0, 18.0, - 6.0, 6.0, 10.0, 10.0, - 5.0, 5.0, 8.0, 8.0, - 3.0, 3.0, 5.0, 5.0], - "group": [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], - "partition": [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0], - "time": [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1], - "unit_id": list(range(16)), - }) + data = pd.DataFrame( + { + "outcome": [ + 10.0, + 10.0, + 18.0, + 18.0, + 6.0, + 6.0, + 10.0, + 10.0, + 5.0, + 5.0, + 8.0, + 8.0, + 3.0, + 3.0, + 5.0, + 5.0, + ], + "group": [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + "partition": [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0], + "time": [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1], + "unit_id": list(range(16)), + } + ) ddd = TripleDifference(estimation_method="reg") results = ddd.fit( @@ -761,17 +829,22 @@ def test_nan_inference_when_se_zero(self): # With zero within-cell variance, SE should be zero # and safe_inference should produce NaN t_stat/p_value if results.se == 0.0: - assert_nan_inference({ - "se": results.se, - "t_stat": results.t_stat, - "p_value": results.p_value, - "conf_int": results.conf_int, - }) + assert_nan_inference( + { + "se": results.se, + "t_stat": results.t_stat, + "p_value": results.p_value, + "conf_int": results.conf_int, + } + ) def test_large_treatment_effect(self): """Large treatment effect is detected correctly.""" data = generate_ddd_data( - n_per_cell=100, treatment_effect=50.0, noise_sd=1.0, seed=42, + n_per_cell=100, + treatment_effect=50.0, + noise_sd=1.0, + seed=42, ) ddd = TripleDifference(estimation_method="dr") @@ -784,7 +857,9 @@ def test_large_treatment_effect(self): ) np.testing.assert_allclose( - results.att, 50.0, rtol=0.1, + results.att, + 50.0, + rtol=0.1, err_msg=f"ATT={results.att:.2f} should be near 50.0", ) assert results.p_value < 0.001, "Large effect should be highly significant" @@ -792,7 +867,9 @@ def test_large_treatment_effect(self): def test_covariates_reduce_se(self): """Adding relevant covariates reduces SE.""" data = generate_ddd_data( - n_per_cell=200, seed=42, add_covariates=True, + n_per_cell=200, + seed=42, + add_covariates=True, ) # Without covariates @@ -835,7 +912,9 @@ def test_att_converges_to_true_effect(self, n_per_cell): """ATT converges to true effect as sample size increases.""" true_effect = 3.0 data = generate_ddd_data( - n_per_cell=n_per_cell, treatment_effect=true_effect, seed=42, + n_per_cell=n_per_cell, + treatment_effect=true_effect, + seed=42, ) ddd = TripleDifference(estimation_method="dr") @@ -870,9 +949,7 @@ def test_se_decreases_with_sample_size(self): # Quadrupling n should halve SE (approximately) se_ratio = ses[100] / ses[400] - assert 1.3 < se_ratio < 3.0, ( - f"SE ratio (n=100/n=400) = {se_ratio:.2f}, expected ~2.0" - ) + assert 1.3 < se_ratio < 3.0, f"SE ratio (n=100/n=400) = {se_ratio:.2f}, expected ~2.0" # ============================================================================= @@ -911,12 +988,16 @@ def test_att_nocov_all_dgps(self, r_results, dgp, method): if abs(r_att) < 0.1: np.testing.assert_allclose( - results.att, r_att, atol=0.05, + results.att, + r_att, + atol=0.05, err_msg=f"ATT ({method} nocov DGP{dgp})", ) else: np.testing.assert_allclose( - results.att, r_att, rtol=0.01, + results.att, + r_att, + rtol=0.01, err_msg=f"ATT ({method} nocov DGP{dgp})", ) @@ -938,7 +1019,9 @@ def test_se_nocov_all_dgps(self, r_results, dgp, method): ) np.testing.assert_allclose( - results.se, r_se, rtol=0.01, + results.se, + r_se, + rtol=0.01, err_msg=f"SE ({method} nocov DGP{dgp})", ) @@ -963,12 +1046,16 @@ def test_att_cov_all_dgps(self, r_results, dgp, method): if abs(r_att) < 0.1: np.testing.assert_allclose( - results.att, r_att, atol=0.05, + results.att, + r_att, + atol=0.05, err_msg=f"ATT ({method} cov DGP{dgp})", ) else: np.testing.assert_allclose( - results.att, r_att, rtol=0.01, + results.att, + r_att, + rtol=0.01, err_msg=f"ATT ({method} cov DGP{dgp})", ) @@ -992,7 +1079,9 @@ def test_se_cov_all_dgps(self, r_results, dgp, method): ) np.testing.assert_allclose( - results.se, r_se, rtol=0.01, + results.se, + r_se, + rtol=0.01, err_msg=f"SE ({method} cov DGP{dgp})", ) @@ -1011,12 +1100,16 @@ def test_get_params_returns_all_parameters(self): params = ddd.get_params() expected_keys = { - "estimation_method", "robust", "cluster", "alpha", - "pscore_trim", "rank_deficient_action", + "estimation_method", + "robust", + "cluster", + "alpha", + "pscore_trim", + "rank_deficient_action", } - assert expected_keys.issubset(params.keys()), ( - f"Missing params: {expected_keys - params.keys()}" - ) + assert expected_keys.issubset( + params.keys() + ), f"Missing params: {expected_keys - params.keys()}" def test_set_params_modifies_attributes(self): """set_params() modifies estimator attributes.""" @@ -1039,8 +1132,15 @@ def test_to_dict_contains_required_fields(self): ) d = results.to_dict() - for key in ["att", "se", "t_stat", "p_value", "n_obs", - "estimation_method", "inference_method"]: + for key in [ + "att", + "se", + "t_stat", + "p_value", + "n_obs", + "estimation_method", + "inference_method", + ]: assert key in d, f"Missing key '{key}' in to_dict()" def test_summary_contains_key_info(self): @@ -1117,14 +1217,13 @@ def test_rank_deficient_action_warn(self): covariates=["age", "age_dup"], ) rank_warnings = [ - x for x in w + x + for x in w if "rank" in str(x.message).lower() or "collinear" in str(x.message).lower() or "dependent" in str(x.message).lower() ] - assert len(rank_warnings) > 0, ( - "Expected rank deficiency warning for collinear covariates" - ) + assert len(rank_warnings) > 0, "Expected rank deficiency warning for collinear covariates" assert np.isfinite(result.att) def test_rank_deficient_action_silent(self): @@ -1133,7 +1232,8 @@ def test_rank_deficient_action_silent(self): data["age_dup"] = data["age"] ddd = TripleDifference( - estimation_method="reg", rank_deficient_action="silent", + estimation_method="reg", + rank_deficient_action="silent", ) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") @@ -1146,14 +1246,13 @@ def test_rank_deficient_action_silent(self): covariates=["age", "age_dup"], ) rank_warnings = [ - x for x in w + x + for x in w if "rank" in str(x.message).lower() or "collinear" in str(x.message).lower() or "dependent" in str(x.message).lower() ] - assert len(rank_warnings) == 0, ( - "Expected no rank deficiency warnings with action='silent'" - ) + assert len(rank_warnings) == 0, "Expected no rank deficiency warnings with action='silent'" assert np.isfinite(result.att) def test_cluster_se_functional(self): @@ -1201,12 +1300,10 @@ def test_low_cell_count_warning(self): partition="partition", time="time", ) - low_count_warnings = [ - x for x in w if "low observation" in str(x.message).lower() - ] - assert len(low_count_warnings) > 0, ( - "Expected low observation count warning for n_per_cell=5" - ) + low_count_warnings = [x for x in w if "low observation" in str(x.message).lower()] + assert ( + len(low_count_warnings) > 0 + ), "Expected low observation count warning for n_per_cell=5" assert np.isfinite(result.att) def test_robust_param_is_noop(self): @@ -1238,8 +1335,7 @@ def test_cluster_single_cluster_raises(self): ddd = TripleDifference(estimation_method="dr", cluster="cluster_id") with pytest.raises(ValueError, match="at least 2 clusters"): - ddd.fit(data, outcome="outcome", group="group", - partition="partition", time="time") + ddd.fit(data, outcome="outcome", group="group", partition="partition", time="time") def test_cluster_nan_ids_raises(self): """NaN cluster IDs raise ValueError.""" @@ -1249,8 +1345,7 @@ def test_cluster_nan_ids_raises(self): ddd = TripleDifference(estimation_method="dr", cluster="cluster_id") with pytest.raises(ValueError, match="missing values"): - ddd.fit(data, outcome="outcome", group="group", - partition="partition", time="time") + ddd.fit(data, outcome="outcome", group="group", partition="partition", time="time") def test_overlap_warning_on_imbalanced_data(self): """Poor overlap triggers warning for IPW/DR.""" @@ -1265,21 +1360,35 @@ def test_overlap_warning_on_imbalanced_data(self): y = 10 + 2 * g + p + 0.5 * t + rng.normal(0, 1) if g == 1 and p == 1 and t == 1: y += 3.0 - records.append({"outcome": y, "group": g, "partition": p, - "time": t, "unit_id": unit_id, - "cov1": rng.normal(0, 1)}) + records.append( + { + "outcome": y, + "group": g, + "partition": p, + "time": t, + "unit_id": unit_id, + "cov1": rng.normal(0, 1), + } + ) unit_id += 1 data = pd.DataFrame(records) ddd = TripleDifference(estimation_method="ipw") with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - result = ddd.fit(data, outcome="outcome", group="group", - partition="partition", time="time", - covariates=["cov1"]) - overlap_warnings = [x for x in w - if "overlap" in str(x.message).lower() - and "trimmed" in str(x.message).lower()] + result = ddd.fit( + data, + outcome="outcome", + group="group", + partition="partition", + time="time", + covariates=["cov1"], + ) + overlap_warnings = [ + x + for x in w + if "overlap" in str(x.message).lower() and "trimmed" in str(x.message).lower() + ] assert len(overlap_warnings) > 0 assert np.isfinite(result.att) @@ -1295,18 +1404,30 @@ def test_no_overlap_warning_for_reg(self): y = 10 + 2 * g + p + 0.5 * t + rng.normal(0, 1) if g == 1 and p == 1 and t == 1: y += 3.0 - records.append({"outcome": y, "group": g, "partition": p, - "time": t, "unit_id": unit_id, - "cov1": rng.normal(0, 1)}) + records.append( + { + "outcome": y, + "group": g, + "partition": p, + "time": t, + "unit_id": unit_id, + "cov1": rng.normal(0, 1), + } + ) unit_id += 1 data = pd.DataFrame(records) ddd = TripleDifference(estimation_method="reg") with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - result = ddd.fit(data, outcome="outcome", group="group", - partition="partition", time="time", - covariates=["cov1"]) + result = ddd.fit( + data, + outcome="outcome", + group="group", + partition="partition", + time="time", + covariates=["cov1"], + ) overlap_warnings = [x for x in w if "overlap" in str(x.message).lower()] assert len(overlap_warnings) == 0 @@ -1319,17 +1440,23 @@ def _failing_lr(*args, **kwargs): raise RuntimeError("Forced PS failure for testing") import diff_diff.triple_diff as td_module - monkeypatch.setattr(td_module, "_logistic_regression", _failing_lr) + + monkeypatch.setattr(td_module, "solve_logit", _failing_lr) ddd = TripleDifference(estimation_method=method) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - result = ddd.fit(data, outcome="outcome", group="group", - partition="partition", time="time", - covariates=["age"]) - ps_warnings = [x for x in w - if "propensity score estimation failed" - in str(x.message).lower()] + result = ddd.fit( + data, + outcome="outcome", + group="group", + partition="partition", + time="time", + covariates=["age"], + ) + ps_warnings = [ + x for x in w if "propensity score estimation failed" in str(x.message).lower() + ] assert len(ps_warnings) > 0, "Expected PS fallback warning" assert np.isfinite(result.att) assert np.isfinite(result.se) and result.se > 0 @@ -1349,26 +1476,33 @@ def _did_rc_with_nan(self_inner, *args, **kwargs): return att, inf monkeypatch.setattr( - td_module.TripleDifference, "_compute_did_rc", _did_rc_with_nan, + td_module.TripleDifference, + "_compute_did_rc", + _did_rc_with_nan, ) ddd = TripleDifference(estimation_method=method) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - result = ddd.fit(data, outcome="outcome", group="group", - partition="partition", time="time", - covariates=["age"]) - nonfinite_warnings = [ - x for x in w if "non-finite" in str(x.message).lower() - ] + result = ddd.fit( + data, + outcome="outcome", + group="group", + partition="partition", + time="time", + covariates=["age"], + ) + nonfinite_warnings = [x for x in w if "non-finite" in str(x.message).lower()] assert len(nonfinite_warnings) > 0, "Expected non-finite IF warning" assert np.isnan(result.se), "SE should be NaN when IF has non-finite values" - assert_nan_inference({ - "se": result.se, - "t_stat": result.t_stat, - "p_value": result.p_value, - "conf_int": result.conf_int, - }) + assert_nan_inference( + { + "se": result.se, + "t_stat": result.t_stat, + "p_value": result.p_value, + "conf_int": result.conf_int, + } + ) def test_r_squared_respects_rank_deficient_action(self): """r_squared computation uses estimator's rank_deficient_action, not hardcoded 'silent'.""" @@ -1377,36 +1511,69 @@ def test_r_squared_respects_rank_deficient_action(self): # "silent" should suppress ALL rank warnings (both main and r_squared paths) ddd_silent = TripleDifference( - estimation_method="reg", rank_deficient_action="silent", + estimation_method="reg", + rank_deficient_action="silent", ) with warnings.catch_warnings(record=True) as w_silent: warnings.simplefilter("always") result_silent = ddd_silent.fit( - data, outcome="outcome", group="group", - partition="partition", time="time", + data, + outcome="outcome", + group="group", + partition="partition", + time="time", covariates=["age", "age_dup"], ) - rank_silent = [x for x in w_silent - if "rank" in str(x.message).lower() - or "dependent" in str(x.message).lower()] + rank_silent = [ + x + for x in w_silent + if "rank" in str(x.message).lower() or "dependent" in str(x.message).lower() + ] # "warn" should emit rank warnings from both main and r_squared paths ddd_warn = TripleDifference( - estimation_method="reg", rank_deficient_action="warn", + estimation_method="reg", + rank_deficient_action="warn", ) with warnings.catch_warnings(record=True) as w_warn: warnings.simplefilter("always") result_warn = ddd_warn.fit( - data, outcome="outcome", group="group", - partition="partition", time="time", + data, + outcome="outcome", + group="group", + partition="partition", + time="time", covariates=["age", "age_dup"], ) - rank_warn = [x for x in w_warn - if "rank" in str(x.message).lower() - or "dependent" in str(x.message).lower()] + rank_warn = [ + x + for x in w_warn + if "rank" in str(x.message).lower() or "dependent" in str(x.message).lower() + ] assert len(rank_silent) == 0, "silent should suppress all rank warnings" assert len(rank_warn) > 0, "warn should emit rank warnings" # Both should produce finite results regardless assert np.isfinite(result_silent.att) assert np.isfinite(result_warn.att) + + @pytest.mark.parametrize("method", ["ipw", "dr"]) + def test_rank_deficient_action_error_raises_in_ps_path(self, method): + """rank_deficient_action='error' raises ValueError in PS-based paths with collinear covariates.""" + data = generate_ddd_data(n_per_cell=50, seed=42, add_covariates=True) + data["age_dup"] = data["age"].copy() + + ddd = TripleDifference( + estimation_method=method, + rank_deficient_action="error", + ) + + with pytest.raises((ValueError, RuntimeError)): + ddd.fit( + data, + outcome="outcome", + group="group", + partition="partition", + time="time", + covariates=["age", "age_dup"], + ) diff --git a/tests/test_staggered.py b/tests/test_staggered.py index 715b88c..84082e9 100644 --- a/tests/test_staggered.py +++ b/tests/test_staggered.py @@ -55,11 +55,7 @@ def test_basic_fit(self): cs = CallawaySantAnna() results = cs.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) assert cs.is_fitted_ @@ -74,11 +70,7 @@ def test_positive_treatment_effect(self): cs = CallawaySantAnna() results = cs.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) # Should detect positive effect @@ -92,11 +84,7 @@ def test_zero_treatment_effect(self): cs = CallawaySantAnna() results = cs.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) # Effect should be close to zero @@ -121,9 +109,9 @@ def test_never_treated_inf_encoding(self): ) # Results should be identical - assert np.isclose(results_inf.overall_att, results_zero.overall_att), ( - f"ATT differs: inf={results_inf.overall_att}, zero={results_zero.overall_att}" - ) + assert np.isclose( + results_inf.overall_att, results_zero.overall_att + ), f"ATT differs: inf={results_inf.overall_att}, zero={results_zero.overall_att}" def test_event_study_aggregation(self): """Test event study aggregation.""" @@ -132,11 +120,11 @@ def test_event_study_aggregation(self): cs = CallawaySantAnna() results = cs.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - aggregate='event_study' + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", ) assert results.event_study_effects is not None @@ -153,11 +141,11 @@ def test_group_aggregation(self): cs = CallawaySantAnna() results = cs.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - aggregate='group' + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="group", ) assert results.group_effects is not None @@ -170,11 +158,11 @@ def test_all_aggregation(self): cs = CallawaySantAnna() results = cs.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - aggregate='all' + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="all", ) assert results.event_study_effects is not None @@ -187,21 +175,13 @@ def test_control_group_options(self): # Never treated only cs1 = CallawaySantAnna(control_group="never_treated") results1 = cs1.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) # Not yet treated cs2 = CallawaySantAnna(control_group="not_yet_treated") results2 = cs2.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) assert results1.control_group == "never_treated" @@ -219,11 +199,7 @@ def test_estimation_methods(self): for method in methods: cs = CallawaySantAnna(estimation_method=method) results[method] = cs.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) # All methods should produce results @@ -237,11 +213,11 @@ def test_summary_output(self): cs = CallawaySantAnna() results = cs.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - aggregate='event_study' + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", ) summary = results.summary() @@ -257,34 +233,34 @@ def test_to_dataframe(self): cs = CallawaySantAnna() results = cs.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - aggregate='all' + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="all", ) # Group-time DataFrame - df_gt = results.to_dataframe(level='group_time') - assert 'group' in df_gt.columns - assert 'time' in df_gt.columns - assert 'effect' in df_gt.columns + df_gt = results.to_dataframe(level="group_time") + assert "group" in df_gt.columns + assert "time" in df_gt.columns + assert "effect" in df_gt.columns # Event study DataFrame - df_es = results.to_dataframe(level='event_study') - assert 'relative_period' in df_es.columns + df_es = results.to_dataframe(level="event_study") + assert "relative_period" in df_es.columns # Group DataFrame - df_g = results.to_dataframe(level='group') - assert 'group' in df_g.columns + df_g = results.to_dataframe(level="group") + assert "group" in df_g.columns def test_get_set_params(self): """Test sklearn-compatible parameter access.""" cs = CallawaySantAnna(alpha=0.10, control_group="not_yet_treated") params = cs.get_params() - assert params['alpha'] == 0.10 - assert params['control_group'] == "not_yet_treated" + assert params["alpha"] == 0.10 + assert params["control_group"] == "not_yet_treated" cs.set_params(alpha=0.05) assert cs.alpha == 0.05 @@ -296,13 +272,7 @@ def test_missing_column_error(self): cs = CallawaySantAnna() with pytest.raises(ValueError, match="Missing columns"): - cs.fit( - data, - outcome='nonexistent', - unit='unit', - time='time', - first_treat='first_treat' - ) + cs.fit(data, outcome="nonexistent", unit="unit", time="time", first_treat="first_treat") def test_no_control_units_error(self): """Test error when no control units exist.""" @@ -312,13 +282,7 @@ def test_no_control_units_error(self): cs = CallawaySantAnna() with pytest.raises(ValueError, match="No never-treated units"): - cs.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' - ) + cs.fit(data, outcome="outcome", unit="unit", time="time", first_treat="first_treat") def test_significance_properties(self): """Test significance-related properties.""" @@ -326,11 +290,7 @@ def test_significance_properties(self): cs = CallawaySantAnna() results = cs.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) # With strong effect, should be significant @@ -346,11 +306,7 @@ def test_repr(self): data = generate_staggered_data() cs = CallawaySantAnna() results = cs.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) repr_str = repr(results) @@ -362,30 +318,22 @@ def test_invalid_level_error(self): data = generate_staggered_data() cs = CallawaySantAnna() results = cs.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) with pytest.raises(ValueError, match="Unknown level"): - results.to_dataframe(level='invalid') + results.to_dataframe(level="invalid") def test_event_study_not_computed_error(self): """Test error when event study not computed.""" data = generate_staggered_data() cs = CallawaySantAnna() results = cs.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) with pytest.raises(ValueError, match="Event study effects not computed"): - results.to_dataframe(level='event_study') + results.to_dataframe(level="event_study") def generate_staggered_data_with_covariates( @@ -441,22 +389,24 @@ def generate_staggered_data_with_covariates( # Outcome depends on covariates outcomes = ( - unit_fe_expanded + - time_fe_expanded + - covariate_effect * x1_expanded + # covariate effect - 0.5 * x2_expanded + # second covariate effect - treatment_effect * post + - np.random.randn(len(units)) * 0.5 + unit_fe_expanded + + time_fe_expanded + + covariate_effect * x1_expanded # covariate effect + + 0.5 * x2_expanded # second covariate effect + + treatment_effect * post + + np.random.randn(len(units)) * 0.5 ) - df = pd.DataFrame({ - 'unit': units, - 'time': times, - 'outcome': outcomes, - 'first_treat': first_treat_expanded.astype(int), - 'x1': x1_expanded, - 'x2': x2_expanded, - }) + df = pd.DataFrame( + { + "unit": units, + "time": times, + "outcome": outcomes, + "first_treat": first_treat_expanded.astype(int), + "x1": x1_expanded, + "x2": x2_expanded, + } + ) return df @@ -471,22 +421,18 @@ def test_covariates_are_used(self): # Fit without covariates cs1 = CallawaySantAnna() results1 = cs1.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) # Fit with covariates cs2 = CallawaySantAnna() results2 = cs2.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - covariates=['x1', 'x2'] + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + covariates=["x1", "x2"], ) # Both should produce valid results @@ -502,14 +448,14 @@ def test_outcome_regression_with_covariates(self): """Test outcome regression method with covariates.""" data = generate_staggered_data_with_covariates(seed=123) - cs = CallawaySantAnna(estimation_method='reg') + cs = CallawaySantAnna(estimation_method="reg") results = cs.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - covariates=['x1', 'x2'] + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + covariates=["x1", "x2"], ) assert results.overall_att is not None @@ -520,14 +466,14 @@ def test_ipw_with_covariates(self): """Test IPW method with covariates.""" data = generate_staggered_data_with_covariates(seed=456) - cs = CallawaySantAnna(estimation_method='ipw') + cs = CallawaySantAnna(estimation_method="ipw") results = cs.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - covariates=['x1', 'x2'] + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + covariates=["x1", "x2"], ) assert results.overall_att is not None @@ -538,14 +484,14 @@ def test_doubly_robust_with_covariates(self): """Test doubly robust method with covariates.""" data = generate_staggered_data_with_covariates(seed=789) - cs = CallawaySantAnna(estimation_method='dr') + cs = CallawaySantAnna(estimation_method="dr") results = cs.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - covariates=['x1', 'x2'] + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + covariates=["x1", "x2"], ) assert results.overall_att is not None @@ -556,18 +502,18 @@ def test_all_methods_with_covariates(self): """Test that all estimation methods work with covariates.""" data = generate_staggered_data_with_covariates(seed=42) - methods = ['reg', 'ipw', 'dr'] + methods = ["reg", "ipw", "dr"] results = {} for method in methods: cs = CallawaySantAnna(estimation_method=method) results[method] = cs.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - covariates=['x1', 'x2'] + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + covariates=["x1", "x2"], ) # All methods should produce valid results @@ -582,12 +528,12 @@ def test_event_study_with_covariates(self): cs = CallawaySantAnna() results = cs.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - covariates=['x1', 'x2'], - aggregate='event_study' + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + covariates=["x1", "x2"], + aggregate="event_study", ) assert results.event_study_effects is not None @@ -602,11 +548,11 @@ def test_missing_covariate_error(self): with pytest.raises(ValueError, match="Missing columns"): cs.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - covariates=['x1', 'nonexistent'] + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + covariates=["x1", "nonexistent"], ) def test_single_covariate(self): @@ -616,11 +562,11 @@ def test_single_covariate(self): cs = CallawaySantAnna() results = cs.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - covariates=['x1'] + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + covariates=["x1"], ) assert results.overall_att is not None @@ -633,17 +579,17 @@ def test_treatment_effect_recovery_with_covariates(self): treatment_effect=3.0, covariate_effect=2.0, seed=123, - n_units=200 # More units for better precision + n_units=200, # More units for better precision ) - cs = CallawaySantAnna(estimation_method='dr') + cs = CallawaySantAnna(estimation_method="dr") results = cs.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - covariates=['x1', 'x2'] + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + covariates=["x1", "x2"], ) # Effect should be roughly correct (within reasonable bounds) @@ -680,23 +626,25 @@ def test_extreme_propensity_scores(self): post = (times >= first_treat_expanded) & (first_treat_expanded > 0) outcomes = 1.0 + 0.5 * x_strong_expanded + 2.0 * post + np.random.randn(len(units)) * 0.3 - data = pd.DataFrame({ - 'unit': units, - 'time': times, - 'outcome': outcomes, - 'first_treat': first_treat_expanded.astype(int), - 'x_strong': x_strong_expanded, - }) + data = pd.DataFrame( + { + "unit": units, + "time": times, + "outcome": outcomes, + "first_treat": first_treat_expanded.astype(int), + "x_strong": x_strong_expanded, + } + ) # IPW should handle extreme propensity scores via clipping - cs = CallawaySantAnna(estimation_method='ipw') + cs = CallawaySantAnna(estimation_method="ipw") results = cs.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - covariates=['x_strong'] + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + covariates=["x_strong"], ) # Should produce valid results (not NaN or inf) @@ -713,6 +661,7 @@ def test_extreme_weights_warning(self, ci_params): - Bootstrap drops invalid samples and adjusts inference accordingly """ import warnings + np.random.seed(42) n_boot = ci_params.bootstrap(100) @@ -729,28 +678,27 @@ def test_extreme_weights_warning(self, ci_params): post = (times >= first_treat_expanded) & (first_treat_expanded > 0) outcomes = 1.0 + 2.0 * post + np.random.randn(len(units)) * 0.1 - data = pd.DataFrame({ - 'unit': units, - 'time': times, - 'outcome': outcomes, - 'first_treat': first_treat_expanded.astype(int), - }) + data = pd.DataFrame( + { + "unit": units, + "time": times, + "outcome": outcomes, + "first_treat": first_treat_expanded.astype(int), + } + ) # Test without bootstrap - ATT should be finite, SE may be NaN for edge cases cs = CallawaySantAnna() results = cs.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) # ATT point estimate should be finite assert np.isfinite(results.overall_att), "ATT should be finite" # SE is either finite (valid) or NaN (signals invalid inference) - not biased - assert np.isfinite(results.overall_se) or np.isnan(results.overall_se), \ - "SE should be finite or NaN (not inf)" + assert np.isfinite(results.overall_se) or np.isnan( + results.overall_se + ), "SE should be finite or NaN (not inf)" # Test with bootstrap - should drop invalid samples with warning cs_boot = CallawaySantAnna(n_bootstrap=n_boot, seed=42) @@ -758,11 +706,7 @@ def test_extreme_weights_warning(self, ci_params): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") boot_results = cs_boot.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) # Collect warning messages for inspection @@ -773,13 +717,15 @@ def test_extreme_weights_warning(self, ci_params): # Bootstrap SE based on valid samples - may be finite or NaN assert boot_results.bootstrap_results is not None, "Bootstrap results should exist" - assert np.isfinite(boot_results.overall_se) or np.isnan(boot_results.overall_se), \ - "Bootstrap SE should be finite or NaN (not inf)" + assert np.isfinite(boot_results.overall_se) or np.isnan( + boot_results.overall_se + ), "Bootstrap SE should be finite or NaN (not inf)" # If SE is NaN, verify it's due to validity threshold (should have warning) if np.isnan(boot_results.overall_se): - assert any("valid" in msg.lower() or "nan" in msg.lower() for msg in warning_messages), \ - "NaN SE should be accompanied by warning about validity" + assert any( + "valid" in msg.lower() or "nan" in msg.lower() for msg in warning_messages + ), "NaN SE should be accompanied by warning about validity" def test_validity_threshold_nan_se(self): """Test that <50% valid bootstrap samples returns NaN SE with warning. @@ -788,6 +734,7 @@ def test_validity_threshold_nan_se(self): is signaled via NaN rather than biased estimates. """ import warnings + np.random.seed(42) # Create minimal dataset that might trigger edge cases @@ -803,12 +750,14 @@ def test_validity_threshold_nan_se(self): post = (times >= first_treat_expanded) & (first_treat_expanded > 0) outcomes = 1.0 + 2.0 * post + np.random.randn(len(units)) * 0.5 - data = pd.DataFrame({ - 'unit': units, - 'time': times, - 'outcome': outcomes, - 'first_treat': first_treat_expanded.astype(int), - }) + data = pd.DataFrame( + { + "unit": units, + "time": times, + "outcome": outcomes, + "first_treat": first_treat_expanded.astype(int), + } + ) # Use low n_bootstrap to trigger warning and potentially non-finite samples cs_boot = CallawaySantAnna(n_bootstrap=30, seed=42) @@ -816,25 +765,23 @@ def test_validity_threshold_nan_se(self): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") boot_results = cs_boot.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) warning_messages = [str(warning.message) for warning in w] # Should get the low n_bootstrap warning - assert any("n_bootstrap" in msg for msg in warning_messages), \ - "Should warn about low n_bootstrap" + assert any( + "n_bootstrap" in msg for msg in warning_messages + ), "Should warn about low n_bootstrap" # Bootstrap results should exist assert boot_results.bootstrap_results is not None, "Bootstrap results should exist" # SE constraints: finite or NaN (never inf) - assert np.isfinite(boot_results.overall_se) or np.isnan(boot_results.overall_se), \ - "Bootstrap SE should be finite or NaN (not inf)" + assert np.isfinite(boot_results.overall_se) or np.isnan( + boot_results.overall_se + ), "Bootstrap SE should be finite or NaN (not inf)" def test_near_collinear_covariates(self): """Test that near-collinear covariates are handled gracefully.""" @@ -845,16 +792,16 @@ def test_near_collinear_covariates(self): # of 1e-5 which is above the tolerance but still creates high collinearity. # With noise < 1e-07, the column would be considered linearly dependent. np.random.seed(42) - data['x1_copy'] = data['x1'] + np.random.randn(len(data)) * 1e-5 + data["x1_copy"] = data["x1"] + np.random.randn(len(data)) * 1e-5 - cs = CallawaySantAnna(estimation_method='reg') + cs = CallawaySantAnna(estimation_method="reg") results = cs.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - covariates=['x1', 'x1_copy'] # Nearly collinear + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + covariates=["x1", "x1_copy"], # Nearly collinear ) # Should still produce valid results (noise is above tolerance) @@ -866,7 +813,7 @@ def test_missing_values_in_covariates_warning(self): data = generate_staggered_data_with_covariates(seed=42) # Introduce NaN in covariate - data.loc[data['time'] == 2, 'x1'] = np.nan + data.loc[data["time"] == 2, "x1"] = np.nan cs = CallawaySantAnna() @@ -874,11 +821,11 @@ def test_missing_values_in_covariates_warning(self): with pytest.warns(UserWarning, match="Missing values in covariates"): results = cs.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - covariates=['x1', 'x2'] + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + covariates=["x1", "x2"], ) # Should still produce valid results (using unconditional estimation) @@ -893,37 +840,35 @@ def test_dr_covariates_not_yet_treated(self): """ data = generate_staggered_data_with_covariates(seed=42, n_units=200) - for method in ['dr', 'reg']: + for method in ["dr", "reg"]: cs = CallawaySantAnna( estimation_method=method, - control_group='not_yet_treated', + control_group="not_yet_treated", ) results = cs.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - covariates=['x1', 'x2'], + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + covariates=["x1", "x2"], ) - assert np.isfinite(results.overall_att), ( - f"{method}/not_yet_treated: ATT should be finite" - ) - assert results.overall_se > 0, ( - f"{method}/not_yet_treated: SE should be positive" - ) - assert len(results.group_time_effects) > 0, ( - f"{method}/not_yet_treated: should have group-time effects" - ) + assert np.isfinite( + results.overall_att + ), f"{method}/not_yet_treated: ATT should be finite" + assert results.overall_se > 0, f"{method}/not_yet_treated: SE should be positive" + assert ( + len(results.group_time_effects) > 0 + ), f"{method}/not_yet_treated: should have group-time effects" # All effects should be finite for (g, t), eff in results.group_time_effects.items(): - assert np.isfinite(eff['effect']), ( - f"{method}/not_yet_treated: effect for ({g},{t}) should be finite" - ) - assert np.isfinite(eff['se']), ( - f"{method}/not_yet_treated: SE for ({g},{t}) should be finite" - ) + assert np.isfinite( + eff["effect"] + ), f"{method}/not_yet_treated: effect for ({g},{t}) should be finite" + assert np.isfinite( + eff["se"] + ), f"{method}/not_yet_treated: SE for ({g},{t}) should be finite" def test_rank_deficient_action_error_raises(self): """Test that rank_deficient_action='error' raises ValueError on collinear data.""" @@ -934,16 +879,16 @@ def test_rank_deficient_action_error_raises(self): cs = CallawaySantAnna( estimation_method="reg", # Use regression method to test OLS path - rank_deficient_action="error" + rank_deficient_action="error", ) with pytest.raises(ValueError, match="rank-deficient"): cs.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - covariates=['x1', 'x1_dup'] + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + covariates=["x1", "x1_dup"], ) def test_rank_deficient_action_silent_no_warning(self): @@ -957,23 +902,26 @@ def test_rank_deficient_action_silent_no_warning(self): cs = CallawaySantAnna( estimation_method="reg", # Use regression method to test OLS path - rank_deficient_action="silent" + rank_deficient_action="silent", ) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") results = cs.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - covariates=['x1', 'x1_dup'] + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + covariates=["x1", "x1_dup"], ) # No warnings about rank deficiency should be emitted - rank_warnings = [x for x in w if "Rank-deficient" in str(x.message) - or "rank-deficient" in str(x.message).lower()] + rank_warnings = [ + x + for x in w + if "Rank-deficient" in str(x.message) or "rank-deficient" in str(x.message).lower() + ] assert len(rank_warnings) == 0, f"Expected no rank warnings, got {rank_warnings}" # Should still get valid results @@ -1000,18 +948,21 @@ def test_rank_deficient_action_warn_emits_warning(self): warnings.simplefilter("always") results = cs.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - covariates=['x1', 'x1_dup'] + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + covariates=["x1", "x1_dup"], ) - rank_warnings = [x for x in w if "rank-deficient" in str(x.message).lower() - or "Rank-deficient" in str(x.message)] - assert len(rank_warnings) > 0, ( - "Expected at least one rank-deficiency warning with collinear covariates" - ) + rank_warnings = [ + x + for x in w + if "rank-deficient" in str(x.message).lower() or "Rank-deficient" in str(x.message) + ] + assert ( + len(rank_warnings) > 0 + ), "Expected at least one rank-deficiency warning with collinear covariates" # Should still produce valid results (lstsq fallback) assert results is not None @@ -1025,20 +976,20 @@ def test_empty_covariates_list_behaves_like_none(self): cs_none = CallawaySantAnna(n_bootstrap=0, seed=42) results_none = cs_none.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", covariates=None, ) cs_empty = CallawaySantAnna(n_bootstrap=0, seed=42) results_empty = cs_empty.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", covariates=[], ) @@ -1054,7 +1005,7 @@ def test_nan_cell_preserved_not_dropped(self): data = generate_staggered_data_with_covariates(seed=42, n_units=100) # Patch lstsq to return inf for one specific call to simulate numerical failure - original_lstsq = __import__('scipy').linalg.lstsq + original_lstsq = __import__("scipy").linalg.lstsq call_count = [0] def mock_lstsq(*args, **kwargs): @@ -1068,41 +1019,39 @@ def mock_lstsq(*args, **kwargs): # Use rank_deficient_action="warn" to ensure we go through the covariate reg path # and also force lstsq fallback by using collinear covariates - data['x1_dup'] = data['x1'] + data["x1_dup"] = data["x1"] cs = CallawaySantAnna( - n_bootstrap=0, seed=42, estimation_method='reg', - rank_deficient_action='warn', + n_bootstrap=0, + seed=42, + estimation_method="reg", + rank_deficient_action="warn", ) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - with patch('scipy.linalg.lstsq', side_effect=mock_lstsq): + with patch("scipy.linalg.lstsq", side_effect=mock_lstsq): results = cs.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - covariates=['x1', 'x1_dup'], + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + covariates=["x1", "x1_dup"], ) # Check that NaN cells are preserved (not dropped) nan_cells = [ - (g, t) for (g, t), eff in results.group_time_effects.items() - if np.isnan(eff['effect']) + (g, t) for (g, t), eff in results.group_time_effects.items() if np.isnan(eff["effect"]) ] # At least one cell should have NaN effect from our mock if call_count[0] > 0: # Verify warning about non-finite regression results - nan_warnings = [ - x for x in w - if "non-finite regression results" in str(x.message) - ] + nan_warnings = [x for x in w if "non-finite regression results" in str(x.message)] if nan_cells: assert len(nan_warnings) > 0 # NaN cells should have NaN SE too for g, t in nan_cells: - assert np.isnan(results.group_time_effects[(g, t)]['se']) + assert np.isnan(results.group_time_effects[(g, t)]["se"]) # Overall ATT should still be finite (NaN cells excluded from aggregation) assert np.isfinite(results.overall_att) @@ -1114,7 +1063,7 @@ def test_nan_cell_bootstrap_aggregation_excludes_nan(self, ci_params): data = generate_staggered_data_with_covariates(seed=42, n_units=100) - original_lstsq = __import__('scipy').linalg.lstsq + original_lstsq = __import__("scipy").linalg.lstsq call_count = [0] def mock_lstsq(*args, **kwargs): @@ -1127,38 +1076,39 @@ def mock_lstsq(*args, **kwargs): return (bad_beta,) + result[1:] return result - data['x1_dup'] = data['x1'] + data["x1_dup"] = data["x1"] n_boot = ci_params.bootstrap(199) cs = CallawaySantAnna( - n_bootstrap=n_boot, seed=42, estimation_method='reg', - rank_deficient_action='warn', + n_bootstrap=n_boot, + seed=42, + estimation_method="reg", + rank_deficient_action="warn", ) with warnings.catch_warnings(record=True): warnings.simplefilter("always") - with patch('scipy.linalg.lstsq', side_effect=mock_lstsq): + with patch("scipy.linalg.lstsq", side_effect=mock_lstsq): results = cs.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - covariates=['x1', 'x1_dup'], - aggregate='all', + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + covariates=["x1", "x1_dup"], + aggregate="all", ) # NaN cell should be preserved in group_time_effects nan_cells = [ - (g, t) for (g, t), eff in results.group_time_effects.items() - if np.isnan(eff['effect']) + (g, t) for (g, t), eff in results.group_time_effects.items() if np.isnan(eff["effect"]) ] assert len(nan_cells) > 0, "Expected at least one NaN cell from mock" # Verify poisoned cell is post-treatment so overall ATT bootstrap path is exercised post_treatment_nan = [(g, t) for g, t in nan_cells if t >= g - cs.anticipation] - assert len(post_treatment_nan) > 0, ( - "Poisoned cell must be post-treatment to exercise overall ATT bootstrap filtering" - ) + assert ( + len(post_treatment_nan) > 0 + ), "Poisoned cell must be post-treatment to exercise overall ATT bootstrap filtering" # Overall ATT bootstrap inference should be finite (NaN cells excluded) assert np.isfinite(results.overall_att), "overall_att should be finite" @@ -1169,17 +1119,16 @@ def mock_lstsq(*args, **kwargs): # Event study: valid relative times should have finite bootstrap inference if results.event_study_effects: for e, data_es in results.event_study_effects.items(): - if np.isfinite(data_es['effect']): - assert np.isfinite(data_es['se']), f"ES e={e} se should be finite" - assert np.isfinite(data_es['p_value']), f"ES e={e} p_value should be finite" + if np.isfinite(data_es["effect"]): + assert np.isfinite(data_es["se"]), f"ES e={e} se should be finite" + assert np.isfinite(data_es["p_value"]), f"ES e={e} p_value should be finite" # Group effects: valid groups should have finite bootstrap inference if results.group_effects: for g, data_ge in results.group_effects.items(): - if np.isfinite(data_ge['effect']): - assert np.isfinite(data_ge['se']), f"Group {g} se should be finite" - assert np.isfinite(data_ge['p_value']), f"Group {g} p_value should be finite" - + if np.isfinite(data_ge["effect"]): + assert np.isfinite(data_ge["se"]), f"Group {g} se should be finite" + assert np.isfinite(data_ge["p_value"]), f"Group {g} p_value should be finite" def test_balance_e_excludes_nan_anchor_cohort(self, ci_params): """balance_e must exclude cohorts whose anchor-horizon effect is NaN.""" @@ -1188,7 +1137,7 @@ def test_balance_e_excludes_nan_anchor_cohort(self, ci_params): data = generate_staggered_data_with_covariates(seed=42, n_units=100) - original_lstsq = __import__('scipy').linalg.lstsq + original_lstsq = __import__("scipy").linalg.lstsq call_count = [0] def mock_lstsq(*args, **kwargs): @@ -1200,53 +1149,53 @@ def mock_lstsq(*args, **kwargs): return (bad_beta,) + result[1:] return result - data['x1_dup'] = data['x1'] + data["x1_dup"] = data["x1"] n_boot = ci_params.bootstrap(199) cs = CallawaySantAnna( - n_bootstrap=n_boot, seed=42, estimation_method='reg', - rank_deficient_action='warn', + n_bootstrap=n_boot, + seed=42, + estimation_method="reg", + rank_deficient_action="warn", ) with warnings.catch_warnings(record=True): warnings.simplefilter("always") - with patch('scipy.linalg.lstsq', side_effect=mock_lstsq): + with patch("scipy.linalg.lstsq", side_effect=mock_lstsq): results = cs.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - covariates=['x1', 'x1_dup'], - aggregate='event_study', + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + covariates=["x1", "x1_dup"], + aggregate="event_study", balance_e=0, ) # Confirm the anchor cell is NaN and is specifically the anchor (t - g == 0) - assert np.isnan(results.group_time_effects[(3, 3)]['effect']), \ - "Mock should have poisoned (g=3, t=3)" + assert np.isnan( + results.group_time_effects[(3, 3)]["effect"] + ), "Mock should have poisoned (g=3, t=3)" assert 3 - 3 == 0, "Poisoned cell must be the anchor at balance_e=0" # Cohort g=3 should be excluded from ALL event-study horizons # Only g=5 and g=8 should contribute (<=2 because not all balanced # cohorts have cells at extreme horizons) for e, es_data in results.event_study_effects.items(): - assert es_data['n_groups'] <= 2, ( + assert es_data["n_groups"] <= 2, ( f"Event time e={e} has n_groups={es_data['n_groups']}, " "expected <=2 (cohort g=3 should be excluded due to NaN anchor)" ) # Analytical effects and SEs should be finite for all horizons for e, es_data in results.event_study_effects.items(): - assert np.isfinite(es_data['effect']), \ - f"e={e}: analytical effect should be finite" - assert np.isfinite(es_data['se']), \ - f"e={e}: analytical SE should be finite" + assert np.isfinite(es_data["effect"]), f"e={e}: analytical effect should be finite" + assert np.isfinite(es_data["se"]), f"e={e}: analytical SE should be finite" # Bootstrap SEs should also be finite if results.bootstrap_results and results.bootstrap_results.event_study_ses: for e, se in results.bootstrap_results.event_study_ses.items(): - assert np.isfinite(se), \ - f"e={e}: bootstrap SE should be finite" + assert np.isfinite(se), f"e={e}: bootstrap SE should be finite" class TestCallawaySantAnnaRankDeficiencyPaths: @@ -1277,11 +1226,14 @@ def test_dr_rank_deficient_action_warn_emits_warning(self): covariates=["x1", "x1_near"], ) - rank_warnings = [x for x in w if "rank-deficient" in str(x.message).lower() - or "Rank-deficient" in str(x.message)] - assert len(rank_warnings) > 0, ( - "Expected at least one rank-deficiency warning from DR path" - ) + rank_warnings = [ + x + for x in w + if "rank-deficient" in str(x.message).lower() or "Rank-deficient" in str(x.message) + ] + assert ( + len(rank_warnings) > 0 + ), "Expected at least one rank-deficiency warning from DR path" assert results is not None assert results.overall_att is not None @@ -1310,16 +1262,59 @@ def test_reg_nyt_rank_deficient_action_warn(self): covariates=["x1", "x1_dup"], ) - rank_warnings = [x for x in w if "rank-deficient" in str(x.message).lower() - or "Rank-deficient" in str(x.message)] - assert len(rank_warnings) > 0, ( - "Expected at least one rank-deficiency warning from reg nyt path" - ) + rank_warnings = [ + x + for x in w + if "rank-deficient" in str(x.message).lower() or "Rank-deficient" in str(x.message) + ] + assert ( + len(rank_warnings) > 0 + ), "Expected at least one rank-deficiency warning from reg nyt path" assert results is not None assert results.overall_att is not None assert results.overall_se > 0 + def test_ipw_rank_deficient_action_error_raises(self): + """IPW path raises ValueError with rank_deficient_action='error' and collinear covariates.""" + data = generate_staggered_data_with_covariates(seed=42) + data["x1_dup"] = data["x1"].copy() + + cs = CallawaySantAnna( + estimation_method="ipw", + rank_deficient_action="error", + ) + + with pytest.raises(ValueError, match="[Rr]ank"): + cs.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + covariates=["x1", "x1_dup"], + ) + + def test_dr_rank_deficient_action_error_raises(self): + """DR path raises ValueError with rank_deficient_action='error' and collinear covariates.""" + data = generate_staggered_data_with_covariates(seed=42) + data["x1_dup"] = data["x1"].copy() + + cs = CallawaySantAnna( + estimation_method="dr", + rank_deficient_action="error", + ) + + with pytest.raises(ValueError, match="[Rr]ank"): + cs.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + covariates=["x1", "x1_dup"], + ) + def test_bootstrap_single_unit_cohort_handles_gracefully(self, ci_params): """Test that bootstrap handles cohort with 1 treated unit without crashing.""" # Build small dataset where one cohort has exactly 1 unit @@ -1354,8 +1349,7 @@ def test_bootstrap_single_unit_cohort_handles_gracefully(self, ci_params): assert results is not None assert results.overall_att is not None # Single-unit cohort (g=5) effects should exist and have finite ATT - g5_effects = {(g, t): eff for (g, t), eff in results.group_time_effects.items() - if g == 5} + g5_effects = {(g, t): eff for (g, t), eff in results.group_time_effects.items() if g == 5} assert len(g5_effects) > 0, "Expected group-time effects for cohort g=5" for (g, t), eff in g5_effects.items(): assert np.isfinite(eff["effect"]), f"g={g},t={t}: ATT should be finite" @@ -1371,11 +1365,7 @@ def test_bootstrap_basic(self, ci_params): cs = CallawaySantAnna(n_bootstrap=n_boot, seed=42) results = cs.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) assert results.bootstrap_results is not None @@ -1392,17 +1382,9 @@ def test_bootstrap_weight_types(self, ci_params): weight_types = ["rademacher", "mammen", "webb"] for wt in weight_types: - cs = CallawaySantAnna( - n_bootstrap=n_boot, - bootstrap_weight_type=wt, - seed=42 - ) + cs = CallawaySantAnna(n_bootstrap=n_boot, bootstrap_weight_type=wt, seed=42) results = cs.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) assert results.bootstrap_results is not None @@ -1417,11 +1399,11 @@ def test_bootstrap_event_study(self, ci_params): cs = CallawaySantAnna(n_bootstrap=n_boot, seed=42) results = cs.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - aggregate='event_study' + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", ) assert results.bootstrap_results is not None @@ -1431,8 +1413,8 @@ def test_bootstrap_event_study(self, ci_params): # Check event study effects have bootstrap SEs for e, effect in results.event_study_effects.items(): - assert effect['se'] > 0 - assert effect['conf_int'][0] < effect['conf_int'][1] + assert effect["se"] > 0 + assert effect["conf_int"][0] < effect["conf_int"][1] def test_bootstrap_group_aggregation(self, ci_params): """Test bootstrap with group aggregation.""" @@ -1442,11 +1424,11 @@ def test_bootstrap_group_aggregation(self, ci_params): cs = CallawaySantAnna(n_bootstrap=n_boot, seed=42) results = cs.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - aggregate='group' + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="group", ) assert results.bootstrap_results is not None @@ -1456,8 +1438,8 @@ def test_bootstrap_group_aggregation(self, ci_params): # Check group effects have bootstrap SEs for g, effect in results.group_effects.items(): - assert effect['se'] > 0 - assert effect['conf_int'][0] < effect['conf_int'][1] + assert effect["se"] > 0 + assert effect["conf_int"][0] < effect["conf_int"][1] def test_bootstrap_all_aggregations(self, ci_params): """Test bootstrap with all aggregations.""" @@ -1467,11 +1449,11 @@ def test_bootstrap_all_aggregations(self, ci_params): cs = CallawaySantAnna(n_bootstrap=n_boot, seed=42) results = cs.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - aggregate='all' + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="all", ) assert results.bootstrap_results is not None @@ -1485,20 +1467,12 @@ def test_bootstrap_reproducibility(self, ci_params): cs1 = CallawaySantAnna(n_bootstrap=n_boot, seed=123) results1 = cs1.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) cs2 = CallawaySantAnna(n_bootstrap=n_boot, seed=123) results2 = cs2.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) # Results should be identical with same seed @@ -1512,20 +1486,12 @@ def test_bootstrap_different_seeds(self, ci_params): cs1 = CallawaySantAnna(n_bootstrap=n_boot, seed=123) results1 = cs1.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) cs2 = CallawaySantAnna(n_bootstrap=n_boot, seed=456) results2 = cs2.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) # Results should differ with different seeds @@ -1533,20 +1499,12 @@ def test_bootstrap_different_seeds(self, ci_params): def test_bootstrap_p_value_significance(self, ci_params): """Test that strong effect has significant p-value with bootstrap.""" - data = generate_staggered_data( - n_units=100, - treatment_effect=5.0, - seed=42 - ) + data = generate_staggered_data(n_units=100, treatment_effect=5.0, seed=42) n_boot = ci_params.bootstrap(199) cs = CallawaySantAnna(n_bootstrap=n_boot, seed=42) results = cs.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) # Strong effect should be significant @@ -1555,20 +1513,12 @@ def test_bootstrap_p_value_significance(self, ci_params): def test_bootstrap_zero_effect_not_significant(self, ci_params): """Test that zero effect is not significant with bootstrap.""" - data = generate_staggered_data( - n_units=50, - treatment_effect=0.0, - seed=42 - ) + data = generate_staggered_data(n_units=50, treatment_effect=0.0, seed=42) n_boot = ci_params.bootstrap(199) cs = CallawaySantAnna(n_bootstrap=n_boot, seed=42) results = cs.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) # Zero effect should not be significant at 0.01 level @@ -1582,11 +1532,7 @@ def test_bootstrap_distribution_stored(self, ci_params): cs = CallawaySantAnna(n_bootstrap=n_boot, seed=42) results = cs.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) assert results.bootstrap_results.bootstrap_distribution is not None @@ -1600,11 +1546,11 @@ def test_bootstrap_with_covariates(self, ci_params): cs = CallawaySantAnna(n_bootstrap=n_boot, seed=42) results = cs.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - covariates=['x1', 'x2'] + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + covariates=["x1", "x2"], ) assert results.bootstrap_results is not None @@ -1618,28 +1564,23 @@ def test_bootstrap_group_time_effects(self, ci_params): # Without bootstrap cs1 = CallawaySantAnna(n_bootstrap=0) results1 = cs1.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) # With bootstrap cs2 = CallawaySantAnna(n_bootstrap=n_boot, seed=42) results2 = cs2.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) # Group-time effects should have same point estimates for gt in results1.group_time_effects: - assert results1.group_time_effects[gt]['effect'] == results2.group_time_effects[gt]['effect'] + assert ( + results1.group_time_effects[gt]["effect"] + == results2.group_time_effects[gt]["effect"] + ) # But SEs may differ (bootstrap vs analytical) - assert results2.group_time_effects[gt]['se'] > 0 + assert results2.group_time_effects[gt]["se"] > 0 def test_bootstrap_invalid_weight_type(self): """Test that invalid weight type raises error.""" @@ -1652,35 +1593,23 @@ def test_bootstrap_invalid_weight_type(self): def test_bootstrap_get_params(self): """Test that get_params includes bootstrap_weights.""" - cs = CallawaySantAnna( - n_bootstrap=99, - bootstrap_weights="mammen", - seed=42 - ) + cs = CallawaySantAnna(n_bootstrap=99, bootstrap_weights="mammen", seed=42) params = cs.get_params() - assert params['n_bootstrap'] == 99 - assert params['bootstrap_weights'] == "mammen" + assert params["n_bootstrap"] == 99 + assert params["bootstrap_weights"] == "mammen" # Deprecated attribute still accessible for backward compat - assert params['bootstrap_weight_type'] == "mammen" - assert params['seed'] == 42 + assert params["bootstrap_weight_type"] == "mammen" + assert params["seed"] == 42 def test_bootstrap_with_not_yet_treated(self, ci_params): """Test bootstrap with not_yet_treated control group.""" data = generate_staggered_data(n_units=50, seed=42) n_boot = ci_params.bootstrap(99) - cs = CallawaySantAnna( - control_group="not_yet_treated", - n_bootstrap=n_boot, - seed=42 - ) + cs = CallawaySantAnna(control_group="not_yet_treated", n_bootstrap=n_boot, seed=42) results = cs.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) assert results.bootstrap_results is not None @@ -1694,17 +1623,9 @@ def test_bootstrap_estimation_methods(self, ci_params): methods = ["reg", "ipw", "dr"] for method in methods: - cs = CallawaySantAnna( - estimation_method=method, - n_bootstrap=n_boot, - seed=42 - ) + cs = CallawaySantAnna(estimation_method=method, n_bootstrap=n_boot, seed=42) results = cs.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) assert results.bootstrap_results is not None @@ -1718,12 +1639,12 @@ def test_bootstrap_with_balanced_event_study(self, ci_params): cs = CallawaySantAnna(n_bootstrap=n_boot, seed=42) results = cs.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - aggregate='event_study', - balance_e=0 # Balance at treatment time + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", + balance_e=0, # Balance at treatment time ) assert results.bootstrap_results is not None @@ -1732,8 +1653,8 @@ def test_bootstrap_with_balanced_event_study(self, ci_params): # Check that event study effects have valid bootstrap SEs for e, effect in results.event_study_effects.items(): - assert effect['se'] > 0 - assert effect['conf_int'][0] < effect['conf_int'][1] + assert effect["se"] > 0 + assert effect["conf_int"][0] < effect["conf_int"][1] def test_bootstrap_low_iterations_warning(self): """Test that low n_bootstrap triggers a warning.""" @@ -1742,13 +1663,7 @@ def test_bootstrap_low_iterations_warning(self): cs = CallawaySantAnna(n_bootstrap=30, seed=42) with pytest.warns(UserWarning, match="n_bootstrap=30 is low"): - cs.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' - ) + cs.fit(data, outcome="outcome", unit="unit", time="time", first_treat="first_treat") # ============================================================================= @@ -1788,23 +1703,19 @@ def test_single_cohort_basic(self): y += np.random.normal(0, 0.5) - data.append({ - 'unit': unit, - 'time': t, - 'outcome': y, - 'first_treat': first_treat, - }) + data.append( + { + "unit": unit, + "time": t, + "outcome": y, + "first_treat": first_treat, + } + ) df = pd.DataFrame(data) cs = CallawaySantAnna() - results = cs.fit( - df, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' - ) + results = cs.fit(df, outcome="outcome", unit="unit", time="time", first_treat="first_treat") # Should produce valid results assert results.overall_att is not None @@ -1846,23 +1757,25 @@ def test_single_cohort_event_study(self): y += np.random.normal(0, 0.4) - data.append({ - 'unit': unit, - 'time': t, - 'outcome': y, - 'first_treat': first_treat, - }) + data.append( + { + "unit": unit, + "time": t, + "outcome": y, + "first_treat": first_treat, + } + ) df = pd.DataFrame(data) cs = CallawaySantAnna() results = cs.fit( df, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - aggregate='event_study' + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", ) assert results.event_study_effects is not None @@ -1877,8 +1790,10 @@ def test_single_cohort_event_study(self): post_periods = [e for e in rel_periods if e >= 0] if post_periods: # At least some post-periods should show positive effect - post_effects = [results.event_study_effects[e]['effect'] for e in post_periods] - assert any(e > 0.5 for e in post_effects), f"Expected positive post-period effects, got {post_effects}" + post_effects = [results.event_study_effects[e]["effect"] for e in post_periods] + assert any( + e > 0.5 for e in post_effects + ), f"Expected positive post-period effects, got {post_effects}" def test_single_cohort_with_bootstrap(self, ci_params): """Test bootstrap inference with single cohort.""" @@ -1901,27 +1816,26 @@ def test_single_cohort_with_bootstrap(self, ci_params): if first_treat > 0 and t >= first_treat: y += 3.0 - data.append({ - 'unit': unit, - 'time': t, - 'outcome': y, - 'first_treat': first_treat, - }) + data.append( + { + "unit": unit, + "time": t, + "outcome": y, + "first_treat": first_treat, + } + ) df = pd.DataFrame(data) cs = CallawaySantAnna(n_bootstrap=n_boot, seed=42) - results = cs.fit( - df, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' - ) + results = cs.fit(df, outcome="outcome", unit="unit", time="time", first_treat="first_treat") assert results.bootstrap_results is not None assert results.bootstrap_results.overall_att_se > 0 - assert results.bootstrap_results.overall_att_ci[0] < results.bootstrap_results.overall_att_ci[1] + assert ( + results.bootstrap_results.overall_att_ci[0] + < results.bootstrap_results.overall_att_ci[1] + ) def test_single_cohort_not_yet_treated_control(self): """Test single cohort with not_yet_treated control group. @@ -1947,31 +1861,25 @@ def test_single_cohort_not_yet_treated_control(self): if first_treat > 0 and t >= first_treat: y += 2.0 - data.append({ - 'unit': unit, - 'time': t, - 'outcome': y, - 'first_treat': first_treat, - }) + data.append( + { + "unit": unit, + "time": t, + "outcome": y, + "first_treat": first_treat, + } + ) df = pd.DataFrame(data) - cs_never = CallawaySantAnna(control_group='never_treated') + cs_never = CallawaySantAnna(control_group="never_treated") results_never = cs_never.fit( - df, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + df, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) - cs_not_yet = CallawaySantAnna(control_group='not_yet_treated') + cs_not_yet = CallawaySantAnna(control_group="not_yet_treated") results_not_yet = cs_not_yet.fit( - df, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + df, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) # Both should produce valid results @@ -1995,28 +1903,20 @@ def test_analytical_se_vs_bootstrap_se(self, ci_params): n_cohorts=3, treatment_effect=3.0, never_treated_frac=0.3, - seed=42 + seed=42, ) n_boot = ci_params.bootstrap(499, min_n=249) # Run with analytical SE (n_bootstrap=0) cs_analytical = CallawaySantAnna(n_bootstrap=0, seed=42) results_analytical = cs_analytical.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) # Run with bootstrap SE (n_bootstrap=499) cs_bootstrap = CallawaySantAnna(n_bootstrap=n_boot, seed=42) results_bootstrap = cs_bootstrap.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) # Point estimates should match exactly @@ -2024,9 +1924,10 @@ def test_analytical_se_vs_bootstrap_se(self, ci_params): # SEs should be similar (within 15% with enough bootstrap iterations, # wider tolerance when min_n cap reduces iterations in pure Python mode) - rel_diff = abs( - results_analytical.overall_se - results_bootstrap.overall_se - ) / results_bootstrap.overall_se + rel_diff = ( + abs(results_analytical.overall_se - results_bootstrap.overall_se) + / results_bootstrap.overall_se + ) threshold = 0.40 if n_boot < 100 else 0.15 assert rel_diff < threshold, ( f"Analytical SE ({results_analytical.overall_se:.4f}) differs from " @@ -2047,16 +1948,12 @@ def test_analytical_se_accounts_for_covariance(self): n_cohorts=2, treatment_effect=2.0, never_treated_frac=0.4, # Larger never-treated pool = more sharing - seed=123 + seed=123, ) cs = CallawaySantAnna(n_bootstrap=0) results = cs.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) # The SE should be non-zero and positive @@ -2067,15 +1964,15 @@ def test_analytical_se_accounts_for_covariance(self): weights = [] variances = [] for (g, t), effect in gt_effects.items(): - weights.append(effect['n_treated']) - variances.append(effect['se'] ** 2) + weights.append(effect["n_treated"]) + variances.append(effect["se"] ** 2) weights = np.array(weights, dtype=float) weights = weights / weights.sum() variances = np.array(variances) # Independence SE formula (the old incorrect formula) - independence_var = np.sum(weights ** 2 * variances) + independence_var = np.sum(weights**2 * variances) independence_se = np.sqrt(independence_var) # The actual SE (with covariance) should generally be larger @@ -2106,29 +2003,25 @@ def test_analytical_se_single_gt_pair(self): y += 2.0 y += np.random.normal(0, 0.5) - data.append({ - 'unit': unit, - 'time': t, - 'outcome': y, - 'first_treat': first_treat, - }) + data.append( + { + "unit": unit, + "time": t, + "outcome": y, + "first_treat": first_treat, + } + ) df = pd.DataFrame(data) # Use only the first post-treatment period cs = CallawaySantAnna(n_bootstrap=0) - results = cs.fit( - df, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' - ) + results = cs.fit(df, outcome="outcome", unit="unit", time="time", first_treat="first_treat") # If there's only one (g,t) pair, overall SE should match individual SE if len(results.group_time_effects) == 1: gt_key = list(results.group_time_effects.keys())[0] - individual_se = results.group_time_effects[gt_key]['se'] + individual_se = results.group_time_effects[gt_key]["se"] # Should be close (may not be exact due to normalization) assert abs(results.overall_se - individual_se) < individual_se * 0.01 @@ -2140,7 +2033,7 @@ def test_event_study_analytical_se(self, ci_params): n_cohorts=3, treatment_effect=2.5, never_treated_frac=0.3, - seed=42 + seed=42, ) n_boot = ci_params.bootstrap(499, min_n=199) @@ -2148,22 +2041,22 @@ def test_event_study_analytical_se(self, ci_params): cs_analytical = CallawaySantAnna(n_bootstrap=0, seed=42) results_analytical = cs_analytical.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - aggregate='event_study' + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", ) # Bootstrap cs_bootstrap = CallawaySantAnna(n_bootstrap=n_boot, seed=42) results_bootstrap = cs_bootstrap.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - aggregate='event_study' + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", ) # Event study effects should exist @@ -2175,8 +2068,8 @@ def test_event_study_analytical_se(self, ci_params): threshold = 0.40 if n_boot < 100 else 0.20 for e in results_analytical.event_study_effects: if e in results_bootstrap.event_study_effects: - se_analytical = results_analytical.event_study_effects[e]['se'] - se_bootstrap = results_bootstrap.event_study_effects[e]['se'] + se_analytical = results_analytical.event_study_effects[e]["se"] + se_bootstrap = results_bootstrap.event_study_effects[e]["se"] if se_bootstrap > 0: rel_diff = abs(se_analytical - se_bootstrap) / se_bootstrap @@ -2195,10 +2088,10 @@ class TestCallawaySantAnnaNonStandardColumnNames: def generate_data_with_custom_names( self, - outcome_name: str = 'y', - unit_name: str = 'id', - time_name: str = 'period', - first_treat_name: str = 'treatment_start', + outcome_name: str = "y", + unit_name: str = "id", + time_name: str = "period", + first_treat_name: str = "treatment_start", n_units: int = 100, n_periods: int = 10, seed: int = 42, @@ -2213,8 +2106,8 @@ def generate_data_with_custom_names( # 30% never-treated, rest treated at period 4 or 6 n_never = int(n_units * 0.3) first_treat = np.zeros(n_units) - first_treat[n_never:n_never + (n_units - n_never) // 2] = 4 - first_treat[n_never + (n_units - n_never) // 2:] = 6 + first_treat[n_never : n_never + (n_units - n_never) // 2] = 4 + first_treat[n_never + (n_units - n_never) // 2 :] = 6 first_treat_expanded = np.repeat(first_treat, n_periods) # Generate outcomes @@ -2223,26 +2116,22 @@ def generate_data_with_custom_names( post = (times >= first_treat_expanded) & (first_treat_expanded > 0) outcomes = unit_fe + time_fe + 2.5 * post + np.random.randn(len(units)) * 0.5 - return pd.DataFrame({ - outcome_name: outcomes, - unit_name: units, - time_name: times, - first_treat_name: first_treat_expanded.astype(int), - }) + return pd.DataFrame( + { + outcome_name: outcomes, + unit_name: units, + time_name: times, + first_treat_name: first_treat_expanded.astype(int), + } + ) def test_non_standard_first_treat_name(self): """Test with non-standard first_treat column name.""" - data = self.generate_data_with_custom_names( - first_treat_name='treatment_cohort' - ) + data = self.generate_data_with_custom_names(first_treat_name="treatment_cohort") cs = CallawaySantAnna() results = cs.fit( - data, - outcome='y', - unit='id', - time='period', - first_treat='treatment_cohort' + data, outcome="y", unit="id", time="period", first_treat="treatment_cohort" ) assert results.overall_att is not None @@ -2254,19 +2143,19 @@ def test_non_standard_first_treat_name(self): def test_non_standard_all_column_names(self): """Test with all non-standard column names.""" data = self.generate_data_with_custom_names( - outcome_name='response_var', - unit_name='entity_id', - time_name='time_period', - first_treat_name='treatment_timing', + outcome_name="response_var", + unit_name="entity_id", + time_name="time_period", + first_treat_name="treatment_timing", ) cs = CallawaySantAnna() results = cs.fit( data, - outcome='response_var', - unit='entity_id', - time='time_period', - first_treat='treatment_timing' + outcome="response_var", + unit="entity_id", + time="time_period", + first_treat="treatment_timing", ) assert results.overall_att is not None @@ -2276,19 +2165,12 @@ def test_non_standard_all_column_names(self): def test_non_standard_names_with_bootstrap(self, ci_params): """Test non-standard column names with bootstrap inference.""" data = self.generate_data_with_custom_names( - first_treat_name='g', # Short name like R's `did` package uses - n_units=50 + first_treat_name="g", n_units=50 # Short name like R's `did` package uses ) n_boot = ci_params.bootstrap(99) cs = CallawaySantAnna(n_bootstrap=n_boot, seed=42) - results = cs.fit( - data, - outcome='y', - unit='id', - time='period', - first_treat='g' - ) + results = cs.fit(data, outcome="y", unit="id", time="period", first_treat="g") assert results.bootstrap_results is not None assert results.overall_se > 0 @@ -2296,19 +2178,16 @@ def test_non_standard_names_with_bootstrap(self, ci_params): def test_non_standard_names_with_event_study(self): """Test non-standard column names with event study aggregation.""" - data = self.generate_data_with_custom_names( - first_treat_name='cohort', - n_periods=12 - ) + data = self.generate_data_with_custom_names(first_treat_name="cohort", n_periods=12) cs = CallawaySantAnna() results = cs.fit( data, - outcome='y', - unit='id', - time='period', - first_treat='cohort', - aggregate='event_study' + outcome="y", + unit="id", + time="period", + first_treat="cohort", + aggregate="event_study", ) assert results.event_study_effects is not None @@ -2317,21 +2196,19 @@ def test_non_standard_names_with_event_study(self): def test_non_standard_names_with_covariates(self): """Test non-standard column names with covariate adjustment.""" # Generate data with covariates - data = self.generate_data_with_custom_names( - first_treat_name='treatment_time' - ) + data = self.generate_data_with_custom_names(first_treat_name="treatment_time") # Add covariates with custom names - data['covariate_x'] = np.random.randn(len(data)) - data['covariate_z'] = np.random.binomial(1, 0.5, len(data)) + data["covariate_x"] = np.random.randn(len(data)) + data["covariate_z"] = np.random.binomial(1, 0.5, len(data)) - cs = CallawaySantAnna(estimation_method='dr') + cs = CallawaySantAnna(estimation_method="dr") results = cs.fit( data, - outcome='y', - unit='id', - time='period', - first_treat='treatment_time', - covariates=['covariate_x', 'covariate_z'] + outcome="y", + unit="id", + time="period", + first_treat="treatment_time", + covariates=["covariate_x", "covariate_z"], ) assert results.overall_att is not None @@ -2339,21 +2216,13 @@ def test_non_standard_names_with_covariates(self): def test_non_standard_names_with_not_yet_treated(self): """Test non-standard column names with not_yet_treated control group.""" - data = self.generate_data_with_custom_names( - first_treat_name='adoption_period' - ) + data = self.generate_data_with_custom_names(first_treat_name="adoption_period") - cs = CallawaySantAnna(control_group='not_yet_treated') - results = cs.fit( - data, - outcome='y', - unit='id', - time='period', - first_treat='adoption_period' - ) + cs = CallawaySantAnna(control_group="not_yet_treated") + results = cs.fit(data, outcome="y", unit="id", time="period", first_treat="adoption_period") assert results.overall_att is not None - assert results.control_group == 'not_yet_treated' + assert results.control_group == "not_yet_treated" def test_non_standard_names_matches_standard_names(self): """Verify results are identical regardless of column naming.""" @@ -2362,32 +2231,24 @@ def test_non_standard_names_matches_standard_names(self): # Generate identical data with different column names data_standard = generate_staggered_data(n_units=80, seed=42) - data_custom = data_standard.rename(columns={ - 'outcome': 'y', - 'unit': 'entity', - 'time': 't', - 'first_treat': 'g', - }) + data_custom = data_standard.rename( + columns={ + "outcome": "y", + "unit": "entity", + "time": "t", + "first_treat": "g", + } + ) # Fit with standard names cs1 = CallawaySantAnna(seed=123) results1 = cs1.fit( - data_standard, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data_standard, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) # Fit with custom names cs2 = CallawaySantAnna(seed=123) - results2 = cs2.fit( - data_custom, - outcome='y', - unit='entity', - time='t', - first_treat='g' - ) + results2 = cs2.fit(data_custom, outcome="y", unit="entity", time="t", first_treat="g") # Results should be identical assert abs(results1.overall_att - results2.overall_att) < 1e-10 @@ -2396,18 +2257,20 @@ def test_non_standard_names_matches_standard_names(self): def test_column_name_with_spaces(self): """Test column names containing spaces.""" data = self.generate_data_with_custom_names() - data = data.rename(columns={ - 'y': 'outcome variable', - 'treatment_start': 'treatment period', - }) + data = data.rename( + columns={ + "y": "outcome variable", + "treatment_start": "treatment period", + } + ) cs = CallawaySantAnna() results = cs.fit( data, - outcome='outcome variable', - unit='id', - time='period', - first_treat='treatment period' + outcome="outcome variable", + unit="id", + time="period", + first_treat="treatment period", ) assert results.overall_att is not None @@ -2416,17 +2279,15 @@ def test_column_name_with_spaces(self): def test_column_name_with_special_characters(self): """Test column names with underscores and numbers.""" data = self.generate_data_with_custom_names() - data = data.rename(columns={ - 'treatment_start': 'first_treat_2024', - }) + data = data.rename( + columns={ + "treatment_start": "first_treat_2024", + } + ) cs = CallawaySantAnna() results = cs.fit( - data, - outcome='y', - unit='id', - time='period', - first_treat='first_treat_2024' + data, outcome="y", unit="id", time="period", first_treat="first_treat_2024" ) assert results.overall_att is not None @@ -2455,83 +2316,49 @@ def test_varying_pre_treatment_effects(self): """Varying mode computes pre-treatment ATT(g,t) for t < g.""" # Generate data with enough pre-treatment periods data = generate_staggered_data( - n_units=100, - n_periods=10, - n_cohorts=2, - treatment_effect=2.0, - seed=42 + n_units=100, n_periods=10, n_cohorts=2, treatment_effect=2.0, seed=42 ) cs = CallawaySantAnna(base_period="varying") results = cs.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) # Should have pre-treatment effects (t < g) - pre_treatment_effects = [ - (g, t) for (g, t) in results.group_time_effects.keys() - if t < g - ] + pre_treatment_effects = [(g, t) for (g, t) in results.group_time_effects.keys() if t < g] assert len(pre_treatment_effects) > 0, "Should compute pre-treatment effects" def test_universal_pre_treatment_effects(self): """Universal mode computes pre-treatment ATT(g,t) for t < g.""" data = generate_staggered_data( - n_units=100, - n_periods=10, - n_cohorts=2, - treatment_effect=2.0, - seed=42 + n_units=100, n_periods=10, n_cohorts=2, treatment_effect=2.0, seed=42 ) cs = CallawaySantAnna(base_period="universal") results = cs.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) # Should have pre-treatment effects (t < g) - pre_treatment_effects = [ - (g, t) for (g, t) in results.group_time_effects.keys() - if t < g - ] + pre_treatment_effects = [(g, t) for (g, t) in results.group_time_effects.keys() if t < g] assert len(pre_treatment_effects) > 0, "Should compute pre-treatment effects" def test_post_treatment_identical(self): """Post-treatment ATT(g,t) identical for both modes.""" data = generate_staggered_data( - n_units=100, - n_periods=10, - n_cohorts=2, - treatment_effect=2.0, - seed=42 + n_units=100, n_periods=10, n_cohorts=2, treatment_effect=2.0, seed=42 ) # Fit with varying cs_v = CallawaySantAnna(base_period="varying") res_v = cs_v.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) # Fit with universal cs_u = CallawaySantAnna(base_period="universal") res_u = cs_u.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) # Post-treatment effects should be identical @@ -2539,7 +2366,7 @@ def test_post_treatment_identical(self): if t >= g: # Post-treatment if (g, t) in res_u.group_time_effects: eff_u = res_u.group_time_effects[(g, t)] - assert abs(eff_v['effect'] - eff_u['effect']) < 1e-10, ( + assert abs(eff_v["effect"] - eff_u["effect"]) < 1e-10, ( f"Post-treatment ATT({g},{t}) differs: " f"varying={eff_v['effect']:.6f}, universal={eff_u['effect']:.6f}" ) @@ -2547,21 +2374,17 @@ def test_post_treatment_identical(self): def test_event_study_negative_periods(self): """Event study includes negative relative periods.""" data = generate_staggered_data( - n_units=100, - n_periods=12, - n_cohorts=2, - treatment_effect=2.0, - seed=42 + n_units=100, n_periods=12, n_cohorts=2, treatment_effect=2.0, seed=42 ) cs = CallawaySantAnna(base_period="varying") results = cs.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - aggregate='event_study' + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", ) assert results.event_study_effects is not None @@ -2569,9 +2392,9 @@ def test_event_study_negative_periods(self): # Should have negative relative periods rel_periods = list(results.event_study_effects.keys()) negative_periods = [e for e in rel_periods if e < 0] - assert len(negative_periods) > 0, ( - f"Event study should include negative periods, got {rel_periods}" - ) + assert ( + len(negative_periods) > 0 + ), f"Event study should include negative periods, got {rel_periods}" def test_base_period_in_results(self): """base_period is stored in results and shown in summary.""" @@ -2579,11 +2402,7 @@ def test_base_period_in_results(self): cs = CallawaySantAnna(base_period="universal") results = cs.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) assert results.base_period == "universal" @@ -2594,25 +2413,13 @@ def test_base_period_in_results(self): def test_pre_treatment_bootstrap(self, ci_params): """Bootstrap handles pre-treatment effects.""" data = generate_staggered_data( - n_units=60, - n_periods=8, - n_cohorts=2, - treatment_effect=2.0, - seed=42 + n_units=60, n_periods=8, n_cohorts=2, treatment_effect=2.0, seed=42 ) n_boot = ci_params.bootstrap(99) - cs = CallawaySantAnna( - base_period="varying", - n_bootstrap=n_boot, - seed=42 - ) + cs = CallawaySantAnna(base_period="varying", n_bootstrap=n_boot, seed=42) results = cs.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) assert results.bootstrap_results is not None @@ -2620,8 +2427,8 @@ def test_pre_treatment_bootstrap(self, ci_params): # Pre-treatment effects should have valid bootstrap SEs for (g, t), eff in results.group_time_effects.items(): if t < g: # Pre-treatment - assert eff['se'] > 0, f"Pre-treatment ATT({g},{t}) should have positive SE" - assert np.isfinite(eff['se']), f"Pre-treatment ATT({g},{t}) SE should be finite" + assert eff["se"] > 0, f"Pre-treatment ATT({g},{t}) should have positive SE" + assert np.isfinite(eff["se"]), f"Pre-treatment ATT({g},{t}) SE should be finite" def test_pre_treatment_near_zero_under_parallel_trends(self): """Pre-treatment effects should be near zero when parallel trends holds.""" @@ -2631,29 +2438,22 @@ def test_pre_treatment_near_zero_under_parallel_trends(self): n_periods=10, n_cohorts=2, treatment_effect=3.0, # Only post-treatment effect - seed=123 + seed=123, ) cs = CallawaySantAnna(base_period="varying") results = cs.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) # Pre-treatment effects should be close to zero - pre_effects = [ - eff['effect'] for (g, t), eff in results.group_time_effects.items() - if t < g - ] + pre_effects = [eff["effect"] for (g, t), eff in results.group_time_effects.items() if t < g] if pre_effects: # Mean of pre-treatment effects should be close to 0 mean_pre = np.mean(pre_effects) - assert abs(mean_pre) < 1.0, ( - f"Pre-treatment effects mean={mean_pre:.3f} should be near zero" - ) + assert ( + abs(mean_pre) < 1.0 + ), f"Pre-treatment effects mean={mean_pre:.3f} should be near zero" def test_set_params_base_period(self): """set_params() can change base_period.""" @@ -2676,28 +2476,20 @@ def test_varying_mode_no_fallback_to_nonconsecutive(self): """Varying mode skips pre-treatment effects where t-1 doesn't exist.""" # Create data where first period (e.g., period 1) has no t-1 predecessor data = generate_staggered_data( - n_units=100, - n_periods=6, # periods 1-6 - n_cohorts=2, - treatment_effect=2.0, - seed=42 + n_units=100, n_periods=6, n_cohorts=2, treatment_effect=2.0, seed=42 # periods 1-6 ) # Identify the earliest time period in data - min_period = data['time'].min() + min_period = data["time"].min() cs = CallawaySantAnna(base_period="varying") results = cs.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) # In varying mode, ATT(g, min_period) should NOT be computed for # any cohort g because t-1 (period 0) doesn't exist - for (g, t) in results.group_time_effects.keys(): + for g, t in results.group_time_effects.keys(): if t == min_period: # This should not happen - the (g, min_period) pair should be skipped pytest.fail( @@ -2722,12 +2514,9 @@ def test_no_post_treatment_effects_returns_nan_with_warning(self): # Data only goes to period 5, so no post-treatment periods exist first_treat = n_periods + 1 if unit < n_units // 2 else 0 outcome = np.random.randn() - data.append({ - 'unit': unit, - 'time': t, - 'outcome': outcome, - 'first_treat': first_treat - }) + data.append( + {"unit": unit, "time": t, "outcome": outcome, "first_treat": first_treat} + ) df = pd.DataFrame(data) @@ -2736,21 +2525,15 @@ def test_no_post_treatment_effects_returns_nan_with_warning(self): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") results = cs.fit( - df, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + df, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) # Should have emitted a warning about no post-treatment effects warning_messages = [str(warning.message) for warning in w] - has_warning = any( - "No post-treatment effects" in msg for msg in warning_messages - ) - assert has_warning, ( - f"Expected warning about no post-treatment effects, got: {warning_messages}" - ) + has_warning = any("No post-treatment effects" in msg for msg in warning_messages) + assert ( + has_warning + ), f"Expected warning about no post-treatment effects, got: {warning_messages}" # Overall ATT should be NaN assert np.isnan(results.overall_att), ( @@ -2758,19 +2541,20 @@ def test_no_post_treatment_effects_returns_nan_with_warning(self): f"got {results.overall_att}" ) # All inference fields should also be NaN - assert np.isnan(results.overall_se), ( - f"Expected NaN for overall_se, got {results.overall_se}" - ) - assert np.isnan(results.overall_t_stat), ( - f"Expected NaN for overall_t_stat, got {results.overall_t_stat}" - ) - assert np.isnan(results.overall_p_value), ( - f"Expected NaN for overall_p_value, got {results.overall_p_value}" - ) + assert np.isnan( + results.overall_se + ), f"Expected NaN for overall_se, got {results.overall_se}" + assert np.isnan( + results.overall_t_stat + ), f"Expected NaN for overall_t_stat, got {results.overall_t_stat}" + assert np.isnan( + results.overall_p_value + ), f"Expected NaN for overall_p_value, got {results.overall_p_value}" def test_no_post_treatment_effects_bootstrap_returns_nan(self, ci_params): """Bootstrap returns NaN inference when no post-treatment effects exist.""" import warnings + n_boot = ci_params.bootstrap(99) # Create data where treatment happens after the data ends @@ -2783,12 +2567,9 @@ def test_no_post_treatment_effects_bootstrap_returns_nan(self, ci_params): for t in range(1, n_periods + 1): first_treat = n_periods + 1 if unit < n_units // 2 else 0 outcome = np.random.randn() - data.append({ - 'unit': unit, - 'time': t, - 'outcome': outcome, - 'first_treat': first_treat - }) + data.append( + {"unit": unit, "time": t, "outcome": outcome, "first_treat": first_treat} + ) df = pd.DataFrame(data) @@ -2797,18 +2578,12 @@ def test_no_post_treatment_effects_bootstrap_returns_nan(self, ci_params): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") results = cs.fit( - df, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + df, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) # Should have warning about no post-treatment effects warning_messages = [str(warning.message) for warning in w] - has_warning = any( - "No post-treatment effects" in msg for msg in warning_messages - ) + has_warning = any("No post-treatment effects" in msg for msg in warning_messages) assert has_warning, f"Expected warning, got: {warning_messages}" # All overall inference fields should be NaN @@ -2831,6 +2606,7 @@ def test_bootstrap_runs_for_pretreatment_effects(self, ci_params): but pre-treatment effects should still get bootstrap SEs (not analytical). """ import warnings + n_boot = ci_params.bootstrap(99) # Create data where all treatment happens after the data ends @@ -2846,12 +2622,9 @@ def test_bootstrap_runs_for_pretreatment_effects(self, ci_params): first_treat = 10 if unit < n_units // 2 else 0 for t in range(1, n_periods + 1): outcome = np.random.randn() + (0.5 * t) # Some time trend - data.append({ - 'unit': unit, - 'time': t, - 'outcome': outcome, - 'first_treat': first_treat - }) + data.append( + {"unit": unit, "time": t, "outcome": outcome, "first_treat": first_treat} + ) df = pd.DataFrame(data) @@ -2861,18 +2634,12 @@ def test_bootstrap_runs_for_pretreatment_effects(self, ci_params): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") results = cs.fit( - df, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + df, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) # Should have warning about no post-treatment effects warning_messages = [str(warning.message) for warning in w] - has_warning = any( - "No post-treatment effects" in msg for msg in warning_messages - ) + has_warning = any("No post-treatment effects" in msg for msg in warning_messages) assert has_warning, f"Expected warning about no post-treatment effects" # Verify overall ATT is NaN @@ -2880,10 +2647,7 @@ def test_bootstrap_runs_for_pretreatment_effects(self, ci_params): assert np.isnan(results.overall_se), "overall_se should be NaN" # Verify we have pre-treatment effects - pre_treatment_effects = [ - (g, t) for (g, t) in results.group_time_effects.keys() - if t < g - ] + pre_treatment_effects = [(g, t) for (g, t) in results.group_time_effects.keys() if t < g] assert len(pre_treatment_effects) > 0, "Should have pre-treatment effects" # Key test: bootstrap should have computed SEs for the pre-treatment effects @@ -2895,17 +2659,17 @@ def test_bootstrap_runs_for_pretreatment_effects(self, ci_params): assert bootstrap_se is not None, f"Bootstrap SE missing for {gt}" # Bootstrap SE should be finite (it was computed, not analytical fallback) # Note: in the old code, these would be analytical SEs, not bootstrap - assert np.isfinite(bootstrap_se), ( - f"Bootstrap SE for {gt} should be finite, got {bootstrap_se}" - ) + assert np.isfinite( + bootstrap_se + ), f"Bootstrap SE for {gt} should be finite, got {bootstrap_se}" # Also verify overall bootstrap statistics are NaN - assert np.isnan(results.bootstrap_results.overall_att_se), ( - "Overall ATT SE should be NaN when no post-treatment" - ) - assert np.isnan(results.bootstrap_results.overall_att_p_value), ( - "Overall ATT p-value should be NaN when no post-treatment" - ) + assert np.isnan( + results.bootstrap_results.overall_att_se + ), "Overall ATT SE should be NaN when no post-treatment" + assert np.isnan( + results.bootstrap_results.overall_att_p_value + ), "Overall ATT p-value should be NaN when no post-treatment" def test_not_yet_treated_excludes_cohort_from_controls(self): """Not-yet-treated control excludes treated cohort g for pre-treatment periods. @@ -2950,33 +2714,23 @@ def test_not_yet_treated_excludes_cohort_from_controls(self): effect = 2.0 outcome = np.random.randn() + effect - data.append({ - 'unit': unit, - 'time': t, - 'outcome': outcome, - 'first_treat': first_treat - }) + data.append( + {"unit": unit, "time": t, "outcome": outcome, "first_treat": first_treat} + ) df = pd.DataFrame(data) # Fit with not_yet_treated control group cs = CallawaySantAnna( - control_group="not_yet_treated", - base_period="varying" # To get pre-treatment effects - ) - results = cs.fit( - df, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + control_group="not_yet_treated", base_period="varying" # To get pre-treatment effects ) + results = cs.fit(df, outcome="outcome", unit="unit", time="time", first_treat="first_treat") # Check the group-time effects for pre-treatment ATT(g=7, t) where t < 7 # These should have been computed using valid controls only for (g, t), eff in results.group_time_effects.items(): if g == 7 and t < g: # Pre-treatment for cohort 7 - n_control = eff['n_control'] + n_control = eff["n_control"] # Control should include: # - 30 never-treated units # - 30 units from cohort g=4 (if t < 4, they're not yet treated either) @@ -2997,9 +2751,9 @@ def test_not_yet_treated_excludes_cohort_from_controls(self): ) # Also verify we have a reasonable number of controls - assert n_control >= 30, ( - f"ATT(g=7, t={t}): n_control={n_control} should be >= 30 (never-treated)." - ) + assert ( + n_control >= 30 + ), f"ATT(g=7, t={t}): n_control={n_control} should be >= 30 (never-treated)." class TestCallawaySantAnnaAnticipation: @@ -3013,35 +2767,23 @@ def test_group_effects_with_anticipation(self): """ # Generate staggered data with a clear treatment effect data = generate_staggered_data( - n_units=100, - n_periods=12, - n_cohorts=2, - treatment_effect=3.0, - seed=42 + n_units=100, n_periods=12, n_cohorts=2, treatment_effect=3.0, seed=42 ) # Get treatment groups - groups = sorted(data[data['first_treat'] > 0]['first_treat'].unique()) + groups = sorted(data[data["first_treat"] > 0]["first_treat"].unique()) assert len(groups) >= 1, "Need at least one treatment group" # Fit without anticipation cs_no_antic = CallawaySantAnna(anticipation=0) res_no_antic = cs_no_antic.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) # Fit with anticipation=1 cs_antic = CallawaySantAnna(anticipation=1) res_antic = cs_antic.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) # With anticipation=1, group effects should include period g-1 @@ -3049,11 +2791,11 @@ def test_group_effects_with_anticipation(self): for g in groups: # Count effects included in group aggregation no_antic_effects = [ - (gg, t) for (gg, t) in res_no_antic.group_time_effects.keys() - if gg == g and t >= g + (gg, t) for (gg, t) in res_no_antic.group_time_effects.keys() if gg == g and t >= g ] antic_effects = [ - (gg, t) for (gg, t) in res_antic.group_time_effects.keys() + (gg, t) + for (gg, t) in res_antic.group_time_effects.keys() if gg == g and t >= g - 1 # anticipation=1 ] @@ -3074,34 +2816,27 @@ def test_group_effects_anticipation_boundary(self): n_periods=10, n_cohorts=1, # Single cohort for cleaner test treatment_effect=2.0, - seed=123 + seed=123, ) # Get the single treatment group - g = data[data['first_treat'] > 0]['first_treat'].iloc[0] + g = data[data["first_treat"] > 0]["first_treat"].iloc[0] # Fit with anticipation=2 cs = CallawaySantAnna(anticipation=2) results = cs.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) # Check group effects exist if results.group_effects is not None and g in results.group_effects: # The group effect for g should aggregate periods t >= g - 2 # Verify by checking which group-time effects exist - gt_for_group = [ - (gg, t) for (gg, t) in results.group_time_effects.keys() - if gg == g - ] + gt_for_group = [(gg, t) for (gg, t) in results.group_time_effects.keys() if gg == g] # There should be effects at t = g - anticipation = g - 2 # (if the data has that period) - min_period = data['time'].min() + min_period = data["time"].min() if g - 2 >= min_period: # Period g-2 should be computed as an ATT(g,t) has_antic_period = any(t == g - 2 for _, t in gt_for_group) @@ -3121,19 +2856,23 @@ def test_not_yet_treated_with_anticipation_excludes_anticipation_window(self): the treatment effect (~3.0) instead of near zero. """ data = generate_staggered_data( - n_units=100, n_periods=10, n_cohorts=2, - treatment_effect=3.0, seed=42, + n_units=100, + n_periods=10, + n_cohorts=2, + treatment_effect=3.0, + seed=42, ) cs = CallawaySantAnna(anticipation=1, control_group="not_yet_treated") result = cs.fit( - data, outcome="outcome", unit="unit", - time="time", first_treat="first_treat", + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", ) - groups = sorted( - g for g in data[data["first_treat"] > 0]["first_treat"].unique() - ) + groups = sorted(g for g in data[data["first_treat"] > 0]["first_treat"].unique()) for g in groups: for (gg, t), eff in result.group_time_effects.items(): @@ -3155,37 +2894,29 @@ def test_invalid_se_produces_nan_tstat_overall(self, ci_params): # Create data that will result in no valid post-treatment effects # This should produce NaN for overall statistics data = generate_staggered_data( - n_units=50, - n_periods=5, - n_cohorts=1, - treatment_effect=2.0, - seed=789 + n_units=50, n_periods=5, n_cohorts=1, treatment_effect=2.0, seed=789 ) n_boot = ci_params.bootstrap(50) # Modify first_treat so all treatment happens after data ends - data['first_treat'] = data['first_treat'].replace( - data['first_treat'].unique()[data['first_treat'].unique() > 0], - data['time'].max() + 10 + data["first_treat"] = data["first_treat"].replace( + data["first_treat"].unique()[data["first_treat"].unique() > 0], data["time"].max() + 10 ) import warnings + with warnings.catch_warnings(record=True): warnings.simplefilter("always") cs = CallawaySantAnna(n_bootstrap=n_boot, seed=42) results = cs.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) # Overall t_stat should be NaN when SE is invalid if np.isnan(results.overall_se) or results.overall_se == 0: - assert np.isnan(results.overall_t_stat), ( - "overall_t_stat should be NaN when SE is invalid" - ) + assert np.isnan( + results.overall_t_stat + ), "overall_t_stat should be NaN when SE is invalid" def test_per_effect_tstat_consistency(self, ci_params): """Per-effect t_stat uses same NaN logic as overall t_stat. @@ -3194,36 +2925,27 @@ def test_per_effect_tstat_consistency(self, ci_params): """ # Generate normal data data = generate_staggered_data( - n_units=60, - n_periods=8, - n_cohorts=2, - treatment_effect=2.0, - seed=456 + n_units=60, n_periods=8, n_cohorts=2, treatment_effect=2.0, seed=456 ) n_boot = ci_params.bootstrap(100) cs = CallawaySantAnna(n_bootstrap=n_boot, seed=42) results = cs.fit( - data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat' + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" ) # Check all group-time effects for (g, t), effect_data in results.group_time_effects.items(): - se = effect_data['se'] - t_stat = effect_data['t_stat'] + se = effect_data["se"] + t_stat = effect_data["t_stat"] if not np.isfinite(se) or se == 0: assert np.isnan(t_stat), ( - f"t_stat for ({g}, {t}) should be NaN when SE={se}, " - f"got t_stat={t_stat}" + f"t_stat for ({g}, {t}) should be NaN when SE={se}, " f"got t_stat={t_stat}" ) else: # t_stat should be effect / se - expected = effect_data['effect'] / se + expected = effect_data["effect"] / se assert np.isclose(t_stat, expected), ( f"t_stat for ({g}, {t}) should be effect/SE, " f"expected {expected}, got {t_stat}" @@ -3232,24 +2954,22 @@ def test_per_effect_tstat_consistency(self, ci_params): # Check event study effects if present if results.event_study_effects is not None: for e, effect_data in results.event_study_effects.items(): - se = effect_data['se'] - t_stat = effect_data['t_stat'] + se = effect_data["se"] + t_stat = effect_data["t_stat"] if not np.isfinite(se) or se == 0: - assert np.isnan(t_stat), ( - f"event study t_stat for e={e} should be NaN when SE={se}" - ) + assert np.isnan( + t_stat + ), f"event study t_stat for e={e} should be NaN when SE={se}" # Check group effects if present if results.group_effects is not None: for g, effect_data in results.group_effects.items(): - se = effect_data['se'] - t_stat = effect_data['t_stat'] + se = effect_data["se"] + t_stat = effect_data["t_stat"] if not np.isfinite(se) or se == 0: - assert np.isnan(t_stat), ( - f"group t_stat for g={g} should be NaN when SE={se}" - ) + assert np.isnan(t_stat), f"group t_stat for g={g} should be NaN when SE={se}" def test_aggregated_tstat_nan_when_se_zero(self): """Aggregated t_stat (event-study and group) is NaN when SE is zero or non-finite. @@ -3268,12 +2988,9 @@ def test_aggregated_tstat_nan_when_se_zero(self): first_treat = 3 if unit < n_units // 2 else 0 for t in range(1, n_periods + 1): outcome = np.random.randn() - data.append({ - 'unit': unit, - 'time': t, - 'outcome': outcome, - 'first_treat': first_treat - }) + data.append( + {"unit": unit, "time": t, "outcome": outcome, "first_treat": first_treat} + ) df = pd.DataFrame(data) @@ -3281,20 +2998,20 @@ def test_aggregated_tstat_nan_when_se_zero(self): cs = CallawaySantAnna(n_bootstrap=0) results = cs.fit( df, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - aggregate='all' # Get both event study and group effects + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="all", # Get both event study and group effects ) # Check that t_stat computation follows the correct pattern: # t_stat = effect / se if np.isfinite(se) and se > 0 else np.nan if results.event_study_effects: for e, data in results.event_study_effects.items(): - se = data['se'] - t_stat = data['t_stat'] - effect = data['effect'] + se = data["se"] + t_stat = data["t_stat"] + effect = data["effect"] if not np.isfinite(se) or se <= 0: assert np.isnan(t_stat), ( @@ -3310,9 +3027,9 @@ def test_aggregated_tstat_nan_when_se_zero(self): if results.group_effects: for g, data in results.group_effects.items(): - se = data['se'] - t_stat = data['t_stat'] - effect = data['effect'] + se = data["se"] + t_stat = data["t_stat"] + effect = data["effect"] if not np.isfinite(se) or se <= 0: assert np.isnan(t_stat), ( @@ -3322,8 +3039,7 @@ def test_aggregated_tstat_nan_when_se_zero(self): else: expected_t = effect / se assert np.isclose(t_stat, expected_t, rtol=1e-10), ( - f"Group t_stat for g={g} should be effect/SE={expected_t}, " - f"got {t_stat}" + f"Group t_stat for g={g} should be effect/SE={expected_t}, " f"got {t_stat}" ) def test_event_study_universal_includes_reference_period(self): @@ -3333,11 +3049,11 @@ def test_event_study_universal_includes_reference_period(self): cs = CallawaySantAnna(base_period="universal") results = cs.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - aggregate='event_study' + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", ) assert results.event_study_effects is not None, "event_study_effects should not be None" @@ -3350,15 +3066,19 @@ def test_event_study_universal_includes_reference_period(self): ref = results.event_study_effects[-1] # Effect is 0 by construction (normalization) - assert ref['effect'] == 0.0, f"Reference period effect should be 0.0, got {ref['effect']}" + assert ref["effect"] == 0.0, f"Reference period effect should be 0.0, got {ref['effect']}" # Inference fields are NaN - this is a normalization constraint, not an estimated effect - assert np.isnan(ref['se']), f"Reference period SE should be NaN, got {ref['se']}" - assert np.isnan(ref['t_stat']), f"Reference period t_stat should be NaN, got {ref['t_stat']}" - assert np.isnan(ref['p_value']), f"Reference period p_value should be NaN, got {ref['p_value']}" - assert np.isnan(ref['conf_int'][0]) and np.isnan(ref['conf_int'][1]), ( - f"Reference period CI should be (NaN, NaN), got {ref['conf_int']}" - ) - assert ref['n_groups'] == 0, f"Reference period n_groups should be 0, got {ref['n_groups']}" + assert np.isnan(ref["se"]), f"Reference period SE should be NaN, got {ref['se']}" + assert np.isnan( + ref["t_stat"] + ), f"Reference period t_stat should be NaN, got {ref['t_stat']}" + assert np.isnan( + ref["p_value"] + ), f"Reference period p_value should be NaN, got {ref['p_value']}" + assert np.isnan(ref["conf_int"][0]) and np.isnan( + ref["conf_int"][1] + ), f"Reference period CI should be (NaN, NaN), got {ref['conf_int']}" + assert ref["n_groups"] == 0, f"Reference period n_groups should be 0, got {ref['n_groups']}" def test_event_study_varying_excludes_reference_period(self): """Test that varying base period does NOT artificially add e=-1 with effect=0.""" @@ -3367,11 +3087,11 @@ def test_event_study_varying_excludes_reference_period(self): cs = CallawaySantAnna(base_period="varying") results = cs.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - aggregate='event_study' + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", ) assert results.event_study_effects is not None, "event_study_effects should not be None" @@ -3380,9 +3100,9 @@ def test_event_study_varying_excludes_reference_period(self): # The key is we don't artificially add a 0-effect entry if -1 in results.event_study_effects: # If it exists, it should be an actual computed effect, not 0.0 with n_groups=0 - assert results.event_study_effects[-1]['n_groups'] > 0, ( - "Varying mode should not artificially add e=-1 with n_groups=0" - ) + assert ( + results.event_study_effects[-1]["n_groups"] > 0 + ), "Varying mode should not artificially add e=-1 with n_groups=0" def test_event_study_universal_with_anticipation(self): """Test reference period with anticipation > 0.""" @@ -3391,11 +3111,11 @@ def test_event_study_universal_with_anticipation(self): cs = CallawaySantAnna(base_period="universal", anticipation=1) results = cs.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - aggregate='event_study' + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", ) assert results.event_study_effects is not None, "event_study_effects should not be None" @@ -3406,12 +3126,12 @@ def test_event_study_universal_with_anticipation(self): f"got periods: {list(results.event_study_effects.keys())}" ) ref = results.event_study_effects[-2] - assert ref['effect'] == 0.0, f"Reference period effect should be 0.0, got {ref['effect']}" + assert ref["effect"] == 0.0, f"Reference period effect should be 0.0, got {ref['effect']}" # Inference fields are NaN - normalization constraint - assert np.isnan(ref['se']), f"Reference period SE should be NaN, got {ref['se']}" - assert np.isnan(ref['conf_int'][0]) and np.isnan(ref['conf_int'][1]), ( - f"Reference period CI should be (NaN, NaN), got {ref['conf_int']}" - ) + assert np.isnan(ref["se"]), f"Reference period SE should be NaN, got {ref['se']}" + assert np.isnan(ref["conf_int"][0]) and np.isnan( + ref["conf_int"][1] + ), f"Reference period CI should be (NaN, NaN), got {ref['conf_int']}" def test_event_study_universal_no_effects_raises_error(self): """Test that estimator raises error when no effects can be computed. @@ -3423,22 +3143,24 @@ def test_event_study_universal_no_effects_raises_error(self): # Create minimal data with only never-treated units # This ensures no ATT(g,t) can be computed (no treatment groups) - data = pd.DataFrame({ - 'unit': [1, 1, 2, 2, 3, 3], - 'time': [1, 2, 1, 2, 1, 2], - 'outcome': [1.0, 1.1, 1.2, 1.3, 1.4, 1.5], - 'first_treat': [0, 0, 0, 0, 0, 0], # All never-treated - }) + data = pd.DataFrame( + { + "unit": [1, 1, 2, 2, 3, 3], + "time": [1, 2, 1, 2, 1, 2], + "outcome": [1.0, 1.1, 1.2, 1.3, 1.4, 1.5], + "first_treat": [0, 0, 0, 0, 0, 0], # All never-treated + } + ) cs = CallawaySantAnna(base_period="universal") with pytest.raises(ValueError, match="Could not estimate any group-time effects"): cs.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - aggregate='event_study' + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", ) @@ -3472,9 +3194,378 @@ def test_nan_se_group_time_ci_is_nan(self): for (g, t), eff in results.group_time_effects.items(): se = eff["se"] if not (np.isfinite(se) and se > 0): - assert_nan_inference({ - "se": se, - "t_stat": eff["t_stat"], - "p_value": eff["p_value"], - "conf_int": eff["conf_int"], - }) + assert_nan_inference( + { + "se": se, + "t_stat": eff["t_stat"], + "p_value": eff["p_value"], + "conf_int": eff["conf_int"], + } + ) + + +class TestPscoreTrimParameter: + """Tests for the pscore_trim parameter.""" + + def test_get_params_includes_pscore_trim(self): + """pscore_trim is included in get_params().""" + cs = CallawaySantAnna(pscore_trim=0.05) + params = cs.get_params() + assert "pscore_trim" in params + assert params["pscore_trim"] == 0.05 + + def test_set_params_pscore_trim(self): + """pscore_trim can be set via set_params().""" + cs = CallawaySantAnna() + cs.set_params(pscore_trim=0.1) + assert cs.pscore_trim == 0.1 + + def test_set_params_invalid_pscore_trim_rejected_at_fit(self): + """Invalid pscore_trim via set_params() raises ValueError at fit().""" + np.random.seed(42) + n_units, n_periods = 50, 6 + units = np.repeat(np.arange(n_units), n_periods) + times = np.tile(np.arange(n_periods), n_units) + first_treat = np.zeros(n_units) + first_treat[n_units // 2 :] = 3 + first_treat_expanded = np.repeat(first_treat, n_periods) + post = (times >= first_treat_expanded) & (first_treat_expanded > 0) + outcomes = 1.0 + 2.0 * post + np.random.randn(len(units)) * 0.5 + data = pd.DataFrame( + { + "unit": units, + "time": times, + "outcome": outcomes, + "first_treat": first_treat_expanded.astype(int), + } + ) + + for bad_val in [0.0, -0.1, 0.5]: + cs = CallawaySantAnna(estimation_method="ipw") + cs.set_params(pscore_trim=bad_val) + with pytest.raises(ValueError, match="pscore_trim must be in"): + cs.fit(data, outcome="outcome", unit="unit", time="time", first_treat="first_treat") + + def test_default_pscore_trim(self): + """Default pscore_trim is 0.01.""" + cs = CallawaySantAnna() + assert cs.pscore_trim == 0.01 + + def test_pscore_trim_negative_raises(self): + """pscore_trim < 0 raises ValueError.""" + with pytest.raises(ValueError, match="pscore_trim must be in"): + CallawaySantAnna(pscore_trim=-0.1) + + def test_pscore_trim_at_half_raises(self): + """pscore_trim == 0.5 raises ValueError.""" + with pytest.raises(ValueError, match="pscore_trim must be in"): + CallawaySantAnna(pscore_trim=0.5) + + def test_pscore_trim_above_half_raises(self): + """pscore_trim > 0.5 raises ValueError.""" + with pytest.raises(ValueError, match="pscore_trim must be in"): + CallawaySantAnna(pscore_trim=0.6) + + def test_pscore_trim_zero_raises(self): + """pscore_trim=0.0 raises ValueError (would cause division by zero in IPW weights).""" + with pytest.raises(ValueError, match="pscore_trim must be in"): + CallawaySantAnna(pscore_trim=0.0) + + def test_pscore_trim_in_results(self): + """results.pscore_trim matches the estimator's setting after fit().""" + np.random.seed(42) + n_units, n_periods = 50, 6 + units = np.repeat(np.arange(n_units), n_periods) + times = np.tile(np.arange(n_periods), n_units) + first_treat = np.zeros(n_units) + first_treat[n_units // 2 :] = 3 + first_treat_expanded = np.repeat(first_treat, n_periods) + post = (times >= first_treat_expanded) & (first_treat_expanded > 0) + outcomes = 1.0 + 2.0 * post + np.random.randn(len(units)) * 0.5 + data = pd.DataFrame( + { + "unit": units, + "time": times, + "outcome": outcomes, + "first_treat": first_treat_expanded.astype(int), + } + ) + cs = CallawaySantAnna(pscore_trim=0.05, estimation_method="reg") + results = cs.fit( + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" + ) + assert results.pscore_trim == 0.05 + + def test_nondefault_pscore_trim_ipw(self): + """IPW with pscore_trim=0.1 produces finite results.""" + np.random.seed(42) + n_units, n_periods = 80, 6 + units = np.repeat(np.arange(n_units), n_periods) + times = np.tile(np.arange(n_periods), n_units) + x = np.random.randn(n_units) + x_expanded = np.repeat(x, n_periods) + first_treat = np.zeros(n_units) + first_treat[n_units // 2 :] = 3 + first_treat_expanded = np.repeat(first_treat, n_periods) + post = (times >= first_treat_expanded) & (first_treat_expanded > 0) + outcomes = 1.0 + 0.5 * x_expanded + 2.0 * post + np.random.randn(len(units)) * 0.5 + data = pd.DataFrame( + { + "unit": units, + "time": times, + "outcome": outcomes, + "first_treat": first_treat_expanded.astype(int), + "x": x_expanded, + } + ) + cs = CallawaySantAnna(estimation_method="ipw", pscore_trim=0.1) + results = cs.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + covariates=["x"], + ) + assert np.isfinite(results.overall_att) + assert results.pscore_trim == 0.1 + + def test_nondefault_pscore_trim_dr(self): + """DR with pscore_trim=0.1 produces finite results.""" + np.random.seed(42) + n_units, n_periods = 80, 6 + units = np.repeat(np.arange(n_units), n_periods) + times = np.tile(np.arange(n_periods), n_units) + x = np.random.randn(n_units) + x_expanded = np.repeat(x, n_periods) + first_treat = np.zeros(n_units) + first_treat[n_units // 2 :] = 3 + first_treat_expanded = np.repeat(first_treat, n_periods) + post = (times >= first_treat_expanded) & (first_treat_expanded > 0) + outcomes = 1.0 + 0.5 * x_expanded + 2.0 * post + np.random.randn(len(units)) * 0.5 + data = pd.DataFrame( + { + "unit": units, + "time": times, + "outcome": outcomes, + "first_treat": first_treat_expanded.astype(int), + "x": x_expanded, + } + ) + cs = CallawaySantAnna(estimation_method="dr", pscore_trim=0.1) + results = cs.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + covariates=["x"], + ) + assert np.isfinite(results.overall_att) + assert results.pscore_trim == 0.1 + + +class TestIRLSPropensityScore: + """Tests for IRLS-based propensity score estimation in CS estimator.""" + + def test_near_separation_warning_ipw(self): + """Near-separation emits warnings in the IPW path.""" + np.random.seed(42) + n_units = 100 + n_periods = 8 + + units = np.repeat(np.arange(n_units), n_periods) + times = np.tile(np.arange(n_periods), n_units) + + # Create a covariate that strongly predicts treatment + x_strong = np.random.randn(n_units) + x_strong_expanded = np.repeat(x_strong, n_periods) + + # Treatment perfectly aligned with covariate sign + first_treat = np.zeros(n_units) + first_treat[x_strong > 0] = 4 + first_treat_expanded = np.repeat(first_treat, n_periods) + + post = (times >= first_treat_expanded) & (first_treat_expanded > 0) + outcomes = 1.0 + x_strong_expanded + 2.0 * post + np.random.randn(len(units)) * 0.5 + + data = pd.DataFrame( + { + "unit": units, + "time": times, + "outcome": outcomes, + "first_treat": first_treat_expanded.astype(int), + "x_strong": x_strong_expanded, + } + ) + + cs = CallawaySantAnna(estimation_method="ipw") + + import warnings + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + results = cs.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + covariates=["x_strong"], + ) + + # Should see propensity-related warnings + pscore_warns = [ + x + for x in w + if "propensity" in str(x.message).lower() + or "separation" in str(x.message).lower() + or "trimmed" in str(x.message).lower() + ] + assert len(pscore_warns) > 0, "Expected propensity score warnings" + # ATT should still be reasonable (not wildly inflated) + assert results.overall_att is not None + assert np.isfinite(results.overall_att) + + def test_near_separation_att_not_inflated(self): + """IRLS produces reasonable ATT even with near-separation covariates. + + This is the key regression test for the reported bug: BFGS-based logit + produced wildly inflated ATT (~2.38 vs 0.45-1.15 in reference packages) + under near-separation conditions. + """ + np.random.seed(123) + n_units = 200 + n_periods = 8 + true_effect = 2.0 + + units = np.repeat(np.arange(n_units), n_periods) + times = np.tile(np.arange(n_periods), n_units) + + # Covariate that creates near-separation + x = np.random.randn(n_units) * 3 # large scale + x_expanded = np.repeat(x, n_periods) + + # Treatment correlated with covariate but not perfect + treat_prob = 1 / (1 + np.exp(-x)) + first_treat = np.zeros(n_units) + first_treat[treat_prob > 0.5] = 4 + first_treat_expanded = np.repeat(first_treat, n_periods) + + post = (times >= first_treat_expanded) & (first_treat_expanded > 0) + outcomes = 1.0 + x_expanded * 0.5 + true_effect * post + np.random.randn(len(units)) * 0.5 + + data = pd.DataFrame( + { + "unit": units, + "time": times, + "outcome": outcomes, + "first_treat": first_treat_expanded.astype(int), + "x": x_expanded, + } + ) + + cs = CallawaySantAnna(estimation_method="dr") + + import warnings + + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + results = cs.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + covariates=["x"], + ) + + # ATT should be in a reasonable range around the true effect + assert results.overall_att is not None + assert ( + abs(results.overall_att - true_effect) < 3.0 + ), f"ATT={results.overall_att} too far from true effect {true_effect}" + + def test_dr_fallback_warning(self): + """DR path emits warning when propensity estimation fails.""" + from unittest.mock import patch + + data = generate_staggered_data_with_covariates(seed=42) + + cs = CallawaySantAnna(estimation_method="dr") + + with patch("diff_diff.staggered.solve_logit", side_effect=ValueError("test")): + import warnings + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + results = cs.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + covariates=["x1"], + ) + + fallback_warns = [x for x in w if "Falling back to unconditional" in str(x.message)] + assert len(fallback_warns) > 0, "Expected fallback warning in DR path" + assert results.overall_att is not None + + def test_large_scale_covariate_stability(self): + """IRLS handles large-scale covariates without wild ATT inflation. + + Mimics scenario from Dias & Fontes (2024) audit where covariates + like poptotaltrend (population in millions) caused near-separation. + """ + np.random.seed(456) + n_units = 150 + n_periods = 8 + true_effect = 1.5 + + units = np.repeat(np.arange(n_units), n_periods) + times = np.tile(np.arange(n_periods), n_units) + + # Large-scale covariate (like population totals in millions) + x_large = np.random.randn(n_units) * 1e6 + x_expanded = np.repeat(x_large, n_periods) + + # Treatment mildly correlated with covariate + first_treat = np.zeros(n_units) + first_treat[x_large > np.median(x_large)] = 4 + first_treat_expanded = np.repeat(first_treat, n_periods) + + post = (times >= first_treat_expanded) & (first_treat_expanded > 0) + outcomes = 5.0 + x_expanded * 1e-7 + true_effect * post + np.random.randn(len(units)) * 0.5 + + data = pd.DataFrame( + { + "unit": units, + "time": times, + "outcome": outcomes, + "first_treat": first_treat_expanded.astype(int), + "x_large": x_expanded, + } + ) + + cs = CallawaySantAnna(estimation_method="dr") + + import warnings + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + results = cs.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + covariates=["x_large"], + ) + + assert results.overall_att is not None + assert np.isfinite(results.overall_att) + # ATT should be in a plausible range + assert ( + abs(results.overall_att - true_effect) < 5.0 + ), f"ATT={results.overall_att} too far from true effect {true_effect}"