## Prepare raw graph files of AqSol dataset from their source SMILES

In [None]:
from rdkit import Chem
import networkx as nx
import matplotlib.pyplot as plt
import math
import random
import numpy as np
from itertools import compress
from rdkit.Chem.Scaffolds import MurckoScaffold
from collections import defaultdict
import os
import pandas as pd
import time
import pickle
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

### Download csv source file

In [None]:
if not os.path.isfile('aqsol.csv'):
    print('downloading..')
    # The download link is present on this webpage: 
    # https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/OVHAW8
    !curl https://www.dropbox.com/s/26zoivsx3s1qr3q/curated-solubility-dataset.csv?dl=1 -o aqsol.csv -J -L -k
    print('download complete')
else:
    print('File already downloaded')

In [None]:
aqsol_df = pd.read_csv('aqsol.csv') # read dataset
smiles_list = list(aqsol_df['SMILES']) # get smiles strings from file
labels_list = np.asarray(aqsol_df['Solubility']) # get solubility values from file
mol_list = [Chem.MolFromSmiles(smiles) for smiles in smiles_list]

### Scaffold Splitting

In [None]:
# Code snippet from
# https://groups.google.com/g/open-graph-benchmark/c/0KML08gNVcM/m/99SeJ0zpAAAJ?pli=1

def generate_scaffold(smiles, include_chirality=False):
    """
    Obtain Bemis-Murcko scaffold from smiles
    :param smiles:
    :param include_chirality:
    :return: smiles of scaffold
    """
    scaffold = MurckoScaffold.MurckoScaffoldSmiles(
        smiles=smiles, includeChirality=include_chirality)
    return scaffold
# # test generate scaffold
# s = 'Cc1cc(Oc2nccc(CCC)c2)ccc1'
# scaffold = generate_scaffold(s)
# assert scaffold == 'c1ccc(Oc2ccccn2)cc1'

def scaffold_split(smiles_list, frac_train=0.8, frac_valid=0.1, frac_test=0.1):
    """
    Adapted from https://github.com/deepchem/deepchem/blob/master/deepchem/splits/splitters.py
    Split dataset by Bemis-Murcko scaffolds. Deterministic split
    :param smiles_list: list of smiles
    :param frac_train:
    :param frac_valid:
    :param frac_test:
    :return: list of train, valid, test indices corresponding to the
    scaffold split
    """
    np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.0)

    # create dict of the form {scaffold_i: [idx1, idx....]}
    all_scaffolds = {}
    for i, smiles in enumerate(smiles_list):
        scaffold = generate_scaffold(smiles, include_chirality=True)
        if scaffold not in all_scaffolds:
            all_scaffolds[scaffold] = [i]
        else:
            all_scaffolds[scaffold].append(i)

    # sort from largest to smallest sets
    all_scaffolds = {key: sorted(value) for key, value in all_scaffolds.items()}
    all_scaffold_sets = [
        scaffold_set for (scaffold, scaffold_set) in sorted(
            all_scaffolds.items(), key=lambda x: (len(x[1]), x[1][0]), reverse=True)
    ]

    # get train, valid test indices
    train_cutoff = frac_train * len(smiles_list)
    valid_cutoff = (frac_train + frac_valid) * len(smiles_list)
    train_idx, valid_idx, test_idx = [], [], []
    for scaffold_set in all_scaffold_sets:
        if len(train_idx) + len(scaffold_set) > train_cutoff:
            if len(train_idx) + len(valid_idx) + len(scaffold_set) > valid_cutoff:
                test_idx.extend(scaffold_set)
            else:
                valid_idx.extend(scaffold_set)
        else:
            train_idx.extend(scaffold_set)

    assert len(set(train_idx).intersection(set(valid_idx))) == 0
    assert len(set(train_idx).intersection(set(test_idx))) == 0
    assert len(set(test_idx).intersection(set(valid_idx))) == 0

    return train_idx, valid_idx, test_idx

In [None]:
train_idx, valid_idx, test_idx = scaffold_split(smiles_list)
print(len(train_idx), len(valid_idx), len(test_idx))

### Create Atom and Bond Dicts

In [None]:
# Code snippet from
# https://github.com/xbresson/IPAM_Tutorial_2019/blob/master/04_molecule_regression/dictionaries.py
class Dictionary(object):
    """
    worddidx is a dictionary
    idx2word is a list
    """
    def __init__(self):
        self.word2idx = {}
        self.idx2word = []
        self.word2num_occurence = {}
        self.idx2num_occurence = []

    def add_word(self, word):
        if word not in self.word2idx:
            # dictionaries
            self.idx2word.append(word)
            self.word2idx[word] = len(self.idx2word) - 1
            # stats
            self.idx2num_occurence.append(0)
            self.word2num_occurence[word] = 0

        # increase counters    
        self.word2num_occurence[word]+=1
        self.idx2num_occurence[  self.word2idx[word]  ] += 1

    def __len__(self):
        return len(self.idx2word)


