In [None]:
pip install torch

In [None]:
pip install torch-geometric

In [5]:
import pandas as pd
import networkx as nx
import logging
#import matplotlib.pyplot as plt
from Bio.PDB import PDBParser
from Bio.PDB.SASA import ShrakeRupley
import numpy as np
import csv
import os

import torch
from torch_geometric.data import Data, DataLoader
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool



# Define interaction criteria
interaction_criteria = {
    "ARM_STACK": {"atomic_type1": "ARM", "atomic_type2": "ARM", "min_dist": 1.5, "max_dist": 3.5},
    "H_BOND": {"atomic_type1": "ACP", "atomic_type2": "DON", "min_dist": 2.0, "max_dist": 3.0},
    "HYDROPHOBIC": {"atomic_type1": "HPB", "atomic_type2": "HPB", "min_dist": 2.0, "max_dist": 3.8},
    "REPULSIVE": {"atomic_type1": "POS", "atomic_type2": "POS", "min_dist": 2.0, "max_dist": 6.0},
    "REPULSIVE": {"atomic_type1": "NEG", "atomic_type2": "NEG", "min_dist": 2.0, "max_dist": 6.0},
    "SALT_BRIDGE": {"atomic_type1": "POS", "atomic_type2": "NEG", "min_dist": 2.0, "max_dist": 6.0},
    "SS_BRIDGE": {"atomic_type1": "SG", "atomic_type2": "SG", "min_dist": 2.0, "max_dist": 2.2},
}

# Full names mapping
type_full_names = {
    "ACP": "Acceptor",
    "DON": "Donor",
    "POS": "Positive",
    "NEG": "Negative",
    "HPB": "Hydrophobic",
    "ARM": "Aromatic",
    "HYDROPHOBIC": "Hydrophobic",
    "SALT_BRIDGE": "Salt bridge",
    "ARM_STACK": "Aromatic",
    "H_BOND": "Hydrogen bond",
    "REPULSIVE": "Repulsive",
    "SS_BRIDGE": "Disulfide Bridge"
}

