From d9101766b63156e88c1076839177ea43df365d3f Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 2 Jan 2026 15:39:00 +0000 Subject: [PATCH] Add multi-period DiD support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extend the library to optionally work across multiple time periods: - Add MultiPeriodDiD estimator that handles multiple pre/post periods - Add MultiPeriodDiDResults for period-specific and average ATT - Add PeriodEffect dataclass for individual period treatment effects - Support for period dummies and treatment × period interactions - Compute average ATT with proper covariance-adjusted standard errors - Allow custom reference period selection - Auto-infer post-periods when not specified (last half of periods) - Full support for covariates, fixed effects, and absorbed FE - Cluster-robust and heteroskedasticity-robust standard errors - Comprehensive test suite with 28 new tests (70 total passing) --- diff_diff/__init__.py | 12 +- diff_diff/estimators.py | 333 ++++++++++++++++++++- diff_diff/results.py | 293 +++++++++++++++++++ tests/test_estimators.py | 605 ++++++++++++++++++++++++++++++++++++++- 4 files changed, 1238 insertions(+), 5 deletions(-) diff --git a/diff_diff/__init__.py b/diff_diff/__init__.py index 66f8a032..444ab801 100644 --- a/diff_diff/__init__.py +++ b/diff_diff/__init__.py @@ -5,8 +5,14 @@ using the difference-in-differences methodology. """ -from diff_diff.estimators import DifferenceInDifferences -from diff_diff.results import DiDResults +from diff_diff.estimators import DifferenceInDifferences, MultiPeriodDiD +from diff_diff.results import DiDResults, MultiPeriodDiDResults, PeriodEffect __version__ = "0.1.0" -__all__ = ["DifferenceInDifferences", "DiDResults"] +__all__ = [ + "DifferenceInDifferences", + "MultiPeriodDiD", + "DiDResults", + "MultiPeriodDiDResults", + "PeriodEffect", +] diff --git a/diff_diff/estimators.py b/diff_diff/estimators.py index 18d4b201..420ae60a 100644 --- a/diff_diff/estimators.py +++ b/diff_diff/estimators.py @@ -8,7 +8,7 @@ import pandas as pd from scipy import stats -from diff_diff.results import DiDResults +from diff_diff.results import DiDResults, MultiPeriodDiDResults, PeriodEffect from diff_diff.utils import ( validate_binary, compute_robust_se, @@ -717,3 +717,334 @@ def _within_transform( data[f"{var}_demeaned"] = data[var] - unit_means - time_means + grand_mean return data + + +class MultiPeriodDiD(DifferenceInDifferences): + """ + Multi-Period Difference-in-Differences estimator. + + Extends the standard DiD to handle multiple pre-treatment and + post-treatment time periods, providing period-specific treatment + effects as well as an aggregate average treatment effect. + + Parameters + ---------- + robust : bool, default=True + Whether to use heteroskedasticity-robust standard errors (HC1). + cluster : str, optional + Column name for cluster-robust standard errors. + alpha : float, default=0.05 + Significance level for confidence intervals. + + Attributes + ---------- + results_ : MultiPeriodDiDResults + Estimation results after calling fit(). + is_fitted_ : bool + Whether the model has been fitted. + + Examples + -------- + Basic usage with multiple time periods: + + >>> import pandas as pd + >>> from diff_diff import MultiPeriodDiD + >>> + >>> # Create sample panel data with 6 time periods + >>> # Periods 0-2 are pre-treatment, periods 3-5 are post-treatment + >>> data = create_panel_data() # Your data + >>> + >>> # Fit the model + >>> did = MultiPeriodDiD() + >>> results = did.fit( + ... data, + ... outcome='sales', + ... treatment='treated', + ... time='period', + ... post_periods=[3, 4, 5] # Specify which periods are post-treatment + ... ) + >>> + >>> # View period-specific effects + >>> for period, effect in results.period_effects.items(): + ... print(f"Period {period}: {effect.effect:.3f} (SE: {effect.se:.3f})") + >>> + >>> # View average treatment effect + >>> print(f"Average ATT: {results.avg_att:.3f}") + + Notes + ----- + The model estimates: + + Y_it = α + β*D_i + Σ_t γ_t*Period_t + Σ_t∈post δ_t*(D_i × Post_t) + ε_it + + Where: + - D_i is the treatment indicator + - Period_t are time period dummies + - D_i × Post_t are treatment-by-post-period interactions + - δ_t are the period-specific treatment effects + + The average ATT is computed as the mean of the δ_t coefficients. + """ + + def fit( + self, + data: pd.DataFrame, + outcome: str, + treatment: str, + time: str, + post_periods: list = None, + covariates: list = None, + fixed_effects: list = None, + absorb: list = None, + reference_period: any = None + ) -> MultiPeriodDiDResults: + """ + Fit the Multi-Period Difference-in-Differences model. + + Parameters + ---------- + data : pd.DataFrame + DataFrame containing the outcome, treatment, and time variables. + outcome : str + Name of the outcome variable column. + treatment : str + Name of the treatment group indicator column (0/1). + time : str + Name of the time period column (can have multiple values). + post_periods : list + List of time period values that are post-treatment. + All other periods are treated as pre-treatment. + covariates : list, optional + List of covariate column names to include as linear controls. + fixed_effects : list, optional + List of categorical column names to include as fixed effects. + absorb : list, optional + List of categorical column names for high-dimensional fixed effects. + reference_period : any, optional + The reference (omitted) time period for the period dummies. + Defaults to the first pre-treatment period. + + Returns + ------- + MultiPeriodDiDResults + Object containing period-specific and average treatment effects. + + Raises + ------ + ValueError + If required parameters are missing or data validation fails. + """ + # Validate basic inputs + if outcome is None or treatment is None or time is None: + raise ValueError( + "Must provide 'outcome', 'treatment', and 'time'" + ) + + # Validate columns exist + self._validate_data(data, outcome, treatment, time, covariates) + + # Validate treatment is binary + validate_binary(data[treatment].values, "treatment") + + # Get all unique time periods + all_periods = sorted(data[time].unique()) + + if len(all_periods) < 2: + raise ValueError("Time variable must have at least 2 unique periods") + + # Determine pre and post periods + if post_periods is None: + # Default: last half of periods are post-treatment + mid_point = len(all_periods) // 2 + post_periods = all_periods[mid_point:] + pre_periods = all_periods[:mid_point] + else: + post_periods = list(post_periods) + pre_periods = [p for p in all_periods if p not in post_periods] + + if len(post_periods) == 0: + raise ValueError("Must have at least one post-treatment period") + + if len(pre_periods) == 0: + raise ValueError("Must have at least one pre-treatment period") + + # Validate post_periods are in the data + for p in post_periods: + if p not in all_periods: + raise ValueError(f"Post-period '{p}' not found in time column") + + # Determine reference period (omitted dummy) + if reference_period is None: + reference_period = pre_periods[0] + elif reference_period not in all_periods: + raise ValueError(f"Reference period '{reference_period}' not found in time column") + + # Validate fixed effects and absorb columns + if fixed_effects: + for fe in fixed_effects: + if fe not in data.columns: + raise ValueError(f"Fixed effect column '{fe}' not found in data") + if absorb: + for ab in absorb: + if ab not in data.columns: + raise ValueError(f"Absorb column '{ab}' not found in data") + + # Handle absorbed fixed effects (within-transformation) + working_data = data.copy() + n_absorbed_effects = 0 + + if absorb: + vars_to_demean = [outcome] + (covariates or []) + for ab_var in absorb: + n_absorbed_effects += working_data[ab_var].nunique() - 1 + for var in vars_to_demean: + group_means = working_data.groupby(ab_var)[var].transform("mean") + working_data[var] = working_data[var] - group_means + + # Extract outcome and treatment + y = working_data[outcome].values.astype(float) + d = working_data[treatment].values.astype(float) + t = working_data[time].values + + # Build design matrix + # Start with intercept and treatment main effect + X = np.column_stack([np.ones(len(y)), d]) + var_names = ["const", treatment] + + # Add period dummies (excluding reference period) + non_ref_periods = [p for p in all_periods if p != reference_period] + period_dummy_indices = {} # Map period -> column index in X + + for period in non_ref_periods: + period_dummy = (t == period).astype(float) + X = np.column_stack([X, period_dummy]) + var_names.append(f"period_{period}") + period_dummy_indices[period] = X.shape[1] - 1 + + # Add treatment × post-period interactions + # These are our coefficients of interest + interaction_indices = {} # Map post-period -> column index in X + + for period in post_periods: + interaction = d * (t == period).astype(float) + X = np.column_stack([X, interaction]) + var_names.append(f"{treatment}:period_{period}") + interaction_indices[period] = X.shape[1] - 1 + + # Add covariates if provided + if covariates: + for cov in covariates: + X = np.column_stack([X, working_data[cov].values.astype(float)]) + var_names.append(cov) + + # Add fixed effects as dummy variables + if fixed_effects: + for fe in fixed_effects: + dummies = pd.get_dummies(working_data[fe], prefix=fe, drop_first=True) + for col in dummies.columns: + X = np.column_stack([X, dummies[col].values.astype(float)]) + var_names.append(col) + + # Fit OLS + coefficients, residuals, fitted, r_squared = self._fit_ols(X, y) + + # Compute standard errors + if self.cluster is not None: + cluster_ids = data[self.cluster].values + vcov = compute_robust_se(X, residuals, cluster_ids) + elif self.robust: + vcov = compute_robust_se(X, residuals) + else: + n = len(y) + k = X.shape[1] + mse = np.sum(residuals ** 2) / (n - k) + vcov = mse * np.linalg.inv(X.T @ X) + + # Degrees of freedom + df = len(y) - X.shape[1] - n_absorbed_effects + + # Extract period-specific treatment effects + period_effects = {} + effect_values = [] + effect_indices = [] + + for period in post_periods: + idx = interaction_indices[period] + effect = coefficients[idx] + se = np.sqrt(vcov[idx, idx]) + t_stat = effect / se + p_value = compute_p_value(t_stat, df=df) + conf_int = compute_confidence_interval(effect, se, self.alpha, df=df) + + period_effects[period] = PeriodEffect( + period=period, + effect=effect, + se=se, + t_stat=t_stat, + p_value=p_value, + conf_int=conf_int + ) + effect_values.append(effect) + effect_indices.append(idx) + + # Compute average treatment effect + # Average ATT = mean of period-specific effects + avg_att = np.mean(effect_values) + + # Standard error of average: need to account for covariance + # Var(avg) = (1/n^2) * sum of all elements in the sub-covariance matrix + n_post = len(post_periods) + sub_vcov = vcov[np.ix_(effect_indices, effect_indices)] + avg_var = np.sum(sub_vcov) / (n_post ** 2) + avg_se = np.sqrt(avg_var) + + avg_t_stat = avg_att / avg_se if avg_se > 0 else 0.0 + avg_p_value = compute_p_value(avg_t_stat, df=df) + avg_conf_int = compute_confidence_interval(avg_att, avg_se, self.alpha, df=df) + + # Count observations + n_treated = int(np.sum(d)) + n_control = int(np.sum(1 - d)) + + # Create coefficient dictionary + coef_dict = {name: coef for name, coef in zip(var_names, coefficients)} + + # Store results + self.results_ = MultiPeriodDiDResults( + period_effects=period_effects, + avg_att=avg_att, + avg_se=avg_se, + avg_t_stat=avg_t_stat, + avg_p_value=avg_p_value, + avg_conf_int=avg_conf_int, + n_obs=len(y), + n_treated=n_treated, + n_control=n_control, + pre_periods=pre_periods, + post_periods=post_periods, + alpha=self.alpha, + coefficients=coef_dict, + vcov=vcov, + residuals=residuals, + fitted_values=fitted, + r_squared=r_squared, + ) + + self._coefficients = coefficients + self._vcov = vcov + self.is_fitted_ = True + + return self.results_ + + def summary(self) -> str: + """ + Get summary of estimation results. + + Returns + ------- + str + Formatted summary. + """ + if not self.is_fitted_: + raise RuntimeError("Model must be fitted before calling summary()") + return self.results_.summary() diff --git a/diff_diff/results.py b/diff_diff/results.py index b12a9953..54779889 100644 --- a/diff_diff/results.py +++ b/diff_diff/results.py @@ -168,3 +168,296 @@ def significance_stars(self) -> str: elif self.p_value < 0.1: return "." return "" + + +def _get_significance_stars(p_value: float) -> str: + """Return significance stars based on p-value.""" + if p_value < 0.001: + return "***" + elif p_value < 0.01: + return "**" + elif p_value < 0.05: + return "*" + elif p_value < 0.1: + return "." + return "" + + +@dataclass +class PeriodEffect: + """ + Treatment effect for a single time period. + + Attributes + ---------- + period : any + The time period identifier. + effect : float + The treatment effect estimate for this period. + se : float + Standard error of the effect estimate. + t_stat : float + T-statistic for the effect estimate. + p_value : float + P-value for the null hypothesis that effect = 0. + conf_int : tuple[float, float] + Confidence interval for the effect. + """ + + period: any + effect: float + se: float + t_stat: float + p_value: float + conf_int: tuple + + def __repr__(self) -> str: + """Concise string representation.""" + sig = _get_significance_stars(self.p_value) + return ( + f"PeriodEffect(period={self.period}, effect={self.effect:.4f}{sig}, " + f"SE={self.se:.4f}, p={self.p_value:.4f})" + ) + + @property + def is_significant(self) -> bool: + """Check if the effect is statistically significant at 0.05 level.""" + return bool(self.p_value < 0.05) + + @property + def significance_stars(self) -> str: + """Return significance stars based on p-value.""" + return _get_significance_stars(self.p_value) + + +@dataclass +class MultiPeriodDiDResults: + """ + Results from a Multi-Period Difference-in-Differences estimation. + + Provides access to period-specific treatment effects as well as + an aggregate average treatment effect. + + Attributes + ---------- + period_effects : dict[any, PeriodEffect] + Dictionary mapping period identifiers to their PeriodEffect objects. + avg_att : float + Average Treatment effect on the Treated across all post-periods. + avg_se : float + Standard error of the average ATT. + avg_t_stat : float + T-statistic for the average ATT. + avg_p_value : float + P-value for the null hypothesis that average ATT = 0. + avg_conf_int : tuple[float, float] + Confidence interval for the average ATT. + n_obs : int + Number of observations used in estimation. + n_treated : int + Number of treated observations. + n_control : int + Number of control observations. + pre_periods : list + List of pre-treatment period identifiers. + post_periods : list + List of post-treatment period identifiers. + """ + + period_effects: dict + avg_att: float + avg_se: float + avg_t_stat: float + avg_p_value: float + avg_conf_int: tuple + n_obs: int + n_treated: int + n_control: int + pre_periods: list + post_periods: list + alpha: float = 0.05 + coefficients: Optional[dict] = field(default=None) + vcov: Optional[np.ndarray] = field(default=None) + residuals: Optional[np.ndarray] = field(default=None) + fitted_values: Optional[np.ndarray] = field(default=None) + r_squared: Optional[float] = field(default=None) + + def __repr__(self) -> str: + """Concise string representation.""" + sig = _get_significance_stars(self.avg_p_value) + return ( + f"MultiPeriodDiDResults(avg_ATT={self.avg_att:.4f}{sig}, " + f"SE={self.avg_se:.4f}, " + f"n_post_periods={len(self.post_periods)})" + ) + + def summary(self, alpha: Optional[float] = None) -> str: + """ + Generate a formatted summary of the estimation results. + + Parameters + ---------- + alpha : float, optional + Significance level for confidence intervals. Defaults to the + alpha used during estimation. + + Returns + ------- + str + Formatted summary table. + """ + alpha = alpha or self.alpha + conf_level = int((1 - alpha) * 100) + + lines = [ + "=" * 80, + "Multi-Period Difference-in-Differences Estimation Results".center(80), + "=" * 80, + "", + f"{'Observations:':<25} {self.n_obs:>10}", + f"{'Treated observations:':<25} {self.n_treated:>10}", + f"{'Control observations:':<25} {self.n_control:>10}", + f"{'Pre-treatment periods:':<25} {len(self.pre_periods):>10}", + f"{'Post-treatment periods:':<25} {len(self.post_periods):>10}", + ] + + if self.r_squared is not None: + lines.append(f"{'R-squared:':<25} {self.r_squared:>10.4f}") + + # Period-specific effects + lines.extend([ + "", + "-" * 80, + "Period-Specific Treatment Effects".center(80), + "-" * 80, + f"{'Period':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", + "-" * 80, + ]) + + for period in self.post_periods: + pe = self.period_effects[period] + stars = pe.significance_stars + lines.append( + f"{str(period):<15} {pe.effect:>12.4f} {pe.se:>12.4f} " + f"{pe.t_stat:>10.3f} {pe.p_value:>10.4f} {stars:>6}" + ) + + # Average effect + lines.extend([ + "-" * 80, + "", + "-" * 80, + "Average Treatment Effect (across post-periods)".center(80), + "-" * 80, + f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10}", + "-" * 80, + f"{'Avg ATT':<15} {self.avg_att:>12.4f} {self.avg_se:>12.4f} " + f"{self.avg_t_stat:>10.3f} {self.avg_p_value:>10.4f}", + "-" * 80, + "", + f"{conf_level}% Confidence Interval: [{self.avg_conf_int[0]:.4f}, {self.avg_conf_int[1]:.4f}]", + ]) + + # Add significance codes + lines.extend([ + "", + "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1", + "=" * 80, + ]) + + return "\n".join(lines) + + def print_summary(self, alpha: Optional[float] = None) -> None: + """Print the summary to stdout.""" + print(self.summary(alpha)) + + def get_effect(self, period) -> PeriodEffect: + """ + Get the treatment effect for a specific period. + + Parameters + ---------- + period : any + The period identifier. + + Returns + ------- + PeriodEffect + The treatment effect for the specified period. + + Raises + ------ + KeyError + If the period is not found in post-treatment periods. + """ + if period not in self.period_effects: + raise KeyError( + f"Period '{period}' not found. " + f"Available post-periods: {list(self.period_effects.keys())}" + ) + return self.period_effects[period] + + def to_dict(self) -> dict: + """ + Convert results to a dictionary. + + Returns + ------- + dict + Dictionary containing all estimation results. + """ + result = { + "avg_att": self.avg_att, + "avg_se": self.avg_se, + "avg_t_stat": self.avg_t_stat, + "avg_p_value": self.avg_p_value, + "avg_conf_int_lower": self.avg_conf_int[0], + "avg_conf_int_upper": self.avg_conf_int[1], + "n_obs": self.n_obs, + "n_treated": self.n_treated, + "n_control": self.n_control, + "n_pre_periods": len(self.pre_periods), + "n_post_periods": len(self.post_periods), + "r_squared": self.r_squared, + } + + # Add period-specific effects + for period, pe in self.period_effects.items(): + result[f"effect_period_{period}"] = pe.effect + result[f"se_period_{period}"] = pe.se + result[f"pval_period_{period}"] = pe.p_value + + return result + + def to_dataframe(self) -> pd.DataFrame: + """ + Convert period-specific effects to a pandas DataFrame. + + Returns + ------- + pd.DataFrame + DataFrame with one row per post-treatment period. + """ + rows = [] + for period, pe in self.period_effects.items(): + rows.append({ + "period": period, + "effect": pe.effect, + "se": pe.se, + "t_stat": pe.t_stat, + "p_value": pe.p_value, + "conf_int_lower": pe.conf_int[0], + "conf_int_upper": pe.conf_int[1], + "is_significant": pe.is_significant, + }) + return pd.DataFrame(rows) + + @property + def is_significant(self) -> bool: + """Check if the average ATT is statistically significant at the alpha level.""" + return bool(self.avg_p_value < self.alpha) + + @property + def significance_stars(self) -> str: + """Return significance stars for the average ATT based on p-value.""" + return _get_significance_stars(self.avg_p_value) diff --git a/tests/test_estimators.py b/tests/test_estimators.py index 62fdbda0..ec990a4e 100644 --- a/tests/test_estimators.py +++ b/tests/test_estimators.py @@ -4,7 +4,13 @@ import pandas as pd import pytest -from diff_diff import DifferenceInDifferences, DiDResults +from diff_diff import ( + DifferenceInDifferences, + DiDResults, + MultiPeriodDiD, + MultiPeriodDiDResults, + PeriodEffect, +) @pytest.fixture @@ -958,3 +964,600 @@ def test_cluster_robust_se(self): # SEs should be different (cluster-robust typically larger) assert results_cluster.se != results_no_cluster.se + + +class TestMultiPeriodDiD: + """Tests for MultiPeriodDiD estimator.""" + + @pytest.fixture + def multi_period_data(self): + """Create panel data with multiple time periods and known ATT.""" + np.random.seed(42) + n_units = 100 + n_periods = 6 # 3 pre-treatment, 3 post-treatment + + data = [] + for unit in range(n_units): + is_treated = unit < n_units // 2 + unit_effect = np.random.normal(0, 1) + + for period in range(n_periods): + # Common time trend + time_effect = period * 0.5 + + y = 10.0 + unit_effect + time_effect + + # Treatment effect: 3.0 in post-periods (periods 3, 4, 5) + if is_treated and period >= 3: + y += 3.0 + + y += np.random.normal(0, 0.5) + + data.append({ + "unit": unit, + "period": period, + "treated": int(is_treated), + "outcome": y, + }) + + return pd.DataFrame(data) + + @pytest.fixture + def heterogeneous_effects_data(self): + """Create data with different treatment effects per period.""" + np.random.seed(42) + n_units = 100 + n_periods = 6 + + # Different true effects per post-period + true_effects = {3: 2.0, 4: 3.0, 5: 4.0} + + data = [] + for unit in range(n_units): + is_treated = unit < n_units // 2 + unit_effect = np.random.normal(0, 1) + + for period in range(n_periods): + time_effect = period * 0.5 + y = 10.0 + unit_effect + time_effect + + # Period-specific treatment effects + if is_treated and period in true_effects: + y += true_effects[period] + + y += np.random.normal(0, 0.5) + + data.append({ + "unit": unit, + "period": period, + "treated": int(is_treated), + "outcome": y, + }) + + return pd.DataFrame(data), true_effects + + def test_basic_fit(self, multi_period_data): + """Test basic model fitting with multiple periods.""" + did = MultiPeriodDiD() + results = did.fit( + multi_period_data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5] + ) + + assert isinstance(results, MultiPeriodDiDResults) + assert did.is_fitted_ + assert results.n_obs == 600 # 100 units * 6 periods + assert len(results.period_effects) == 3 # 3 post-periods + assert len(results.pre_periods) == 3 + assert len(results.post_periods) == 3 + + def test_avg_att_close_to_true(self, multi_period_data): + """Test that average ATT is close to true effect.""" + did = MultiPeriodDiD() + results = did.fit( + multi_period_data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5] + ) + + # True ATT is 3.0 + assert abs(results.avg_att - 3.0) < 0.5 + assert results.avg_att > 0 + + def test_period_specific_effects(self, heterogeneous_effects_data): + """Test that period-specific effects are estimated correctly.""" + data, true_effects = heterogeneous_effects_data + + did = MultiPeriodDiD() + results = did.fit( + data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5] + ) + + # Each period-specific effect should be close to truth + for period, true_effect in true_effects.items(): + estimated = results.period_effects[period].effect + assert abs(estimated - true_effect) < 0.5, \ + f"Period {period}: expected ~{true_effect}, got {estimated}" + + def test_period_effects_have_all_stats(self, multi_period_data): + """Test that period effects contain all statistics.""" + did = MultiPeriodDiD() + results = did.fit( + multi_period_data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5] + ) + + for period, pe in results.period_effects.items(): + assert isinstance(pe, PeriodEffect) + assert hasattr(pe, 'effect') + assert hasattr(pe, 'se') + assert hasattr(pe, 't_stat') + assert hasattr(pe, 'p_value') + assert hasattr(pe, 'conf_int') + assert pe.se > 0 + assert len(pe.conf_int) == 2 + assert pe.conf_int[0] < pe.conf_int[1] + + def test_get_effect_method(self, multi_period_data): + """Test get_effect method.""" + did = MultiPeriodDiD() + results = did.fit( + multi_period_data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5] + ) + + # Valid period + effect = results.get_effect(4) + assert isinstance(effect, PeriodEffect) + assert effect.period == 4 + + # Invalid period + with pytest.raises(KeyError): + results.get_effect(0) # Pre-period + + def test_auto_infer_post_periods(self, multi_period_data): + """Test automatic inference of post-periods.""" + did = MultiPeriodDiD() + results = did.fit( + multi_period_data, + outcome="outcome", + treatment="treated", + time="period" + # post_periods not specified - should infer last half + ) + + # With 6 periods, should infer periods 3, 4, 5 as post + assert results.pre_periods == [0, 1, 2] + assert results.post_periods == [3, 4, 5] + + def test_custom_reference_period(self, multi_period_data): + """Test custom reference period.""" + did = MultiPeriodDiD() + results = did.fit( + multi_period_data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5], + reference_period=2 # Use period 2 as reference + ) + + # Should work and give reasonable results + assert results is not None + assert did.is_fitted_ + # Reference period should not be in coefficients as a dummy + assert "period_2" not in results.coefficients + + def test_with_covariates(self, multi_period_data): + """Test multi-period DiD with covariates.""" + # Add a covariate + multi_period_data["size"] = np.random.normal(100, 10, len(multi_period_data)) + + did = MultiPeriodDiD() + results = did.fit( + multi_period_data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5], + covariates=["size"] + ) + + assert results is not None + assert "size" in results.coefficients + + def test_with_fixed_effects(self): + """Test multi-period DiD with fixed effects.""" + np.random.seed(42) + n_units = 50 + n_periods = 6 + n_states = 5 + + data = [] + for unit in range(n_units): + state = unit % n_states + is_treated = unit < n_units // 2 + state_effect = state * 2.0 + + for period in range(n_periods): + y = 10.0 + state_effect + period * 0.5 + if is_treated and period >= 3: + y += 3.0 + y += np.random.normal(0, 0.5) + + data.append({ + "unit": unit, + "state": f"state_{state}", + "period": period, + "treated": int(is_treated), + "outcome": y, + }) + + df = pd.DataFrame(data) + + did = MultiPeriodDiD() + results = did.fit( + df, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5], + fixed_effects=["state"] + ) + + assert results is not None + assert did.is_fitted_ + # ATT should still be close to 3.0 + assert abs(results.avg_att - 3.0) < 1.0 + + def test_with_absorbed_fe(self, multi_period_data): + """Test multi-period DiD with absorbed fixed effects.""" + did = MultiPeriodDiD() + results = did.fit( + multi_period_data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5], + absorb=["unit"] + ) + + assert results is not None + assert did.is_fitted_ + assert abs(results.avg_att - 3.0) < 1.0 + + def test_cluster_robust_se(self, multi_period_data): + """Test cluster-robust standard errors.""" + did_cluster = MultiPeriodDiD(cluster="unit") + did_robust = MultiPeriodDiD(robust=True) + + results_cluster = did_cluster.fit( + multi_period_data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5] + ) + + results_robust = did_robust.fit( + multi_period_data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5] + ) + + # ATT should be similar + assert abs(results_cluster.avg_att - results_robust.avg_att) < 0.01 + + # SEs should be different + assert results_cluster.avg_se != results_robust.avg_se + + def test_summary_output(self, multi_period_data): + """Test that summary produces string output.""" + did = MultiPeriodDiD() + results = did.fit( + multi_period_data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5] + ) + + summary = results.summary() + assert isinstance(summary, str) + assert "Multi-Period" in summary + assert "Period-Specific" in summary + assert "Average Treatment Effect" in summary + assert "Avg ATT" in summary + + def test_to_dict(self, multi_period_data): + """Test conversion to dictionary.""" + did = MultiPeriodDiD() + results = did.fit( + multi_period_data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5] + ) + + result_dict = results.to_dict() + assert "avg_att" in result_dict + assert "avg_se" in result_dict + assert "n_pre_periods" in result_dict + assert "n_post_periods" in result_dict + + def test_to_dataframe(self, multi_period_data): + """Test conversion to DataFrame.""" + did = MultiPeriodDiD() + results = did.fit( + multi_period_data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5] + ) + + df = results.to_dataframe() + assert isinstance(df, pd.DataFrame) + assert len(df) == 3 # 3 post-periods + assert "period" in df.columns + assert "effect" in df.columns + assert "p_value" in df.columns + + def test_is_significant_property(self, multi_period_data): + """Test is_significant property.""" + did = MultiPeriodDiD() + results = did.fit( + multi_period_data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5] + ) + + # With true effect of 3.0, should be significant + assert isinstance(results.is_significant, bool) + assert results.is_significant + + def test_significance_stars(self, multi_period_data): + """Test significance stars property.""" + did = MultiPeriodDiD() + results = did.fit( + multi_period_data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5] + ) + + # Should have significance stars + assert results.significance_stars in ["*", "**", "***"] + + def test_repr(self, multi_period_data): + """Test string representation.""" + did = MultiPeriodDiD() + results = did.fit( + multi_period_data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5] + ) + + repr_str = repr(results) + assert "MultiPeriodDiDResults" in repr_str + assert "avg_ATT=" in repr_str + + def test_period_effect_repr(self, multi_period_data): + """Test PeriodEffect string representation.""" + did = MultiPeriodDiD() + results = did.fit( + multi_period_data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5] + ) + + pe = results.period_effects[3] + repr_str = repr(pe) + assert "PeriodEffect" in repr_str + assert "period=" in repr_str + assert "effect=" in repr_str + + def test_invalid_post_period(self, multi_period_data): + """Test error when post_period not in data.""" + did = MultiPeriodDiD() + with pytest.raises(ValueError, match="not found in time column"): + did.fit( + multi_period_data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 99] # 99 doesn't exist + ) + + def test_no_pre_periods_error(self, multi_period_data): + """Test error when all periods are post-treatment.""" + did = MultiPeriodDiD() + with pytest.raises(ValueError, match="at least one pre-treatment period"): + did.fit( + multi_period_data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[0, 1, 2, 3, 4, 5] # All periods + ) + + def test_no_post_periods_error(self): + """Test error when no post-treatment periods.""" + data = pd.DataFrame({ + "outcome": [10, 11, 12, 13], + "treated": [1, 1, 0, 0], + "period": [0, 1, 0, 1], + }) + + did = MultiPeriodDiD() + with pytest.raises(ValueError, match="at least one post-treatment period"): + did.fit( + data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[] + ) + + def test_invalid_treatment_values(self, multi_period_data): + """Test error on non-binary treatment.""" + multi_period_data["treated"] = multi_period_data["treated"] * 2 # Makes values 0, 2 + + did = MultiPeriodDiD() + with pytest.raises(ValueError, match="binary"): + did.fit( + multi_period_data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5] + ) + + def test_unfitted_model_error(self): + """Test error when accessing results before fitting.""" + did = MultiPeriodDiD() + with pytest.raises(RuntimeError, match="fitted"): + did.summary() + + def test_confidence_interval_contains_estimate(self, multi_period_data): + """Test that confidence intervals contain the estimates.""" + did = MultiPeriodDiD() + results = did.fit( + multi_period_data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5] + ) + + # Average ATT CI + lower, upper = results.avg_conf_int + assert lower < results.avg_att < upper + + # Period-specific CIs + for pe in results.period_effects.values(): + lower, upper = pe.conf_int + assert lower < pe.effect < upper + + def test_two_periods_works(self): + """Test that MultiPeriodDiD works with just 2 periods (edge case).""" + np.random.seed(42) + data = [] + for unit in range(50): + is_treated = unit < 25 + for period in [0, 1]: + y = 10.0 + (3.0 if is_treated and period == 1 else 0) + y += np.random.normal(0, 0.5) + data.append({ + "unit": unit, + "period": period, + "treated": int(is_treated), + "outcome": y, + }) + + df = pd.DataFrame(data) + + did = MultiPeriodDiD() + results = did.fit( + df, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[1] + ) + + assert len(results.period_effects) == 1 + assert len(results.pre_periods) == 1 + assert abs(results.avg_att - 3.0) < 1.0 + + def test_many_periods(self): + """Test with many time periods.""" + np.random.seed(42) + n_periods = 20 + data = [] + for unit in range(50): + is_treated = unit < 25 + for period in range(n_periods): + y = 10.0 + period * 0.1 + if is_treated and period >= 10: + y += 2.5 + y += np.random.normal(0, 0.3) + data.append({ + "unit": unit, + "period": period, + "treated": int(is_treated), + "outcome": y, + }) + + df = pd.DataFrame(data) + + did = MultiPeriodDiD() + results = did.fit( + df, + outcome="outcome", + treatment="treated", + time="period", + post_periods=list(range(10, 20)) + ) + + assert len(results.period_effects) == 10 + assert len(results.pre_periods) == 10 + assert abs(results.avg_att - 2.5) < 0.5 + + def test_r_squared_reported(self, multi_period_data): + """Test that R-squared is reported.""" + did = MultiPeriodDiD() + results = did.fit( + multi_period_data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5] + ) + + assert results.r_squared is not None + assert 0 <= results.r_squared <= 1 + + def test_coefficients_dict(self, multi_period_data): + """Test that coefficients dictionary contains expected keys.""" + did = MultiPeriodDiD() + results = did.fit( + multi_period_data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5] + ) + + # Should have treatment, period dummies, and interactions + assert "treated" in results.coefficients + assert "const" in results.coefficients + # Period dummies (excluding reference) + assert any("period_" in k for k in results.coefficients) + # Treatment interactions + assert any("treated:period_" in k for k in results.coefficients)