In [10]:
import numpy as np
from dataclasses import dataclass


@dataclass
class DataLabelsPair:
    data: np.ndarray
    labels: np.ndarray


@dataclass
class SplittedDataset:
    train: DataLabelsPair
    val: DataLabelsPair
    test: DataLabelsPair


def split_train_val_test(
    data: np.ndarray, labels: np.ndarray, train_perc: float, val_perc: float
) -> SplittedDataset:
    assert train_perc + val_perc < 1.0
    assert len(data) == len(labels)

    ds_size = len(data)
    random_indecies = np.arange(ds_size)
    np.random.shuffle(random_indecies)

    train_indecies, validation_indecies, test_indecies, _ = np.split(
        random_indecies,
        [int(ds_size * train_perc), int(ds_size * (train_perc + val_perc)), ds_size],
    )

    train_data, train_labels = data[train_indecies], labels[train_indecies]
    validation_data, validation_labels = data[validation_indecies], labels[validation_indecies]
    test_data, test_labels = data[test_indecies], labels[test_indecies]

    return SplittedDataset(
        DataLabelsPair(train_data, train_labels),
        DataLabelsPair(validation_data, validation_labels),
        DataLabelsPair(test_data, test_labels),
    )


In [11]:
ds_size = 1000
data = np.arange(0, ds_size, 1)
np.random.shuffle(data)
labels = (np.random.random(ds_size) > 0.5).astype(np.int32)
print(data)
print(labels)

[310 238 239 291 562 335 635 758 866 273 967 681 786 798 730 587 606 615
 686 612 879 234 465 856 319 619 460 480 781 360 466 552 318 401 675 389
 924 764 825 707 538 607 325 230 838 645 722 236 800 177 433 870 787 705
 560 372 775 581 130 317 394 717 992  13 962 513 228  47  95 646 295 280
 778 316 102 267 828 422 636 407 338 700 439 263  87 845 892 634 631 734
 881 996 121 350 426 698 936  21 434 756 159 365 926  41 322 843 537  16
 427 573 998 224 655 508 839 356 366 473 297 309 351 942 604 766 885 672
 142 627 404 390 971 999 986 285 875 890 720 725 204 180 662 960 139 482
  64 392 211 352 421 584  85  36 157 938 233 406 243 554 476 557 776 979
 832 894 594 511 678 302 391 897 336 549 811 135 723 994 525 282 431  68
 977 303 164 546 333 810 659 424 574 362 544  44 181 328 783 512 483 171
 370 252 616 475 453  66 754 702 611 923  98 966 388 974 287 623 782 694
 753 516 432  99 487 455 760 899 836 380 527 588 154 541  83 841 489  67
 704 173   1 518 399 744 821  77 457 945 521  33 45

In [12]:
ds = split_train_val_test(data, labels, 0.8, 0.1)

In [13]:
ds.train.data

array([ 32,   6, 688, 602, 201, 987, 195, 779, 317, 475, 227, 321, 305,
       940, 210, 926, 389, 151, 984, 910, 289, 885, 963, 888, 997, 893,
       941, 371, 384, 164, 146, 225,  31, 404,  10, 961, 250, 558, 584,
       762, 176, 847, 659, 966,  14, 799, 245, 817,  73, 198, 969,  67,
       213,  80, 916, 788, 886, 592, 951, 474, 800, 303, 276, 611, 820,
       331, 131, 933, 459, 266, 741, 537, 249,   7,   1, 122, 676, 152,
       866, 792,  77,  36, 891, 878, 559, 810,  33, 380, 237, 535, 852,
       757, 578, 113, 601, 929, 108,  30, 896, 115, 939, 796, 605, 767,
       153, 135,  44, 718, 795, 875, 493, 254, 772, 531, 743, 702, 831,
       318, 824, 848, 964, 612, 548, 674, 387, 804, 911, 224, 557, 398,
       408, 416, 945, 337, 999, 754, 618, 262, 322, 163, 299, 465, 309,
       206,  57, 707, 752, 711, 186, 652, 405, 840, 739, 295, 693, 445,
       713, 838, 308, 616, 626, 928, 600, 447, 534, 271, 846,  34, 188,
       957, 479, 429, 178, 572, 348, 532, 160, 236, 623, 464, 62