In [1]:
import tensorflow as tf
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix, matthews_corrcoef, accuracy_score, balanced_accuracy_score
from sklearn.model_selection import KFold, StratifiedKFold
from sklearn.preprocessing import RobustScaler
import matplotlib.pyplot as plt
import random

2025-04-08 16:53:48.231587: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744131228.250903  991900 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744131228.256867  991900 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-08 16:53:48.279978: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)
tf.keras.utils.set_random_seed(SEED)  # This sets all random seeds in keras
tf.config.experimental.enable_op_determinism()  # For complete reproducibility

In [3]:
def prepare_structure_data(df):
    """Structure data preparation without contacts"""
    features_list = []
    middle_pos = 16  
    
    # Normalize angles to their circular nature
    def normalize_angles(angle_array, pos):
        angles = np.array([arr[pos] for arr in angle_array])
        angle_rad = np.pi * angles / 180.0
        return np.stack([np.sin(angle_rad), np.cos(angle_rad)], axis=-1)
    
    # 1. Process angles
    angles = ['phi', 'psi', 'omega']
    for angle in angles:
        angle_arrays = np.array([np.array(eval(x)) for x in df[angle]])
        angle_features = normalize_angles(angle_arrays, middle_pos)
        features_list.append(angle_features)
        print(f"{angle} features shape: {angle_features.shape}")
    
    # 2. Process SASA
    sasa_arrays = np.array([np.array(eval(x)) for x in df['sasa']])
    scaler = RobustScaler()
    sasa_features = []
    for pos in [middle_pos-1, middle_pos, middle_pos+1]:
        sasa_pos = np.array([arr[pos] for arr in sasa_arrays]).reshape(-1, 1)
        sasa_scaled = scaler.fit_transform(sasa_pos)
        sasa_features.append(sasa_scaled)
    sasa_features = np.concatenate(sasa_features, axis=1)
    features_list.append(sasa_features)
    print(f"SASA features shape: {sasa_features.shape}")
    
    # 3. Process chi angles
    chi_angles = ['chi1', 'chi2', 'chi3', 'chi4']
    for chi in chi_angles:
        chi_arrays = np.array([np.array(eval(x)) for x in df[chi]])
        chi_features = normalize_angles(chi_arrays, middle_pos)
        features_list.append(chi_features)
        print(f"{chi} features shape: {chi_features.shape}")
    
    # 4. Process SS (optional)
    ss_arrays = np.array([list(seq) for seq in df['ss']])
    ss_center = ss_arrays[:, middle_pos]
    ss_encoded = np.zeros((len(ss_arrays), 3))
    ss_map = {'H': 0, 'E': 1, 'L': 2}
    for i, ss in enumerate(ss_center):
        ss_encoded[i, ss_map[ss]] = 1
    features_list.append(ss_encoded)
    print(f"SS features shape: {ss_encoded.shape}")
    
    # 5. Process plDDT
    plddt_arrays = np.array([np.array(eval(x)) for x in df['plDDT']])
    plddt_center = np.array([arr[middle_pos] for arr in plddt_arrays]).reshape(-1, 1)
    scaler = RobustScaler()
    plddt_scaled = scaler.fit_transform(plddt_center)
    features_list.append(plddt_scaled)
    print(f"plDDT features shape: {plddt_scaled.shape}")
    
    # Combine all features
    features = np.concatenate(features_list, axis=1)
    print(f"\nFinal combined features shape: {features.shape}")
    print("Feature list lengths:", [f.shape[1] for f in features_list])
    
    return features

In [4]:
def create_structure_model(input_dim):
    """Create a standalone model for structural features
    
    Args:
        input_dim: The dimensionality of the structural features
        
    Returns:
        A compiled Keras model
    """
    model = tf.keras.Sequential([
        tf.keras.layers.Input(shape=(input_dim,)),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Dropout(0.3),
        tf.keras.layers.Dense(32, activation='relu'),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Dropout(0.3),
        tf.keras.layers.Dense(1, activation='sigmoid')
    ])
    
    return model

