In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, MessagePassing, global_add_pool
from torch_geometric.data import Data, Batch
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix, matthews_corrcoef, accuracy_score, balanced_accuracy_score
from sklearn.model_selection import KFold, StratifiedKFold
import matplotlib.pyplot as plt
import random

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

class GNNModel(nn.Module):
    def __init__(self, seq_length=33, node_features=23, edge_features=18, hidden_dim=32, gnn_type='gcn'):
        super(GNNModel, self).__init__()
        
        # Sequence track (CNN)
        self.embedding = nn.Embedding(21, 21)
        self.conv = nn.Conv2d(1, 32, kernel_size=(17, 3), padding=0)
        self.bn = nn.BatchNorm2d(32)
        self.dropout = nn.Dropout(0.4)
        self.pool = nn.MaxPool2d(kernel_size=2)
        
        # Calculate flattened size after conv + pool
        # For the CNN part, we need to calculate the output dimensions
        conv_output_height = seq_length - 17 + 1  # Valid padding
        conv_output_width = 21 - 3 + 1  # Valid padding
        pool_output_height = conv_output_height // 2
        pool_output_width = conv_output_width // 2
        flatten_size = 32 * pool_output_height * pool_output_width
        
        self.seq_fc = nn.Linear(flatten_size, hidden_dim)
        
        # GNN track
        if gnn_type == 'gcn':
            self.gnn_layer = GCNConv(node_features, hidden_dim)
        elif gnn_type == 'gat':
            self.gnn_layer = GATConv(node_features, hidden_dim // 4, heads=4)
        
        # Final classifier
        self.fc1 = nn.Linear(hidden_dim * 2, 64)
        self.bn_fc = nn.BatchNorm1d(64)
        self.dropout_fc = nn.Dropout(0.5)
        self.fc2 = nn.Linear(64, 1)
        
        # Store GNN type
        self.gnn_type = gnn_type
    
    def forward(self, seq_data, graph_data_batch):
        # Process sequence data
        x_seq = self.embedding(seq_data)  # [batch_size, seq_length, 21]
        x_seq = x_seq.reshape(x_seq.size(0), 1, x_seq.size(1), x_seq.size(2))  # [batch_size, 1, seq_length, 21]
        x_seq = self.conv(x_seq)
        x_seq = self.bn(x_seq)
        x_seq = F.relu(x_seq)
        x_seq = self.dropout(x_seq)
        x_seq = self.pool(x_seq)
        x_seq = x_seq.flatten(1)  # Flatten all dimensions except batch
        x_seq = self.seq_fc(x_seq)
        x_seq = F.relu(x_seq)
        
        # Process graph data
        x = graph_data_batch.x
        edge_index = graph_data_batch.edge_index
        edge_attr = graph_data_batch.edge_attr
        node_mask = graph_data_batch.node_mask
        batch = graph_data_batch.batch
        
        if self.gnn_type == 'gcn':
            # Standard GCNConv doesn't use edge features directly
            x_gnn = self.gnn_layer(x, edge_index)
            # Apply node mask
            x_gnn = x_gnn * node_mask.unsqueeze(-1)
            x_gnn = F.relu(x_gnn)
        elif self.gnn_type == 'gat':
            # Standard GATConv doesn't use edge features directly
            x_gnn = self.gnn_layer(x, edge_index)
            # Apply node mask
            x_gnn = x_gnn * node_mask.unsqueeze(-1)
            x_gnn = F.relu(x_gnn)
        
        # Global pooling over nodes
        x_gnn = global_add_pool(x_gnn, batch)
        
        # Combine tracks
        combined = torch.cat([x_seq, x_gnn], dim=1)
        
        # Final classification
        x = self.fc1(combined)
        x = self.bn_fc(x)
        x = F.relu(x)
        x = self.dropout_fc(x)
        x = self.fc2(x)
        
        return torch.sigmoid(x)

def create_edge_features(i, j, dist_map, sequence, ss_string, sasa_vals, K_POS=16):
    """
    Create edge features with proper padding handling using sequence
    Returns a fixed-size numpy array
    """
    edge_features = np.zeros(18, dtype=np.float32)  # Pre-allocate fixed size array
    distance = dist_map[i,j]
    
    # Check if either position is padded using sequence
    is_i_padded = sequence[i] == '-'
    is_j_padded = sequence[j] == '-'
    
    if is_i_padded or is_j_padded:
        return edge_features  # Return zero array for padded positions
    
    feature_idx = 0
    
    # 1. Distance Features
    # Distance bins (4 features)
    if distance <= 4.0:
        edge_features[0] = 1.0
    elif distance <= 8.0:
        edge_features[1] = 1.0
    elif distance <= 12.0:
        edge_features[2] = 1.0
    else:
        edge_features[3] = 1.0
    feature_idx += 4
    
    # Continuous distance feature (1 feature)
    edge_features[feature_idx] = 1/distance
    feature_idx += 1
    
    # 2. Sequential Features (2 features)
    seq_dist = abs(i - j)
    edge_features[feature_idx] = float(seq_dist == 1)  # Is sequential
    edge_features[feature_idx + 1] = seq_dist / 32.0  # Normalized distance
    feature_idx += 2
    
    # 3. K-relative Features (2 features)
    edge_features[feature_idx] = float(i == K_POS or j == K_POS)  # Is K-connected
    edge_features[feature_idx + 1] = min(abs(i - K_POS), abs(j - K_POS)) / 16.0  # Min distance to K
    feature_idx += 2
    
    # 4. Secondary Structure Interaction (6 features)
    ss_pairs = ['HH', 'HE', 'HL', 'EE', 'EL', 'LL']
    ss_pair = ''.join(sorted([ss_string[i], ss_string[j]]))
    for idx, pair in enumerate(ss_pairs):
        edge_features[feature_idx + idx] = float(ss_pair == pair)
    feature_idx += 6
    
    # 5. SASA Interaction (2 features)
    edge_features[feature_idx] = abs(sasa_vals[i] - sasa_vals[j])  # SASA difference
    edge_features[feature_idx + 1] = (sasa_vals[i] + sasa_vals[j]) / 2  # SASA average
    
    return edge_features

def prepare_gnn_data_pyg(df, threshold=8.0):
    """
    Prepare node and edge features using sequence for padding detection
    Returns a list of PyG Data objects
    """
    graph_data_list = []
    K_POS = 16
    
    print("Processing samples for GNN...")
    for idx, row in df.iterrows():
        sequence = row['sequence']
        # Use sequence for padding detection
        is_padded = [pos for pos, aa in enumerate(sequence) if aa == '-']
        
        # Pre-allocate node features array
        sample_nodes = np.zeros((33, 23), dtype=np.float32)
        
        for pos in range(33):
            if pos in is_padded:
                continue  # Keep zeros for padded positions
            
            feature_idx = 0
            
            # Backbone angles
            for angle in ['phi', 'psi', 'omega']:
                angle_vals = np.array(eval(row[angle]))
                angle_rad = np.pi * angle_vals[pos] / 180.0
                sample_nodes[pos, feature_idx:feature_idx+2] = [np.sin(angle_rad), np.cos(angle_rad)]
                feature_idx += 2
            
            # SASA
            sasa_vals = np.array(eval(row['sasa']))
            sample_nodes[pos, feature_idx] = sasa_vals[pos]
            feature_idx += 1
            
            # SS
            ss_val = row['ss'][pos]
            ss_onehot = [1 if ss_val == ss_type else 0 for ss_type in 'HEL']
            sample_nodes[pos, feature_idx:feature_idx+3] = ss_onehot
            feature_idx += 3
            
            # plDDT
            plddt_vals = np.array(eval(row['plDDT']))
            sample_nodes[pos, feature_idx] = plddt_vals[pos]
            feature_idx += 1
            
            # Chi angles
            for chi in ['chi1', 'chi2', 'chi3', 'chi4']:
                chi_vals = np.array(eval(row[chi]))
                chi_rad = np.pi * chi_vals[pos] / 180.0
                sample_nodes[pos, feature_idx:feature_idx+2] = [np.sin(chi_rad), np.cos(chi_rad)]
                feature_idx += 2
            
            # K-relative features
            sample_nodes[pos, feature_idx] = abs(pos - K_POS) / 16.0  # Distance to K
            feature_idx += 1
            sample_nodes[pos, feature_idx] = float(pos == K_POS)  # Is K position
            feature_idx += 1
            sample_nodes[pos, feature_idx] = float(sequence[pos] == 'K')  # Is K amino acid
        
        # Create edges
        dist_map = np.array(eval(row['distance_map'])).reshape(33, 33)
        edges = []
        edge_attrs = []
        
        # Pre-calculate SASA values
        sasa_vals = np.array(eval(row['sasa']))
        
        for i in range(33):
            if i in is_padded:
                continue
                
            for j in range(33):
                if j in is_padded or i == j:
                    continue
                    
                if dist_map[i,j] != -1 and dist_map[i,j] < threshold:
                    edges.append([i, j])
                    edge_attrs.append(
                        create_edge_features(
                            i, j, dist_map, sequence, row['ss'], 
                            sasa_vals, K_POS
                        )
                    )
        
        # Create node mask
        node_mask = np.array([0.0 if pos in is_padded else 1.0 for pos in range(33)], dtype=np.float32)
        
        # Convert to PyG Data object
        if len(edges) == 0:
            # Handle case with no edges
            edge_index = torch.zeros((2, 0), dtype=torch.long)
            edge_attr = torch.zeros((0, 18), dtype=torch.float)
        else:
            # PyG expects edge_index to be of shape [2, num_edges]
            edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
            edge_attr = torch.tensor(edge_attrs, dtype=torch.float)
        
        x = torch.tensor(sample_nodes, dtype=torch.float)
        node_mask = torch.tensor(node_mask, dtype=torch.float)
        
        # Create PyG Data object
        data = Data(
            x=x,
            edge_index=edge_index,
            edge_attr=edge_attr,
            node_mask=node_mask,
            y=torch.tensor([row['label']], dtype=torch.float)
        )
        
        graph_data_list.append(data)
        
        if idx % 1000 == 0:
            print(f"Processed {idx}/{len(df)} samples")
    
    print(f"Created {len(graph_data_list)} graph data objects")
    
    return graph_data_list

def prepare_sequence_data_pyg(df):
    """Convert sequences to integer encoding for PyTorch"""
    alphabet = 'ARNDCQEGHILKMFPSTWYV-'
    char_to_int = dict((c, i) for i, c in enumerate(alphabet))
    
    sequences = df['sequence'].values
    encodings = []
    
    for seq in sequences:
        try:
            integer_encoded = [char_to_int[char] for char in seq]
            encodings.append(integer_encoded)
        except Exception as e:
            print(f"Error processing sequence: {e}")
            continue
    
    return torch.tensor(encodings, dtype=torch.long)

class SequenceGraphDataset(torch.utils.data.Dataset):
    def __init__(self, sequence_data, graph_data_list):
        self.sequence_data = sequence_data
        self.graph_data_list = graph_data_list
    
    def __len__(self):
        return len(self.sequence_data)
    
    def __getitem__(self, idx):
        return self.sequence_data[idx], self.graph_data_list[idx]

def collate_fn(batch):
    seq_data, graph_data = zip(*batch)
    seq_data = torch.stack(seq_data)
    graph_data = Batch.from_data_list(graph_data)
    return seq_data, graph_data

def train_epoch(model, data_loader, optimizer, device, class_weights=None):
    model.train()
    total_loss = 0
    for seq_data, graph_data in data_loader:
        seq_data = seq_data.to(device)
        graph_data = graph_data.to(device)
        
        optimizer.zero_grad()
        output = model(seq_data, graph_data)
        target = graph_data.y.view(-1, 1).float()
        
        if class_weights is not None:
            # Apply class weights to loss calculation
            weight = torch.ones_like(target)
            weight[target == 0] = class_weights[0]
            weight[target == 1] = class_weights[1]
            loss = F.binary_cross_entropy(output, target, weight=weight)
        else:
            loss = F.binary_cross_entropy(output, target)
            
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(data_loader)

def eval_model(model, data_loader, device):
    model.eval()
    predictions = []
    targets = []
    
    with torch.no_grad():
        for seq_data, graph_data in data_loader:
            seq_data = seq_data.to(device)
            graph_data = graph_data.to(device)
            
            output = model(seq_data, graph_data)
            pred = (output > 0.5).float().cpu().numpy()
            target = graph_data.y.view(-1, 1).cpu().numpy()
            
            predictions.extend(pred)
            targets.extend(target)
    
    predictions = np.array(predictions).flatten()
    targets = np.array(targets).flatten()
    
    return predictions, targets

def print_metrics(y_true, y_pred):
    """Print comprehensive evaluation metrics"""
    cm = confusion_matrix(y_true, y_pred)
    acc = accuracy_score(y_true, y_pred)
    balanced_acc = balanced_accuracy_score(y_true, y_pred)
    mcc = matthews_corrcoef(y_true, y_pred)
    
    if cm.shape[0] > 1 and cm.shape[1] > 1:  # Ensure there are both positive and negative classes
        sensitivity = cm[1][1]/(cm[1][1]+cm[1][0]) if (cm[1][1]+cm[1][0]) > 0 else 0
        specificity = cm[0][0]/(cm[0][0]+cm[0][1]) if (cm[0][0]+cm[0][1]) > 0 else 0
    else:
        sensitivity = specificity = 0
    
    print(f"Accuracy: {acc:.4f}")
    print(f"Balanced Accuracy: {balanced_acc:.4f}")
    print(f"MCC: {mcc:.4f}")
    print(f"Sensitivity: {sensitivity:.4f}")
    print(f"Specificity: {specificity:.4f}")
    print("Confusion Matrix:")
    print(cm)
    
    return {
        'accuracy': acc,
        'balanced_accuracy': balanced_acc,
        'mcc': mcc,
        'sensitivity': sensitivity,
        'specificity': specificity
    }

def train_with_cv(train_df, test_df, threshold=8.0, gnn_type='gcn', epochs=50, batch_size=32, lr=1e-3, n_folds=5):
    """
    Training function with cross-validation for PyTorch Geometric implementation
    Args:
        train_df: training dataframe
        test_df: test dataframe
        threshold: distance threshold for edge creation
        gnn_type: 'simple', 'gcn', or 'gat'
        epochs: maximum number of training epochs
        batch_size: batch size for training
        lr: learning rate
        n_folds: number of cross-validation folds
    """
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Prepare sequence data
    print("Preparing sequence data...")
    train_df = train_df.sample(frac=1, random_state=42).reset_index(drop=True)
    X_train_seq = prepare_sequence_data_pyg(train_df)
    X_test_seq = prepare_sequence_data_pyg(test_df)
    
    # Prepare graph data
    print("Preparing GNN data...")
    train_graph_data = prepare_gnn_data_pyg(train_df, threshold=threshold)
    test_graph_data = prepare_gnn_data_pyg(test_df, threshold=threshold)
    
    y_train = train_df['label'].values
    y_test = test_df['label'].values
    
    # Print class distribution
    print("\nClass distribution:")
    print("Train:", np.bincount(y_train))
    print("Test:", np.bincount(y_test))
    
    # Cross-validation
    kfold = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
    metrics = {
        'accuracy': [], 'balanced_accuracy': [], 
        'mcc': [], 'sensitivity': [], 'specificity': []
    }
    test_predictions = []
    
    # Create full test dataset
    test_dataset = SequenceGraphDataset(X_test_seq, test_graph_data)
    test_loader = torch.utils.data.DataLoader(
        test_dataset, 
        batch_size=batch_size, 
        shuffle=False,
        collate_fn=collate_fn
    )
    
    for fold, (train_idx, val_idx) in enumerate(kfold.split(X_train_seq, y_train), 1):
        print(f"\nFold {fold}/{n_folds}")
        
        # Calculate class weights
        total = len(train_idx)
        pos = np.sum(y_train[train_idx] == 1)
        neg = np.sum(y_train[train_idx] == 0)
        class_weights = {
            0: total / (2 * neg),
            1: total / (2 * pos)
        }
        class_weights_tensor = (class_weights[0], class_weights[1])
        
        # Create datasets for this fold
        train_fold_seq = X_train_seq[train_idx]
        val_fold_seq = X_train_seq[val_idx]
        
        train_fold_graph = [train_graph_data[i] for i in train_idx]
        val_fold_graph = [train_graph_data[i] for i in val_idx]
        
        train_fold_dataset = SequenceGraphDataset(train_fold_seq, train_fold_graph)
        val_fold_dataset = SequenceGraphDataset(val_fold_seq, val_fold_graph)
        
        train_loader = torch.utils.data.DataLoader(
            train_fold_dataset, 
            batch_size=batch_size, 
            shuffle=True,
            collate_fn=collate_fn
        )
        
        val_loader = torch.utils.data.DataLoader(
            val_fold_dataset, 
            batch_size=batch_size, 
            shuffle=False,
            collate_fn=collate_fn
        )
        
        # Create model
        model = GNNModel(
            seq_length=33,
            node_features=23,
            edge_features=18,
            hidden_dim=32,
            gnn_type=gnn_type
        ).to(device)
        
        # Optimizer
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        
        # Learning rate scheduler
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, 
            mode='min', 
            factor=0.5,
            patience=4,
            min_lr=1e-6
        )
        
        # Early stopping
        best_val_loss = float('inf')
        early_stop_counter = 0
        early_stop_patience = 7
        best_model_state = None
        
        # Training loop
        for epoch in range(epochs):
            # Train
            train_loss = train_epoch(model, train_loader, optimizer, device, class_weights_tensor)
            
            # Validate
            val_pred, val_true = eval_model(model, val_loader, device)
            val_loss = F.binary_cross_entropy(
                torch.tensor(val_pred, dtype=torch.float), 
                torch.tensor(val_true, dtype=torch.float)
            ).item()
            
            val_acc = accuracy_score(val_true, (val_pred > 0.5).astype(int))

            
            # Update scheduler
            scheduler.step(val_loss)
            
            # Check for early stopping
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                early_stop_counter = 0
                # Save best model
                best_model_state = model.state_dict().copy()
            else:
                early_stop_counter += 1
                
            if early_stop_counter >= early_stop_patience:
                print(f"Early stopping after {epoch+1} epochs")
                break
                
            # Print progress
            if (epoch + 1) % 5 == 0:
                print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
        
        # Load best model
        model.load_state_dict(best_model_state)
        
        # Evaluate on validation set
        val_pred, val_true = eval_model(model, val_loader, device)
        
        # Calculate metrics
        print(f"\nFold {fold} Results:")
        fold_metrics = print_metrics(val_true, val_pred)
        
        # Store metrics
        for key in metrics:
            metrics[key].append(fold_metrics[key])
        
        # Predict on test set
        test_pred, test_true = eval_model(model, test_loader, device)
        test_predictions.append(test_pred)
    
    # Print average CV results
    print(f"\nAverage Cross-validation Results for {gnn_type.upper()}:")
    cv_results = {}
    for metric in metrics:
        mean = np.mean(metrics[metric])
        std = np.std(metrics[metric])
        print(f"{metric}: {mean:.4f} ± {std:.4f}")
        cv_results[metric] = {'mean': mean, 'std': std}
    
    # Ensemble predictions on test set
    test_pred_avg = np.mean(test_predictions, axis=0)
    test_pred_binary = (test_pred_avg > 0.5).astype(int)
    
    # Print final test results
    print(f"\nFinal Test Set Results for {gnn_type.upper()}:")
    test_results = print_metrics(y_test, test_pred_binary)
    
    return {
        'cv_results': cv_results,
        'test_results': test_results,
        'model': model
    }

