In [1]:
import numpy as np
import json
from paths import QM9, Alchemy, OE62, HOPV

In [2]:
inchi_dicts = {}
with open(QM9.inchis, 'r', encoding='utf-8') as f:
    inchi_dicts['qm9'] = json.load(f)
with open(Alchemy.inchis, 'r', encoding='utf-8') as f:
    inchi_dicts['alchemy'] = json.load(f)
with open(OE62.inchis, 'r', encoding='utf-8') as f:
    inchi_dicts['oe62'] = json.load(f)
with open(HOPV.inchis, 'r', encoding='utf-8') as f:
    inchi_dicts['hopv'] = json.load(f)

# Make a train/valid/test split based only on Chemical Formula and Connectivity Layer

In [3]:
L = 2
test_fraction = 0.18
val_fraction = 0.09
def inchi_up_to_layer(inchi, layer=1):
    return '/'.join(inchi.split('/')[:layer+1])
random_state = np.random.RandomState(seed=2022)

- Are the relative sizes of train/test/validation OK?

## Reserve the test set from the Alchemy Contest for testing

In [4]:
def idxs_per_split(DS, split):
    return {int(i) for i in np.load(DS.split)[split].tolist()}
def split_inchis(DS, split):
    with open(DS.inchis, 'r', encoding='utf-8') as f:
        idict = json.load(f)
    return {
        inchi_up_to_layer(idict[str(i)], L) for i in idxs_per_split(DS, split) \
        if str(i) in idict}

In [5]:
train_tris = set()
val_tris = set()
test_tris = set(split_inchis(Alchemy, 'test_idx'))
n_alchemy_test = len(test_tris)

## Get (the largest part of) the Alchemy validation set

In [6]:
val_tris_orig = set(split_inchis(Alchemy, 'val_idx'))
val_tris = val_tris_orig - test_tris
print('{:.1f}% of the original validation set is independet of the test set\
(in terms of disjunct truncated InChIs).'.format(100*len(val_tris)/len(val_tris_orig)))

98.0% of the original validation set is independet of the test set(in terms of disjunct truncated InChIs).


## Assign the rest of the Alchemy Dataset

In [7]:
# Collect InChIs where the truncated InChIs differ (Alchemy)
truncated_inchi_to_id = {}
for k, v in inchi_dicts['alchemy'].items():
    tri = inchi_up_to_layer(v, L)
    if tri not in truncated_inchi_to_id:
        truncated_inchi_to_id[tri] = [int(k)]
    else:
        truncated_inchi_to_id[tri] += [int(k)]

In [8]:
local_tris = set(truncated_inchi_to_id.keys())
# Identify those tris that are already in the QM9 splits
local_train_tris = local_tris & set(train_tris)
local_val_tris = local_tris & set(val_tris)
local_test_tris = local_tris & set(test_tris)
len(local_train_tris), len(local_val_tris), len(local_test_tris)

(0, 3869, 15572)

In [9]:
# Split the remaining truncated InChIs randomly into train, validation, test:
remaining_tris = local_tris - local_train_tris - local_val_tris - local_test_tris
idx = random_state.permutation(len(remaining_tris))
trinchis = sorted(remaining_tris)
trinchis = [trinchis[i] for i in idx]

num_test = max(int(test_fraction * len(local_tris)) - len(local_test_tris), 0)
num_val = max(int(val_fraction * len(local_tris)) - len(local_val_tris), 0)
num_train = len(trinchis) - num_test - num_val
assert num_train > 0

local_train_tris.update(trinchis[:num_train])
local_val_tris.update(trinchis[num_train : num_train + num_val])
local_test_tris.update(trinchis[num_train + num_val :])
len(local_train_tris), len(local_val_tris), len(local_test_tris)

(127674, 15740, 31480)

In [10]:
len(local_test_tris)/len(local_tris), len(local_val_tris)/(len(local_tris))

(0.17999473967088636, 0.08999736983544318)

In [11]:
train_idx = sorted([id for tri in local_train_tris for id in truncated_inchi_to_id[tri]])
val_idx = sorted([id for tri in local_val_tris for id in truncated_inchi_to_id[tri]])
test_idx = sorted([id for tri in local_test_tris for id in truncated_inchi_to_id[tri]])
len(train_idx), len(val_idx), len(test_idx)