In [5]:
def print_metrics(y_true, y_pred):
    """
    Print comprehensive evaluation metrics
    
    Parameters:
    y_true: array-like of true labels
    y_pred: array-like of predicted labels
    """
    # Calculate confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    
    # Calculate metrics
    acc = accuracy_score(y_true, y_pred)
    balanced_acc = balanced_accuracy_score(y_true, y_pred)
    mcc = matthews_corrcoef(y_true, y_pred)
    sensitivity = cm[1][1] / (cm[1][1] + cm[1][0])  # True Positive Rate
    specificity = cm[0][0] / (cm[0][0] + cm[0][1])  # True Negative Rate
    
    # Print results
    print(f"Accuracy: {acc:.4f}")
    print(f"Balanced Accuracy: {balanced_acc:.4f}")
    print(f"MCC: {mcc:.4f}")
    print(f"Sensitivity: {sensitivity:.4f}")
    print(f"Specificity: {specificity:.4f}")
    print("Confusion Matrix:")
    print(cm)

In [None]:
from xgboost import XGBClassifier


def train_and_evaluate_seq_only_xgboost():
    """Train and evaluate structure-based model using XGBoost"""
    print("Loading structural data...")
    
    # Load data
    train_df = pd.read_csv("../../../../data/train/structure/processed_features_train.csv")
    test_df = pd.read_csv("../../../../data/test/structure/processed_features_test.csv")
    
    # Extract labels
    y_train = train_df['label'].values
    y_test = test_df['label'].values
    
    # Prepare feature data
    X_train = prepare_structure_data(train_df)
    X_test = prepare_structure_data(test_df)
    
    # Shuffle training data (important since negatives come first, then positives)
    shuffle_idx = np.random.RandomState(42).permutation(len(y_train))
    X_train = X_train[shuffle_idx]
    y_train = y_train[shuffle_idx]
    
    # Calculate positive class weight for XGBoost
    pos_weight = np.sum(y_train == 0) / np.sum(y_train == 1)
    
    print(f"Training data shape: {X_train.shape}")
    print(f"Testing data shape: {X_test.shape}")
    print(f"Positive class weight: {pos_weight}")
    
    # Cross-validation setup
    kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    metrics = {'acc': [], 'balanced_acc': [], 'mcc': [], 'sn': [], 'sp': []}
    test_predictions = []
    
    for fold, (train_idx, val_idx) in enumerate(kfold.split(X_train, y_train), 1):
        print(f"\nFold {fold}/5")
        
        # Create XGBoost model
        
        model = XGBClassifier(
            n_estimators=100,
            max_depth=5,
            learning_rate=0.1,
            min_child_weight=1,
            gamma=0,
            subsample=0.8,
            colsample_bytree=0.8,
            objective='binary:logistic',
            scale_pos_weight=pos_weight,
            random_state=42,
            use_label_encoder=False,
            eval_metric='auc',
            early_stopping_rounds=10,
        )
        
        # Train model
        eval_set = [(X_train[train_idx], y_train[train_idx]), (X_train[val_idx], y_train[val_idx])]
        model.fit(
            X_train[train_idx], 
            y_train[train_idx],
            eval_set=eval_set,
            verbose=True
        )
        
        # Evaluate on validation set
        val_pred = model.predict_proba(X_train[val_idx])[:, 1]
        val_pred_binary = (val_pred > 0.5).astype(int)
        
        # Calculate metrics
        cm = confusion_matrix(y_train[val_idx], val_pred_binary)
        metrics['acc'].append(accuracy_score(y_train[val_idx], val_pred_binary))
        metrics['balanced_acc'].append(balanced_accuracy_score(y_train[val_idx], val_pred_binary))
        metrics['mcc'].append(matthews_corrcoef(y_train[val_idx], val_pred_binary))
        metrics['sn'].append(cm[1][1]/(cm[1][1]+cm[1][0]))
        metrics['sp'].append(cm[0][0]/(cm[0][0]+cm[0][1]))
        
        # Predict on test set
        test_pred = model.predict_proba(X_test)[:, 1]
        test_predictions.append(test_pred)
        
        print(f"\nFold {fold} Results:")
        print(f"Accuracy: {metrics['acc'][-1]:.4f}")
        print(f"Balanced Accuracy: {metrics['balanced_acc'][-1]:.4f}")
        print(f"MCC: {metrics['mcc'][-1]:.4f}")
        print(f"Sensitivity: {metrics['sn'][-1]:.4f}")
        print(f"Specificity: {metrics['sp'][-1]:.4f}")
        
        # Feature importance for this fold
        feature_importance = model.feature_importances_
        print(f"\nTop 5 important features for fold {fold}:")
        top_indices = np.argsort(feature_importance)[-5:]
        for i in top_indices[::-1]:
            print(f"Feature {i}: {feature_importance[i]:.4f}")
    
    # Print average cross-validation results
    print("\nAverage Cross-validation Results:")
    for metric in metrics:
        print(f"{metric.upper()}: {np.mean(metrics[metric]):.4f} ± {np.std(metrics[metric]):.4f}")
    
    # Ensemble predictions on test set
    test_pred_avg = np.mean(test_predictions, axis=0)
    test_pred_binary = (test_pred_avg > 0.5).astype(int)
    
    # Calculate final test metrics
    cm_test = confusion_matrix(y_test, test_pred_binary)
    
    print("\nFinal Test Set Results:")
    print(f"Accuracy: {accuracy_score(y_test, test_pred_binary):.4f}")
    print(f"Balanced Accuracy: {balanced_accuracy_score(y_test, test_pred_binary):.4f}")
    print(f"MCC: {matthews_corrcoef(y_test, test_pred_binary):.4f}")
    print(f"Sensitivity: {cm_test[1][1]/(cm_test[1][1]+cm_test[1][0]):.4f}")
    print(f"Specificity: {cm_test[0][0]/(cm_test[0][0]+cm_test[0][1]):.4f}")
    print("Confusion Matrix:")
    print(cm_test)
    
    return model, test_pred_avg

