# Testing Parallel Trends and DiD Diagnostics

The **parallel trends assumption** is the key identifying assumption for Difference-in-Differences. It states that in the absence of treatment, treated and control groups would have followed the same trend.

This notebook covers:
1. Visual inspection of parallel trends
2. Statistical tests for parallel trends
3. Equivalence testing (TOST)
4. Distributional comparison (Wasserstein)
5. Placebo tests and diagnostics
6. Sensitivity analysis

In [None]:
import numpy as np
import pandas as pd
from diff_diff import DifferenceInDifferences, MultiPeriodDiD
from diff_diff.utils import (
    check_parallel_trends,
    check_parallel_trends_robust,
    equivalence_test_trends
)
from diff_diff.diagnostics import (
    run_placebo_test,
    placebo_timing_test,
    placebo_group_test,
    permutation_test,
    run_all_placebo_tests
)

# For plots
try:
    import matplotlib.pyplot as plt
    plt.style.use('seaborn-v0_8-whitegrid')
    HAS_MATPLOTLIB = True
except ImportError:
    HAS_MATPLOTLIB = False
    print("matplotlib not installed - visualization examples will be skipped")

## 1. Create Example Data

We'll create two datasets:
- One where parallel trends **holds**
- One where parallel trends is **violated**

In [None]:
def generate_panel_data(n_units=100, n_periods=8, parallel=True, seed=42):
    """
    Generate panel data with or without parallel trends.
    
    Parameters
    ----------
    parallel : bool
        If True, treated and control have the same pre-treatment trend.
        If False, treated has a steeper trend.
    """
    np.random.seed(seed)
    
    treatment_time = n_periods // 2
    
    data = []
    for unit in range(n_units):
        is_treated = unit < n_units // 2
        unit_effect = np.random.normal(0, 2)
        
        for period in range(n_periods):
            # Base trend
            if parallel:
                # Same trend for both groups
                time_effect = period * 1.0
            else:
                # Different trends
                if is_treated:
                    time_effect = period * 2.0  # Steeper trend for treated
                else:
                    time_effect = period * 1.0
            
            y = 10.0 + unit_effect + time_effect
            
            # Treatment effect in post-period
            post = period >= treatment_time
            if is_treated and post:
                y += 5.0  # True ATT
            
            y += np.random.normal(0, 0.5)
            
            data.append({
                'unit': unit,
                'period': period,
                'treated': int(is_treated),
                'post': int(post),
                'outcome': y
            })
    
    return pd.DataFrame(data)

# Generate both datasets
df_parallel = generate_panel_data(parallel=True)
df_nonparallel = generate_panel_data(parallel=False)

print("Generated two datasets:")
print(f"  - df_parallel: Parallel trends holds")
print(f"  - df_nonparallel: Parallel trends violated")

## 2. Visual Inspection

The first step is always to **plot the data**. Look for:
- Similar slopes in pre-treatment periods
- Divergence only after treatment begins

In [None]:
def plot_trends(df, title, ax):
    """Plot mean outcomes by group over time."""
    means = df.groupby(['period', 'treated'])['outcome'].mean().unstack()
    
    treatment_time = df[df['post'] == 1]['period'].min()
    
    ax.plot(means.index, means[0], 'o-', label='Control', color='blue')
    ax.plot(means.index, means[1], 's-', label='Treated', color='red')
    ax.axvline(x=treatment_time - 0.5, color='gray', linestyle='--', 
               label='Treatment')
    ax.set_xlabel('Period')
    ax.set_ylabel('Mean Outcome')
    ax.set_title(title)
    ax.legend()

if HAS_MATPLOTLIB:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    plot_trends(df_parallel, 'Parallel Trends Holds', axes[0])
    plot_trends(df_nonparallel, 'Parallel Trends Violated', axes[1])
    
    plt.tight_layout()
    plt.show()

## 3. Simple Parallel Trends Test

The `check_parallel_trends()` function computes and compares the pre-treatment trends.

In [None]:
# Test for parallel trends (parallel case)
results_pt_parallel = check_parallel_trends(
    df_parallel,
    outcome='outcome',
    time='period',
    treatment_group='treated',
    pre_periods=[0, 1, 2, 3]  # Pre-treatment periods
)

print("Parallel Trends Test (parallel case):")
print("=" * 50)
print(f"Treated trend: {results_pt_parallel['treated_trend']:.4f} "
      f"(SE: {results_pt_parallel['treated_trend_se']:.4f})")
