# Analysis Notebook

Use Python code cells to run analyses inline.


In [None]:
import pandas as pd
import numpy as np
print("Notebook ready")


# Speed Dating → Logit Utilities → Gale–Shapley Stable Matching

This notebook is a **fully worked end-to-end exercise** that links:

1. A **logit model** for individual yes/no decisions (utility estimation)
2. **Predicted probabilities** to build **preference rankings**
3. The **Gale–Shapley** algorithm to compute **stable matchings**
4. A comparison to the observed **`match`** outcomes in the data

**Data file:** `../speed-dating.csv`

> Note on IDs: this cleaned dataset does not include explicit participant IDs.
> We reconstruct consistent IDs within each `wave` using stable demographic variables
> and the stated preference-weight questions.


In [None]:
# --- Setup ---
from pathlib import Path
import os

# Make matplotlib work in restricted environments (no writable home cache)
os.environ.setdefault('MPLCONFIGDIR', str(Path('.latex-interface') / 'mpl-cache'))

import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, accuracy_score

import networkx as nx

from IPython.display import display

pd.set_option('display.max_columns', 200)
pd.set_option('display.width', 120)

RANDOM_SEED = 0


## 0) Load the data

Each row is one speed-date **from one participant's perspective**.
Key variables we will use:

- `decision`: whether the participant wants to see the partner again (0/1)
- `decision_o`: the partner's decision (0/1)
- `match`: 1 if both said yes
- `wave`: a separate speed-dating market/event
- `gender`: participant gender (male/female)

We will use `wave = 1` as the running example (10 men, 10 women → 100 pairs).


In [None]:
RAW_DATA_PATH = Path('../speed-dating.csv')
PROCESSED_DATA_PATH = Path('data/speed-dating-ethnicity-with-ids.csv')

# If the processed file exists (the shareable student version), use it.
# Otherwise fall back to the raw file and create the processed dataset later.
if PROCESSED_DATA_PATH.exists():
    df = pd.read_csv(PROCESSED_DATA_PATH)
    print('Loaded processed dataset:', PROCESSED_DATA_PATH)
else:
    assert RAW_DATA_PATH.exists(), f"Missing raw data file: {RAW_DATA_PATH.resolve()}"
    df = pd.read_csv(RAW_DATA_PATH)

    # Rename race-related variables to ethnicity-related names
    df = df.rename(columns={
        'race': 'ethnicity',
        'race_o': 'ethnicity_o',
        'samerace': 'same_ethnicity',
        'importance_same_race': 'importance_same_ethnicity',
        'd_importance_same_race': 'd_importance_same_ethnicity',
    })
    print('Loaded raw dataset:', RAW_DATA_PATH)

print('Shape:', df.shape)
print('Columns:', len(df.columns))
df.head(3)


In [None]:
# Quick sanity checks
print(df['gender'].value_counts(dropna=False))
print('Waves:', df['wave'].nunique())
print('Decision mean:', df['decision'].mean())
print('Match mean:', df['match'].mean())


## 0.1) Reconstruct participant IDs

To run Gale–Shapley we need consistent IDs for the two sides of the market.

This dataset contains **stable preference-weight questions** (they sum to 100):

- Participant's own weights: `*_important`
- Partner's weights (same questions but for the partner): `pref_o_*`

So we can identify a person within a wave by:

- `wave`, `gender`, `age`, `ethnicity`
- their preference-weight vector (`*_important`)

And we can identify the partner in the row by the *same* information using
`age_o`, `ethnicity_o`, and `pref_o_*`.


In [None]:
imp_cols = [
    'attractive_important',
    'sincere_important',
    'intellicence_important',
    'funny_important',
    'ambtition_important',
    'shared_interests_important',
]

pref_cols = [
    'pref_o_attractive',
    'pref_o_sincere',
    'pref_o_intelligence',
    'pref_o_funny',
    'pref_o_ambitious',
    'pref_o_shared_interests',
]

# If the processed dataset is loaded, it already has person_id/partner_id.
# If not, reconstruct IDs from stable demographics + preference-weight vectors.
df = df.copy()

