In [15]:
import numpy as np
import pandas as pd
from scipy.optimize import linear_sum_assignment, milp, LinearConstraint, Bounds
import networkx as nx
import time
from typing import Dict, List, Tuple, Optional
import warnings
warnings.filterwarnings('ignore')

# -----------------------------
# Fixed Leximin Implementation 
# -----------------------------

def leximin_bottleneck_assignment(dist_matrix):
    """
    Fixed leximin using bottleneck assignment.
    This is the correct approach for most matching problems.
    """
    n_treated, n_control = dist_matrix.shape
    distances = np.unique(dist_matrix.flatten())
    
    # Binary search for minimum bottleneck
    for threshold in sorted(distances):
        binary_matrix = (dist_matrix <= threshold).astype(float)
        
        try:
            row_ind, col_ind = linear_sum_assignment(-binary_matrix)
            if np.all(binary_matrix[row_ind, col_ind] == 1):
                actual_dists = dist_matrix[row_ind, col_ind]
                return col_ind, actual_dists.max()
        except:
            continue
    
    # Fallback to Hungarian if bottleneck fails
    row_ind, col_ind = linear_sum_assignment(dist_matrix)
    return col_ind, dist_matrix[row_ind, col_ind].max()

def leximin_integer_programming(dist_matrix):
    """
    True leximin using integer programming (for small problems).
    """
    n_treated, n_control = dist_matrix.shape
    
    if n_treated > 20:  # Too large for integer programming
        return leximin_bottleneck_assignment(dist_matrix)
    
    try:
        # Variables: x[i,j] binary + z continuous
        n_vars = n_treated * n_control + 1
        
        # Objective: minimize z
        c = np.zeros(n_vars)
        c[-1] = 1  # Minimize z
        
        # Equality constraints: sum_j x[i,j] = 1 for each i
        A_eq = np.zeros((n_treated, n_vars))
        for i in range(n_treated):
            for j in range(n_control):
                A_eq[i, i * n_control + j] = 1
        b_eq = np.ones(n_treated)
        
        # Inequality constraints: dist[i,j] * x[i,j] <= z
        # Reformulated as: dist[i,j] * x[i,j] - z <= 0
        A_ub = np.zeros((n_treated * n_control, n_vars))
        for i in range(n_treated):
            for j in range(n_control):
                idx = i * n_control + j
                A_ub[idx, idx] = dist_matrix[i, j]  # dist[i,j] * x[i,j]
                A_ub[idx, -1] = -1  # -z
        b_ub = np.zeros(n_treated * n_control)
        
        # Bounds: x[i,j] in {0,1}, z >= 0
        bounds = Bounds(
            lb=np.zeros(n_vars),
            ub=np.concatenate([np.ones(n_treated * n_control), [np.inf]])
        )
        
        # Integer constraints: all x variables are binary
        integrality = np.ones(n_vars, dtype=int)
        integrality[-1] = 0  # z is continuous
        
        # Solve
        result = milp(
            c=c,
            integrality=integrality,
            A_ub=A_ub, b_ub=b_ub,
            A_eq=A_eq, b_eq=b_eq,
            bounds=bounds
        )
        
        if result.success:
            x_vals = result.x[:-1].reshape(n_treated, n_control)
            matches = np.argmax(x_vals, axis=1)
            max_dist = result.x[-1]
            return matches, max_dist
        else:
            return leximin_bottleneck_assignment(dist_matrix)
            
    except Exception:
        return leximin_bottleneck_assignment(dist_matrix)

# -----------------------------
# Hungarian Methods (from earlier)
# -----------------------------

def hungarian_1to1(dist_matrix):
    """Standard Hungarian algorithm."""
    row_ind, col_ind = linear_sum_assignment(dist_matrix)
    return col_ind