print(f"Control trend: {results_pt_parallel['control_trend']:.4f} "
      f"(SE: {results_pt_parallel['control_trend_se']:.4f})")
print(f"Difference: {results_pt_parallel['trend_difference']:.4f} "
      f"(SE: {results_pt_parallel['trend_difference_se']:.4f})")
print(f"t-statistic: {results_pt_parallel['t_statistic']:.4f}")
print(f"p-value: {results_pt_parallel['p_value']:.4f}")
print(f"\nParallel trends plausible: {results_pt_parallel['parallel_trends_plausible']}")

In [None]:
# Test for parallel trends (non-parallel case)
results_pt_nonparallel = check_parallel_trends(
    df_nonparallel,
    outcome='outcome',
    time='period',
    treatment_group='treated',
    pre_periods=[0, 1, 2, 3]
)

print("\nParallel Trends Test (non-parallel case):")
print("=" * 50)
print(f"Treated trend: {results_pt_nonparallel['treated_trend']:.4f}")
print(f"Control trend: {results_pt_nonparallel['control_trend']:.4f}")
print(f"Difference: {results_pt_nonparallel['trend_difference']:.4f}")
print(f"p-value: {results_pt_nonparallel['p_value']:.4f}")
print(f"\nParallel trends plausible: {results_pt_nonparallel['parallel_trends_plausible']}")

## 4. Robust Parallel Trends Test (Wasserstein)

