# Walmart Application: Effects of Walmart Opening on Local Retail Employment

This notebook replicates the empirical analysis from Lee and Wooldridge (2025),
"A Simple Transformation Approach to Difference-in-Differences Estimation
for Panel Data" (SSRN 4516518), Section 6.

## Data Description

- **Source**: Brown and Butts (2025), based on County Business Patterns (CBP) data
- **Panel**: 1,280 counties over 23 years (1977-1999)
- **Treatment**: First Walmart store opening in a county
- **Outcome**: Log county-level retail employment

## Reference Results (Table A4)

Rolling IPWRA with Heterogeneous Trends:
- ATT(0)  = 0.007 (SE=0.004)
- ATT(1)  = 0.032 (SE=0.005)
- ATT(2)  = 0.025 (SE=0.006)
- ATT(3)  = 0.021 (SE=0.007)
- ATT(4)  = 0.018 (SE=0.009)
- ATT(5)  = 0.017 (SE=0.010)
- ATT(6)  = 0.019 (SE=0.012)
- ATT(7)  = 0.036 (SE=0.013)
- ATT(8)  = 0.041 (SE=0.016)
- ATT(9)  = 0.041 (SE=0.019)
- ATT(10) = 0.037 (SE=0.023)
- ATT(11) = 0.018 (SE=0.030)
- ATT(12) = 0.017 (SE=0.036)
- ATT(13) = 0.047 (SE=0.053)

## 1. Setup

In [None]:
import os
import warnings

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from lwdid import lwdid

warnings.formatwarning = lambda msg, cat, *a, **kw: f'{cat.__name__}: {msg}\n'

print("lwdid package loaded successfully")

## 2. Data Loading and Descriptive Statistics

Load the Walmart data and verify descriptive statistics match Table 2 of the paper.

In [None]:
df = pd.read_csv('../data/walmart.csv')

print(f"Data shape: {df.shape[0]:,} observations, {df.shape[1]} variables")
print(f"Counties: {df['fips'].nunique():,}")
print(f"Years: {df['year'].min()} - {df['year'].max()}")
print(f"Observations per county: {df.groupby('fips').size().unique()[0]}")

In [None]:
# Treatment cohort distribution (Table 2)
print("Treatment Cohort Distribution (Table 2)")
print("-" * 50)

cohort_dist = df.groupby('g')['fips'].nunique().sort_index()
n_never_treated = cohort_dist.get(np.inf, 0)
n_treated = cohort_dist[cohort_dist.index != np.inf].sum()

print(f"Treated counties: {n_treated}")
print(f"Never-treated counties: {n_never_treated}")
print(f"Treatment cohort range: 1986 - 1999")

In [None]:
# Verify descriptive statistics (Table 2)
stats = {
    'log(Retail Employment)': ('log_retail_emp', 7.754502),
    'Share Poverty (above)': ('share_pop_poverty_78_above', 0.8470385),
    'Share in Manufacturing': ('share_pop_ind_manuf', 0.0998018),
    'Share HS Graduate': ('share_school_some_hs', 0.092258),
}

print(f"{'Variable':<30} {'Data Mean':>12} {'Paper Mean':>12} {'Match':>8}")
print("-" * 65)

all_match = True
for name, (col, paper_val) in stats.items():
    data_val = df[col].mean()
    match = abs(data_val - paper_val) < 0.001
    all_match = all_match and match
    match_str = "\u2713" if match else "\u2717"
    print(f"{name:<30} {data_val:>12.6f} {paper_val:>12.6f} {match_str:>8}")

if all_match:
    print("\nAll descriptive statistics match Table 2 \u2713")
else:
    print("\nWarning: Some statistics do not match exactly")

## 3. Helper Functions

Define the estimation wrapper and WATT (Weighted Average Treatment Effect on the Treated) computation functions.

