## Performing Tests for: 
### 1. Cluster Separation Quality
### 2. Embedding Space Structure Analysis

In [None]:
import pickle
import numpy as np
import pandas as pd
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import umap
from sklearn.metrics import silhouette_score, davies_bouldin_score
from scipy.spatial.distance import pdist, squareform, cosine
from scipy.stats import mannwhitneyu
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from collections import defaultdict

XLSX_PATH = 'data.xlsx'
ANTIBINDER_HEAVY_PKL = 'antibinder_heavy.pkl'

COL_ADC_ID = "ADC ID"
COL_ANTIBODY_NAME = "Antibody Name"
COL_ADC_NAME = "ADC Name"
COL_HEAVY_SEQ = "Antibody Heavy Chain Sequence" 
COL_LIGHT_SEQ = "Antibody Light Chain Sequence" 
COL_ANTIGEN_SEQ = "Antigen Sequence"
TOP_N_ANTIGENS = 5
SEED = 42
np.random.seed(SEED)

# Output files
OUTPUT_DIR = './'

ANTIGEN_NAMES = {
    0: "ERBB2",
    1: "KIT", 
    2: "FGFR2",
    3: "TNFRSF1A",
    4: "EGFR"
}

def load_data():
    print("="*70)
    print("LOADING DATA")
    print("="*70)
    
    df = pd.read_excel(XLSX_PATH)
    print(f"Loaded {len(df)} samples from {XLSX_PATH}")
    
    with open(ANTIBINDER_HEAVY_PKL, 'rb') as f:
        antibinder_emb = pickle.load(f)
    print(f"Loaded AntiBinder embeddings: {len(antibinder_emb)} entries")

    aligned_data = []
    missing = 0
    
    for idx, row in df.iterrows():
        key = row[COL_ADC_ID] if COL_ADC_ID in df.columns else idx
        
        vec = None
        if key in antibinder_emb:
            vec = antibinder_emb[key]
        elif str(key) in antibinder_emb:
            vec = antibinder_emb[str(key)]
            
        if vec is not None:
            if isinstance(vec, torch.Tensor):
                vec = vec.detach().cpu().numpy()
            
            aligned_data.append({
                'adc_id': key,
                'antibody_name': row.get(COL_ANTIBODY_NAME, 'Unknown'),
                'antigen_seq': row[COL_ANTIGEN_SEQ],
                'embedding': np.array(vec).flatten()
            })
        else:
            missing += 1
    
    print(f"Successfully aligned: {len(aligned_data)} samples")
    print(f"Missing embeddings: {missing} samples")
    print("="*70 + "\n")
    
    return pd.DataFrame(aligned_data)

def test1_cluster_quality(df):
    print("="*70)
    print("TEST 1: CLUSTER SEPARATION QUALITY")
    print("="*70)
    
    # Get top antigens
    antigen_counts = df['antigen_seq'].value_counts()
    top_antigens = antigen_counts.head(TOP_N_ANTIGENS).index.tolist()
    df_top = df[df['antigen_seq'].isin(top_antigens)].copy()
    
    # Create antigen labels
    antigen_to_id = {seq: i for i, seq in enumerate(top_antigens)}
    df_top['antigen_id'] = df_top['antigen_seq'].map(antigen_to_id)
    df_top['antigen_name'] = df_top['antigen_id'].map(ANTIGEN_NAMES)
    
    print(f"Analyzing {len(df_top)} samples across {len(top_antigens)} antigens:")
    for i, seq in enumerate(top_antigens):
        count = (df_top['antigen_seq'] == seq).sum()
        print(f"  {ANTIGEN_NAMES[i]}: {count} samples")
    
    X = np.array(df_top['embedding'].tolist())
    labels = df_top['antigen_id'].values
    
    sil_score = silhouette_score(X, labels, metric='cosine')
    db_score = davies_bouldin_score(X, labels)
    
    print(f"\nClustering Quality Metrics (High-Dimensional Space):")
    print(f"  Silhouette Score: {sil_score:.4f}")
    print(f"  Davies-Bouldin Index: {db_score:.4f}")
    
    print("\nPerforming UMAP projection...")
    reducer = umap.UMAP(n_components=2, random_state=SEED, n_neighbors=15, min_dist=0.1)
    X_umap = reducer.fit_transform(X)
    
    sil_2d = silhouette_score(X_umap, labels, metric='euclidean')
    print(f"  Silhouette Score (2D UMAP): {sil_2d:.4f}")
    
    fig, ax = plt.subplots(figsize=(10, 8))
    
    for antigen_id in sorted(df_top['antigen_id'].unique()):
        mask = df_top['antigen_id'] == antigen_id
        subset = X_umap[mask]
        name = ANTIGEN_NAMES[antigen_id]
        count = mask.sum()
        
        ax.scatter(subset[:, 0], subset[:, 1], 
                  label=f"{name} (n={count})",
                  alpha=0.7, s=80, edgecolor='white', linewidth=0.5)
    
    ax.set_xlabel('UMAP Dimension 1', fontsize=12)
    ax.set_ylabel('UMAP Dimension 2', fontsize=12)
    ax.set_title('AntiBinder Embeddings: Domain Generalization to ADC Antigens\n' +
                 f'(Trained on MET, Applied to {len(top_antigens)} Novel Targets)',
                 fontsize=13, weight='bold')
    ax.legend(loc='best', frameon=True, fontsize=10)
    ax.grid(alpha=0.3)
    
    plt.tight_layout()
    # plt.savefig(OUTPUT_DIR + 'test1_cluster_quality.png', dpi=600, bbox_inches='tight')
    print(f"\nSaved: test1_cluster_quality.png")
    print("="*70 + "\n")
    
    return {
        'silhouette_hd': sil_score,
        'davies_bouldin': db_score,
        'silhouette_2d': sil_2d
    }

