# PET Challenge 2025 — Approach Comparison

## Two Scoring Approaches

| Aspect | Conservation (MSA + Table S6) | ESM PLM (Zero-Shot) |
|--------|-------------------------------|---------------------|
| **Method** | Penalize mutations at conserved positions | ESM2 log-likelihood ratio per mutation |
| **Compute** | CPU-only, ~1 min (MAFFT) | GPU required, ~30 min (ESM2-650M) |
| **Biological prior** | Conservation = function | Evolutionary plausibility |
| **Notebook** | `Conservation_Scoring_Pipeline.ipynb` | `PET_Challenge_2025_Pipeline_v2.ipynb` |
| **Output** | `results/submission_conservation.csv` | `results/submission_zero_shot_v5.csv` |

## Evaluation Strategy

Since ground truth is unavailable before the competition deadline, we evaluate using:

1. **Inter-approach agreement** — Spearman/Kendall correlation, top-K overlap, rank-rank plots
2. **WT vs mutant separation** — Cohen's d, Mann-Whitney AUC (WT should score higher on average)
3. **IsPETase validation** — 10 single-point mutants with experimental delta-Tm (Son 2019)
4. **Score distribution quality** — entropy, spread, discrimination capacity
5. **Ensemble exploration** — alpha-sweep to find optimal blending

**Competition metric:** NDCG (rank-based) — so rank-order quality is what matters most.


In [None]:
import os
import numpy as np
import pandas as pd
from scipy import stats
from scipy.spatial.distance import jensenshannon
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
matplotlib.rcParams.update({'font.size': 11, 'figure.dpi': 120})

# Paths — detect project root (works in Colab, VS Code, and local)
if os.path.exists('data/petase_challenge_data'):
    PROJECT_ROOT = os.getcwd()
elif os.path.exists('pet-challenge-2025/data/petase_challenge_data'):
    PROJECT_ROOT = os.path.join(os.getcwd(), 'pet-challenge-2025')
    os.chdir(PROJECT_ROOT)
elif os.path.exists('/content/pet-challenge-2025/data/petase_challenge_data'):
    PROJECT_ROOT = '/content/pet-challenge-2025'
    os.chdir(PROJECT_ROOT)
else:
    raise FileNotFoundError(
        'Cannot find data/petase_challenge_data/. '
        'Run this notebook from the repo root or clone first.'
    )
DATA_DIR = os.path.join(PROJECT_ROOT, 'data', 'petase_challenge_data')
RESULTS_DIR = os.path.join(PROJECT_ROOT, 'results')

CONS_CSV = os.path.join(RESULTS_DIR, 'submission_conservation.csv')
ESM_CSV  = os.path.join(RESULTS_DIR, 'submission_zero_shot_v5.csv')
WT_CSV   = os.path.join(DATA_DIR, 'pet-2025-wildtype-cds.csv')
TEST_CSV = os.path.join(DATA_DIR, 'predictive-pet-zero-shot-test-2025.csv')
FEAT_CSV = os.path.join(PROJECT_ROOT, 'data', 'features_matrix.csv')
MUT_CSV  = os.path.join(PROJECT_ROOT, 'data', 'mutations_dataset.csv')

# Target columns (short aliases)
TARGETS = {
    'act1': 'activity_1 (\u03bcmol [TPA]/min\u00b7mg [E])',
    'act2': 'activity_2 (\u03bcmol [TPA]/min\u00b7mg [E])',
    'expr': 'expression (mg/mL)',
}
TARGET_RANGES = {'act1': (0, 5), 'act2': (0, 5), 'expr': (0, 3)}
TARGET_LABELS = {'act1': 'Activity 1 (pH 5.5)', 'act2': 'Activity 2 (pH 9.0)', 'expr': 'Expression'}

print("Approach Comparison Notebook — PET Challenge 2025")
print(f"Project root: {PROJECT_ROOT}")

In [None]:
# --- File existence checks ---
cons_available = os.path.exists(CONS_CSV)
esm_available  = os.path.exists(ESM_CSV)
print(f"Conservation submission: {'FOUND' if cons_available else 'MISSING'} — {CONS_CSV}")
print(f"ESM submission:          {'FOUND' if esm_available else 'MISSING'} — {ESM_CSV}")

if not cons_available:
    raise FileNotFoundError("Conservation submission is required. Run Conservation_Scoring_Pipeline first.")

# --- Load submissions ---
cons = pd.read_csv(CONS_CSV)
print(f"\nConservation: {len(cons)} sequences, columns: {list(cons.columns)}")

if esm_available:
    esm = pd.read_csv(ESM_CSV)
    print(f"ESM:          {len(esm)} sequences, columns: {list(esm.columns)}")
    # Verify sequence alignment
    assert list(cons['sequence']) == list(esm['sequence']), \
        "Sequence order mismatch between submissions!"
    print("Sequence order: MATCHED")
