In [None]:
#|default_exp molgraphdataset

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#|export
__all__ = ['TARGET2TYPE', 'TYPE2TARGET', 'CHIRAL_DICT', 'HYBRIDIZATION', 'ELEMENTS', 'encode_onehot', 'AtomFeaturizer',
           'BondFeaturizer', 'MolecularGraph', 'MolecularDataset', 'save_df', 'MolecularGraphDataset',
           'molgraph_collate_fn', 'MolGraphDataLoader']


In [None]:
#|export
import torch
from torch import nn
from fastai.data.core import DataLoader, DataLoaders, Datasets
from fastcore.foundation import *
from fastcore.basics import *
from fastai.torch_core import Module
import pandas as pd
import numpy as np

from rdkit.Chem import AllChem
from rdkit import Chem, rdBase
from rdkit.Chem import rdchem, rdmolops, SanitizeMol
from rdkit.Chem.FilterCatalog import *
from rdkit.Chem.MolStandardize.rdMolStandardize import LargestFragmentChooser
from rdkit.Chem.SaltRemover import SaltRemover

import scipy as sp
import multiprocessing as mp
from tqdm import tqdm
from datetime import datetime

from concurrent.futures import ProcessPoolExecutor

from sklearn.utils.multiclass import type_of_target
from sklearn.utils.multiclass import unique_labels

from molgraph.utils.sanitizer import *

from typing import Tuple, List, Collection, Iterator

from pathlib import Path

from fastai.data.transforms import TrainTestSplitter, IndexSplitter, ColSplitter

In [None]:
#|export


# Cell
TARGET2TYPE = {'continuous': 'regression',
               'binary': 'classification',
               'continuous-multioutput': 'regression-multi',
               'multiclass': 'multiclass',
               'multilabel-indicator': 'multilabel'}

TYPE2TARGET = {v: k for k, v in TARGET2TYPE.items()}

CHIRAL_DICT = {Chem.ChiralType.CHI_UNSPECIFIED:0,
Chem.ChiralType.CHI_TETRAHEDRAL_CW:1,
Chem.ChiralType.CHI_TETRAHEDRAL_CCW:2,
Chem.ChiralType.CHI_OTHER:3}

HYBRIDIZATION = {Chem.rdchem.HybridizationType.UNSPECIFIED: 0,
 Chem.rdchem.HybridizationType.S: 1,
 Chem.rdchem.HybridizationType.SP: 2,
 Chem.rdchem.HybridizationType.SP2: 3,
 Chem.rdchem.HybridizationType.SP3: 4,
 Chem.rdchem.HybridizationType.SP3D: 5,
 Chem.rdchem.HybridizationType.SP3D2: 6,
 Chem.rdchem.HybridizationType.OTHER: 7}


ELEMENTS = [1,5,6,7,8,9,15,16,17,35,53]

In [None]:
#|export

def encode_onehot(value: int, choices):
    """
    Creates a one-hot encoding with an extra category for uncommon values.
    :param value: The value for which the encoding should be one.
    :param choices: A list of possible values.
    :return: A one-hot encoding of the :code:`value` in a list of length :code:`len(choices) + 1`.
             If :code:`value` is not in :code:`choices`, then the final element in the encoding is 1.
    """
    encoding = [0] * len(choices)
    index = choices.index(value) if value in choices else -1
    encoding[index] = 1

    return torch.tensor(encoding).view(1,-1).float()

def convert_smiles(mol, sanitize=False):
    if isinstance(mol, str):
        try:
            mol = Chem.MolFromSmiles(mol, sanitize=sanitize)
            return mol
        except:
            return None
    elif isinstance(mol, rdchem.Mol):
        return mol
    
