## Data splitting

In [1]:
from chemprop.data import SplitType, make_split_indices, split_data_by_indices

These are example [datapoints](./datapoints.ipynb) to split.

In [2]:
import numpy as np
from chemprop.data import MoleculeDatapoint

smis = ["C" * i for i in range(1, 11)]
ys = np.random.rand(len(smis), 1)
datapoints = [MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]

### Data splits

A typical Chemprop workflow uses three sets of data. The first is used to train the model. The second is used as validation for early stopping and hyperparameter optimization. The third is used to test the final model's performance as an estimate for how it will perform on future data. 

Chemprop provides helper functions to split data into these training, validation, and test sets. Available splitting schemes are listed in `SplitType`. All of these rely on `astartes` in the backend except for cross validation. 

In [3]:
for splittype in SplitType:
    print(splittype)

cv_no_val
cv
scaffold_balanced
random_with_repeated_smiles
random
kennard_stone
kmeans


### Splitting steps

1. Collect the `rdkit.Chem.mol` objects for each datapoint. These are required for structure based splits.
2. Generate the splitting indices.
3. Split the data using those indices.

In [4]:
mols = [d.mol for d in datapoints]

train_indices, val_indices, test_indices = make_split_indices(mols)

train_data, val_data, test_data = split_data_by_indices(
    datapoints, train_indices, val_indices, test_indices
)

The default splitting scheme is a random split with 80% of the data used to train, 10% to validate and 10% to split.

In [5]:
len(train_data), len(val_data), len(test_data)

(8, 1, 1)

### Split randomness

All split randomness uses a default seed of 0 and `numpy.random`. The seed can be changed to get different splits.

In [6]:
make_split_indices(datapoints)

([8, 4, 9, 1, 6, 7, 3, 0], [5], [2])

In [7]:
make_split_indices(datapoints, seed=12)

([8, 7, 0, 4, 9, 3, 2, 1], [6], [5])

### Split fractions

The split sizes can also be changed. Set the middle value to 0 for a two way split. If the data can not be split to exactly the specified proportions, you will get a warning from `astartes` with the actual sizes used. And if the specified sizes don't sum to 1, the sizes will first be rescaled to sum to 1. 

In [8]:
make_split_indices(datapoints, sizes=(0.4, 0.3, 0.3))

([8, 4, 9, 1], [6, 7, 3], [0, 5, 2])

In [9]:
make_split_indices(datapoints, sizes=(0.6, 0.0, 0.4))

([8, 4, 9, 1, 6, 7], [], [3, 0, 5, 2])

