# Data Splits Truncated

Here, we split in the same way as in `data_splits_truncated.ipynb`, but then we restrict the number of points in the training data, but on the new data set `2023-12-20` and with a few adjustments.

## Diasteromers
For the 0D split, this is not important, but for 1D/2D/3D, we want to define the groups such that diastereomers will always be in the same group. These splits will receive a `_dia` suffix.

## Synthetic data
For the 1D/2D/3D problems we use a synthetically amended data set. Splits of the synthetically ammended data set will receive a `_synthetic` suffix.

In [1]:
import pathlib
import sys

sys.path.append(str(pathlib.Path().resolve().parents[1]))

import numpy as np
import pandas as pd
from sklearn.model_selection import GroupShuffleSplit, ShuffleSplit

from src.definitions import DATA_DIR
from src.util.train_test_split import GroupShuffleSplitND
from util import write_indices_and_stats

In [2]:
# Load data
data_filename = "synferm_dataset_2023-12-20_39486records.csv"
data_name = data_filename.rsplit("_", maxsplit=1)[0]
df = pd.read_csv(DATA_DIR / "curated_data" / data_filename)
df.shape

(39486, 27)

In [3]:
df.head()

Unnamed: 0,I_long,M_long,T_long,product_A_smiles,I_smiles,M_smiles,T_smiles,reaction_smiles,reaction_smiles_atom_mapped,experiment_id,...,binary_H,scaled_A,scaled_B,scaled_C,scaled_D,scaled_E,scaled_F,scaled_G,scaled_H,major_A-C
0,2-Pyr003,Fused002,TerABT004,COc1ccc(CCOC(=O)N2C[C@H](NC(=O)c3cccc(Cl)n3)[C...,O=C(c1cccc(Cl)n1)[B-](F)(F)F.[K+],COc1ccc(CCOC(=O)N2C[C@@H]3NO[C@]4(OC5(CCCCC5)O...,Nc1ccc(F)cc1S,O=C(c1cccc(Cl)n1)[B-](F)(F)F.COc1ccc(CCOC(=O)N...,F[B-](F)(F)[C:1](=[O:2])[c:15]1[cH:16][cH:18][...,56113,...,0,0.036021,0.003427,0.0,0.020975,0.002958,0.941981,0.914281,0.0,A
1,2-Pyr003,Fused002,TerABT007,COc1ccc(CCOC(=O)N2C[C@H](NC(=O)c3cccc(Cl)n3)[C...,O=C(c1cccc(Cl)n1)[B-](F)(F)F.[K+],COc1ccc(CCOC(=O)N2C[C@@H]3NO[C@]4(OC5(CCCCC5)O...,Nc1cc(Br)ccc1S,O=C(c1cccc(Cl)n1)[B-](F)(F)F.COc1ccc(CCOC(=O)N...,F[B-](F)(F)[C:1](=[O:2])[c:15]1[cH:16][cH:18][...,56114,...,0,0.0,0.0,0.0,0.006159,0.364398,0.928851,1.106548,0.0,no_product
2,2-Pyr003,Fused002,TerABT013,COc1ccc(CCOC(=O)N2C[C@H](NC(=O)c3cccc(Cl)n3)[C...,O=C(c1cccc(Cl)n1)[B-](F)(F)F.[K+],COc1ccc(CCOC(=O)N2C[C@@H]3NO[C@]4(OC5(CCCCC5)O...,Nc1cc(C(F)(F)F)ccc1S,O=C(c1cccc(Cl)n1)[B-](F)(F)F.COc1ccc(CCOC(=O)N...,F[B-](F)(F)[C:1](=[O:2])[c:15]1[cH:16][cH:18][...,56106,...,1,0.0,0.0,0.0,0.014212,2.16642,1.013596,0.537785,0.05686,no_product
3,2-Pyr003,Fused002,TerABT014,COc1ccc(CCOC(=O)N2C[C@H](NC(=O)c3cccc(Cl)n3)[C...,O=C(c1cccc(Cl)n1)[B-](F)(F)F.[K+],COc1ccc(CCOC(=O)N2C[C@@H]3NO[C@]4(OC5(CCCCC5)O...,Nc1ccc(Cl)cc1S,O=C(c1cccc(Cl)n1)[B-](F)(F)F.COc1ccc(CCOC(=O)N...,F[B-](F)(F)[C:1](=[O:2])[c:15]1[cH:16][cH:18][...,56112,...,0,0.028915,0.005039,0.0,0.015578,0.504057,0.992614,0.890646,0.0,A
4,2-Pyr003,Fused002,TerTH001,COc1ccc(CCOC(=O)N2C[C@H](NC(=O)c3cccc(Cl)n3)[C...,O=C(c1cccc(Cl)n1)[B-](F)(F)F.[K+],COc1ccc(CCOC(=O)N2C[C@@H]3NO[C@]4(OC5(CCCCC5)O...,[Cl-].[NH3+]NC(=S)c1ccccc1,O=C(c1cccc(Cl)n1)[B-](F)(F)F.COc1ccc(CCOC(=O)N...,F[B-](F)(F)[C:1](=[O:2])[c:11]1[cH:12][cH:14][...,56109,...,0,0.350061,0.643219,0.0,0.031689,0.613596,0.109309,0.439018,0.0,B


In [4]:
# M_long_dia will be to sort diastereomers into the same group on group shuffle splits
diastereomers = {
    "Mon001": "Mon087",
    "Mon003": "Mon078",
    "Mon011": "Mon088",
    "Mon013": "Mon074",
    "Mon014": "Mon090",
    "Mon015": "Mon076",
    "Mon016": "Mon096",
    "Mon017": "Mon075",
    "Mon019": "Mon091",
    "Mon020": "Mon077",
    "Mon080": "Mon010",
}
df["M_long_dia"] = df["M_long"].replace(diastereomers)

## 0D split

In [4]:
def split_0d(splitter, inner_splitter):
    indices = []
    sizes = []
    pos_class = []
    for idx_train_val, idx_test in splitter.split(df):
        # inner split
        train, val = next(inner_splitter.split(idx_train_val))
        # use indices to index indices :P (we need to obtain indices referring to the original dataframe)
        idx_train = idx_train_val[train]
        idx_val = idx_train_val[val]
        # add to list
        indices.append((idx_train, idx_val, idx_test))
        sizes.append((len(idx_train), len(idx_val), len(idx_test)))
        pos_class.append(
            (np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_train]).to_numpy(), 
             np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_val]).to_numpy(), 
             np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_test]).to_numpy(),
            )
        )
        
    return indices, sizes, pos_class

In [7]:
splitter = ShuffleSplit(n_splits=9, test_size=0.1, random_state=42)
inner_splitter = ShuffleSplit(n_splits=1, test_size=0.1/0.9, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class = split_0d(splitter, inner_splitter)

print(sizes)
#print(pos_class)

[(31588, 3949, 3949), (31588, 3949, 3949), (31588, 3949, 3949), (31588, 3949, 3949), (31588, 3949, 3949), (31588, 3949, 3949), (31588, 3949, 3949), (31588, 3949, 3949), (31588, 3949, 3949)]


In [8]:
write_indices_and_stats(
    indices, sizes, pos_class, split_dimension=0, save_indices=True, train_size=80, total_size=len(df), data_name=data_name
)

In [9]:
splitter = ShuffleSplit(n_splits=9, test_size=0.3, random_state=42)
inner_splitter = ShuffleSplit(n_splits=1, test_size=0.3/0.7, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class = split_0d(splitter, inner_splitter)

print(sizes)
#print(pos_class)

[(15794, 11846, 11846), (15794, 11846, 11846), (15794, 11846, 11846), (15794, 11846, 11846), (15794, 11846, 11846), (15794, 11846, 11846), (15794, 11846, 11846), (15794, 11846, 11846), (15794, 11846, 11846)]


In [10]:
write_indices_and_stats(
    indices, sizes, pos_class, split_dimension=0, save_indices=True, train_size=40, total_size=len(df), data_name=data_name
)

In [11]:
splitter = ShuffleSplit(n_splits=9, test_size=0.4, random_state=42)
inner_splitter = ShuffleSplit(n_splits=1, test_size=0.4/0.6, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class = split_0d(splitter, inner_splitter)

print(sizes)
#print(pos_class)

[(7896, 15795, 15795), (7896, 15795, 15795), (7896, 15795, 15795), (7896, 15795, 15795), (7896, 15795, 15795), (7896, 15795, 15795), (7896, 15795, 15795), (7896, 15795, 15795), (7896, 15795, 15795)]


In [12]:
write_indices_and_stats(
    indices, sizes, pos_class, split_dimension=0, save_indices=True, train_size=20, total_size=len(df), data_name=data_name
)

In [13]:
splitter = ShuffleSplit(n_splits=9, test_size=0.45, random_state=42)
inner_splitter = ShuffleSplit(n_splits=1, test_size=0.45/0.55, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class = split_0d(splitter, inner_splitter)

print(sizes)
#print(pos_class)

[(3948, 17769, 17769), (3948, 17769, 17769), (3948, 17769, 17769), (3948, 17769, 17769), (3948, 17769, 17769), (3948, 17769, 17769), (3948, 17769, 17769), (3948, 17769, 17769), (3948, 17769, 17769)]


In [14]:
write_indices_and_stats(
    indices, sizes, pos_class, split_dimension=0, save_indices=True, train_size=10, total_size=len(df), data_name=data_name
)

In [15]:
splitter = ShuffleSplit(n_splits=9, test_size=0.45, random_state=42)
inner_splitter = ShuffleSplit(n_splits=1, test_size=0.5/0.55, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class = split_0d(splitter, inner_splitter)

print(sizes)
#print(pos_class)

[(1974, 19743, 17769), (1974, 19743, 17769), (1974, 19743, 17769), (1974, 19743, 17769), (1974, 19743, 17769), (1974, 19743, 17769), (1974, 19743, 17769), (1974, 19743, 17769), (1974, 19743, 17769)]


In [16]:
write_indices_and_stats(
    indices, sizes, pos_class, split_dimension=0, save_indices=True, train_size=5, total_size=len(df), data_name=data_name
)

In [17]:
splitter = ShuffleSplit(n_splits=9, test_size=0.475, random_state=42)
inner_splitter = ShuffleSplit(n_splits=1, test_size=0.5/0.525, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class = split_0d(splitter, inner_splitter)

print(sizes)
#print(pos_class)

[(987, 19743, 18756), (987, 19743, 18756), (987, 19743, 18756), (987, 19743, 18756), (987, 19743, 18756), (987, 19743, 18756), (987, 19743, 18756), (987, 19743, 18756), (987, 19743, 18756)]


In [18]:
write_indices_and_stats(
    indices, sizes, pos_class, split_dimension=0, save_indices=True, train_size=2.5, total_size=len(df), data_name=data_name
)

In [19]:
splitter = ShuffleSplit(n_splits=9, test_size=0.4875, random_state=42)
inner_splitter = ShuffleSplit(n_splits=1, test_size=0.5/0.5125, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class = split_0d(splitter, inner_splitter)

print(sizes)
#print(pos_class)

[(493, 19743, 19250), (493, 19743, 19250), (493, 19743, 19250), (493, 19743, 19250), (493, 19743, 19250), (493, 19743, 19250), (493, 19743, 19250), (493, 19743, 19250), (493, 19743, 19250)]


In [20]:
write_indices_and_stats(
    indices, sizes, pos_class, split_dimension=0, save_indices=True, train_size=1.25, total_size=len(df), data_name=data_name
)

In [21]:
splitter = ShuffleSplit(n_splits=9, test_size=0.49375, random_state=42)
inner_splitter = ShuffleSplit(n_splits=1, test_size=0.5/0.50625, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class = split_0d(splitter, inner_splitter)

print(sizes)
#print(pos_class)

[(246, 19743, 19497), (246, 19743, 19497), (246, 19743, 19497), (246, 19743, 19497), (246, 19743, 19497), (246, 19743, 19497), (246, 19743, 19497), (246, 19743, 19497), (246, 19743, 19497)]


In [22]:
write_indices_and_stats(
    indices, sizes, pos_class, split_dimension=0, save_indices=True, train_size=0.625, total_size=len(df), data_name=data_name
)

## 1D split

In [23]:
def split_1d(splitter, inner_splitter):
    indices = []
    sizes = []
    pos_class = []
    unique_initiators = []
    unique_monomers = []
    unique_terminators = []
    for idx_train_val, idx_test in splitter.split(list(range(len(df))), groups=df["I_long"]):
        # inner split
        train, val = next(inner_splitter.split(idx_train_val, groups=df["I_long"][idx_train_val]))
        # use indices to index indices :P (we need to obtain indices referring to the original dataframe)
        idx_train = idx_train_val[train]
        idx_val = idx_train_val[val]
        # add to list
        indices.append((idx_train, idx_val, idx_test))
        sizes.append((len(idx_train), len(idx_val), len(idx_test)))
        pos_class.append(
            (np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_train]).to_numpy(), 
             np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_val]).to_numpy(), 
             np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_test]).to_numpy(),
            )
        )
        unique_initiators.append((len(df['I_long'][idx_train].drop_duplicates()), len(df['I_long'][idx_val].drop_duplicates()), len(df['I_long'][idx_test].drop_duplicates())))
        unique_monomers.append((len(df['M_long'][idx_train].drop_duplicates()), len(df['M_long'][idx_val].drop_duplicates()), len(df['M_long'][idx_test].drop_duplicates())))
        unique_terminators.append((len(df['T_long'][idx_train].drop_duplicates()), len(df['T_long'][idx_val].drop_duplicates()), len(df['T_long'][idx_test].drop_duplicates())))
    
    return indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators


In [24]:
splitter = GroupShuffleSplit(n_splits=9, test_size=0.1, random_state=42)
inner_splitter = GroupShuffleSplit(n_splits=1, test_size=0.1/0.9, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_1d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(30594, 4239, 4653), (31800, 3496, 4190), (32067, 3970, 3449), (29146, 6020, 4320), (30773, 4635, 4078), (31769, 4621, 3096), (29991, 4763, 4732), (29941, 5287, 4258), (31529, 4134, 3823)]
[(array([25099, 17658,  8945]), array([3310, 2427, 1029]), array([4003, 2719, 1295])), (array([26326, 18584,  9187]), array([2882, 2065, 1070]), array([3204, 2155, 1012])), (array([26642, 18885,  9589]), array([2994, 1966,  907]), array([2776, 1953,  773])), (array([23586, 16675,  7939]), array([5060, 3435, 1893]), array([3766, 2694, 1437])), (array([25179, 17867,  8771]), array([3798, 2555, 1228]), array([3435, 2382, 1270])), (array([26218, 18394,  9354]), array([3992, 2922, 1416]), array([2202, 1488,  499])), (array([24244, 17230,  8784]), array([4069, 2677, 1205]), array([4099, 2897, 1280])), (array([24231, 17060,  8581]), array([4637, 3185, 1601]), array([3544, 2559, 1087])), (array([25677, 18033,  8700]), array([3476, 2369, 1444]), array([3259, 2402, 1125]))]
[(53, 7, 7), (53, 7, 7), (53, 7, 7)

In [25]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=1, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=80
)

In [26]:
splitter = GroupShuffleSplit(n_splits=9, test_size=0.3, random_state=42)
inner_splitter = GroupShuffleSplit(n_splits=1, test_size=0.3/0.7, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_1d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(16946, 10579, 11961), (16650, 11348, 11488), (15287, 11962, 12237), (15654, 11263, 12569), (15008, 12497, 11981), (15763, 12422, 11301), (13975, 11594, 13917), (16185, 11436, 11865), (13611, 13081, 12794)]
[(array([13934,  9989,  4972]), array([8287, 5691, 2713]), array([10191,  7124,  3584])), (array([13540,  9582,  4195]), array([9615, 6883, 3554]), array([9257, 6339, 3520])), (array([12544,  9167,  4223]), array([9740, 6725, 3717]), array([10128,  6912,  3329])), (array([13091,  9174,  4309]), array([9137, 6269, 3425]), array([10184,  7361,  3535])), (array([12718,  8990,  4466]), array([10117,  6762,  3229]), array([9577, 7052, 3574])), (array([13381,  9087,  4575]), array([10300,  7232,  4091]), array([8731, 6485, 2603])), (array([11120,  7938,  3696]), array([9364, 6646, 3422]), array([11928,  8220,  4151])), (array([12961,  9363,  4588]), array([9509, 6375, 3316]), array([9942, 7066, 3365])), (array([10784,  7711,  3910]), array([11206,  7818,  3663]), array([10422,  7275,  36

In [27]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=1, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=40
)

In [28]:
splitter = GroupShuffleSplit(n_splits=9, test_size=0.4, random_state=42)
inner_splitter = GroupShuffleSplit(n_splits=1, test_size=0.4/0.6, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_1d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(8704, 15962, 14820), (8967, 14367, 16152), (6729, 17054, 15703), (8331, 15147, 16008), (8709, 15112, 15665), (7877, 15172, 16437), (7150, 15626, 16710), (7516, 15668, 16302), (7990, 15190, 16306)]
[(array([7300, 5154, 2566]), array([12985,  8822,  4576]), array([12127,  8828,  4127])), (array([7738, 5605, 3104]), array([11323,  7960,  3316]), array([13351,  9239,  4849])), (array([5516, 3765, 2228]), array([14361, 10153,  5060]), array([12535,  8886,  3981])), (array([7199, 4972, 2122]), array([12236,  8545,  4758]), array([12977,  9287,  4389])), (array([7361, 5194, 2442]), array([12341,  8458,  4254]), array([12710,  9152,  4573])), (array([6671, 4540, 2021]), array([12571,  8685,  5082]), array([13170,  9579,  4166])), (array([5808, 4009, 2097]), array([12364,  9044,  4093]), array([14240,  9751,  5079])), (array([6237, 4239, 2075]), array([12457,  9027,  4716]), array([13718,  9538,  4478])), (array([6474, 4901, 2531]), array([12611,  8542,  4285]), array([13327,  9361,  4453]))]

In [29]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=1, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=20
)

In [30]:
splitter = GroupShuffleSplit(n_splits=9, test_size=0.45, random_state=42)
inner_splitter = GroupShuffleSplit(n_splits=1, test_size=0.45/0.55, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_1d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(3516, 18433, 17537), (3950, 17658, 17878), (4179, 17735, 17572), (3412, 17633, 18441), (4514, 16278, 18694), (2974, 17483, 19029), (2562, 17851, 19073), (3653, 16581, 19252), (3930, 16876, 18680)]
[(array([2835, 1990, 1021]), array([15116, 10253,  5379]), array([14461, 10561,  4869])), (array([3359, 2413, 1299]), array([14246, 10135,  4670]), array([14807, 10256,  5300])), (array([3553, 2423,  938]), array([14800, 10395,  5634]), array([14059,  9986,  4697])), (array([2837, 2001,  969]), array([14644, 10167,  5377]), array([14931, 10636,  4923])), (array([3945, 2808, 1619]), array([13175,  9092,  3846]), array([15292, 10904,  5804])), (array([2430, 1846, 1051]), array([14578,  9852,  4902]), array([15404, 11106,  5316])), (array([2114, 1475,  692]), array([14186, 10272,  4911]), array([16112, 11057,  5666])), (array([3032, 2043, 1160]), array([13134,  9442,  4884]), array([16246, 11319,  5225])), (array([3275, 2318, 1192]), array([13756,  9612,  4917]), array([15381, 10874,  5160]))]

In [31]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=1, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=10
)

In [32]:
splitter = GroupShuffleSplit(n_splits=9, test_size=0.45, random_state=42)
inner_splitter = GroupShuffleSplit(n_splits=1, test_size=0.5/0.55, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_1d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(1427, 20522, 17537), (2379, 19229, 17878), (1268, 20646, 17572), (1964, 19081, 18441), (1616, 19176, 18694), (1701, 18756, 19029), (1205, 19208, 19073), (1213, 19021, 19252), (2026, 18780, 18680)]
[(array([1079,  697,  337]), array([16872, 11546,  6063]), array([14461, 10561,  4869])), (array([2062, 1449,  726]), array([15543, 11099,  5243]), array([14807, 10256,  5300])), (array([1023,  721,  323]), array([17330, 12097,  6249]), array([14059,  9986,  4697])), (array([1693, 1194,  552]), array([15788, 10974,  5794]), array([14931, 10636,  4923])), (array([1344,  841,  524]), array([15776, 11059,  4941]), array([15292, 10904,  5804])), (array([1445, 1121,  630]), array([15563, 10577,  5323]), array([15404, 11106,  5316])), (array([1002,  726,  296]), array([15298, 11021,  5307]), array([16112, 11057,  5666])), (array([977, 581, 264]), array([15189, 10904,  5780]), array([16246, 11319,  5225])), (array([1709, 1186,  552]), array([15322, 10744,  5557]), array([15381, 10874,  5160]))]
[(

In [33]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=1, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=5
)

In [34]:
splitter = GroupShuffleSplit(n_splits=9, test_size=0.475, random_state=42)
inner_splitter = GroupShuffleSplit(n_splits=1, test_size=0.5/0.525, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_1d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(192, 21208, 18086), (937, 19753, 18796), (483, 21122, 17881), (593, 19437, 19456), (612, 19672, 19202), (472, 19473, 19541), (192, 19976, 19318), (474, 19207, 19805), (359, 19964, 19163)]
[(array([146,  77,  33]), array([17357, 11855,  6247]), array([14909, 10872,  4989])), (array([818, 622, 414]), array([16048, 11407,  5314]), array([15546, 10775,  5541])), (array([382, 265, 152]), array([17722, 12361,  6238]), array([14308, 10178,  4879])), (array([534, 385, 198]), array([16058, 11152,  5555]), array([15820, 11267,  5516])), (array([492, 256, 206]), array([16165, 11283,  5042]), array([15755, 11265,  6021])), (array([381, 236, 215]), array([16268, 11323,  5700]), array([15763, 11245,  5354])), (array([146,  77,  33]), array([15968, 11527,  5501]), array([16298, 11200,  5735])), (array([385, 237, 163]), array([15281, 10852,  5746]), array([16746, 11715,  5360])), (array([289, 197, 170]), array([16360, 11468,  5787]), array([15763, 11139,  5312]))]
[(1, 34, 32), (1, 34, 32), (1, 34, 

In [35]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=1, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=2.5
)

## 2D split

In [28]:
def split_2d(splitter, inner_splitter):
    indices = []
    sizes = []
    pos_class = []
    unique_initiators = []
    unique_monomers = []
    unique_terminators = []
    for idx_train_val, idx_test in splitter.split(df, groups=df[["I_long", "M_long_dia"]]):
        train, val = next(inner_splitter.split(df.iloc[idx_train_val], groups=df[["I_long", "M_long_dia"]].iloc[idx_train_val]))
        # use indices to index indices :P (we need to obtain indices referring to the original dataframe)
        idx_train = idx_train_val[train]
        idx_val = idx_train_val[val]
        indices.append((idx_train, idx_val, idx_test))
        sizes.append((len(idx_train), len(idx_val), len(idx_test)))
        pos_class.append(
            (np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_train]).to_numpy(), 
             np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_val]).to_numpy(), 
             np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_test]).to_numpy(),
            )
        )
        unique_initiators.append((len(df['I_long'][idx_train].drop_duplicates()), len(df['I_long'][idx_val].drop_duplicates()), len(df['I_long'][idx_test].drop_duplicates())))
        unique_monomers.append((len(df['M_long_dia'][idx_train].drop_duplicates()), len(df['M_long_dia'][idx_val].drop_duplicates()), len(df['M_long'][idx_test].drop_duplicates())))
        unique_terminators.append((len(df['T_long'][idx_train].drop_duplicates()), len(df['T_long'][idx_val].drop_duplicates()), len(df['T_long'][idx_test].drop_duplicates())))
    
    return indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators


In [29]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.1, random_state=42)
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.1/0.9, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_2d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(24829, 380, 443), (23898, 372, 586), (24652, 524, 350), (25144, 355, 382), (24451, 532, 370), (24307, 569, 374), (24394, 510, 436), (24584, 428, 354), (26121, 315, 427)]
[(array([21110, 14989,  7529]), array([226, 151,  79]), array([330, 222,  97])), (array([19554, 13897,  6861]), array([291, 207,  79]), array([500, 339, 179])), (array([21199, 14546,  7847]), array([415, 377, 135]), array([175, 105,  29])), (array([20247, 14125,  6953]), array([290, 205, 112]), array([347, 251, 134])), (array([20146, 14119,  7235]), array([508, 388, 177]), array([211, 149,  38])), (array([20351, 14296,  7044]), array([427, 321, 137]), array([300, 212,  86])), (array([20039, 14278,  7141]), array([400, 271,  91]), array([348, 233, 137])), (array([20633, 14565,  7459]), array([330, 228, 109]), array([274, 174,  68])), (array([21579, 15189,  7679]), array([222, 155,  59]), array([376, 257, 145]))]
[(53, 6, 7), (53, 7, 7), (53, 7, 7), (53, 7, 7), (53, 7, 7), (53, 7, 7), (53, 7, 7), (53, 7, 7), (53, 7, 7)

In [30]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=2, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=80
)

In [31]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.2, random_state=42)
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.2/0.8, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_2d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(13513, 2038, 1429), (13183, 1681, 1564), (13985, 1783, 1391), (16017, 1054, 1671), (14312, 1864, 1276), (14568, 1565, 1478), (14216, 1571, 1496), (14343, 1808, 1274), (12745, 1933, 1841)]
[(array([10911,  7744,  3840]), array([1783, 1218,  591]), array([1175,  818,  334])), (array([11149,  8145,  3652]), array([1272,  856,  442]), array([1243,  873,  497])), (array([12287,  8563,  4634]), array([1482, 1089,  491]), array([939, 607, 258])), (array([13667,  9532,  5020]), array([819, 493, 243]), array([1282,  977,  376])), (array([12359,  8716,  4664]), array([1513,  978,  449]), array([801, 632, 241])), (array([11866,  8435,  4130]), array([1223,  814,  353]), array([1258,  917,  431])), (array([12353,  8791,  4471]), array([1205,  892,  373]), array([1048,  635,  370])), (array([12217,  9046,  4280]), array([1338,  885,  506]), array([1017,  648,  329])), (array([10371,  7141,  3507]), array([1658, 1294,  649]), array([1464,  972,  492]))]
[(39, 14, 14), (39, 14, 14), (39, 14, 14), (

In [32]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=2, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=60
)

In [33]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.3, random_state=42)
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.3/0.7, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_2d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(6970, 3580, 3104), (6496, 3272, 3550), (5537, 3984, 3321), (5449, 3909, 3734), (6119, 4523, 2773), (5258, 3948, 4044), (5588, 3759, 3698), (6695, 3153, 3419), (5730, 3849, 4039)]
[(array([5754, 4044, 1979]), array([2879, 2032, 1012]), array([2572, 1825,  835])), (array([5244, 3759, 1499]), array([2764, 2020,  939]), array([2851, 1903, 1247])), (array([4294, 3081, 1501]), array([3704, 2771, 1577]), array([2393, 1527,  558])), (array([4457, 3111, 1538]), array([3285, 2233, 1198]), array([3009, 2196,  941])), (array([5308, 3604, 1843]), array([3892, 2861, 1395]), array([1865, 1320,  656])), (array([4241, 2803, 1414]), array([3051, 2076, 1094]), array([3418, 2590, 1104])), (array([4496, 3337, 1512]), array([3267, 2268, 1171]), array([2804, 1831,  972])), (array([5770, 4117, 1907]), array([2461, 1752,  949]), array([2675, 1765,  887])), (array([5005, 3478, 1709]), array([2968, 2226,  885]), array([3217, 2154, 1307]))]
[(26, 20, 21), (26, 20, 21), (26, 20, 21), (26, 20, 21), (26, 20, 21), 

In [34]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=2, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=40
)

In [35]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.35, random_state=42)
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.35/0.65, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_2d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(3608, 5570, 4147), (3046, 5523, 4743), (3267, 5414, 4519), (3460, 4822, 4878), (3805, 5536, 3736), (2852, 4347, 6127), (3662, 5032, 4399), (3557, 5137, 4744), (2926, 5292, 5409)]
[(array([2957, 2027,  969]), array([4578, 3169, 1661]), array([3379, 2566, 1148])), (array([2422, 1620,  756]), array([4764, 3575, 1403]), array([3738, 2502, 1580])), (array([2804, 2065,  965]), array([4602, 3324, 1969]), array([3357, 2185,  804])), (array([2927, 2022, 1087]), array([3813, 2693, 1320]), array([3952, 2813, 1182])), (array([3163, 2217, 1097]), array([5064, 3580, 1814]), array([2497, 1692,  812])), (array([2232, 1458,  948]), array([3477, 2414, 1117]), array([5194, 3819, 1612])), (array([3121, 2088,  851]), array([4128, 3201, 1698]), array([3318, 2186, 1083])), (array([2903, 2193, 1324]), array([4317, 3042, 1308]), array([3833, 2571, 1221])), (array([2266, 1760,  779]), array([4524, 3096, 1516]), array([4346, 2921, 1582]))]
[(19, 24, 24), (19, 24, 24), (19, 24, 24), (19, 24, 24), (19, 24, 24), 

In [36]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=2, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=30
)

In [37]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.4, random_state=42)
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.4/0.6, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_2d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(1549, 7699, 5177), (1433, 6575, 6337), (1745, 6434, 6003), (1609, 6180, 6015), (1586, 7657, 5323), (1326, 5436, 7722), (1170, 8094, 5438), (1452, 6373, 6340), (1580, 6286, 6870)]
[(array([1306,  926,  486]), array([6449, 4466, 2342]), array([4082, 3026, 1335])), (array([1296,  895,  408]), array([5273, 3928, 1620]), array([5085, 3410, 1973])), (array([1505, 1089,  712]), array([5633, 3977, 2009]), array([4426, 2953, 1200])), (array([1238,  876,  261]), array([5144, 3483, 2162]), array([4846, 3401, 1506])), (array([1427, 1042,  391]), array([6593, 4522, 2465]), array([3794, 2611, 1266])), (array([1007,  641,  313]), array([4311, 2879, 1679]), array([6629, 4873, 2050])), (array([879, 669, 304]), array([7044, 4975, 2467]), array([4094, 2718, 1403])), (array([1205,  896,  355]), array([5192, 3882, 2004]), array([5157, 3434, 1761])), (array([1230, 1016,  412]), array([5344, 3521, 1841]), array([5414, 3714, 1996]))]
[(13, 27, 27), (13, 27, 27), (13, 27, 27), (13, 27, 27), (13, 27, 27), (13

In [38]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=2, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=20
)

In [39]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.4, random_state=42)
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.45/0.6, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_2d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(851, 9544, 5177), (836, 8030, 6337), (968, 8451, 6003), (1011, 7754, 6015), (856, 9267, 5323), (766, 6593, 7722), (642, 9719, 5438), (779, 7968, 6340), (1084, 7378, 6870)]
[(array([738, 508, 286]), array([7954, 5559, 2874]), array([4082, 3026, 1335])), (array([768, 521, 217]), array([6507, 4785, 2045]), array([5085, 3410, 1973])), (array([827, 584, 403]), array([7413, 5280, 2703]), array([4426, 2953, 1200])), (array([776, 562, 139]), array([6445, 4366, 2644]), array([4846, 3401, 1506])), (array([770, 564, 185]), array([8038, 5581, 2966]), array([3794, 2611, 1266])), (array([593, 408, 184]), array([5117, 3331, 1936]), array([6629, 4873, 2050])), (array([460, 373, 169]), array([8450, 5916, 2902]), array([4094, 2718, 1403])), (array([633, 467, 208]), array([6498, 4853, 2409]), array([5157, 3434, 1761])), (array([848, 720, 288]), array([6267, 4159, 2130]), array([5414, 3714, 1996]))]
[(10, 30, 27), (10, 30, 27), (10, 30, 27), (10, 30, 27), (10, 30, 27), (10, 30, 27), (10, 30, 27), (10, 3

In [40]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=2, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=15
)

In [41]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.45, random_state=42)
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.45/0.55, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_2d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(378, 9534, 6774), (213, 8943, 8051), (176, 9354, 7380), (238, 8486, 7916), (430, 7714, 8087), (194, 7085, 10045), (342, 8498, 7458), (313, 8315, 8195), (456, 7243, 8907)]
[(array([295, 207,  91]), array([8050, 5475, 2944]), array([5318, 3932, 1764])), (array([132,  97,  44]), array([7725, 5567, 2361]), array([6470, 4349, 2458])), (array([166, 104,  53]), array([8184, 5929, 3292]), array([5502, 3667, 1484])), (array([196, 134,  67]), array([7154, 5088, 2753]), array([6434, 4479, 1931])), (array([400, 281, 145]), array([6551, 4691, 1959]), array([6201, 4343, 2414])), (array([145,  86,  69]), array([5462, 3731, 1843]), array([8643, 6275, 2879])), (array([295, 227, 115]), array([7086, 5073, 2368]), array([5698, 3785, 1960])), (array([258, 192,  55]), array([6688, 4825, 2570]), array([6788, 4572, 2322])), (array([372, 262, 113]), array([5871, 4190, 1897]), array([7206, 4980, 2666]))]
[(6, 30, 31), (6, 30, 31), (6, 30, 31), (6, 30, 31), (6, 30, 31), (6, 30, 31), (6, 30, 31), (6, 30, 31), (

In [42]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=2, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=10
)

In [43]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.45, random_state=42)
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.475/0.55, random_state=np.random.RandomState(4))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_2d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(132, 10993, 6774), (159, 9055, 8051), (173, 10102, 7380), (127, 9683, 7916), (213, 9090, 8087), (104, 7812, 10045), (181, 9618, 7458), (189, 9258, 8195), (252, 8192, 8907)]
[(array([89, 55, 50]), array([9456, 6515, 3342]), array([5318, 3932, 1764])), (array([136,  82,  56]), array([7460, 5523, 2293]), array([6470, 4349, 2458])), (array([147, 119,  45]), array([8936, 6401, 3649]), array([5502, 3667, 1484])), (array([108,  80,  41]), array([7991, 5616, 3057]), array([6434, 4479, 1931])), (array([147, 104,  40]), array([7898, 5561, 2465]), array([6201, 4343, 2414])), (array([76, 37, 35]), array([6078, 4094, 2063]), array([8643, 6275, 2879])), (array([166, 120,  42]), array([7962, 5738, 2756]), array([5698, 3785, 1960])), (array([172, 128,  49]), array([7415, 5412, 2780]), array([6788, 4572, 2322])), (array([236, 163,  97]), array([6502, 4532, 1945]), array([7206, 4980, 2666]))]
[(4, 32, 31), (4, 32, 31), (4, 32, 31), (4, 32, 31), (4, 32, 31), (4, 32, 31), (4, 32, 31), (4, 32, 31), (4, 3

In [44]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=2, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=7.5
)

In [45]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.45, random_state=42)
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.5/0.55, random_state=np.random.RandomState(4))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_2d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(65, 11741, 6774), (103, 9786, 8051), (77, 11060, 7380), (94, 9938, 7916), (106, 9942, 8087), (47, 8163, 10045), (128, 10203, 7458), (102, 9935, 8195), (184, 8734, 8907)]
[(array([36, 21, 18]), array([10106,  6977,  3580]), array([5318, 3932, 1764])), (array([90, 56, 37]), array([8127, 5924, 2478]), array([6470, 4349, 2458])), (array([60, 48, 17]), array([9820, 7075, 3960]), array([5502, 3667, 1484])), (array([82, 65, 34]), array([8160, 5742, 3100]), array([6434, 4479, 1931])), (array([48, 30, 10]), array([8712, 6154, 2649]), array([6201, 4343, 2414])), (array([36, 19, 12]), array([6310, 4227, 2166]), array([8643, 6275, 2879])), (array([116,  84,  34]), array([8444, 6070, 2867]), array([5698, 3785, 1960])), (array([90, 62, 33]), array([8007, 5856, 2949]), array([6788, 4572, 2322])), (array([173, 128,  90]), array([6903, 4809, 2012]), array([7206, 4980, 2666]))]
[(3, 33, 31), (3, 33, 31), (3, 33, 31), (3, 33, 31), (3, 33, 31), (3, 33, 31), (3, 33, 31), (3, 33, 31), (3, 33, 31)]
[(3, 31

In [46]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=2, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=5
)

In [47]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.48, random_state=42)
inner_splitter = GroupShuffleSplitND(n_splits=1, train_size=2, random_state=np.random.RandomState(4))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_2d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

RuntimeError: No samples found in train groups. Consider lowering test_size.

In [48]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=2, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=2
)

## 3D split

In [49]:
def split_3d(splitter, inner_splitter):
    indices = []
    sizes = []
    pos_class = []
    unique_initiators = []
    unique_monomers = []
    unique_terminators = []
    for idx_train_val, idx_test in splitter.split(df, groups=df[["I_long", "M_long_dia", "T_long"]]):
        train, val = next(inner_splitter.split(df.iloc[idx_train_val], groups=df[["I_long", "M_long_dia", "T_long"]].iloc[idx_train_val]))
        # use indices to index indices :P (we need to obtain indices referring to the original dataframe)
        idx_train = idx_train_val[train]
        idx_val = idx_train_val[val]
        indices.append((idx_train, idx_val, idx_test))
        sizes.append((len(idx_train), len(idx_val), len(idx_test)))
        pos_class.append(
            (np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_train]).to_numpy(), 
             np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_val]).to_numpy(), 
             np.sum(df[['binary_A', 'binary_B', 'binary_C']].loc[idx_test]).to_numpy(),
            )
        )
        unique_initiators.append((len(df['I_long'][idx_train].drop_duplicates()), len(df['I_long'][idx_val].drop_duplicates()), len(df['I_long'][idx_test].drop_duplicates())))
        unique_monomers.append((len(df['M_long_dia'][idx_train].drop_duplicates()), len(df['M_long_dia'][idx_val].drop_duplicates()), len(df['M_long'][idx_test].drop_duplicates())))
        unique_terminators.append((len(df['T_long'][idx_train].drop_duplicates()), len(df['T_long'][idx_val].drop_duplicates()), len(df['T_long'][idx_test].drop_duplicates())))

    return indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators


In [50]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.1, random_state=np.random.RandomState(42))  # here, we reuse the outer splitter as well, so we use RandomState
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.1/0.9, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_3d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(19418, 26, 51), (19138, 49, 48), (19435, 36, 41), (17766, 42, 49), (19544, 43, 26), (19795, 38, 39), (18709, 48, 68), (19819, 49, 46), (18482, 30, 72)]
[(array([16366, 12284,  6113]), array([21, 13,  9]), array([30, 17,  4])), (array([16422, 11788,  6334]), array([39, 12, 11]), array([32, 23,  5])), (array([16431, 11382,  5635]), array([27, 25, 21]), array([32, 15,  5])), (array([14998, 10083,  4987]), array([35, 27,  1]), array([32, 28, 17])), (array([16373, 11539,  5475]), array([29, 17, 12]), array([21, 16,  6])), (array([16226, 10824,  5663]), array([35, 16,  5]), array([30, 31, 12])), (array([15472, 10854,  4864]), array([44, 37, 30]), array([37, 31,  9])), (array([15976, 11722,  6187]), array([47, 29, 12]), array([38, 23,  9])), (array([15322, 10301,  5700]), array([29, 27,  5]), array([52, 42, 15]))]
[(53, 6, 7), (53, 7, 7), (53, 7, 7), (53, 7, 7), (53, 7, 4), (53, 7, 7), (53, 7, 7), (53, 6, 7), (53, 7, 7)]
[(48, 6, 7), (48, 5, 7), (48, 6, 8), (47, 6, 9), (48, 7, 8), (48, 7, 8

In [51]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=3, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=80
)

In [52]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.15, random_state=np.random.RandomState(42))  # here, we reuse the outer splitter as well, so we use RandomState
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.15/0.85, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_3d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(13044, 176, 198), (13673, 116, 154), (12485, 101, 166), (12365, 90, 178), (12743, 125, 130), (12558, 171, 147), (12465, 129, 184), (12636, 151, 156), (13471, 71, 192)]
[(array([11041,  7447,  3477]), array([109, 101,  51]), array([163, 107,  49])), (array([11148,  7834,  3795]), array([107,  62,  42]), array([111,  83,  45])), (array([10214,  7641,  3684]), array([86, 59, 29]), array([133,  85,  32])), (array([10101,  6619,  3215]), array([84, 67, 27]), array([137, 130,  69])), (array([10406,  6640,  3168]), array([98, 86, 42]), array([105,  84,  40])), (array([10527,  6794,  3410]), array([159, 109,  51]), array([87, 88, 37])), (array([10182,  7050,  3593]), array([110, 114,  47]), array([132,  73,  27])), (array([10593,  7994,  4212]), array([117,  65,  27]), array([124,  72,  34])), (array([11009,  8271,  4281]), array([53, 18,  7]), array([159, 134,  59]))]
[(46, 9, 11), (46, 10, 11), (46, 10, 11), (46, 9, 11), (46, 10, 11), (46, 10, 11), (46, 8, 11), (46, 10, 11), (46, 10, 11)]