else:
    esm = None
    print("\nESM submission not found — comparison cells will show conservation-only analysis.")
    print("To enable full comparison, copy submission_zero_shot_v5.csv from Colab to results/")

# --- Load WT sequences for WT/mutant classification ---
wt_df = pd.read_csv(WT_CSV)
wt_set = set(wt_df['Wt AA Sequence'].values)
print(f"\nWT scaffolds loaded: {len(wt_set)}")

is_wt = cons['sequence'].isin(wt_set).values
n_wt = is_wt.sum()
n_mut = (~is_wt).sum()
print(f"Test set: {n_wt} WT sequences, {n_mut} mutant sequences")
print(f"WT fraction: {n_wt/len(cons)*100:.1f}%")


## 1. Submission-Level Comparison

Direct comparison of the two approaches on the 4988 test sequences.


In [None]:
if not esm_available:
    print("SKIPPING — ESM submission not available.")
else:
    fig, axes = plt.subplots(1, 3, figsize=(16, 4.5))
    for idx, (key, col) in enumerate(TARGETS.items()):
        ax = axes[idx]
        c_scores = cons[col].values
        e_scores = esm[col].values

        # Color by WT/mutant
        colors = np.where(is_wt, 'steelblue', 'coral')
        ax.scatter(c_scores[~is_wt], e_scores[~is_wt], s=6, alpha=0.15, c='coral', label='Mutant')
        ax.scatter(c_scores[is_wt], e_scores[is_wt], s=12, alpha=0.4, c='steelblue', label='WT')

        # Diagonal reference
        lo, hi = TARGET_RANGES[key]
        ax.plot([lo, hi], [lo, hi], 'k--', alpha=0.3, lw=1)

        # Statistics
        rho, p_rho = stats.spearmanr(c_scores, e_scores)
        tau, p_tau = stats.kendalltau(c_scores, e_scores)
        ax.set_title(f"{TARGET_LABELS[key]}\nSpearman $\\rho$={rho:.3f}, Kendall $\\tau$={tau:.3f}", fontsize=11)
        ax.set_xlabel('Conservation score')
        ax.set_ylabel('ESM score')
        ax.legend(fontsize=8, loc='lower right')

    plt.suptitle('ESM vs Conservation — Per-Target Score Comparison', fontsize=13, fontweight='bold')
    plt.tight_layout()
    plt.savefig(os.path.join(RESULTS_DIR, 'comparison_scatter.png'), dpi=150, bbox_inches='tight')
    plt.show()
    print("Saved: results/comparison_scatter.png")


In [None]:
if not esm_available:
    print("SKIPPING — ESM submission not available.")
else:
    K_values = [20, 50, 100, 200, 500]
    fig, axes = plt.subplots(1, 3, figsize=(16, 4.5))

    for idx, (key, col) in enumerate(TARGETS.items()):
        ax = axes[idx]
        c_ranks = stats.rankdata(-cons[col].values)  # rank 1 = highest score
        e_ranks = stats.rankdata(-esm[col].values)

        shared_counts = []
        jaccard_vals = []
        for K in K_values:
            c_topk = set(np.where(c_ranks <= K)[0])
            e_topk = set(np.where(e_ranks <= K)[0])
            overlap = len(c_topk & e_topk)
            jaccard = overlap / len(c_topk | e_topk) if len(c_topk | e_topk) > 0 else 0
            shared_counts.append(overlap)
            jaccard_vals.append(jaccard)

        x = np.arange(len(K_values))
        width = 0.35
        ax.bar(x - width/2, shared_counts, width, color='steelblue', edgecolor='white', label='Shared')
        ax.bar(x + width/2, [K - s for K, s in zip(K_values, shared_counts)], width,
               color='coral', edgecolor='white', label='Unique (each)')

        # Annotate Jaccard
        for i, (j, k) in enumerate(zip(jaccard_vals, K_values)):
            ax.text(i, k * 0.95, f'J={j:.2f}', ha='center', va='top', fontsize=8, fontweight='bold')

        ax.set_xticks(x)
        ax.set_xticklabels([f'Top {K}' for K in K_values], fontsize=9)
        ax.set_ylabel('Number of sequences')
        ax.set_title(TARGET_LABELS[key], fontsize=11)
        ax.legend(fontsize=8, loc='upper left')

    plt.suptitle('Top-K Overlap Between Approaches', fontsize=13, fontweight='bold')
    plt.tight_layout()
    plt.savefig(os.path.join(RESULTS_DIR, 'comparison_topk_overlap.png'), dpi=150, bbox_inches='tight')
    plt.show()
    print("Saved: results/comparison_topk_overlap.png")


In [None]:
def cohens_d(group1, group2):
    # Cohen's d effect size
    n1, n2 = len(group1), len(group2)
    var1, var2 = np.var(group1, ddof=1), np.var(group2, ddof=1)
    pooled_std = np.sqrt(((n1 - 1) * var1 + (n2 - 1) * var2) / (n1 + n2 - 2))
    return (np.mean(group1) - np.mean(group2)) / pooled_std if pooled_std > 0 else 0