# Amino acid to atom types mapping (same as given in the provided data)
# Amino acid to atom types mapping
amino_acid_atoms = {
    "ALA": {"N": "DON", "CA": None, "C": None, "O": "ACP", "CB": "HPB"},
    "ARG": {"N": "DON", "CA": None, "C": None, "O": "ACP", "CB": "HPB", "CG": "HPB", "CD": None, 
            "NE": "POS,DON", "CZ": "POS", "NH1": "POS,DON", "NH2": "POS,DON"},
    "ASN": {"N": "DON", "CA": None, "C": None, "O": "ACP", "CB": "HPB", "CG": None, 
            "OD1": "ACP", "ND2": "DON"},
    "ASP": {"N": "DON", "CA": None, "C": None, "O": "ACP", "CB": "HPB", "CG": None, 
            "OD1": "NEG,ACP", "OD2": "NEG,ACP"},
    "CYS": {"N": "DON", "CA": None, "C": None, "O": "ACP", "CB": "HPB", "SG": "DON,ACP"},
    "GLN": {"N": "DON", "CA": None, "C": None, "O": "ACP", "CB": "HPB", "CG": "HPB", 
            "CD": None, "OE1": "ACP", "NE2": "DON"},
    "GLU": {"N": "DON", "CA": None, "C": None, "O": "ACP", "CB": "HPB", "CG": "HPB", 
            "CD": None, "OE1": "NEG,ACP", "OE2": "NEG,ACP"},
    "GLY": {"N": "DON", "CA": None, "C": None, "O": "ACP"},
    "HIS": {"N": "DON", "CA": None, "C": None, "O": "ACP", "CB": "HPB", "CG": "ARM", 
            "ND1": "ARM,POS", "CD2": "ARM", "CE1": "ARM", "NE2": "ARM,POS"},
    "ILE": {"N": "DON", "CA": None, "C": None, "O": "ACP", "CB": "HPB", "CG1": "HPB", 
            "CG2": "HPB", "CD1": "HPB"},
    "LEU": {"N": "DON", "CA": None, "C": None, "O": "ACP", "CB": "HPB", "CG": "HPB", 
            "CD1": "HPB", "CD2": "HPB"},
    "LYS": {"N": "DON", "CA": None, "C": None, "O": "ACP", "CB": "HPB", "CG": "HPB", 
            "CD": "HPB", "CE": None, "NZ": "POS,DON"},
    "MET": {"N": "DON", "CA": None, "C": None, "O": "ACP", "CB": "HPB", "CG": "HPB", 
            "SD": "ACP", "CE": "HPB"},
    "PHE": {"N": "DON", "CA": None, "C": None, "O": "ACP", "CB": "HPB", "CG": "HPB,ARM", 
            "CD1": "HPB,ARM", "CD2": "HPB,ARM", "CE1": "HPB,ARM", "CE2": "HPB,ARM", 
            "CZ": "HPB,ARM"},
    "PRO": {"N": None, "CA": None, "C": None, "O": "ACP", "CB": "HPB", "CG": "HPB", 
            "CD": None},
    "SER": {"N": "DON", "CA": None, "C": None, "O": "ACP", "CB": None, "OG": "DON,ACP"},
    "THR": {"N": "DON", "CA": None, "C": None, "O": "ACP", "CB": None, "OG1": "DON,ACP", 
            "CG2": "HPB"},
    "TRP": {"N": "DON", "CA": None, "C": None, "O": "ACP", "CB": "HPB", "CG": "HPB,ARM", 
            "CD1": "ARM", "CD2": "HPB,ARM", "NE1": "ARM,DON", "CE2": "ARM", 
            "CE3": "HPB,ARM", "CZ2": "HPB,ARM", "CZ3": "HPB,ARM", "CH2": "HPB,ARM"},
    "TYR": {"N": "DON", "CA": None, "C": None, "O": "ACP", "CB": "HPB", "CG": "HPB,ARM", 
            "CD1": "HPB,ARM", "CD2": "HPB,ARM", "CE1": "HPB,ARM", "CE2": "HPB,ARM", 
            "CZ": "ARM", "OH": "DON,ACP"},
    "VAL": {"N": "DON", "CA": None, "C": None, "O": "ACP", "CB": "HPB", 
            "CG1": "HPB", "CG2": "HPB"},
}


# 1. **Surface Area Residue Calculation using SASA**
# Define helper functions for SASA calculation
def calculate_sasa(pdb_file):
    """
    Calculate the Solvent Accessible Surface Area (SASA) for residues in a PDB file.
    Uses the Shrake-Rupley method.
    """
    parser = PDBParser()
    structure = parser.get_structure(pdb_file, pdb_file)
    
    # Step 2: Initialize the ShrakeRupley class and compute SASA
    sr = ShrakeRupley()
    sr.compute(structure, level="A")  # Compute at the atom level
    
    surface_residues = []
    residues_dict = {}

    for chain in structure[0]:  # Iterate over chains in the first model
        for residue in chain:
            # Skip any residues that are not standard (e.g., water, ligands)
            if residue.get_resname() in ["WAT", "HOH", "ACE"]:
                continue

            # Accumulate SASA for all atoms in the residue
            residue_sasa = sum(atom.sasa for atom in residue if hasattr(atom, 'sasa'))
            id = f"{residue.get_resname()} {residue.id[1]}"
            residues_dict[residue] = residue_sasa
            
    # # Step 3:Identify the maximum SASA: Find the maximum SASA value across all residues in the protein structure
    max_sasa = max(residues_dict.values())
        
    # Step 4: Set a threshold for identifying surface residues
    surface_threshold = 0.25 * max_sasa
    
    # Step 5: Identify and collect surface residues
    for sasa_residue in residues_dict.items():
        if sasa_residue[1] >= surface_threshold:
            surface_residues.append(sasa_residue[0])
            #print(sasa_residue)

    return surface_residues


