In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import json

from sklearn import preprocessing
import torch
# from torch_geometric.data import Data, Dataset, InMemoryDataset
# from graphein.protein.config import ProteinGraphConfig
# from graphein.ml.conversion import GraphFormatConvertor
import graphein.protein as gp
from functools import partial
from graphein.ml.conversion import GraphFormatConvertor
from graphein.protein.graphs import construct_graph
from graphein.protein.graphs import read_pdb_to_dataframe

from IPython.display import clear_output
clear_output()

In [2]:
df = pd.read_csv('dataset.csv')

  exec(code_obj, self.user_global_ns, self.user_ns)


# Graphein

In [3]:
def extract_residues_with_data(graphein_df, data_df):
    
    residues_to_keep = list(data_df.pdb_chain + data_df.pdb_resi.apply(lambda x: str(x)))
    graphein_df['temp'] = graphein_df.chain_id + graphein_df.residue_number.apply(lambda x: str(x))
    graphein_df = graphein_df[graphein_df.temp.isin(residues_to_keep)].copy()
    
    return graphein_df

In [4]:
def get_node_features(graph, dssp_df, ss_label_binarizer, features):
    
    node_features = []
    for feat_name in features:
        if feat_name == "rsa":
            feat = torch.tensor(dssp_df.rsa)
        elif feat_name == "ss":
            feat = torch.tensor(ss_label_binarizer.transform(dssp_df.ss))
        else:
            feat = graph[feat_name]
        if feat.dim() == 1:
            feat = feat.unsqueeze(dim=1)
        node_features.append(feat)

    return torch.concatenate(node_features, dim=1)

In [5]:
def generate_protein_graph(pdb_code, data_df, label, config, features, ss_label_binarizer, convertor):
    
    data_df = data_df[data_df.pdb_id==pdb_code].copy();
    raw_df = read_pdb_to_dataframe(pdb_code=pdb_code);
    graphein_df = extract_residues_with_data(raw_df, data_df);
    graph = construct_graph(config=config, df=graphein_df);
    dssp_df = graph.graph['dssp_df']
    graph = convertor(graph);
    
    node_features = get_node_features(graph, dssp_df, ss_label_binarizer, features);
    coords = graph._store['coords']
    edge_map = graph._store['edge_index']
    labels = torch.tensor(list(data_df[label]))
    
    return node_features, coords, edge_map, labels

In [6]:
label = 's2calc'

In [7]:
edge_funcs = {"edge_construction_functions": [gp.add_peptide_bonds,
                                              gp.add_hydrophobic_interactions,
                                              gp.add_disulfide_interactions,
                                              gp.add_hydrogen_bond_interactions,
                                              gp.add_ionic_interactions,
                                              gp.add_aromatic_interactions,
                                              gp.add_aromatic_sulphur_interactions,
                                              gp.add_cation_pi_interactions,
                                              gp.add_vdw_interactions,
                                              gp.add_pi_stacking_interactions,
                                              gp.add_backbone_carbonyl_carbonyl_interactions]}

graph_metadata = {"graph_metadata_functions" : [gp.rsa,
                                                gp.secondary_structure]}
node_metadata = {"node_metadata_functions" : [gp.amino_acid_one_hot,
                                              gp.meiler_embedding,
                                              partial(gp.expasy_protein_scale, add_separate=True)],
                 "dssp_config": gp.DSSPConfig()}

config = gp.ProteinGraphConfig(**{**edge_funcs, **graph_metadata, **node_metadata})

features = ["amino_acid_one_hot", "meiler",
            "bulkiness", "averageflexibility",
            "isoelectric_points", "pka_rgroup",
            "polaritygrantham", "hphob_black",
            "rsa", "ss"]

convertor = GraphFormatConvertor(src_format="nx", dst_format="pyg", columns=["amino_acid_one_hot", "meiler",
                                                                            "bulkiness", "averageflexibility",
                                                                            "isoelectric_points", "pka_rgroup",
                                                                            "polaritygrantham", "hphob_black",
                                                                            "coords", "edge_index"])

secondary_structures = ['H', 'B', 'E', 'G', 'I', 'T', 'S', '-']
ss_lb = preprocessing.LabelBinarizer().fit(secondary_structures)

In [8]:
pdb_ids = list(set(df.pdb_id))
pdb_graphs = {}
issue_count = 0
for pdb_id in tqdm(pdb_ids):
    try:
        pdb_graphs[pdb_id] = generate_protein_graph(pdb_id, df, label, config,
                                                features, ss_lb, convertor);
    except (ValueError, RuntimeError, IndexError):
        issue_count += 1  
    clear_output()

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1034/1034 [23:23<00:00,  1.36s/it]


In [9]:
# import pickle as pkl
# pkl.dump(pdb_graphs, open('pdb_graphs.pkl', 'wb'))