In [1]:
import pandas as pd
import networkx as nx
import logging
#import matplotlib.pyplot as plt
from Bio.PDB import PDBParser
from Bio.PDB.SASA import ShrakeRupley
import numpy as np
import csv
import os
import pickle
from Bio.PDB import Residue

import torch
from torch_geometric.data import Data, DataLoader
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool



# Define interaction criteria
interaction_criteria = {
    "ARM_STACK": {"atomic_type1": "ARM", "atomic_type2": "ARM", "min_dist": 1.5, "max_dist": 3.5},
    "H_BOND": {"atomic_type1": "ACP", "atomic_type2": "DON", "min_dist": 2.0, "max_dist": 3.0},
    "HYDROPHOBIC": {"atomic_type1": "HPB", "atomic_type2": "HPB", "min_dist": 2.0, "max_dist": 3.8},
    "REPULSIVE": {"atomic_type1": "POS", "atomic_type2": "POS", "min_dist": 2.0, "max_dist": 6.0},
    "REPULSIVE": {"atomic_type1": "NEG", "atomic_type2": "NEG", "min_dist": 2.0, "max_dist": 6.0},
    "SALT_BRIDGE": {"atomic_type1": "POS", "atomic_type2": "NEG", "min_dist": 2.0, "max_dist": 6.0},
    "SS_BRIDGE": {"atomic_type1": "SG", "atomic_type2": "SG", "min_dist": 2.0, "max_dist": 2.2},
}

# Full names mapping
type_full_names = {
    "ACP": "Acceptor",
    "DON": "Donor",
    "POS": "Positive",
    "NEG": "Negative",
    "HPB": "Hydrophobic",
    "ARM": "Aromatic",
    "HYDROPHOBIC": "Hydrophobic",
    "SALT_BRIDGE": "Salt bridge",
    "ARM_STACK": "Aromatic",
    "H_BOND": "Hydrogen bond",
    "REPULSIVE": "Repulsive",
    "SS_BRIDGE": "Disulfide Bridge"
}

