<a href="https://colab.research.google.com/github/nuriamontala/PYT-SBI/blob/main/model_%26_prediction_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_add_pool
from torch_geometric.data import Data
import graphein.protein as gp
from graphein.protein.config import ProteinGraphConfig
from Bio import PDB
from torch_geometric.loader import DataLoader
import os
from torch.utils.data import Dataset, random_split, Subset
import torch.optim as optim
from torch_geometric.data import Data, Dataset
import pandas as pd
from graphein.protein.config import ProteinGraphConfig
from graphein.protein.graphs import construct_graph
from graphein.protein.edges.distance import (add_peptide_bonds,
                                             add_hydrogen_bond_interactions,
                                             add_disulfide_interactions,
                                             add_ionic_interactions,
                                             add_aromatic_interactions,
                                             add_aromatic_sulphur_interactions,
                                             add_cation_pi_interactions)

# Edge construction functions
new_edge_funcs = {"edge_construction_functions": [
    add_peptide_bonds,
    add_aromatic_interactions,
    add_hydrogen_bond_interactions,
    add_disulfide_interactions,
    add_ionic_interactions,
    add_aromatic_sulphur_interactions,
    add_cation_pi_interactions
]}

In [15]:
# ✅ Load Residue Features from CSV
def load_residue_features(csv_path):
    """Loads residue features and binding site labels from CSV."""
    df = pd.read_csv(csv_path)

    # Convert first column (Residue ID) into tuples (residue_number, chain)
    df.iloc[:, 0] = df.iloc[:, 0].apply(eval)  # Convert string "(1, 'L')" → tuple (1, 'L')

    residue_ids = df.iloc[:, 0]  # Residue ID as tuple
    features = df.iloc[:, 1:-1].values  # Feature columns (excluding label)
    labels = df.iloc[:, -1].values  # Last column = binding site labels

    # Create lookup dictionaries
    features_dict = {res_id: feat for res_id, feat in zip(residue_ids, features)}
    labels_dict = {res_id: label for res_id, label in zip(residue_ids, labels)}

    return features_dict, labels_dict, features.shape[1]  # Return num_features

# ✅ Convert NetworkX Graph → PyTorch Geometric Graph
def networkx_to_pyg(G_nx, features_dict, labels_dict, num_features):
    """Converts a NetworkX protein graph to a PyTorch Geometric Data object, removing nodes without features."""
    node_map = {}  # Maps node (residue, chain) to index
    reverse_map = {}  # Reverse lookup: PyG index → (residue_number, chain)
    node_features = []
    y = []
    valid_nodes = set()  # Stores nodes that have valid features

    for i, (node, attr) in enumerate(G_nx.nodes(data=True)):
        res_id = (attr.get("residue_number"), attr.get("chain_id"))  # Standardized format

        # Only include nodes that have valid features
        if res_id not in features_dict:
            continue  # Skip nodes with missing features

        valid_nodes.add(node)
        node_map[node] = len(node_features)  # Assign PyG-compatible node index
        reverse_map[len(node_features)] = res_id  # Store mapping back to residue identifier

        # Retrieve features
        node_features.append(features_dict[res_id])

        # Retrieve binding site labels
        y.append(labels_dict.get(res_id, 0))

    x = torch.tensor(node_features, dtype=torch.float)
    y = torch.tensor(y, dtype=torch.long)

    # Extract edges but only keep edges where **both** nodes are in valid_nodes
    edges = []
    for u, v in G_nx.edges():
        if u in valid_nodes and v in valid_nodes:  # Ensure both nodes exist in the filtered set
            edges.append((node_map[u], node_map[v]))

    if not edges:  # If no valid edges, return None (to avoid empty graphs)
        print(f"⚠️ Graph contains no valid edges after filtering. Skipping!")
        return None, None

    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()

    return Data(x=x, edge_index=edge_index, y=y), reverse_map

# ✅ Process All PDB + CSV Files in a Folder
def process_protein_graphs(folder_path, save_path="protein_graphs"):
    """Processes all PDB files and saves graphs to disk."""
    os.makedirs(save_path, exist_ok=True)  # Create folder if it doesn't exist

    # Get all PDB files in the folder
    pdb_files = [f for f in os.listdir(folder_path) if f.endswith(".pdb")]

    for i, pdb_file in enumerate(pdb_files):
        pdb_path = os.path.join(folder_path, pdb_file)
        csv_path = os.path.join(folder_path, pdb_file.replace(".pdb", ".csv"))

        if not os.path.exists(csv_path):
            print(f"⚠️ Warning: No CSV found for {pdb_file}. Skipping!")
            continue  # Skip this PDB if no matching CSV

        print(f"📌 Processing {pdb_file}...")

        # Load features & labels
        features_dict, labels_dict, num_features = load_residue_features(csv_path)

        # Construct NetworkX graph
        config = ProteinGraphConfig(**new_edge_funcs)
        G_nx = construct_graph(config=config, path=pdb_path)

        # Convert to PyG Data object
        protein_graph_data, reverse_map = networkx_to_pyg(G_nx, features_dict, labels_dict, num_features)

        if protein_graph_data:  # Ensure we have a valid graph
            graph_file = os.path.join(save_path, f"graph_{i}.pt")
            torch.save(protein_graph_data, graph_file)
            print(f"✅ Saved graph {i} to {graph_file}")

    print(f"✅ All graphs processed and saved in {save_path}")

