In [1]:
import numpy as np
import pandas as pd
from scipy.optimize import linear_sum_assignment
import networkx as nx
import time
from typing import Dict, List, Tuple, Optional
import warnings
warnings.filterwarnings('ignore')

# -----------------------------
# Hungarian 1-to-1 Matching (Your Original)
# -----------------------------
def att_hungarian_1to1(X, T, Y):
    """Original 1-to-1 Hungarian matching."""
    treated_idx = np.where(T == 1)[0]
    control_idx = np.where(T == 0)[0]
    cost = np.linalg.norm(X[treated_idx][:, None] - X[control_idx][None, :], axis=2)
    row, col = linear_sum_assignment(cost)
    Y_treated = Y[treated_idx[row]]
    Y_control = Y[control_idx[col]]
    att = np.mean(Y_treated - Y_control)
    match = {i: [j] for i, j in zip(row, col)}
    return att, match, treated_idx, control_idx

# -----------------------------
# Hungarian 1-to-k Matching (Your Original - Min Cost Flow)
# -----------------------------
def att_hungarian_1tok_flow(X, T, Y, k=3):
    """Original 1-to-k matching using min-cost max-flow."""
    treated_idx = np.where(T == 1)[0]
    control_idx = np.where(T == 0)[0]
    
    G = nx.DiGraph()
    source, sink = "s", "t"
    
    # Add source to treated edges
    for i in range(len(treated_idx)):
        G.add_edge(source, f"T{i}", capacity=k, weight=0)
    
    # Add control to sink edges
    for j in range(len(control_idx)):
        G.add_edge(f"C{j}", sink, capacity=1, weight=0)
    
    # Add treated to control edges with costs
    for i in range(len(treated_idx)):
        for j in range(len(control_idx)):
            dist = np.linalg.norm(X[treated_idx[i]] - X[control_idx[j]])
            G.add_edge(f"T{i}", f"C{j}", capacity=1, weight=int(dist * 1e6))
    
    flow = nx.max_flow_min_cost(G, source, sink)
    
    matches = {i: [] for i in range(len(treated_idx))}
    for i in range(len(treated_idx)):
        for j in range(len(control_idx)):
            if flow[f"T{i}"].get(f"C{j}", 0) > 0:
                matches[i].append(j)
    
    att_list = []
    for i, matched_js in matches.items():
        if matched_js:
            treated_y = Y[treated_idx[i]]
            matched_y = Y[control_idx][matched_js].mean()
            att_list.append(treated_y - matched_y)
    
    att = np.mean(att_list)
    return att, matches, treated_idx, control_idx

# -----------------------------
# NEW: Sequential Hungarian
# -----------------------------
def att_hungarian_sequential(X, T, Y, k=3):
    """Sequential Hungarian: Run Hungarian k times, removing matched controls."""
    treated_idx = np.where(T == 1)[0]
    control_idx = np.where(T == 0)[0]
    
    if k > len(control_idx):
        raise ValueError(f"k={k} cannot exceed number of controls ({len(control_idx)})")
    
    # Initialize
    matches = {i: [] for i in range(len(treated_idx))}
    available_controls = set(range(len(control_idx)))
    
    for round_num in range(k):
        if not available_controls:
            break
        
        # Create cost matrix with only available controls
        available_list = sorted(available_controls)
        cost_matrix = np.linalg.norm(
            X[treated_idx][:, None] - X[control_idx[available_list]][None, :], 
            axis=2
        )
        
        # Solve Hungarian
        row_indices, col_indices = linear_sum_assignment(cost_matrix)
        
        # Store matches and remove used controls
        for treated_idx_pos, available_pos in zip(row_indices, col_indices):
            control_idx_pos = available_list[available_pos]
            matches[treated_idx_pos].append(control_idx_pos)
            available_controls.remove(control_idx_pos)
    
    # Calculate ATT
    att_list = []
    for i, matched_js in matches.items():
        if matched_js:
            treated_y = Y[treated_idx[i]]
            matched_y = Y[control_idx][matched_js].mean()
            att_list.append(treated_y - matched_y)
    
    att = np.mean(att_list)
    return att, matches, treated_idx, control_idx

