In [None]:
from Bio.PDB import *
from tqdm import tqdm
import numpy as np

In [None]:
import torch
from torch_geometric.data import *

In [None]:
from utils.path_manage import get_files
data, lookup, ASD_dictionary, BCE_dictionary, _, __ = get_files()
entities = int(len(lookup)/2)

In [None]:
def get_one_hot_dictionary(keys):
    vector_size = len(keys)
    dictionary = {}
    for x_index, x in enumerate(keys):
        one_hot = [0] * (vector_size-1)
        one_hot.insert(x_index,1)
        dictionary[x] = one_hot
    return dictionary

def get_residue_list(chain):
    return [res for res in chain if res.resname in Polypeptide.d3_to_index.keys()]

def get_covalent_edges(residues):
    return [[index, index-1] for index, x in enumerate(residues) if x._id[1] == ((residues[index-1]._id[1])+1)]

def check_proximity(residue_one, residue_two, max_distance = 8) :
    """Returns the C-alpha distance between two residues"""
    '''C BETA IS PREFEREABLE BUT NEEDS GLYCINE ENCODING'''
    diff_vector  = residue_one["CA"].coord - residue_two["CA"].coord
    euclidian_distance = np.sqrt(np.sum(diff_vector * diff_vector))
    if euclidian_distance < max_distance:     
        return True
    else:
        return False

In [None]:
graph_keys = list(set(data[:,2]))
PDB_keys = [ASD_dictionary[lookup[graph_key]]['Protein_ID'] for graph_key in graph_keys]
# pdbl = PDBList()
# pdbl.download_pdb_files(pdb_codes = PDB_keys, file_format = 'pdb', pdir = 'PDB_files')

In [None]:
keys = Polypeptide.d3_to_index.keys()
hot_dick = get_one_hot_dictionary(keys)

In [None]:
def get_graph(residues):
    processed_residues = get_residue_list(residues)
    covalent_edges = get_covalent_edges(processed_residues)
    proximal_edges = [] 
    for x_index, x in enumerate(processed_residues):
        for y_index, y in enumerate(processed_residues):
            if y_index >= x_index + 5: 
                if y > x:
                    if check_proximity(x, y):
                        proximal_edges.append([x_index,y_index])

    all_edges = covalent_edges + proximal_edges
    node_features = [hot_dick[res.resname] for res in processed_residues]
    edge_index = torch.tensor(all_edges, dtype=torch.long)
    node_data = torch.tensor(node_features, dtype=torch.float)
    # print(node_data.shape)
    graph = Data(x = node_data, edge_index=edge_index.t().contiguous())
    return graph

## your graph is directed which probably breaks it

In [None]:
parser = PDBParser()
graph_list = []

for PDB,graph in tqdm(zip(PDB_keys, graph_keys)):
    try:
        structure = parser.get_structure('{}'.format(graph), 'PDB_files/pdb{}.ent'.format(PDB.lower()))
        model = structure[0]
        residues = structure.get_residues()
        output_graph = get_graph(residues)
        graph_list.append(output_graph)
    except:
        print(PDB, ' missing')
        graph_list.append('{} missing'.format(PDB))



In [None]:
# import pickle
# with open('data/protein_graph_list.pickle', 'wb') as item:
#     pickle.dump(graph_list, item)

In [None]:
import pickle
with open('data/protein_graph_list.pickle', 'rb') as item:
    stored_graphs = pickle.load(item)

In [None]:
def get_adj_mask(max_nodes, graph):
    num_nodes = graph.num_nodes
    mask = np.zeros([max_nodes,max_nodes], dtype = bool)
    mask[0:num_nodes][0:num_nodes] = True
    
    adjacency = np.zeros([max_nodes,max_nodes]) # Check if Dtype int is needed! 
    edges = graph.edge_index.T
    for edge in edges:
        adjacency[edge[0]][edge[1]] = 1
        adjacency[edge[1]][edge[0]] = 1
    return Data(x = store.x, adj = adjacency, mask = mask)
    

max_graph_size = 2000
graphs_with_masks = []
for store, PDB_key in tqdm(zip(stored_graphs, PDB_keys)):
    try:
        if store.num_nodes > max_graph_size:
            # print(store, ' too big')
            graphs_with_masks.append('{} too big'.format(PDB_key))
        else:
            graphs_with_masks.append(get_adj_mask(max_graph_size, store))
    except:
        # print(PDB_key, ' missing')
        graphs_with_masks.append('{} missing'.format(PDB_key))


In [None]:
max_graph_size = 2000
graphs_with_masks = []
for store, PDB_key in tqdm(zip(stored_graphs, PDB_keys)):
    try:
        if store.num_nodes > max_graph_size:
            # print(store, ' too big')
            graphs_with_masks.append('{} too big'.format(PDB_key))
        else:
            graphs_with_masks.append(get_adj_mask(max_graph_size, store))
    except:
        # print(PDB_key, ' missing')
        graphs_with_masks.append('{} missing'.format(PDB_key))



In [None]:
print(len(graphs_with_masks))
print(len([x for x in graphs_with_masks if not isinstance(x, str)]))


In [None]:
import pickle
with open('data/protein_graphs_with_masks.pickle', 'wb') as item:
    pickle.dump(graphs_with_masks[0], item)

In [None]:
# this is apparently 25 gb