In [53]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=3, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=70
)

In [54]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.2, random_state=np.random.RandomState(42))  # here, we reuse the outer splitter as well, so we use RandomState
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.2/0.8, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_3d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(7609, 364, 427), (8416, 258, 400), (7518, 449, 343), (6918, 344, 428), (8461, 304, 270), (8482, 298, 274), (8230, 333, 275), (7358, 427, 334), (8899, 203, 414)]
[(array([6357, 4807, 2161]), array([310, 219, 133]), array([337, 222,  99])), (array([6869, 4741, 2402]), array([222, 149,  95]), array([313, 237, 131])), (array([6147, 4271, 2290]), array([351, 256, 123]), array([292, 214,  95])), (array([5893, 3896, 1851]), array([255, 155,  98]), array([343, 300, 151])), (array([6912, 5079, 2659]), array([244, 147,  60]), array([201, 130,  60])), (array([7626, 4900, 2542]), array([182,  85,  62]), array([189, 185,  73])), (array([6880, 5290, 2694]), array([296, 224, 116]), array([182,  87,  32])), (array([6025, 4351, 2322]), array([385, 287, 148]), array([257, 165,  59])), (array([7651, 5361, 2746]), array([139,  87,  52]), array([309, 233,  80]))]
[(39, 14, 14), (39, 14, 14), (39, 14, 14), (39, 14, 14), (39, 14, 14), (39, 14, 14), (39, 14, 14), (39, 14, 14), (39, 14, 14)]
[(36, 13, 15), (

In [55]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=3, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=60
)

In [56]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.25, random_state=np.random.RandomState(42))  # here, we reuse the outer splitter as well, so we use RandomState
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.25/0.75, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_3d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(4481, 613, 834), (4384, 672, 688), (4506, 580, 714), (3999, 554, 735), (5221, 590, 539), (5498, 554, 502), (4626, 672, 515), (5028, 584, 563), (4225, 638, 736)]
[(array([3608, 2827, 1349]), array([517, 319, 180]), array([697, 491, 201])), (array([3674, 2726, 1073]), array([548, 266, 177]), array([553, 435, 230])), (array([3458, 2533, 1416]), array([509, 348, 149]), array([589, 373, 160])), (array([3470, 2729, 1498]), array([408, 148,  56]), array([581, 475, 213])), (array([4275, 3253, 1292]), array([474, 321, 248]), array([445, 268, 124])), (array([4746, 2687, 1420]), array([486, 302, 116]), array([344, 334, 127])), (array([3870, 3192, 1655]), array([584, 394, 132]), array([359, 184,  66])), (array([4388, 2779, 1668]), array([449, 386, 182]), array([443, 298, 113])), (array([3440, 2135,  934]), array([567, 394, 307]), array([598, 498, 191]))]
[(33, 17, 17), (33, 17, 17), (33, 17, 17), (33, 17, 17), (33, 17, 17), (33, 17, 17), (33, 17, 17), (33, 17, 17), (33, 17, 17)]
[(30, 15, 19), (

In [57]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=3, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=50
)

In [58]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.3, random_state=np.random.RandomState(42))  # here, we reuse the outer splitter as well, so we use RandomState
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.3/0.7, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_3d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(2569, 922, 1372), (2341, 1046, 1128), (2123, 1035, 1266), (2308, 772, 1267), (2287, 1221, 968), (2466, 1176, 779), (2176, 1221, 788), (2831, 967, 1110), (2053, 1100, 1242)]
[(array([2291, 1707,  841]), array([664, 456, 259]), array([1180,  809,  378])), (array([1963, 1592,  820]), array([931, 421, 224]), array([848, 637, 308])), (array([1714, 1394,  596]), array([860, 515, 339]), array([1071,  743,  340])), (array([2062, 1236,  643]), array([622, 474, 241]), array([1030,  838,  367])), (array([1738, 1400,  736]), array([1043,  719,  388]), array([784, 442, 189])), (array([2191, 1624,  775]), array([966, 473, 245]), array([585, 490, 203])), (array([1726, 1343,  606]), array([1066,  836,  393]), array([569, 333, 122])), (array([2445, 1822, 1013]), array([700, 437, 247]), array([884, 625, 256])), (array([1554,  830,  406]), array([952, 664, 356]), array([1045,  891,  387]))]
[(26, 20, 21), (26, 20, 21), (26, 20, 21), (26, 20, 21), (26, 20, 21), (26, 20, 21), (26, 20, 21), (26, 20, 21), 

In [59]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=3, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=40
)

