Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions diff_diff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,14 @@
using the difference-in-differences methodology.
"""

from diff_diff.estimators import DifferenceInDifferences
from diff_diff.results import DiDResults
from diff_diff.estimators import DifferenceInDifferences, MultiPeriodDiD
from diff_diff.results import DiDResults, MultiPeriodDiDResults, PeriodEffect

__version__ = "0.1.0"
__all__ = ["DifferenceInDifferences", "DiDResults"]
__all__ = [
"DifferenceInDifferences",
"MultiPeriodDiD",
"DiDResults",
"MultiPeriodDiDResults",
"PeriodEffect",
]
333 changes: 332 additions & 1 deletion diff_diff/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pandas as pd
from scipy import stats

from diff_diff.results import DiDResults
from diff_diff.results import DiDResults, MultiPeriodDiDResults, PeriodEffect
from diff_diff.utils import (
validate_binary,
compute_robust_se,
Expand Down Expand Up @@ -717,3 +717,334 @@ def _within_transform(
data[f"{var}_demeaned"] = data[var] - unit_means - time_means + grand_mean

return data


class MultiPeriodDiD(DifferenceInDifferences):
"""
Multi-Period Difference-in-Differences estimator.

Extends the standard DiD to handle multiple pre-treatment and
post-treatment time periods, providing period-specific treatment
effects as well as an aggregate average treatment effect.

Parameters
----------
robust : bool, default=True
Whether to use heteroskedasticity-robust standard errors (HC1).
cluster : str, optional
Column name for cluster-robust standard errors.
alpha : float, default=0.05
Significance level for confidence intervals.

Attributes
----------
results_ : MultiPeriodDiDResults
Estimation results after calling fit().
is_fitted_ : bool
Whether the model has been fitted.

Examples
--------
Basic usage with multiple time periods:

>>> import pandas as pd
>>> from diff_diff import MultiPeriodDiD
>>>
>>> # Create sample panel data with 6 time periods
>>> # Periods 0-2 are pre-treatment, periods 3-5 are post-treatment
>>> data = create_panel_data() # Your data
>>>
>>> # Fit the model
>>> did = MultiPeriodDiD()
>>> results = did.fit(
... data,
... outcome='sales',
... treatment='treated',
... time='period',
... post_periods=[3, 4, 5] # Specify which periods are post-treatment
... )
>>>
>>> # View period-specific effects
>>> for period, effect in results.period_effects.items():
... print(f"Period {period}: {effect.effect:.3f} (SE: {effect.se:.3f})")
>>>
>>> # View average treatment effect
>>> print(f"Average ATT: {results.avg_att:.3f}")

Notes
-----
The model estimates:

Y_it = α + β*D_i + Σ_t γ_t*Period_t + Σ_t∈post δ_t*(D_i × Post_t) + ε_it

Where:
- D_i is the treatment indicator
- Period_t are time period dummies
- D_i × Post_t are treatment-by-post-period interactions
- δ_t are the period-specific treatment effects

The average ATT is computed as the mean of the δ_t coefficients.
"""

def fit(
self,
data: pd.DataFrame,
outcome: str,
treatment: str,
time: str,
post_periods: list = None,
covariates: list = None,
fixed_effects: list = None,
absorb: list = None,
reference_period: any = None
) -> MultiPeriodDiDResults:
"""
Fit the Multi-Period Difference-in-Differences model.

Parameters
----------
data : pd.DataFrame
DataFrame containing the outcome, treatment, and time variables.
outcome : str
Name of the outcome variable column.
treatment : str
Name of the treatment group indicator column (0/1).
time : str
Name of the time period column (can have multiple values).
post_periods : list
List of time period values that are post-treatment.
All other periods are treated as pre-treatment.
covariates : list, optional
List of covariate column names to include as linear controls.
fixed_effects : list, optional
List of categorical column names to include as fixed effects.
absorb : list, optional
List of categorical column names for high-dimensional fixed effects.
reference_period : any, optional
The reference (omitted) time period for the period dummies.
Defaults to the first pre-treatment period.

Returns
-------
MultiPeriodDiDResults
Object containing period-specific and average treatment effects.

