In [1]:
from rdkit import Chem
from rdkit.Chem.rdmolops import GetAdjacencyMatrix
from typing import Union, List, Type, Optional
import torch
import numpy as np
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import pandas as pd
import os
from sklearn.model_selection import train_test_split

In [44]:
def one_hot_encoding(elem : str, allowable_elem : Union[List[str],List[int]])->List[bool]:
    
    if elem not in allowable_elem : elem = allowable_elem[-1]
    
    # [True if k==elem else False for k in allowable_elem]
    # note map is faster
    onehotvec=list(map(lambda dummy : dummy==elem,allowable_elem))
    
    return onehotvec

def get_atom_features(atom : Type[Chem.rdchem.Mol],
                  allowable_elem : List[str] =['C','N','O','S','F','Si','P','Cl','Br','Mg','Na',
                                               'Ca','Fe','As','Al','I', 'B','V','K','Tl','Yb',
                                               'Sb','Sn','Ag','Pd','Co','Se','Ti','Zn', 'Li','Ge',
                                               'Cu','Au','Ni','Cd','In','Mn','Zr','Cr','Pt','Hg','Pb','Unknown'], 
                  use_H_explicit : bool =False,
                  use_chirality : bool =False)->Optional[torch.Tensor]:
    '''
    function calcualtes node features
    args: 
        - atom : Mol object
        - allowable_elem : list of allowable elements
    return :
        - feats : node features
    '''
    # use H
    if use_H_explicit : allowable_elem = ['H'] + allowable_elem
    
    # map elements 
    elem_list_enc=one_hot_encoding(atom.GetSymbol(),allowable_elem)
    
    # get atom degree
    n_neighbors_enc = one_hot_encoding(atom.GetDegree(), [0, 1, 2, 3, 4])
    
    # formal charges
    formal_charge_enc = one_hot_encoding(atom.GetFormalCharge(), [-3, -2, -1, 0, 1, 2, 3])
    
    # hybridization
    hybridisation_type_enc = one_hot_encoding(str(atom.GetHybridization()), ["S", "SP", "SP2", "SP3", "SP3D", "SP3D2", "OTHER"])
    
    # atom belongs to a ring
    is_in_a_ring_enc = [int(atom.IsInRing())]
    
    # presence of aromatic ring
    is_aromatic_enc = [int(atom.GetIsAromatic())]
    
    # chirality and explicit H's
    
    # final encoding
    f_enc=elem_list_enc + formal_charge_enc + hybridisation_type_enc + is_in_a_ring_enc + is_aromatic_enc
    
    return np.array(f_enc).astype('float')

def get_bond_features(bond : Type[Chem.rdchem.Bond], use_stereochemistry : bool = True)->Optional[torch.Tensor]:
    
    # bond type only used common ones for organic molecules
    bond_type_enc=one_hot_encoding(str(bond.GetBondType()),['SINGLE', 'DOUBLE', 'TRIPLE','AROMATIC'])
    
    # conjugation
    bond_is_conj_enc = [int(bond.GetIsConjugated())]
    
    # is ring
    bond_is_in_ring_enc = [int(bond.IsInRing())]
    
    bond_feature_enc = bond_type_enc + bond_is_conj_enc + bond_is_in_ring_enc
    
    if use_stereochemistry == True: 
        bond_feature_enc.extend(one_hot_encoding(str(bond.GetStereo()), ["STEREOZ", "STEREOE", "STEREOANY", "STEREONONE"]))
    
    return np.array(bond_feature_enc).astype('float')
    
def gen_graph_dset(smiles,labels):
    # check num of nodses and edges
    ethane='CC'
    ethane=Chem.MolFromSmiles(ethane)
    num_node_features=len(get_atom_features(ethane.GetAtoms()[0]))
    num_edge_features=len(get_bond_features(ethane.GetBonds()[0]))

    data_list=[]
    #print(len(smiles))
    #print(len(labels))
    for smile, label in zip(smiles,labels):
        print(smile,label)
        #print(label)
        mol=Chem.MolFromSmiles(smile)

        num_atoms=mol.GetNumAtoms()
        atom_features=np.zeros((num_atoms,num_node_features))

        num_bonds=2*mol.GetNumBonds()
        bond_features=np.zeros((num_bonds,num_edge_features))

        # node features
        for idx,atom in enumerate(mol.GetAtoms()):
            atom_features[idx,:]=get_atom_features(atom)

        # adjacency
        adj=np.array([*np.nonzero(GetAdjacencyMatrix(mol))])

        # bond features
        for idx,(i,j) in enumerate(zip(adj[0,:],adj[1,:])):
            bond_features[idx,:]=get_bond_features(mol.GetBondBetweenAtoms(int(i),int(j)))
            
        atom_features=torch.tensor(atom_features,dtype=torch.float)
        adj=torch.tensor(adj,dtype=torch.float)
        y=torch.tensor([label],dtype=torch.float)
        edge_attr=torch.tensor(bond_features,dtype=torch.float)
        # data
        data_list.append(Data(x=atom_features,
            edge_index=adj,
            edge_attr=edge_attr,
            y=y))
        
        return data_list
    