def test2_distance_analysis(df):
    print("="*70)
    print("TEST 2: EMBEDDING SPACE STRUCTURE ANALYSIS")
    print("="*70)
    
    antigen_counts = df['antigen_seq'].value_counts()
    top_antigens = antigen_counts.head(TOP_N_ANTIGENS).index.tolist()
    df_top = df[df['antigen_seq'].isin(top_antigens)].copy()
    
    antigen_to_id = {seq: i for i, seq in enumerate(top_antigens)}
    df_top['antigen_id'] = df_top['antigen_seq'].map(antigen_to_id)
    
    X = np.array(df_top['embedding'].tolist())
    
    print("Computing pairwise distances...")
    dist_matrix = squareform(pdist(X, metric='cosine'))
    
    intra_distances = []
    inter_distances = []
    
    labels = df_top['antigen_id'].values
    n = len(labels)
    
    for i in range(n):
        for j in range(i+1, n):
            dist = dist_matrix[i, j]
            if labels[i] == labels[j]:
                intra_distances.append(dist)
            else:
                inter_distances.append(dist)
    
    intra_distances = np.array(intra_distances)
    inter_distances = np.array(inter_distances)
    
    print(f"\nDistance Statistics:")
    print(f"  Intra-antigen distances: {len(intra_distances)} pairs")
    print(f"    Mean: {intra_distances.mean():.4f}")
    print(f"    Std:  {intra_distances.std():.4f}")
    print(f"  Inter-antigen distances: {len(inter_distances)} pairs")
    print(f"    Mean: {inter_distances.mean():.4f}")
    print(f"    Std:  {inter_distances.std():.4f}")
    
    u_stat, p_value = mannwhitneyu(intra_distances, inter_distances, alternative='less')
    print(f"\nMann-Whitney U test (intra < inter):")
    print(f"  U-statistic: {u_stat:.2e}")
    print(f"  p-value: {p_value:.2e}")
    
    ratio = inter_distances.mean() / intra_distances.mean()
    print(f"  Inter/Intra Ratio: {ratio:.2f}x (Inter-cluster distances are {ratio:.2f} times larger)")
    
    
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    axes[0].hist(intra_distances, bins=50, alpha=0.7, label='Intra-antigen', density=True, color='blue')
    axes[0].hist(inter_distances, bins=50, alpha=0.7, label='Inter-antigen', density=True, color='red')
    axes[0].axvline(intra_distances.mean(), color='blue', linestyle='--', linewidth=2, label='Intra mean')
    axes[0].axvline(inter_distances.mean(), color='red', linestyle='--', linewidth=2, label='Inter mean')
    axes[0].set_xlabel('Cosine Distance', fontsize=11)
    axes[0].set_ylabel('Density', fontsize=11)
    axes[0].set_title('Distance Distribution', fontsize=12, weight='bold')
    axes[0].legend()
    axes[0].grid(alpha=0.3)
    
    data_box = pd.DataFrame({
        'Distance': np.concatenate([intra_distances, inter_distances]),
        'Type': ['Intra-antigen']*len(intra_distances) + ['Inter-antigen']*len(inter_distances)
    })
    sns.boxplot(data=data_box, x='Type', y='Distance', ax=axes[1], palette=['blue', 'red'])
    axes[1].set_title(f'Distance Comparison\n(p < 0.001)', 
                     fontsize=12, weight='bold')
    axes[1].set_ylabel('Cosine Distance', fontsize=11)
    axes[1].grid(alpha=0.3, axis='y')
    
    plt.tight_layout()
    # plt.savefig(OUTPUT_DIR + 'test2_distance_analysis.png', dpi=600, bbox_inches='tight')
    print(f"\nSaved: test2_distance_analysis.png")
    print("="*70 + "\n")
    
    return {
        'intra_mean': intra_distances.mean(),
        'inter_mean': inter_distances.mean(),
        'p_value': p_value,
    }