In [None]:
model, test_probs = train_and_evaluate_seq_only_xgboost()

Loading structural data...
phi features shape: (8853, 2)
psi features shape: (8853, 2)
omega features shape: (8853, 2)
SASA features shape: (8853, 3)
chi1 features shape: (8853, 2)
chi2 features shape: (8853, 2)
chi3 features shape: (8853, 2)
chi4 features shape: (8853, 2)
SS features shape: (8853, 3)
plDDT features shape: (8853, 1)

Final combined features shape: (8853, 21)
Feature list lengths: [2, 2, 2, 3, 2, 2, 2, 2, 3, 1]
phi features shape: (2737, 2)
psi features shape: (2737, 2)
omega features shape: (2737, 2)
SASA features shape: (2737, 3)
chi1 features shape: (2737, 2)
chi2 features shape: (2737, 2)
chi3 features shape: (2737, 2)
chi4 features shape: (2737, 2)
SS features shape: (2737, 3)
plDDT features shape: (2737, 1)

Final combined features shape: (2737, 21)
Feature list lengths: [2, 2, 2, 3, 2, 2, 2, 2, 3, 1]
Training data shape: (8853, 21)
Testing data shape: (2737, 21)
Positive class weight: 0.927918118466899

Fold 1/5
[0]	validation_0-auc:0.67742	validation_1-auc:0.620

Parameters: { "use_label_encoder" } are not used.

Parameters: { "use_label_encoder" } are not used.