# Amino acid to atom types mapping (same as given in the provided data)
# Amino acid to atom types mapping
amino_acid_atoms = {
    "ALA": {"N": "DON", "CA": None, "C": None, "O": "ACP", "CB": "HPB"},
    "ARG": {"N": "DON", "CA": None, "C": None, "O": "ACP", "CB": "HPB", "CG": "HPB", "CD": None, 
            "NE": "POS,DON", "CZ": "POS", "NH1": "POS,DON", "NH2": "POS,DON"},
    "ASN": {"N": "DON", "CA": None, "C": None, "O": "ACP", "CB": "HPB", "CG": None, 
            "OD1": "ACP", "ND2": "DON"},
    "ASP": {"N": "DON", "CA": None, "C": None, "O": "ACP", "CB": "HPB", "CG": None, 
            "OD1": "NEG,ACP", "OD2": "NEG,ACP"},
    "CYS": {"N": "DON", "CA": None, "C": None, "O": "ACP", "CB": "HPB", "SG": "DON,ACP"},
    "GLN": {"N": "DON", "CA": None, "C": None, "O": "ACP", "CB": "HPB", "CG": "HPB", 
            "CD": None, "OE1": "ACP", "NE2": "DON"},
    "GLU": {"N": "DON", "CA": None, "C": None, "O": "ACP", "CB": "HPB", "CG": "HPB", 
            "CD": None, "OE1": "NEG,ACP", "OE2": "NEG,ACP"},
    "GLY": {"N": "DON", "CA": None, "C": None, "O": "ACP"},
    "HIS": {"N": "DON", "CA": None, "C": None, "O": "ACP", "CB": "HPB", "CG": "ARM", 
            "ND1": "ARM,POS", "CD2": "ARM", "CE1": "ARM", "NE2": "ARM,POS"},
    "ILE": {"N": "DON", "CA": None, "C": None, "O": "ACP", "CB": "HPB", "CG1": "HPB", 
            "CG2": "HPB", "CD1": "HPB"},
    "LEU": {"N": "DON", "CA": None, "C": None, "O": "ACP", "CB": "HPB", "CG": "HPB", 
            "CD1": "HPB", "CD2": "HPB"},
    "LYS": {"N": "DON", "CA": None, "C": None, "O": "ACP", "CB": "HPB", "CG": "HPB", 
            "CD": "HPB", "CE": None, "NZ": "POS,DON"},
    "MET": {"N": "DON", "CA": None, "C": None, "O": "ACP", "CB": "HPB", "CG": "HPB", 
            "SD": "ACP", "CE": "HPB"},
    "PHE": {"N": "DON", "CA": None, "C": None, "O": "ACP", "CB": "HPB", "CG": "HPB,ARM", 
            "CD1": "HPB,ARM", "CD2": "HPB,ARM", "CE1": "HPB,ARM", "CE2": "HPB,ARM", 
            "CZ": "HPB,ARM"},
    "PRO": {"N": None, "CA": None, "C": None, "O": "ACP", "CB": "HPB", "CG": "HPB", 
            "CD": None},
    "SER": {"N": "DON", "CA": None, "C": None, "O": "ACP", "CB": None, "OG": "DON,ACP"},
    "THR": {"N": "DON", "CA": None, "C": None, "O": "ACP", "CB": None, "OG1": "DON,ACP", 
            "CG2": "HPB"},
    "TRP": {"N": "DON", "CA": None, "C": None, "O": "ACP", "CB": "HPB", "CG": "HPB,ARM", 
            "CD1": "ARM", "CD2": "HPB,ARM", "NE1": "ARM,DON", "CE2": "ARM", 
            "CE3": "HPB,ARM", "CZ2": "HPB,ARM", "CZ3": "HPB,ARM", "CH2": "HPB,ARM"},
    "TYR": {"N": "DON", "CA": None, "C": None, "O": "ACP", "CB": "HPB", "CG": "HPB,ARM", 
            "CD1": "HPB,ARM", "CD2": "HPB,ARM", "CE1": "HPB,ARM", "CE2": "HPB,ARM", 
            "CZ": "ARM", "OH": "DON,ACP"},
    "VAL": {"N": "DON", "CA": None, "C": None, "O": "ACP", "CB": "HPB", 
            "CG1": "HPB", "CG2": "HPB"},
}



In [2]:
def load_data(filename):
    with open(filename, 'rb') as f:
        return pickle.load(f)


    
def parse_residue_id(residue_id):
    """
    Parses a residue ID string into a Biopython Residue object.

    Args:
        residue_id (str): Residue ID in the format "CHAIN_RESNAME_RESSEQ".

    Returns:
        Residue: A Residue object with parsed attributes.
    """
    chain, resname, resseq = residue_id.split("_")
    resseq = int(resseq)  # Convert residue sequence to integer
    return Residue.Residue((' ', resseq, ' '), resname, chain)



def filter_surface_residues(protein_sasa_results, sasa_threshold):
    """
    Filters surface residues based on SASA results for multiple protein structures.

    Args:
        protein_sasa_results (dict): SASA results for multiple protein structures,
                                     keyed by structure ID (PDB file name), 
                                     with residues and their SASA values.

    Returns:
        dict: Filtered residues for each protein structure, 
              keyed by structure ID, containing only residues meeting the surface threshold.
    """
    filtered_results = {}

    for structure_id, residues_dict in protein_sasa_results.items():
        # Step 3: Identify the maximum SASA for the current protein structure
        max_sasa = max(residues_dict.values())

        # Step 4: Set a threshold for identifying surface residues
        surface_threshold = sasa_threshold * max_sasa

        # Step 5: Identify and collect surface residues
        surface_residues = [
            parse_residue_id(residue_id)
            for residue_id, sasa_value in residues_dict.items()
            if sasa_value >= surface_threshold
        ]

        # Store the filtered results for the current protein
        filtered_results[structure_id] = surface_residues

    return filtered_results


In [3]:
def load_all_graphs(file_path):
    """
    Loads all graphs from a saved file.

    Args:
        file_path (str): Path to the file containing saved graphs.

    Returns:
        dict: A dictionary of graphs, keyed by structure IDs.
    """
    with open(file_path, 'rb') as file:
        graph_dict = pickle.load(file)
    print(f"All graphs loaded from {file_path}")
    return graph_dict


