In [None]:
import src
from src import oic_dwic
from src import ecif
from src import utils

In [None]:
dwic = oic_dwic.InterAtomicContact(
    pathfiles="./test/",
    filename="dwic_fv.csv",
    ligand_format="mol2",
    amino_acid_classes=utils.amino_acid_classes_DWIC,
    cutoff=12.0,
    feature_type="DWIC",
    exp=2,
)

oic = oic_dwic.InterAtomicContact(
    pathfiles="./test/",
    filename="oic_fv.csv",
    ligand_format="mol2",
    amino_acid_classes=utils.amino_acid_classes_OIC,
    cutoff=12.0,
    feature_type="OIC",
    exp=None,
)

ecif = ecif.ECIF(
    pathfiles="./test/", filename="ecif_fv.csv", ligand_format="sdf", cutoff=6.0
)

In [None]:
dwic.generate_features(n_jobs=-1)
oic.generate_features(n_jobs=-1)
ecif.generate_features(n_jobs=-1)

In [94]:
import pandas as pd
import numpy as np
from scipy.spatial.distance import cdist
import itertools
from collections import OrderedDict
from Bio.PDB import PDBParser
from src.script import mol2parser
from src.feature_generators.oic_dwic import FeatureGenerator

In [170]:
help(cdist)

Help on function cdist in module scipy.spatial.distance:

cdist(XA, XB, metric='euclidean', *, out=None, **kwargs)
    Compute distance between each pair of the two collections of inputs.
    
    See Notes for common calling conventions.
    
    Parameters
    ----------
    XA : array_like
        An :math:`m_A` by :math:`n` array of :math:`m_A`
        original observations in an :math:`n`-dimensional space.
        Inputs are converted to float type.
    XB : array_like
        An :math:`m_B` by :math:`n` array of :math:`m_B`
        original observations in an :math:`n`-dimensional space.
        Inputs are converted to float type.
    metric : str or callable, optional
        The distance metric to use. If a string, the distance function can be
        'braycurtis', 'canberra', 'chebyshev', 'cityblock', 'correlation',
        'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'jensenshannon',
        'kulsinski', 'kulczynski1', 'mahalanobis', 'matching', 'minkowski',
        

In [162]:
msoic = MultiShellOIC(pathfiles="./test/", filename="msoic_fv.csv", ligand_format="mol2", n_shells=62)

In [168]:
msoic.generate_features(n_jobs=-1)

Time: 0.0:0.0:8.28


In [161]:
class MultiShellOIC(FeatureGenerator):

    def __init__(
        self, pathfiles: str, filename: str, ligand_format: str, n_shells: int
    ) -> None:
        super().__init__(pathfiles, filename, ligand_format)
        self.n_shells = n_shells
        self.defined_residues = [
            "GLY",
            "ALA",
            "VAL",
            "LEU",
            "ILE",
            "PRO",
            "PHE",
            "TYR",
            "TRP",
            "SER",
            "THR",
            "CYS",
            "MET",
            "ASN",
            "GLN",
            "ASP",
            "GLU",
            "LYS",
            "ARG",
            "HIS",
            "OTH",
        ]
        self.defined_elements = ["H", "C", "O", "N", "P", "S", "Hal", "DU"]
        self.keys = [
            "_".join(x)
            for x in list(
                itertools.product(self.defined_residues, self.defined_elements)
            )
        ]
        self.feature_names = [
            key + "_" + str(shell)
            for shell in range(1, n_shells + 1)
            for key in self.keys
        ]

    def features_generator(self, ligand_file: str, protein_file: str) -> dict:

        ligand_element_list, ligand_coords_list = self._loadmol2(ligand_file)
        residue_list, all_residue_coords_list = self._loadpdb(protein_file)
        residue_atom_dist, residue_atom_pairs = self._calculate_distance(
            residue_list,
            all_residue_coords_list,
            ligand_element_list,
            ligand_coords_list,
        )
        feature_vector = self._count_contacts(
            residue_atom_dist, residue_atom_pairs, self.n_shells, self.keys
        )

        return dict(zip(self.feature_names, feature_vector))

    def _loadmol2(self, ligand_file: str) -> tuple:

        ligand = mol2parser.Mol2Parser(ligand_file)
        ligand.parse()

        ligand_element_list = list(
            map(lambda x: x[0], ligand.molecule_info["atom_name"].values())
        )

        for item in range(len(ligand_element_list)):

            if ligand_element_list[item] in ["F", "Cl", "Br", "I"]:
                ligand_element_list[item] = "Hal"

            elif ligand_element_list[item] not in ["H", "C", "O", "N", "P", "S"]:
                ligand_element_list[item] = "DU"

            else:
                continue

        ligand_coords_list = np.array(
            list(ligand.molecule_info["coords"].values())
        ).astype(np.float32)

        return (ligand_element_list, ligand_coords_list * 0.1)

    def _loadpdb(self, protein_file: str) -> tuple:

        parser = PDBParser(PERMISSIVE=True, QUIET=True)
        protein = parser.get_structure("", protein_file)

        residue_list = []
        all_residue_coords_list = []

        for residue in protein.get_residues():

            residue_coords = []
            if residue.get_resname() in self.defined_residues:
                residue_list.append(residue.get_resname())
            else:
                residue_list.append("OTH")

            for atom in residue.get_atoms():

                if atom.element != "H":
                    residue_coords.append(list(atom.get_coord() * 0.1))

            all_residue_coords_list.append(np.array(residue_coords))

        return (residue_list, all_residue_coords_list)

    def _calculate_distance(
        self,
        residue_list: list,
        all_residue_coords_list: np.array,
        ligand_element_list: list,
        ligand_coords_list: np.array,
    ) -> tuple:

        residue_atom_dist = []
        residue_atom_pairs = []

        for res, res_coords in zip(residue_list, all_residue_coords_list):
            for ele, atom_coords in zip(ligand_element_list, ligand_coords_list):
                pair = f"{res}_{ele}"
                dist_mtx = cdist(
                    atom_coords.reshape(1, -1), res_coords, metric="euclidean"
                )
                residue_atom_pairs.append(pair)
                residue_atom_dist.append(dist_mtx.min())

        residue_atom_dist = np.array(residue_atom_dist)

        return (residue_atom_dist, residue_atom_pairs)

    def _count_contacts(
        self,
        residue_atom_dist: np.array,
        residue_atom_pairs: list,
        n_shells: int,
        keys: dict,
    ) -> np.array:

        outermost = 0.05 * (self.n_shells + 1)
        ncutoffs = np.linspace(0.1, outermost, self.n_shells)

        temp_counts = []
        onion_counts = []

        for i, cutoff in enumerate(ncutoffs):
            contact_bool = (residue_atom_dist <= cutoff) * 1
            if i == 0:
                onion_counts.append(contact_bool)
            else:
                onion_counts.append(contact_bool - temp_counts[-1])
            temp_counts.append(contact_bool)
        temp_counts = []

        results = []

        for n in range(len(ncutoffs)):
            d = OrderedDict()
            d = d.fromkeys(self.keys, 0)
            for e_e, c in zip(residue_atom_pairs, onion_counts[n]):
                d[e_e] += c
            results.append(np.array(list(d.values())).ravel())
        results = np.concatenate(results, axis=0)

        return results