<a href="https://colab.research.google.com/github/nuriamontala/PYT-SBI/blob/main/model_%26_prediction_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [29]:
import torch
import torch.nn as nx
import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_add_pool
from torch_geometric.data import Data
import graphein.protein as gp
from graphein.protein.config import ProteinGraphConfig
from Bio import PDB
import os
from torch_geometric.data import Data, Dataset
import pandas as pd
from graphein.protein.config import ProteinGraphConfig
from graphein.protein.graphs import construct_graph
from graphein.protein.edges.distance import (add_peptide_bonds,
                                             add_hydrogen_bond_interactions,
                                             add_disulfide_interactions,
                                             add_ionic_interactions,
                                             add_aromatic_interactions,
                                             add_aromatic_sulphur_interactions,
                                             add_cation_pi_interactions,
                                             add_delaunay_triangulation)

# Edge construction functions
new_edge_funcs = {"edge_construction_functions": [
    add_peptide_bonds,
    add_aromatic_interactions,
    add_hydrogen_bond_interactions,
    add_disulfide_interactions,
    add_ionic_interactions,
    add_aromatic_sulphur_interactions,
    add_cation_pi_interactions,
    add_delaunay_triangulation
]}


In [61]:
# ✅ Load Residue Features from CSV
def load_residue_features(csv_path):
    """Loads residue features and binding site labels from CSV."""
    df = pd.read_csv(csv_path)

    # Convert first column (Residue ID) into tuples (residue_number, chain)
    df.iloc[:, 0] = df.iloc[:, 0].apply(eval)  # Convert string "(1, 'L')" → tuple (1, 'L')

    residue_ids = df.iloc[:, 0]  # Residue ID as tuple
    features = df.iloc[:, 1:-1].values  # Feature columns (excluding label)
    labels = df.iloc[:, -1].values  # Last column = binding site labels

    # Create lookup dictionaries
    features_dict = {res_id: feat for res_id, feat in zip(residue_ids, features)}
    labels_dict = {res_id: label for res_id, label in zip(residue_ids, labels)}

    return features_dict, labels_dict, features.shape[1]  # Return num_features

# ✅ Convert NetworkX Graph → PyTorch Geometric Graph
def networkx_to_pyg(G_nx, features_dict, labels_dict, num_features):
    """Converts a NetworkX protein graph to a PyTorch Geometric Data object, removing nodes without features."""
    node_map = {}  # Maps node (residue, chain) to index
    reverse_map = {}  # Reverse lookup: PyG index → (residue_number, chain)
    node_features = []
    y = []
    valid_nodes = set()  # Stores nodes that have valid features

    for i, (node, attr) in enumerate(G_nx.nodes(data=True)):
        res_id = (attr.get("residue_number"), attr.get("chain_id"))  # Standardized format

        # Only include nodes that have valid features
        if res_id not in features_dict:
            continue  # Skip nodes with missing features

        valid_nodes.add(node)
        node_map[node] = len(node_features)  # Assign PyG-compatible node index
        reverse_map[len(node_features)] = res_id  # Store mapping back to residue identifier

        # Retrieve features
        node_features.append(features_dict[res_id])

        # Retrieve binding site labels
        y.append(labels_dict.get(res_id, 0))

    x = torch.tensor(node_features, dtype=torch.float)
    y = torch.tensor(y, dtype=torch.long)

    # Extract edges but only keep edges where **both** nodes are in valid_nodes
    edges = []
    for u, v in G_nx.edges():
        if u in valid_nodes and v in valid_nodes:  # Ensure both nodes exist in the filtered set
            edges.append((node_map[u], node_map[v]))

    if not edges:  # If no valid edges, return None (to avoid empty graphs)
        print(f"⚠️ Graph contains no valid edges after filtering. Skipping!")
        return None, None

    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()

    return Data(x=x, edge_index=edge_index, y=y), reverse_map