class AtomFeaturizer:
    def __init__(self):

        self.ATOM_FEATURES = {
            'atomic_num': ELEMENTS,
            'degree': [0, 1, 2, 3, 4, 5],
            'formal_charge': [-3, -2, -1, 0, 1, 2],
            'chiral_tag': [0, 1, 2, 3],
            'num_Hs': [0, 1, 2, 3, 4],
            'hybridization': [
                Chem.rdchem.HybridizationType.SP,
                Chem.rdchem.HybridizationType.SP2,
                Chem.rdchem.HybridizationType.SP3,
                Chem.rdchem.HybridizationType.SP3D,
                Chem.rdchem.HybridizationType.SP3D2,
            ]}

        self.NODE_DIM = sum([len(f) for f in self.ATOM_FEATURES.values()]) + 2

    def get_atomic_num(self, atom=None):
        atom = self.atom if atom is None else atom
        return atom.GetAtomicNum()

    def get_degree(self, atom=None):
        atom = self.atom if atom is None else atom
        return atom.GetTotalDegree()

    def get_formal_charge(self, atom=None):
        atom = self.atom if atom is None else atom
        return atom.GetFormalCharge()

    def get_chiral(self, atom=None):
        atom = self.atom if atom is None else atom
        return int(atom.GetChiralTag())

    def get_numh(self, atom=None):
        atom = self.atom if atom is None else atom
        return int(atom.GetTotalNumHs())

    def get_hybridization(self, atom=None):
        atom = self.atom if atom is None else atom
        return int(atom.GetHybridization())

    def get_is_aromatic(self, atom=None):
        atom = self.atom if atom is None else atom
        return atom.GetIsAromatic()

    def get_mass(self, atom=None):
        atom = self.atom if atom is None else atom
        return atom.GetMass()*0.01


    def encode(self, atom:Chem.rdchem.Atom):
        features = [encode_onehot(self.get_atomic_num(atom), self.ATOM_FEATURES['atomic_num']),
                   encode_onehot(self.get_degree(atom), self.ATOM_FEATURES['degree']),
                    encode_onehot(self.get_formal_charge(atom), self.ATOM_FEATURES['formal_charge']),
                   encode_onehot(self.get_chiral(atom), self.ATOM_FEATURES['chiral_tag']),
                   encode_onehot(self.get_numh(atom), self.ATOM_FEATURES['num_Hs']),
                   encode_onehot(self.get_hybridization(atom),self.ATOM_FEATURES['hybridization']),
                   torch.tensor(int(self.get_is_aromatic(atom))).view(1,-1),
                   torch.tensor(self.get_mass(atom)).view(1,-1)]
        return torch.cat(features, -1)


class BondFeaturizer:
    def __init__(self):

        self.BOND_TYPES = {Chem.rdchem.BondType.SINGLE: 0,
                      Chem.rdchem.BondType.DOUBLE: 1,
                      Chem.rdchem.BondType.TRIPLE: 2,
                      Chem.rdchem.BondType.AROMATIC: 3}


        self.EDGE_FEATURES = {'bond_type':[Chem.rdchem.BondType.SINGLE,
                                           Chem.rdchem.BondType.DOUBLE,
                                           Chem.rdchem.BondType.TRIPLE,
                                           Chem.rdchem.BondType.AROMATIC],
                             'stereo_info':list(range(6))}

        self.BOND_DIM = sum([len(f) for f in self.EDGE_FEATURES.values()]) + 2

    def get_bondtype(self, bond:Chem.rdchem.Bond):
        return self.BOND_TYPES[bond.GetBondType()]

    def get_stereo(self, bond:Chem.rdchem.Bond):
        return bond.GetStereo()

    def get_isconjugated(self, bond:Chem.rdchem.Bond):
        return bond.GetIsConjugated()

    def get_isring(self, bond:Chem.rdchem.Bond):
        return bond.IsInRing()

    def encode(self, bond: Chem.rdchem.Bond):
        """
        Builds a feature vector for a bond.
        :param bond: An RDKit bond.
        :return: A list containing the bond features.
        """
        if bond is None:
            features = [0] * self.BOND_DIM
        else:
            features = [encode_onehot(self.get_bondtype(bond), self.EDGE_FEATURES['bond_type']),
                        encode_onehot(self.get_stereo(bond), self.EDGE_FEATURES['stereo_info']),
                       torch.tensor(self.get_isconjugated(bond)).view(1,-1),
                       torch.tensor(self.get_isring(bond)).view(1,-1)]

        return torch.cat(features,-1)


