## Canonical 0D data split

We want a set of splits following this recipe:

- First take away the last 96 records, which are meant as an external validation set (this is the last plate ice-12-103)
- Of all remaining data (A), take away a random test set which is 10% of all data
- Of remaining data B, take away random validation set (10% of A, 11% of B)

Note:
- We do not actually have a 0D split as we do not ensure that every reactant showing up in test has also been seen in train. We simplify things and do a random split. This should usually be very close to a 0D split
- We use a 10-fold shuffle split, not CV. 

In [None]:
import pathlib
import sys
import os
sys.path.append(os.path.abspath("../"))

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split, ShuffleSplit

from src.data import SLAPData

In [None]:
# load data

data_path = os.path.abspath("../data/Data S4.csv")

data = SLAPData(data_path)

data.load_data_from_file()
data.split_reaction_smiles()

In [None]:
print(data.groups)

In [None]:
len(data.all_X)

In [None]:
splitter = ShuffleSplit(n_splits=10, test_size=0.1, random_state=42)

In [None]:
# we use only the first 763 records, as the validation plate starts after that
# (can be checked in generate_ml_datasets.ipynb).
# Note that this is only applicable to the LCMS data set, not the isolated yields, which have less entries


train_counter, val_counter, test_counter = 0, 0, 0
train_pos_class, val_pos_class, test_pos_class = 0, 0, 0

for i, (data_subset_B, test_0D) in enumerate(splitter.split(data.all_X[:763])):
    # we take a (0D) validation set. Rest is training set
    train, val = train_test_split(data_subset_B, test_size=0.11, random_state=None)  # <-- I forgot to seed this rng
    
    # update counters
    train_counter += len(train)
    val_counter += len(val)
    test_counter += len(test_0D)
    train_pos_class += np.sum(data.all_y[train])
    val_pos_class += np.sum(data.all_y[val])
    test_pos_class += np.sum(data.all_y[test_0D])
    
    
    print(f"Statistics for fold {i}:")
    print(f"ID \t\t num \t|\t %positive")
    print(f"Train: \t\t {len(train)} \t|\t {np.mean(data.all_y[train]):.0%}")
    print(f"Val: \t\t {len(val)} \t|\t {np.mean(data.all_y[val]):.0%}")
    print(f"Test_0D: \t {len(test_0D)} \t|\t {np.mean(data.all_y[test_0D]):.0%}")
    print()
    
    # save the indices
    save = False
    if save:
        save_path = pathlib.Path("../data/dataset_splits/LCMS_split_763records_0Dsplit_10fold/")
        save_path.mkdir(parents=True, exist_ok=True)
        pd.DataFrame(train).to_csv(save_path / f"fold{i}_train.csv", index=False, header=None)
        pd.DataFrame(val).to_csv(save_path / f"fold{i}_val.csv", index=False, header=None)
        pd.DataFrame(test_0D).to_csv(save_path / f"fold{i}_test_0D.csv", index=False, header=None)
        
# summary statistics
n = train_counter + val_counter + test_counter
print("\nSummary statistics:")
print(f"Split sizes: {train_counter/n:.0%} train, {val_counter/n:.0%} val, {test_counter/n:.0%} test")
print(f"Class balance (positive class ratio): {train_pos_class/train_counter:.0%} train, {val_pos_class/val_counter:.0%} val, {test_pos_class/test_counter:.0%} test")