# ✅ Process All PDB + CSV Files in a Folder
def process_protein_graphs(folder_path):
    """Processes all PDB files in a given folder & returns a ProteinGraphDataset."""
    protein_graphs = []

    # Get all PDB files in folder
    pdb_files = [f for f in os.listdir(folder_path) if f.endswith(".pdb")]

    for pdb_file in pdb_files:
        pdb_path = os.path.join(folder_path, pdb_file)
        csv_path = os.path.join(folder_path, pdb_file.replace(".pdb", ".csv"))

        if not os.path.exists(csv_path):
            print(f"⚠️ Warning: No CSV found for {pdb_file}. Skipping!")
            continue  # Skip this PDB if no matching CSV

        print(f"📌 Processing {pdb_file}...")

        # Load features & labels
        features_dict, labels_dict, num_features = load_residue_features(csv_path)

        # Construct NetworkX graph
        config = ProteinGraphConfig(**new_edge_funcs)
        G_nx = construct_graph(config=config, path=pdb_path)

        # Convert to PyG Data object
        protein_graph_data, reverse_map = networkx_to_pyg(G_nx, features_dict, labels_dict, num_features)

        # Store the graph
        protein_graphs.append(protein_graph_data)

        print(f"✅ Processed {pdb_file} ({len(G_nx.nodes())} nodes, {len(G_nx.edges())} edges)")

    return ProteinGraphDataset(protein_graphs)

# ✅ Define PyTorch Geometric Dataset
class ProteinGraphDataset(Dataset):
    """Custom PyTorch Geometric Dataset to handle protein graphs."""
    def __init__(self, protein_graphs):
        super().__init__()
        self.protein_graphs = protein_graphs  # List of Data objects

    def len(self):
        return len(self.protein_graphs)

    def get(self, idx):
        return self.protein_graphs[idx]