def mann_whitney_auc(group1, group2):
    # Mann-Whitney U as AUC: P(WT > Mutant)
    u, _ = stats.mannwhitneyu(group1, group2, alternative='greater')
    return u / (len(group1) * len(group2))

approaches = [('Conservation', cons)]
if esm_available:
    approaches.append(('ESM', esm))

n_rows = len(approaches)
fig, axes = plt.subplots(n_rows, 3, figsize=(16, 4.5 * n_rows))
if n_rows == 1:
    axes = axes.reshape(1, -1)

colors_map = {'Conservation': ('steelblue', 'coral'), 'ESM': ('forestgreen', 'darkorange')}
sep_stats = []

for row_idx, (name, df) in enumerate(approaches):
    c_wt, c_mut = colors_map[name]
    for col_idx, (key, col) in enumerate(TARGETS.items()):
        ax = axes[row_idx, col_idx]
        wt_scores = df[col].values[is_wt]
        mut_scores = df[col].values[~is_wt]

        ax.hist(wt_scores, bins=30, alpha=0.7, color=c_wt, edgecolor='white', label='WT', density=True)
        ax.hist(mut_scores, bins=30, alpha=0.5, color=c_mut, edgecolor='white', label='Mutant', density=True)

        d = cohens_d(wt_scores, mut_scores)
        auc = mann_whitney_auc(wt_scores, mut_scores)
        sep_stats.append({'Approach': name, 'Target': TARGET_LABELS[key],
                          'Cohen_d': d, 'MW_AUC': auc,
                          'WT_mean': np.mean(wt_scores), 'Mut_mean': np.mean(mut_scores)})

        ax.set_title(f"{name} — {TARGET_LABELS[key]}\nCohen's d={d:.2f}, AUC={auc:.3f}", fontsize=10)
        ax.set_xlabel('Score')
        ax.set_ylabel('Density')
        ax.legend(fontsize=8)

plt.suptitle('WT vs Mutant Score Separation', fontsize=13, fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(RESULTS_DIR, 'comparison_wt_separation.png'), dpi=150, bbox_inches='tight')
plt.show()

sep_df = pd.DataFrame(sep_stats)
print("\nWT/Mutant Separation Statistics:")
print(sep_df.to_string(index=False, float_format='%.3f'))
print("\nSaved: results/comparison_wt_separation.png")


## 2. IsPETase Validation

The 10 single-point IsPETase mutants with experimentally measured delta-Tm (Son 2019) provide a small but direct validation set. **Important:** IsPETase is *not* one of the 313 challenge scaffolds, so these mutants are not in the 4988 test set.

- **Conservation:** Score each mutant directly via Table S6 frequencies
- **ESM (proxy):** LOOCV on the 31-variant features_matrix.csv to estimate ML prediction quality


In [None]:
# Table S6 from Buchholz et al. (Proteins, 2022)
# {ispetase_position: (consensus_aa, conservation_frequency_pct)}
TABLE_S6 = {
    32: ("Y", 74), 34: ("R", 84), 35: ("G", 91), 36: ("P", 92),
    38: ("P", 95), 39: ("T", 87), 42: ("S", 73), 45: ("A", 87),
    48: ("G", 97), 49: ("P", 71), 57: ("V", 91), 62: ("G", 93),
    63: ("F", 93), 64: ("G", 83), 65: ("G", 86), 66: ("G", 93),
    67: ("T", 76), 68: ("I", 79), 69: ("Y", 84), 70: ("Y", 91),
    71: ("P", 98), 72: ("T", 85), 74: ("T", 81), 76: ("G", 90),
    77: ("T", 84), 78: ("F", 74), 79: ("G", 90), 80: ("A", 77),
    85: ("P", 99), 86: ("G", 100), 88: ("T", 76), 92: ("S", 70),
    96: ("W", 93), 98: ("G", 89), 99: ("P", 82), 100: ("R", 81),
    101: ("L", 81), 102: ("A", 97), 103: ("S", 96), 105: ("G", 99),
    106: ("F", 97), 107: ("V", 96), 108: ("V", 94), 111: ("I", 84),
    113: ("T", 96), 118: ("D", 98), 120: ("P", 87), 122: ("S", 71),
    123: ("R", 99), 124: ("G", 73), 126: ("Q", 92), 127: ("L", 83),
    128: ("L", 78), 129: ("A", 88), 130: ("A", 96), 131: ("L", 88),
    132: ("D", 82), 133: ("Y", 77), 134: ("L", 85), 138: ("S", 83),
    145: ("V", 82), 146: ("R", 71), 148: ("R", 81), 150: ("D", 94),
    153: ("R", 94), 154: ("L", 85), 156: ("V", 89), 158: ("G", 100),
    159: ("H", 87), 160: ("S", 100), 161: ("M", 94), 162: ("G", 100),
    163: ("G", 99), 164: ("G", 96), 165: ("G", 97), 167: ("L", 88),
    169: ("A", 89), 170: ("A", 82), 173: ("R", 76), 174: ("P", 76),
    176: ("L", 84), 178: ("A", 95), 179: ("A", 78), 181: ("P", 80),
    182: ("L", 79), 184: ("P", 76), 185: ("W", 77), 197: ("P", 97),
    198: ("T", 93), 202: ("G", 75), 206: ("D", 100), 209: ("A", 87),
    211: ("V", 70), 214: ("H", 77), 217: ("P", 79), 218: ("F", 74),
    219: ("Y", 96), 221: ("S", 70), 228: ("A", 77), 229: ("Y", 83),
    231: ("E", 91), 232: ("L", 76), 235: ("A", 76), 237: ("H", 100),
    240: ("P", 74), 244: ("N", 74), 257: ("W", 90), 258: ("L", 80),
    259: ("K", 94), 260: ("R", 78), 261: ("F", 79), 263: ("D", 94),
    265: ("D", 97), 266: ("T", 76), 267: ("R", 96), 268: ("Y", 92),
    270: ("Q", 77), 271: ("F", 96), 272: ("L", 86), 273: ("C", 95),
    274: ("P", 82),
}