In [None]:
def estimate_rolling_ipwra(df, rolling_method, controls, control_group='not_yet_treated',
                           include_pretreatment=True, verbose=True):
    """
    Estimate ATT using Rolling IPWRA method.

    Parameters
    ----------
    df : pd.DataFrame
        Panel data
    rolling_method : str
        'demean' or 'detrend'
    controls : list
        Control variables
    control_group : str
        'never_treated', 'not_yet_treated', or 'all_others'
    include_pretreatment : bool
        Whether to compute pre-treatment effects (set False inside bootstrap for speed)
    verbose : bool
        Whether to print progress info

    Returns
    -------
    LWDIDResults
        Estimation results
    """
    if verbose:
        print(f"Estimating Rolling IPWRA with {rolling_method} (control: {control_group})...")

    results = lwdid(
        data=df,
        y='log_retail_emp',
        ivar='fips',
        tvar='year',
        gvar='g',
        rolling=rolling_method,
        estimator='ipwra',
        controls=controls,
        control_group=control_group,
        aggregate='none',
        alpha=0.05,
        include_pretreatment=include_pretreatment,
    )

    return results

In [None]:
def compute_watt(results, df):
    """
    Compute Weighted Average Treatment Effects on the Treated (WATT) by event time.

    WATT(r) = sum_g w(g,r) * ATT(g, g+r)
    where w(g,r) = N_g / N_Gr is the share of treated units in cohort g.

    Parameters
    ----------
    results : LWDIDResults
        Estimation results with cohort-time effects
    df : pd.DataFrame
        Original data for computing weights

    Returns
    -------
    pd.DataFrame
        WATT by event time
    """
    att_ct = results.att_by_cohort_time.copy()

    if att_ct is None or len(att_ct) == 0:
        return pd.DataFrame()

    # Get cohort sizes for weighting
    cohort_sizes = df[df['g'] != np.inf].groupby('g')['fips'].nunique().to_dict()

    # Add weights
    att_ct['weight'] = att_ct['cohort'].map(cohort_sizes)
    att_ct['weight'] = att_ct['weight'].fillna(0)

    # Aggregate by event time
    watt_list = []

    for event_time in sorted(att_ct['event_time'].unique()):
        subset = att_ct[att_ct['event_time'] == event_time].copy()
        subset = subset[subset['att'].notna()]

        if len(subset) == 0:
            continue

        # Normalize weights
        total_weight = subset['weight'].sum()
        if total_weight == 0:
            continue

        subset['norm_weight'] = subset['weight'] / total_weight

        # Weighted ATT
        watt = (subset['att'] * subset['norm_weight']).sum()

        # Weighted SE (conservative: assumes independence)
        watt_se = np.sqrt((subset['se']**2 * subset['norm_weight']**2).sum())

        # Number of cohorts contributing
        n_cohorts = len(subset)
        n_total = subset['n_treated'].sum() + subset['n_control'].sum()

        watt_list.append({
            'event_time': int(event_time),
            'watt': watt,
            'se': watt_se,
            'ci_lower': watt - 1.96 * watt_se,
            'ci_upper': watt + 1.96 * watt_se,
            'n_cohorts': n_cohorts,
            'n_total': n_total,
        })

    return pd.DataFrame(watt_list)

## 4. Bootstrap Standard Errors

Define the cluster bootstrap procedure for computing WATT standard errors (paper-style: bootstrap over units).

In [None]:
def _bootstrap_resample_units(df, ivar, seed, rep):
    """
    Cluster bootstrap at the unit level (resample units with replacement).
    Duplicated units are assigned new synthetic unit IDs to keep (ivar, tvar) unique.
    """
    rng = np.random.default_rng(seed + rep)
    unit_to_idx = df.groupby(ivar, sort=False).indices
    unit_ids = np.array(list(unit_to_idx.keys()))
    sampled_ids = rng.choice(unit_ids, size=len(unit_ids), replace=True)

    idx_arrays = [unit_to_idx[u] for u in sampled_ids]
    boot_idx = np.concatenate(idx_arrays)

    rep_counts = [len(unit_to_idx[u]) for u in sampled_ids]
    new_unit_ids = np.repeat(np.arange(len(sampled_ids)), rep_counts)

    boot_df = df.iloc[boot_idx].copy()
    boot_df[ivar] = new_unit_ids
    return boot_df


