Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 46 additions & 22 deletions diff_diff/prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
import numpy as np
import pandas as pd

from diff_diff.utils import compute_synthetic_weights

# Constants for rank_control_units
_SIMILARITY_THRESHOLD_SD = 0.5 # Controls within this many SDs are "similar"
_OUTLIER_PENALTY_WEIGHT = 0.3 # Penalty weight for outcome outliers in treatment candidate scoring


def make_treatment_indicator(
data: pd.DataFrame,
Expand Down Expand Up @@ -985,9 +991,6 @@ def rank_control_units(
>>> top_controls = ranking['unit'].tolist()
>>> filtered = data[(data['treated'] == 1) | (data['unit'].isin(top_controls))]
"""
# Import compute_synthetic_weights from utils
from diff_diff.utils import compute_synthetic_weights

# -------------------------------------------------------------------------
# Input validation
# -------------------------------------------------------------------------
Expand Down Expand Up @@ -1080,6 +1083,11 @@ def rank_control_units(
if len(valid_pre_periods) == 0:
raise ValueError("No data found for specified pre-treatment periods.")

# Filter control_candidates to those present in pivot (handles unbalanced panels)
control_candidates = [c for c in control_candidates if c in pivot.columns]
if len(control_candidates) == 0:
raise ValueError("No control units found in pre-treatment data.")

# Control outcomes: shape (n_pre_periods, n_control_candidates)
Y_control = pivot.loc[valid_pre_periods, control_candidates].values.astype(float)

Expand All @@ -1097,18 +1105,24 @@ def rank_control_units(
Y_control, Y_treated_mean, lambda_reg=lambda_reg
)

# RMSE for each control vs treated mean
# RMSE for each control vs treated mean (use nanmean to handle missing data)
rmse_scores = []
for j in range(len(control_candidates)):
y_c = Y_control[:, j]
rmse = np.sqrt(np.mean((y_c - Y_treated_mean) ** 2))
rmse = np.sqrt(np.nanmean((y_c - Y_treated_mean) ** 2))
rmse_scores.append(rmse)

# Convert RMSE to similarity score (lower RMSE = higher score)
max_rmse = max(rmse_scores) if rmse_scores else 1.0
outcome_trend_scores = [
1 - (rmse / (max_rmse + 1e-10)) for rmse in rmse_scores
]
min_rmse = min(rmse_scores) if rmse_scores else 0.0
rmse_range = max_rmse - min_rmse

if rmse_range < 1e-10:
# All controls have identical/similar pre-trends (includes single control case)
outcome_trend_scores = [1.0] * len(rmse_scores)
else:
# Normalize so best control gets 1.0, worst gets 0.0
outcome_trend_scores = [1 - (rmse - min_rmse) / rmse_range for rmse in rmse_scores]

# -------------------------------------------------------------------------
# Compute covariate scores (if covariates provided)
Expand All @@ -1126,18 +1140,24 @@ def rank_control_units(
cov_standardized = (cov_data - cov_mean) / cov_std
treated_cov_std = (treated_cov - cov_mean) / cov_std

# Euclidean distance in standardized space
covariate_distances = []
for control_unit in control_candidates:
control_cov_std = cov_standardized.loc[control_unit]
dist = np.sqrt(np.sum((control_cov_std - treated_cov_std) ** 2))
covariate_distances.append(dist)

# Convert distance to similarity score
max_dist = max(covariate_distances) if covariate_distances else 1.0
covariate_scores = [
1 - (d / (max_dist + 1e-10)) for d in covariate_distances
]
# Euclidean distance in standardized space (vectorized)
control_cov_matrix = cov_standardized.loc[control_candidates].values
treated_cov_vector = treated_cov_std.values
covariate_distances = np.sqrt(
np.sum((control_cov_matrix - treated_cov_vector) ** 2, axis=1)
)

# Convert distance to similarity score (min-max normalization)
max_dist = covariate_distances.max() if len(covariate_distances) > 0 else 1.0
min_dist = covariate_distances.min() if len(covariate_distances) > 0 else 0.0
dist_range = max_dist - min_dist

if dist_range < 1e-10:
# All controls have identical/similar covariate profiles
covariate_scores = [1.0] * len(covariate_distances)
else:
# Normalize so best control (closest) gets 1.0, worst gets 0.0
covariate_scores = (1 - (covariate_distances - min_dist) / dist_range).tolist()
else:
covariate_scores = [np.nan] * len(control_candidates)

Expand Down Expand Up @@ -1271,7 +1291,9 @@ def _suggest_treatment_candidates(
if len(other_means) > 0:
sd = other_means.std()
if sd > 0:
n_similar = int(np.sum(np.abs(other_means - avg_outcome) < 0.5 * sd))
n_similar = int(np.sum(
np.abs(other_means - avg_outcome) < _SIMILARITY_THRESHOLD_SD * sd
))
else:
n_similar = len(other_means)
else:
Expand Down Expand Up @@ -1307,7 +1329,9 @@ def _suggest_treatment_candidates(
else:
outcome_z = pd.Series([0.0] * len(result))

result['treatment_candidate_score'] = (similarity_score - 0.3 * outcome_z).clip(0, 1)
result['treatment_candidate_score'] = (
similarity_score - _OUTLIER_PENALTY_WEIGHT * outcome_z
).clip(0, 1)

# Return top candidates
result = result.nlargest(n_candidates, 'treatment_candidate_score')
Expand Down
58 changes: 58 additions & 0 deletions tests/test_prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,3 +734,61 @@ def test_weight_parameters(self):
# (just check both work, exact comparison is data-dependent)
assert len(result1) > 0
assert len(result2) > 0

def test_unbalanced_panel(self):
"""Test handling of unbalanced panels with missing data."""
from diff_diff.prep import rank_control_units

data = generate_did_data(n_units=20, n_periods=6, seed=42)

# Remove some observations to create unbalanced panel
# Remove all pre-period data for one control unit
control_units = data[data["treated"] == 0]["unit"].unique()
unit_to_partially_remove = control_units[0]
mask = ~(
(data["unit"] == unit_to_partially_remove) &
(data["period"] < 3)
)
unbalanced_data = data[mask].copy()

result = rank_control_units(
unbalanced_data,
unit_column="unit",
time_column="period",
outcome_column="outcome",
treatment_column="treated"
)

# Should still work and exclude the unit with no pre-treatment data
assert len(result) > 0
# The unit with missing pre-treatment data should not be in results
assert unit_to_partially_remove not in result["unit"].values

def test_single_control_unit(self):
"""Test edge case with only one control unit."""
from diff_diff.prep import rank_control_units

data = generate_did_data(n_units=10, n_periods=6, seed=42)

# Keep only one control unit
treated_units = data[data["treated"] == 1]["unit"].unique()
control_units = data[data["treated"] == 0]["unit"].unique()
single_control = control_units[0]

filtered_data = data[
(data["unit"].isin(treated_units)) |
(data["unit"] == single_control)
].copy()

result = rank_control_units(
filtered_data,
unit_column="unit",
time_column="period",
outcome_column="outcome",
treatment_column="treated"
)

assert len(result) == 1
assert result["unit"].iloc[0] == single_control
# Single control should get score of 1.0 (best possible)
assert result["quality_score"].iloc[0] == 1.0