# Load mutations dataset — filter single-point IsPETase mutants
mut_df = pd.read_csv(MUT_CSV)
features_df = pd.read_csv(FEAT_CSV)

# Single-point mutants: N_mutations == 1 in features_matrix
single_mask = features_df['N_mutations'] == 1
single_variants = features_df.loc[single_mask, 'variant_name'].values
single_df = mut_df[mut_df['variant_name'].isin(single_variants)].copy()
print(f"Single-point IsPETase mutants: {len(single_df)}")
print(f"Variants: {list(single_df['variant_name'].values)}")

# Parse mutation: e.g. "D186H" -> (pos=186, wt_aa='D', mut_aa='H')
def parse_single_mutation(mut_str):
    wt_aa = mut_str[0]
    mut_aa = mut_str[-1]
    pos = int(mut_str[1:-1])
    return pos, wt_aa, mut_aa

# Conservation score for each mutant
cons_scores = []
for _, row in single_df.iterrows():
    pos, wt_aa, mut_aa = parse_single_mutation(row['mutation'])
    if pos in TABLE_S6:
        consensus_aa, freq_pct = TABLE_S6[pos]
        # Penalty: frequency of consensus (WT) minus estimated frequency of mutant
        # If WT matches consensus, use table freq; otherwise estimate lower
        wt_freq = freq_pct / 100.0 if wt_aa == consensus_aa else (100 - freq_pct) / (19 * 100.0)
        mut_freq = freq_pct / 100.0 if mut_aa == consensus_aa else (100 - freq_pct) / (19 * 100.0)
        score = -(wt_freq - mut_freq)  # negative penalty = less damage to function
    else:
        # Position not conserved (< 70%) — mutation has minimal predicted impact
        score = 0.0
    cons_scores.append(score)

single_df = single_df.copy()
single_df['cons_score'] = cons_scores
single_df['delta_tm'] = single_df['delta_tm'].astype(float)

# Print table
print("\n--- IsPETase Single-Point Mutant Validation ---")
print(f"{'Mutation':<12} {'Pos':>4} {'In S6?':>6} {'Cons Score':>10} {'delta-Tm':>9}")
print("-" * 50)
for _, row in single_df.iterrows():
    pos, wt_aa, mut_aa = parse_single_mutation(row['mutation'])
    in_s6 = 'Yes' if pos in TABLE_S6 else 'No'
    print(f"{row['mutation']:<12} {pos:>4} {in_s6:>6} {row['cons_score']:>10.4f} {row['delta_tm']:>9.2f}")

# Spearman correlation
rho, pval = stats.spearmanr(single_df['cons_score'], single_df['delta_tm'])
print(f"\nSpearman rho = {rho:.3f}, p = {pval:.4f}")
print(f"(Positive rho means: higher conservation score ~ higher delta-Tm — correct direction)")


In [None]:
from sklearn.linear_model import Ridge
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import LeaveOneOut

# --- ML LOOCV on features_matrix (proxy for ESM prediction quality) ---
feat_df = pd.read_csv(FEAT_CSV)
feature_cols = [c for c in feat_df.columns if c not in ('variant_name', 'Tm')]
X = feat_df[feature_cols].values
y = feat_df['Tm'].values
names = feat_df['variant_name'].values

loo = LeaveOneOut()
scaler = StandardScaler()
y_pred_loocv = np.zeros(len(y))