[6]	validation_0-auc:0.72286	validation_1-auc:0.65817
[7]	validation_0-auc:0.72494	validation_1-auc:0.65772
[8]	validation_0-auc:0.72866	validation_1-auc:0.66127
[9]	validation_0-auc:0.73249	validation_1-auc:0.66178
[10]	validation_0-auc:0.73697	validation_1-auc:0.66084
[11]	validation_0-auc:0.74038	validation_1-auc:0.66144
[12]	validation_0-auc:0.74230	validation_1-auc:0.66236
[13]	validation_0-auc:0.74511	validation_1-auc:0.66358
[14]	validation_0-auc:0.74659	validation_1-auc:0.66280
[15]	validation_0-auc:0.74809	validation_1-auc:0.66318
[16]	validation_0-auc:0.75081	validation_1-auc:0.66410
[17]	validation_0-auc:0.75178	validation_1-auc:0.66507
[18]	validation_0-auc:0.75568	validation_1-auc:0.66503
[19]	validation_0-auc:0.75926	validation_1-auc:0.66422
[20]	validation_0-auc:0.76258	validation_1-auc:0.66466
[21]	validation_0-auc:0.76524	validation_1-auc:0.66490
[22]	validation_0-auc:0.76882	validation_1-auc:0.66471
[23]	validation_0-auc:0.77144	validation_1-auc:0.66599
[24]	validatio

Parameters: { "use_label_encoder" } are not used.



[0]	validation_0-auc:0.67480	validation_1-auc:0.61505
[1]	validation_0-auc:0.69452	validation_1-auc:0.63430
[2]	validation_0-auc:0.70226	validation_1-auc:0.63960
[3]	validation_0-auc:0.70963	validation_1-auc:0.64003
[4]	validation_0-auc:0.71834	validation_1-auc:0.64348
[5]	validation_0-auc:0.72221	validation_1-auc:0.64468
[6]	validation_0-auc:0.72552	validation_1-auc:0.64822
[7]	validation_0-auc:0.72937	validation_1-auc:0.64816
[8]	validation_0-auc:0.73281	validation_1-auc:0.64838
[9]	validation_0-auc:0.73662	validation_1-auc:0.64876
[10]	validation_0-auc:0.73985	validation_1-auc:0.64865
[11]	validation_0-auc:0.74040	validation_1-auc:0.64942
[12]	validation_0-auc:0.74429	validation_1-auc:0.65120
[13]	validation_0-auc:0.74543	validation_1-auc:0.65257
[14]	validation_0-auc:0.74747	validation_1-auc:0.65402
[15]	validation_0-auc:0.74894	validation_1-auc:0.65440
[16]	validation_0-auc:0.75109	validation_1-auc:0.65404
[17]	validation_0-auc:0.75492	validation_1-auc:0.65435
[18]	validation_0-au

Parameters: { "use_label_encoder" } are not used.

Parameters: { "use_label_encoder" } are not used.



[12]	validation_0-auc:0.74462	validation_1-auc:0.65189
[13]	validation_0-auc:0.74613	validation_1-auc:0.65290
[14]	validation_0-auc:0.74764	validation_1-auc:0.65310
[15]	validation_0-auc:0.74883	validation_1-auc:0.65537
[16]	validation_0-auc:0.75103	validation_1-auc:0.65480
[17]	validation_0-auc:0.75372	validation_1-auc:0.65485
[18]	validation_0-auc:0.75482	validation_1-auc:0.65657
[19]	validation_0-auc:0.75905	validation_1-auc:0.65752
[20]	validation_0-auc:0.76103	validation_1-auc:0.65790
[21]	validation_0-auc:0.76286	validation_1-auc:0.65732
[22]	validation_0-auc:0.76433	validation_1-auc:0.65757
[23]	validation_0-auc:0.76667	validation_1-auc:0.65821
[24]	validation_0-auc:0.76904	validation_1-auc:0.65786
[25]	validation_0-auc:0.77152	validation_1-auc:0.65726
[26]	validation_0-auc:0.77341	validation_1-auc:0.65659
[27]	validation_0-auc:0.77467	validation_1-auc:0.65545
[28]	validation_0-auc:0.77646	validation_1-auc:0.65644
[29]	validation_0-auc:0.77752	validation_1-auc:0.65636
[30]	valid