class MolecularGraph:
    """
    Parent class for Molecular graphs

    Converts a Chem.Mol object into a molecular graph representation


    Arguments:
    -----------------------------------------------------------------

        smiles : str
            A SMILES representing a molecular structure


     Attributes:
    -----------------------------------------------------------------

        mol : Chem.Mol
            `Chem.Mol` object generated from `smiles`

        n_nodes : int
            Number of atoms (nodes) in the graph

        n_edges : int
            Number of bonds (edges) in the graph

        node_features : numpy.array
            An array with node features

        edge_features : numpy.array
            An array with edge features

    """

    def __init__(self, smiles:str, target:np.array):


        self.smiles = smiles
        self.y = torch.tensor(np.array(target,dtype=np.float32))#.view(-1,1)
        self.mol = convert_smiles(smiles)
        self.n_nodes = len(self.mol.GetAtoms())
        self.n_edges = len(self.mol.GetBonds())
        self.node_features = None
        self.edge_features = None

        self.atom_featurizer = AtomFeaturizer()
        self.bond_featurizer = BondFeaturizer()

        self.N_NODE_FEATURES = getattr(self.atom_featurizer, 'NODE_DIM')
        self.N_EDGE_FEATURES = getattr(self.bond_featurizer, 'BOND_DIM')

    def mol_to_graph(self, max_nodes=None):
        """
        Generates the graph representation (`self.node_features` and
        `self.edge_features`) when creating a new `PreprocessingGraph`.
        """

        self.node_features = torch.zeros((self.n_nodes, self.N_NODE_FEATURES))
        self.edge_features = torch.zeros((self.n_nodes, self.n_nodes, self.N_EDGE_FEATURES)).long()
        self.pair_indices = []



        for idx, atom in enumerate(self.mol.GetAtoms()):
            self.node_features[idx] = self.atom_featurizer.encode(atom)

        for idx1 in range(self.n_nodes):

            for idx2 in range(idx1+1, self.n_nodes):

                bond = self.mol.GetBondBetweenAtoms(idx1, idx2)

                if bond:
                    self.edge_features[idx1,idx2] = self.bond_featurizer.encode(bond)
                    self.edge_features[idx2,idx1] = self.bond_featurizer.encode(bond)
                    #self.pair_indices[idx1,0] = idx1
                    #self.pair_indices[idx1,1] = idx2


                else:
                    continue

        # Get adjacency matrix for each atom -> the adjancecy matrix is calculated by summing every atom row
        self.adjacency_matrix = torch.tensor(Chem.GetAdjacencyMatrix(self.mol)).long()
        self.adjacency_matrix_norm = self.adjacency_matrix  + torch.eye(self.adjacency_matrix.shape[1]) # Adjacency + identity matrix
        self.degree_matrix = torch.tensor(sp.linalg.fractional_matrix_power(torch.diag(torch.sum(self.adjacency_matrix_norm,-1)),-0.5))

        return self


class MolecularDataset:

    """
    A generic class to handle tabular (e.g. dataframes) data for QSAR tasks.

    Parameters
    ------------------------------------------------------------------------------------------------------------

        df : DataFrame
            A `pandas.DataFrame` object with the data used for the modeling task.

        targets_col : str
            A string representing the column of `df` with the target variable.

        smiles_col : str
            A string representing the column of `df` with SMILES.


    Attributes
    ------------------------------------------------------------------------------------------------------------

        df : `DataFrame`
            A `pandas.DataFrame` object with the data used for the modeling task.

        cols : array-like
            An array of columns in `df`

        targets_col : str
            A string representing the column of `df` with the target variable.

        smiles_col : str
            A string representing the column of `df` with SMILES.

        smiles : array-like
            An array of SMILES of the molecules in `df`.

        y : array-like
            An array of target variable in `df`.

        data : array-like
            A n x p matrix with n SMILES and p y values.


    """

    def __init__(self,
                 df:pd.DataFrame,
                 smiles_col:str,
                 targets_col:str):

        self.df = df
        self.cols = df.columns
        self.smiles_col = smiles_col
        self.targets_col = targets_col
        self.smiles = df[smiles_col].values
        self._x = None
        self._y = df[targets_col].values
        self.data = df[[smiles_col, targets_col]].values


    @property
    def x(self):
        return self._x

    @x.setter
    def x(self, i):
        self._x = i

    @property
    def y(self):
        return self._y

    @y.setter
    def y(self, i):
        self._y = i

    def copy(self):

        """Returns a copy of MolecularDataset"""

        return copy.deepcopy(self)

    def __len__(self):

        """Returns the size of MolecularDataset"""

        return len(self.data)

    def __getitem__(self, i):
        """

        Return a the i-th pair of (smiles, y) values

        """
        return self.data[i]


def save_df(df, fname:str):
    """Save DataFrame to file"""
    df.to_csv(fname, index=False)