(149644, 17906, 35029)

In [12]:
split_file = '/home/cgaul/MaLTOSe2020/schnetpack_exps/data/alchemy_split_v2.npz'
np.savez(split_file, train_idx=train_idx, val_idx=val_idx, test_idx=test_idx)

In [13]:
# For the next set, take the split of this dataset into account:
train_tris.update(local_train_tris)
val_tris.update(local_val_tris)
test_tris.update(local_test_tris)

### QM9

In [14]:
# Collect inchis where the truncated InChIs differ (QM9)
truncated_inchi_to_id = {}
for k, v in inchi_dicts['qm9'].items():
    tri = inchi_up_to_layer(v, L)
    if tri not in truncated_inchi_to_id:
        truncated_inchi_to_id[tri] = [int(k)]
    else:
        truncated_inchi_to_id[tri] += [int(k)]

In [15]:
local_tris = set(truncated_inchi_to_id.keys())
# Identify those tris that are already in the previous splits
local_train_tris = local_tris & set(train_tris)
local_val_tris = local_tris & set(val_tris)
local_test_tris = local_tris & set(test_tris)
len(local_train_tris), len(local_val_tris), len(local_test_tris)

(15211, 1409, 1955)

In [16]:
remaining_tris = set(truncated_inchi_to_id.keys()) - local_train_tris - local_val_tris - local_test_tris
# Split the remaining truncated InChIs randomly into train, validation, test:
idx = random_state.permutation(len(remaining_tris))
trinchis = sorted(remaining_tris)
trinchis = [trinchis[i] for i in idx]

num_test = max(int(test_fraction * len(local_tris)) - len(local_test_tris), 0)
num_val = max(int(val_fraction * len(local_tris)) - len(local_val_tris), 0)
num_train = len(trinchis) - num_test - num_val
assert num_train > 0

local_train_tris.update(trinchis[:num_train])
local_val_tris.update(trinchis[num_train : num_train + num_val])
local_test_tris.update(trinchis[num_train + num_val :])
len(local_train_tris), len(local_val_tris), len(local_test_tris)

(83737, 10323, 20647)

In [17]:
len(local_test_tris)/len(local_tris), len(local_val_tris)/(len(local_tris))

(0.17999773335541858, 0.08999450774582196)

In [18]:
train_idx = sorted([id for tri in local_train_tris for id in truncated_inchi_to_id[tri]])
val_idx = sorted([id for tri in local_val_tris for id in truncated_inchi_to_id[tri]])
test_idx = sorted([id for tri in local_test_tris for id in truncated_inchi_to_id[tri]])
len(train_idx), len(val_idx), len(test_idx)

(98192, 12031, 23656)

In [19]:
split_file = '/home/cgaul/MaLTOSe2020/schnetpack_exps/data/qm9_split_v2.npz'
np.savez(split_file, train_idx=train_idx, val_idx=val_idx, test_idx=test_idx)

In [20]:
# For the next set, take the split of this dataset into account:
train_tris.update(local_train_tris)
val_tris.update(local_val_tris)
test_tris.update(local_test_tris)

### OE62

In [21]:
# Collect InChIs where the truncated InChIs differ (OE62)
truncated_inchi_to_id = {}
for k, v in inchi_dicts['oe62'].items():
    tri = inchi_up_to_layer(v, L)
    if tri not in truncated_inchi_to_id:
        truncated_inchi_to_id[tri] = [int(k)]
    else:
        truncated_inchi_to_id[tri] += [int(k)]

In [22]:
local_tris = set(truncated_inchi_to_id.keys())
# Identify those tris that are already in the previous splits
local_train_tris = local_tris & set(train_tris)
local_val_tris = local_tris & set(val_tris)
local_test_tris = local_tris & set(test_tris)
len(local_train_tris), len(local_val_tris), len(local_test_tris)

(303, 35, 65)

In [23]:
# Split the remaining truncated InChIs randomly into train, validation, test:
remaining_tris = set(truncated_inchi_to_id.keys()) - local_train_tris - local_val_tris - local_test_tris
idx = random_state.permutation(len(remaining_tris))
trinchis = sorted(remaining_tris)
trinchis = [trinchis[i] for i in idx]