# -----------------------------
# NEW: Graph Duplication Hungarian
# -----------------------------
def att_hungarian_duplication(X, T, Y, k=3):
    """Graph Duplication: Create k copies of treated units, run standard Hungarian."""
    treated_idx = np.where(T == 1)[0]
    control_idx = np.where(T == 0)[0]
    
    if k > len(control_idx):
        raise ValueError(f"k={k} cannot exceed number of controls ({len(control_idx)})")
    
    # Create expanded cost matrix: (k*n_treated, n_controls)
    n_treated = len(treated_idx)
    n_controls = len(control_idx)
    
    # Tile the cost matrix k times
    base_cost = np.linalg.norm(X[treated_idx][:, None] - X[control_idx][None, :], axis=2)
    expanded_cost = np.tile(base_cost, (k, 1))
    
    # Solve Hungarian on expanded problem
    row_indices, col_indices = linear_sum_assignment(expanded_cost)
    
    # Group results by original treated unit
    matches = {i: [] for i in range(n_treated)}
    for expanded_row, control in zip(row_indices, col_indices):
        original_treated = expanded_row % n_treated
        matches[original_treated].append(control)
    
    # Calculate ATT
    att_list = []
    for i, matched_js in matches.items():
        if matched_js:
            treated_y = Y[treated_idx[i]]
            matched_y = Y[control_idx][matched_js].mean()
            att_list.append(treated_y - matched_y)
    
    att = np.mean(att_list)
    return att, matches, treated_idx, control_idx

# -----------------------------
# NEW: Greedy k-Nearest Neighbors (Baseline)
# -----------------------------
def att_greedy_knn(X, T, Y, k=3):
    """Greedy k-NN matching for comparison."""
    treated_idx = np.where(T == 1)[0]
    control_idx = np.where(T == 0)[0]
    
    matches = {i: [] for i in range(len(treated_idx))}
    used_controls = set()
    
    for i in range(len(treated_idx)):
        # Calculate distances to all unused controls
        available_controls = [j for j in range(len(control_idx)) if j not in used_controls]
        if len(available_controls) < k:
            available_controls = list(range(len(control_idx)))  # Allow reuse if necessary
        
        distances = np.linalg.norm(
            X[treated_idx[i]] - X[control_idx[available_controls]], axis=1
        )
        
        # Take k nearest
        nearest_k = np.argsort(distances)[:k]
        matches[i] = [available_controls[j] for j in nearest_k]
        
        # Mark as used (for without replacement)
        for j in nearest_k:
            used_controls.add(available_controls[j])
    
    # Calculate ATT
    att_list = []
    for i, matched_js in matches.items():
        if matched_js:
            treated_y = Y[treated_idx[i]]
            matched_y = Y[control_idx][matched_js].mean()
            att_list.append(treated_y - matched_y)
    
    att = np.mean(att_list)
    return att, matches, treated_idx, control_idx

# -----------------------------
# Enhanced Data Generator
# -----------------------------
def generate_data(n_treated=100, n_control=300, p=5, tau=2.0, hetero=False, 
                  noise_level=1.0, confounding_strength=0.5, seed=42):
    """Enhanced data generator with more realistic confounding."""
    np.random.seed(seed)
    
    # Generate covariates
    X = np.random.normal(0, 1, size=(n_treated + n_control, p))
    
    # Treatment assignment (first n_treated units are treated)
    T = np.zeros(n_treated + n_control)
    T[:n_treated] = 1
    
    # Potential outcomes with confounding
    # Y0 depends on X with some confounding
    beta = np.random.normal(confounding_strength, 0.1, p)
    Y0 = X @ beta + np.random.normal(0, noise_level, X.shape[0])
    
    # Treatment effect (constant or heterogeneous)
    if hetero:
        tau_x = tau * (1 + 0.5 * X[:, 0])  # Effect varies with first covariate
    else:
        tau_x = np.full(X.shape[0], tau)
    
    Y1 = Y0 + tau_x
    Y = T * Y1 + (1 - T) * Y0
    
    return X, T, Y, tau_x

