In [1]:
import numpy as np
import json
from paths import QM9, Alchemy, OE62, HOPV

In [2]:
inchi_dicts = {}
with open(QM9.inchis, 'r', encoding='utf-8') as f:
    inchi_dicts['qm9'] = json.load(f)
with open(Alchemy.inchis, 'r', encoding='utf-8') as f:
    inchi_dicts['alchemy'] = json.load(f)
with open(OE62.inchis, 'r', encoding='utf-8') as f:
    inchi_dicts['oe62'] = json.load(f)
with open(HOPV.inchis, 'r', encoding='utf-8') as f:
    inchi_dicts['hopv'] = json.load(f)

# Make a train/valid/test split based only on Chemical Formula and Connectivity Layer

In [3]:
L = 2
test_fraction = 0.18
val_fraction = 0.09
def inchi_up_to_layer(inchi, layer=1):
    return '/'.join(inchi.split('/')[:layer+1])
random_state = np.random.RandomState(seed=2022)

- Are the relative sizes of train/test/validation OK?

### QM9

In [4]:
# Collect inchis where the truncated InChIs differ (QM9)
truncated_inchi_to_id = {}
for k, v in inchi_dicts['qm9'].items():
    tri = inchi_up_to_layer(v, L)
    if tri not in truncated_inchi_to_id:
        truncated_inchi_to_id[tri] = [int(k)]
    else:
        truncated_inchi_to_id[tri] += [int(k)]

In [5]:
# Split truncated InChIs randomly into train, validation, test:
idx = random_state.permutation(len(truncated_inchi_to_id))
trinchis = sorted(truncated_inchi_to_id.keys())
trinchis = [trinchis[i] for i in idx]

num_test = int(test_fraction * len(trinchis))
num_val = int(val_fraction * len(trinchis))
num_train = len(trinchis) - num_test - num_val
assert num_train > 0

train_tris = set(trinchis[:num_train])
val_tris = set(trinchis[num_train : num_train + num_val])
test_tris = set(trinchis[num_train + num_val :])

In [6]:
train_idx = sorted([id for tri in train_tris for id in truncated_inchi_to_id[tri]])
val_idx = sorted([id for tri in val_tris for id in truncated_inchi_to_id[tri]])
test_idx = sorted([id for tri in test_tris for id in truncated_inchi_to_id[tri]])
len(train_idx), len(val_idx), len(test_idx)

(97816, 12094, 23969)

In [7]:
split_file = '/home/cgaul/MaLTOSe2020/schnetpack_exps/data/qm9_split_v2.npz'
np.savez(split_file, train_idx=train_idx, val_idx=val_idx, test_idx=test_idx)

### Alchemy

In [8]:
# Collect InChIs where the truncated InChIs differ (Alchemy)
truncated_inchi_to_id = {}
for k, v in inchi_dicts['alchemy'].items():
    tri = inchi_up_to_layer(v, L)
    if tri not in truncated_inchi_to_id:
        truncated_inchi_to_id[tri] = [int(k)]
    else:
        truncated_inchi_to_id[tri] += [int(k)]

In [9]:
# Identify those tris that are already in the QM9 splits
local_train_tris = set(truncated_inchi_to_id.keys()) & set(train_tris)
local_val_tris = set(truncated_inchi_to_id.keys()) & set(val_tris)
local_test_tris = set(truncated_inchi_to_id.keys()) & set(test_tris)
len(local_train_tris), len(local_val_tris), len(local_test_tris)

(13519, 1739, 3317)

In [10]:
# Split the remaining truncated InChIs randomly into train, validation, test:
remaining_tris = set(truncated_inchi_to_id.keys()) - local_train_tris - local_val_tris - local_test_tris
idx = random_state.permutation(len(remaining_tris))
trinchis = sorted(remaining_tris)
trinchis = [trinchis[i] for i in idx]

num_test = int(test_fraction * len(trinchis))
num_val = int(val_fraction * len(trinchis))
num_train = len(trinchis) - num_test - num_val
assert num_train > 0

local_train_tris.update(trinchis[:num_train])
local_val_tris.update(trinchis[num_train : num_train + num_val])
local_test_tris.update(trinchis[num_train + num_val :])

In [11]:
train_idx = sorted([id for tri in local_train_tris for id in truncated_inchi_to_id[tri]])
val_idx = sorted([id for tri in local_val_tris for id in truncated_inchi_to_id[tri]])
test_idx = sorted([id for tri in local_test_tris for id in truncated_inchi_to_id[tri]])
len(train_idx), len(val_idx), len(test_idx)

(147079, 18186, 36277)

In [12]:
split_file = '/home/cgaul/MaLTOSe2020/schnetpack_exps/data/alchemy_split_v2.npz'
np.savez(split_file, train_idx=train_idx, val_idx=val_idx, test_idx=test_idx)

In [13]:
# For the next set, take the split of this dataset into account:
train_tris.update(local_train_tris)
val_tris.update(local_val_tris)
test_tris.update(local_test_tris)