In [60]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.33, random_state=np.random.RandomState(42))  # here, we reuse the outer splitter as well, so we use RandomState
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.33/0.67, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_3d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(1438, 1617, 1633), (1417, 1489, 1458), (1359, 1409, 1687), (1173, 1072, 1730), (1500, 1491, 1380), (1598, 1493, 1033), (1237, 1807, 1168), (983, 1983, 1457), (988, 1830, 1568)]
[(array([1168,  847,  327]), array([1334,  899,  589]), array([1350, 1001,  452])), (array([1228,  957,  522]), array([1293,  619,  335]), array([1098,  903,  377])), (array([1159,  876,  449]), array([1140,  819,  438]), array([1339,  885,  407])), (array([902, 615, 260]), array([981, 611, 403]), array([1413, 1115,  480])), (array([1256,  986,  580]), array([1171,  941,  480]), array([1159,  642,  271])), (array([1442,  830,  510]), array([1267,  892,  378]), array([770, 640, 229])), (array([984, 690, 405]), array([1545, 1156,  516]), array([876, 571, 194])), (array([847, 603, 293]), array([1659, 1083,  745]), array([1175,  834,  363])), (array([858, 558, 435]), array([1427,  867,  345]), array([1298, 1133,  496]))]
[(22, 22, 23), (22, 22, 23), (22, 22, 23), (22, 22, 23), (22, 22, 23), (22, 22, 23), (22, 22, 

In [61]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=3, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=34
)

In [62]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.35, random_state=np.random.RandomState(42))  # here, we reuse the outer splitter as well, so we use RandomState
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.35/0.65, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_3d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(1011, 1763, 1770), (879, 1781, 1831), (788, 1922, 1976), (904, 1433, 1912), (886, 1950, 1572), (1036, 2040, 1295), (1136, 1672, 1406), (948, 1865, 1673), (974, 1769, 1777)]
[(array([817, 655, 260]), array([1448,  746,  398]), array([1479, 1129,  523])), (array([769, 568, 306]), array([1478,  816,  403]), array([1436, 1148,  564])), (array([636, 406, 184]), array([1641, 1328,  793]), array([1575, 1082,  483])), (array([704, 426, 223]), array([1313,  970,  539]), array([1565, 1254,  501])), (array([661, 501, 279]), array([1710, 1344,  649]), array([1269,  735,  319])), (array([926, 616, 232]), array([1718, 1113,  709]), array([1004,  841,  295])), (array([1084,  941,  474]), array([1238,  945,  472]), array([1058,  623,  229])), (array([795, 683, 324]), array([1500,  815,  513]), array([1329,  919,  378])), (array([770, 528, 307]), array([1441,  900,  506]), array([1462, 1208,  514]))]
[(19, 24, 24), (19, 24, 24), (19, 24, 24), (19, 24, 24), (19, 24, 24), (19, 24, 24), (19, 24, 24), (1

In [63]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=3, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=30
)

In [64]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.35, random_state=np.random.RandomState(42))  # here, we reuse the outer splitter as well, so we use RandomState
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.40/0.65, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_3d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(460, 2978, 1770), (546, 2520, 1831), (478, 2618, 1976), (509, 2168, 1912), (533, 2788, 1572), (581, 2937, 1295), (709, 2415, 1406), (553, 2622, 1673), (570, 2529, 1777)]
[(array([363, 283, 103]), array([2508, 1418,  674]), array([1479, 1129,  523])), (array([470, 390, 183]), array([2162, 1149,  581]), array([1436, 1148,  564])), (array([386, 248, 114]), array([2248, 1792, 1068]), array([1575, 1082,  483])), (array([394, 201, 122]), array([1941, 1511,  826]), array([1565, 1254,  501])), (array([407, 294, 157]), array([2356, 1860,  930]), array([1269,  735,  319])), (array([516, 337, 105]), array([2477, 1596, 1042]), array([1004,  841,  295])), (array([678, 579, 300]), array([1810, 1398,  692]), array([1058,  623,  229])), (array([479, 408, 216]), array([2083, 1277,  736]), array([1329,  919,  378])), (array([481, 347, 203]), array([2037, 1220,  649]), array([1462, 1208,  514]))]
[(16, 27, 24), (16, 27, 24), (16, 27, 24), (16, 27, 24), (16, 27, 24), (16, 27, 24), (16, 27, 24), (16, 27,

In [65]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=3, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=25
)

In [66]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.40, random_state=np.random.RandomState(42))  # here, we reuse the outer splitter as well, so we use RandomState
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.40/0.60, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_3d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(291, 2590, 2743), (188, 2822, 2831), (350, 2257, 2903), (240, 2188, 2765), (265, 2968, 2506), (252, 2875, 2318), (204, 3124, 1925), (280, 2765, 2550), (386, 1856, 2793)]
[(array([217, 159, 105]), array([2304, 1427,  678]), array([2262, 1750,  717])), (array([165,  72,  35]), array([2516, 1891,  891]), array([2081, 1581,  769])), (array([302, 235,  99]), array([1826, 1384,  795]), array([2348, 1519,  720])), (array([201, 110,  62]), array([1866, 1279,  839]), array([2212, 1826,  702])), (array([235, 164, 109]), array([2220, 1767,  767]), array([2122, 1303,  564])), (array([190, 140,  73]), array([2552, 1552,  785]), array([1770, 1413,  559])), (array([160, 138,  50]), array([2666, 2108, 1287]), array([1475,  848,  315])), (array([259, 187, 129]), array([2195, 1431,  853]), array([2050, 1499,  586])), (array([351, 218, 140]), array([1511,  903,  507]), array([2268, 1916,  721]))]
[(13, 27, 27), (13, 27, 27), (13, 27, 27), (13, 27, 27), (13, 27, 27), (12, 27, 27), (13, 27, 27), (13, 27,

In [67]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=3, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=20
)

In [68]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.40, random_state=np.random.RandomState(42))  # here, we reuse the outer splitter as well, so we use RandomState
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.45/0.60, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_3d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(140, 3569, 2743), (87, 3727, 2831), (161, 3313, 2903), (112, 2909, 2765), (127, 3955, 2506), (105, 3967, 2318), (103, 4387, 1925), (100, 3851, 2550), (189, 2648, 2793)]
[(array([100,  63,  54]), array([3165, 2062,  986]), array([2262, 1750,  717])), (array([77, 31, 13]), array([3326, 2409, 1192]), array([2081, 1581,  769])), (array([132,  96,  34]), array([2757, 2113, 1202]), array([2348, 1519,  720])), (array([93, 54, 28]), array([2509, 1706, 1064]), array([2212, 1826,  702])), (array([109,  74,  54]), array([3056, 2425, 1178]), array([2122, 1303,  564])), (array([86, 68, 35]), array([3456, 2126, 1027]), array([1770, 1413,  559])), (array([78, 61, 24]), array([3767, 3069, 1794]), array([1475,  848,  315])), (array([89, 58, 42]), array([3135, 2159, 1299]), array([2050, 1499,  586])), (array([169, 100,  71]), array([2200, 1342,  728]), array([2268, 1916,  721]))]
[(10, 30, 27), (10, 30, 27), (6, 30, 27), (10, 30, 27), (10, 30, 27), (8, 30, 27), (10, 30, 27), (10, 30, 27), (10, 30, 27)

In [69]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=3, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=15
)

