# 05 - Embedding Validation

**Author:** Agna Chan | **Date:** December 2025  
**Repository:** github.com/biohackingmathematician/frontier-pep

---

**CRITICAL DISCLAIMER**

This notebook is for RESEARCH AND EDUCATIONAL PURPOSES ONLY.
No medical advice, dosing, or protocol recommendations provided.

---

## Purpose

Validate that trained GNN embeddings capture meaningful mechanistic structure:
1. t-SNE visualization - Do peptides cluster by class?
2. Cluster purity metrics - Quantitative validation
3. Nearest neighbor analysis - Do similar peptides share targets/pathways?


In [None]:
import sys
sys.path.insert(0, '../src')

import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.metrics import silhouette_score, adjusted_rand_score
from sklearn.cluster import KMeans

from peptide_atlas import PeptideAtlas, print_disclaimer
from peptide_atlas.constants import PeptideClass, PEPTIDE_CLASS_COLORS

print_disclaimer()


In [None]:
# Load Atlas
atlas = PeptideAtlas.load("../data/processed/")

print(f"Peptides: {atlas.num_peptides}")
print(f"Embedding dim: {atlas.embedding_dim}")
print(f"Has embeddings: {atlas.has_embeddings}")

# Get data
peptide_names = [p.canonical_name for p in atlas.kg.peptides]
peptide_classes = [p.peptide_class.value for p in atlas.kg.peptides]
evidence_tiers = [p.evidence_tier.value for p in atlas.kg.peptides]

# Load embeddings
embeddings = np.load("../data/processed/embeddings.npy")
print(f"Embeddings shape: {embeddings.shape}")


In [None]:
# t-SNE Projection
print("Computing t-SNE...")

tsne = TSNE(
    n_components=2,
    perplexity=min(15, len(peptide_names) - 1),
    random_state=42,
    n_iter=1000,
)
embeddings_2d = tsne.fit_transform(embeddings)

print(f"t-SNE shape: {embeddings_2d.shape}")


In [None]:
# Visualization
fig, ax = plt.subplots(figsize=(14, 10))

unique_classes = sorted(set(peptide_classes))

for pclass in unique_classes:
    mask = [c == pclass for c in peptide_classes]
    indices = [i for i, m in enumerate(mask) if m]
    
    if not indices:
        continue
    
    try:
        color = PEPTIDE_CLASS_COLORS.get(PeptideClass(pclass), "#888888")
    except:
        color = "#888888"
    
    ax.scatter(
        embeddings_2d[indices, 0],
        embeddings_2d[indices, 1],
        c=color,
        label=pclass.replace("_", " ").title(),
        s=100,
        alpha=0.7,
        edgecolors='white',
        linewidths=0.5,
    )
    
    for i in indices:
        ax.annotate(
            peptide_names[i],
            (embeddings_2d[i, 0], embeddings_2d[i, 1]),
            fontsize=6,
            alpha=0.7,
        )

ax.set_xlabel("t-SNE Dimension 1")
ax.set_ylabel("t-SNE Dimension 2")
ax.set_title("Peptide Atlas: t-SNE of GNN Embeddings\n(Research Use Only)")
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)

plt.tight_layout()
plt.savefig("../outputs/embedding_tsne.png", dpi=150, bbox_inches='tight')
plt.show()

print("Saved: outputs/embedding_tsne.png")


In [None]:
# Cluster Metrics
print("=== Cluster Purity Metrics ===")

class_to_int = {c: i for i, c in enumerate(unique_classes)}
class_labels = np.array([class_to_int[c] for c in peptide_classes])

# Silhouette score
if len(unique_classes) > 1:
    silhouette = silhouette_score(embeddings, class_labels)
    print(f"Silhouette Score: {silhouette:.3f}")
    print("  (>0.5 = good, >0.7 = excellent)")
else:
    silhouette = 0
    print("Only one class - silhouette not applicable")

# K-means and ARI
n_classes = len(unique_classes)
if n_classes > 1:
    kmeans = KMeans(n_clusters=n_classes, random_state=42, n_init=10)
    kmeans_labels = kmeans.fit_predict(embeddings)
    ari = adjusted_rand_score(class_labels, kmeans_labels)
    print(f"Adjusted Rand Index: {ari:.3f}")
    print("  (1.0 = perfect, 0.0 = random)")
else:
    ari = 0


In [None]:
# Nearest Neighbor Analysis
print("\n=== Nearest Neighbor Analysis ===")

def get_k_nearest(emb, idx, k=3):
    query = emb[idx]
    dists = np.linalg.norm(emb - query, axis=1)
    sorted_idx = np.argsort(dists)
    return [(i, dists[i]) for i in sorted_idx[1:k+1]]

same_class = 0
total = 0

for i in range(len(peptide_names)):
    neighbors = get_k_nearest(embeddings, i, k=3)
    query_class = peptide_classes[i]
    
    for n_idx, dist in neighbors:
        if peptide_classes[n_idx] == query_class:
            same_class += 1
        total += 1

nn_accuracy = same_class / total if total > 0 else 0
print(f"NN Same-Class Rate: {nn_accuracy:.1%}")
print(f"  ({same_class}/{total} neighbors share class)")


In [None]:
# Summary
print("\n" + "="*50)
print("VALIDATION SUMMARY")
print("="*50)

print(f"""
Metrics:
  Silhouette Score:    {silhouette:.3f}
  Adjusted Rand Index: {ari:.3f}
  NN Same-Class Rate:  {nn_accuracy:.1%}

Interpretation:
  Silhouette > 0.2 = meaningful clustering
  ARI > 0.2 = class structure captured
  NN Rate > 40% = similar peptides cluster
""")

passed = silhouette > 0.1 or nn_accuracy > 0.3
print(f"Status: {'PASS' if passed else 'NEEDS REVIEW'}")


In [None]:
# Save report
import json
from pathlib import Path

Path("../outputs").mkdir(exist_ok=True)

report = {
    "silhouette_score": float(silhouette),
    "adjusted_rand_index": float(ari),
    "nn_same_class_rate": float(nn_accuracy),
    "num_peptides": len(peptide_names),
    "num_classes": len(unique_classes),
    "embedding_dim": int(embeddings.shape[1]),
}

with open("../outputs/embedding_validation.json", "w") as f:
    json.dump(report, f, indent=2)

print("Saved: outputs/embedding_validation.json")