if __name__ == "__main__":
    # Example usage
    train_df = pd.read_csv("../data/processed_features_fixed_train_contactmap.csv")
    test_df = pd.read_csv("../data/processed_features_fixed_test_contactmap.csv")
    
    # Minimal data
    # train_df = pd.read_csv("../data/processed_features_fixed_train_contactmap_copy.csv")
    # test_df = pd.read_csv("../data/processed_features_fixed_test_contactmap_copy.csv")
    
    # Train with GCN (basic implementation)
    results = train_with_cv(train_df, test_df, threshold=8.0, gnn_type='gcn')
    
    # # To try all GNN types:
    # results = train_with_cv(train_df, test_df, threshold=8.0, gnn_type=None)

Using device: cuda
Preparing sequence data...
Preparing GNN data...
Processing samples for GNN...
Processed 0/8853 samples
Processed 1000/8853 samples
Processed 2000/8853 samples
Processed 3000/8853 samples
Processed 4000/8853 samples
Processed 5000/8853 samples
Processed 6000/8853 samples
Processed 7000/8853 samples
Processed 8000/8853 samples
Created 8853 graph data objects
Processing samples for GNN...
Processed 0/2737 samples
Processed 1000/2737 samples
Processed 2000/2737 samples
Created 2737 graph data objects

