THINGS TO IMPORT

In [None]:
import torch
import torch.nn as nx
import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_add_pool
from torch_geometric.data import Data
import graphein.protein as gp
from graphein.protein.config import ProteinGraphConfig
from Bio import PDB
import os
from torch_geometric.data import Data, Dataset
import pandas as pd
from graphein.protein.config import ProteinGraphConfig
from graphein.protein.graphs import construct_graph
from graphein.protein.edges.distance import (add_peptide_bonds,
                                             add_hydrogen_bond_interactions,
                                             add_disulfide_interactions,
                                             add_ionic_interactions,
                                             add_aromatic_interactions,
                                             add_aromatic_sulphur_interactions,
                                             add_cation_pi_interactions,
                                             add_delaunay_triangulation)

# Edge construction functions
new_edge_funcs = {"edge_construction_functions": [
    add_peptide_bonds,
    add_aromatic_interactions,
    add_hydrogen_bond_interactions,
    add_disulfide_interactions,
    add_ionic_interactions,
    add_aromatic_sulphur_interactions,
    add_cation_pi_interactions,
    add_delaunay_triangulation
]}

TESTING THE BUILDING OF THE GNN WITH GRAPH ATTENTION LAYERS

GET THE FEATURES FROM THE CSV FILE AND OBTAIN THE GRAPH OBJECT (literally copying from the proofmodel bindingfeatures.py)

In [8]:
from torch_geometric.data import Dataset

class ProteinGraphDataset(Dataset):
    """Custom PyTorch Geometric Dataset to handle protein graphs."""
    def __init__(self, protein_graphs):
        super().__init__()
        self.protein_graphs = protein_graphs  # List of PyG Data objects

    def len(self):
        return len(self.protein_graphs)

    def get(self, idx):
        return self.protein_graphs[idx]

def process_protein_graphs(folder_path):
    """
    Processes all PDB files in a given folder, loads corresponding CSVs,
    and returns a ProteinGraphDataset for PyTorch Geometric.
    """
    protein_graphs = []
    
    # Get all PDB files in folder
    pdb_files = [f for f in os.listdir(folder_path) if f.endswith(".pdb")]

    for pdb_file in pdb_files:
        pdb_path = os.path.join(folder_path, pdb_file)
        csv_path = os.path.join(folder_path, pdb_file.replace(".pdb", ".csv"))

        if not os.path.exists(csv_path):
            print(f"⚠️ Warning: No CSV found for {pdb_file}. Skipping!")
            continue  # Skip this PDB if no matching CSV

        print(f"📌 Processing {pdb_file}...")

        # Load features & labels
        features_dict, labels_dict, num_features = load_residue_features(csv_path)

        # Construct NetworkX graph
        config = ProteinGraphConfig(**new_edge_funcs)
        G_nx = construct_graph(config=config, path=pdb_path)

        # Convert to PyG Data object
        protein_graph_data, reverse_map = networkx_to_pyg(G_nx, features_dict, labels_dict, num_features)

        # Store the graph
        protein_graphs.append(protein_graph_data)

        print(f"✅ Processed {pdb_file} ({len(G_nx.nodes())} nodes, {len(G_nx.edges())} edges)")

    # ✅ Return as a PyTorch Geometric Dataset
    return ProteinGraphDataset(protein_graphs)

# Folder containing PDB & CSV files
folder_path = "testingfiles"  # CHANGE THIS!

# ✅ Process all proteins into a PyG Dataset
protein_graph_dataset = process_protein_graphs(folder_path)

# ✅ Print dataset summary
print(f"\n🔹 Total Processed Proteins: {len(protein_graph_dataset)}")
for i, graph in enumerate(protein_graph_dataset[:3]):  # Show first 3 graphs
    print(f"Graph {i+1}: {graph}")




Output()

📌 Processing 3emh.pdb...


Output()

✅ Processed 3emh.pdb (306 nodes, 2173 edges)
📌 Processing 1h28.pdb...


