## Hungarian 1-to-Many

In [1]:
import numpy as np
import pandas as pd
from scipy.optimize import linear_sum_assignment
import networkx as nx

In [2]:
# -----------------------------
# Hungarian 1-to-1 Matching
# -----------------------------
def att_hungarian_1to1(X, T, Y):
    treated_idx = np.where(T == 1)[0]
    control_idx = np.where(T == 0)[0]
    cost = np.linalg.norm(X[treated_idx][:, None] - X[control_idx][None, :], axis=2)
    row, col = linear_sum_assignment(cost)
    Y_treated = Y[treated_idx[row]]
    Y_control = Y[control_idx[col]]
    att = np.mean(Y_treated - Y_control)
    match = {i: [j] for i, j in zip(row, col)}
    return att, match, treated_idx, control_idx

# -----------------------------
# Hungarian 1-to-k Matching
# -----------------------------
def att_hungarian_1tok(X, T, Y, k=3):
    treated_idx = np.where(T == 1)[0]
    control_idx = np.where(T == 0)[0]
    G = nx.DiGraph()
    source, sink = "s", "t"

    for i in range(len(treated_idx)):
        G.add_edge(source, f"T{i}", capacity=k, weight=0)

    for j in range(len(control_idx)):
        G.add_edge(f"C{j}", sink, capacity=1, weight=0)

    for i in range(len(treated_idx)):
        for j in range(len(control_idx)):
            dist = np.linalg.norm(X[treated_idx[i]] - X[control_idx[j]])
            G.add_edge(f"T{i}", f"C{j}", capacity=1, weight=int(dist * 1e6))  # scaled to int

    flow = nx.max_flow_min_cost(G, source, sink)

    matches = {i: [] for i in range(len(treated_idx))}
    for i in range(len(treated_idx)):
        for j in range(len(control_idx)):
            if flow[f"T{i}"].get(f"C{j}", 0) > 0:
                matches[i].append(j)

    att_list = []
    for i, matched_js in matches.items():
        if matched_js:
            treated_y = Y[treated_idx[i]]
            matched_y = Y[control_idx][matched_js].mean()
            att_list.append(treated_y - matched_y)
    att = np.mean(att_list)
    return att, matches, treated_idx, control_idx

In [3]:
# -----------------------------
# Data Generator
# -----------------------------
def generate_data(n_treated=100, n_control=300, p=5, tau=2.0, hetero=False, seed=42):
    np.random.seed(seed)
    X = np.random.normal(0, 1, size=(n_treated + n_control, p))
    T = np.zeros(n_treated + n_control)
    T[:n_treated] = 1
    Y0 = X @ np.random.normal(0.5, 0.1, p) + np.random.normal(0, 1, X.shape[0])
    tau_x = tau * X[:, 0] if hetero else np.full(X.shape[0], tau)
    Y1 = Y0 + tau_x
    Y = T * Y1 + (1 - T) * Y0
    return X, T, Y, tau_x

# -----------------------------
# Covariate Balance Calculator
# -----------------------------
def covariate_balance(X, matches, treated_idx, control_idx):
    diffs = []
    for i, js in matches.items():
        if js:
            treated_x = X[treated_idx[i]]
            control_x = X[control_idx][js].mean(axis=0)
            diffs.append(treated_x - control_x)
    return np.abs(np.array(diffs)).mean(axis=0)

In [4]:
# -----------------------------
# Main Script
# -----------------------------
if __name__ == "__main__":
    X, T, Y, tau_x = generate_data()
    att_true = np.mean(tau_x[T == 1])

    att1, match1, t_idx1, c_idx1 = att_hungarian_1to1(X, T, Y)
    bal1 = covariate_balance(X, match1, t_idx1, c_idx1)

    attk, matchk, t_idxk, c_idxk = att_hungarian_1tok(X, T, Y, k=3)
    balk = covariate_balance(X, matchk, t_idxk, c_idxk)

    print(f"True ATT:           {att_true:.4f}")
    print(f"Hungarian 1-to-1:   {att1:.4f} | Mean cov. diff: {bal1.mean():.4f}")
    print(f"Hungarian 1-to-3:   {attk:.4f} | Mean cov. diff: {balk.mean():.4f}")

True ATT:           2.0000
Hungarian 1-to-1:   1.7935 | Mean cov. diff: 0.3461
Hungarian 1-to-3:   1.8459 | Mean cov. diff: 0.3122