def hungarian_sequential(dist_matrix, k=3):
    """Sequential Hungarian for 1-to-k matching."""
    n_treated, n_control = dist_matrix.shape
    
    if k > n_control:
        k = n_control
    
    matches = {i: [] for i in range(n_treated)}
    available_controls = set(range(n_control))
    
    for round_num in range(k):
        if not available_controls:
            break
        
        available_list = sorted(available_controls)
        reduced_matrix = dist_matrix[:, available_list]
        
        row_indices, col_indices = linear_sum_assignment(reduced_matrix)
        
        for treated_idx, reduced_control_idx in zip(row_indices, col_indices):
            original_control_idx = available_list[reduced_control_idx]
            matches[treated_idx].append(original_control_idx)
            available_controls.remove(original_control_idx)
    
    return matches

def hungarian_duplication(dist_matrix, k=3):
    """Graph duplication method for 1-to-k matching."""
    n_treated, n_control = dist_matrix.shape
    
    if k > n_control:
        k = n_control
    
    # Create expanded cost matrix
    expanded_cost = np.tile(dist_matrix, (k, 1))
    
    # Solve Hungarian on expanded problem
    row_indices, col_indices = linear_sum_assignment(expanded_cost)
    
    # Group results by original treated unit
    matches = {i: [] for i in range(n_treated)}
    for expanded_row, control in zip(row_indices, col_indices):
        original_treated = expanded_row % n_treated
        matches[original_treated].append(control)
    
    return matches

def hungarian_flow(dist_matrix, k=3):
    """Min-cost max-flow for 1-to-k matching."""
    n_treated, n_control = dist_matrix.shape
    
    G = nx.DiGraph()
    source, sink = "s", "t"
    
    for i in range(n_treated):
        G.add_edge(source, f"T{i}", capacity=k, weight=0)
    
    for j in range(n_control):
        G.add_edge(f"C{j}", sink, capacity=1, weight=0)
    
    for i in range(n_treated):
        for j in range(n_control):
            dist = dist_matrix[i, j]
            G.add_edge(f"T{i}", f"C{j}", capacity=1, weight=int(dist * 1e6))
    
    flow = nx.max_flow_min_cost(G, source, sink)
    
    matches = {i: [] for i in range(n_treated)}
    for i in range(n_treated):
        for j in range(n_control):
            if flow[f"T{i}"].get(f"C{j}", 0) > 0:
                matches[i].append(j)
    
    return matches

def leximin_1tok(dist_matrix, k=3):
    """
    Leximin for 1-to-k matching.
    For now, just use k independent bottleneck assignments.
    """
    n_treated, n_control = dist_matrix.shape
    
    matches = {i: [] for i in range(n_treated)}
    available_controls = set(range(n_control))
    
    for round_num in range(k):
        if not available_controls:
            break
        
        available_list = sorted(available_controls)
        reduced_matrix = dist_matrix[:, available_list]
        
        # Use bottleneck assignment for this round
        round_matches, _ = leximin_bottleneck_assignment(reduced_matrix)
        
        for treated_idx, reduced_control_idx in enumerate(round_matches):
            if reduced_control_idx < len(available_list):
                original_control_idx = available_list[reduced_control_idx]
                matches[treated_idx].append(original_control_idx)
                if original_control_idx in available_controls:
                    available_controls.remove(original_control_idx)
    
    return matches

# -----------------------------
# Enhanced Data Generator
# -----------------------------
def generate_matching_data(n_treated=50, n_control=150, p=5, tau=2.0, 
                          hetero=False, noise_level=1.0, confounding_strength=0.5, 
                          seed=42, imbalance_type="none"):
    """Enhanced data generator with various imbalance patterns."""
    np.random.seed(seed)
    
    # Generate covariates based on imbalance type
    if imbalance_type == "clustered":
        # Create clusters to force difficult matching
        X_treated = np.random.normal([1, 1, 0, 0, 0], 0.5, (n_treated, p))
        X_control = np.random.normal([0, 0, 0, 0, 0], 1.0, (n_control, p))
    elif imbalance_type == "sparse":
        # Some treated units are very isolated
        X_treated = np.random.normal(0, 1, (n_treated, p))
        X_control = np.random.normal(0, 1, (n_control, p))
        # Make 10% of treated units outliers
        outlier_count = n_treated // 10
        X_treated[-outlier_count:] += np.random.normal(3, 0.5, (outlier_count, p))
    else:  # "none"
        X_treated = np.random.normal(0, 1, (n_treated, p))
        X_control = np.random.normal(0, 1, (n_control, p))
    
    X = np.vstack([X_treated, X_control])
    T = np.concatenate([np.ones(n_treated), np.zeros(n_control)])
    
    # Potential outcomes with confounding
    beta = np.random.normal(confounding_strength, 0.1, p)
    Y0 = X @ beta + np.random.normal(0, noise_level, len(X))
    
    # Treatment effect
    if hetero:
        tau_x = tau * (1 + 0.5 * X[:, 0])
    else:
        tau_x = np.full(len(X), tau)
    
    Y1 = Y0 + tau_x
    Y = T * Y1 + (1 - T) * Y0
    
    return X, T, Y, tau_x

