In [193]:
import os
import math

import numba
import numpy as np
import atomium
import Bio
import torch as th

from scipy.spatial import distance_matrix
protein_letters_3to1 = Bio.SeqUtils.IUPACData.protein_letters_3to1_extended
protein_letters_3to1 = {k.upper() : v for k,v in protein_letters_3to1.items()}

In [239]:
@numba.njit(parallel=True)
def numba_jit_scalar_distance_parallel(xyz):
    rows = xyz.shape[0]
    output = np.empty((rows, rows), dtype=np.float32)
    for i in numba.prange(rows):
        cols = rows - i
        for j in numba.prange(cols):
            tmp = 0.0
            tmp += (xyz[i, 0] - xyz[j, 0])**2
            tmp += (xyz[i, 1] - xyz[j, 1])**2
            tmp += (xyz[i, 2] - xyz[j, 2])**2
            tmp = math.sqrt(tmp) 
            output[i,j] = tmp
            output[j,i] = tmp
    return output

In [107]:
def get_atom_xyz(atoms, atom_name):
    for a in atoms:
        if a.name == atom_name:
            return a.location
    return (np.nan, np.nan, np.nan)

def get_ss_label(residue):
    '''
    E, H or C label from atomium
    '''
    if residue.helix:
        return 'H'
    elif residue.strand:
        return 'E'
    else:
        return 'C'

In [237]:
def parse_graph_data_numba(path_pdb, chain):
    
    if not os.path.isfile(path_pdb):
        FileNotFoundError('no such file', path_pdb)
    file = atomium.open(path_pdb)
    chain = file.model.chain(chain)
    preparation_dict = dict()
    for i, r in enumerate(chain.residues()):
        r_atoms = r.atoms()
        preparation_dict[i] = {'aa' : protein_letters_3to1[r.name],
                                    'charge' : r.charge,
                                    'CA' : get_atom_xyz(r_atoms, 'CA'),
                                    'CB' : get_atom_xyz(r_atoms, 'CB'),
                                    'ss_label' : get_ss_label(r)
                                   }

        ca_xyz = np.asarray(list(map(lambda v : v['CA'], preparation_dict.values())), dtype=np.float32)
        sequence = list(map(lambda v : v['aa'], preparation_dict.values()))
        ca_ca_matrix = numba_jit_scalar_distance_parallel(ca_xyz)
    return ca_ca_matrix, sequence

In [176]:
def parse_graph_data_torch(path_pdb, chain):
    
    if not os.path.isfile(path_pdb):
        FileNotFoundError('no such file', path_pdb)
    file = atomium.open(path_pdb)
    chain = file.model.chain(chain)
    preparation_dict = dict()
    for i, r in enumerate(chain.residues()):
        r_atoms = r.atoms()
        preparation_dict[i] = {'aa' : protein_letters_3to1[r.name],
                                    'charge' : r.charge,
                                    'CA' : get_atom_xyz(r_atoms, 'CA'),
                                    'CB' : get_atom_xyz(r_atoms, 'CB'),
                                    'ss_label' : get_ss_label(r)
                                   }

        ca_xyz = th.FloatTensor(list(map(lambda v : v['CA'], preparation_dict.values())))
        sequence = list(map(lambda v : v['aa'], preparation_dict.values()))

        ca_ca_matrix = th.cdist(ca_xyz, ca_xyz)
    return ca_ca_matrix, sequence

In [184]:
def parse_graph_data(path_pdb, chain):
    
    if not os.path.isfile(path_pdb):
        FileNotFoundError('no such file', path_pdb)
    file = atomium.open(path_pdb)
    chain = file.model.chain(chain)
    preparation_dict = dict()
    for i, r in enumerate(chain.residues()):
        r_atoms = r.atoms()
        preparation_dict[i] = {'aa' : protein_letters_3to1[r.name],
                                    'charge' : r.charge,
                                    'CA' : get_atom_xyz(r_atoms, 'CA'),
                                    'CB' : get_atom_xyz(r_atoms, 'CB'),
                                    'ss_label' : get_ss_label(r)
                                   }

        ca_xyz = np.asarray(list(map(lambda v : v['CA'], preparation_dict.values())), dtype=np.float32)
        sequence = list(map(lambda v : v['aa'], preparation_dict.values()))

        ca_ca_matrix = distance_matrix(ca_xyz, ca_xyz)
    return ca_ca_matrix, sequence

In [169]:
path = '/home/db/localpdb/mirror/ea/pdb6eac.ent.gz'
chain = 'A'

In [185]:
%timeit parse_graph_data(path, chain)

4.11 s ± 3.02 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [186]:
%timeit parse_graph_data_torch(path, chain)

2.99 s ± 7.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [240]:
%timeit parse_graph_data_numba(path, chain)

3.05 s ± 9.05 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [241]:
a,b = parse_graph_data_torch(path, chain)

In [244]:
a.element_size()*a.nelement() / 1024

885.0625