def compute_watt_bootstrap_se(df, rolling_method, controls, control_group,
                               *, n_bootstrap=100, seed=12345):
    """
    Compute WATT and bootstrap SE (paper-style: bootstrap reps over units).

    This is computationally expensive: it re-runs the full staggered pipeline
    n_bootstrap times. Enable only when you explicitly want paper-style SEs.
    """
    # Point estimate on original sample (no pre-treatment needed)
    base_results = estimate_rolling_ipwra(
        df, rolling_method, controls, control_group=control_group,
        include_pretreatment=False, verbose=False
    )
    watt_point = compute_watt(base_results, df)
    if len(watt_point) == 0:
        return watt_point

    event_times = watt_point['event_time'].tolist()
    rep_matrix = {et: [] for et in event_times}

    for b in range(n_bootstrap):
        if (b + 1) % 10 == 0 or b == 0:
            print(f"  Bootstrap rep {b + 1}/{n_bootstrap}...")
        boot_df = _bootstrap_resample_units(df, ivar='fips', seed=seed, rep=b)
        boot_results = estimate_rolling_ipwra(
            boot_df, rolling_method, controls, control_group=control_group,
            include_pretreatment=False, verbose=False
        )
        boot_watt = compute_watt(boot_results, boot_df)

        for et in event_times:
            vals = boot_watt.loc[boot_watt['event_time'] == et, 'watt'].values
            rep_matrix[et].append(float(vals[0]) if len(vals) else np.nan)

    # Replace SE/CI with bootstrap-based values
    se_boot = []
    for et in event_times:
        arr = np.asarray(rep_matrix[et], dtype=float)
        arr = arr[np.isfinite(arr)]
        if len(arr) < 2:
            se_boot.append(np.nan)
        else:
            se_boot.append(float(np.std(arr, ddof=1)))

    watt_point = watt_point.copy()
    watt_point['se'] = se_boot
    watt_point['ci_lower'] = watt_point['watt'] - 1.96 * watt_point['se']
    watt_point['ci_upper'] = watt_point['watt'] + 1.96 * watt_point['se']
    watt_point['se_method'] = f'bootstrap({n_bootstrap})'
    return watt_point

## 5. Define Control Variables

Control variables matching Table 2 of the paper.

In [None]:
controls = [
    'share_pop_poverty_78_above',  # Share above poverty line
    'share_pop_ind_manuf',          # Share in manufacturing
    'share_school_some_hs',         # Share with HS education
]

## 6. Rolling IPWRA with Demeaning

Table A4 Rolling IPWRA (demean) column uses `control_group='all_others'`, where the treatment
indicator is defined as 1{g_i=g} and the control group includes all non-cohort units (including
already-treated ones).

In [None]:
results_demean = estimate_rolling_ipwra(df, 'demean', controls, control_group='all_others')

print("\nCohort-Time ATT Estimates (first 10 rows):")
results_demean.att_by_cohort_time.head(10)

## 7. Rolling IPWRA with Detrending (Heterogeneous Trends)

The heterogeneous trends (detrend) column uses the standard staggered DiD control group:
not-yet-treated + never-treated.

In [None]:
results_detrend = estimate_rolling_ipwra(df, 'detrend', controls, control_group='not_yet_treated')

print("\nCohort-Time ATT Estimates (first 10 rows):")
results_detrend.att_by_cohort_time.head(10)

## 8. Weighted ATT by Event Time

The paper uses bootstrap SE (100 reps) for WATT standard errors and confidence intervals.

Set `WALMART_FAST=1` environment variable to skip bootstrap and use analytical SE (for debugging only).

In [None]:
skip_bootstrap = os.getenv('WALMART_FAST', '0') == '1'

if skip_bootstrap:
    print("[FAST MODE] Skipping bootstrap, using analytical SE (debug only)")
    watt_demean = compute_watt(results_demean, df)
    watt_detrend = compute_watt(results_detrend, df)