if __name__ == "__main__":
    print("\n" + "="*70)
    print("ANTIBINDER DOMAIN GENERALIZATION ANALYSIS")
    print("Validating Transfer from MET to Diverse ADC Antigens")
    print("="*70 + "\n")
    
    df = load_data()
    
    results = {}
    
    results['test1'] = test1_cluster_quality(df)
    results['test2'] = test2_distance_analysis(df)
    

## Performing Tests for: 
### 3. t-sne
### 4. Y-Scrambling

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.metrics import roc_auc_score, accuracy_score
from torch.utils.data import DataLoader, ConcatDataset

from ADCNet import PredictModel
from AB_Data import AB_Data
import os


os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

DATA_FILE = "data/data.xlsx"
EMBED_PATHS = [
    "Embeddings/antibinder_heavy.pkl",
    "Embeddings/Light.pkl",
    "Embeddings/Antigen.pkl",
]

import torch, numpy as np, random

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


PATH_BEFORE_TRAINING = "model_weights/modified_model.pth" 
PATH_AFTER_TRAINING = "ckpts/ADC_1_best_model.pth" 

BATCH_SIZE = 32
features_storage = []

def get_fc2_hook(module, input, output):
    """ Hooks into fc2 to capture the learned embedding.This bypasses the final prediction layer to see what the model 'learned'."""
    features_storage.append(output.detach().cpu())

def extract_data(model, dataloader, device, description):
    
    print(f"[{description}] Extracting embeddings...")
    model.eval()
    model.to(device)
    
    hook_handle = model.fc2.register_forward_hook(get_fc2_hook)
    features_storage.clear()
    
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for batch in dataloader:
            x1, x1maccs, x2, x2maccs, t1, t2, t3, aac1, aac2, aac3, t4, y = batch

            x1, x2 = x1.to(device), x2.to(device)
            x1maccs, x2maccs = x1maccs.to(device), x2maccs.to(device)
            t1, t2, t3, t4 = t1.to(device), t2.to(device), t3.to(device), t4.to(device)
            aac1, aac2, aac3 = aac1.to(device), aac2.to(device), aac3.to(device)

            if t4.dim() == 1: t4 = t4.unsqueeze(1)
            if t1.dim() == 1: t1 = t1.unsqueeze(1)
            if t2.dim() == 1: t2 = t2.unsqueeze(1)
            if t3.dim() == 1: t3 = t3.unsqueeze(1)
            if aac1.dim() == 1: aac1 = aac1.unsqueeze(1)

            logits = model(x1, x1maccs, x2, x2maccs, t1, t2, t3, aac1, aac2, aac3, t4)
            
            preds = torch.sigmoid(logits)
            
            all_preds.append(preds.cpu())
            all_labels.append(y)

    hook_handle.remove()

    X = torch.cat(features_storage).numpy()
    y = torch.cat(all_labels).numpy()
    preds = torch.cat(all_preds).numpy()
    
    return X, y, preds