# ✅ Run this once to generate graphs
process_protein_graphs("testingfiles/testingfiles")


Output()

📌 Processing 3c6w.pdb...


Output()

  x = torch.tensor(node_features, dtype=torch.float)


✅ Saved graph 0 to protein_graphs/graph_0.pt
📌 Processing 3wns.pdb...


Output()

✅ Saved graph 1 to protein_graphs/graph_1.pt
📌 Processing 5mwz.pdb...


Output()

✅ Saved graph 2 to protein_graphs/graph_2.pt
📌 Processing 6cyh.pdb...


Output()

✅ Saved graph 3 to protein_graphs/graph_3.pt
📌 Processing 4znx.pdb...


Output()

✅ Saved graph 4 to protein_graphs/graph_4.pt
📌 Processing 3jzi.pdb...


Output()

✅ Saved graph 5 to protein_graphs/graph_5.pt
📌 Processing 3fhe.pdb...


Output()

✅ Saved graph 6 to protein_graphs/graph_6.pt
📌 Processing 4inb.pdb...


Output()

✅ Saved graph 7 to protein_graphs/graph_7.pt
📌 Processing 5j4y.pdb...


Output()

✅ Saved graph 8 to protein_graphs/graph_8.pt
📌 Processing 5izq.pdb...


KeyboardInterrupt: 

In [16]:
def train(model, train_loader, optimizer, criterion, device):
    """Trains the model for one epoch."""
    model.train()  # Set model to training mode
    total_loss = 0.0

    for data in train_loader:
        data = data.to(device)  # Move batch to GPU (if available)

        optimizer.zero_grad()  # Reset gradients
        out = model(data)  # Forward pass
        loss = criterion(out, data.y)  # Compute loss

        loss.backward()  # Backpropagation
        optimizer.step()  # Update model parameters

        total_loss += loss.item()

    return total_loss / len(train_loader)  # Return average loss


def evaluate(model, test_loader, device):
    """Evaluates the model on the test dataset and returns accuracy."""
    model.eval()  # Set model to evaluation mode
    correct = 0
    total = 0

    with torch.no_grad():  # No gradient calculation during evaluation
        for data in test_loader:
            data = data.to(device)
            out = model(data)  # Forward pass
            pred = out.argmax(dim=1)  # Get class with highest probability

            correct += (pred == data.y).sum().item()
            total += data.y.size(0)

    return correct / total  # Return accuracy

In [19]:
# ✅ Define PyTorch Geometric Dataset
class LazyProteinDataset(Dataset):
    """Lazy dataset that loads graphs from disk one at a time."""
    def __init__(self, graph_folder):
        self.graph_folder = graph_folder
        self.graph_files = [os.path.join(graph_folder, f) for f in os.listdir(graph_folder) if f.endswith(".pt")]

    def __len__(self):  # ✅ Correctly defined __len__()
        return len(self.graph_files)

    def __getitem__(self, idx):  # ✅ Fix: Use __getitem__() instead of get()
        return torch.load(self.graph_files[idx])  # Load graph from disk



# ✅ Load Dataset from Saved Graphs
graph_folder = "protein_graphs"  # Set the folder where graphs were saved
protein_graph_dataset = LazyProteinDataset(graph_folder)

# ✅ Split dataset into train/test
train_size = int(0.8 * len(protein_graph_dataset))
test_size = len(protein_graph_dataset) - train_size

indices = list(range(len(protein_graph_dataset)))
train_indices, test_indices = torch.utils.data.random_split(indices, [train_size, test_size])

train_dataset = Subset(protein_graph_dataset, train_indices)
test_dataset = Subset(protein_graph_dataset, test_indices)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)


