In [1]:
import torch
import numpy as np
import os
from torch import nn
import sys
import polars as pl
import sklearn
import shap
from functools import partial
import pandas as pd
import random
from matplotlib import pyplot as plt
import time
from rdkit.Chem import DataStructs
from tqdm import tqdm
from joblib import Parallel, delayed
import math
from skfp import fingerprints as skfps


sys.path.append('../')

In [2]:
train_df = pl.scan_csv('/home/dangnh36/datasets/competitions/leash_belka/processed/train_v2.csv').head(100).select(
        pl.col('molecule'),
        pl.col('bb1', 'bb2', 'bb3').cast(pl.UInt16),
        pl.col('BRD4', 'HSA', 'sEH').cast(pl.UInt8),
    ).collect()
print(train_df.estimated_size('gb'), 'GB')
train_df

7.471069693565369e-06 GB


molecule,bb1,bb2,bb3,BRD4,HSA,sEH
str,u16,u16,u16,u8,u8,u8
"""C#CCOc1ccc(CNc…",1640,1653,765,0,0,0
"""C#CCOc1ccc(CNc…",1640,1653,205,0,0,0
"""C#CCOc1ccc(CNc…",1640,1653,1653,0,0,0
"""C#CCOc1ccc(CNc…",1640,1653,146,0,0,0
"""C#CCOc1ccc(CNc…",1640,1653,439,0,0,0
…,…,…,…,…,…,…
"""C#CCOc1ccc(CNc…",1640,1653,1358,0,0,0
"""C#CCOc1ccc(CNc…",1640,1653,1215,0,0,0
"""C#CCOc1ccc(CNc…",1640,1653,5,0,0,0
"""C#CCOc1ccc(CNc…",1640,1653,1705,0,0,0


In [16]:
from rdkit import Chem


PACK_NODE_DIM=9
PACK_EDGE_DIM=1
NODE_DIM=PACK_NODE_DIM*8
EDGE_DIM=PACK_EDGE_DIM*8

def one_of_k_encoding(x, allowable_set, allow_unk=False):
    if x not in allowable_set:
        if allow_unk:
            x = allowable_set[-1]
        else:
            raise Exception(f'input {x} not in allowable set{allowable_set}!!!')
    return list(map(lambda s: x == s, allowable_set))


#Get features of an atom (one-hot encoding:)
'''
    1.atom element: 44+1 dimensions    
    2.the atom's hybridization: 5 dimensions
    3.degree of atom: 6 dimensions                        
    4.total number of H bound to atom: 6 dimensions
    5.number of implicit H bound to atom: 6 dimensions    
    6.whether the atom is on ring: 1 dimension
    7.whether the atom is aromatic: 1 dimension           
    Total: 70 dimensions
'''

ATOM_SYMBOL = [
    'C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg',
    'Na', 'Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl',
    'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H',
    'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr', 'Cr',
    'Pt', 'Hg', 'Pb', 'Dy',
    #'Unknown'
]
#print('ATOM_SYMBOL', len(ATOM_SYMBOL))44
HYBRIDIZATION_TYPE = [
    Chem.rdchem.HybridizationType.S,
    Chem.rdchem.HybridizationType.SP,
    Chem.rdchem.HybridizationType.SP2,
    Chem.rdchem.HybridizationType.SP3,
    Chem.rdchem.HybridizationType.SP3D
]

def get_atom_feature(atom):
    feature = (
         one_of_k_encoding(atom.GetSymbol(), ATOM_SYMBOL)
       + one_of_k_encoding(atom.GetHybridization(), HYBRIDIZATION_TYPE)
       + one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5])
       + one_of_k_encoding(atom.GetTotalNumHs(), [0, 1, 2, 3, 4, 5])
       + one_of_k_encoding(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5])
       + [atom.IsInRing()]
       + [atom.GetIsAromatic()]
    )
    #feature = np.array(feature, dtype=np.uint8)
    feature = np.packbits(feature)
    return feature


#Get features of an edge (one-hot encoding)
'''
    1.single/double/triple/aromatic: 4 dimensions       
    2.the atom's hybridization: 1 dimensions
    3.whether the bond is on ring: 1 dimension          
    Total: 6 dimensions
'''

def get_bond_feature(bond):
    bond_type = bond.GetBondType()
    feature = [
        bond_type == Chem.rdchem.BondType.SINGLE,
        bond_type == Chem.rdchem.BondType.DOUBLE,
        bond_type == Chem.rdchem.BondType.TRIPLE,
        bond_type == Chem.rdchem.BondType.AROMATIC,
        bond.GetIsConjugated(),
        bond.IsInRing()
    ]
    #feature = np.array(feature, dtype=np.uint8)
    feature = np.packbits(feature)
    return feature


def smiles_to_graph(smiles):
    mol = Chem.MolFromSmiles(smiles)
    N = mol.GetNumAtoms()
    node_feature = []
    edge_feature = []
    edge = []
    for i in range(mol.GetNumAtoms()):
        atom_i = mol.GetAtomWithIdx(i)
        atom_i_features = get_atom_feature(atom_i)
        node_feature.append(atom_i_features)

        for j in range(mol.GetNumAtoms()):
            bond_ij = mol.GetBondBetweenAtoms(i, j)
            if bond_ij is not None:
                edge.append([i, j])
                bond_features_ij = get_bond_feature(bond_ij)
                edge_feature.append(bond_features_ij)
    node_feature=np.stack(node_feature)
    edge_feature=np.stack(edge_feature)
    edge = np.array(edge,dtype=np.uint8)
    return N,edge,node_feature,edge_feature

In [18]:
ret = smiles_to_graph(train_df[0, 'molecule'])
print([e.shape for e in ret[1:]])

[(88, 2), (41, 9), (88, 1)]


In [13]:
# # allowable node and edge features
# allowable_features = {
#     'possible_atomic_num_list' : list(range(1, 119)),
#     'possible_formal_charge_list' : [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5],
#     'possible_chirality_list' : [
#         Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
#         Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
#         Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
#         Chem.rdchem.ChiralType.CHI_OTHER
#     ],
#     'possible_hybridization_list' : [
#         Chem.rdchem.HybridizationType.S,
#         Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
#         Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D,
#         Chem.rdchem.HybridizationType.SP3D2, Chem.rdchem.HybridizationType.UNSPECIFIED
#     ],
#     'possible_numH_list' : [0, 1, 2, 3, 4, 5, 6, 7, 8],
#     'possible_implicit_valence_list' : [0, 1, 2, 3, 4, 5, 6],
#     'possible_degree_list' : [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
#     'possible_bonds' : [
#         Chem.rdchem.BondType.SINGLE,
#         Chem.rdchem.BondType.DOUBLE,
#         Chem.rdchem.BondType.TRIPLE,
#         Chem.rdchem.BondType.AROMATIC
#     ],
#     'possible_bond_dirs' : [ # only for double bond stereo information
#         Chem.rdchem.BondDir.NONE,
#         Chem.rdchem.BondDir.ENDUPRIGHT,
#         Chem.rdchem.BondDir.ENDDOWNRIGHT
#     ]
# }


# def smiles_to_graph_v2(smiles):
#     """
#     Converts rdkit mol object to graph Data object required by the pytorch
#     geometric package. NB: Uses simplified atom and bond features, and represent
#     as indices
#     :param mol: rdkit mol object
#     :return: graph data object with the attributes: x, edge_index, edge_attr
#     """
#     # atoms
#     mol = Chem.MolFromSmiles(smiles)
#     num_atom_features = 2   # atom type,  chirality tag
#     atom_features_list = []
#     for atom in mol.GetAtoms():
#         atom_feature = [allowable_features['possible_atomic_num_list'].index(
#             atom.GetAtomicNum())] + [allowable_features[
#             'possible_chirality_list'].index(atom.GetChiralTag())]
#         atom_features_list.append(atom_feature)
#     x = torch.tensor(np.array(atom_features_list), dtype=torch.long)