class MolecularGraphDataset(MolecularDataset):

    """
    Class to create a dataset for MolecularGraphDataset modeling.

    The factory methods from_df and from_csv can be used to directly load data from pandas DataFrames or CSV files.

    Instead use static methods `MolecularGraphDataset.from_df` or `MolecularGraphDataset.from_csv` for that purpose.

    Parameters
    ------------------------------------------------------------------------------------------------------------

        data : DataFrame

            A `pandas.DataFrame` object with the data used for the modeling task.

        targets_col : str

            A string representing the column of `data` with the target variable.

        smiles_col : str

            A string representing the column of `data` with SMILES.


    Attributes
    ------------------------------------------------------------------------------------------------------------

        data : MolecularDataset

            A MolecularDataset instance.

        x : array-like (n_samples, n_bits)

            An array with calculated descriptors for the processed dataset.

        y : array-like (n_samples, n_classes) or (n_samples, )

            An array with target variable values for each sample.

        train : MolecularDataset

            A MolecularDataset for the training set (defined with splitter).

        valid : MolecularDataset

            A MolecularDataset for the validation set (defaults to 20% of the data).

        splits : List[List, List]

            A list containing indices of samples belonging to training and validation sets.

        dtype : str

            A string representing the type of the target variable (`continuous`, `binary`, `multiclass`, `continuous-multioutput` or `multi-label indicatior`).


        jobtype : str

            Type of modeling task to perform based on `dtype`.

        timestamp : str

            A unique time identifier for the modeling task.

        save_dir : Path

            A `Path` object representing a folder where files and outputs will be saved.
    """

    def __init__(self, df: pd.DataFrame,smiles_col:str,targets_col:str,task_id:str=None):
        super().__init__(df,smiles_col,targets_col)

        self._x = None
        self._train = None
        self._valid = None
        self.task_id = task_id
        self._splits = None
        self.dtype = type_of_target(self.y)
        self.max_nodes = max(Chem.MolFromSmiles(m).GetNumAtoms() for m in self.smiles)
        self.job_type = TARGET2TYPE.get(self.dtype, None)
        self.save_dir = Path(f'{task_id}/{self.job_type}') if task_id else Path(f'GraphModel/{self.job_type}')

        # Create output folder
        Path(self.save_dir).mkdir(parents=True, exist_ok=True)

        if self.job_type in ['classification','multiclass']:
            self.classes = unique_labels(self.y)
            self.c = len(self.classes)

        self._timestamp = datetime.now()


    def copy(self):
        return copy.deepcopy(self)

    def __len__(self):
        return len(self.data)

    def __str__(self):
        return f'Time stamp: {self.timestamp}\nTarget: {self.targets_col}\nData type: {self.dtype}\nNumber of compounds: {len(self.data)}'

    @property
    def timestamp(self):
        return self._timestamp

    @property
    def x(self):
        return self._x

    @x.setter
    def x(self, i):
        self._x = i

    @property
    def y(self):
        return self._y

    @y.setter
    def y(self, i):
        self._y = i

    @property
    def save_dir(self):
        return self._save_dir

    @save_dir.setter
    def save_dir(self, i):
        self._save_dir = i

    @property
    def splits(self):
        return self._splits

    @splits.setter
    def splits(self, i):
        self._splits = i

    @property
    def train(self):
        return self._train

    @train.setter
    def train(self, i):
        self._train = i

    @property
    def valid(self):
        return self._valid

    @valid.setter
    def valid(self, i):
        self._valid = i


    @classmethod
    def from_csv(cls, data_path : str,
                 smiles_col : str,
                 targets_col : str,
                 task_id:str):

        """Build a `MolecularGraphDataset` from a csv file using `smiles_cols`, `targets_cols` and `task_id`.

        Arguments
        ------------------------------------------------------------------------------------------------------------

        data_path : str
            A string representing the path to the modeling dataset

        targets_col : str
            A string representing the column of `data` with the target variable.

        smiles_col : str
            A string representing the column of `data` with SMILES.

        process_data : bool
            If True, process the dataset using `MolecularGraphDataset.preprocessing`

        Returns
        ------------------------------------------------------------------------------------------------------------
            Returns a `MolecularGraphDataset` for ready modeling tasks.

        """

        return cls.from_df(pd.read_csv(data_path),
                           smiles_col=smiles_col,
                           targets_col=targets_col,
                           task_id=task_id)

    @classmethod
    def from_df(cls,
                df : pd.DataFrame,
                smiles_col : str,
                targets_col : str,
                task_id:str):

        """Builds a `MolecularGraphDataset` from a dataframe using `smiles_cols`, `targets_cols` and `task_id`.

        Arguments
        ------------------------------------------------------------------------------------------------------------

            df : DataFrame
                A `pandas.DataFrame` object with the data used for the modeling task.

            targets_col : str
                A string representing the column of `data` with the target variable.

            smiles_col : str
                A string representing the column of `data` with SMILES.

            censored : bool
                If true, removes censored samples from `df`.

        Returns
        ------------------------------------------------------------------------------------------------------------
            Returns a `MolecularGraphDataset` for ready modeling tasks.

        """

        df = df.copy()

        return MolecularGraphDataset(df=df,
                                smiles_col=smiles_col,
                                targets_col=targets_col,
                                     task_id=task_id).create_dataset()


    def create_dataset(self, splitter : Iterator=None, random_state: int = None, test_size: float = 0.2):

        """Creates a dataset for modeling applying featurization and splitting the data into train and test sets.

        Arguments
        ------------------------------------------------------------------------------------------------------------

            splitter : Splitter
                A class or function that returns splits of the data. Default: TrainTestSplitter.

            random_state : int
                A number for random seed.

            test_size : float
                The amount of data used as test set.

        Returns
        ------------------------------------------------------------------------------------------------------------
            Returns a `MolecularGraphDataset` with train/test splits, calculated features and targets.
        """

        # Create Train/Valid datasets
        stratify = self.y if self.job_type in ['classification','multiclass'] else None
        splitter = splitter if splitter else TrainTestSplitter(test_size=test_size, random_state=random_state,stratify=stratify)
        setattr(splitter,'random_state', random_state)
        setattr(splitter, 'test_size', test_size)
        self.splits = splitter(self.data)

        self.train = MolecularGraphDataset(df=self.df.iloc[self.splits[0]].reset_index(drop=True),
                                           smiles_col=self.smiles_col,
                                           targets_col=self.targets_col)

        self.valid = MolecularGraphDataset(df=self.df.iloc[self.splits[1]].reset_index(drop=True),
                                           smiles_col=self.smiles_col,
                                           targets_col=self.targets_col)

        self.train.max_nodes = self.max_nodes
        self.valid.max_nodes = self.max_nodes

        # Save files
        save_df(self.train.df, self.save_dir/f'{self.task_id}_{self.job_type}_{self.timestamp.date()}_trainset.csv')
        save_df(self.df, self.save_dir/f'{self.task_id}_{self.job_type}_{self.timestamp.date()}_dataset_full.csv')
        save_df(self.valid.df, self.save_dir/f'{self.task_id}_{self.job_type}_{self.timestamp.date()}_testset.csv')

        return self

    def __getitem__(self, i):

        """Featurize dataset and optionally updates splits features"""

        graph = MolecularGraph(*self.data[i]).mol_to_graph(max_nodes=self.max_nodes)
        return (graph.node_features, graph.edge_features, graph.adjacency_matrix_norm, graph.degree_matrix), graph.y

        return self