Output()

✅ Processed 1h28.pdb (1126 nodes, 8540 edges)
📌 Processing 2wik.pdb...


Output()

✅ Processed 2wik.pdb (527 nodes, 3881 edges)
📌 Processing 1fig.pdb...


Output()

✅ Processed 1fig.pdb (431 nodes, 3145 edges)
📌 Processing 2mpa.pdb...


✅ Processed 2mpa.pdb (450 nodes, 3295 edges)

🔹 Total Processed Proteins: 5
Graph 1: Data(x=[306, 71], edge_index=[2, 2173], y=[306])
Graph 2: Data(x=[1126, 71], edge_index=[2, 8540], y=[1126])
Graph 3: Data(x=[527, 71], edge_index=[2, 3881], y=[527])


DATA LOADER

In [9]:
from torch_geometric.data import DataLoader

# Split dataset into train/test sets
train_size = int(0.8 * len(protein_graph_dataset))
test_size = len(protein_graph_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(protein_graph_dataset, [train_size, test_size])

# ✅ Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)




NEURAL NETWORK

In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.data import Data, Dataset, DataLoader
from torch_geometric.nn import GATConv, global_mean_pool

# ✅ Extend Dataset class to handle protein graphs
class ProteinGraphDataset(Dataset):
    def __init__(self, protein_graphs):
        super().__init__()
        self.protein_graphs = protein_graphs  # List of Data objects

    def len(self):
        return len(self.protein_graphs)

    def get(self, idx):
        return self.protein_graphs[idx]

# ✅ Create Virtual Node GAT Model
class GATVirtualNode(nn.Module):
    def __init__(self, in_features, hidden_dim, num_heads, num_classes, dropout=0.2):
        super(GATVirtualNode, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads

        # Virtual Node (Global Graph Representation)
        self.virtual_node_embedding = nn.Parameter(torch.zeros(1, hidden_dim))
        
        # Graph Attention Layers
        self.gat1 = GATConv(in_features, hidden_dim // num_heads, heads=num_heads)
        self.gat2 = GATConv(hidden_dim, hidden_dim // num_heads, heads=num_heads)

        # MLP Classifier for **node classification**
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, num_classes)  # Output for **each node**
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        # Apply first GAT layer
        x = self.gat1(x, edge_index)
        x = torch.relu(x)
        x = self.dropout(x)

        # Apply second GAT layer
        x = self.gat2(x, edge_index)
        x = torch.relu(x)
        x = self.dropout(x)

        # Virtual Node Aggregation
        virtual_node = self.virtual_node_embedding.expand(x.size(0), -1)
        x = x + virtual_node

        # **Return node-level predictions**
        return self.mlp(x)  # Output has same shape as `data.y`


# ✅ Train the model
def train(model, train_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0

    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        
        out = model(data)  # **Now out has shape (num_nodes, num_classes)**
        loss = criterion(out, data.y)  # **Matches (num_nodes, num_classes) vs (num_nodes,)**

        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(train_loader)


# ✅ Evaluate the model
def evaluate(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for data in test_loader:
            data = data.to(device)
            out = model(data)  # Shape: (num_nodes, num_classes)
            pred = out.argmax(dim=1)  # Get class with highest probability

            correct += (pred == data.y).sum().item()  # Compare node-wise
            total += data.y.size(0)

    return correct / total  # Node-level accuracy


# ✅ Load Protein Dataset
protein_graph_dataset = ProteinGraphDataset(protein_graph_dataset)  # Wrap list into PyG Dataset

# ✅ Split dataset into train/test
train_size = int(0.8 * len(protein_graph_dataset))
test_size = len(protein_graph_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(protein_graph_dataset, [train_size, test_size])

# ✅ Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

# ✅ Define Model, Loss, Optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GATVirtualNode(in_features=protein_graph_dataset[0].x.shape[1], hidden_dim=128, num_heads=4, num_classes=2).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()

# ✅ Training Loop
num_epochs = 30
for epoch in range(num_epochs):
    train_loss = train(model, train_loader, optimizer, criterion, device)
    test_acc = evaluate(model, test_loader, device)

    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {train_loss:.4f}, Test Accuracy: {test_acc:.4f}")

print("✅ Training Completed!")


Epoch 1/30 - Loss: 0.4721, Test Accuracy: 0.8933
Epoch 2/30 - Loss: 0.3478, Test Accuracy: 0.8933
Epoch 3/30 - Loss: 0.2969, Test Accuracy: 0.8933
Epoch 4/30 - Loss: 0.2919, Test Accuracy: 0.8933
Epoch 5/30 - Loss: 0.3056, Test Accuracy: 0.8933
Epoch 6/30 - Loss: 0.3203, Test Accuracy: 0.8933
Epoch 7/30 - Loss: 0.3134, Test Accuracy: 0.8933
Epoch 8/30 - Loss: 0.3202, Test Accuracy: 0.8933
Epoch 9/30 - Loss: 0.2990, Test Accuracy: 0.8933
Epoch 10/30 - Loss: 0.2969, Test Accuracy: 0.8933
Epoch 11/30 - Loss: 0.2786, Test Accuracy: 0.8933
Epoch 12/30 - Loss: 0.2790, Test Accuracy: 0.8933
Epoch 13/30 - Loss: 0.2796, Test Accuracy: 0.8933
Epoch 14/30 - Loss: 0.2720, Test Accuracy: 0.8933
Epoch 15/30 - Loss: 0.2787, Test Accuracy: 0.8933
Epoch 16/30 - Loss: 0.2805, Test Accuracy: 0.8933
Epoch 17/30 - Loss: 0.2835, Test Accuracy: 0.8933
Epoch 18/30 - Loss: 0.2811, Test Accuracy: 0.8933
Epoch 19/30 - Loss: 0.2816, Test Accuracy: 0.8933
Epoch 20/30 - Loss: 0.2773, Test Accuracy: 0.8933
Epoch 21/

FULL CODE

In [14]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import networkx as nx
from torch_geometric.data import Data, Dataset, DataLoader
from torch_geometric.nn import GATConv, global_mean_pool
from graphein.protein.config import ProteinGraphConfig
from graphein.protein.graphs import construct_graph
from graphein.protein.edges.distance import (add_peptide_bonds,
                                             add_hydrogen_bond_interactions,
                                             add_disulfide_interactions,
                                             add_ionic_interactions,
                                             add_aromatic_interactions,
                                             add_aromatic_sulphur_interactions,
                                             add_cation_pi_interactions,
                                             add_delaunay_triangulation)

# ✅ Define Edge Functions
new_edge_funcs = {"edge_construction_functions": [add_peptide_bonds,
                                                  add_aromatic_interactions,
                                                  add_hydrogen_bond_interactions,
                                                  add_disulfide_interactions,
                                                  add_ionic_interactions,
                                                  add_aromatic_sulphur_interactions,
                                                  add_cation_pi_interactions,
                                                  add_delaunay_triangulation]}

# ✅ Load Residue Features from CSV
def load_residue_features(csv_path):
    """Loads residue features and binding site labels from CSV."""
    df = pd.read_csv(csv_path)

    # Convert first column (Residue ID) into tuples (residue_number, chain)
    df.iloc[:, 0] = df.iloc[:, 0].apply(eval)  # Convert string "(1, 'L')" → tuple (1, 'L')

    residue_ids = df.iloc[:, 0]  # Residue ID as tuple
    features = df.iloc[:, 1:-1].values  # Feature columns (excluding label)
    labels = df.iloc[:, -1].values  # Last column = binding site labels

    # Create lookup dictionaries
    features_dict = {res_id: feat for res_id, feat in zip(residue_ids, features)}
    labels_dict = {res_id: label for res_id, label in zip(residue_ids, labels)}

    return features_dict, labels_dict, features.shape[1]  # Return num_features

# ✅ Convert NetworkX Graph → PyTorch Geometric Graph
def networkx_to_pyg(G_nx, features_dict, labels_dict, num_features):
    """Converts a NetworkX protein graph to a PyTorch Geometric Data object with features & labels."""
    node_map = {}  # Maps node (residue, chain) to index
    reverse_map = {}  # Reverse lookup: PyG index → (residue_number, chain)
    node_features = []
    y = []

    for i, (node, attr) in enumerate(G_nx.nodes(data=True)):
        res_id = (attr.get("residue_number"), attr.get("chain_id"))  # Standardized format

        node_map[node] = i  # Assign PyG-compatible node index
        reverse_map[i] = res_id  # Store mapping back to residue identifier

        # Retrieve features (default to zero vector if missing)
        features = features_dict.get(res_id, [0] * num_features)
        node_features.append(features)

        # Retrieve binding site labels (default = non-binding)
        y.append(labels_dict.get(res_id, 0))

    x = torch.tensor(node_features, dtype=torch.float)
    y = torch.tensor(y, dtype=torch.long)

    # Extract edges
    edges = []
    for u, v in G_nx.edges():
        try:
            edges.append((node_map[u], node_map[v]))  # Use fixed node IDs
        except KeyError:
            print(f"⚠️ Skipping edge ({u}, {v}) due to missing node mapping!")

    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()

    return Data(x=x, edge_index=edge_index, y=y), reverse_map

# ✅ Process All PDB + CSV Files in a Folder
def process_protein_graphs(folder_path):
    """Processes all PDB files in a given folder & returns a ProteinGraphDataset."""
    protein_graphs = []
    
    # Get all PDB files in folder
    pdb_files = [f for f in os.listdir(folder_path) if f.endswith(".pdb")]

    for pdb_file in pdb_files:
        pdb_path = os.path.join(folder_path, pdb_file)
        csv_path = os.path.join(folder_path, pdb_file.replace(".pdb", ".csv"))

        if not os.path.exists(csv_path):
            print(f"⚠️ Warning: No CSV found for {pdb_file}. Skipping!")
            continue  # Skip this PDB if no matching CSV

        print(f"📌 Processing {pdb_file}...")

        # Load features & labels
        features_dict, labels_dict, num_features = load_residue_features(csv_path)

        # Construct NetworkX graph
        config = ProteinGraphConfig(**new_edge_funcs)
        G_nx = construct_graph(config=config, path=pdb_path)

        # Convert to PyG Data object
        protein_graph_data, reverse_map = networkx_to_pyg(G_nx, features_dict, labels_dict, num_features)

        # Store the graph
        protein_graphs.append(protein_graph_data)

        print(f"✅ Processed {pdb_file} ({len(G_nx.nodes())} nodes, {len(G_nx.edges())} edges)")

    return ProteinGraphDataset(protein_graphs)

# ✅ Define PyTorch Geometric Dataset
class ProteinGraphDataset(Dataset):
    """Custom PyTorch Geometric Dataset to handle protein graphs."""
    def __init__(self, protein_graphs):
        super().__init__()
        self.protein_graphs = protein_graphs  # List of Data objects

    def len(self):
        return len(self.protein_graphs)

    def get(self, idx):
        return self.protein_graphs[idx]

# ✅ Define Virtual Node GAT Model
class GATVirtualNode(nn.Module):
    def __init__(self, in_features, hidden_dim, num_heads, num_classes, dropout=0.2):
        super(GATVirtualNode, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads

        # Virtual Node (Global Graph Representation)
        self.virtual_node_embedding = nn.Parameter(torch.zeros(1, hidden_dim))
        
        # Graph Attention Layers
        self.gat1 = GATConv(in_features, hidden_dim // num_heads, heads=num_heads)
        self.gat2 = GATConv(hidden_dim, hidden_dim // num_heads, heads=num_heads)

        # MLP Classifier for **node classification**
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, num_classes)  # Output for **each node**
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        # Apply first GAT layer
        x = self.gat1(x, edge_index)
        x = torch.relu(x)
        x = self.dropout(x)

        # Apply second GAT layer
        x = self.gat2(x, edge_index)
        x = torch.relu(x)
        x = self.dropout(x)

        # Virtual Node Aggregation
        virtual_node = self.virtual_node_embedding.expand(x.size(0), -1)
        x = x + virtual_node

        # **Return node-level predictions**
        return self.mlp(x)  # Output has same shape as `data.y`



# ✅ Load & Prepare Dataset
folder_path = "testingfiles"  # CHANGE THIS!
protein_graph_dataset = process_protein_graphs(folder_path)

# ✅ Split dataset into train/test
train_size = int(0.8 * len(protein_graph_dataset))
test_size = len(protein_graph_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(protein_graph_dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

# ✅ Define Model, Loss, Optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GATVirtualNode(in_features=protein_graph_dataset[0].x.shape[1], hidden_dim=128, num_heads=4, num_classes=2).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()

# ✅ Training Loop
num_epochs = 30
for epoch in range(num_epochs):
    train_loss = train(model, train_loader, optimizer, criterion, device)
    test_acc = evaluate(model, test_loader, device)
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {train_loss:.4f}, Test Accuracy: {test_acc:.4f}")

print("✅ Training Completed!")


Output()

📌 Processing 3emh.pdb...


Output()

✅ Processed 3emh.pdb (306 nodes, 2173 edges)
📌 Processing 1h28.pdb...


Output()

✅ Processed 1h28.pdb (1126 nodes, 8540 edges)
📌 Processing 2wik.pdb...


Output()

✅ Processed 2wik.pdb (527 nodes, 3881 edges)
📌 Processing 1fig.pdb...


Output()

✅ Processed 1fig.pdb (431 nodes, 3145 edges)
📌 Processing 2mpa.pdb...


✅ Processed 2mpa.pdb (450 nodes, 3295 edges)
Epoch 1/30 - Loss: 1.0321, Test Accuracy: 0.8933




Epoch 2/30 - Loss: 0.5381, Test Accuracy: 0.8933
Epoch 3/30 - Loss: 0.3500, Test Accuracy: 0.8933
Epoch 4/30 - Loss: 0.3196, Test Accuracy: 0.8933
Epoch 5/30 - Loss: 0.3284, Test Accuracy: 0.8933
Epoch 6/30 - Loss: 0.3524, Test Accuracy: 0.8933
Epoch 7/30 - Loss: 0.3709, Test Accuracy: 0.8933
Epoch 8/30 - Loss: 0.3838, Test Accuracy: 0.8933
Epoch 9/30 - Loss: 0.3805, Test Accuracy: 0.8933
Epoch 10/30 - Loss: 0.3700, Test Accuracy: 0.8933
Epoch 11/30 - Loss: 0.3540, Test Accuracy: 0.8933
Epoch 12/30 - Loss: 0.3375, Test Accuracy: 0.8933
Epoch 13/30 - Loss: 0.3255, Test Accuracy: 0.8933
Epoch 14/30 - Loss: 0.3152, Test Accuracy: 0.8933
Epoch 15/30 - Loss: 0.2957, Test Accuracy: 0.8933
Epoch 16/30 - Loss: 0.3010, Test Accuracy: 0.8933
Epoch 17/30 - Loss: 0.2905, Test Accuracy: 0.8933
Epoch 18/30 - Loss: 0.2894, Test Accuracy: 0.8933
Epoch 19/30 - Loss: 0.2873, Test Accuracy: 0.8933
Epoch 20/30 - Loss: 0.2919, Test Accuracy: 0.8933
Epoch 21/30 - Loss: 0.2953, Test Accuracy: 0.8933
Epoch 22

TESTING