def augment_dictionary(atom_dict, bond_dict, list_of_mol ):

    """
    take a lists of rdkit molecules and use it to augment existing atom and bond dictionaries
    """
    for idx,mol in enumerate(list_of_mol):

        for atom in mol.GetAtoms():
            atom_dict.add_word( atom.GetSymbol() )

        for bond in mol.GetBonds():
            bond_dict.add_word( str(bond.GetBondType()) )

        # compute the number of edges of type 'None'
        N=mol.GetNumAtoms()
        if N>2:
            E=N+math.factorial(N)/(math.factorial(2)*math.factorial(N-2)) # including self loop
            num_NONE_bonds = E-mol.GetNumBonds()
            bond_dict.word2num_occurence['NONE']+=num_NONE_bonds
            bond_dict.idx2num_occurence[0]+=num_NONE_bonds


def make_dictionary(list_of_mol):

    """
    the list of smiles (train, val and test) and build atoms and bond dictionaries
    """
    atom_dict=Dictionary()
    bond_dict=Dictionary()
    bond_dict.add_word('NONE')
    print('making dictionary')
    augment_dictionary(atom_dict, bond_dict, list_of_mol )
    print('complete')
    return atom_dict, bond_dict

In [None]:
atom_dict, bond_dict = make_dictionary(mol_list)

In [None]:
print(atom_dict.word2idx)
print(atom_dict.word2num_occurence)
print(bond_dict.word2idx)
print(bond_dict.word2num_occurence)

### Create graph objects

In [None]:
def mol_to_graph(mol, solubility):
    """
        mol is a rdkit mol object
    """
    no_bond_flag = False
    node_feat = np.array([atom_dict.word2idx[atom.GetSymbol()] for atom in mol.GetAtoms()], dtype = np.int64)

    if len(mol.GetBonds()) > 0: # mol has bonds
        edges_list = []
        edge_features_list = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            edge_feature = bond_dict.word2idx[str(bond.GetBondType())]
            # add edges in both directions
            edges_list.append((i, j))
            edge_features_list.append(edge_feature)
            edges_list.append((j, i))
            edge_features_list.append(edge_feature)

        # edge_index: Graph connectivity in COO format with shape [2, num_edges]
        edge_index = np.array(edges_list, dtype = np.int64).T

        # edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
        edge_feat = np.array(edge_features_list, dtype = np.int64)

    else:   # mol has no bonds
        no_bond_flag = True
        # print(mol)
        edge_index = np.empty((2, 0), dtype = np.int64)
        edge_feat = np.empty((0, 1), dtype = np.int64)
        
    graph_object = (node_feat, edge_feat, edge_index, solubility)    
    return graph_object, no_bond_flag

In [None]:
def get_graph_objects(split_idx, mol_list, labels_list):
    count_no_bond = 0
    split_graph_objects =[]

    for idx in split_idx:
        graph_object, no_bond_flag = mol_to_graph(mol_list[idx], labels_list[idx])
        split_graph_objects.append(graph_object)
        len(graph_object[0])
        if no_bond_flag:
            #print("no bond graph, num nodes:",len(graph_object[0]))
            count_no_bond += 1
    print("Total graphs with no bonds: ", count_no_bond)
    return split_graph_objects

In [None]:
print("Train split..")
train_graph_objects = get_graph_objects(train_idx, mol_list, labels_list)
print("Total graphs:", len(train_graph_objects))
print("Valid split..")
valid_graph_objects = get_graph_objects(valid_idx, mol_list, labels_list)
print("Total graphs:", len(valid_graph_objects))
print("Train split..")
test_graph_objects = get_graph_objects(test_idx, mol_list, labels_list)
print("Total graphs:", len(test_graph_objects))
print("Done")

### Save pkl files

In [None]:
savedir = './asqol_graph_raw'
if not os.path.exists(savedir):
    os.makedirs(savedir)

In [None]:
start = time.time()
with open(savedir+'/train.pickle','wb') as f:
    pickle.dump(train_graph_objects,f)
with open(savedir+'/val.pickle','wb') as f:
    pickle.dump(valid_graph_objects,f)
with open(savedir+'/test.pickle','wb') as f:
    pickle.dump(test_graph_objects,f)
with open(savedir+'/atom_dict.pickle','wb') as f:
    pickle.dump(atom_dict,f)
with open(savedir+'/bond_dict.pickle','wb') as f:
    pickle.dump(bond_dict,f)
print('Time (sec):',time.time() - start)

In [None]:
!zip -r aqsol_graph_raw.zip asqol_graph_raw