In [None]:
from Bio.PDB import PDBParser
from Bio.PDB.DSSP import DSSP, ss_to_index

import numpy as np
import torch
from torch_geometric.data import Data
import time

In [None]:
dist_th = 12.0  # threshold distance between CA atoms below which edges will be built
idx_mode = 0  # first mode = 0, last mode = 63

np.random.seed(0)
torch.manual_seed(0)

task_name = 'all_th'+str(int(dist_th))+'_mode'+str(idx_mode)

In [None]:
with open("data/all_proteins_freqs_no_ratios.dat") as f:
    lines = f.read().split("\n")

N_seq = len(lines)-1 # ignore the last blank line

pdb_ids = []
freqs = []
for idx in range(N_seq):
    line = lines[idx].split(" ")
    pdb_ids.append(line[1])
    freqs.append(float(line[idx_mode+2]))

In [None]:
def load_vocab(filename):
    try:
        d = dict()
        with open(filename) as f:
            for idx, word in enumerate(f):
                word = word.strip()
                d[word] = idx

    except IOError:
        raise MyIOError(filename)
    return d

vocab_chars = load_vocab("food_chars.txt")

def get_node_features(model, pdb_path):
    dssp = DSSP(model, pdb_path)
    seq = []
    ss = []
    phi = []
    psi = []
    for key in list(dssp.keys()):
        seq.append(vocab_chars[dssp[key][1]])
    node_features = torch.tensor(seq, dtype=torch.long)
    return node_features

def pdb2graph(path, pdb_id):
    pdb_path = path+"/"+pdb_id+".pdb"
    parser = PDBParser()
    structure = parser.get_structure(pdb_id, pdb_path)

    # DSSP can only handle one model, and will only run calculations on the first model in the provided PDB file.
    model = structure[0]
    
    node_features = get_node_features(model, pdb_path)
    
    residues = []
    for chain in model:
        #print(chain)
        for residue in chain:
            residues.append(residue)
        #print(len(residues))
    
    edges = []
    CA_distances = []
    for idx1, residue1 in enumerate(residues):
        for idx2, residue2 in enumerate(residues):
            if residue1 != residue2:
                # compute distance between CA atoms
                try:
                    distance = residue1['CA'] - residue2['CA']
                except KeyError:
                    ## no CA atom, e.g. for H_NAG
                    continue
                if distance < dist_th:
                    edges.append([idx1, idx2])
                    CA_distances.append(distance)
    
    edge_index = torch.tensor(edges, dtype=torch.long)
    edge_attr = torch.tensor(CA_distances, dtype=torch.float)
    
    return node_features, edge_index.T, edge_attr

In [None]:
import warnings
warnings.filterwarnings('ignore')

data_list = []
excluded_pdb_ids = [] # exclude problematic graphs from the dataset

t = time.time()
percetage = 0
for idx, pdb_id in enumerate(pdb_ids):
    try:
        node_features, edge_index, edge_attr = pdb2graph('../PDB/pdb', pdb_id)
        data = Data(edge_attr=edge_attr, 
                    edge_index=edge_index, 
                    x=node_features, 
                    y=torch.tensor([freqs[idx]], dtype=torch.float))
        
        if data.x.shape[0] != data.edge_index.max().item()+1:
            excluded_pdb_ids.append(pdb_id)
        elif data.contains_isolated_nodes():
            excluded_pdb_ids.append(pdb_id)
        else:
            data_list.append(data)
    except:
        excluded_pdb_ids.append(pdb_id)
        
    precetage_inc = 1
    if idx % int(N_seq / 100 * precetage_inc) == 0:
        print(f'Data generated: {percetage}%, Time: {time.time()-t:.4f}')
        percetage += precetage_inc
        t = time.time()

In [None]:
N_data = len(data_list)
print(f'Number of proteins: {N_data}')

In [None]:
torch.save(data_list, 'data/'+task_name+'.pt')

In [None]:
freqs_processed = torch.zeros(len(data_list), dtype=torch.float)
for index, data in enumerate(data_list):
    freqs_processed[index] = data.y
torch.save(freqs_processed, 'data/'+task_name+'_freqs.pt')