def run_tsne(full_loader):
    model = PredictModel().to(DEVICE)
    try:
        model.load_state_dict(torch.load(PATH_BEFORE_TRAINING, map_location=DEVICE), strict=False)
        print("Loaded 'Before' weights.")
    except Exception as e:
        print(f"Warning: Could not load 'Before' weights. Using random init. ({e})")
    
    X_before, y_before, _ = extract_data(model, full_loader, DEVICE, "Before Training")

    try:
        model.load_state_dict(torch.load(PATH_AFTER_TRAINING, map_location=DEVICE))
        print("Loaded 'After' weights.")
    except Exception as e:
        print(f"CRITICAL ERROR: Could not load trained checkpoint! {e}")
        return

    X_after, y_after, _ = extract_data(model, full_loader, DEVICE, "After Training")

    n_samples = len(y_before)
    perp = min(30, int(n_samples / 10)) if n_samples > 0 else 30
    print(f"Running t-SNE on {n_samples} samples (Perplexity: {perp})...")
    
    tsne = TSNE(n_components=2, perplexity=perp, random_state=42, init='pca', learning_rate='auto')
    Z_before = tsne.fit_transform(X_before)
    Z_after = tsne.fit_transform(X_after)

    fig, axes = plt.subplots(1, 2, figsize=(16, 7))
    colors = {0: 'orange', 1: 'green'}
    labels = {0: 'Negative', 1: 'Positive'}

    for label_val in [0, 1]:
        mask = (y_before == label_val)
        axes[0].scatter(Z_before[mask, 0], Z_before[mask, 1], c=colors[label_val], label=labels[label_val], 
                        alpha=0.6, s=40, edgecolors='w', linewidth=0.5)
    axes[0].set_title("Before Training\n(Untrained Space)", fontsize=14)
    axes[0].legend()
    axes[0].axis('off')

    for label_val in [0, 1]:
        mask = (y_after == label_val)
        axes[1].scatter(Z_after[mask, 0], Z_after[mask, 1], c=colors[label_val], label=labels[label_val], 
                        alpha=0.6, s=40, edgecolors='w', linewidth=0.5)
    axes[1].set_title("After Training\n(ABFormer Learned Space)", fontsize=14)
    axes[1].legend()
    axes[1].axis('off')

    plt.tight_layout()
    # plt.savefig("t-sne.png", dpi=600)
    plt.show()
    plt.close()

def run_yscramble_test(full_loader):
    print("\n" + "="*50)
    print("TASK 2: Y-Scrambling Validation Test")
    print("="*50)

    model = PredictModel().to(DEVICE)
    model.load_state_dict(torch.load(PATH_AFTER_TRAINING, map_location=DEVICE))

    X, y_true, y_preds = extract_data(model, full_loader, DEVICE, "Y-Scramble Evaluation")

    np.random.seed(42)
    y_scrambled = np.random.permutation(y_true)

    auc_true = roc_auc_score(y_true, y_preds)
    auc_scram = roc_auc_score(y_scrambled, y_preds)
    
    print("-" * 40)
    print(f"TRUE LABELS AUC: {auc_true:.4f}")
    print(f"SCRAMBLED Y  AUC: {auc_scram:.4f}")
    print("-" * 40)

    perp = min(30, int(len(y_true) / 10))
    tsne = TSNE(n_components=2, perplexity=perp, random_state=42, init='pca', learning_rate='auto')
    Z = tsne.fit_transform(X)

    fig, axes = plt.subplots(1, 2, figsize=(16, 7))
    colors = {0: 'orange', 1: 'green'}
    labels = {0: 'Negative', 1: 'Positive'}

    for label_val in [0, 1]:
        mask = (y_true == label_val)
        axes[0].scatter(Z[mask, 0], Z[mask, 1], c=colors[label_val], label=labels[label_val], 
                        alpha=0.6, s=40, edgecolors='w', linewidth=0.5)
    axes[0].set_title(f"True Labels\nAUC: {auc_true:.3f}", fontsize=14, fontweight='bold')
    axes[0].legend()
    axes[0].axis('off')

    for label_val in [0, 1]:
        mask = (y_scrambled == label_val)
        axes[1].scatter(Z[mask, 0], Z[mask, 1], c=colors[label_val], label=labels[label_val], 
                        alpha=0.6, s=40, edgecolors='w', linewidth=0.5)
    axes[1].set_title(f"Y-Scrambled Labels\nAUC: {auc_scram:.3f} (Random)", fontsize=14, fontweight='bold')
    axes[1].legend()
    axes[1].axis('off')

    plt.tight_layout()
    # plt.savefig("Y_Scramble_Test.png", dpi=600)
    plt.show()
    # print("SAVED: Y_Scramble_Test.png")
    plt.close()


if __name__ == "__main__":
    print("Loading Dataset (Merging Train+Val+Test for visualization)...")
    data_handler = AB_Data(DATA_FILE, EMBED_PATHS)
    
    train, val, test = data_handler.get_dataloaders(batch_size=BATCH_SIZE, seed=1)
    
    full_dataset = ConcatDataset([
        data_handler.train_dataset, 
        data_handler.valid_dataset, 
        data_handler.test_dataset
    ])
    full_loader = DataLoader(full_dataset, batch_size=BATCH_SIZE, shuffle=False)
    
    run_tsne(full_loader)
    run_yscramble_test(full_loader)
    