In [70]:
splitter = GroupShuffleSplitND(n_splits=9, test_size=0.45, random_state=np.random.RandomState(42))  # here, we reuse the outer splitter as well, so we use RandomState
inner_splitter = GroupShuffleSplitND(n_splits=1, test_size=0.45/0.55, random_state=np.random.RandomState(42))  # we use a RandomState instance, not an int, because we will reuse this splitter several times

indices, sizes, pos_class, unique_initiators, unique_monomers, unique_terminators = split_3d(splitter, inner_splitter)

print(sizes)
print(pos_class)
print(unique_initiators)
print(unique_monomers)
print(unique_terminators)

[(24, 3863, 4011), (13, 3576, 3987), (24, 3522, 4242), (14, 3279, 3920), (37, 3629, 3743), (29, 4020, 3334), (38, 4555, 2792), (23, 3560, 3876), (38, 3362, 3796)]
[(array([23, 16, 10]), array([3191, 1902, 1010]), array([3223, 2540, 1143])), (array([11,  7,  5]), array([3185, 2091, 1040]), array([2977, 2220, 1155])), (array([15, 15,  3]), array([2949, 2164, 1212]), array([3527, 2239, 1102])), (array([13,  9,  3]), array([2764, 1812, 1129]), array([2974, 2368,  904])), (array([30, 22, 12]), array([3024, 2367, 1227]), array([3062, 1801,  803])), (array([29, 21, 12]), array([3284, 2244, 1086]), array([2652, 2016,  847])), (array([18, 11,  1]), array([4160, 3565, 2285]), array([2206, 1305,  486])), (array([14,  5,  5]), array([3019, 2234, 1370]), array([3215, 2292,  958])), (array([21, 25,  9]), array([2867, 1539, 1001]), array([3086, 2572, 1057]))]
[(6, 30, 31), (5, 30, 31), (6, 30, 31), (5, 30, 31), (6, 30, 31), (6, 30, 31), (6, 30, 31), (6, 30, 31), (6, 30, 31)]
[(6, 28, 31), (5, 27, 32)

In [71]:
write_indices_and_stats(
    indices, 
    sizes, 
    pos_class,
    total_size=len(df),
    data_name=data_name,
    split_dimension=3, 
    save_indices=True, 
    n_initiators=unique_initiators, 
    n_monomers=unique_monomers, 
    n_terminators=unique_terminators, 
    train_size=10
)