# Data Splits Truncated

Here, we split in the same way as in `data_splits.ipynb`, but then we restrict the number of points in the training data.
There are two ways we could do this:
1. Take the original splits and drop random examples. Problem: As we go to smaller training sets, this will lead to us inadvertently removing all the examples of one compound from the training data but not from validation/test.
2. We do (procedurally) the same splits as before, but with different train-test ratios.

Option 2 is safer to do w.r.t. retaining the split dimensionality.

### 0D Split
For the 0D split, we use a random train-test split.
Standard: We use a 80/10/10 split into train, val, and test set.

Truncated: 
- 40/30/30 -> 16k training samples
- 20/40/40 -> 8k training samples
- 10/45/45 -> 4k training samples
- 5/50/45 -> 2k training samples
- 2.5/50/47.5 -> 1k training samples
- 1.25/50/48.75 -> 500 training samples
- 0.625/50/49.375 -> 250 training samples

(numbers of training samples are averaged over folds and approximate)

### 1D Split
For the 1D split, we use a (1D) GroupShuffleSplit.
Standard: As groups, we use either initiator, monomer, or terminator (3x3 splits).

Truncated:
(for uniform results, we pick only one dimension to split on: Initiators)
- 80/10/10 -> 31,250 training samples (53 I) (this is not equivalent to the "standard" 1D split b/c here we split on initiators 9 times)
- 40/30/30 -> 15,600 training samples (26 I)
- 20/40/40 -> 8,100 training samples (13 I)
- 10/45/45 -> 3,680 training samples (6 I)
- 5/50/45 -> 1,668 training samples (3 I)
- 2.5/50/47.5 -> 487 training samples (1 I)
- 1.25/50/48.75 -> not possible b/c training set will be empty on some folds

(numbers of training samples are averaged over folds)

### 2D split
For the 2D split, we use a (2D) GroupShuffleSplit.
Standard: As groups, we use either [initiator, monomer], [monomer, terminator] or [initiator, terminator]. (3x3)

Truncated:
(for uniform results, we pick only one set of two dimensions to split on: Initiators & Monomers)
- 80/10/10 -> 24,562 training samples (53 I, 56 M)
- 60/20/20 -> 13,802 training samples (39 I, 41.7 M)
- 40/30/30 -> 6,046 training samples (26 I, 27.9 M)
- 30/35/35 -> 3,193 training samples (19 I, 20.6 M)
- 20/40/40 -> 1,475 training samples (13 I, 13.7 M)
- 15/45/40 -> 812 training samples (10 I, 10 M)
- 10/45/45 -> 355 training samples (6 I, 6.7 M)
- 7.5/47.5/45 -> 163 training samples (4 I, 4.9 M)
- 5/50/45 -> 78 training samples (3 I, 2.9 M)
- 2/(all-2)/48 -> 38 training samples (2 I, 1.8 M) (here we define the train/val split such that 2 groups are in train)

(numbers of training samples are averaged over folds. In total there are 67 I & 72 M)

### 3D split
For the 3D split, we use a (3D) GroupShuffleSplit.

Truncated:
- 80/10/10 -> 19,725 training samples (53 I, 55.8 M, 32 T) n.b. this has extremely few val/test samples
- 70/15/15 -> 12,975 training samples (46 I, 49.4 M, 28 T)
- 60/20/20 -> 7,948 training samples (39 I, 41.4 M, 24 T) (this should be close to the standard split)
- 50/25/25 -> 4,747 training samples (33 I, 35.5 M, 20 T)
- 40/30/30 -> 2,399 training samples (26 I, 27.8 M, 16 T)
- 34/33/33 -> 1,313 training samples (22 I, 23.3 M, 13 T)
- 30/35/35 -> 901 training samples (19 I, 20.3 M, 12 T)
- 25/40/35 -> 529 training samples (16 I, 16.9 M, 10 T)
- 20/40/40 -> 273 training samples (13 I, 13.3 M, 8 T)
- 15/45/40 -> 112 training samples (10 I, 9.9 M, 6 T)
- 10/45/45 -> 32 training samples (5.8 I, 6 M, 4 T)

## Diasteromers
For the 0D split, this is not important, but for 1D/2D/3D, we want to define the groups such that diastereomers will always be in the same group. These splits will receive a `_dia` suffix

## Synthetic data
For the 1D/2D/3D problems we use a synthetically amended data set. Splits of the synthetically ammended data set will receive a `_synthetic` suffix.

In [32]:
import pathlib
import sys

sys.path.append(str(pathlib.Path().resolve().parents[1]))

import numpy as np
import pandas as pd
from sklearn.model_selection import GroupShuffleSplit, ShuffleSplit

from src.definitions import DATA_DIR
from src.util.train_test_split import GroupShuffleSplitND
from util import write_indices_and_stats

In [21]:
# Load data
data_filename = "synferm_dataset_2023-09-05_40018records.csv"
data_name = data_filename.rsplit("_", maxsplit=1)[0]
df = pd.read_csv(DATA_DIR / "curated_data" / data_filename)
df.shape

(40018, 27)

In [22]:
df.head()