In [4]:
def map_surface_residues_to_graph(graph, surface_residues):
    mapped_surface_nodes = []
    residue_dict = {}
    
    for surface_res in surface_residues:
        res_name = surface_res.get_resname()
        res_id = surface_res.id[1]
        surface_node_key = f"{res_name}_{res_id}"
        
        # Check if this residue is in the graph
        if surface_node_key in graph.nodes:
            mapped_surface_nodes.append(surface_node_key)
            residue_dict[surface_node_key] = surface_res  # Store residue for later retrieval
            
            
    return mapped_surface_nodes,  residue_dict


def get_adjacent_nodes_and_edges(graph, surface_node):
    # Get the adjacent residues (1-hop neighbors)
    neighbors = list(graph.neighbors(surface_node))
    edges = list(graph.edges(surface_node, data=True))  # Edges with data (interaction_type, distance, etc.)
    
    return neighbors, edges



def get_adjacent_adjacent_nodes_and_edges(graph, neighbor_nodes):
    adjacent_adjacent_nodes = set()  # To avoid duplicate entries
    adjacent_adjacent_edges = []
    
    for neighbor in neighbor_nodes:
        # Get neighbors of the neighbor (2-hop neighbors)
        next_neighbors = list(graph.neighbors(neighbor))
        next_edges = list(graph.edges(neighbor, data=True))
        
        adjacent_adjacent_nodes.update(next_neighbors)  # Add next level of neighbors
        adjacent_adjacent_edges.extend(next_edges)      # Add next level of edges
    
    return list(adjacent_adjacent_nodes), adjacent_adjacent_edges


atom_encoding = {
        'ACP': 1,
        'DON': 2,
        'POS': 3,
        'NEG': 4,
        'HPB': 5,
        'ARM': 6,        
        None: 0  # Use 0 or any encoding for `None`
    }
  
    
def encode_atom_categories(atom_categories, atom_encoding, max_length):
    # Convert dictionary of atom categories to a list of encoded values
    encoded = [atom_encoding.get(value, 0) for value in atom_categories.values()]
    # Pad encoded list to ensure it has `max_length`
    return encoded + [0] * (max_length - len(encoded))


# Helper function to extract node features and edges for each mapped surface node and adjacent nodes
def extract_subgraph(graph, mapped_surface_nodes):
    nodes, edges = [], []
    max_atom_categories_length = max(len(amino_acid_atoms[res_name]) for res_name in amino_acid_atoms)

    for surface_node in mapped_surface_nodes:
        if surface_node in graph:
            res_name, res_id = surface_node.split('_')
            atom_categories = amino_acid_atoms[res_name]
            encoded_atom_categories = encode_atom_categories(atom_categories, atom_encoding, max_atom_categories_length)
            
            # Convert each node feature list to a tensor with consistent length
            node_tensor = torch.tensor([int(res_id)] + encoded_atom_categories, dtype=torch.float)
            nodes.append(node_tensor)
            
        adjacent_nodes, adjacent_edges = get_adjacent_nodes_and_edges(graph, surface_node)
        adjacent_adjacent_nodes, adjacent_adjacent_edges = get_adjacent_adjacent_nodes_and_edges(graph, adjacent_nodes)
        
        # Ensure edges are stored as integer pairs
        edges.extend([
            (int(edge[0].split('_')[1]), int(edge[1].split('_')[1]))
            for edge in adjacent_edges
        ])
        edges.extend([
            (int(edge[0].split('_')[1]), int(edge[1].split('_')[1]))
            for edge in adjacent_adjacent_edges
        ])
        
        for adj_node in adjacent_nodes + adjacent_adjacent_nodes:
            if adj_node in graph:
                res_name, res_id = adj_node.split('_')
                atom_categories = amino_acid_atoms[res_name]
                encoded_atom_categories = encode_atom_categories(atom_categories, atom_encoding, max_atom_categories_length)
                
                adj_node_tensor = torch.tensor([int(res_id)] + encoded_atom_categories, dtype=torch.float)
                nodes.append(adj_node_tensor)
                
    nodes_tensor = torch.stack(nodes)
    edges_tensor = torch.tensor(edges, dtype=torch.long).t().contiguous()  # Ensure edges are in long format
    
    return nodes_tensor, edges_tensor