# -----------------------------
# Comprehensive Evaluation
# -----------------------------
def evaluate_matching_method(X, T, Y, matches, treated_idx, control_idx, tau_x, method_name):
    """Comprehensive evaluation of any matching method."""
    results = {'method': method_name}
    n_treated = len(treated_idx)
    
    # Handle both 1-to-1 and 1-to-k matches
    if isinstance(matches, dict):
        # 1-to-k matching
        att_list = []
        balance_diffs = []
        distances = []
        
        for i, matched_controls in matches.items():
            if matched_controls:
                treated_y = Y[treated_idx[i]]
                matched_y = Y[control_idx][matched_controls].mean()
                att_list.append(treated_y - matched_y)
                
                treated_x = X[treated_idx[i]]
                matched_x = X[control_idx][matched_controls].mean(axis=0)
                balance_diffs.append(treated_x - matched_x)
                
                for j in matched_controls:
                    dist = np.linalg.norm(treated_x - X[control_idx[j]])
                    distances.append(dist)
        
        results['avg_matches_per_treated'] = np.mean([len(m) for m in matches.values()])
        
    else:
        # 1-to-1 matching
        att_list = []
        balance_diffs = []
        distances = []
        
        for i, j in enumerate(matches):
            treated_y = Y[treated_idx[i]]
            control_y = Y[control_idx[j]]
            att_list.append(treated_y - control_y)
            
            treated_x = X[treated_idx[i]]
            control_x = X[control_idx[j]]
            balance_diffs.append(treated_x - control_x)
            
            distances.append(np.linalg.norm(treated_x - control_x))
        
        results['avg_matches_per_treated'] = 1.0
    
    # ATT metrics
    if att_list:
        results['att_estimate'] = np.mean(att_list)
        results['att_true'] = np.mean(tau_x[T == 1])
        results['att_bias'] = results['att_estimate'] - results['att_true']
        results['att_mse'] = results['att_bias'] ** 2
    else:
        results.update({'att_estimate': np.nan, 'att_true': np.nan, 'att_bias': np.nan, 'att_mse': np.nan})
    
    # Balance metrics
    if balance_diffs:
        balance_matrix = np.array(balance_diffs)
        results['mean_balance'] = np.abs(balance_matrix).mean()
        results['max_balance'] = np.abs(balance_matrix).max()
    else:
        results['mean_balance'] = np.inf
        results['max_balance'] = np.inf
    
    # Distance metrics
    if distances:
        results['mean_distance'] = np.mean(distances)
        results['max_distance'] = np.max(distances)
        results['median_distance'] = np.median(distances)
    else:
        results['mean_distance'] = np.inf
        results['max_distance'] = np.inf
        results['median_distance'] = np.inf
    
    return results