### OE62

In [14]:
# Collect InChIs where the truncated InChIs differ (OE62)
truncated_inchi_to_id = {}
for k, v in inchi_dicts['oe62'].items():
    tri = inchi_up_to_layer(v, L)
    if tri not in truncated_inchi_to_id:
        truncated_inchi_to_id[tri] = [int(k)]
    else:
        truncated_inchi_to_id[tri] += [int(k)]

In [15]:
# Identify those tris that are already in the previous splits
local_train_tris = set(truncated_inchi_to_id.keys()) & set(train_tris)
local_val_tris = set(truncated_inchi_to_id.keys()) & set(val_tris)
local_test_tris = set(truncated_inchi_to_id.keys()) & set(test_tris)
len(local_train_tris), len(local_val_tris), len(local_test_tris)

(297, 32, 74)

In [16]:
# Split the remaining truncated InChIs randomly into train, validation, test:
remaining_tris = set(truncated_inchi_to_id.keys()) - local_train_tris - local_val_tris - local_test_tris
idx = random_state.permutation(len(remaining_tris))
trinchis = sorted(remaining_tris)
trinchis = [trinchis[i] for i in idx]

num_test = int(test_fraction * len(trinchis))
num_val = int(val_fraction * len(trinchis))
num_train = len(trinchis) - num_test - num_val
assert num_train > 0

local_train_tris.update(trinchis[:num_train])
local_val_tris.update(trinchis[num_train : num_train + num_val])
local_test_tris.update(trinchis[num_train + num_val :])

In [17]:
train_idx = sorted([id for tri in local_train_tris for id in truncated_inchi_to_id[tri]])
val_idx = sorted([id for tri in local_val_tris for id in truncated_inchi_to_id[tri]])
test_idx = sorted([id for tri in local_test_tris for id in truncated_inchi_to_id[tri]])
len(train_idx), len(val_idx), len(test_idx)

(44900, 5525, 11064)

In [18]:
split_file = '/home/cgaul/MaLTOSe2020/schnetpack_exps/data/oe62/split_v2.npz'
np.savez(split_file, train_idx=train_idx, val_idx=val_idx, test_idx=test_idx)

In [19]:
# For the next set, take the split of this dataset into account:
train_tris.update(local_train_tris)
val_tris.update(local_val_tris)
test_tris.update(local_test_tris)

### HOPV

In [20]:
# Collect InChIs where the truncated InChIs differ (HOPV)
truncated_inchi_to_id = {}
for k, v in inchi_dicts['hopv'].items():
    tri = inchi_up_to_layer(v, L)
    if tri not in truncated_inchi_to_id:
        truncated_inchi_to_id[tri] = [int(k)]
    else:
        truncated_inchi_to_id[tri] += [int(k)]

In [21]:
# Identify those tris that are already in the previous splits
local_train_tris = set(truncated_inchi_to_id.keys()) & set(train_tris)
local_val_tris = set(truncated_inchi_to_id.keys()) & set(val_tris)
local_test_tris = set(truncated_inchi_to_id.keys()) & set(test_tris)
len(local_train_tris), len(local_val_tris), len(local_test_tris)

(2, 0, 1)

In [22]:
# Split the remaining truncated InChIs randomly into train, validation, test:
remaining_tris = set(truncated_inchi_to_id.keys()) - local_train_tris - local_val_tris - local_test_tris
idx = random_state.permutation(len(remaining_tris))
trinchis = sorted(remaining_tris)
trinchis = [trinchis[i] for i in idx]

num_test = int(test_fraction * len(trinchis))
num_val = int(val_fraction * len(trinchis))
num_train = len(trinchis) - num_test - num_val
assert num_train > 0

local_train_tris.update(trinchis[:num_train])
local_val_tris.update(trinchis[num_train : num_train + num_val])
local_test_tris.update(trinchis[num_train + num_val :])

In [23]:
train_idx = sorted([id for tri in local_train_tris for id in truncated_inchi_to_id[tri]])
val_idx = sorted([id for tri in local_val_tris for id in truncated_inchi_to_id[tri]])
test_idx = sorted([id for tri in local_test_tris for id in truncated_inchi_to_id[tri]])
len(train_idx), len(val_idx), len(test_idx)

(3626, 380, 849)

In [24]:
split_file = '/home/cgaul/MaLTOSe2020/schnetpack_exps/data/hopv/split_v2.npz'
np.savez(split_file, train_idx=train_idx, val_idx=val_idx, test_idx=test_idx)

In [25]:
# For the next set, take the split of this dataset into account:
train_tris.update(local_train_tris)
val_tris.update(local_val_tris)
test_tris.update(local_test_tris)

## Some plausibility checks

In [26]:
train_tris & test_tris

set()

In [27]:
train_tris & val_tris

set()

In [28]:
val_tris & test_tris

set()

In [29]:
len(train_tris)

242631

In [30]:
len(test_tris)

59824

In [31]:
len(val_tris)

29910