# Step 1: Parse PDB File
def parse_pdb(pdb_file):
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure(pdb_file, pdb_file)
    residues = {}
    for model in structure:
        for chain in model:
            for residue in chain:
                res_name = residue.get_resname()
                res_id = residue.get_id()[1]  # Residue sequence number
                res_key = f"{res_name}_{res_id}"
                if res_key not in residues:
                    residues[res_key] = {
                        'residue': res_name,
                        'res_id': res_id,
                        'atoms': []
                    }
                for atom in residue:
                    atom_name = atom.get_name()
                    element = atom.element
                    coord = atom.get_coord()
                    atom_type = None
                    if res_name in amino_acid_atoms:
                        atom_type_info = amino_acid_atoms[res_name].get(atom_name)
                        if atom_type_info:
                            atom_types = atom_type_info.split(',')
                            atom_type = atom_types  # list
                    residues[res_key]['atoms'].append({
                        'atom_name': atom_name,
                        'element': element,
                        'coord': coord,
                        'atom_type': atom_type
                    })
    return residues

# Step 2: Compute Residue-Residue Interactions
def compute_residue_interactions(residues, interaction_criteria):
    interactions = []
    residue_keys = list(residues.keys())
    for i in range(len(residue_keys)):
        for j in range(i+1, len(residue_keys)):
            res1 = residues[residue_keys[i]]
            res2 = residues[residue_keys[j]]
            # Compute all atom pairs between res1 and res2
            for atom1 in res1['atoms']:
                for atom2 in res2['atoms']:
                    distance = np.linalg.norm(atom1['coord'] - atom2['coord'])
                    # Check each interaction type
                    for interaction, criteria in interaction_criteria.items():
                        type1 = criteria['atomic_type1']
                        type2 = criteria['atomic_type2']
                        min_dist = criteria['min_dist']
                        max_dist = criteria['max_dist']
                        
                        # Check if atom types match
                        if atom1['atom_type'] and atom2['atom_type']:
                            # Since atom_type can be a list, check intersection
                            if type1 in atom1['atom_type'] and type2 in atom2['atom_type']:
                                if min_dist <= distance <= max_dist:
                                    interactions.append({
                                        'res1': residue_keys[i],
                                        'res2': residue_keys[j],
                                        'interaction_type': interaction,
                                        'distance': distance
                                    })
    return interactions


# Step 3: Generate Graph Network
def generate_residue_graph(residues, residue_interactions):
    G = nx.Graph()
    # Add residues as nodes
    for res_key, res_info in residues.items():
        G.add_node(res_key, residue=res_info['residue'], res_id=res_info['res_id'], 
                   atom_types=[atom['atom_type'] for atom in res_info['atoms'] if atom['atom_type']])
    
    # Add edges based on interactions
    for interaction in residue_interactions:
        res1 = interaction['res1']
        res2 = interaction['res2']
        interaction_type = interaction['interaction_type']
        distance = interaction['distance']
        G.add_edge(res1, res2, interaction_type=interaction_type, distance=distance)
    
    return G


def map_surface_residues_to_graph(graph, surface_residues):
    mapped_surface_nodes = []
    for surface_res in surface_residues:
        res_name = surface_res.get_resname()
        res_id = surface_res.id[1]
        surface_node_key = f"{res_name}_{res_id}"
        
        # Check if this residue is in the graph
        if surface_node_key in graph.nodes:
            mapped_surface_nodes.append(surface_node_key)
    return mapped_surface_nodes


def get_adjacent_nodes_and_edges(graph, surface_node):
    # Get the adjacent residues (1-hop neighbors)
    neighbors = list(graph.neighbors(surface_node))
    edges = list(graph.edges(surface_node, data=True))  # Edges with data (interaction_type, distance, etc.)
    
    return neighbors, edges



def get_adjacent_adjacent_nodes_and_edges(graph, neighbor_nodes):
    adjacent_adjacent_nodes = set()  # To avoid duplicate entries
    adjacent_adjacent_edges = []
    
    for neighbor in neighbor_nodes:
        # Get neighbors of the neighbor (2-hop neighbors)
        next_neighbors = list(graph.neighbors(neighbor))
        next_edges = list(graph.edges(neighbor, data=True))
        
        adjacent_adjacent_nodes.update(next_neighbors)  # Add next level of neighbors
        adjacent_adjacent_edges.extend(next_edges)      # Add next level of edges
    
    return list(adjacent_adjacent_nodes), adjacent_adjacent_edges