Raises
------
ValueError
If required parameters are missing or data validation fails.
"""
# Validate basic inputs
if outcome is None or treatment is None or time is None:
raise ValueError(
"Must provide 'outcome', 'treatment', and 'time'"
)

# Validate columns exist
self._validate_data(data, outcome, treatment, time, covariates)

# Validate treatment is binary
validate_binary(data[treatment].values, "treatment")

# Get all unique time periods
all_periods = sorted(data[time].unique())

if len(all_periods) < 2:
raise ValueError("Time variable must have at least 2 unique periods")

# Determine pre and post periods
if post_periods is None:
# Default: last half of periods are post-treatment
mid_point = len(all_periods) // 2
post_periods = all_periods[mid_point:]
pre_periods = all_periods[:mid_point]
else:
post_periods = list(post_periods)
pre_periods = [p for p in all_periods if p not in post_periods]

if len(post_periods) == 0:
raise ValueError("Must have at least one post-treatment period")

if len(pre_periods) == 0:
raise ValueError("Must have at least one pre-treatment period")

# Validate post_periods are in the data
for p in post_periods:
if p not in all_periods:
raise ValueError(f"Post-period '{p}' not found in time column")

# Determine reference period (omitted dummy)
if reference_period is None:
reference_period = pre_periods[0]
elif reference_period not in all_periods:
raise ValueError(f"Reference period '{reference_period}' not found in time column")

# Validate fixed effects and absorb columns
if fixed_effects:
for fe in fixed_effects:
if fe not in data.columns:
raise ValueError(f"Fixed effect column '{fe}' not found in data")
if absorb:
for ab in absorb:
if ab not in data.columns:
raise ValueError(f"Absorb column '{ab}' not found in data")

# Handle absorbed fixed effects (within-transformation)
working_data = data.copy()
n_absorbed_effects = 0

if absorb:
vars_to_demean = [outcome] + (covariates or [])
for ab_var in absorb:
n_absorbed_effects += working_data[ab_var].nunique() - 1
for var in vars_to_demean:
group_means = working_data.groupby(ab_var)[var].transform("mean")
working_data[var] = working_data[var] - group_means

# Extract outcome and treatment
y = working_data[outcome].values.astype(float)
d = working_data[treatment].values.astype(float)
t = working_data[time].values

# Build design matrix
# Start with intercept and treatment main effect
X = np.column_stack([np.ones(len(y)), d])
var_names = ["const", treatment]

# Add period dummies (excluding reference period)
non_ref_periods = [p for p in all_periods if p != reference_period]
period_dummy_indices = {} # Map period -> column index in X

for period in non_ref_periods:
period_dummy = (t == period).astype(float)
X = np.column_stack([X, period_dummy])
var_names.append(f"period_{period}")
period_dummy_indices[period] = X.shape[1] - 1

# Add treatment × post-period interactions
# These are our coefficients of interest
interaction_indices = {} # Map post-period -> column index in X

for period in post_periods:
interaction = d * (t == period).astype(float)
X = np.column_stack([X, interaction])
var_names.append(f"{treatment}:period_{period}")
interaction_indices[period] = X.shape[1] - 1

# Add covariates if provided
if covariates:
for cov in covariates:
X = np.column_stack([X, working_data[cov].values.astype(float)])
var_names.append(cov)

# Add fixed effects as dummy variables
if fixed_effects:
for fe in fixed_effects:
dummies = pd.get_dummies(working_data[fe], prefix=fe, drop_first=True)
for col in dummies.columns:
X = np.column_stack([X, dummies[col].values.astype(float)])
var_names.append(col)

# Fit OLS
coefficients, residuals, fitted, r_squared = self._fit_ols(X, y)

# Compute standard errors
if self.cluster is not None:
cluster_ids = data[self.cluster].values
vcov = compute_robust_se(X, residuals, cluster_ids)
elif self.robust:
vcov = compute_robust_se(X, residuals)
else:
n = len(y)
k = X.shape[1]
mse = np.sum(residuals ** 2) / (n - k)
vcov = mse * np.linalg.inv(X.T @ X)

# Degrees of freedom
df = len(y) - X.shape[1] - n_absorbed_effects

# Extract period-specific treatment effects
period_effects = {}
effect_values = []
effect_indices = []

for period in post_periods:
idx = interaction_indices[period]
effect = coefficients[idx]
se = np.sqrt(vcov[idx, idx])
t_stat = effect / se
p_value = compute_p_value(t_stat, df=df)
conf_int = compute_confidence_interval(effect, se, self.alpha, df=df)

period_effects[period] = PeriodEffect(
period=period,
effect=effect,
se=se,
t_stat=t_stat,
p_value=p_value,
conf_int=conf_int
)
effect_values.append(effect)
effect_indices.append(idx)

# Compute average treatment effect
# Average ATT = mean of period-specific effects
avg_att = np.mean(effect_values)

# Standard error of average: need to account for covariance
# Var(avg) = (1/n^2) * sum of all elements in the sub-covariance matrix
n_post = len(post_periods)
sub_vcov = vcov[np.ix_(effect_indices, effect_indices)]
avg_var = np.sum(sub_vcov) / (n_post ** 2)
avg_se = np.sqrt(avg_var)

avg_t_stat = avg_att / avg_se if avg_se > 0 else 0.0
avg_p_value = compute_p_value(avg_t_stat, df=df)
avg_conf_int = compute_confidence_interval(avg_att, avg_se, self.alpha, df=df)

# Count observations
n_treated = int(np.sum(d))
n_control = int(np.sum(1 - d))

# Create coefficient dictionary
coef_dict = {name: coef for name, coef in zip(var_names, coefficients)}

# Store results
self.results_ = MultiPeriodDiDResults(
period_effects=period_effects,
avg_att=avg_att,
avg_se=avg_se,
avg_t_stat=avg_t_stat,
avg_p_value=avg_p_value,
avg_conf_int=avg_conf_int,
n_obs=len(y),
n_treated=n_treated,
n_control=n_control,
pre_periods=pre_periods,
post_periods=post_periods,
alpha=self.alpha,
coefficients=coef_dict,
vcov=vcov,
residuals=residuals,
fitted_values=fitted,
r_squared=r_squared,
)

self._coefficients = coefficients
self._vcov = vcov
self.is_fitted_ = True

return self.results_

def summary(self) -> str:
"""
Get summary of estimation results.

Returns
-------
str
Formatted summary.
"""
if not self.is_fitted_:
raise RuntimeError("Model must be fitted before calling summary()")
return self.results_.summary()
Loading