In [10]:
make_split_indices(datapoints, sizes=(0.5, 0.25, 0.25))

  warn(


([8, 4, 9, 1, 6], [7, 3], [0, 5, 2])

In [11]:
make_split_indices(datapoints, sizes=(0.5, 0.5, 0.5))

  warn(
  warn(


([8, 4, 9], [1, 6, 7, 3, 0], [5, 2])

### Random with repeated molecules

If your dataset has repeated molecules, all duplicate molecules should go in the same split. This split type requires the `rdkit.Chem.mol` objects of the datapoints. It first removes duplicates before using `astartes` to make the random splits and then adds back in the duplicate datapoints.

In [12]:
smis = ["O", "O"] + ["C" * i for i in range(1, 10)]
ys = np.random.rand(len(smis), 1)
repeat_datapoints = [MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]
mols = [d.mol for d in repeat_datapoints]

In [13]:
make_split_indices(mols, split="random_with_repeated_smiles")

([10, 6, 0, 1, 3, 8, 9, 5, 2], [7], [4])

### Cross Validation

Smaller datasets can benefit from cross validation. This allows for better evaluation of a model architecture by training multiple models and ensuring that each datapoint appears in the validation/test sets for one of the models. The number of folds determines how many sets of splits to make (i.e. how many models to train). For each fold, roughly 1/(number of folds) of the data will be put in each of the validation and test set while the rest will be used for training. 

In [14]:
train_indices, val_indices, test_indices = make_split_indices(datapoints, split="cv", num_folds=4)

Cross validation returns a tuple of lists of lists of indices. The n-th list of each list in the tuple corresponds to the split indices for the n-th fold. The function to split data based on indices lists handles a list of lists automatically and returns a tuple of list of lists of datapoints. You can zip the results together to get what is needed for a single fold.

In [15]:
train_indices, val_indices, test_indices

([array([1, 2, 3, 4]),
  array([0, 3, 4, 7, 8]),
  array([0, 5, 6, 7, 8, 9]),
  array([1, 2, 5, 6, 9])],
 [array([5, 6, 9]), array([1, 2]), array([3, 4]), array([0, 7, 8])],
 [array([0, 7, 8]), array([5, 6, 9]), array([1, 2]), array([3, 4])])

In [16]:
train_data, val_data, test_data = split_data_by_indices(
    datapoints, train_indices, val_indices, test_indices
)

for fold_idx, (train_data, val_data, test_data) in enumerate(zip(train_data, val_data, test_data)):
    print(f"Fold {fold_idx}")
    print(len(train_data), len(val_data), len(test_data))

Fold 0
4 3 3
Fold 1
5 2 3
Fold 2
6 2 2
Fold 3
5 3 2


### Structure based splits

Including all similar molecules in only one of the datasets can give a more realistic estimate of how a model will perform on unseen chemistry. This uses the `rdkit.Chem.mol` representation of the molecules. See the `astartes` [documentation](https://jacksonburns.github.io/astartes/) for details about Kennard Stone, k-means, and scaffold balanced splitting schemes.

In [17]:
smis = [
    "Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14",
    "COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)CCc3ccccc23",
    "COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl",
    "OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(Cl)sc4[nH]3",
    "Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)NCC#N)c1",
    "OC1(CN2CCC1CC2)C#Cc3ccc(cc3)c4ccccc4",
    "COc1cc(OC)c(cc1NC(=O)CCC(=O)O)S(=O)(=O)NCc2ccccc2N3CCCCC3",
    "CNc1cccc(CCOc2ccc(C[C@H](NC(=O)c3c(Cl)cccc3Cl)C(=O)O)cc2C)n1",
    "COc1ccc(cc1)C2=COc3cc(OC)cc(OC)c3C2=O",
    "Oc1ncnc2scc(c3ccsc3)c12",
    "CS(=O)(=O)c1ccc(Oc2ccc(cc2)C#C[C@]3(O)CN4CCC3CC4)cc1",
    "C[C@H](Nc1nc(Nc2cc(C)[nH]n2)c(C)nc1C#N)c3ccc(F)cn3",
    "O=C1CCCCCN1",
    "CCCSc1ncccc1C(=O)N2CCCC2c3ccncc3",
    "CC1CCCCC1NC(=O)c2cnn(c2NS(=O)(=O)c3ccc(C)cc3)c4ccccc4",
    "Nc1ccc(cc1)c2nc3ccc(O)cc3s2",
    "COc1ccc(cc1)N2CCN(CC2)C(=O)[C@@H]3CCCC[C@H]3C(=O)NCC#N",
    "CCC(COC(=O)c1cc(OC)c(OC)c(OC)c1)(N(C)C)c2ccccc2",
    "COc1cc(ccc1N2CC[C@@H](O)C2)N3N=Nc4cc(sc4C3=O)c5ccc(Cl)cc5",
    "CO[C@H]1CN(CCN2C(=O)C=Cc3ccc(cc23)C#N)CC[C@H]1NCc4ccc5OCC(=O)Nc5n4",
]

ys = np.random.rand(len(smis), 1)
datapoints = [MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]
mols = [d.mol for d in datapoints]

In [18]:
make_split_indices(mols, split="kmeans")

  warn(


([0, 1, 2, 3, 4, 6, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18, 19], [5, 10], [7])