for train_idx, test_idx in loo.split(X):
    X_train, X_test = X[train_idx], X[test_idx]
    y_train = y[train_idx]
    X_train_s = scaler.fit_transform(X_train)
    X_test_s = scaler.transform(X_test)
    model = Ridge(alpha=10.0)
    model.fit(X_train_s, y_train)
    y_pred_loocv[test_idx] = model.predict(X_test_s)

ml_rho, ml_pval = stats.spearmanr(y, y_pred_loocv)
ml_rmse = np.sqrt(np.mean((y - y_pred_loocv) ** 2))
print(f"ML LOOCV (Ridge alpha=10, {len(y)} variants):")
print(f"  Spearman rho = {ml_rho:.3f}, p = {ml_pval:.6f}")
print(f"  RMSE = {ml_rmse:.2f} C")

# --- Side-by-side validation plot ---
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Left: Conservation score vs delta-Tm (single-point mutants)
ax1.scatter(single_df['cons_score'], single_df['delta_tm'],
            s=80, c='steelblue', edgecolors='white', zorder=5, linewidths=0.5)
for _, row in single_df.iterrows():
    ax1.annotate(row['mutation'], (row['cons_score'], row['delta_tm']),
                 fontsize=8, xytext=(6, 4), textcoords='offset points')

rho_cons, p_cons = stats.spearmanr(single_df['cons_score'], single_df['delta_tm'])
ax1.set_xlabel('Conservation Score')
ax1.set_ylabel('$\\Delta T_m$ ($\\degree$C)')
ax1.set_title(f'Conservation vs $\\Delta T_m$\nSpearman $\\rho$ = {rho_cons:.3f} (p = {p_cons:.3f})', fontsize=11)
ax1.axhline(0, ls='--', color='gray', alpha=0.4)
ax1.axvline(0, ls='--', color='gray', alpha=0.4)
legend1 = [Line2D([0], [0], marker='o', color='w', markerfacecolor='steelblue',
                  markersize=8, label=f'n = {len(single_df)} mutants')]
ax1.legend(handles=legend1, fontsize=8, loc='lower right')

# Right: ML predicted Tm vs actual Tm (LOOCV, all 31 variants)
ax2.scatter(y, y_pred_loocv, s=60, c='forestgreen', edgecolors='white', zorder=5, linewidths=0.5)
lims = [min(y.min(), y_pred_loocv.min()) - 2, max(y.max(), y_pred_loocv.max()) + 2]
ax2.plot(lims, lims, 'r--', alpha=0.5, lw=1.5, label='Perfect prediction')
ax2.set_xlim(lims)
ax2.set_ylim(lims)
ax2.set_xlabel('Actual $T_m$ ($\\degree$C)')
ax2.set_ylabel('Predicted $T_m$ ($\\degree$C, LOOCV)')
ax2.set_title(f'ML LOOCV: Ridge on Features Matrix\n$\\rho$ = {ml_rho:.3f}, RMSE = {ml_rmse:.1f}$\\degree$C', fontsize=11)
legend2 = [Line2D([0], [0], marker='o', color='w', markerfacecolor='forestgreen',
                  markersize=8, label=f'n = {len(y)} variants'),
           Line2D([0], [0], ls='--', color='r', alpha=0.5, label='y = x')]
ax2.legend(handles=legend2, fontsize=8, loc='lower right')

plt.suptitle('IsPETase Validation — Conservation (direct) vs ML Feature Proxy', fontsize=13, fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(RESULTS_DIR, 'comparison_ispetase_validation.png'), dpi=150, bbox_inches='tight')
plt.show()
print("Saved: results/comparison_ispetase_validation.png")


## 3. Score Distribution Analysis

Compare how each approach distributes scores across the 4988 test sequences.
Good ranking requires spread (not all scores clustered together) and discrimination
(different sequences get distinguishably different scores).


In [None]:
approaches_avail = [('Conservation', cons, 'steelblue')]
if esm_available:
    approaches_avail.append(('ESM', esm, 'forestgreen'))

fig, axes = plt.subplots(2, 3, figsize=(16, 9))