# -----------------------------
# Complete Comparison Framework
# -----------------------------
def complete_matching_comparison(X, T, Y, tau_x, k=3, verbose=True):
    """Compare all matching methods including fixed leximin."""
    
    treated_idx = np.where(T == 1)[0]
    control_idx = np.where(T == 0)[0]
    
    # Calculate distance matrix
    dist_matrix = np.linalg.norm(
        X[treated_idx][:, None] - X[control_idx][None, :], axis=2
    )
    
    methods = {}
    results = []
    
    # 1-to-1 methods
    one_to_one_methods = {
        'Hungarian 1-to-1': lambda: hungarian_1to1(dist_matrix),
        'Leximin 1-to-1 (Bottleneck)': lambda: leximin_bottleneck_assignment(dist_matrix)[0],
        'Leximin 1-to-1 (Integer)': lambda: leximin_integer_programming(dist_matrix)[0],
    }
    
    # 1-to-k methods  
    one_to_k_methods = {
        f'Hungarian Sequential 1-to-{k}': lambda: hungarian_sequential(dist_matrix, k),
        f'Hungarian Duplication 1-to-{k}': lambda: hungarian_duplication(dist_matrix, k),
        f'Hungarian Flow 1-to-{k}': lambda: hungarian_flow(dist_matrix, k),
        f'Leximin 1-to-{k}': lambda: leximin_1tok(dist_matrix, k),
    }
    
    all_methods = {**one_to_one_methods, **one_to_k_methods}
    
    for method_name, method_func in all_methods.items():
        if verbose:
            print(f"Running {method_name}...")
        
        start_time = time.time()
        try:
            matches = method_func()
            runtime = time.time() - start_time
            
            eval_results = evaluate_matching_method(
                X, T, Y, matches, treated_idx, control_idx, tau_x, method_name
            )
            eval_results['runtime'] = runtime
            eval_results['success'] = True
            
        except Exception as e:
            if verbose:
                print(f"  Error: {str(e)}")
            eval_results = {
                'method': method_name,
                'runtime': time.time() - start_time,
                'success': False,
                'error': str(e)
            }
        
        results.append(eval_results)
    
    return pd.DataFrame(results)

# -----------------------------
# Results Display
# -----------------------------
def display_comprehensive_results(df_results):
    """Display comprehensive results with focus on Hungarian vs Leximin."""
    
    print("=" * 90)
    print("COMPREHENSIVE MATCHING COMPARISON: Hungarian vs Leximin")
    print("=" * 90)
    
    successful = df_results[df_results['success'] == True].copy()
    
    if len(successful) == 0:
        print("No methods completed successfully!")
        return
    
    # Separate 1-to-1 and 1-to-k results
    one_to_one = successful[successful['avg_matches_per_treated'] == 1.0]
    one_to_k = successful[successful['avg_matches_per_treated'] > 1.0]
    
    print(f"\n📊 1-TO-1 MATCHING RESULTS")
    print("-" * 70)
    true_att = successful['att_true'].iloc[0]
    print(f"True ATT: {true_att:.4f}")
    print()
    
    if len(one_to_one) > 0:
        print(f"{'Method':<25} {'ATT Est':<8} {'Bias':<8} {'Balance':<8} {'Max Dist':<10} {'Time(s)':<8}")
        print("-" * 70)
        
        for _, row in one_to_one.iterrows():
            bias_pct = 100 * abs(row['att_bias']) / abs(true_att) if true_att != 0 else 0
            print(f"{row['method']:<25} "
                  f"{row['att_estimate']:<8.4f} "
                  f"{row['att_bias']:<8.4f} "
                  f"{row['mean_balance']:<8.4f} "
                  f"{row['max_distance']:<10.4f} "
                  f"{row['runtime']:<8.4f}")
    
    print(f"\n📊 1-TO-K MATCHING RESULTS")
    print("-" * 70)
    
    if len(one_to_k) > 0:
        print(f"{'Method':<30} {'ATT Est':<8} {'Bias':<8} {'Balance':<8} {'Max Dist':<10} {'Time(s)':<8}")
        print("-" * 70)
        
        for _, row in one_to_k.iterrows():
            print(f"{row['method']:<30} "
                  f"{row['att_estimate']:<8.4f} "
                  f"{row['att_bias']:<8.4f} "
                  f"{row['mean_balance']:<8.4f} "
                  f"{row['max_distance']:<10.4f} "
                  f"{row['runtime']:<8.4f}")
    
    # Key comparisons
    print(f"\n🎯 KEY INSIGHTS: Hungarian vs Leximin")
    print("-" * 70)
    
    # Compare 1-to-1 methods
    if len(one_to_one) >= 2:
        hungarian_1to1 = one_to_one[one_to_one['method'].str.contains('Hungarian 1-to-1')]
        leximin_1to1 = one_to_one[one_to_one['method'].str.contains('Leximin')]
        
        if len(hungarian_1to1) > 0 and len(leximin_1to1) > 0:
            h_max = hungarian_1to1['max_distance'].iloc[0]
            l_max = leximin_1to1['max_distance'].min()  # Best leximin
            
            print(f"1-to-1 Max Distance:")
            print(f"  Hungarian: {h_max:.4f}")
            print(f"  Leximin:   {l_max:.4f}")
            print(f"  Improvement: {h_max - l_max:.4f} ({100*(h_max-l_max)/h_max:.1f}%)")
    
    # Compare 1-to-k methods
    if len(one_to_k) >= 2:
        hungarian_1tok = one_to_k[one_to_k['method'].str.contains('Hungarian')]
        leximin_1tok = one_to_k[one_to_k['method'].str.contains('Leximin')]
        
        if len(hungarian_1tok) > 0 and len(leximin_1tok) > 0:
            h_bias = hungarian_1tok['att_bias'].abs().min()  # Best Hungarian
            l_bias = leximin_1tok['att_bias'].abs().iloc[0]
            
            print(f"\n1-to-k ATT Bias (absolute):")
            print(f"  Best Hungarian: {h_bias:.4f}")
            print(f"  Leximin:        {l_bias:.4f}")
            