Unnamed: 0,I_long,M_long,T_long,product_A_smiles,I_smiles,M_smiles,T_smiles,reaction_smiles,reaction_smiles_atom_mapped,experiment_id,...,binary_H,scaled_A,scaled_B,scaled_C,scaled_D,scaled_E,scaled_F,scaled_G,scaled_H,major_A-C
0,2-Pyr003,Fused002,TerABT004,COc1ccc(CCOC(=O)N2C[C@H](NC(=O)c3cccc(Cl)n3)[C...,O=C(c1cccc(Cl)n1)[B-](F)(F)F.[K+],COc1ccc(CCOC(=O)N2C[C@@H]3NO[C@]4(OC5(CCCCC5)O...,Nc1ccc(F)cc1S,O=C(c1cccc(Cl)n1)[B-](F)(F)F.COc1ccc(CCOC(=O)N...,F[B-](F)(F)[C:2]([c:1]1[cH:16][cH:18][cH:20][c...,56113,...,0,0.036021,0.003427,0.0,0.020975,0.002958,0.941981,0.914281,0.0,A
1,2-Pyr003,Fused002,TerABT007,COc1ccc(CCOC(=O)N2C[C@H](NC(=O)c3cccc(Cl)n3)[C...,O=C(c1cccc(Cl)n1)[B-](F)(F)F.[K+],COc1ccc(CCOC(=O)N2C[C@@H]3NO[C@]4(OC5(CCCCC5)O...,Nc1cc(Br)ccc1S,O=C(c1cccc(Cl)n1)[B-](F)(F)F.COc1ccc(CCOC(=O)N...,F[B-](F)(F)[C:2]([c:1]1[cH:16][cH:18][cH:20][c...,56114,...,0,0.0,0.0,0.0,0.006159,0.364398,0.928851,1.106548,0.0,no_product
2,2-Pyr003,Fused002,TerABT013,COc1ccc(CCOC(=O)N2C[C@H](NC(=O)c3cccc(Cl)n3)[C...,O=C(c1cccc(Cl)n1)[B-](F)(F)F.[K+],COc1ccc(CCOC(=O)N2C[C@@H]3NO[C@]4(OC5(CCCCC5)O...,Nc1cc(C(F)(F)F)ccc1S,O=C(c1cccc(Cl)n1)[B-](F)(F)F.COc1ccc(CCOC(=O)N...,F[B-](F)(F)[C:2]([c:1]1[cH:16][cH:18][cH:20][c...,56106,...,1,0.0,0.0,0.0,0.014212,2.16642,1.013596,0.537785,0.05686,no_product
3,2-Pyr003,Fused002,TerABT014,COc1ccc(CCOC(=O)N2C[C@H](NC(=O)c3cccc(Cl)n3)[C...,O=C(c1cccc(Cl)n1)[B-](F)(F)F.[K+],COc1ccc(CCOC(=O)N2C[C@@H]3NO[C@]4(OC5(CCCCC5)O...,Nc1ccc(Cl)cc1S,O=C(c1cccc(Cl)n1)[B-](F)(F)F.COc1ccc(CCOC(=O)N...,F[B-](F)(F)[C:2]([c:1]1[cH:16][cH:18][cH:20][c...,56112,...,0,0.028915,0.005039,0.0,0.015578,0.504057,0.992614,0.890646,0.0,A
4,2-Pyr003,Fused002,TerTH001,COc1ccc(CCOC(=O)N2C[C@H](NC(=O)c3cccc(Cl)n3)[C...,O=C(c1cccc(Cl)n1)[B-](F)(F)F.[K+],COc1ccc(CCOC(=O)N2C[C@@H]3NO[C@]4(OC5(CCCCC5)O...,[Cl-].[NH3+]NC(=S)c1ccccc1,O=C(c1cccc(Cl)n1)[B-](F)(F)F.COc1ccc(CCOC(=O)N...,F[B-](F)(F)[C:2]([c:1]1[cH:13][cH:15][cH:17][c...,56109,...,0,0.350061,0.643219,0.0,0.031689,0.613596,0.109309,0.439018,0.0,B


## 0D split

In [4]:
splitter = ShuffleSplit(n_splits=9, test_size=0.3, random_state=42)
inner_splitter = ShuffleSplit(n_splits=1, test_size=0.3/0.7, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices = []
sizes = []
pos_class = []
for idx_train_val, idx_test in splitter.split(df):
    # inner split
    train, val = next(inner_splitter.split(idx_train_val))
    # use indices to index indices :P (we need to obtain indices referring to the original dataframe)
    idx_train = idx_train_val[train]
    idx_val = idx_train_val[val]
    # add to list
    indices.append((idx_train, idx_val, idx_test))
    sizes.append((len(idx_train), len(idx_val), len(idx_test)))
    pos_class.append(
        (np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_train]).to_numpy(), 
         np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_val]).to_numpy(), 
         np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_test]).to_numpy(),
        )
    )

print(sizes)
print(pos_class)

[(16006, 12006, 12006), (16006, 12006, 12006), (16006, 12006, 12006), (16006, 12006, 12006), (16006, 12006, 12006), (16006, 12006, 12006), (16006, 12006, 12006), (16006, 12006, 12006), (16006, 12006, 12006)]
[(array([13035,  9180,  4620]), array([9737, 6889, 3393]), array([9760, 6767, 3300])), (array([13026,  9250,  4570]), array([9784, 6854, 3421]), array([9722, 6732, 3322])), (array([13015,  9059,  4449]), array([9752, 6859, 3483]), array([9765, 6918, 3381])), (array([13069,  9112,  4517]), array([9725, 6877, 3388]), array([9738, 6847, 3408])), (array([12983,  9134,  4597]), array([9805, 6787, 3354]), array([9744, 6915, 3362])), (array([12960,  9081,  4556]), array([9772, 6818, 3359]), array([9800, 6937, 3398])), (array([13002,  9104,  4548]), array([9755, 6881, 3348]), array([9775, 6851, 3417])), (array([13000,  9137,  4558]), array([9771, 6897, 3404]), array([9761, 6802, 3351])), (array([12943,  9132,  4520]), array([9777, 6811, 3373]), array([9812, 6893, 3420]))]


In [5]:
write_indices_and_stats(
    indices, sizes, pos_class, split_dimension=0, save_indices=True, train_size=40, total_size=len(df), data_name=data_name
)

In [6]:
splitter = ShuffleSplit(n_splits=9, test_size=0.4, random_state=42)
inner_splitter = ShuffleSplit(n_splits=1, test_size=0.4/0.6, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices = []
sizes = []
pos_class = []
for idx_train_val, idx_test in splitter.split(df):
    # inner split
    train, val = next(inner_splitter.split(idx_train_val))
    # use indices to index indices :P (we need to obtain indices referring to the original dataframe)
    idx_train = idx_train_val[train]
    idx_val = idx_train_val[val]
    # add to list
    indices.append((idx_train, idx_val, idx_test))
    sizes.append((len(idx_train), len(idx_val), len(idx_test)))
    pos_class.append(
        (np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_train]).to_numpy(), 
         np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_val]).to_numpy(), 
         np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_test]).to_numpy(),
        )
    )

print(sizes)
print(pos_class)

[(8003, 16007, 16008), (8003, 16007, 16008), (8003, 16007, 16008), (8003, 16007, 16008), (8003, 16007, 16008), (8003, 16007, 16008), (8003, 16007, 16008), (8003, 16007, 16008), (8003, 16007, 16008)]
[(array([6550, 4649, 2271]), array([12944,  9129,  4586]), array([13038,  9058,  4456])), (array([6466, 4617, 2342]), array([13069,  9160,  4500]), array([12997,  9059,  4471])), (array([6551, 4536, 2207]), array([12973,  9093,  4626]), array([13008,  9207,  4480])), (array([6501, 4539, 2265]), array([13031,  9178,  4555]), array([13000,  9119,  4493])), (array([6498, 4513, 2249]), array([13019,  9155,  4566]), array([13015,  9168,  4498])), (array([6529, 4602, 2281]), array([12949,  9070,  4558]), array([13054,  9164,  4474])), (array([6487, 4562, 2178]), array([12956,  9066,  4527]), array([13089,  9208,  4608])), (array([6511, 4598, 2291]), array([12983,  9092,  4513]), array([13038,  9146,  4509])), (array([6470, 4519, 2216]), array([12997,  9134,  4563]), array([13065,  9183,  4534]))]

In [7]:
write_indices_and_stats(
    indices, sizes, pos_class, split_dimension=0, save_indices=True, train_size=20, total_size=len(df), data_name=data_name
)

In [8]:
splitter = ShuffleSplit(n_splits=9, test_size=0.45, random_state=42)
inner_splitter = ShuffleSplit(n_splits=1, test_size=0.45/0.55, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices = []
sizes = []
pos_class = []
for idx_train_val, idx_test in splitter.split(df):
    # inner split
    train, val = next(inner_splitter.split(idx_train_val))
    # use indices to index indices :P (we need to obtain indices referring to the original dataframe)
    idx_train = idx_train_val[train]
    idx_val = idx_train_val[val]
    # add to list
    indices.append((idx_train, idx_val, idx_test))
    sizes.append((len(idx_train), len(idx_val), len(idx_test)))
    pos_class.append(
        (np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_train]).to_numpy(), 
         np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_val]).to_numpy(), 
         np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_test]).to_numpy(),
        )
    )

print(sizes)
print(pos_class)

[(4001, 18008, 18009), (4001, 18008, 18009), (4001, 18008, 18009), (4001, 18008, 18009), (4001, 18008, 18009), (4001, 18008, 18009), (4001, 18008, 18009), (4001, 18008, 18009), (4001, 18008, 18009)]
[(array([3258, 2275, 1157]), array([14585, 10319,  5130]), array([14689, 10242,  5026])), (array([3258, 2307, 1155]), array([14636, 10319,  5132]), array([14638, 10210,  5026])), (array([3262, 2308, 1113]), array([14628, 10159,  5130]), array([14642, 10369,  5070])), (array([3285, 2327, 1165]), array([14622, 10281,  5084]), array([14625, 10228,  5064])), (array([3272, 2281, 1151]), array([14611, 10220,  5080]), array([14649, 10335,  5082])), (array([3248, 2253, 1158]), array([14584, 10245,  5086]), array([14700, 10338,  5069])), (array([3208, 2233, 1147]), array([14597, 10239,  4998]), array([14727, 10364,  5168])), (array([3266, 2335, 1138]), array([14622, 10266,  5108]), array([14644, 10235,  5067])), (array([3264, 2306, 1191]), array([14572, 10216,  5030]), array([14696, 10314,  5092]))]

In [9]:
write_indices_and_stats(
    indices, sizes, pos_class, split_dimension=0, save_indices=True, train_size=10, total_size=len(df), data_name=data_name
)

In [10]:
splitter = ShuffleSplit(n_splits=9, test_size=0.45, random_state=42)
inner_splitter = ShuffleSplit(n_splits=1, test_size=0.5/0.55, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices = []
sizes = []
pos_class = []
for idx_train_val, idx_test in splitter.split(df):
    # inner split
    train, val = next(inner_splitter.split(idx_train_val))
    # use indices to index indices :P (we need to obtain indices referring to the original dataframe)
    idx_train = idx_train_val[train]
    idx_val = idx_train_val[val]
    # add to list
    indices.append((idx_train, idx_val, idx_test))
    sizes.append((len(idx_train), len(idx_val), len(idx_test)))
    pos_class.append(
        (np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_train]).to_numpy(), 
         np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_val]).to_numpy(), 
         np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_test]).to_numpy(),
        )
    )

print(sizes)
print(pos_class)

[(2000, 20009, 18009), (2000, 20009, 18009), (2000, 20009, 18009), (2000, 20009, 18009), (2000, 20009, 18009), (2000, 20009, 18009), (2000, 20009, 18009), (2000, 20009, 18009), (2000, 20009, 18009)]
[(array([1639, 1146,  573]), array([16204, 11448,  5714]), array([14689, 10242,  5026])), (array([1626, 1170,  592]), array([16268, 11456,  5695]), array([14638, 10210,  5026])), (array([1644, 1121,  541]), array([16246, 11346,  5702]), array([14642, 10369,  5070])), (array([1639, 1153,  588]), array([16268, 11455,  5661]), array([14625, 10228,  5064])), (array([1642, 1150,  589]), array([16241, 11351,  5642]), array([14649, 10335,  5082])), (array([1632, 1140,  545]), array([16200, 11358,  5699]), array([14700, 10338,  5069])), (array([1600, 1111,  563]), array([16205, 11361,  5582]), array([14727, 10364,  5168])), (array([1615, 1185,  575]), array([16273, 11416,  5671]), array([14644, 10235,  5067])), (array([1630, 1154,  615]), array([16206, 11368,  5606]), array([14696, 10314,  5092]))]

In [11]:
write_indices_and_stats(
    indices, sizes, pos_class, split_dimension=0, save_indices=True, train_size=5, total_size=len(df), data_name=data_name
)

In [12]:
splitter = ShuffleSplit(n_splits=9, test_size=0.475, random_state=42)
inner_splitter = ShuffleSplit(n_splits=1, test_size=0.5/0.525, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices = []
sizes = []
pos_class = []
for idx_train_val, idx_test in splitter.split(df):
    # inner split
    train, val = next(inner_splitter.split(idx_train_val))
    # use indices to index indices :P (we need to obtain indices referring to the original dataframe)
    idx_train = idx_train_val[train]
    idx_val = idx_train_val[val]
    # add to list
    indices.append((idx_train, idx_val, idx_test))
    sizes.append((len(idx_train), len(idx_val), len(idx_test)))
    pos_class.append(
        (np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_train]).to_numpy(), 
         np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_val]).to_numpy(), 
         np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_test]).to_numpy(),
        )
    )

print(sizes)
print(pos_class)

[(1000, 20009, 19009), (1000, 20009, 19009), (1000, 20009, 19009), (1000, 20009, 19009), (1000, 20009, 19009), (1000, 20009, 19009), (1000, 20009, 19009), (1000, 20009, 19009), (1000, 20009, 19009)]
[(array([830, 593, 306]), array([16213, 11423,  5717]), array([15489, 10820,  5290])), (array([810, 583, 290]), array([16280, 11492,  5709]), array([15442, 10761,  5314])), (array([803, 548, 255]), array([16283, 11374,  5721]), array([15446, 10914,  5337])), (array([819, 569, 282]), array([16298, 11462,  5696]), array([15415, 10805,  5335])), (array([813, 548, 298]), array([16271, 11387,  5640]), array([15448, 10901,  5375])), (array([828, 567, 298]), array([16174, 11336,  5661]), array([15530, 10933,  5354])), (array([825, 563, 298]), array([16155, 11316,  5545]), array([15552, 10957,  5470])), (array([814, 589, 297]), array([16268, 11446,  5681]), array([15450, 10801,  5335])), (array([838, 565, 281]), array([16198, 11403,  5649]), array([15496, 10868,  5383]))]


In [13]:
write_indices_and_stats(
    indices, sizes, pos_class, split_dimension=0, save_indices=True, train_size=2.5, total_size=len(df), data_name=data_name
)

In [14]:
splitter = ShuffleSplit(n_splits=9, test_size=0.4875, random_state=42)
inner_splitter = ShuffleSplit(n_splits=1, test_size=0.5/0.5125, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices = []
sizes = []
pos_class = []
for idx_train_val, idx_test in splitter.split(df):
    # inner split
    train, val = next(inner_splitter.split(idx_train_val))
    # use indices to index indices :P (we need to obtain indices referring to the original dataframe)
    idx_train = idx_train_val[train]
    idx_val = idx_train_val[val]
    # add to list
    indices.append((idx_train, idx_val, idx_test))
    sizes.append((len(idx_train), len(idx_val), len(idx_test)))
    pos_class.append(
        (np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_train]).to_numpy(), 
         np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_val]).to_numpy(), 
         np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_test]).to_numpy(),
        )
    )

print(sizes)
print(pos_class)

[(500, 20009, 19509), (500, 20009, 19509), (500, 20009, 19509), (500, 20009, 19509), (500, 20009, 19509), (500, 20009, 19509), (500, 20009, 19509), (500, 20009, 19509), (500, 20009, 19509)]
[(array([425, 289, 142]), array([16221, 11443,  5736]), array([15886, 11104,  5435])), (array([409, 281, 142]), array([16283, 11516,  5719]), array([15840, 11039,  5452])), (array([399, 289, 144]), array([16286, 11344,  5684]), array([15847, 11203,  5485])), (array([421, 292, 149]), array([16315, 11472,  5693]), array([15796, 11072,  5471])), (array([416, 292, 152]), array([16251, 11349,  5644]), array([15865, 11195,  5517])), (array([408, 280, 146]), array([16193, 11335,  5675]), array([15931, 11221,  5492])), (array([394, 288, 118]), array([16192, 11313,  5600]), array([15946, 11235,  5595])), (array([401, 275, 126]), array([16269, 11456,  5699]), array([15862, 11105,  5488])), (array([408, 285, 155]), array([16223, 11392,  5628]), array([15901, 11159,  5530]))]


In [15]:
write_indices_and_stats(
    indices, sizes, pos_class, split_dimension=0, save_indices=True, train_size=1.25, total_size=len(df), data_name=data_name
)

In [23]:
splitter = ShuffleSplit(n_splits=9, test_size=0.49375, random_state=42)
inner_splitter = ShuffleSplit(n_splits=1, test_size=0.5/0.50625, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices = []
sizes = []
pos_class = []
for idx_train_val, idx_test in splitter.split(df):
    # inner split
    train, val = next(inner_splitter.split(idx_train_val))
    # use indices to index indices :P (we need to obtain indices referring to the original dataframe)
    idx_train = idx_train_val[train]
    idx_val = idx_train_val[val]
    # add to list
    indices.append((idx_train, idx_val, idx_test))
    sizes.append((len(idx_train), len(idx_val), len(idx_test)))
    pos_class.append(
        (np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_train]).to_numpy(), 
         np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_val]).to_numpy(), 
         np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_test]).to_numpy(),
        )
    )