# ✅ Define Virtual Node GAT Model
class GATVirtualNode(nn.Module):
    def __init__(self, in_features, hidden_dim, num_heads, num_classes, dropout=0.2):
        super(GATVirtualNode, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads

        # Virtual Node (Global Graph Representation)
        self.virtual_node_embedding = nn.Parameter(torch.zeros(1, hidden_dim))

        # Graph Attention Layers
        self.gat1 = GATConv(in_features, hidden_dim // num_heads, heads=num_heads)
        self.gat2 = GATConv(hidden_dim, hidden_dim // num_heads, heads=num_heads)

        # MLP Classifier for **node classification**
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, num_classes)  # Output for **each node**
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        # Apply first GAT layer
        x = self.gat1(x, edge_index)
        x = torch.relu(x)
        x = self.dropout(x)

        # Apply second GAT layer
        x = self.gat2(x, edge_index)
        x = torch.relu(x)
        x = self.dropout(x)

        # Virtual Node Aggregation
        virtual_node = self.virtual_node_embedding.expand(x.size(0), -1)
        x = x + virtual_node

        # **Return node-level predictions**
        return self.mlp(x)  # Output has same shape as `data.y`

# ✅ Define Model, Loss, Optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GATVirtualNode(in_features=protein_graph_dataset[0].x.shape[1], hidden_dim=128, num_heads=4, num_classes=2).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()

model_save_path = "trained_gat_model.pth"  # Change if needed

# ✅ Training Loop (with saving)
num_epochs = 30
best_acc = 0.0  # Track best accuracy to save the best model

for epoch in range(num_epochs):
    train_loss = train(model, train_loader, optimizer, criterion, device)
    test_acc = evaluate(model, test_loader, device)

    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {train_loss:.4f}, Test Accuracy: {test_acc:.4f}")

    # ✅ Save best model
    if test_acc > best_acc:
        best_acc = test_acc
        torch.save(model.state_dict(), model_save_path)
        print(f"✅ Model saved at epoch {epoch+1} with accuracy {test_acc:.4f}!")

print("✅ Training Completed!")





  return torch.load(self.graph_files[idx])  # Load graph from disk


Epoch 1/30 - Loss: 0.4202, Test Accuracy: 0.8784
✅ Model saved at epoch 1 with accuracy 0.8784!
Epoch 2/30 - Loss: 0.4010, Test Accuracy: 0.8784
Epoch 3/30 - Loss: 0.3944, Test Accuracy: 0.8784
Epoch 4/30 - Loss: 0.3936, Test Accuracy: 0.8783
Epoch 5/30 - Loss: 0.3893, Test Accuracy: 0.8786
✅ Model saved at epoch 5 with accuracy 0.8786!
Epoch 6/30 - Loss: 0.3872, Test Accuracy: 0.8788
✅ Model saved at epoch 6 with accuracy 0.8788!
Epoch 7/30 - Loss: 0.3928, Test Accuracy: 0.8784
Epoch 8/30 - Loss: 0.3896, Test Accuracy: 0.8795
✅ Model saved at epoch 8 with accuracy 0.8795!
Epoch 9/30 - Loss: 0.3813, Test Accuracy: 0.8798
✅ Model saved at epoch 9 with accuracy 0.8798!
Epoch 10/30 - Loss: 0.3901, Test Accuracy: 0.8786
Epoch 11/30 - Loss: 0.3885, Test Accuracy: 0.8786
Epoch 12/30 - Loss: 0.3837, Test Accuracy: 0.8801
✅ Model saved at epoch 12 with accuracy 0.8801!
Epoch 13/30 - Loss: 0.3845, Test Accuracy: 0.8787
Epoch 14/30 - Loss: 0.3853, Test Accuracy: 0.8798
Epoch 15/30 - Loss: 0.3834

Output()

  model.load_state_dict(torch.load(model_path, map_location=device))


📌 Processing 1a09.pdb...



🔹 Binding Site Probability Predictions 🔹
Residue (144, 'A'): Probability of Binding Site = nan
Residue (145, 'A'): Probability of Binding Site = nan
Residue (146, 'A'): Probability of Binding Site = nan
Residue (147, 'A'): Probability of Binding Site = nan
Residue (148, 'A'): Probability of Binding Site = nan
Residue (149, 'A'): Probability of Binding Site = nan
Residue (150, 'A'): Probability of Binding Site = nan
Residue (151, 'A'): Probability of Binding Site = nan
Residue (152, 'A'): Probability of Binding Site = nan
Residue (153, 'A'): Probability of Binding Site = nan
Residue (154, 'A'): Probability of Binding Site = nan
Residue (155, 'A'): Probability of Binding Site = nan
Residue (156, 'A'): Probability of Binding Site = nan
Residue (157, 'A'): Probability of Binding Site = nan
Residue (158, 'A'): Probability of Binding Site = nan
Residue (159, 'A'): Probability of Binding Site = nan
Residue (160, 'A'): Probability of Binding Site = nan
Residue (161, 'A'): Probability of Bindi