# -----------------------------
# Enhanced Evaluation Metrics
# -----------------------------
def comprehensive_evaluation(X, T, Y, matches, treated_idx, control_idx, tau_x, method_name):
    """Comprehensive evaluation of matching quality."""
    results = {'method': method_name}
    
    # 1. ATT Estimation
    att_list = []
    for i, matched_js in matches.items():
        if matched_js:
            treated_y = Y[treated_idx[i]]
            matched_y = Y[control_idx][matched_js].mean()
            att_list.append(treated_y - matched_y)
    
    results['att_estimate'] = np.mean(att_list)
    results['att_true'] = np.mean(tau_x[T == 1])
    results['att_bias'] = results['att_estimate'] - results['att_true']
    results['att_mse'] = results['att_bias'] ** 2
    
    # 2. Covariate Balance
    balance_diffs = []
    for i, matched_js in matches.items():
        if matched_js:
            treated_x = X[treated_idx[i]]
            matched_x = X[control_idx][matched_js].mean(axis=0)
            balance_diffs.append(treated_x - matched_x)
    
    if balance_diffs:
        balance_matrix = np.array(balance_diffs)
        results['mean_balance'] = np.abs(balance_matrix).mean()
        results['max_balance'] = np.abs(balance_matrix).max()
        results['balance_by_covariate'] = np.abs(balance_matrix).mean(axis=0)
    else:
        results['mean_balance'] = np.inf
        results['max_balance'] = np.inf
        results['balance_by_covariate'] = np.full(X.shape[1], np.inf)
    
    # 3. Matching Statistics
    match_sizes = [len(matched_js) for matched_js in matches.values()]
    results['avg_matches_per_treated'] = np.mean(match_sizes)
    results['min_matches_per_treated'] = np.min(match_sizes)
    results['unmatched_treated'] = sum(1 for size in match_sizes if size == 0)
    
    # 4. Distance Statistics
    distances = []
    for i, matched_js in matches.items():
        if matched_js:
            treated_x = X[treated_idx[i]]
            for j in matched_js:
                control_x = X[control_idx[j]]
                distances.append(np.linalg.norm(treated_x - control_x))
    
    if distances:
        results['mean_distance'] = np.mean(distances)
        results['median_distance'] = np.median(distances)
        results['max_distance'] = np.max(distances)
    else:
        results['mean_distance'] = np.inf
        results['median_distance'] = np.inf
        results['max_distance'] = np.inf
    
    return results

# -----------------------------
# Method Comparison Framework
# -----------------------------
def compare_all_methods(X, T, Y, tau_x, k=3, verbose=True):
    """Compare all matching methods side by side."""
    methods = {
        '1-to-1 Hungarian': lambda: att_hungarian_1to1(X, T, Y),
        f'1-to-{k} Flow': lambda: att_hungarian_1tok_flow(X, T, Y, k),
        f'1-to-{k} Sequential': lambda: att_hungarian_sequential(X, T, Y, k),
        f'1-to-{k} Duplication': lambda: att_hungarian_duplication(X, T, Y, k),
        f'1-to-{k} Greedy kNN': lambda: att_greedy_knn(X, T, Y, k),
    }
    
    results = []
    
    for method_name, method_func in methods.items():
        if verbose:
            print(f"Running {method_name}...")
        
        start_time = time.time()
        try:
            att, matches, treated_idx, control_idx = method_func()
            runtime = time.time() - start_time
            
            eval_results = comprehensive_evaluation(
                X, T, Y, matches, treated_idx, control_idx, tau_x, method_name
            )
            eval_results['runtime'] = runtime
            eval_results['success'] = True
            
        except Exception as e:
            if verbose:
                print(f"  Error: {str(e)}")
            eval_results = {
                'method': method_name,
                'runtime': time.time() - start_time,
                'success': False,
                'error': str(e)
            }
        
        results.append(eval_results)
    
    return pd.DataFrame(results)

# -----------------------------
# Enhanced Results Display
# -----------------------------
def display_results(df_results, detailed=True):
    """Display comprehensive results comparison."""
    
    print("=" * 80)
    print("COMPREHENSIVE HUNGARIAN MATCHING COMPARISON")
    print("=" * 80)
    
    # Filter successful methods
    successful = df_results[df_results['success'] == True].copy()
    
    if len(successful) == 0:
        print("No methods completed successfully!")
        return
    
    print(f"\n📊 ATT ESTIMATION RESULTS")
    print("-" * 50)
    true_att = successful['att_true'].iloc[0]
    print(f"True ATT: {true_att:.4f}")
    print()
    
    for _, row in successful.iterrows():
        bias = row['att_bias']
        bias_pct = 100 * abs(bias) / abs(true_att) if true_att != 0 else 0
        print(f"{row['method']:<20}: {row['att_estimate']:>7.4f} "
              f"(bias: {bias:>+7.4f}, {bias_pct:>5.1f}%)")
    
    print(f"\n🎯 BALANCE & QUALITY METRICS")
    print("-" * 50)
    print(f"{'Method':<20} {'Balance':<8} {'Matches':<8} {'Distance':<8} {'Time(s)':<8}")
    print("-" * 50)
    
    for _, row in successful.iterrows():
        print(f"{row['method']:<20} "
              f"{row['mean_balance']:<8.4f} "
              f"{row['avg_matches_per_treated']:<8.2f} "
              f"{row['mean_distance']:<8.4f} "
              f"{row['runtime']:<8.3f}")
    
    if detailed:
        print(f"\n📈 DETAILED STATISTICS")
        print("-" * 50)
        
        # Best method by different criteria
        best_att = successful.loc[successful['att_mse'].idxmin(), 'method']
        best_balance = successful.loc[successful['mean_balance'].idxmin(), 'method']
        fastest = successful.loc[successful['runtime'].idxmin(), 'method']
        
        print(f"🏆 Best ATT Estimation: {best_att}")
        print(f"⚖️  Best Balance:        {best_balance}")
        print(f"⚡ Fastest:             {fastest}")
        
        print(f"\n📊 BALANCE BY COVARIATE")
        print("-" * 50)
        for i, row in successful.iterrows():
            if 'balance_by_covariate' in row and hasattr(row['balance_by_covariate'], '__len__'):
                balance_str = " ".join([f"{x:.3f}" for x in row['balance_by_covariate']])
                print(f"{row['method']:<20}: [{balance_str}]")