if 'person_id' not in df.columns or 'partner_id' not in df.columns:
    # Respondent signature → person_id
    person_sig_cols = ['wave', 'gender', 'age', 'ethnicity'] + imp_cols
    df['person_sig'] = df[person_sig_cols].astype(str).agg('|'.join, axis=1)
    df['person_id'] = pd.factorize(df['person_sig'])[0]

    sig_to_id = (
        df[['person_sig', 'person_id']]
        .drop_duplicates('person_sig')
        .set_index('person_sig')['person_id']
        .to_dict()
    )

    # Partner signature → partner_id (mapped into the same ID space)
    partner_gender = df['gender'].map({'male': 'female', 'female': 'male'})
    partner_sig_df = pd.DataFrame({
        'wave': df['wave'],
        'gender': partner_gender,
        'age': df['age_o'],
        'ethnicity': df['ethnicity_o'],
    })
    for imp, pref in zip(imp_cols, pref_cols):
        partner_sig_df[imp] = df[pref]

    df['partner_sig'] = partner_sig_df.astype(str).agg('|'.join, axis=1)
    df['partner_id'] = df['partner_sig'].map(sig_to_id)
else:
    # Ensure consistent dtypes when loading from CSV
    df['person_id'] = df['person_id'].astype(int)
    df['partner_id'] = df['partner_id'].astype(int)

print('Rows with missing partner_id:', df['partner_id'].isna().sum())
print('Unique participants (reconstructed):', df[['wave','gender','person_id']].drop_duplicates().shape[0])

# Create the analysis dataset we share with students: keep only rows with both IDs
# and save it with ethnicity terminology.
df_clean = df.dropna(subset=['person_id', 'partner_id']).copy()
df_clean['person_id'] = df_clean['person_id'].astype(int)
df_clean['partner_id'] = df_clean['partner_id'].astype(int)

PROCESSED_DATA_PATH.parent.mkdir(parents=True, exist_ok=True)
df_clean.to_csv(PROCESSED_DATA_PATH, index=False)
print('Wrote processed analysis dataset:', PROCESSED_DATA_PATH)


In [None]:
# Participants per wave (reconstructed)
people = df[['wave','gender','person_id','age','ethnicity']].drop_duplicates(['wave','gender','person_id'])
counts = people.groupby(['wave','gender']).size().unstack(fill_value=0)
counts['balanced'] = (counts.get('female',0) == counts.get('male',0))
counts.sort_values(['balanced', 'female'], ascending=[False, False]).head(12)


## 0.2) Choose a wave (a matching market)

We will work with one wave as a clean two-sided matching market.
You can change `WAVE` to explore other markets.


In [None]:
WAVE = 1

# Keep only rows where both sides have reconstructed IDs
df_clean = df.dropna(subset=['person_id', 'partner_id']).copy()
# partner_id is created via a mapping (float + NaN); after dropping NaNs we cast to int
df_clean['person_id'] = df_clean['person_id'].astype(int)
df_clean['partner_id'] = df_clean['partner_id'].astype(int)
df_wave = df_clean[df_clean['wave'] == WAVE].copy()

print('Wave rows:', df_wave.shape[0])
print('Unique women:', df_wave[df_wave['gender']=='female']['person_id'].nunique())
print('Unique men  :', df_wave[df_wave['gender']=='male']['person_id'].nunique())

# Create short labels for nicer tables/plots
men_ids = sorted(df_wave[df_wave['gender']=='male']['person_id'].unique().tolist())
women_ids = sorted(df_wave[df_wave['gender']=='female']['person_id'].unique().tolist())

id_to_label = {pid: f"M{idx+1:02d}" for idx, pid in enumerate(men_ids)}
id_to_label.update({pid: f"W{idx+1:02d}" for idx, pid in enumerate(women_ids)})

people_wave = (
    people[people['wave'] == WAVE]
    .assign(label=lambda d: d['person_id'].map(id_to_label))
    .sort_values(['gender','label'])
)
people_wave.head(10)


## 1) Step 1 — Estimate utilities with a logit model

We model the probability that participant *i* says yes to partner *j* as:

