In [1]:
import numpy as np
from tqdm import tqdm

from rdkit.Chem.Scaffolds import MurckoScaffold
from collections import defaultdict

import random
import os

In [2]:
indir = '/scratch/midway3/jshe/data/qm9/transformed/'
outdir = '/scratch/midway3/jshe/data/qm9/scaffolded/'

## Load Data

In [3]:
datas = {fname.split('.')[0] : np.load(indir + fname) for fname in os.listdir(indir)}

In [5]:
smiles = datas['smiles']
del datas['smiles']
del datas['y_mean']
del datas['y_std']
del datas['y_labels']

## Determine scaffolds

In [6]:
scaffold_dict = defaultdict(set)
for i, smile in tqdm(enumerate(list(smiles))):
    scaffold_dict[MurckoScaffold.MurckoScaffoldSmiles(smile)].add(i)

127346it [00:12, 10375.19it/s]


In [7]:
scaffold_sets = iter(random.sample(list(scaffold_dict.values()), len(scaffold_dict)))

## Create scaffold splits

In [8]:
train_set = set()
while len(train_set) < (0.85 * len(smiles)):
    train_set.update(next(scaffold_sets))

validation_set = set().union(*scaffold_sets)

train_set, validation_set = list(train_set), list(validation_set)
print('Train:', len(train_set))
print('Validation:', len(validation_set))

Train: 119782
Validation: 7564


In [9]:
np.save(outdir + 'train/smiles.npy', smiles[train_set])
for name, data in datas.items():
    np.save(outdir + 'train/' + name, data[train_set])

np.save(outdir + 'validation/smiles.npy', smiles[validation_set])
for name, data in datas.items():
    np.save(outdir + 'validation/' + name, data[validation_set])