print(sizes)
print(pos_class)

[(250, 20009, 19759), (250, 20009, 19759), (250, 20009, 19759), (250, 20009, 19759), (250, 20009, 19759), (250, 20009, 19759), (250, 20009, 19759), (250, 20009, 19759), (250, 20009, 19759)]
[(array([193, 138,  64]), array([16257, 11462,  5742]), array([16082, 11236,  5507])), (array([215, 152,  67]), array([16274, 11505,  5732]), array([16043, 11179,  5514])), (array([202, 137,  63]), array([16281, 11368,  5706]), array([16049, 11331,  5544])), (array([217, 155,  74]), array([16325, 11479,  5703]), array([15990, 11202,  5536])), (array([212, 146,  73]), array([16250, 11349,  5652]), array([16070, 11341,  5588])), (array([205, 145,  68]), array([16199, 11349,  5685]), array([16128, 11342,  5560])), (array([194, 138,  67]), array([16195, 11320,  5575]), array([16143, 11378,  5671])), (array([209, 142,  61]), array([16260, 11450,  5677]), array([16063, 11244,  5575])), (array([204, 139,  73]), array([16224, 11399,  5634]), array([16104, 11298,  5606]))]


In [24]:
write_indices_and_stats(
    indices, sizes, pos_class, split_dimension=0, save_indices=True, train_size=0.625, total_size=len(df), data_name=data_name
)

## 1D split

In [16]:
def split_1d(splitter, inner_splitter):
    indices = []
    sizes = []
    pos_class = []
    unique_initiators = []
    unique_monomers = []
    unique_terminators = []
    for idx_train_val, idx_test in splitter.split(list(range(len(df))), groups=df["I_long"]):
        # inner split
        train, val = next(inner_splitter.split(idx_train_val, groups=df["I_long"][idx_train_val]))
        # use indices to index indices :P (we need to obtain indices referring to the original dataframe)
        idx_train = idx_train_val[train]
        idx_val = idx_train_val[val]
        # add to list
        indices.append((idx_train, idx_val, idx_test))
        sizes.append((len(idx_train), len(idx_val), len(idx_test)))
        pos_class.append(
            (np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_train]).to_numpy(), 
             np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_val]).to_numpy(), 
             np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_test]).to_numpy(),
            )
        )
        unique_initiators.append((len(df['I_long'][idx_train].drop_duplicates()), len(df['I_long'][idx_val].drop_duplicates()), len(df['I_long'][idx_test].drop_duplicates())))
        unique_monomers.append((len(df['M_long'][idx_train].drop_duplicates()), len(df['M_long'][idx_val].drop_duplicates()), len(df['M_long'][idx_test].drop_duplicates())))
        unique_terminators.append((len(df['T_long'][idx_train].drop_duplicates()), len(df['T_long'][idx_val].drop_duplicates()), len(df['T_long'][idx_test].drop_duplicates())))
    
    return indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators


In [17]:
splitter = GroupShuffleSplit(n_splits=9, test_size=0.1, random_state=42)
inner_splitter = GroupShuffleSplit(n_splits=1, test_size=0.1/0.9, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_1d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(30974, 4318, 4726), (32233, 3546, 4239), (32488, 4020, 3510), (29520, 6094, 4404), (31196, 4704, 4118), (32186, 4686, 3146), (30353, 4855, 4810), (30367, 5326, 4325), (31930, 4193, 3895)]
[(array([25183, 17684,  8977]), array([3332, 2430, 1036]), array([4017, 2722, 1300])), (array([26439, 18610,  9216]), array([2886, 2070, 1084]), array([3207, 2156, 1013])), (array([26725, 18911,  9629]), array([3018, 1969,  909]), array([2789, 1956,  775])), (array([23667, 16692,  7963]), array([5073, 3440, 1898]), array([3792, 2704, 1452])), (array([25283, 17894,  8807]), array([3811, 2559, 1232]), array([3438, 2383, 1274])), (array([26321, 18423,  9394]), array([4005, 2925, 1418]), array([2206, 1488,  501])), (array([24318, 17247,  8815]), array([4093, 2681, 1212]), array([4121, 2908, 1286])), (array([24334, 17088,  8619]), array([4641, 3188, 1604]), array([3557, 2560, 1090])), (array([25753, 18053,  8735]), array([3495, 2373, 1448]), array([3284, 2410, 1130]))]
[(53, 7, 7), (53, 7, 7), (53, 7, 7)

In [18]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=1, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=80
)

In [19]:
splitter = GroupShuffleSplit(n_splits=9, test_size=0.3, random_state=42)
inner_splitter = GroupShuffleSplit(n_splits=1, test_size=0.3/0.7, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_1d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(17126, 10733, 12159), (16875, 11529, 11614), (15464, 12107, 12447), (15857, 11436, 12725), (15219, 12660, 12139), (15958, 12613, 11447), (14167, 11735, 14116), (16400, 11582, 12036), (13764, 13270, 12984)]
[(array([13955, 10001,  4992]), array([8318, 5697, 2725]), array([10259,  7138,  3596])), (array([13586,  9592,  4206]), array([9676, 6896, 3571]), array([9270, 6348, 3536])), (array([12577,  9177,  4237]), array([9769, 6737, 3738]), array([10186,  6922,  3338])), (array([13126,  9183,  4322]), array([9179, 6279, 3435]), array([10227,  7374,  3556])), (array([12787,  9003,  4484]), array([10139,  6770,  3242]), array([9606, 7063, 3587])), (array([13432,  9102,  4587]), array([10341,  7246,  4104]), array([8759, 6488, 2622])), (array([11156,  7946,  3718]), array([9398, 6653, 3430]), array([11978,  8237,  4165])), (array([13012,  9377,  4601]), array([9529, 6385, 3337]), array([9991, 7074, 3375])), (array([10817,  7721,  3918]), array([11257,  7827,  3686]), array([10458,  7288,  37

In [20]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=1, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=40
)

In [21]:
splitter = GroupShuffleSplit(n_splits=9, test_size=0.4, random_state=42)
inner_splitter = GroupShuffleSplit(n_splits=1, test_size=0.4/0.6, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_1d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(8818, 16152, 15048), (9070, 14595, 16353), (6805, 17281, 15932), (8445, 15367, 16206), (8844, 15312, 15862), (8004, 15372, 16642), (7254, 15824, 16940), (7631, 15875, 16512), (8090, 15386, 16542)]
[(array([7327, 5158, 2585]), array([13007,  8834,  4587]), array([12198,  8844,  4141])), (array([7765, 5613, 3111]), array([11387,  7972,  3326]), array([13380,  9251,  4876])), (array([5521, 3769, 2231]), array([14416, 10169,  5088]), array([12595,  8898,  3994])), (array([7227, 4980, 2131]), array([12271,  8555,  4770]), array([13034,  9301,  4412])), (array([7406, 5204, 2457]), array([12375,  8466,  4267]), array([12751,  9166,  4589])), (array([6713, 4551, 2029]), array([12612,  8699,  5094]), array([13207,  9586,  4190])), (array([5823, 4014, 2111]), array([12409,  9052,  4107]), array([14300,  9770,  5095])), (array([6262, 4246, 2082]), array([12499,  9039,  4738]), array([13771,  9551,  4493])), (array([6508, 4906, 2546]), array([12661,  8552,  4299]), array([13363,  9378,  4468]))]

In [22]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=1, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=20
)

In [23]:
splitter = GroupShuffleSplit(n_splits=9, test_size=0.45, random_state=42)
inner_splitter = GroupShuffleSplit(n_splits=1, test_size=0.45/0.55, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_1d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(3567, 18651, 17800), (4020, 17884, 18114), (4219, 17970, 17829), (3452, 17898, 18668), (4571, 16516, 18931), (3014, 17745, 19259), (2607, 18079, 19332), (3703, 16819, 19496), (3970, 17106, 18942)]
[(array([2839, 1990, 1035]), array([15154, 10265,  5393]), array([14539, 10581,  4885])), (array([3381, 2420, 1305]), array([14306, 10144,  4678]), array([14845, 10272,  5330])), (array([3556, 2425,  940]), array([14845, 10408,  5661]), array([14131, 10003,  4712])), (array([2843, 2001,  973]), array([14698, 10183,  5390]), array([14991, 10652,  4950])), (array([3966, 2813, 1626]), array([13229,  9101,  3854]), array([15337, 10922,  5833])), (array([2433, 1847, 1056]), array([14651,  9874,  4916]), array([15448, 11115,  5341])), (array([2127, 1476,  703]), array([14232, 10284,  4927]), array([16173, 11076,  5683])), (array([3037, 2046, 1161]), array([13186,  9452,  4910]), array([16309, 11338,  5242])), (array([3279, 2319, 1197]), array([13820,  9625,  4937]), array([15433, 10892,  5179]))]

In [24]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=1, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=10
)

In [25]:
splitter = GroupShuffleSplit(n_splits=9, test_size=0.45, random_state=42)
inner_splitter = GroupShuffleSplit(n_splits=1, test_size=0.5/0.55, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_1d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(1447, 20771, 17800), (2409, 19495, 18114), (1278, 20911, 17829), (1983, 19367, 18668), (1643, 19444, 18931), (1721, 19038, 19259), (1240, 19446, 19332), (1243, 19279, 19496), (2046, 19030, 18942)]
[(array([1081,  697,  339]), array([16912, 11558,  6089]), array([14539, 10581,  4885])), (array([2068, 1451,  729]), array([15619, 11113,  5254]), array([14845, 10272,  5330])), (array([1023,  721,  323]), array([17378, 12112,  6278]), array([14131, 10003,  4712])), (array([1696, 1194,  555]), array([15845, 10990,  5808]), array([14991, 10652,  4950])), (array([1359,  843,  524]), array([15836, 11071,  4956]), array([15337, 10922,  5833])), (array([1448, 1122,  633]), array([15636, 10599,  5339]), array([15448, 11115,  5341])), (array([1015,  727,  297]), array([15344, 11033,  5333]), array([16173, 11076,  5683])), (array([980, 582, 264]), array([15243, 10916,  5807]), array([16309, 11338,  5242])), (array([1710, 1187,  554]), array([15389, 10757,  5580]), array([15433, 10892,  5179]))]
[(

In [26]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=1, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=5
)

In [27]:
splitter = GroupShuffleSplit(n_splits=9, test_size=0.475, random_state=42)
inner_splitter = GroupShuffleSplit(n_splits=1, test_size=0.5/0.525, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_1d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(202, 21461, 18355), (947, 20029, 19042), (483, 21397, 18138), (609, 19716, 19693), (626, 19940, 19452), (472, 19775, 19771), (202, 20239, 19577), (484, 19470, 20064), (359, 20234, 19425)]
[(array([147,  77,  33]), array([17398, 11867,  6275]), array([14987, 10892,  5005])), (array([821, 622, 416]), array([16126, 11423,  5326]), array([15585, 10791,  5571])), (array([382, 265, 152]), array([17770, 12376,  6267]), array([14380, 10195,  4894])), (array([537, 385, 198]), array([16114, 11166,  5571]), array([15881, 11285,  5544])), (array([501, 258, 206]), array([16225, 11295,  5057]), array([15806, 11283,  6050])), (array([381, 236, 215]), array([16344, 11346,  5719]), array([15807, 11254,  5379])), (array([147,  77,  33]), array([16026, 11540,  5528]), array([16359, 11219,  5752])), (array([387, 238, 163]), array([15325, 10863,  5773]), array([16820, 11735,  5377])), (array([289, 197, 170]), array([16428, 11482,  5812]), array([15815, 11157,  5331]))]
[(1, 34, 32), (1, 34, 32), (1, 34, 

In [28]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=1, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=2.5
)

## 2D split

In [25]:
def split_2d(splitter, inner_splitter):
    indices = []
    sizes = []
    pos_class = []
    unique_initiators = []
    unique_monomers = []
    unique_terminators = []
    for idx_train_val, idx_test in splitter.split(df, groups=df[["I_long", "M_long"]]):
        train, val = next(inner_splitter.split(df.iloc[idx_train_val], groups=df[["I_long", "M_long"]].iloc[idx_train_val]))
        # use indices to index indices :P (we need to obtain indices referring to the original dataframe)
        idx_train = idx_train_val[train]
        idx_val = idx_train_val[val]
        indices.append((idx_train, idx_val, idx_test))
        sizes.append((len(idx_train), len(idx_val), len(idx_test)))
        pos_class.append(
            (np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_train]).to_numpy(), 
             np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_val]).to_numpy(), 
             np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_test]).to_numpy(),
            )
        )
        unique_initiators.append((len(df['I_long'][idx_train].drop_duplicates()), len(df['I_long'][idx_val].drop_duplicates()), len(df['I_long'][idx_test].drop_duplicates())))
        unique_monomers.append((len(df['M_long'][idx_train].drop_duplicates()), len(df['M_long'][idx_val].drop_duplicates()), len(df['M_long'][idx_test].drop_duplicates())))
        unique_terminators.append((len(df['T_long'][idx_train].drop_duplicates()), len(df['T_long'][idx_val].drop_duplicates()), len(df['T_long'][idx_test].drop_duplicates())))
    
    return indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators


In [30]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.1, random_state=42)
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.1/0.9, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_2d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(24438, 472, 469), (24006, 474, 548), (23389, 761, 463), (23859, 502, 537), (25007, 387, 379), (26332, 326, 337), (24546, 485, 386), (24453, 442, 478), (25030, 427, 488)]
[(array([20452, 14462,  7408]), array([326, 239, 107]), array([360, 238,  77])), (array([19472, 13912,  6769]), array([388, 264, 114]), array([460, 300, 170])), (array([18890, 13449,  6722]), array([679, 422, 213]), array([342, 235, 100])), (array([18304, 12804,  6444]), array([490, 357, 153]), array([503, 362, 201])), (array([21162, 15102,  7694]), array([313, 212,  85]), array([229, 152,  66])), (array([21735, 15521,  7678]), array([242, 152,  69]), array([259, 170,  69])), (array([19480, 13267,  6998]), array([431, 345, 193]), array([342, 257,  75])), (array([19647, 13674,  6917]), array([392, 270, 132]), array([402, 309, 112])), (array([20961, 14744,  7295]), array([363, 276, 108]), array([295, 196, 133]))]
[(53, 7, 7), (53, 7, 7), (53, 7, 7), (53, 7, 7), (53, 7, 7), (53, 7, 7), (53, 7, 7), (53, 7, 7), (53, 7, 7)

In [31]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=2, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=80
)

In [32]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.2, random_state=42)
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.2/0.8, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_2d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(13487, 1991, 1514), (13872, 1733, 1534), (13816, 1594, 1784), (14461, 1370, 1796), (12841, 1892, 1637), (14168, 1747, 1391), (14220, 1553, 1656), (13984, 2003, 1450), (13370, 1718, 1859)]
[(array([11112,  8085,  4194]), array([1542, 1040,  452]), array([1222,  788,  375])), (array([11203,  8084,  3768]), array([1488,  968,  470]), array([1260,  820,  452])), (array([11335,  8192,  3844]), array([1240,  842,  377]), array([1472, 1001,  602])), (array([11934,  8567,  4539]), array([996, 619, 305]), array([1568, 1181,  534])), (array([10887,  7434,  3851]), array([1586, 1159,  596]), array([1163,  887,  374])), (array([11605,  8092,  4086]), array([1459, 1026,  445]), array([1069,  774,  394])), (array([11709,  8388,  4458]), array([1210,  839,  451]), array([1398,  928,  344])), (array([11280,  8197,  3744]), array([1731, 1122,  746]), array([1163,  822,  343])), (array([11168,  7770,  3830]), array([1465, 1020,  498]), array([1337,  904,  507]))]
[(39, 14, 14), (39, 14, 14), (39, 14, 

In [33]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=2, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=60
)

In [34]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.3, random_state=42)
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.3/0.7, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_2d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(6732, 3274, 3641), (6598, 3524, 3549), (6311, 3235, 4161), (6455, 3049, 3847), (6063, 3555, 3991), (5039, 4358, 3790), (5776, 3168, 4655), (6064, 3731, 3511), (5541, 3686, 3955)]
[(array([5502, 4059, 1972]), array([2535, 1724,  883]), array([3086, 2040,  992])), (array([5435, 3702, 1633]), array([2789, 2048,  943]), array([2942, 1964, 1161])), (array([5067, 3621, 1274]), array([2633, 1886, 1225]), array([3435, 2273, 1267])), (array([5137, 3694, 1987]), array([2550, 1682,  852]), array([3155, 2273,  951])), (array([5110, 3600, 1699]), array([2825, 1930,  928]), array([3182, 2321, 1222])), (array([4396, 3028, 1372]), array([3423, 2272, 1322]), array([2969, 2225,  999])), (array([4466, 3013, 1698]), array([2467, 1878,  992]), array([4130, 2821, 1223])), (array([4918, 3498, 1709]), array([2938, 1989,  909]), array([2939, 2100, 1006])), (array([4619, 3280, 1443]), array([3284, 2319, 1144]), array([2943, 2010, 1148]))]
[(26, 20, 21), (26, 20, 21), (26, 20, 21), (26, 20, 21), (26, 20, 21), 

In [35]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=2, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=40
)

In [36]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.35, random_state=42)
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.35/0.65, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_2d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(3195, 5660, 4556), (3036, 5479, 4942), (3414, 4694, 5619), (3571, 4430, 5338), (2941, 5192, 5351), (2943, 4878, 5694), (2810, 4733, 5834), (3255, 5278, 4870), (3570, 4140, 5444)]
[(array([2669, 1826,  854]), array([4682, 3209, 1837]), array([3695, 2644, 1125])), (array([2546, 1771,  741]), array([4223, 3005, 1324]), array([4206, 2875, 1748])), (array([2770, 1863,  867]), array([3799, 2913, 1349]), array([4539, 3039, 1592])), (array([2829, 2021,  975]), array([3507, 2342, 1458]), array([4441, 3161, 1223])), (array([2386, 1598,  785]), array([4329, 3077, 1309]), array([4227, 3024, 1690])), (array([2453, 1650,  832]), array([4266, 2913, 1622]), array([4364, 3223, 1499])), (array([2272, 1609,  872]), array([3470, 2544, 1291]), array([5135, 3511, 1625])), (array([2599, 1934,  744]), array([4196, 2812, 1614]), array([4140, 2966, 1400])), (array([3006, 2057, 1020]), array([3613, 2519, 1286]), array([4104, 2891, 1416]))]
[(19, 24, 24), (19, 24, 24), (19, 24, 24), (19, 24, 24), (19, 24, 24), 

In [37]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=2, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=30
)