Class distribution:
Train: [4261 4592]
Test: [2497  240]

Fold 1/5
Epoch 5/50, Train Loss: 0.5680, Val Loss: 32.5805, Val Acc: 0.6742
Epoch 10/50, Train Loss: 0.5109, Val Loss: 32.4111, Val Acc: 0.6759
Epoch 15/50, Train Loss: 0.4582, Val Loss: 30.6606, Val Acc: 0.6934
Epoch 20/50, Train Loss: 0.4441, Val Loss: 29.7007, Val Acc: 0.7030
Epoch 25/50, Train Loss: 0.4134, Val Loss: 29.6443, Val Acc: 0.7036
Epoch 30/50, Train Loss: 0.3968, Val Loss: 27.2163, Val Acc: 0.7278
Epo

In [8]:
results = train_with_cv(train_df, test_df, threshold=8.0, gnn_type="gat")

Using device: cuda
Preparing sequence data...
Preparing GNN data...
Processing samples for GNN...
Processed 0/8853 samples
Processed 1000/8853 samples
Processed 2000/8853 samples
Processed 3000/8853 samples
Processed 4000/8853 samples
Processed 5000/8853 samples
Processed 6000/8853 samples
Processed 7000/8853 samples
Processed 8000/8853 samples
Created 8853 graph data objects
Processing samples for GNN...
Processed 0/2737 samples
Processed 1000/2737 samples
Processed 2000/2737 samples
Created 2737 graph data objects

Class distribution:
Train: [4261 4592]
Test: [2497  240]

Fold 1/5
Epoch 5/50, Train Loss: 0.5648, Val Loss: 33.0322, Val Acc: 0.6697
Epoch 10/50, Train Loss: 0.5021, Val Loss: 33.4274, Val Acc: 0.6657
Epoch 15/50, Train Loss: 0.4657, Val Loss: 27.5551, Val Acc: 0.7244
Epoch 20/50, Train Loss: 0.4237, Val Loss: 26.1434, Val Acc: 0.7386
Epoch 25/50, Train Loss: 0.3993, Val Loss: 26.2564, Val Acc: 0.7374
Early stopping after 30 epochs

Fold 1 Results:
Accuracy: 0.7357
Balanc