In [1]:
import numpy as np
import h5py
from sklearn.utils import shuffle
import json

In [2]:
def get_trainvalid_indexes(path_h5, valid_split = 0.2, nfolds = 5):
    
    # Dataset splitting for cross validation.
    import random
    with h5py.File(path_h5, 'r') as f:
        n = len(f['labels'])
        
    all_indexes = np.array(shuffle(list(range(n))))
    for test_indexes in np.array_split(all_indexes, nfolds):
        num_valid = int((n - len(test_indexes)) * valid_split)
        num_train = n - num_valid - len(test_indexes)
        train_indexes = all_indexes.copy()
        train_indexes = train_indexes.tolist()
        for t in test_indexes:
            train_indexes.remove(t)
            
        valid_indexes = random.sample(train_indexes, num_valid)
        for v in valid_indexes:
            train_indexes.remove(v)
            
        test_indexes, train_indexes, valid_indexes = \
        shuffle(test_indexes), shuffle(train_indexes), shuffle(valid_indexes)
        yield test_indexes.tolist(), train_indexes, valid_indexes


In [3]:
%%script echo 'Remove this line to execute the code in this cell.'

path_h5 = './dataset.h5'
nfolds = 5

fold = 0
for test_indexes, train_indexes, valid_indexes in get_trainvalid_indexes(path_h5,
                                                                         valid_split = 0.2,
                                                                         nfolds = nfolds):
    
    outfname = f'./cv_fold_{fold}_dset_splits.json'
    subsets_dict = {"test_indexes" : test_indexes,
                    "train_indexes" : train_indexes,
                    "valid_indexes" : valid_indexes}
    
    with open(outfname, 'w') as f:
        json.dump(subsets_dict, f)
        
    fold += 1


Remove this line to execute the code in this cell.


In [4]:
# Read dataset splits from file:
path_json = './cv_fold_0_dset_splits.json'
with open(path_json, 'r') as f:
    d = json.load(f)
    
for key in d.keys():
    print(f'{len(d[key])} items in {key}:\n\n{d[key]}\n')


292 items in test_indexes:

[1423, 475, 898, 1107, 894, 399, 55, 180, 156, 1409, 206, 258, 436, 571, 207, 1338, 433, 505, 74, 279, 553, 314, 653, 1373, 16, 680, 405, 938, 637, 817, 727, 682, 663, 979, 265, 753, 708, 1355, 499, 82, 749, 1095, 370, 784, 361, 65, 1234, 1015, 681, 329, 171, 1270, 1165, 911, 186, 785, 1046, 984, 1295, 570, 742, 1332, 698, 52, 754, 737, 706, 1410, 995, 916, 818, 1144, 177, 473, 1357, 641, 412, 1045, 506, 871, 910, 129, 27, 796, 723, 759, 252, 100, 886, 33, 1001, 1456, 488, 1080, 363, 424, 1318, 88, 686, 1442, 351, 168, 621, 1143, 631, 194, 301, 66, 793, 32, 661, 1238, 961, 1009, 1133, 605, 259, 1161, 940, 1126, 809, 816, 366, 988, 651, 890, 80, 515, 164, 1111, 1278, 62, 372, 519, 719, 665, 909, 167, 1038, 1360, 1211, 196, 1233, 1261, 697, 1110, 1227, 946, 434, 1378, 1245, 959, 778, 632, 91, 736, 336, 906, 1353, 555, 408, 1057, 7, 303, 493, 509, 1189, 608, 813, 756, 1071, 118, 579, 1069, 1117, 1079, 1134, 240, 774, 1150, 895, 837, 1124, 1387, 1055, 384, 783, 