In [38]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.4, random_state=42)
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.4/0.6, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_2d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(1666, 6503, 5939), (1575, 6219, 6561), (1684, 6033, 6737), (1340, 6735, 6581), (1372, 6910, 6673), (1495, 5990, 7181), (1290, 6866, 7003), (1187, 7236, 6347), (1665, 5757, 6622)]
[(array([1482, 1080,  567]), array([5239, 3492, 1849]), array([4729, 3350, 1488])), (array([1198,  887,  355]), array([4943, 3443, 1459]), array([5646, 3874, 2311])), (array([1320,  886,  446]), array([5144, 3706, 1920]), array([5346, 3711, 1840])), (array([1137,  847,  480]), array([5256, 3502, 1971]), array([5477, 3864, 1515])), (array([1179,  840,  208]), array([5606, 3858, 2092]), array([5382, 3877, 2118])), (array([1297,  871,  489]), array([5053, 3538, 1796]), array([5495, 3986, 1837])), (array([784, 616, 286]), array([5665, 3963, 2108]), array([6174, 4167, 1995])), (array([820, 530, 336]), array([6092, 4468, 1920]), array([5390, 3822, 1908])), (array([1534, 1124,  526]), array([4827, 3400, 1665]), array([4958, 3435, 1615]))]
[(13, 27, 27), (13, 27, 27), (13, 27, 27), (13, 27, 27), (13, 27, 27), (13, 2

In [39]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=2, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=20
)

In [40]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.4, random_state=42)
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.45/0.6, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_2d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(887, 8590, 5939), (924, 7723, 6561), (876, 7982, 6737), (598, 8679, 6581), (807, 8291, 6673), (843, 7544, 7181), (740, 8326, 7003), (670, 8802, 6347), (967, 7454, 6622)]
[(array([765, 548, 342]), array([7068, 4791, 2428]), array([4729, 3350, 1488])), (array([660, 516, 225]), array([6235, 4332, 1817]), array([5646, 3874, 2311])), (array([687, 445, 224]), array([6735, 4835, 2466]), array([5346, 3711, 1840])), (array([479, 367, 235]), array([6926, 4630, 2565]), array([5477, 3864, 1515])), (array([687, 507, 130]), array([6764, 4636, 2416]), array([5382, 3877, 2118])), (array([714, 470, 254]), array([6411, 4486, 2322]), array([5495, 3986, 1837])), (array([425, 332, 173]), array([6744, 4782, 2448]), array([6174, 4167, 1995])), (array([454, 278, 195]), array([7284, 5319, 2370]), array([5390, 3822, 1908])), (array([908, 670, 279]), array([6249, 4427, 2270]), array([4958, 3435, 1615]))]
[(10, 30, 27), (10, 30, 27), (10, 30, 27), (10, 30, 27), (10, 30, 27), (10, 30, 27), (10, 30, 27), (10, 30,

In [41]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=2, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=15
)

In [66]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.45, random_state=42)
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.45/0.55, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_2d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(307, 8784, 8136), (284, 8807, 8243), (393, 7890, 8401), (431, 7717, 8400), (365, 7691, 8696), (462, 7456, 8985), (274, 7994, 8884), (399, 7121, 8981), (277, 7608, 8893)]
[(array([249, 189,  95]), array([7330, 4955, 2632]), array([6501, 4663, 2047])), (array([264, 155,  81]), array([6678, 4873, 2009]), array([7125, 4859, 2826])), (array([356, 240, 111]), array([6490, 4536, 2249]), array([6618, 4614, 2411])), (array([392, 283, 141]), array([6016, 4264, 2443]), array([6913, 4834, 2018])), (array([284, 183,  57]), array([6464, 4569, 2136]), array([6986, 4946, 2789])), (array([433, 279, 143]), array([6172, 4361, 2167]), array([6985, 5014, 2466])), (array([173, 161,  70]), array([6496, 4560, 2305]), array([7538, 5098, 2514])), (array([381, 234, 139]), array([5225, 3659, 1726]), array([7804, 5552, 2862])), (array([261, 168,  90]), array([6359, 4526, 2380]), array([6899, 4839, 2192]))]
[(6, 30, 31), (6, 30, 31), (6, 30, 31), (6, 30, 31), (6, 30, 31), (6, 30, 31), (6, 30, 31), (6, 30, 31), (6

In [67]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=2, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=10
)

In [26]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.45, random_state=42)
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.475/0.55, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_2d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(135, 9831, 8136), (120, 9574, 8243), (166, 9255, 8401), (192, 8885, 8400), (194, 8599, 8696), (192, 8641, 8985), (178, 8674, 8884), (157, 8484, 8981), (137, 8360, 8893)]
[(array([110,  80,  39]), array([8170, 5550, 2913]), array([6501, 4663, 2047])), (array([112,  71,  33]), array([7354, 5310, 2250]), array([7125, 4859, 2826])), (array([149,  89,  30]), array([7743, 5479, 2710]), array([6618, 4614, 2411])), (array([169, 119,  56]), array([7046, 5003, 2776]), array([6913, 4834, 2018])), (array([165,  98,  19]), array([7038, 4976, 2356]), array([6986, 4946, 2789])), (array([186, 118,  72]), array([7180, 5058, 2447]), array([6985, 5014, 2466])), (array([137, 101,  56]), array([6817, 4940, 2419]), array([7538, 5098, 2514])), (array([149,  94,  74]), array([6361, 4403, 2074]), array([7804, 5552, 2862])), (array([132,  84,  54]), array([7003, 4948, 2571]), array([6899, 4839, 2192]))]
[(4, 32, 31), (4, 32, 31), (4, 32, 31), (4, 32, 31), (4, 32, 31), (4, 32, 31), (4, 32, 31), (4, 32, 31), (4

In [27]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=2, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=7.5
)

In [16]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.45, random_state=42)
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.5/0.55, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_2d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(74, 10547, 8136), (68, 10281, 8243), (83, 10254, 8401), (106, 9490, 8400), (79, 9433, 8696), (101, 9340, 8985), (73, 9492, 8884), (58, 9409, 8981), (56, 9082, 8893)]
[(array([55, 45, 20]), array([8837, 6040, 3209]), array([6501, 4663, 2047])), (array([63, 40, 19]), array([7967, 5741, 2456]), array([7125, 4859, 2826])), (array([73, 52, 19]), array([8639, 6046, 2961]), array([6618, 4614, 2411])), (array([93, 64, 40]), array([7559, 5361, 2947]), array([6913, 4834, 2018])), (array([67, 44, 16]), array([7737, 5428, 2472]), array([6986, 4946, 2789])), (array([98, 62, 41]), array([7809, 5499, 2681]), array([6985, 5014, 2466])), (array([58, 47, 27]), array([7456, 5375, 2671]), array([7538, 5098, 2514])), (array([52, 26, 28]), array([7195, 5024, 2343]), array([7804, 5552, 2862])), (array([54, 38, 22]), array([7666, 5358, 2851]), array([6899, 4839, 2192]))]
[(3, 33, 31), (3, 33, 31), (3, 33, 31), (3, 33, 31), (3, 33, 31), (3, 33, 31), (3, 33, 31), (3, 33, 31), (3, 33, 31)]
[(3, 36, 33), (3, 35

In [17]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=2, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=5
)

In [18]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.48, random_state=42)
inner_splitter = GroupShuffleSplitND(n_splits=1, train_size=2, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_2d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(24, 10123, 9185), (55, 9069, 9520), (51, 9881, 9300), (31, 8707, 10012), (64, 8614, 9870), (30, 8943, 9881), (20, 9275, 9802), (47, 8401, 9924), (19, 8818, 10014)]
[(array([14, 16, 15]), array([8651, 5873, 3227]), array([7157, 5133, 2187])), (array([35, 24, 10]), array([7056, 5116, 2187]), array([8252, 5670, 3174])), (array([43, 30, 15]), array([8277, 5759, 2737]), array([7369, 5156, 2618])), (array([21, 14, 11]), array([6997, 4911, 2834]), array([8316, 5793, 2487])), (array([59, 42, 10]), array([6930, 4890, 2306]), array([8007, 5597, 3086])), (array([24, 16, 13]), array([7570, 5274, 2564]), array([7670, 5433, 2682])), (array([16,  6, 11]), array([7171, 5246, 2523]), array([8379, 5709, 2796])), (array([42, 28,  2]), array([6404, 4417, 2233]), array([8506, 5984, 3071])), (array([18,  9,  2]), array([7507, 5286, 2844]), array([7831, 5474, 2458]))]
[(2, 32, 33), (2, 32, 33), (2, 32, 33), (2, 32, 33), (2, 32, 33), (2, 32, 33), (2, 32, 33), (2, 32, 33), (2, 32, 33)]
[(1, 35, 35), (2, 34, 

In [19]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=2, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=2
)

## 3D split

In [29]:
def split_3d(splitter, inner_splitter):
    indices = []
    sizes = []
    pos_class = []
    unique_initiators = []
    unique_monomers = []
    unique_terminators = []
    for idx_train_val, idx_test in splitter.split(df, groups=df[["I_long", "M_long", "T_long"]]):
        train, val = next(inner_splitter.split(df.iloc[idx_train_val], groups=df[["I_long", "M_long", "T_long"]].iloc[idx_train_val]))
        # use indices to index indices :P (we need to obtain indices referring to the original dataframe)
        idx_train = idx_train_val[train]
        idx_val = idx_train_val[val]
        indices.append((idx_train, idx_val, idx_test))
        sizes.append((len(idx_train), len(idx_val), len(idx_test)))
        pos_class.append(
            (np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_train]).to_numpy(), 
             np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_val]).to_numpy(), 
             np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_test]).to_numpy(),
            )
        )
        unique_initiators.append((len(df['I_long'][idx_train].drop_duplicates()), len(df['I_long'][idx_val].drop_duplicates()), len(df['I_long'][idx_test].drop_duplicates())))
        unique_monomers.append((len(df['M_long'][idx_train].drop_duplicates()), len(df['M_long'][idx_val].drop_duplicates()), len(df['M_long'][idx_test].drop_duplicates())))
        unique_terminators.append((len(df['T_long'][idx_train].drop_duplicates()), len(df['T_long'][idx_val].drop_duplicates()), len(df['T_long'][idx_test].drop_duplicates())))

    return indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators


In [43]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.1, random_state=np.random.RandomState(42))  # here, we reuse the outer splitter as well, so we use RandomState
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.1/0.9, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_3d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(18634, 57, 70), (18901, 24, 68), (19587, 46, 58), (19767, 42, 55), (19749, 34, 55), (19761, 42, 52), (21450, 29, 37), (20481, 49, 71), (19192, 40, 60)]
[(array([15388, 11524,  5878]), array([48, 32, 17]), array([48, 18,  8])), (array([15805, 11283,  5801]), array([19,  9,  3]), array([48, 39, 15])), (array([15675, 10099,  4959]), array([38, 34, 27]), array([52, 52, 21])), (array([16324, 11591,  5650]), array([38, 23, 13]), array([44, 28, 24])), (array([15754, 11536,  5464]), array([30, 19, 13]), array([52, 26, 12])), (array([15798, 10907,  5512]), array([33, 32, 13]), array([46, 32, 18])), (array([17467, 12645,  6413]), array([29, 22, 15]), array([24, 13,  9])), (array([16946, 12236,  5895]), array([37, 28, 12]), array([53, 33, 14])), (array([15591, 10637,  5232]), array([33, 25, 23]), array([49, 29, 13]))]
[(53, 5, 7), (53, 6, 7), (53, 6, 7), (53, 7, 7), (53, 6, 7), (53, 7, 7), (53, 6, 4), (53, 7, 7), (53, 7, 7)]
[(56, 8, 8), (56, 7, 8), (55, 8, 8), (56, 8, 8), (56, 8, 8), (56, 8, 7

In [44]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=3, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=80
)

In [45]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.15, random_state=np.random.RandomState(42))  # here, we reuse the outer splitter as well, so we use RandomState
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.15/0.85, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_3d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(13480, 125, 163), (12588, 126, 203), (12701, 131, 197), (12340, 143, 214), (12915, 137, 165), (12861, 181, 153), (13455, 114, 160), (13015, 171, 179), (13418, 110, 159)]
[(array([11142,  8344,  3837]), array([91, 50, 32]), array([120,  64,  37])), (array([10429,  7758,  3607]), array([111,  60,  50]), array([145, 106,  43])), (array([9713, 5763, 3165]), array([113, 115,  63]), array([167, 128,  56])), (array([10099,  7226,  3193]), array([107,  72,  24]), array([166, 104, 101])), (array([10370,  7369,  3353]), array([110,  80,  62]), array([140,  89,  40])), (array([10247,  7148,  3927]), array([127,  83,  22]), array([134,  96,  46])), (array([10577,  7325,  3631]), array([112, 106,  65]), array([115,  54,  31])), (array([10360,  7376,  3385]), array([156, 106,  69]), array([125,  65,  45])), (array([10979,  7959,  3950]), array([75, 40, 16]), array([131,  80,  54]))]
[(46, 10, 11), (46, 10, 11), (46, 10, 11), (46, 10, 11), (46, 10, 11), (46, 10, 11), (46, 9, 11), (46, 10, 11), (46,

In [46]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=3, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=70
)

In [47]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.2, random_state=np.random.RandomState(42))  # here, we reuse the outer splitter as well, so we use RandomState
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.2/0.8, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_3d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(8180, 355, 367), (7995, 305, 439), (7184, 424, 455), (8145, 272, 528), (8042, 338, 416), (7640, 428, 358), (8029, 398, 304), (7631, 423, 359), (8682, 261, 368)]
[(array([6663, 4844, 2386]), array([325, 243,  79]), array([239, 133,  76])), (array([6600, 4614, 2500]), array([273, 244, 104]), array([310, 203,  77])), (array([5690, 3625, 1874]), array([344, 298, 137]), array([394, 264, 142])), (array([6498, 4966, 2019]), array([234, 112,  67]), array([437, 321, 222])), (array([6365, 4554, 2256]), array([248, 215, 118]), array([371, 220,  87])), (array([6038, 4239, 2172]), array([333, 219,  80]), array([317, 235, 125])), (array([6296, 4447, 2500]), array([368, 305, 104]), array([248, 143,  75])), (array([6194, 4239, 1992]), array([301, 272,  97]), array([274, 177, 118])), (array([7278, 5059, 2877]), array([235, 166,  32]), array([265, 185,  99]))]
[(39, 14, 14), (39, 13, 14), (39, 14, 14), (39, 14, 14), (39, 14, 14), (39, 14, 14), (39, 14, 13), (39, 14, 14), (39, 14, 14)]
[(41, 15, 15), (

In [48]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=3, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=60
)

In [49]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.25, random_state=np.random.RandomState(42))  # here, we reuse the outer splitter as well, so we use RandomState
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.25/0.75, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_3d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(4803, 659, 624), (4989, 420, 732), (4653, 547, 750), (4180, 622, 1001), (4543, 601, 749), (4609, 686, 754), (5063, 620, 614), (4704, 726, 688), (5179, 498, 626)]
[(array([3879, 2645, 1181]), array([582, 540, 293]), array([444, 209, 123])), (array([4240, 3169, 1745]), array([350, 256, 101]), array([512, 340, 113])), (array([3881, 2479, 1365]), array([379, 287, 128]), array([670, 512, 246])), (array([3400, 2299,  924]), array([521, 364, 151]), array([824, 594, 394])), (array([3494, 2604, 1212]), array([454, 333, 193]), array([668, 390, 199])), (array([3696, 2765, 1328]), array([508, 256, 148]), array([675, 538, 265])), (array([3967, 2916, 1235]), array([538, 348, 212]), array([513, 312, 162])), (array([3744, 2676, 1176]), array([580, 435, 216]), array([553, 318, 209])), (array([4428, 3311, 1840]), array([395, 225,  93]), array([466, 314, 177]))]
[(33, 17, 17), (33, 17, 17), (33, 17, 17), (33, 17, 17), (33, 17, 17), (33, 17, 17), (33, 17, 17), (33, 17, 17), (33, 17, 17)]
[(36, 17, 18), 

In [50]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=3, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=50
)

In [51]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.3, random_state=np.random.RandomState(42))  # here, we reuse the outer splitter as well, so we use RandomState
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.3/0.7, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_3d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(2592, 936, 1166), (2234, 880, 1312), (2607, 886, 1293), (2520, 865, 1615), (2250, 951, 1355), (2489, 1009, 1259), (2421, 1059, 1137), (2245, 1191, 1227), (2237, 1095, 1235)]
[(array([2195, 1403,  645]), array([762, 684, 305]), array([897, 515, 325])), (array([1993, 1418,  731]), array([663, 535, 288]), array([970, 582, 235])), (array([2100, 1541,  775]), array([653, 397, 205]), array([1136,  790,  354])), (array([1997, 1518,  537]), array([776, 519, 320]), array([1283,  886,  538])), (array([1779, 1379,  639]), array([813, 605, 338]), array([1110,  702,  328])), (array([2076, 1309,  652]), array([655, 433, 208]), array([1154,  936,  455])), (array([1873, 1280,  526]), array([860, 732, 403]), array([942, 546, 279])), (array([1633, 1094,  391]), array([1078,  821,  464]), array([977, 622, 404])), (array([1807, 1193,  584]), array([953, 688, 297]), array([967, 668, 304]))]
[(26, 20, 21), (26, 20, 21), (26, 20, 21), (26, 20, 21), (26, 20, 21), (26, 20, 21), (26, 20, 21), (26, 20, 21), (2

In [52]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=3, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=40
)

In [53]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.33, random_state=np.random.RandomState(42))  # here, we reuse the outer splitter as well, so we use RandomState
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.33/0.67, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_3d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(1469, 1697, 1420), (1481, 1410, 1689), (1265, 1544, 1612), (1466, 1116, 2033), (1273, 1431, 1789), (1282, 1540, 1730), (1003, 1770, 1550), (1407, 1463, 1582), (1175, 1829, 1484)]
[(array([1247,  911,  410]), array([1423, 1048,  576]), array([1031,  658,  396])), (array([1305,  970,  441]), array([1153,  844,  508]), array([1256,  767,  339])), (array([1115,  743,  337]), array([1208,  891,  567]), array([1348,  899,  416])), (array([1340,  824,  387]), array([847, 722, 301]), array([1641, 1158,  660])), (array([908, 653, 294]), array([1237,  893,  482]), array([1501,  995,  507])), (array([1052,  579,  261]), array([1090,  823,  422]), array([1568, 1225,  600])), (array([732, 611, 309]), array([1474, 1089,  535]), array([1261,  703,  377])), (array([1192,  987,  450]), array([1082,  634,  224]), array([1286,  791,  527])), (array([941, 573, 276]), array([1553, 1148,  605]), array([1174,  817,  341]))]
[(22, 22, 23), (22, 22, 23), (22, 22, 23), (22, 22, 23), (22, 22, 23), (22, 22, 23)

In [54]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=3, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=34
)