The `check_parallel_trends_robust()` function uses the Wasserstein (Earth Mover's) distance to compare the **full distribution** of outcome changes, not just means.

In [None]:
# Robust test (parallel case)
results_robust_parallel = check_parallel_trends_robust(
    df_parallel,
    outcome='outcome',
    time='period',
    treatment_group='treated',
    unit='unit',
    pre_periods=[0, 1, 2, 3],
    n_permutations=999,
    seed=42
)

print("Robust Parallel Trends Test (parallel case):")
print("=" * 50)
print(f"Wasserstein distance: {results_robust_parallel['wasserstein_distance']:.4f}")
print(f"Wasserstein (normalized): {results_robust_parallel['wasserstein_normalized']:.4f}")
print(f"Wasserstein p-value: {results_robust_parallel['wasserstein_p_value']:.4f}")
print(f"KS statistic: {results_robust_parallel['ks_statistic']:.4f}")
print(f"KS p-value: {results_robust_parallel['ks_p_value']:.4f}")
print(f"Mean difference: {results_robust_parallel['mean_difference']:.4f}")
print(f"Variance ratio: {results_robust_parallel['variance_ratio']:.4f}")
print(f"\nParallel trends plausible: {results_robust_parallel['parallel_trends_plausible']}")

In [None]:
# Robust test (non-parallel case)
results_robust_nonparallel = check_parallel_trends_robust(
    df_nonparallel,
    outcome='outcome',
    time='period',
    treatment_group='treated',
    unit='unit',
    pre_periods=[0, 1, 2, 3],
    n_permutations=999,
    seed=42
)

print("\nRobust Parallel Trends Test (non-parallel case):")
print("=" * 50)
print(f"Wasserstein distance: {results_robust_nonparallel['wasserstein_distance']:.4f}")
print(f"Wasserstein p-value: {results_robust_nonparallel['wasserstein_p_value']:.4f}")
print(f"\nParallel trends plausible: {results_robust_nonparallel['parallel_trends_plausible']}")

In [None]:
if HAS_MATPLOTLIB:
    # Visualize the distribution of outcome changes
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    for i, (results, title) in enumerate([
        (results_robust_parallel, 'Parallel Trends'),
        (results_robust_nonparallel, 'Non-Parallel Trends')
    ]):
        ax = axes[i]
        ax.hist(results['treated_changes'], bins=20, alpha=0.5, 
                label='Treated', color='red')
        ax.hist(results['control_changes'], bins=20, alpha=0.5, 
                label='Control', color='blue')
        ax.set_xlabel('Outcome Change')
        ax.set_ylabel('Frequency')
        ax.set_title(f'{title}\n(Wasserstein p={results["wasserstein_p_value"]:.3f})')
        ax.legend()
    
    plt.tight_layout()
    plt.show()

## 5. Equivalence Testing (TOST)

Standard hypothesis testing has **low power** to detect parallel trends. A better approach is **equivalence testing** using the Two One-Sided Tests (TOST) procedure.

Instead of asking "Can we reject that trends are different?", we ask:
"Can we confirm that trend differences are smaller than some practically meaningful threshold?"

In [None]:
# Equivalence test (parallel case)
results_equiv_parallel = equivalence_test_trends(
    df_parallel,
    outcome='outcome',
    time='period',
    treatment_group='treated',
    unit='unit',
    pre_periods=[0, 1, 2, 3],
    equivalence_margin=0.5  # Differences < 0.5 are "equivalent"
)

print("Equivalence Test (parallel case):")
print("=" * 50)
print(f"Mean difference: {results_equiv_parallel['mean_difference']:.4f}")
print(f"SE: {results_equiv_parallel['se_difference']:.4f}")
print(f"Equivalence margin: +/- {results_equiv_parallel['equivalence_margin']:.4f}")
print(f"TOST p-value: {results_equiv_parallel['tost_p_value']:.4f}")
print(f"\nTrends are equivalent (at alpha=0.05): {results_equiv_parallel['equivalent']}")

In [None]:
# Equivalence test (non-parallel case)
results_equiv_nonparallel = equivalence_test_trends(
    df_nonparallel,
    outcome='outcome',
    time='period',
    treatment_group='treated',
    unit='unit',
    pre_periods=[0, 1, 2, 3],
    equivalence_margin=0.5
)

print("\nEquivalence Test (non-parallel case):")
print("=" * 50)
print(f"Mean difference: {results_equiv_nonparallel['mean_difference']:.4f}")
print(f"TOST p-value: {results_equiv_nonparallel['tost_p_value']:.4f}")
print(f"\nTrends are equivalent: {results_equiv_nonparallel['equivalent']}")

## 6. Placebo Tests

Placebo tests check whether we would detect "effects" where none should exist. Types of placebo tests:

1. **Timing placebo**: Pretend treatment happened earlier
2. **Group placebo**: Estimate DiD on never-treated units only
3. **Permutation test**: Randomly reassign treatment and see if effect persists

In [None]:
# First, fit the main model
did = DifferenceInDifferences()
main_results = did.fit(
    df_parallel,
    outcome='outcome',
    treatment='treated',
    time='post'
)

print("Main DiD Results:")
print(f"ATT: {main_results.att:.4f} (SE: {main_results.se:.4f})")
print(f"p-value: {main_results.p_value:.4f}")

In [None]:
# Placebo timing test
# Estimate DiD with a fake treatment time in pre-period
placebo_timing = placebo_timing_test(
    df_parallel,
    outcome='outcome',
    treatment='treated',
    time='period',
    placebo_time=2,  # Pretend treatment at period 2
    actual_treatment_time=4
)

print("\nPlacebo Timing Test:")
print("=" * 50)
print(f"Placebo ATT: {placebo_timing.effect:.4f}")
print(f"SE: {placebo_timing.se:.4f}")
print(f"p-value: {placebo_timing.p_value:.4f}")
print(f"\nPass (effect not significant): {placebo_timing.passed}")

In [None]:
# Placebo group test
# Estimate DiD using only never-treated units (random placebo assignment)
placebo_group = placebo_group_test(
    df_parallel,
    outcome='outcome',
    treatment='treated',
    time='post',
    unit='unit',
    seed=42
)

print("\nPlacebo Group Test:")
print("=" * 50)
print(f"Placebo ATT: {placebo_group.effect:.4f}")
print(f"SE: {placebo_group.se:.4f}")
print(f"p-value: {placebo_group.p_value:.4f}")
print(f"\nPass: {placebo_group.passed}")

In [None]:
# Permutation test
perm_results = permutation_test(
    df_parallel,
    outcome='outcome',
    treatment='treated',
    time='post',
    n_permutations=999,
    seed=42
)

print("\nPermutation Test:")
print("=" * 50)
print(f"Observed ATT: {perm_results.effect:.4f}")
print(f"Permutation p-value: {perm_results.p_value:.4f}")
print(f"Number of permutations: {perm_results.n_permutations}")

In [None]:
if HAS_MATPLOTLIB:
    # Visualize permutation distribution
    fig, ax = plt.subplots(figsize=(10, 6))
    
    ax.hist(perm_results.permuted_effects, bins=30, alpha=0.7, 
            edgecolor='black', label='Permuted effects')
    ax.axvline(x=perm_results.effect, color='red', linewidth=2, 
               linestyle='--', label=f'Observed = {perm_results.effect:.2f}')
    ax.axvline(x=0, color='gray', linewidth=1, linestyle=':')
    
    ax.set_xlabel('Effect')
    ax.set_ylabel('Frequency')
    ax.set_title(f'Permutation Test Distribution\n(p-value = {perm_results.p_value:.3f})')
    ax.legend()
    plt.tight_layout()
    plt.show()

## 7. Comprehensive Diagnostics

Run all placebo tests at once with `run_all_placebo_tests()`.

In [None]:
# Run comprehensive diagnostics
all_tests = run_all_placebo_tests(
    df_parallel,
    outcome='outcome',
    treatment='treated',
    time='period',
    unit='unit',
    treatment_time=4,
    n_permutations=499,
    seed=42
)

print("Comprehensive Placebo Test Results:")
print("=" * 60)
print(f"{'Test':<25} {'Effect':>10} {'p-value':>10} {'Pass':>10}")
print("-" * 60)

for test_name, result in all_tests.items():
    print(f"{test_name:<25} {result.effect:>10.4f} {result.p_value:>10.4f} {str(result.passed):>10}")

## 8. Event Study as a Parallel Trends Check

An **event study** shows period-by-period effects. Pre-treatment coefficients should be close to zero if parallel trends holds.

In [None]:
# Event study
mp_did = MultiPeriodDiD()
event_results = mp_did.fit(
    df_parallel,
    outcome='outcome',
    treatment='treated',
    time='period',
    post_periods=[4, 5, 6, 7],
    reference_period=3  # Use period 3 as reference
)

print(event_results.summary())

In [None]:
from diff_diff.visualization import plot_event_study

if HAS_MATPLOTLIB:
    fig, ax = plt.subplots(figsize=(10, 6))
    plot_event_study(
        results=event_results,
        ax=ax,
        title='Event Study: Check Pre-trends',
        xlabel='Period',
        ylabel='Effect'
    )
    plt.tight_layout()
    plt.show()

## 9. What to Do If Parallel Trends Fails?

If parallel trends is violated, consider:

1. **Add covariates** that might explain differential trends
2. **Use Synthetic DiD** which is more robust to trend differences
3. **Use bounds/sensitivity analysis** (Rambachan-Roth)
4. **Consider alternative designs** (RDD, IV, etc.)

In [None]:
# Example: Compare standard DiD vs Synthetic DiD on non-parallel data
from diff_diff import SyntheticDiD

# Standard DiD (biased when trends differ)
did_np = DifferenceInDifferences()
results_did_np = did_np.fit(
    df_nonparallel,
    outcome='outcome',
    treatment='treated',
    time='post'
)

# Synthetic DiD (may be less biased)
sdid = SyntheticDiD(n_bootstrap=99, seed=42)
results_sdid = sdid.fit(
    df_nonparallel,
    outcome='outcome',
    treatment='treated',
    unit='unit',
    time='period',
    post_periods=[4, 5, 6, 7]
)

print("Comparison on Non-Parallel Trends Data")
print("=" * 50)
print(f"True ATT: 5.0")
print(f"")
print(f"Standard DiD:")
print(f"  ATT: {results_did_np.att:.4f} (Bias: {results_did_np.att - 5.0:.4f})")
print(f"")
print(f"Synthetic DiD:")
print(f"  ATT: {results_sdid.att:.4f} (Bias: {results_sdid.att - 5.0:.4f})")

## Summary

**Key takeaways for parallel trends testing:**

1. **Always visualize** the data first

2. **Simple tests** (`check_parallel_trends`):
   - Compare pre-treatment slopes
   - Easy to interpret but limited

3. **Robust tests** (`check_parallel_trends_robust`):
   - Compare full distributions with Wasserstein distance
   - More powerful for detecting violations

4. **Equivalence testing** (`equivalence_test_trends`):
   - Tests whether differences are practically small
   - Better than "failing to reject" parallel trends

5. **Placebo tests**:
   - Timing: Fake treatment in pre-period
   - Group: DiD on never-treated only
   - Permutation: Randomize treatment assignment

6. **Event studies** show pre-treatment coefficients should be ~0

7. **If parallel trends fails**, consider:
   - Adding covariates
   - Synthetic DiD
   - Sensitivity analysis
   - Alternative identification strategies