#     # bonds
#     num_bond_features = 2   # bond type, bond direction
#     if len(mol.GetBonds()) > 0: # mol has bonds
#         edges_list = []
#         edge_features_list = []
#         for bond in mol.GetBonds():
#             i = bond.GetBeginAtomIdx()
#             j = bond.GetEndAtomIdx()
#             edge_feature = [allowable_features['possible_bonds'].index(
#                 bond.GetBondType())] + [allowable_features[
#                                             'possible_bond_dirs'].index(
#                 bond.GetBondDir())]
#             edges_list.append((i, j))
#             edge_features_list.append(edge_feature)
#             edges_list.append((j, i))
#             edge_features_list.append(edge_feature)

#         # data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
#         edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long)

#         # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
#         edge_attr = torch.tensor(np.array(edge_features_list),
#                                  dtype=torch.long)
#     else:   # mol has no bonds
#         edge_index = torch.empty((2, 0), dtype=torch.long)
#         edge_attr = torch.empty((0, num_bond_features), dtype=torch.long)

# #     data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

#     return x, edge_index, edge_attr

In [14]:
# smiles_to_graph_v2(train_df[0, 'molecule'])

In [19]:
class MolToGraphV1:
    # allowable node and edge features
    ALLOWABLE_FEATURES = {
        'possible_atomic_num_list':
        list(range(1, 119)),
        'possible_formal_charge_list': [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5],
        'possible_chirality_list': [
            Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
            Chem.rdchem.ChiralType.CHI_OTHER
        ],
        'possible_hybridization_list': [
            Chem.rdchem.HybridizationType.S, Chem.rdchem.HybridizationType.SP,
            Chem.rdchem.HybridizationType.SP2,
            Chem.rdchem.HybridizationType.SP3,
            Chem.rdchem.HybridizationType.SP3D,
            Chem.rdchem.HybridizationType.SP3D2,
            Chem.rdchem.HybridizationType.UNSPECIFIED
        ],
        'possible_numH_list': [0, 1, 2, 3, 4, 5, 6, 7, 8],
        'possible_implicit_valence_list': [0, 1, 2, 3, 4, 5, 6],
        'possible_degree_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
        'possible_bonds': [
            Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
            Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC
        ],
        'possible_bond_dirs': [  # only for double bond stereo information
            Chem.rdchem.BondDir.NONE, Chem.rdchem.BondDir.ENDUPRIGHT,
            Chem.rdchem.BondDir.ENDDOWNRIGHT
        ]
    }

    def __call__(self, mol):
        """
        Converts rdkit mol object to graph Data object required by the pytorch
        geometric package. NB: Uses simplified atom and bond features, and represent
        as indices
        :param mol: rdkit mol object
        :return: graph data object with the attributes: x, edge_index, edge_attr
        """
        # atoms
        num_nodes = mol.GetNumAtoms()
        NUM_ATOM_FEATURES = 2  # atom type,  chirality tag
        atom_features_list = []
        for atom in mol.GetAtoms():
            atom_feature = [
                self.ALLOWABLE_FEATURES['possible_atomic_num_list'].index(
                    atom.GetAtomicNum())
            ] + [
                self.ALLOWABLE_FEATURES['possible_chirality_list'].index(
                    atom.GetChiralTag())
            ]
            atom_features_list.append(atom_feature)
        node_features = torch.tensor(np.array(atom_features_list),
                                     dtype=torch.long)

        # bonds
        NUM_BOND_FEATURES = 2  # bond type, bond direction
        if len(mol.GetBonds()) > 0:  # mol has bonds
            edges_list = []
            edge_features = []
            for bond in mol.GetBonds():
                i = bond.GetBeginAtomIdx()
                j = bond.GetEndAtomIdx()
                edge_feature = [
                    self.ALLOWABLE_FEATURES['possible_bonds'].index(
                        bond.GetBondType())
                ] + [
                    self.ALLOWABLE_FEATURES['possible_bond_dirs'].index(
                        bond.GetBondDir())
                ]
                edges_list.append((i, j))
                edge_features.append(edge_feature)
                edges_list.append((j, i))
                edge_features.append(edge_feature)

            # data.edge_index: Graph connectivity with shape [num_edges, 2]
            edges = torch.tensor(np.array(edges_list), dtype=torch.long)

            # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
            edge_features = torch.tensor(np.array(edge_features),
                                         dtype=torch.long)
        else:  # mol has no bonds
            edges = torch.empty((2, 0), dtype=torch.long)
            edge_features = torch.empty((0, NUM_BOND_FEATURES),
                                        dtype=torch.long)

    #     data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
        return num_nodes, edges, node_features, edge_features

In [23]:
mol2graph = MolToGraphV1()
mol = Chem.MolFromSmiles(train_df[0, 'molecule'])
ret = mol2graph(mol)

for e in ret:
    if isinstance(e, torch.Tensor):
        print(e.shape)
    else:
        print(e)

41
torch.Size([88, 2])
torch.Size([41, 2])
torch.Size([88, 2])