atom_encoding = {
        'ACP': 1,
        'DON': 2,
        'POS': 3,
        'NEG': 4,
        'HPB': 5,
        'ARM': 6,        
        None: 0  # Use 0 or any encoding for `None`
    }
  
    
def encode_atom_categories(atom_categories, atom_encoding, max_length):
    # Convert dictionary of atom categories to a list of encoded values
    encoded = [atom_encoding.get(value, 0) for value in atom_categories.values()]
    # Pad encoded list to ensure it has `max_length`
    return encoded + [0] * (max_length - len(encoded))


# Helper function to extract node features and edges for each mapped surface node and adjacent nodes
def extract_subgraph(graph, mapped_surface_nodes):
    nodes, edges = [], []
    max_atom_categories_length = max(len(amino_acid_atoms[res_name]) for res_name in amino_acid_atoms)

    for surface_node in mapped_surface_nodes:
        if surface_node in graph:
            res_name, res_id = surface_node.split('_')
            atom_categories = amino_acid_atoms[res_name]
            encoded_atom_categories = encode_atom_categories(atom_categories, atom_encoding, max_atom_categories_length)
            
            # Convert each node feature list to a tensor with consistent length
            node_tensor = torch.tensor([int(res_id)] + encoded_atom_categories, dtype=torch.float)
            nodes.append(node_tensor)
            
        adjacent_nodes, adjacent_edges = get_adjacent_nodes_and_edges(graph, surface_node)
        adjacent_adjacent_nodes, adjacent_adjacent_edges = get_adjacent_adjacent_nodes_and_edges(graph, adjacent_nodes)
        
        # Ensure edges are stored as integer pairs
        edges.extend([
            (int(edge[0].split('_')[1]), int(edge[1].split('_')[1]))
            for edge in adjacent_edges
        ])
        edges.extend([
            (int(edge[0].split('_')[1]), int(edge[1].split('_')[1]))
            for edge in adjacent_adjacent_edges
        ])
        
        for adj_node in adjacent_nodes + adjacent_adjacent_nodes:
            if adj_node in graph:
                res_name, res_id = adj_node.split('_')
                atom_categories = amino_acid_atoms[res_name]
                encoded_atom_categories = encode_atom_categories(atom_categories, atom_encoding, max_atom_categories_length)
                
                adj_node_tensor = torch.tensor([int(res_id)] + encoded_atom_categories, dtype=torch.float)
                nodes.append(adj_node_tensor)
                
    nodes_tensor = torch.stack(nodes)
    edges_tensor = torch.tensor(edges, dtype=torch.long).t().contiguous()  # Ensure edges are in long format
    
    return nodes_tensor, edges_tensor


# Define function to create a Data object for each protein graph from parsed PDB data
def create_protein_data(graph_p1, graph_p2, surface_residues_p1, surface_residues_p2, class_label):
    # Map surface residues to graph nodes and retrieve subgraphs
    mapped_surface_nodes_p1 = map_surface_residues_to_graph(graph_p1, surface_residues_p1)
    mapped_surface_nodes_p2 = map_surface_residues_to_graph(graph_p2, surface_residues_p2)
    
    # Get nodes and edges for subgraphs
    nodes_p1, edges_p1 = extract_subgraph(graph_p1, mapped_surface_nodes_p1)
    nodes_p2, edges_p2 = extract_subgraph(graph_p2, mapped_surface_nodes_p2)
    
    # Concatenate node features
    x = torch.cat((nodes_p1, nodes_p2), dim=0)
    
    # Offset edges for p2 and concatenate
    edge_index_p2 = edges_p2 + nodes_p1.size(0)  # Offset edges for p2 nodes
    edge_index = torch.cat((edges_p1, edge_index_p2), dim=1)
    
    # Define label for the protein
    y = torch.tensor([class_label], dtype=torch.long)
    
    return Data(x=x, edge_index=edge_index, y=y)