else:
    reps = int(os.getenv('WALMART_WATT_BOOTSTRAP_REPS', '100'))
    seed = int(os.getenv('WALMART_WATT_BOOTSTRAP_SEED', '12345'))
    print(f"Bootstrap WATT SE (paper config): reps={reps}, seed={seed}")
    print("-" * 70)
    print("\nDemeaning bootstrap:")
    watt_demean = compute_watt_bootstrap_se(
        df, 'demean', controls, control_group='all_others',
        n_bootstrap=reps, seed=seed
    )
    print("\nDetrending bootstrap:")
    watt_detrend = compute_watt_bootstrap_se(
        df, 'detrend', controls, control_group='not_yet_treated',
        n_bootstrap=reps, seed=seed
    )

print("\nWATT with Demeaning:")
print(watt_demean.to_string(index=False))

print("\nWATT with Detrending:")
print(watt_detrend.to_string(index=False))

## 9. Comparison with Paper Results (Table A4)

Compare estimated WATT with the reference values from Table A4.

In [None]:
# Paper reference values (Table A4, last column: Rolling IPWRA with Het. Trends)
paper_detrend = {
    0: (0.007, 0.004), 1: (0.032, 0.005), 2: (0.025, 0.006),
    3: (0.021, 0.007), 4: (0.018, 0.009), 5: (0.017, 0.010),
    6: (0.019, 0.012), 7: (0.036, 0.013), 8: (0.041, 0.016),
    9: (0.041, 0.019), 10: (0.037, 0.023), 11: (0.018, 0.030),
    12: (0.017, 0.036), 13: (0.047, 0.053),
}

# Paper reference values for demeaning (Table A4, column 3)
paper_demean = {
    0: (0.018, 0.004), 1: (0.045, 0.004), 2: (0.038, 0.004),
    3: (0.032, 0.004), 4: (0.031, 0.004), 5: (0.036, 0.005),
    6: (0.040, 0.005), 7: (0.054, 0.006), 8: (0.062, 0.008),
    9: (0.063, 0.010), 10: (0.081, 0.013), 11: (0.083, 0.018),
    12: (0.080, 0.026), 13: (0.107, 0.039),
}

# Compare detrend results
print("Rolling IPWRA with Detrending (Heterogeneous Trends)")
print("-" * 70)
print(f"{'r':>3} | {'Python':>10} | {'Paper':>10} | {'Diff':>10} | {'Rating':>10}")
print("-" * 70)

detrend_diffs = []
for _, row in watt_detrend.iterrows():
    r = int(row['event_time'])
    if r in paper_detrend:
        paper_att, _ = paper_detrend[r]
        diff = row['watt'] - paper_att
        detrend_diffs.append(abs(diff))
        rating = "Close" if abs(diff) < 0.005 else ("Near" if abs(diff) < 0.01 else "Far")
        print(f"{r:>3} | {row['watt']:>10.4f} | {paper_att:>10.4f} | {diff:>+10.4f} | {rating:>10}")

if detrend_diffs:
    print("-" * 70)
    print(f"Mean absolute difference: {np.mean(detrend_diffs):.4f}")

In [None]:
# Compare demean results
print("Rolling IPWRA with Demeaning")
print("-" * 70)
print(f"{'r':>3} | {'Python':>10} | {'Paper':>10} | {'Diff':>10} | {'Rating':>10}")
print("-" * 70)

demean_diffs = []
for _, row in watt_demean.iterrows():
    r = int(row['event_time'])
    if r in paper_demean:
        paper_att, _ = paper_demean[r]
        diff = row['watt'] - paper_att
        demean_diffs.append(abs(diff))
        rating = "Close" if abs(diff) < 0.01 else ("Near" if abs(diff) < 0.03 else "Far")
        print(f"{r:>3} | {row['watt']:>10.4f} | {paper_att:>10.4f} | {diff:>+10.4f} | {rating:>10}")

if demean_diffs:
    print("-" * 70)
    print(f"Mean absolute difference: {np.mean(demean_diffs):.4f}")

## 10. Event Study Visualization (Figure 1)

Generate event study plots similar to Figure 1 in the paper.
Post-treatment uses bootstrap SE (consistent with the paper); pre-treatment uses analytical SE.