def molgraph_collate_fn(data):
    n_samples = len(data)
    (node_features_0, edge_features_0, adjacency_0, degree_0), targets_0 = data[0]
    n_nodes_largest_graph = max(map(lambda sample: sample[0][0].size(0), data))
    n_node_features = node_features_0.size(-1)
    n_edge_features = edge_features_0.size(-1)

    adjacency_tensors = torch.zeros(n_samples, n_nodes_largest_graph, n_nodes_largest_graph)

    degree_tensors = torch.zeros(n_samples, n_nodes_largest_graph, n_nodes_largest_graph)

    node_tensors = torch.zeros(n_samples, n_nodes_largest_graph, n_node_features)

    edge_tensors = torch.zeros(n_samples, n_nodes_largest_graph, n_nodes_largest_graph, n_edge_features)

    target_tensors = torch.zeros(n_samples)

    for i in range(n_samples):
        (node_features, edge_features, adjacency, degree), target = data[i]
        n_nodes = node_features.size(0)

        adjacency_tensors[i, :n_nodes, :n_nodes] = adjacency
        degree_tensors[i, :n_nodes, :n_nodes] = degree
        node_tensors[i, :n_nodes, :] = node_features
        edge_tensors[i, :n_nodes, :n_nodes, :] = edge_features
        target_tensors[i] = target

    return (node_tensors, edge_tensors, adjacency_tensors, degree_tensors), target_tensors


class MolGraphDataLoader(DataLoader):

    @classmethod
    def dataloders(cls, dataset:Tuple[MolecularGraphDataset, MolecularGraphDataset], bs:int=32, shuffle:bool=True, collate_fn=None, drop_last:bool=True, device='cpu'):

        if collate_fn is None:
            raise ValueError('The collate function is invalid. Please pass a valid function.')

        train_shuffle = shuffle
        valid_shuffle = not train_shuffle

        train_dl =  DataLoader(dataset[0], bs=bs, shuffle=train_shuffle, create_batch=collate_fn, drop_last=True)
        valid_dl =  DataLoader(dataset[1], bs=bs, shuffle=valid_shuffle, create_batch=collate_fn, drop_last=True)
        dls = DataLoaders(train_dl, valid_dl)
        return dls

In [None]:
#| hide
from nbdev import nbdev_export
nbdev_export()

Converted molgraphdataset.ipynb.