In [6]:
# Define the GNN model
class GNN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GNN, self).__init__()
        # Define GCN layers
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        # Fully connected layer for classification
        self.fc = torch.nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x, edge_index, batch):
        # Apply GCN layers with ReLU activation
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        # Global pooling for graph-level embedding
        x = global_mean_pool(x, batch)  # Pool over nodes for each graph
        # Output layer
        x = self.fc(x)
        return x

# Function to train the model
def train(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    for data in loader:
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

# Function to test the model
def test(model, loader):
    model.eval()
    correct = 0
    for data in loader:
        out = model(data.x, data.edge_index, data.batch)
        pred = out.argmax(dim=1)
        correct += int((pred == data.y).sum())
    return correct / len(loader.dataset)

In [9]:
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def main():
    # Load the CSV dataset containing protein pairs and class labels
    df_pairs = pd.read_csv('demo.csv')  # Load your dataset here (replace demo.csv with actual file)
    logger.info("Dataset loaded successfully")
    
    # Iterate over protein pairs and generate feature vectors
    data_list = []
    
    logger.info("Start creating protein data object")
    for index, row in df_pairs.iterrows():
        protein1 = row['P1']  # Assuming 'P1' is the protein 1 column
        protein2 = row['P2']  # Assuming 'P2' is the protein 2 column
        class_label = row['class']
        
        
        # P1
        surface_residues_p1 = calculate_sasa('files/' + protein1)
        residues_p1 = parse_pdb('files/' + protein1)
        interactions_p1 = compute_residue_interactions(residues_p1, interaction_criteria)
        graph_p1 = generate_residue_graph(residues_p1, interactions_p1)
        
        
        # P2
        surface_residues_p2 = calculate_sasa('files/' + protein2)
        residues_p2 = parse_pdb('files/' + protein2)
        interactions_p2 = compute_residue_interactions(residues_p2, interaction_criteria)
        graph_p2 = generate_residue_graph(residues_p2, interactions_p2)
        
        
        # Create the Data object
        protein_data = create_protein_data(graph_p1, graph_p2, surface_residues_p1, surface_residues_p2, class_label)
        data_list.append(protein_data)
        
    # Save the data_list to a file
    torch.save(data_list, 'data_list3.pt')
    
    

In [10]:
if __name__ == "__main__":
    main()

2024-11-09 18:06:58,154 - INFO - Dataset loaded successfully
2024-11-09 18:06:58,155 - INFO - Start creating protein data object


In [None]:
import torch

# Load the two .pt files
data_list_1 = torch.load('data_list1.pt')
data_list_2 = torch.load('data_list2.pt')

# Merge the lists
merged_data_list = data_list_1 + data_list_2  # Concatenates the two lists

# Save the merged list to a new .pt file
torch.save(merged_data_list, 'merged_data_list.pt')

print("Merged data saved to 'merged_data_list.pt'")


In [None]:
from torch_geometric.data import DataLoader

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def main():    
    logger.info("Start defining hyperparameters")
    # Define hyperparameters
    input_dim = 10  # Replace with the actual feature dimension of nodes
    hidden_dim = 64
    output_dim = 3  # Number of classes in your multi-class classification
    epochs = 50
    batch_size = 32
    learning_rate = 0.01
    
    logger.info("Start Initializing the model")
    # Initialize the model, optimizer, and loss function
    model = GNN(input_dim, hidden_dim, output_dim)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = torch.nn.CrossEntropyLoss()
    
    logger.info("Start loading into DataLoader")
    # Load the data_list from the saved file
    loaded_data_list = torch.load('data_list.pt')
    loader = DataLoader(loaded_data_list, batch_size=batch_size, shuffle=True)

    # Example: Iterating through DataLoader and printing batch info
    for batch in loader:
        print("Batch node features shape:", batch.x.shape)
        print("Batch edge indices shape:", batch.edge_index.shape)
        print("Batch labels:", batch.y)
        print("------------")
        
    logger.info("Start training loop")
    # Training loop
    for epoch in range(epochs):
        loss = train(model, loader, optimizer, criterion)
        accuracy = test(model, loader)
        print(f'Epoch {epoch+1}/{epochs}, Loss: {loss:.4f}, Accuracy: {accuracy:.4f}')

    print("Training complete.")

In [None]:
if __name__ == "__main__":
    main()