\[\Pr(	exttt{decision}_{ij}=1\mid X_{ij}) = \Lambda(X_{ij}'eta)\]

where \(\Lambda(\cdot)\) is the logistic CDF.

- The *linear index* \(U_{ij} = X_{ij}'eta\) is an empirical **utility score**.
- The predicted probability \(\hat p_{ij}\) is a convenient **ranking score**.

We estimate **separate models** for women and men (preferences may differ).


In [None]:
FEATURES = [
    'attractive_o',
    'sinsere_o',
    'intelligence_o',
    'funny_o',
    'ambitous_o',
    'shared_interests_o',
    'd_age',
    'same_ethnicity',
]
TARGET = 'decision'

# We fit on the full dataset for stable coefficients, then apply to a wave.

def fit_logit_for_gender(df_in, gender):
    df_g = df_in[df_in['gender'] == gender].copy()
    df_g = df_g.dropna(subset=[TARGET])

    X = df_g[FEATURES]
    y = df_g[TARGET].astype(int)

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.25, random_state=RANDOM_SEED, stratify=y
    )

    model = Pipeline([
        ('imputer', SimpleImputer(strategy='median')),
        ('logit', LogisticRegression(max_iter=1000, solver='lbfgs')),
    ])
    model.fit(X_train, y_train)

    p_test = model.predict_proba(X_test)[:, 1]
    yhat = (p_test >= 0.5).astype(int)

    metrics = {
        'gender': gender,
        'n_train': len(X_train),
        'n_test': len(X_test),
        'auc': roc_auc_score(y_test, p_test),
        'accuracy@0.5': accuracy_score(y_test, yhat),
        'base_rate': float(y_test.mean()),
    }
    return model, metrics

models = {}
metrics = []
for g in ['female', 'male']:
    m, met = fit_logit_for_gender(df_clean, g)
    models[g] = m
    metrics.append(met)

pd.DataFrame(metrics)


In [None]:
# Coefficients (regularized MLE coefficients)

def coef_table(pipeline):
    logit = pipeline.named_steps['logit']
    coefs = pd.Series(logit.coef_.ravel(), index=FEATURES, name='coef')
    out = coefs.to_frame().reset_index().rename(columns={'index': 'feature'})
    out.loc[len(out)] = ['intercept', float(logit.intercept_[0])]
    return out

coef_female = coef_table(models['female'])
coef_male = coef_table(models['male'])

coef_female


In [None]:
# Add predicted utility (linear index) and probability to every row

df_pred = df_clean.copy()
df_pred['utility_hat'] = np.nan
df_pred['p_hat'] = np.nan

for g, model in models.items():
    mask = df_pred['gender'] == g
    Xg = df_pred.loc[mask, FEATURES]
    df_pred.loc[mask, 'utility_hat'] = model.decision_function(Xg)
    df_pred.loc[mask, 'p_hat'] = model.predict_proba(Xg)[:, 1]

print('p_hat range:', float(df_pred['p_hat'].min()), float(df_pred['p_hat'].max()))
df_pred[['wave','gender','decision','p_hat','utility_hat']].head(5)


## 2) Step 2 — Build preference rankings and the Gale–Shapley matrices

For a given wave, we can build (estimated) preferences by sorting partners by \(\hat p_{ij}\).

- Men’s preference list over women: sort women by men’s predicted probabilities
- Women’s preference list over men: sort men by women’s predicted probabilities

We will construct:

- A men→women probability matrix \(P^{M}\)
- A women→men probability matrix \(P^{W}\)
- Two preference dictionaries (one for each side) to feed into Gale–Shapley


In [None]:
df_wave_pred = df_pred[df_pred['wave'] == WAVE].copy()

# Men as respondents: p(m says yes to w)
men_rows = df_wave_pred[df_wave_pred['gender'] == 'male'].copy()
P_m = men_rows.pivot_table(index='person_id', columns='partner_id', values='p_hat', aggfunc='mean')

# Women as respondents: p(w says yes to m)
women_rows = df_wave_pred[df_wave_pred['gender'] == 'female'].copy()
P_w = women_rows.pivot_table(index='person_id', columns='partner_id', values='p_hat', aggfunc='mean')

print('P_m shape (men x women):', P_m.shape)
print('P_w shape (women x men):', P_w.shape)

# Relabel for nicer display
P_m_labeled = P_m.rename(index=id_to_label, columns=id_to_label).sort_index().sort_index(axis=1)
P_w_labeled = P_w.rename(index=id_to_label, columns=id_to_label).sort_index().sort_index(axis=1)

P_m_labeled.iloc[:5, :5]


In [None]:
# Preference lists (highest probability = best)

def prefs_from_matrix(P):
    prefs = {}
    for i, row in P.iterrows():
        prefs[i] = row.sort_values(ascending=False).index.tolist()
    return prefs

prefs_men = prefs_from_matrix(P_m)
prefs_women = prefs_from_matrix(P_w)

first_man = men_ids[0]
first_woman = women_ids[0]
print('Example man', id_to_label[first_man], 'top-5 women:', [id_to_label[x] for x in prefs_men[first_man][:5]])
print('Example woman', id_to_label[first_woman], 'top-5 men :', [id_to_label[x] for x in prefs_women[first_woman][:5]])


In [None]:
# Optional: visualize men's probability matrix as a heatmap
plt.figure(figsize=(7, 6))
sns.heatmap(P_m_labeled, cmap='viridis', cbar_kws={'label': 'p_hat (man says yes)'})
plt.title(f'Wave {WAVE}: Men → Women predicted probabilities')
plt.tight_layout()
plt.show()


## 3) Step 3 — Gale–Shapley stable matching

We now compute the stable matching equilibrium under two regimes:

- **Women propose** (women-optimal stable matching)
- **Men propose** (men-optimal stable matching)

Gale–Shapley guarantees a stable matching (no blocking pair) given complete preference lists.


In [None]:
def gale_shapley(proposer_prefs, acceptor_prefs):
    # Return a stable matching as {proposer: acceptor}

    acceptor_rank = {
        a: {p: r for r, p in enumerate(pref_list)}
        for a, pref_list in acceptor_prefs.items()
    }

    free = list(proposer_prefs.keys())
    next_idx = {p: 0 for p in proposer_prefs}
    engaged = {}  # acceptor -> proposer

    while free:
        p = free.pop(0)
        prefs = proposer_prefs[p]

        if next_idx[p] >= len(prefs):
            continue

        a = prefs[next_idx[p]]
        next_idx[p] += 1

        if a not in engaged:
            engaged[a] = p
            continue

        current = engaged[a]
        rank_new = acceptor_rank[a].get(p, float('inf'))
        rank_cur = acceptor_rank[a].get(current, float('inf'))

        if rank_new < rank_cur:
            engaged[a] = p
            free.append(current)
        else:
            free.append(p)

    match = {p: None for p in proposer_prefs}
    for a, p in engaged.items():
        match[p] = a
    return match


def matching_to_frame(match, proposer_label='proposer', acceptor_label='acceptor'):
    rows = []
    for p, a in match.items():
        rows.append({
            proposer_label: p,
            acceptor_label: a,
            proposer_label + '_label': id_to_label.get(p, str(p)),
            acceptor_label + '_label': id_to_label.get(a, str(a)) if a is not None else None,
        })
    return pd.DataFrame(rows)


def rank_in_prefs(prefs_dict, agent, partner):
    if partner is None:
        return None
    try:
        return prefs_dict[agent].index(partner) + 1
    except ValueError:
        return None

# Women propose
match_women_propose = gale_shapley(prefs_women, prefs_men)
res_wp = matching_to_frame(match_women_propose, proposer_label='woman', acceptor_label='man')
res_wp['woman_rank_of_man'] = res_wp.apply(lambda r: rank_in_prefs(prefs_women, r['woman'], r['man']), axis=1)
res_wp['man_rank_of_woman'] = res_wp.apply(lambda r: rank_in_prefs(prefs_men, r['man'], r['woman']), axis=1)

# Men propose
match_men_propose = gale_shapley(prefs_men, prefs_women)
res_mp = matching_to_frame(match_men_propose, proposer_label='man', acceptor_label='woman')
res_mp['man_rank_of_woman'] = res_mp.apply(lambda r: rank_in_prefs(prefs_men, r['man'], r['woman']), axis=1)
res_mp['woman_rank_of_man'] = res_mp.apply(lambda r: rank_in_prefs(prefs_women, r['woman'], r['man']), axis=1)

print('Women-propose matching (first 10 rows):')
display(res_wp.head(10))
print('Men-propose matching (first 10 rows):')
display(res_mp.head(10))


In [None]:
# Welfare-style summary: average ranks (lower is better)

def summarize_matching(res, proposer_side):
    if proposer_side == 'women':
        return pd.Series({
            'avg woman rank (of man)': res['woman_rank_of_man'].mean(),
            'avg man rank (of woman)': res['man_rank_of_woman'].mean(),
        })
    if proposer_side == 'men':
        return pd.Series({
            'avg man rank (of woman)': res['man_rank_of_woman'].mean(),
            'avg woman rank (of man)': res['woman_rank_of_man'].mean(),
        })

summary = pd.DataFrame({
    'women propose': summarize_matching(res_wp, 'women'),
    'men propose': summarize_matching(res_mp, 'men'),
})
summary


## 4) Step 4 — Compare to the observed `match` outcomes

`match = 1` means both sides said yes after their speed-date.

Important conceptual note:

- The experiment allows **many matches per person** (a person can match with multiple partners).
- Gale–Shapley produces a **one-to-one** stable matching.

So we compare in two ways:

1. **Overlap with mutual matches:** Is the GS partner also a mutual match (`match=1`)?
2. **Best one-to-one matching inside the mutual-match network:** maximum-cardinality bipartite matching using only edges with `match=1`.


In [None]:
# Build a unique (man,woman) edge table from male-respondent rows
edges = (
    men_rows[['person_id','partner_id','match']]
    .rename(columns={'person_id':'man_id','partner_id':'woman_id'})
    .drop_duplicates(['man_id','woman_id'])
)
edge_match = edges.set_index(['man_id','woman_id'])['match']


def overlap_with_match(df_pairs):
    flags = []
    for m, w in df_pairs[['man_id','woman_id']].itertuples(index=False, name=None):
        flags.append(int(edge_match.get((m, w), 0) == 1))
    return sum(flags), len(flags), sum(flags)/len(flags) if flags else float('nan')

wp_pairs = res_wp.dropna(subset=['man']).rename(columns={'man':'man_id','woman':'woman_id'})[['man_id','woman_id']]
mp_pairs = res_mp.dropna(subset=['woman']).rename(columns={'man':'man_id','woman':'woman_id'})[['man_id','woman_id']]

wp_hit, wp_n, wp_rate = overlap_with_match(wp_pairs)
mp_hit, mp_n, mp_rate = overlap_with_match(mp_pairs)

print(f"Wave {WAVE}: GS women-propose overlap: {wp_hit}/{wp_n} = {wp_rate:.2%}")
print(f"Wave {WAVE}: GS men-propose   overlap: {mp_hit}/{mp_n} = {mp_rate:.2%}")
print('Total mutual matches in wave (edges with match=1):', int(edges['match'].sum()))


In [None]:
# Maximum-cardinality one-to-one matching using only mutual-match edges
B = nx.Graph()

men_nodes = [f"M_{m}" for m in men_ids]
women_nodes = [f"W_{w}" for w in women_ids]
B.add_nodes_from(men_nodes, bipartite=0)
B.add_nodes_from(women_nodes, bipartite=1)

for (m, w), val in edge_match.items():
    if val == 1:
        B.add_edge(f"M_{m}", f"W_{w}")

mm = nx.algorithms.bipartite.matching.maximum_matching(B, top_nodes=men_nodes)

mm_pairs = set()
for u, v in mm.items():
    if u.startswith('M_') and v.startswith('W_'):
        man_id = int(u.split('_',1)[1])
        woman_id = int(v.split('_',1)[1])
        mm_pairs.add((man_id, woman_id))

print('Max-cardinality matching size inside mutual matches:', len(mm_pairs))


In [None]:
# Compare GS pair-sets to the max-cardinality matching from observed mutual matches

def pair_set(df_pairs):
    return set(map(tuple, df_pairs[['man_id','woman_id']].values.tolist()))

S_wp = pair_set(wp_pairs)
S_mp = pair_set(mp_pairs)
S_mm = mm_pairs

mutual_edges = set(edge_match[edge_match == 1].index.tolist())

def jaccard(A, B):
    if not A and not B:
        return 1.0
    return len(A & B) / len(A | B)

comparison = pd.DataFrame([
    {
        'matching': 'GS (women propose)',
        'pairs': len(S_wp),
        'GS_pairs_that_are_mutual_matches': len(S_wp & mutual_edges),
        'jaccard_with_max_mutual_matching': jaccard(S_wp, S_mm),
    },
    {
        'matching': 'GS (men propose)',
        'pairs': len(S_mp),
        'GS_pairs_that_are_mutual_matches': len(S_mp & mutual_edges),
        'jaccard_with_max_mutual_matching': jaccard(S_mp, S_mm),
    },
    {
        'matching': 'Max matching within match=1 edges',
        'pairs': len(S_mm),
        'GS_pairs_that_are_mutual_matches': len(S_mm),
        'jaccard_with_max_mutual_matching': 1.0,
    },
])
comparison


## 5) Export tables/figures for LaTeX

We export a few key tables and figures into `tables/` and `images/` so a LaTeX document can include them.


In [None]:
Path('tables').mkdir(exist_ok=True)
Path('images').mkdir(exist_ok=True)

participants_out = people_wave[['label','gender','age','ethnicity']].copy()
participants_out.to_csv('tables/wave_participants.csv', index=False)
participants_out.to_latex('tables/wave_participants.tex', index=False, escape=True)


# Logit fit metrics
metrics_df = pd.DataFrame(metrics)
metrics_df.to_csv('tables/logit_metrics.csv', index=False)
metrics_df.to_latex('tables/logit_metrics.tex', index=False, escape=True, float_format='%.3f')

coef_female.to_csv('tables/logit_coef_female.csv', index=False)
coef_male.to_csv('tables/logit_coef_male.csv', index=False)
coef_female.to_latex('tables/logit_coef_female.tex', index=False, escape=True)
coef_male.to_latex('tables/logit_coef_male.tex', index=False, escape=True)

wp_out = res_wp[['woman_label','man_label','woman_rank_of_man','man_rank_of_woman']].sort_values('woman_label')
mp_out = res_mp[['man_label','woman_label','man_rank_of_woman','woman_rank_of_man']].sort_values('man_label')

wp_out.to_csv('tables/gs_women_propose.csv', index=False)
mp_out.to_csv('tables/gs_men_propose.csv', index=False)
wp_out.to_latex('tables/gs_women_propose.tex', index=False, escape=True)
mp_out.to_latex('tables/gs_men_propose.tex', index=False, escape=True)

summary.to_latex('tables/gs_rank_summary.tex', escape=True)
comparison_pretty = comparison.rename(columns={
    'matching': 'Matching',
    'pairs': 'Pairs',
    'GS_pairs_that_are_mutual_matches': 'Pairs with match=1',
    'jaccard_with_max_mutual_matching': 'Jaccard vs benchmark',
})
comparison_pretty.to_latex(
    'tables/gs_vs_match_comparison.tex',
    index=False,
    escape=True,
    float_format='%.3f',
)


plt.figure(figsize=(7, 6))
sns.heatmap(P_m_labeled, cmap='viridis', cbar_kws={'label': 'p_hat (man says yes)'})
plt.title(f'Wave {WAVE}: Men → Women predicted probabilities')
plt.tight_layout()
plt.savefig('images/prob_matrix_men.png', dpi=200)
plt.close()



# Preference (ranking) matrices for Gale–Shapley (as labels)
def prefs_table(prefs_dict, agent_ids, partner_prefix):
    rows = []
    for a in agent_ids:
        ranked = [id_to_label[p] for p in prefs_dict[a]]
        rows.append([id_to_label[a]] + ranked)
    cols = ['agent'] + [f'rank_{k:02d}' for k in range(1, len(rows[0]))]
    return pd.DataFrame(rows, columns=cols)

prefs_men_table = prefs_table(prefs_men, men_ids, 'W')
prefs_women_table = prefs_table(prefs_women, women_ids, 'M')

prefs_men_table.to_csv('tables/prefs_men.csv', index=False)
prefs_women_table.to_csv('tables/prefs_women.csv', index=False)

# Keep LaTeX tables compact: show top-5 only
prefs_men_table.head(10)[['agent','rank_01','rank_02','rank_03','rank_04','rank_05']].to_latex(
    'tables/prefs_men_top5.tex', index=False, escape=True
)
prefs_women_table.head(10)[['agent','rank_01','rank_02','rank_03','rank_04','rank_05']].to_latex(
    'tables/prefs_women_top5.tex', index=False, escape=True
)

# Probability matrices
P_m_labeled.to_csv('tables/prob_matrix_men.csv')
P_w_labeled.to_csv('tables/prob_matrix_women.csv')



# -----------------
# Wave 9: export the same key outputs
# (Wave 1 can have a unique stable matching; Wave 9 is a good example where the
# women-propose and men-propose stable matchings can differ.)
# -----------------
WAVE_9 = 9

df_wave_pred_9 = df_pred[df_pred['wave'] == WAVE_9].copy()
men_rows_9 = df_wave_pred_9[df_wave_pred_9['gender'] == 'male'].copy()
women_rows_9 = df_wave_pred_9[df_wave_pred_9['gender'] == 'female'].copy()

men_ids_9 = sorted(men_rows_9['person_id'].unique().tolist())
women_ids_9 = sorted(women_rows_9['person_id'].unique().tolist())

id_to_label_9 = {pid: f"M{idx+1:02d}" for idx, pid in enumerate(men_ids_9)}
id_to_label_9.update({pid: f"W{idx+1:02d}" for idx, pid in enumerate(women_ids_9)})

people_wave_9 = (
    people[people['wave'] == WAVE_9]
    .assign(label=lambda d: d['person_id'].map(id_to_label_9))
    .sort_values(['gender', 'label'])
)

participants_out_9 = people_wave_9[['label', 'gender', 'age', 'ethnicity']].copy()
participants_out_9.to_csv('tables/wave09_participants.csv', index=False)
participants_out_9.to_latex('tables/wave09_participants.tex', index=False, escape=True)

# Preference probability matrices (Wave 9)
P_m_9 = men_rows_9.pivot_table(index='person_id', columns='partner_id', values='p_hat', aggfunc='mean')
P_w_9 = women_rows_9.pivot_table(index='person_id', columns='partner_id', values='p_hat', aggfunc='mean')

prefs_men_9 = prefs_from_matrix(P_m_9)
prefs_women_9 = prefs_from_matrix(P_w_9)

# Gale–Shapley (Wave 9)
match_women_propose_9 = gale_shapley(prefs_women_9, prefs_men_9)
res_wp_9 = pd.DataFrame({'woman': list(match_women_propose_9.keys()), 'man': list(match_women_propose_9.values())})
res_wp_9['woman_label'] = res_wp_9['woman'].map(id_to_label_9)
res_wp_9['man_label'] = res_wp_9['man'].map(id_to_label_9)
res_wp_9['woman_rank_of_man'] = res_wp_9.apply(lambda r: rank_in_prefs(prefs_women_9, r['woman'], r['man']), axis=1)
res_wp_9['man_rank_of_woman'] = res_wp_9.apply(lambda r: rank_in_prefs(prefs_men_9, r['man'], r['woman']), axis=1)

match_men_propose_9 = gale_shapley(prefs_men_9, prefs_women_9)
res_mp_9 = pd.DataFrame({'man': list(match_men_propose_9.keys()), 'woman': list(match_men_propose_9.values())})
res_mp_9['man_label'] = res_mp_9['man'].map(id_to_label_9)
res_mp_9['woman_label'] = res_mp_9['woman'].map(id_to_label_9)
res_mp_9['man_rank_of_woman'] = res_mp_9.apply(lambda r: rank_in_prefs(prefs_men_9, r['man'], r['woman']), axis=1)
res_mp_9['woman_rank_of_man'] = res_mp_9.apply(lambda r: rank_in_prefs(prefs_women_9, r['woman'], r['man']), axis=1)

wp_out_9 = res_wp_9[['woman_label', 'man_label', 'woman_rank_of_man', 'man_rank_of_woman']].sort_values('woman_label')
mp_out_9 = res_mp_9[['man_label', 'woman_label', 'man_rank_of_woman', 'woman_rank_of_man']].sort_values('man_label')

wp_out_9.to_csv('tables/gs_women_propose_wave09.csv', index=False)
mp_out_9.to_csv('tables/gs_men_propose_wave09.csv', index=False)
wp_out_9.to_latex('tables/gs_women_propose_wave09.tex', index=False, escape=True)
mp_out_9.to_latex('tables/gs_men_propose_wave09.tex', index=False, escape=True)

summary_9 = pd.DataFrame({
    'women propose': summarize_matching(res_wp_9, 'women'),
    'men propose': summarize_matching(res_mp_9, 'men'),
})
summary_9.to_latex('tables/gs_rank_summary_wave09.tex', escape=True, float_format='%.3f')

# Compare to observed match outcomes (Wave 9)
edges_9 = (
    men_rows_9[['person_id', 'partner_id', 'match']]
    .rename(columns={'person_id': 'man_id', 'partner_id': 'woman_id'})
    .drop_duplicates(['man_id', 'woman_id'])
)
edge_match_9 = edges_9.set_index(['man_id', 'woman_id'])['match']
mutual_edges_9 = set(edge_match_9[edge_match_9 == 1].index.tolist())

wp_pairs_9 = set(map(tuple, res_wp_9.dropna(subset=['man'])[['man', 'woman']].values.tolist()))
mp_pairs_9 = set(map(tuple, res_mp_9.dropna(subset=['woman'])[['man', 'woman']].values.tolist()))

B_9 = nx.Graph()
men_nodes_9 = [f"M_{m}" for m in men_ids_9]
women_nodes_9 = [f"W_{w}" for w in women_ids_9]
B_9.add_nodes_from(men_nodes_9, bipartite=0)
B_9.add_nodes_from(women_nodes_9, bipartite=1)
for (m, w), val in edge_match_9.items():
    if val == 1:
        B_9.add_edge(f"M_{int(m)}", f"W_{int(w)}")

mm_9 = nx.algorithms.bipartite.matching.maximum_matching(B_9, top_nodes=men_nodes_9)
mm_pairs_9 = set()
for u, v in mm_9.items():
    if u.startswith('M_') and v.startswith('W_'):
        mm_pairs_9.add((int(u.split('_', 1)[1]), int(v.split('_', 1)[1])))

comparison_9 = pd.DataFrame([
    {
        'matching': 'GS (women propose)',
        'pairs': len(wp_pairs_9),
        'GS_pairs_that_are_mutual_matches': len(wp_pairs_9 & mutual_edges_9),
        'jaccard_with_max_mutual_matching': jaccard(wp_pairs_9, mm_pairs_9),
    },
    {
        'matching': 'GS (men propose)',
        'pairs': len(mp_pairs_9),
        'GS_pairs_that_are_mutual_matches': len(mp_pairs_9 & mutual_edges_9),
        'jaccard_with_max_mutual_matching': jaccard(mp_pairs_9, mm_pairs_9),
    },
    {
        'matching': 'Max matching within match=1 edges',
        'pairs': len(mm_pairs_9),
        'GS_pairs_that_are_mutual_matches': len(mm_pairs_9),
        'jaccard_with_max_mutual_matching': 1.0,
    },
])

comparison_9_pretty = comparison_9.rename(columns={
    'matching': 'Matching',
    'pairs': 'Pairs',
    'GS_pairs_that_are_mutual_matches': 'Pairs with match=1',
    'jaccard_with_max_mutual_matching': 'Jaccard vs benchmark',
})
comparison_9_pretty.to_latex(
    'tables/gs_vs_match_comparison_wave09.tex',
    index=False,
    escape=True,
    float_format='%.3f',
)

# Optional: Wave 9 heatmap (men → women)
P_m_9_labeled = P_m_9.rename(index=id_to_label_9, columns=id_to_label_9).sort_index().sort_index(axis=1)
plt.figure(figsize=(9, 8))
sns.heatmap(P_m_9_labeled, cmap='viridis', cbar_kws={'label': 'p_hat (man says yes)'})
plt.title(f'Wave {WAVE_9}: Men → Women predicted probabilities')
plt.tight_layout()
plt.savefig('images/prob_matrix_men_wave09.png', dpi=200)
plt.close()

print('Also wrote Wave 9 tables (suffix _wave09) and images/prob_matrix_men_wave09.png')

# Package everything for students
import zipfile
import subprocess

# Build the PDF from main.tex so the package is self-contained
build_dir = Path('.latex-interface/build')
build_dir.mkdir(parents=True, exist_ok=True)

try:
    for _ in range(2):
        subprocess.run(
            [
                'pdflatex',
                '-interaction=nonstopmode',
                '-halt-on-error',
                '-output-directory',
                str(build_dir),
                'main.tex',
            ],
            check=True,
            capture_output=True,
            text=True,
        )
    print('Built PDF:', build_dir / 'main.pdf')
except Exception as e:
    print('WARNING: PDF build failed; continuing without rebuilding PDF.')
    print(e)

# Build a standalone Gale–Shapley explainer PDF as well
try:
    for _ in range(2):
        subprocess.run(
            [
                'pdflatex',
                '-interaction=nonstopmode',
                '-halt-on-error',
                '-output-directory',
                str(build_dir),
                'gale_shapley_explained.tex',
            ],
            check=True,
            capture_output=True,
            text=True,
        )
    print('Built PDF:', build_dir / 'gale_shapley_explained.pdf')
except Exception as e:
    print('WARNING: Gale–Shapley explainer PDF build failed; continuing without it.')
    print(e)

package_path = Path('student_package_speed_dating_gs.zip')
with zipfile.ZipFile(package_path, 'w', compression=zipfile.ZIP_DEFLATED) as z:
    for p in [
        Path('Initial.ipynb'),
        Path('main.tex'),
        Path('gale_shapley_explained.tex'),
        Path('references.bib'),
        PROCESSED_DATA_PATH,
    ]:
        if p.exists():
            z.write(p, arcname=str(p))

    for p in Path('tables').glob('*'):
        if p.is_file():
            z.write(p, arcname=str(p))

    for p in Path('images').glob('*.png'):
        z.write(p, arcname=str(p))

    pdf = build_dir / 'main.pdf'
    if pdf.exists():
        z.write(pdf, arcname=str(pdf))

    pdf = build_dir / 'gale_shapley_explained.pdf'
    if pdf.exists():
        z.write(pdf, arcname=str(pdf))

print('Wrote student package:', package_path)

print('Wrote tables/*.tex, tables/*.csv and images/*.png')