# -----------------------------
# Monte Carlo Simulation
# -----------------------------
def monte_carlo_comparison(n_simulations=50, **data_params):
    """Run Monte Carlo comparison across multiple datasets."""
    print(f"Running Monte Carlo simulation with {n_simulations} replications...")
    
    all_results = []
    
    for sim in range(n_simulations):
        if sim % 10 == 0:
            print(f"  Simulation {sim+1}/{n_simulations}")
        
        # Generate new data
        data_params['seed'] = sim + 42
        X, T, Y, tau_x = generate_data(**data_params)
        
        # Compare methods (non-verbose)
        sim_results = compare_all_methods(X, T, Y, tau_x, verbose=False)
        sim_results['simulation'] = sim
        all_results.append(sim_results)
    
    # Combine results
    combined_df = pd.concat(all_results, ignore_index=True)
    
    # Summary statistics
    print(f"\n📊 MONTE CARLO SUMMARY ({n_simulations} simulations)")
    print("=" * 60)
    
    successful_df = combined_df[combined_df['success'] == True]
    summary = successful_df.groupby('method').agg({
        'att_bias': ['mean', 'std'],
        'att_mse': 'mean',
        'mean_balance': ['mean', 'std'],
        'runtime': ['mean', 'std']
    }).round(4)
    
    print(summary)
    
    return combined_df

# -----------------------------
# Main Script
# -----------------------------
if __name__ == "__main__":
    print("🚀 Enhanced Hungarian Matching Analysis")
    print("=" * 60)
    
    # Generate data
    X, T, Y, tau_x = generate_data(n_treated=50, n_control=150, p=5, tau=2.0, hetero=True)
    
    print(f"Dataset: {len(X)} units ({sum(T)} treated, {sum(1-T)} controls)")
    print(f"Covariates: {X.shape[1]}")
    print(f"True ATT: {np.mean(tau_x[T==1]):.4f}")
    
    # Single comparison
    results_df = compare_all_methods(X, T, Y, tau_x, k=3)
    display_results(results_df, detailed=True)
    
    # Optional: Monte Carlo simulation
    print(f"\n" + "="*60)
    print("🎲 Monte Carlo Simulation (optional - set to True to run)")
    run_mc = False  # Set to True to run Monte Carlo
    
    if run_mc:
        mc_results = monte_carlo_comparison(
            n_simulations=20,
            n_treated=50, 
            n_control=150, 
            p=5, 
            tau=2.0, 
            hetero=True
        )
        
        # Save results
        mc_results.to_csv('hungarian_comparison_results.csv', index=False)
        print("Results saved to 'hungarian_comparison_results.csv'")

🚀 Enhanced Hungarian Matching Analysis
Dataset: 200 units (50.0 treated, 150.0 controls)
Covariates: 5
True ATT: 2.0944
Running 1-to-1 Hungarian...
Running 1-to-3 Flow...
Running 1-to-3 Sequential...
Running 1-to-3 Duplication...
Running 1-to-3 Greedy kNN...
COMPREHENSIVE HUNGARIAN MATCHING COMPARISON

📊 ATT ESTIMATION RESULTS
--------------------------------------------------
True ATT: 2.0944

1-to-1 Hungarian    :  2.4538 (bias: +0.3594,  17.2%)
1-to-3 Flow         :  2.2553 (bias: +0.1609,   7.7%)
1-to-3 Sequential   :  2.2553 (bias: +0.1609,   7.7%)
1-to-3 Duplication  :  2.2553 (bias: +0.1609,   7.7%)
1-to-3 Greedy kNN   :  2.2553 (bias: +0.1609,   7.7%)

🎯 BALANCE & QUALITY METRICS
--------------------------------------------------
Method               Balance  Matches  Distance Time(s) 
--------------------------------------------------
1-to-1 Hungarian     0.3828   1.00     1.0344   0.000   
1-to-3 Flow          0.3580   3.00     1.3881   0.273   
1-to-3 Sequential    0.3728   