# Define function to create a Data object for each protein graph from parsed PDB data
def create_protein_data(graph_p1, graph_p2, surface_residues_p1, surface_residues_p2, class_label):
    # Map surface residues to graph nodes and retrieve subgraphs
    mapped_surface_nodes_p1, res1 = map_surface_residues_to_graph(graph_p1, surface_residues_p1)
    mapped_surface_nodes_p2, res2 = map_surface_residues_to_graph(graph_p2, surface_residues_p2)
    
    # Get nodes and edges for subgraphs
    nodes_p1, edges_p1 = extract_subgraph(graph_p1, mapped_surface_nodes_p1)
    nodes_p2, edges_p2 = extract_subgraph(graph_p2, mapped_surface_nodes_p2)
    
    # Concatenate node features
    x = torch.cat((nodes_p1, nodes_p2), dim=0)
    
    # Offset edges for p2 and concatenate
    edge_index_p2 = edges_p2 + nodes_p1.size(0)  # Offset edges for p2 nodes
    edge_index = torch.cat((edges_p1, edge_index_p2), dim=1)
    
    # Define label for the protein
    y = torch.tensor([class_label], dtype=torch.long)
    
    return Data(x=x, edge_index=edge_index, y=y)


In [10]:
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def main():
    # Load the CSV dataset containing protein pairs and class labels
    #df_pairs = pd.read_csv('demo.csv')  # Load your dataset here (replace demo.csv with actual file)
    #logger.info("Dataset loaded successfully")
    
    
    sasa_threshold = 0.90
    output_graphs = "graphs_demo"
    output_sasa = "sasa_demo"
    
    sasas = "all_sasa_results.pkl"
    
    graphs = "all_graph_results.pkl"

    

    # Calculate SASA results for the current protein
    sasa_results = load_data(os.path.join(output_sasa, sasas))
    surface_residues = filter_surface_residues(sasa_results, sasa_threshold)
    
    # Load all graphs for verification
    loaded_graphs = load_all_graphs(os.path.join(output_graphs, graphs))
 

    
    
    #print(surface_residues_p1)
    #print(f"Loaded Graphs: {list(loaded_graphs_p1.keys())}")
    
    
    
    df_pairs = pd.read_csv('demo.csv')  # Load your dataset here (replace demo.csv with actual file)
    logger.info("Dataset loaded successfully")
    # Iterate over protein pairs and generate feature vectors
    data_list = []
    
    logger.info("Start creating protein data object")
    for index, row in df_pairs.iterrows():
        protein1 = row['P1']  # Assuming 'P1' is the protein 1 column
        protein2 = row['P2']  # Assuming 'P2' is the protein 2 column
        class_label = row['class']
        
        
        surface_residues_list_p1 = surface_residues[protein1]
        graph_ntw_p1 = loaded_graphs[protein1]
        mapped_surface_nodes_p2, residue_dict_p2 = map_surface_residues_to_graph(graph_ntw_p1, surface_residues_list_p1)
        
        surface_residues_list_p2 = surface_residues[protein2]
        graph_ntw_p2 = loaded_graphs[protein2]
        mapped_surface_nodes_p2, residue_dict_p2 = map_surface_residues_to_graph(graph_ntw_p2, surface_residues_list_p2)
        
        #print(graph_ntw_p2)
        
        # Create the Data object
        protein_data = create_protein_data(graph_ntw_p1, graph_ntw_p2, surface_residues_list_p1, surface_residues_list_p2, class_label)
        data_list.append(protein_data)
        
        
        
    
    # Save the data_list to a file
    torch.save(data_list, 'data_list3.pt')
    


In [11]:
if __name__ == "__main__":
    main()

2024-11-27 18:26:56,615 - INFO - Dataset loaded successfully
2024-11-27 18:26:56,615 - INFO - Start creating protein data object


All graphs loaded from graphs_demo\all_graph_results.pkl
Graph with 423 nodes and 725 edges
Graph with 907 nodes and 1376 edges