In [None]:
from lwdid.staggered.aggregation import aggregate_to_event_time, event_time_effects_to_dataframe

fig, axes = plt.subplots(1, 2, figsize=(14, 5.5))

for idx, (results_obj, watt_post, panel_title) in enumerate([
    (results_demean, watt_demean, '(b) Rolling IPWRA with unit-specific demeaning'),
    (results_detrend, watt_detrend, '(c) Rolling IPWRA with unit-specific detrending'),
]):
    ax = axes[idx]

    # Pre-treatment: aggregate from att_pre_treatment (analytical SE)
    pre_plot = pd.DataFrame()
    if results_obj.include_pretreatment and results_obj.att_pre_treatment is not None:
        pre_ct = results_obj.att_pre_treatment.copy()
        pre_effects = aggregate_to_event_time(
            pre_ct, results_obj.cohort_sizes, alpha=0.05, df_strategy='conservative'
        )
        pre_plot = event_time_effects_to_dataframe(pre_effects)
        pre_plot = pre_plot[pre_plot['event_time'] < 0].sort_values('event_time')

    # Post-treatment: use WATT (bootstrap SE or analytical SE)
    post_plot = watt_post[watt_post['event_time'] >= 0].copy().sort_values('event_time')

    # Pre-treatment (blue error bars)
    if len(pre_plot) > 0:
        ax.errorbar(
            pre_plot['event_time'], pre_plot['att'],
            yerr=[pre_plot['att'] - pre_plot['ci_lower'], pre_plot['ci_upper'] - pre_plot['att']],
            fmt='o-', color='steelblue', capsize=2, markersize=4,
            linewidth=1.2, label='Pre-treatment',
        )

    # Post-treatment (red error bars)
    if len(post_plot) > 0:
        ax.errorbar(
            post_plot['event_time'], post_plot['watt'],
            yerr=1.96 * post_plot['se'],
            fmt='o-', color='firebrick', capsize=2, markersize=4,
            linewidth=1.2, label='Post-treatment',
        )

    ax.axhline(y=0, color='gray', linestyle='-', linewidth=0.6, alpha=0.7)
    ax.axvline(x=-0.5, color='gray', linestyle=':', linewidth=0.8, alpha=0.5)
    ax.set_xlabel('Time To Treatment', fontsize=10)
    ax.set_ylabel('WATT', fontsize=10)
    ax.set_title(panel_title, fontsize=11)
    ax.legend(loc='upper left', fontsize=8)
    ax.grid(False)

plt.tight_layout()
plt.savefig('walmart_event_study.png', dpi=150, bbox_inches='tight')
plt.show()
print("Figure saved to: walmart_event_study.png")

## 11. Summary

In [None]:
print("Key Findings (Detrending - Heterogeneous Trends):")
if len(watt_detrend) > 0:
    post_watt = watt_detrend[watt_detrend['event_time'] >= 0]
    if len(post_watt) > 0:
        att_0 = post_watt[post_watt['event_time'] == 0]['watt'].values
        att_1 = post_watt[post_watt['event_time'] == 1]['watt'].values

        if len(att_0) > 0:
            print(f"  ATT(0) = {att_0[0]:.4f} (Instantaneous effect)")
        if len(att_1) > 0:
            print(f"  ATT(1) = {att_1[0]:.4f} (One year after opening)")
            pct_effect = (np.exp(att_1[0]) - 1) * 100
            print(f"         = {pct_effect:.1f}% increase in retail employment")

print()
print("Interpretation:")
print("  The heterogeneous trends estimator shows more modest effects")
print("  compared to estimators that don't account for county-specific trends.")
print("  This suggests pre-existing trends may have inflated earlier estimates.")
print()
print("Replication Analysis:")
print("  1. Detrend results closely match paper Table A4 (last column)")
print("  2. Demean results with control_group='all_others' closely match Table A4 (column 3)")
print("  3. Key qualitative findings consistent with paper:")
print("     - Detrending produces smaller, more conservative estimates")
print("     - Pre-treatment trends are flatter with detrending")
print("     - Effect of Walmart opening is positive but modest (~3%)")