Hello!
This is a generic scipt for extracting predictions for KS structures from pretrained models. Some information you will need to get this going:
1. This notebook was written for Python 3.8.17 for use in a Linux environment.    
2. Packages that must be installed prior to use are listed below. You can install these using pip in the terminal. 
   numpy==1.23.5
   pandas==1.5.3
   torch==2.0.1
   scikit-learn==1.3.0
   matplotlib==3.7.2
   torch-geometric==2.3.1
   graphein==1.7.0

3. The input for this script is KS homodimers as .pdb files. You can generate these using colabfold elsewhere.


In [3]:
### Variables ###
input_structures_folder = "/path/to/structure/folder/"  # Path to protein structure folder. Please ensure that only .pdb structures are contained in this file
graph_storage_folder = "/path/to/empty/graph/storage/"  # Please designate an empty fold for storing graph structures
model_path = "20250621_DHvsER_model6.pth"               # Trained model path. 
output_folder = "/path/to/results/folder/"              # Must end with "/"
output_file_name = "name.csv"                           # Must end with ".csv"
num_cores = 7                                           # Number of CPU cores to use for graph construction

In [7]:
### Create graphs from protein structures ###

import os
from functools import partial
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import graphein
from graphein.protein.edges.distance import add_k_nn_edges
from graphein.protein.features.nodes.amino_acid import amino_acid_one_hot
from graphein.ml import GraphFormatConvertor, ProteinGraphDataset

# Load structure file paths
paths = [os.path.join(input_structures_folder, f) 
         for f in os.listdir(input_structures_folder)]
print(f"Number of files found: {len(paths)}")

# Graphein config
config = graphein.protein.ProteinGraphConfig(
    edge_construction_functions=[partial(add_k_nn_edges, k=4, long_interaction_threshold=0)],
    node_metadata_functions=[amino_acid_one_hot]
)

# Conversion config
convertor = GraphFormatConvertor(
    src_format="nx", dst_format="pyg", verbose="all_info",
    columns=["edge_index", 
             "amino_acid_one_hot", 
             "node_id", "chain_id",
             "residue_name", 
             "residue_number", 
             "atom_type", 
             "element_symbol",
             "coords", 
             "b_factor", 
             "kind", 
             "name", 
             "chain_ids"]
)

# Build dataset
dataset = ProteinGraphDataset(
    root=graph_storage_folder, paths=paths,
    graphein_config=config, graph_format_convertor=convertor,
    num_cores=num_cores
)

# Convert to PyG Data objects
data_list = []
for g in dataset:
    data = Data(
        edge_index=g.edge_index,
        node_id=g.node_id,
        coords=g.coords,
        name=g.name,
        num_nodes=g.num_nodes,
        x=g.amino_acid_one_hot.view(len(g.node_id), 20)  # reshape one-hot encoding
    )
    data_list.append(data)

loader = DataLoader(data_list, batch_size=64)
print("Finished making graphs.")

Number of files found: 1262
Finished making graphs.


In [6]:
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.nn import GCNConv, global_max_pool

# Define a GNN model
class MyGNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.lin = Linear(hidden_dim, output_dim)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x.float(), edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        x = global_max_pool(x, batch)
        return F.softmax(self.lin(x), dim=1)

# Load trained model
input_dim, hidden_dim, output_dim = 20, 64, 2
model = MyGNN(input_dim, hidden_dim, output_dim)
model.load_state_dict(torch.load(model_path, map_location="cpu"))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device).eval()

# Run predictions
predictions = []
with torch.no_grad():
    for batch in loader:
        batch = batch.to(device)
        out = model(batch.x, batch.edge_index, batch.batch)
        predictions.append(out.cpu().numpy())

# Combine predictions
predictions = np.concatenate(predictions, axis=0)

# Create results DataFrame
df = pd.DataFrame({
    "Class 0 Probability": predictions[:, 0],
    "Class 1 Probability": predictions[:, 1],
    "File Paths": [os.path.splitext(os.path.basename(p))[0] for p in paths]
})

result_df = df.sort_values(by="File Paths")

# Save results
os.makedirs(output_folder, exist_ok=True)
save_path = os.path.join(output_folder, output_file_name)
result_df.to_csv(save_path, index=False)

print(result_df)

print(f"Results saved to: {save_path}")

      Class 0 Probability  Class 1 Probability           File Paths
193              0.929503             0.070497  SeEryAIII.KS1_A135A
1050             0.923521             0.076479  SeEryAIII.KS1_A135D
171              0.924780             0.075220  SeEryAIII.KS1_A135E
914              0.904220             0.095780  SeEryAIII.KS1_A135G
778              0.935340             0.064660  SeEryAIII.KS1_A135I
...                   ...                  ...                  ...
648              0.969999             0.030001      SeEryAIII_Mod.2
365              0.997945             0.002055       SeEryAII_Mod.1
350              0.982858             0.017142       SeEryAII_Mod.2
90               0.949354             0.050646        SeEryAI_Mod.1
493              0.977907             0.022093        SeEryAI_Mod.2

[1262 rows x 3 columns]
Results saved to: /home/q31032mw/Dropbox (The University of Manchester)/Max/17_ML_Project/GitHub/01 Scripts/03 Graph Networks/05 SeEryAIII_KS1_Mutation/2025072