# -----------------------------
# Main Analysis
# -----------------------------
def run_comprehensive_analysis():
    """Run the complete Hungarian vs Leximin analysis."""
    
    print("🚀 COMPREHENSIVE HUNGARIAN vs LEXIMIN ANALYSIS")
    print("=" * 70)
    
    # Test different scenarios
    scenarios = [
        ("Balanced", {"imbalance_type": "none", "n_treated": 40, "n_control": 120}),
        ("Clustered", {"imbalance_type": "clustered", "n_treated": 40, "n_control": 120}),
        ("Sparse", {"imbalance_type": "sparse", "n_treated": 40, "n_control": 120}),
    ]
    
    for scenario_name, params in scenarios:
        print(f"\n" + "="*70)
        print(f"SCENARIO: {scenario_name}")
        print("="*70)
        
        # Generate data
        X, T, Y, tau_x = generate_matching_data(**params)
        
        # Compare methods
        results_df = complete_matching_comparison(X, T, Y, tau_x, k=3, verbose=False)
        
        # Display results
        display_comprehensive_results(results_df)
    
    return results_df

if __name__ == "__main__":
    final_results = run_comprehensive_analysis()

🚀 COMPREHENSIVE HUNGARIAN vs LEXIMIN ANALYSIS

SCENARIO: Balanced
COMPREHENSIVE MATCHING COMPARISON: Hungarian vs Leximin

📊 1-TO-1 MATCHING RESULTS
----------------------------------------------------------------------
True ATT: 2.0000

Method                    ATT Est  Bias     Balance  Max Dist   Time(s) 
----------------------------------------------------------------------
Hungarian 1-to-1          1.5895   -0.4105  0.4021   2.0027     0.0014  
Leximin 1-to-1 (Bottleneck) 1.7512   -0.2488  0.5917   2.0027     0.0312  
Leximin 1-to-1 (Integer)  1.7512   -0.2488  0.5917   2.0027     0.0283  

📊 1-TO-K MATCHING RESULTS
----------------------------------------------------------------------
Method                         ATT Est  Bias     Balance  Max Dist   Time(s) 
----------------------------------------------------------------------
Hungarian Sequential 1-to-3    1.7982   -0.2018  0.3906   3.8734     0.0007  
Hungarian Duplication 1-to-3   1.7982   -0.2018  0.3946   3.1331     0.0