# Wang 2023 dataset

Dataset curated by the Kleinstein group and available on their [Bitbucket](https://bitbucket.org/kleinstein/projects/src/master/Wang2023/).

In [1]:
import torch
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

from sklearn.svm import SVC
from sklearn.metrics import roc_auc_score, average_precision_score
from sklearn.preprocessing import StandardScaler

from dnsmex.local import localify
from netam.framework import load_crepe
import dnsmex.wang2023_helper as helper
from dnsmex.ablang_wrapper import AbLangWrapper


data_dir = localify("DATA_DIR/Wang2023/data")

# Load and prepare the dataset
wang_df, max_seq_len = helper.filtered_wang_specificity(data_dir)

# wang_df = wang_df.sample(1000).reset_index(drop=True)
train_df = wang_df.sample(frac=0.8, random_state=42)
test_df = wang_df.drop(train_df.index)

print(train_df.shape)
print(test_df.shape)

Discarding rows with total_len > 300. There is 1
(12430, 9)
(3107, 9)


In [2]:
# Initialize container for all results
all_results = {}

def train_and_evaluate_binding_predictor(
    train_embeddings, 
    test_embeddings, 
    train_binds,
    test_binds,
    random_state=42
):
    """Train an SVM classifier on embeddings to predict binding.
    
    Args:
        train_embeddings: Numpy array of shape (n_train, n_features)
        test_embeddings: Numpy array of shape (n_test, n_features) 
        train_binds: Binary labels for training data
        test_binds: Binary labels for test data
        random_state: Random seed for reproducibility
    
    Returns:
        dict: Dictionary containing model, scaler, and performance metrics
    """
    # Scale the features
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(train_embeddings)
    X_test_scaled = scaler.transform(test_embeddings)

    # Train SVM
    model = SVC(
        kernel='rbf',
        random_state=random_state,
        probability=True,  # Enable probability estimates
        class_weight='balanced'  # Handle potential class imbalance
    )
    model.fit(X_train_scaled, train_binds)

    # Get predictions
    y_train_pred = model.predict_proba(X_train_scaled)[:, 1]
    y_test_pred = model.predict_proba(X_test_scaled)[:, 1]

    # Calculate metrics
    results = {
        'model': model,
        'scaler': scaler,
        'train_auroc': roc_auc_score(train_binds, y_train_pred),
        'test_auroc': roc_auc_score(test_binds, y_test_pred),
        'train_auprc': average_precision_score(train_binds, y_train_pred),
        'test_auprc': average_precision_score(test_binds, y_test_pred),
        'embedding_dim': train_embeddings.shape[1]
    }

    return results

def evaluate_dasm_model(model_path, model_name):
    """Evaluate a DASM model."""
    crepe = load_crepe(localify(model_path))
    
    def mean_rep_of(df):
        seqs = df[["heavy", "light"]].values.tolist()
        rep = crepe.represent_sequences(seqs)
        mean_rep = [rep.mean(axis=0) for rep in rep]
        return torch.stack(mean_rep)
    
    train_mean_rep = mean_rep_of(train_df)
    test_mean_rep = mean_rep_of(test_df)
    
    print(f"Generated embeddings for {model_name}. Shape: {train_mean_rep.shape}")
    
    train_mean_rep_array = train_mean_rep.numpy()
    test_mean_rep_array = test_mean_rep.numpy()
    
    results = train_and_evaluate_binding_predictor(
        train_mean_rep_array,
        test_mean_rep_array,
        train_df["binds"],
        test_df["binds"]
    )
    
    all_results[model_name] = results
    return results

def evaluate_ablang():
    """Evaluate AbLang."""
    model_name = "AbLang seqcoding"
    ablang_wrapper = AbLangWrapper()
    
    def seqcoding_of(df):
        seqs = df[["heavy", "light"]].values.tolist()
        return ablang_wrapper.seqcoding(seqs)
    
    train_mean_rep = seqcoding_of(train_df)
    test_mean_rep = seqcoding_of(test_df)
    
    print(f"Generated embeddings for {model_name}. Shape: {train_mean_rep.shape}")
    
    train_mean_rep_array = train_mean_rep.numpy()
    test_mean_rep_array = test_mean_rep.numpy()
    
    results = train_and_evaluate_binding_predictor(
        train_mean_rep_array,
        test_mean_rep_array,
        train_df["binds"],
        test_df["binds"]
    )
    
    all_results[model_name] = results
    return results

# Run evaluations for each model
# First DASM model
evaluate_dasm_model(
    "DASM_GRID_DIR/dasm_1m-v1jaffeCC+v1tangCC-mh-0", 
    "DASM 1M (jaffeCC+tangCC)"
)

# Second DASM model
evaluate_dasm_model(
    "DASM_TRAINED_MODELS_DIR/dasm_4m-v1jaffeCC+v1tangCC-joint",
    "DASM 4M (jaffeCC+tangCC-joint)"
)

# AbLang model
evaluate_ablang()

# Convert results to DataFrame for easy viewing
def create_results_df(results_dict):
    metrics = ['embedding_dim', 'train_auroc', 'test_auroc', 'train_auprc', 'test_auprc']
    results_df = pd.DataFrame(index=metrics)
    
    for model_name, result in results_dict.items():
        model_results = {
            'embedding_dim': result['embedding_dim'],
            'train_auroc': f"{result['train_auroc']:.3f}",
            'test_auroc': f"{result['test_auroc']:.3f}",
            'train_auprc': f"{result['train_auprc']:.3f}",
            'test_auprc': f"{result['test_auprc']:.3f}"
        }
        results_df[model_name] = pd.Series(model_results)
    
    return results_df

# Display results table
results_df = create_results_df(all_results)
display(results_df)

Generated embeddings for DASM 1M (jaffeCC+tangCC). Shape: torch.Size([12430, 128])
Generated embeddings for DASM 4M (jaffeCC+tangCC-joint). Shape: torch.Size([12430, 256])


  torch.load(
Running seqcoding: 100%|██████████| 125/125 [10:24<00:00,  5.00s/it]
Running seqcoding: 100%|██████████| 32/32 [02:33<00:00,  4.80s/it]


Generated embeddings for AbLang seqcoding. Shape: torch.Size([12430, 480])


Unnamed: 0,DASM 1M (jaffeCC+tangCC),DASM 4M (jaffeCC+tangCC-joint),AbLang seqcoding
embedding_dim,128.0,256.0,480.0
train_auroc,0.904,0.911,0.936
test_auroc,0.854,0.852,0.899
train_auprc,0.912,0.921,0.938
test_auprc,0.856,0.858,0.887