def create_dset(x_smiles, y):
    """
    Inputs:
    
    x_smiles = [smiles_1, smiles_2, ....] ... a list of SMILES strings
    y = [y_1, y_2, ...] ... a list of numerial labels for the SMILES strings (such as associated pKi values)
    
    Outputs:
    
    data_list = [G_1, G_2, ...] ... a list of torch_geometric.data.Data objects which represent labeled molecular graphs that can readily be used for machine learning
    
    """
    
    data_list = []
    
    for (smiles, y_val) in zip(x_smiles, y):
        
        # convert SMILES to RDKit mol object
        mol = Chem.MolFromSmiles(smiles)

        # get feature dimensions
        n_nodes = mol.GetNumAtoms()
        n_edges = 2*mol.GetNumBonds()
        unrelated_smiles = "O=O"
        unrelated_mol = Chem.MolFromSmiles(unrelated_smiles)
        n_node_features = len(get_atom_features(unrelated_mol.GetAtomWithIdx(0)))
        n_edge_features = len(get_bond_features(unrelated_mol.GetBondBetweenAtoms(0,1)))

        # construct node feature matrix X of shape (n_nodes, n_node_features)
        X = np.zeros((n_nodes, n_node_features))

        for atom in mol.GetAtoms():
            X[atom.GetIdx(), :] = get_atom_features(atom)
            
        X = torch.tensor(X, dtype = torch.float)
        
        # construct edge index array E of shape (2, n_edges)
        (rows, cols) = np.nonzero(GetAdjacencyMatrix(mol))
        torch_rows = torch.from_numpy(rows.astype(np.int64)).to(torch.long)
        torch_cols = torch.from_numpy(cols.astype(np.int64)).to(torch.long)
        E = torch.stack([torch_rows, torch_cols], dim = 0)
        
        # construct edge feature array EF of shape (n_edges, n_edge_features)
        EF = np.zeros((n_edges, n_edge_features))
        
        for (k, (i,j)) in enumerate(zip(rows, cols)):
            
            EF[k] = get_bond_features(mol.GetBondBetweenAtoms(int(i),int(j)))
        
        EF = torch.tensor(EF, dtype = torch.float)
        
        print(E)
        
        # construct label tensor
        y_tensor = torch.tensor(np.array([y_val]), dtype = torch.float)
        
        # construct Pytorch Geometric data object and append to data list
        data_list.append(Data(x = X, edge_index = E, edge_attr = EF, y = y_tensor))

    return data_list


def graph_dset_2(x_smiles,y):

    data_list=[]
    print(len(x_smiles))
    print(len(y))
    
        
    for (smile, y_val) in zip(x_smiles,y):
        # convert SMILES to RDKit mol object
        print(smile, y_val)
        mol = Chem.MolFromSmiles(smile)

        # get feature dimensions
        n_nodes = mol.GetNumAtoms()
        n_edges = 2*mol.GetNumBonds()
        unrelated_smiles = "O=O"
        unrelated_mol = Chem.MolFromSmiles(unrelated_smiles)
        n_node_features = len(get_atom_features(unrelated_mol.GetAtomWithIdx(0)))
        n_edge_features = len(get_bond_features(unrelated_mol.GetBondBetweenAtoms(0,1)))

        # construct node feature matrix X of shape (n_nodes, n_node_features)
        X = np.zeros((n_nodes, n_node_features))

        for atom in mol.GetAtoms():
            X[atom.GetIdx(), :] = get_atom_features(atom)
            
        X = torch.tensor(X, dtype = torch.float)
        
        # construct edge index array E of shape (2, n_edges)
        (rows, cols) = np.nonzero(GetAdjacencyMatrix(mol))
        torch_rows = torch.from_numpy(rows.astype(np.int64)).to(torch.long)
        torch_cols = torch.from_numpy(cols.astype(np.int64)).to(torch.long)
        E = torch.stack([torch_rows, torch_cols], dim = 0)
        
        # construct edge feature array EF of shape (n_edges, n_edge_features)
        EF = np.zeros((n_edges, n_edge_features))
        
        for (k, (i,j)) in enumerate(zip(rows, cols)):
            
            EF[k] = get_bond_features(mol.GetBondBetweenAtoms(int(i),int(j)))
        
        EF = torch.tensor(EF, dtype = torch.float)
        
        # construct label tensor
        y_tensor = torch.tensor(np.array([y_val]), dtype = torch.float)
        
        # construct Pytorch Geometric data object and append to data list
        data_list.append(Data(x = X, edge_index = E, edge_attr = EF, y = y_tensor))

        return data_list



In [45]:
data_list=graph_dset_2(['C=C(Br)c1ccccc1','c1ccccc1'],[1,1])
data_list

2
2
C=C(Br)c1ccccc1 1


[Data(x=[9, 59], edge_index=[2, 18], edge_attr=[18, 10], y=[1])]

In [36]:
data_list=gen_graph_dset(['C=C(Br)c1ccccc1','c1ccccc1'],[1,1])
data_list

C=C(Br)c1ccccc1 1


[Data(x=[9, 59], edge_index=[2, 18], edge_attr=[18, 10], y=[1])]

In [41]:
data_list1=create_dset(['C=C(Br)c1ccccc1','c1ccccc1'],[1,1])
data_list1

tensor([[0, 1, 1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8],
        [1, 0, 2, 3, 1, 1, 4, 8, 3, 5, 4, 6, 5, 7, 6, 8, 3, 7]])
tensor([[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5],
        [1, 5, 0, 2, 1, 3, 2, 4, 3, 5, 0, 4]])


[Data(x=[9, 59], edge_index=[2, 18], edge_attr=[18, 10], y=[1]),
 Data(x=[6, 59], edge_index=[2, 12], edge_attr=[12, 10], y=[1])]

In [38]:
smiles=['C=C(Br)c1ccccc1','c1ccccc1']
labels=[1,1]
for (smile, label) in zip(smiles,labels):
        # convert SMILES to RDKit mol object
        print(smile)
        #print(mol)

C=C(Br)c1ccccc1
c1ccccc1