# ✅ Define Virtual Node GAT Model
class GATVirtualNode(nn.Module):
    def __init__(self, in_features, hidden_dim, num_heads, num_classes, dropout=0.2):
        super(GATVirtualNode, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads

        # Virtual Node (Global Graph Representation)
        self.virtual_node_embedding = nn.Parameter(torch.zeros(1, hidden_dim))

        # Graph Attention Layers
        self.gat1 = GATConv(in_features, hidden_dim // num_heads, heads=num_heads)
        self.gat2 = GATConv(hidden_dim, hidden_dim // num_heads, heads=num_heads)

        # MLP Classifier for **node classification**
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, num_classes)  # Output for **each node**
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        # Apply first GAT layer
        x = self.gat1(x, edge_index)
        x = torch.relu(x)
        x = self.dropout(x)

        # Apply second GAT layer
        x = self.gat2(x, edge_index)
        x = torch.relu(x)
        x = self.dropout(x)

        # Virtual Node Aggregation
        virtual_node = self.virtual_node_embedding.expand(x.size(0), -1)
        x = x + virtual_node

        # **Return node-level predictions**
        return self.mlp(x)  # Output has same shape as `data.y`

# ✅ Load & Prepare Dataset
folder_path = "testingfiles"  # CHANGE THIS!
protein_graph_dataset = process_protein_graphs(folder_path)

# ✅ Split dataset into train/test
train_size = int(0.8 * len(protein_graph_dataset))
test_size = len(protein_graph_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(protein_graph_dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

# ✅ Define Model, Loss, Optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GATVirtualNode(in_features=protein_graph_dataset[0].x.shape[1], hidden_dim=128, num_heads=4, num_classes=2).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()

model_save_path = "trained_gat_model.pth"  # Change if needed

# ✅ Training Loop (with saving)
num_epochs = 30
best_acc = 0.0  # Track best accuracy to save the best model

for epoch in range(num_epochs):
    train_loss = train(model, train_loader, optimizer, criterion, device)
    test_acc = evaluate(model, test_loader, device)

    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {train_loss:.4f}, Test Accuracy: {test_acc:.4f}")

    # ✅ Save best model
    if test_acc > best_acc:
        best_acc = test_acc
        torch.save(model.state_dict(), model_save_path)
        print(f"✅ Model saved at epoch {epoch+1} with accuracy {test_acc:.4f}!")

print("✅ Training Completed!")



Output()

📌 Processing 2asu.pdb...


Output()

✅ Processed 2asu.pdb (229 nodes, 1601 edges)
📌 Processing 6h5w.pdb...


Output()

✅ Processed 6h5w.pdb (585 nodes, 4342 edges)
📌 Processing 1k03.pdb...


Output()

Output()

✅ Processed 5f1u.pdb (1184 nodes, 8827 edges)
📌 Processing 1awh.pdb...


✅ Processed 1awh.pdb (590 nodes, 4414 edges)
Epoch 1/30 - Loss: 0.6751, Test Accuracy: 0.9145
✅ Model saved at epoch 1 with accuracy 0.9145!
Epoch 2/30 - Loss: 0.6503, Test Accuracy: 0.9145




Epoch 3/30 - Loss: 0.6271, Test Accuracy: 0.9145
Epoch 4/30 - Loss: 0.6053, Test Accuracy: 0.9145
Epoch 5/30 - Loss: 0.5828, Test Accuracy: 0.9145
Epoch 6/30 - Loss: 0.5579, Test Accuracy: 0.9145
Epoch 7/30 - Loss: 0.5321, Test Accuracy: 0.9145
Epoch 8/30 - Loss: 0.5043, Test Accuracy: 0.9145
Epoch 9/30 - Loss: 0.4751, Test Accuracy: 0.9145
Epoch 10/30 - Loss: 0.4457, Test Accuracy: 0.9145
Epoch 11/30 - Loss: 0.4190, Test Accuracy: 0.9145
Epoch 12/30 - Loss: 0.3950, Test Accuracy: 0.9145
Epoch 13/30 - Loss: 0.3838, Test Accuracy: 0.9145
Epoch 14/30 - Loss: 0.3749, Test Accuracy: 0.9145
Epoch 15/30 - Loss: 0.3837, Test Accuracy: 0.9145
Epoch 16/30 - Loss: 0.3977, Test Accuracy: 0.9145
Epoch 17/30 - Loss: 0.4010, Test Accuracy: 0.9145
Epoch 18/30 - Loss: 0.4101, Test Accuracy: 0.9145
Epoch 19/30 - Loss: 0.4073, Test Accuracy: 0.9145
Epoch 20/30 - Loss: 0.4048, Test Accuracy: 0.9145
Epoch 21/30 - Loss: 0.4001, Test Accuracy: 0.9145
Epoch 22/30 - Loss: 0.3888, Test Accuracy: 0.9145
Epoch 2

In [62]:
import torch_geometric
from torch_geometric.utils import to_networkx
import networkx as nx
import matplotlib.pyplot as plt

# Select a graph from the dataset
graph = protein_graph_dataset[0]  # Change index as needed

# Print graph details
print("🔹 Graph Summary")
print(graph)  # Prints PyG Data object details

print("\n🔹 Node Feature Matrix (x):")
print(graph.x)  # Shape: [num_nodes, num_features]

print("\n🔹 Node Labels (y):")
print(graph.y)  # Shape: [num_nodes] (labels for each node)

print("\n🔹 Edge Index:")
print(graph.edge_index)  # Shape: [2, num_edges] (Adjacency list)


🔹 Graph Summary
Data(x=[224, 68], edge_index=[2, 1531], y=[224])

🔹 Node Feature Matrix (x):
tensor([[0.8117, 0.9025, 0.7978,  ..., 0.1947, 0.1721, 0.2992],
        [0.8117, 0.9025, 0.7978,  ..., 0.1947, 0.1721, 0.2992],
        [0.0000, 0.7525, 0.0000,  ..., 0.1947, 0.1721, 0.2992],
        ...,
        [0.9416, 0.0700, 0.9438,  ..., 0.1947, 0.1721, 0.2992],
        [0.8117, 0.9025, 0.7978,  ..., 0.1947, 0.1721, 0.2992],
        [0.8669, 0.8875, 0.8090,  ..., 0.1947, 0.1721, 0.2992]])

🔹 Node Labels (y):
tensor([0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
        1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0,
        0