num_test = max(int(test_fraction * len(local_tris)) - len(local_test_tris), 0)
num_val = max(int(val_fraction * len(local_tris)) - len(local_val_tris), 0)
num_train = len(trinchis) - num_test - num_val
assert num_train > 0

local_train_tris.update(trinchis[:num_train])
local_val_tris.update(trinchis[num_train : num_train + num_val])
local_test_tris.update(trinchis[num_train + num_val :])
len(local_train_tris), len(local_val_tris), len(local_test_tris)

(44822, 5525, 11051)

In [24]:
len(local_test_tris)/len(local_tris), len(local_val_tris)/(len(local_tris))

(0.17998957620769407, 0.08998664451610802)

In [25]:
train_idx = sorted([id for tri in local_train_tris for id in truncated_inchi_to_id[tri]])
val_idx = sorted([id for tri in local_val_tris for id in truncated_inchi_to_id[tri]])
test_idx = sorted([id for tri in local_test_tris for id in truncated_inchi_to_id[tri]])
len(train_idx), len(val_idx), len(test_idx)

(44895, 5531, 11063)

In [26]:
split_file = '/home/cgaul/MaLTOSe2020/schnetpack_exps/data/oe62/split_v2.npz'
np.savez(split_file, train_idx=train_idx, val_idx=val_idx, test_idx=test_idx)

In [27]:
# For the next set, take the split of this dataset into account:
train_tris.update(local_train_tris)
val_tris.update(local_val_tris)
test_tris.update(local_test_tris)

### HOPV

In [28]:
# Collect InChIs where the truncated InChIs differ (HOPV)
truncated_inchi_to_id = {}
for k, v in inchi_dicts['hopv'].items():
    tri = inchi_up_to_layer(v, L)
    if tri not in truncated_inchi_to_id:
        truncated_inchi_to_id[tri] = [int(k)]
    else:
        truncated_inchi_to_id[tri] += [int(k)]

In [29]:
local_tris = set(truncated_inchi_to_id.keys())
# Identify those tris that are already in the previous splits
local_train_tris = local_tris & set(train_tris)
local_val_tris = local_tris & set(val_tris)
local_test_tris = local_tris & set(test_tris)
len(local_train_tris), len(local_val_tris), len(local_test_tris)

(1, 1, 1)

In [30]:
# Split the remaining truncated InChIs randomly into train, validation, test:
remaining_tris = set(truncated_inchi_to_id.keys()) - local_train_tris - local_val_tris - local_test_tris
idx = random_state.permutation(len(remaining_tris))
trinchis = sorted(remaining_tris)
trinchis = [trinchis[i] for i in idx]

num_test = max(int(test_fraction * len(local_tris)) - len(local_test_tris), 0)
num_val = max(int(val_fraction * len(local_tris)) - len(local_val_tris), 0)
num_train = len(trinchis) - num_test - num_val
assert num_train > 0

local_train_tris.update(trinchis[:num_train])
local_val_tris.update(trinchis[num_train : num_train + num_val])
local_test_tris.update(trinchis[num_train + num_val :])
len(local_train_tris), len(local_val_tris), len(local_test_tris)

(254, 31, 62)

In [31]:
len(local_test_tris)/len(local_tris), len(local_val_tris)/(len(local_tris))

(0.1786743515850144, 0.0893371757925072)

In [32]:
train_idx = sorted([id for tri in local_train_tris for id in truncated_inchi_to_id[tri]])
val_idx = sorted([id for tri in local_val_tris for id in truncated_inchi_to_id[tri]])
test_idx = sorted([id for tri in local_test_tris for id in truncated_inchi_to_id[tri]])
len(train_idx), len(val_idx), len(test_idx)

(3569, 412, 874)

In [33]:
split_file = '/home/cgaul/MaLTOSe2020/schnetpack_exps/data/hopv/split_v2.npz'
np.savez(split_file, train_idx=train_idx, val_idx=val_idx, test_idx=test_idx)

In [34]:
# For the next set, take the split of this dataset into account:
train_tris.update(local_train_tris)
val_tris.update(local_val_tris)
test_tris.update(local_test_tris)

## Some plausibility checks

In [35]:
train_tris & test_tris

set()

In [36]:
train_tris & val_tris

set()

In [37]:
val_tris & test_tris

set()

In [38]:
len(train_tris)

240972

In [39]:
len(test_tris)

61219

In [40]:
len(val_tris)

30174