In [1]:
import numpy as np
from tqdm import tqdm

from rdkit import Chem
from rdkit.Chem.Scaffolds import MurckoScaffold
from collections import defaultdict

import random
import os

In [2]:
indir = '/scratch/midway3/jshe/data/qm9/transformed/'
outdir = '/scratch/midway3/jshe/data/qm9/scaffolded/'

## Load Data

In [4]:
smiles = np.load(indir + 'smiles.npy')

names, datas = zip(*[
    ('atoms', np.load(indir + 'atoms.npy')), 
    ('coordinates', np.load(indir + 'coordinates.npy')), 
    ('conformers', np.load(indir + 'conformers.npy')), 
    ('partial_charges', np.load(indir + 'partial_charges.npy')), 
    ('norm_y', np.load(indir + 'norm_y.npy')), 
])

## Determine scaffolds

In [5]:
scaffold_dict = defaultdict(set)
for i, smile in tqdm(enumerate(list(smiles))):
    scaffold_dict[MurckoScaffold.MurckoScaffoldSmiles(smile)].add(i)

127346it [00:12, 10367.02it/s]


In [6]:
scaffold_sets = iter(random.sample(list(scaffold_dict.values()), len(scaffold_dict)))

## Create scaffold splits

In [7]:
train_set = set()
while len(train_set) < (0.85 * len(smiles)):
    train_set.update(next(scaffold_sets))

validation_set = set().union(*scaffold_sets)

train_set, validation_set = list(train_set), list(validation_set)
print('Train:', len(train_set))
print('Validation:', len(validation_set))

Train: 109024
Validation: 18322


In [8]:
outdir = '/scratch/midway3/jshe/data/qm9/scaffolded/'

np.save(outdir + 'train_smiles.npy', smiles[train_set])
for name, arr in zip(names, datas):
    np.save(outdir + 'train_' + name, arr[train_set])

np.save(outdir + 'validation_smiles.npy', smiles[validation_set])
for name, arr in zip(names, datas):
    np.save(outdir + 'validation_' + name, arr[validation_set])