From 4721bfd41191279fcf7dc2a5c2e2adf9691c4481 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 3 Jan 2026 15:53:22 +0000 Subject: [PATCH] Address code review feedback for rank_control_units - Move import to module level for efficiency - Add filtering for control units missing from pivot (unbalanced panels) - Use nanmean for RMSE to handle missing data - Fix edge case scoring when all controls have similar RMSE (min-max normalization) - Vectorize covariate distance computation for speed - Extract magic numbers to named constants - Add tests for unbalanced panels and single control unit edge case --- diff_diff/prep.py | 68 +++++++++++++++++++++++++++++++--------------- tests/test_prep.py | 58 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 104 insertions(+), 22 deletions(-) diff --git a/diff_diff/prep.py b/diff_diff/prep.py index a11c2faa..283fa237 100644 --- a/diff_diff/prep.py +++ b/diff_diff/prep.py @@ -11,6 +11,12 @@ import numpy as np import pandas as pd +from diff_diff.utils import compute_synthetic_weights + +# Constants for rank_control_units +_SIMILARITY_THRESHOLD_SD = 0.5 # Controls within this many SDs are "similar" +_OUTLIER_PENALTY_WEIGHT = 0.3 # Penalty weight for outcome outliers in treatment candidate scoring + def make_treatment_indicator( data: pd.DataFrame, @@ -985,9 +991,6 @@ def rank_control_units( >>> top_controls = ranking['unit'].tolist() >>> filtered = data[(data['treated'] == 1) | (data['unit'].isin(top_controls))] """ - # Import compute_synthetic_weights from utils - from diff_diff.utils import compute_synthetic_weights - # ------------------------------------------------------------------------- # Input validation # ------------------------------------------------------------------------- @@ -1080,6 +1083,11 @@ def rank_control_units( if len(valid_pre_periods) == 0: raise ValueError("No data found for specified pre-treatment periods.") + # Filter control_candidates to those present in pivot (handles unbalanced panels) + control_candidates = [c for c in control_candidates if c in pivot.columns] + if len(control_candidates) == 0: + raise ValueError("No control units found in pre-treatment data.") + # Control outcomes: shape (n_pre_periods, n_control_candidates) Y_control = pivot.loc[valid_pre_periods, control_candidates].values.astype(float) @@ -1097,18 +1105,24 @@ def rank_control_units( Y_control, Y_treated_mean, lambda_reg=lambda_reg ) - # RMSE for each control vs treated mean + # RMSE for each control vs treated mean (use nanmean to handle missing data) rmse_scores = [] for j in range(len(control_candidates)): y_c = Y_control[:, j] - rmse = np.sqrt(np.mean((y_c - Y_treated_mean) ** 2)) + rmse = np.sqrt(np.nanmean((y_c - Y_treated_mean) ** 2)) rmse_scores.append(rmse) # Convert RMSE to similarity score (lower RMSE = higher score) max_rmse = max(rmse_scores) if rmse_scores else 1.0 - outcome_trend_scores = [ - 1 - (rmse / (max_rmse + 1e-10)) for rmse in rmse_scores - ] + min_rmse = min(rmse_scores) if rmse_scores else 0.0 + rmse_range = max_rmse - min_rmse + + if rmse_range < 1e-10: + # All controls have identical/similar pre-trends (includes single control case) + outcome_trend_scores = [1.0] * len(rmse_scores) + else: + # Normalize so best control gets 1.0, worst gets 0.0 + outcome_trend_scores = [1 - (rmse - min_rmse) / rmse_range for rmse in rmse_scores] # ------------------------------------------------------------------------- # Compute covariate scores (if covariates provided) @@ -1126,18 +1140,24 @@ def rank_control_units( cov_standardized = (cov_data - cov_mean) / cov_std treated_cov_std = (treated_cov - cov_mean) / cov_std - # Euclidean distance in standardized space - covariate_distances = [] - for control_unit in control_candidates: - control_cov_std = cov_standardized.loc[control_unit] - dist = np.sqrt(np.sum((control_cov_std - treated_cov_std) ** 2)) - covariate_distances.append(dist) - - # Convert distance to similarity score - max_dist = max(covariate_distances) if covariate_distances else 1.0 - covariate_scores = [ - 1 - (d / (max_dist + 1e-10)) for d in covariate_distances - ] + # Euclidean distance in standardized space (vectorized) + control_cov_matrix = cov_standardized.loc[control_candidates].values + treated_cov_vector = treated_cov_std.values + covariate_distances = np.sqrt( + np.sum((control_cov_matrix - treated_cov_vector) ** 2, axis=1) + ) + + # Convert distance to similarity score (min-max normalization) + max_dist = covariate_distances.max() if len(covariate_distances) > 0 else 1.0 + min_dist = covariate_distances.min() if len(covariate_distances) > 0 else 0.0 + dist_range = max_dist - min_dist + + if dist_range < 1e-10: + # All controls have identical/similar covariate profiles + covariate_scores = [1.0] * len(covariate_distances) + else: + # Normalize so best control (closest) gets 1.0, worst gets 0.0 + covariate_scores = (1 - (covariate_distances - min_dist) / dist_range).tolist() else: covariate_scores = [np.nan] * len(control_candidates) @@ -1271,7 +1291,9 @@ def _suggest_treatment_candidates( if len(other_means) > 0: sd = other_means.std() if sd > 0: - n_similar = int(np.sum(np.abs(other_means - avg_outcome) < 0.5 * sd)) + n_similar = int(np.sum( + np.abs(other_means - avg_outcome) < _SIMILARITY_THRESHOLD_SD * sd + )) else: n_similar = len(other_means) else: @@ -1307,7 +1329,9 @@ def _suggest_treatment_candidates( else: outcome_z = pd.Series([0.0] * len(result)) - result['treatment_candidate_score'] = (similarity_score - 0.3 * outcome_z).clip(0, 1) + result['treatment_candidate_score'] = ( + similarity_score - _OUTLIER_PENALTY_WEIGHT * outcome_z + ).clip(0, 1) # Return top candidates result = result.nlargest(n_candidates, 'treatment_candidate_score') diff --git a/tests/test_prep.py b/tests/test_prep.py index 7049962d..057b919c 100644 --- a/tests/test_prep.py +++ b/tests/test_prep.py @@ -734,3 +734,61 @@ def test_weight_parameters(self): # (just check both work, exact comparison is data-dependent) assert len(result1) > 0 assert len(result2) > 0 + + def test_unbalanced_panel(self): + """Test handling of unbalanced panels with missing data.""" + from diff_diff.prep import rank_control_units + + data = generate_did_data(n_units=20, n_periods=6, seed=42) + + # Remove some observations to create unbalanced panel + # Remove all pre-period data for one control unit + control_units = data[data["treated"] == 0]["unit"].unique() + unit_to_partially_remove = control_units[0] + mask = ~( + (data["unit"] == unit_to_partially_remove) & + (data["period"] < 3) + ) + unbalanced_data = data[mask].copy() + + result = rank_control_units( + unbalanced_data, + unit_column="unit", + time_column="period", + outcome_column="outcome", + treatment_column="treated" + ) + + # Should still work and exclude the unit with no pre-treatment data + assert len(result) > 0 + # The unit with missing pre-treatment data should not be in results + assert unit_to_partially_remove not in result["unit"].values + + def test_single_control_unit(self): + """Test edge case with only one control unit.""" + from diff_diff.prep import rank_control_units + + data = generate_did_data(n_units=10, n_periods=6, seed=42) + + # Keep only one control unit + treated_units = data[data["treated"] == 1]["unit"].unique() + control_units = data[data["treated"] == 0]["unit"].unique() + single_control = control_units[0] + + filtered_data = data[ + (data["unit"].isin(treated_units)) | + (data["unit"] == single_control) + ].copy() + + result = rank_control_units( + filtered_data, + unit_column="unit", + time_column="period", + outcome_column="outcome", + treatment_column="treated" + ) + + assert len(result) == 1 + assert result["unit"].iloc[0] == single_control + # Single control should get score of 1.0 (best possible) + assert result["quality_score"].iloc[0] == 1.0