entropy_stats = []
for col_idx, (key, col) in enumerate(TARGETS.items()):
    ax_top = axes[0, col_idx]
    ax_bot = axes[1, col_idx]

    for name, df, color in approaches_avail:
        scores = df[col].values

        # Top row: overlaid histograms / KDE
        ax_top.hist(scores, bins=50, alpha=0.5, color=color, edgecolor='white',
                    density=True, label=name)

        # Ranking entropy: how uniformly spread are the ranks?
        ranks = stats.rankdata(scores)
        norm_ranks = ranks / len(ranks)
        # Bin into 50 bins and compute entropy
        hist_counts, _ = np.histogram(norm_ranks, bins=50, range=(0, 1))
        hist_probs = hist_counts / hist_counts.sum()
        hist_probs = hist_probs[hist_probs > 0]
        rank_entropy = stats.entropy(hist_probs) / np.log(50)  # normalize to [0, 1]
        n_unique = len(np.unique(scores))
        entropy_stats.append({'Approach': name, 'Target': TARGET_LABELS[key],
                              'Rank Entropy': rank_entropy, 'Unique Scores': n_unique,
                              'Std': np.std(scores), 'IQR': np.percentile(scores, 75) - np.percentile(scores, 25)})

    ax_top.set_title(TARGET_LABELS[key], fontsize=11)
    ax_top.set_xlabel('Score')
    ax_top.set_ylabel('Density')
    ax_top.legend(fontsize=8)

    # Bottom row: rank-rank percentile plots
    if esm_available:
        c_pctile = stats.rankdata(cons[col].values) / len(cons) * 100
        e_pctile = stats.rankdata(esm[col].values) / len(esm) * 100
        ax_bot.scatter(c_pctile, e_pctile, s=3, alpha=0.15, c='gray')
        ax_bot.plot([0, 100], [0, 100], 'r--', alpha=0.5, lw=1)
        rho_rr, _ = stats.spearmanr(c_pctile, e_pctile)
        ax_bot.set_title(f'Rank-Rank ($\\rho$ = {rho_rr:.3f})', fontsize=10)
        ax_bot.set_xlabel('Conservation percentile')
        ax_bot.set_ylabel('ESM percentile')
        legend_rr = [Line2D([0], [0], ls='--', color='r', alpha=0.5, label='y = x'),
                     Line2D([0], [0], marker='o', color='w', markerfacecolor='gray',
                            markersize=5, label=f'n = {len(cons)}')]
        ax_bot.legend(handles=legend_rr, fontsize=8, loc='lower right')
    else:
        ax_bot.text(0.5, 0.5, 'ESM not available', ha='center', va='center',
                    transform=ax_bot.transAxes, fontsize=12, color='gray')
        ax_bot.set_axis_off()

plt.suptitle('Score Distributions and Rank Agreement', fontsize=13, fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(RESULTS_DIR, 'comparison_distributions.png'), dpi=150, bbox_inches='tight')
plt.show()

ent_df = pd.DataFrame(entropy_stats)
print("\nDistribution Statistics:")
print(ent_df.to_string(index=False, float_format='%.4f'))
print("\nRank entropy = 1.0 means perfectly uniform rank distribution (ideal).")
print("Saved: results/comparison_distributions.png")


## 4. Ensemble Exploration

Blend Conservation and ESM scores:
**ensemble = alpha * Conservation_rank + (1 - alpha) * ESM_rank**