In [55]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.35, random_state=np.random.RandomState(42))  # here, we reuse the outer splitter as well, so we use RandomState
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.35/0.65, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_3d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(867, 2306, 1615), (946, 1625, 2038), (911, 1659, 2033), (926, 1464, 2457), (963, 1670, 2122), (861, 1726, 2107), (847, 1898, 1914), (948, 1691, 1937), (844, 1927, 1710)]
[(array([776, 567, 255]), array([1954, 1334,  625]), array([1125,  726,  428])), (array([822, 636, 301]), array([1374, 1031,  532]), array([1464,  859,  379])), (array([810, 544, 303]), array([1227,  833,  456]), array([1726, 1209,  576])), (array([737, 555, 216]), array([1272,  783,  388]), array([2006, 1465,  783])), (array([711, 527, 231]), array([1374, 1042,  587]), array([1784, 1154,  559])), (array([573, 276, 175]), array([1452, 1078,  457]), array([1934, 1525,  750])), (array([731, 517, 308]), array([1449, 1138,  506]), array([1528,  899,  444])), (array([821, 515, 263]), array([1303, 1003,  388]), array([1592, 1017,  652])), (array([706, 412, 201]), array([1549, 1178,  608]), array([1375,  971,  452]))]
[(19, 24, 24), (19, 24, 24), (19, 24, 24), (19, 24, 24), (19, 24, 24), (19, 24, 24), (19, 24, 24), (19, 24,

In [56]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=3, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=30
)

In [60]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.35, random_state=np.random.RandomState(42))  # here, we reuse the outer splitter as well, so we use RandomState
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.40/0.65, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_3d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(529, 3186, 1615), (526, 2437, 2038), (513, 2483, 2033), (583, 2087, 2457), (579, 2464, 2122), (469, 2672, 2107), (498, 2627, 1914), (608, 2410, 1937), (456, 2926, 1710)]
[(array([454, 312, 125]), array([2758, 1955, 1002]), array([1125,  726,  428])), (array([453, 368, 156]), array([2094, 1519,  817]), array([1464,  859,  379])), (array([472, 329, 155]), array([1809, 1204,  698]), array([1726, 1209,  576])), (array([456, 335, 105]), array([1789, 1194,  606]), array([2006, 1465,  783])), (array([452, 338, 134]), array([1924, 1457,  831]), array([1784, 1154,  559])), (array([331, 163,  85]), array([2154, 1545,  692]), array([1934, 1525,  750])), (array([427, 324, 221]), array([2048, 1561,  646]), array([1528,  899,  444])), (array([519, 380, 180]), array([1820, 1252,  517]), array([1592, 1017,  652])), (array([357, 212, 111]), array([2413, 1745,  870]), array([1375,  971,  452]))]
[(16, 27, 24), (16, 27, 24), (16, 27, 24), (16, 27, 24), (16, 27, 24), (16, 27, 24), (16, 27, 24), (16, 27,

In [61]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=3, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=25
)

In [62]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.40, random_state=np.random.RandomState(42))  # here, we reuse the outer splitter as well, so we use RandomState
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.40/0.60, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_3d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(288, 2867, 2501), (260, 2468, 2996), (238, 2624, 2961), (207, 2466, 3626), (231, 2461, 2837), (296, 2424, 2722), (299, 2558, 2693), (325, 2519, 2577), (315, 2425, 2558)]
[(array([250, 235, 111]), array([2508, 1777,  802]), array([1800, 1029,  572])), (array([228, 159, 109]), array([2104, 1643,  713]), array([2175, 1394,  661])), (array([205, 164,  71]), array([2041, 1282,  731]), array([2557, 1857,  934])), (array([163, 142,  28]), array([2082, 1397,  671]), array([2930, 2051, 1136])), (array([182, 123,  46]), array([1897, 1332,  685]), array([2351, 1616,  791])), (array([190, 156,  74]), array([1880,  958,  482]), array([2500, 1990, 1013])), (array([198, 174,  96]), array([2251, 1510,  759]), array([2194, 1299,  654])), (array([262, 187,  78]), array([1971, 1443,  596]), array([2111, 1382,  835])), (array([221, 140,  75]), array([2032, 1361,  612]), array([2130, 1573,  797]))]
[(13, 27, 27), (13, 27, 27), (13, 27, 27), (13, 27, 27), (13, 27, 27), (13, 27, 27), (13, 27, 27), (13, 27,

In [63]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=3, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=20
)

In [30]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.40, random_state=np.random.RandomState(42))  # here, we reuse the outer splitter as well, so we use RandomState
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.45/0.60, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_3d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(135, 3999, 2501), (102, 3469, 2996), (101, 3495, 2961), (94, 3311, 3626), (104, 3269, 2837), (117, 3629, 2722), (98, 3763, 2693), (139, 3511, 2577), (119, 3499, 2558)]
[(array([124, 120,  68]), array([3421, 2424, 1039]), array([1800, 1029,  572])), (array([96, 62, 44]), array([2957, 2319,  992]), array([2175, 1394,  661])), (array([79, 58, 20]), array([2815, 1843, 1049]), array([2557, 1857,  934])), (array([77, 74, 16]), array([2795, 1823,  839]), array([2930, 2051, 1136])), (array([68, 54, 14]), array([2621, 1827, 1013]), array([2351, 1616,  791])), (array([62, 52, 27]), array([2926, 1601,  761]), array([2500, 1990, 1013])), (array([57, 50, 20]), array([3198, 2313, 1189]), array([2194, 1299,  654])), (array([111,  72,  24]), array([2757, 2023,  886]), array([2111, 1382,  835])), (array([68, 37, 23]), array([2900, 1968,  923]), array([2130, 1573,  797]))]
[(10, 30, 27), (10, 30, 27), (10, 30, 27), (10, 30, 27), (10, 30, 27), (10, 30, 27), (10, 30, 27), (10, 30, 27), (10, 30, 27)]
[(1

In [31]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=3, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=15
)

In [64]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.45, random_state=np.random.RandomState(42))  # here, we reuse the outer splitter as well, so we use RandomState
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.45/0.55, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_3d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(25, 3774, 3832), (23, 3353, 4388), (31, 3228, 4456), (18, 3117, 5080), (28, 3321, 4109), (48, 3350, 3860), (58, 3527, 3814), (26, 3760, 3773), (30, 3571, 3907)]
[(array([24, 12,  6]), array([3217, 2413, 1276]), array([2913, 1889,  961])), (array([19,  9,  9]), array([2912, 2216,  954]), array([3257, 2099, 1026])), (array([31, 25, 10]), array([2504, 1671,  910]), array([3780, 2685, 1378])), (array([15, 12,  5]), array([2528, 1704,  855]), array([4179, 2871, 1417])), (array([23, 15,  8]), array([2675, 1909,  929]), array([3456, 2346, 1115])), (array([36,  5,  7]), array([2521, 1740,  754]), array([3408, 2754, 1397])), (array([58, 40, 14]), array([2817, 2279, 1129]), array([3012, 1820,  947])), (array([21, 17,  7]), array([3190, 2205, 1057]), array([2950, 1908, 1102])), (array([26, 18,  8]), array([2742, 1611,  777]), array([3316, 2495, 1219]))]
[(6, 30, 31), (6, 30, 31), (6, 30, 31), (6, 30, 31), (6, 30, 31), (5, 30, 31), (5, 30, 31), (6, 30, 31), (6, 30, 31)]
[(7, 32, 33), (6, 32, 33)

In [65]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=3, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=10
)