Sweep alpha from 0 (pure ESM) to 1 (pure Conservation) in 51 steps.
Evaluate WT/mutant separation (Cohen's d) and ranking entropy at each alpha.


In [None]:
if not esm_available:
    print("SKIPPING — ESM submission not available for ensemble exploration.")
else:
    alphas = np.linspace(0, 1, 51)

    fig, axes = plt.subplots(2, 3, figsize=(16, 9))
    optimal_alphas = {}

    for col_idx, (key, col) in enumerate(TARGETS.items()):
        c_ranks = stats.rankdata(cons[col].values)
        e_ranks = stats.rankdata(esm[col].values)
        c_norm = c_ranks / len(c_ranks)
        e_norm = e_ranks / len(e_ranks)

        d_vals = []
        ent_vals = []
        for alpha in alphas:
            blend = alpha * c_norm + (1 - alpha) * e_norm
            wt_blend = blend[is_wt]
            mut_blend = blend[~is_wt]
            d = cohens_d(wt_blend, mut_blend)
            d_vals.append(d)

            # Rank entropy of blended scores
            blend_ranks = stats.rankdata(blend)
            norm_br = blend_ranks / len(blend_ranks)
            hist_c, _ = np.histogram(norm_br, bins=50, range=(0, 1))
            hist_p = hist_c / hist_c.sum()
            hist_p = hist_p[hist_p > 0]
            ent_vals.append(stats.entropy(hist_p) / np.log(50))

        d_vals = np.array(d_vals)
        ent_vals = np.array(ent_vals)

        # Optimal alpha: maximize |Cohen's d|
        best_idx = np.argmax(np.abs(d_vals))
        optimal_alphas[key] = alphas[best_idx]

        # Top row: Cohen's d vs alpha
        ax_d = axes[0, col_idx]
        ax_d.plot(alphas, d_vals, 'b-', lw=2, label="Cohen's d")
        ax_d.axvline(alphas[best_idx], ls='--', color='red', alpha=0.7,
                     label=f'Optimal $\\alpha$={alphas[best_idx]:.2f}')
        ax_d.set_xlabel('$\\alpha$ (Conservation weight)')
        ax_d.set_ylabel("Cohen's d (WT vs Mutant)")
        ax_d.set_title(TARGET_LABELS[key], fontsize=11)
        ax_d.legend(fontsize=8)

        # Bottom row: Rank entropy vs alpha
        ax_e = axes[1, col_idx]
        ax_e.plot(alphas, ent_vals, 'g-', lw=2, label='Rank entropy')
        ax_e.axvline(alphas[best_idx], ls='--', color='red', alpha=0.7,
                     label=f'Optimal $\\alpha$={alphas[best_idx]:.2f}')
        ax_e.set_xlabel('$\\alpha$ (Conservation weight)')
        ax_e.set_ylabel('Normalized rank entropy')
        ax_e.set_title(TARGET_LABELS[key], fontsize=11)
        ax_e.legend(fontsize=8)

    plt.suptitle('Ensemble Alpha Sweep — WT/Mutant Separation vs Ranking Quality',
                 fontsize=13, fontweight='bold')
    plt.tight_layout()
    plt.savefig(os.path.join(RESULTS_DIR, 'comparison_ensemble_sweep.png'), dpi=150, bbox_inches='tight')
    plt.show()

    print("\nOptimal alpha per target (maximizes |Cohen's d|):")
    for key, alpha in optimal_alphas.items():
        print(f"  {TARGET_LABELS[key]}: alpha = {alpha:.2f}")
    print("\nSaved: results/comparison_ensemble_sweep.png")


In [None]:
if not esm_available:
    print("SKIPPING — ESM submission not available.")
else:
    # Generate ensembles
    ensemble_50 = cons.copy()
    ensemble_opt = cons.copy()

    corr_data = {}
    for key, col in TARGETS.items():
        c_ranks = stats.rankdata(cons[col].values)
        e_ranks = stats.rankdata(esm[col].values)
        c_norm = c_ranks / len(c_ranks)
        e_norm = e_ranks / len(e_ranks)

        # alpha = 0.5 ensemble
        blend_50 = 0.5 * c_norm + 0.5 * e_norm
        lo, hi = TARGET_RANGES[key]
        ensemble_50[col] = lo + (stats.rankdata(blend_50) - 1) / (len(blend_50) - 1) * (hi - lo)

        # Per-target optimal alpha
        alpha_opt = optimal_alphas[key]
        blend_opt = alpha_opt * c_norm + (1 - alpha_opt) * e_norm
        ensemble_opt[col] = lo + (stats.rankdata(blend_opt) - 1) / (len(blend_opt) - 1) * (hi - lo)

        # Spearman between all pairs
        corr_data[key] = {
            'Conservation': cons[col].values,
            'ESM': esm[col].values,
            'Ensemble_50': ensemble_50[col].values,
        }

    # Print correlation matrix per target
    print("Inter-Approach Spearman Correlation:\n")
    for key in TARGETS:
        data = corr_data[key]
        names_c = list(data.keys())
        n = len(names_c)
        matrix = np.ones((n, n))
        for i in range(n):
            for j in range(i + 1, n):
                rho, _ = stats.spearmanr(data[names_c[i]], data[names_c[j]])
                matrix[i, j] = rho
                matrix[j, i] = rho
        corr_df = pd.DataFrame(matrix, index=names_c, columns=names_c)
        print(f"--- {TARGET_LABELS[key]} ---")
        print(corr_df.to_string(float_format='%.3f'))
        print()


In [None]:
pairs = [('act1', 'act2'), ('act1', 'expr'), ('act2', 'expr')]
fig, axes = plt.subplots(1, 3, figsize=(16, 4.5))

for idx, (k1, k2) in enumerate(pairs):
    ax = axes[idx]
    col1, col2 = TARGETS[k1], TARGETS[k2]

    # Conservation
    c1, c2 = cons[col1].values, cons[col2].values
    ax.scatter(c1, c2, s=4, alpha=0.12, c='steelblue', label='Conservation')

    if esm_available:
        e1, e2 = esm[col1].values, esm[col2].values
        ax.scatter(e1, e2, s=4, alpha=0.12, c='forestgreen', label='ESM')

    # Spearman annotations
    rho_c, _ = stats.spearmanr(c1, c2)
    text = f"Cons $\\rho$={rho_c:.3f}"
    if esm_available:
        rho_e, _ = stats.spearmanr(e1, e2)
        text += f"\nESM $\\rho$={rho_e:.3f}"

    ax.text(0.03, 0.97, text, transform=ax.transAxes, fontsize=9, va='top',
            bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

    ax.set_xlabel(TARGET_LABELS[k1])
    ax.set_ylabel(TARGET_LABELS[k2])
    ax.set_title(f'{TARGET_LABELS[k1]} vs {TARGET_LABELS[k2]}', fontsize=10)

    legend_elems = [Line2D([0], [0], marker='o', color='w', markerfacecolor='steelblue',
                           markersize=6, label='Conservation')]
    if esm_available:
        legend_elems.append(Line2D([0], [0], marker='o', color='w', markerfacecolor='forestgreen',
                                   markersize=6, label='ESM'))
    ax.legend(handles=legend_elems, fontsize=8, loc='lower right')

plt.suptitle('Cross-Target Consistency', fontsize=13, fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(RESULTS_DIR, 'comparison_cross_target.png'), dpi=150, bbox_inches='tight')
plt.show()
print("Saved: results/comparison_cross_target.png")


## 5. Summary and Recommendation

Aggregate all metrics into a winner-per-criterion table.


In [None]:
summary_rows = []

# Criterion 1: IsPETase validation (Conservation direct)
summary_rows.append({
    'Criterion': 'IsPETase delta-Tm correlation',
    'Conservation': f'rho={rho_cons:.3f} (direct, n={len(single_df)})',
    'ESM (proxy)': f'rho={ml_rho:.3f} (LOOCV Ridge, n={len(y)})',
    'Winner': 'Conservation' if abs(rho_cons) > abs(ml_rho) else 'ESM proxy'
})

# Criterion 2: WT/mutant separation (average Cohen's d)
if esm_available:
    cons_d_avg = sep_df[sep_df['Approach'] == 'Conservation']['Cohen_d'].abs().mean()
    esm_d_avg = sep_df[sep_df['Approach'] == 'ESM']['Cohen_d'].abs().mean()
    summary_rows.append({
        'Criterion': 'WT/Mutant separation (avg |d|)',
        'Conservation': f'{cons_d_avg:.3f}',
        'ESM (proxy)': f'{esm_d_avg:.3f}',
        'Winner': 'Conservation' if cons_d_avg > esm_d_avg else 'ESM'
    })
else:
    cons_d_avg = sep_df[sep_df['Approach'] == 'Conservation']['Cohen_d'].abs().mean()
    summary_rows.append({
        'Criterion': 'WT/Mutant separation (avg |d|)',
        'Conservation': f'{cons_d_avg:.3f}',
        'ESM (proxy)': 'N/A',
        'Winner': 'Conservation (only available)'
    })

# Criterion 3: Ranking entropy (higher = better discrimination)
cons_ent = ent_df[ent_df['Approach'] == 'Conservation']['Rank Entropy'].mean()
if esm_available:
    esm_ent = ent_df[ent_df['Approach'] == 'ESM']['Rank Entropy'].mean()
    summary_rows.append({
        'Criterion': 'Ranking entropy (avg)',
        'Conservation': f'{cons_ent:.4f}',
        'ESM (proxy)': f'{esm_ent:.4f}',
        'Winner': 'Conservation' if cons_ent > esm_ent else 'ESM'
    })
else:
    summary_rows.append({
        'Criterion': 'Ranking entropy (avg)',
        'Conservation': f'{cons_ent:.4f}',
        'ESM (proxy)': 'N/A',
        'Winner': 'Conservation (only available)'
    })

# Criterion 4: Compute cost
summary_rows.append({
    'Criterion': 'Compute requirements',
    'Conservation': 'CPU, ~1 min',
    'ESM (proxy)': 'GPU (A100), ~30 min',
    'Winner': 'Conservation'
})

# Criterion 5: Biological interpretability
summary_rows.append({
    'Criterion': 'Interpretability',
    'Conservation': 'High (per-position conservation)',
    'ESM (proxy)': 'Medium (log-likelihood ratios)',
    'Winner': 'Conservation'
})

summary = pd.DataFrame(summary_rows)
print("=" * 80)
print("APPROACH COMPARISON SUMMARY")
print("=" * 80)
print(summary.to_string(index=False))

# Count wins
cons_wins = sum(1 for r in summary_rows if 'Conservation' in r['Winner'] and 'ESM' not in r['Winner'])
esm_wins = sum(1 for r in summary_rows if 'ESM' in r['Winner'] and 'Conservation' not in r['Winner'])
ties = len(summary_rows) - cons_wins - esm_wins

print(f"\n--- Score: Conservation {cons_wins}, ESM {esm_wins}, Ties/N/A {ties} ---\n")

if cons_wins > esm_wins:
    rec = "Conservation"
    reason = ("Conservation scoring wins on more criteria. It provides direct biological "
              "interpretability, requires no GPU, and correlates with experimental delta-Tm data. "
              "Recommended as primary submission.")
elif esm_wins > cons_wins:
    rec = "ESM"
    reason = ("ESM scoring wins on more criteria, likely providing better ranking through "
              "evolutionary plausibility signals captured by the protein language model.")
else:
    rec = "Ensemble or Conservation"
    reason = ("Results are mixed. Consider submitting Conservation as the safer choice "
              "(interpretable, validated) or an ensemble at alpha=0.5 for potential synergy.")

print(f"RECOMMENDATION: Submit **{rec}**")
print(f"Rationale: {reason}")

if esm_available:
    print(f"\nIf ensemble is desired, per-target optimal alphas:")
    for key, alpha in optimal_alphas.items():
        print(f"  {TARGET_LABELS[key]}: alpha = {alpha:.2f}")
