In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from sklearn.model_selection import train_test_split

In [2]:
print("Loading MNIST dataset...")
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

print(f"Original training data shape: {x_train.shape}")
print(f"Original training labels shape: {y_train.shape}")
print(f"Original test data shape: {x_test.shape}")
print(f"Original test labels shape: {y_test.shape}")
print(f"Data type: {x_train.dtype}")
print(f"Pixel value range: [{x_train.min()}, {x_train.max()}]")

Loading MNIST dataset...
Original training data shape: (60000, 28, 28)
Original training labels shape: (60000,)
Original test data shape: (10000, 28, 28)
Original test labels shape: (10000,)
Data type: uint8
Pixel value range: [0, 255]


In [3]:
# Normalize pixel values to [0, 1] range
x_train_norm = x_train.astype('float32') / 255.0
x_test_norm = x_test.astype('float32') / 255.0

In [4]:
# Combine training and test data for balanced splitting
x_combined = np.concatenate([x_train_norm, x_test_norm], axis=0)
y_combined = np.concatenate([y_train, y_test], axis=0)

print(f"Combined data shape: {x_combined.shape}")
print(f"Combined labels shape: {y_combined.shape}")

# Shuffle the combined data
from sklearn.utils import shuffle
x_combined, y_combined = shuffle(x_combined, y_combined, random_state=42)
print("Data shuffled successfully!")

Combined data shape: (70000, 28, 28)
Combined labels shape: (70000,)
Data shuffled successfully!


In [6]:

from sklearn.model_selection import StratifiedKFold

# Five equal, disjoint, stratified parts
skf = StratifiedKFold(n_splits=20, shuffle=True, random_state=42)

parts_x, parts_y, fold_indices = [], [], []
for fold, (train_idx, test_idx) in enumerate(skf.split(x_combined, y_combined), start=1):
    parts_x.append(x_combined[test_idx])
    parts_y.append(y_combined[test_idx])
    fold_indices.append(test_idx)
    print(f"Part {fold} total samples: {parts_x[-1].shape[0]}")

# Optional: assign to named variables if you prefer
x_part1_all, y_part1_all = parts_x[0], parts_y[0]
x_part2_all, y_part2_all = parts_x[1], parts_y[1]
x_part3_all, y_part3_all = parts_x[2], parts_y[2]
x_part4_all, y_part4_all = parts_x[3], parts_y[3]
x_part5_all, y_part5_all = parts_x[4], parts_y[4]
x_part6_all, y_part6_all = parts_x[5], parts_y[5]
x_part7_all, y_part7_all = parts_x[6], parts_y[6]
x_part8_all, y_part8_all = parts_x[7], parts_y[7]
x_part9_all, y_part9_all = parts_x[8], parts_y[8]
x_part10_all, y_part10_all = parts_x[9], parts_y[9]
x_part11_all, y_part11_all = parts_x[10], parts_y[10]
x_part12_all, y_part12_all = parts_x[11], parts_y[11]
x_part13_all, y_part13_all = parts_x[12], parts_y[12]
x_part14_all, y_part14_all = parts_x[13], parts_y[13]
x_part15_all, y_part15_all = parts_x[14], parts_y[14]
x_part16_all, y_part16_all = parts_x[15], parts_y[15]
x_part17_all, y_part17_all = parts_x[16], parts_y[16]
x_part18_all, y_part18_all = parts_x[17], parts_y[17]
x_part19_all, y_part19_all = parts_x[18], parts_y[18]
x_part20_all, y_part20_all = parts_x[19], parts_y[19]

# Sanity check (coverage and no overlap)
concat_idx = np.concatenate(fold_indices)
print(f"Covered: {np.unique(concat_idx).size} of {x_combined.shape[0]} samples")


Part 1 total samples: 3500
Part 2 total samples: 3500
Part 3 total samples: 3500
Part 4 total samples: 3500
Part 5 total samples: 3500
Part 6 total samples: 3500
Part 7 total samples: 3500
Part 8 total samples: 3500
Part 9 total samples: 3500
Part 10 total samples: 3500
Part 11 total samples: 3500
Part 12 total samples: 3500
Part 13 total samples: 3500
Part 14 total samples: 3500
Part 15 total samples: 3500
Part 16 total samples: 3500
Part 17 total samples: 3500
Part 18 total samples: 3500
Part 19 total samples: 3500
Part 20 total samples: 3500
Covered: 70000 of 70000 samples


In [7]:
original_train_ratio = 60000 / (60000 + 10000)
test_ratio = 1 - original_train_ratio  # ~0.142857

x_parts_train, x_parts_test, y_parts_train, y_parts_test = [], [], [], []

for i, (px, py) in enumerate(zip(parts_x, parts_y), start=1):
    x_tr, x_te, y_tr, y_te = train_test_split(
        px, py,
        test_size=test_ratio,
        stratify=py,
        random_state=42  # keep reproducible
    )
    x_parts_train.append(x_tr); x_parts_test.append(x_te)
    y_parts_train.append(y_tr); y_parts_test.append(y_te)
    print(f"Part {i}: train={x_tr.shape[0]}, test={x_te.shape[0]}")

Part 1: train=2999, test=501
Part 2: train=2999, test=501
Part 3: train=2999, test=501
Part 4: train=2999, test=501
Part 5: train=2999, test=501
Part 6: train=2999, test=501
Part 7: train=2999, test=501
Part 8: train=2999, test=501
Part 9: train=2999, test=501
Part 10: train=2999, test=501
Part 11: train=2999, test=501
Part 12: train=2999, test=501
Part 13: train=2999, test=501
Part 14: train=2999, test=501
Part 15: train=2999, test=501
Part 16: train=2999, test=501
Part 17: train=2999, test=501
Part 18: train=2999, test=501
Part 19: train=2999, test=501
Part 20: train=2999, test=501


In [8]:
import os
save_dir = "mnist_split_data_20"
os.makedirs(save_dir, exist_ok=True)

for i in range(20):
    out_path = os.path.join(save_dir, f"mnist_part{i+1}.npz")
    np.savez_compressed(
        out_path,
        x_train=x_parts_train[i],
        y_train=y_parts_train[i],
        x_test=x_parts_test[i],
        y_test=y_parts_test[i],
    )
    size_mb = os.path.getsize(out_path) / (1024*1024)
    print(f"Saved Part {i+1} to {out_path} ({size_mb:.2f} MB)")

Saved Part 1 to mnist_split_data_20\mnist_part1.npz (0.88 MB)
Saved Part 2 to mnist_split_data_20\mnist_part2.npz (0.88 MB)
Saved Part 3 to mnist_split_data_20\mnist_part3.npz (0.88 MB)
Saved Part 4 to mnist_split_data_20\mnist_part4.npz (0.88 MB)
Saved Part 5 to mnist_split_data_20\mnist_part5.npz (0.88 MB)
Saved Part 6 to mnist_split_data_20\mnist_part6.npz (0.88 MB)
Saved Part 7 to mnist_split_data_20\mnist_part7.npz (0.88 MB)
Saved Part 8 to mnist_split_data_20\mnist_part8.npz (0.88 MB)
Saved Part 9 to mnist_split_data_20\mnist_part9.npz (0.88 MB)
Saved Part 10 to mnist_split_data_20\mnist_part10.npz (0.88 MB)
Saved Part 11 to mnist_split_data_20\mnist_part11.npz (0.88 MB)
Saved Part 12 to mnist_split_data_20\mnist_part12.npz (0.88 MB)
Saved Part 13 to mnist_split_data_20\mnist_part13.npz (0.88 MB)
Saved Part 14 to mnist_split_data_20\mnist_part14.npz (0.88 MB)
Saved Part 15 to mnist_split_data_20\mnist_part15.npz (0.88 MB)
Saved Part 16 to mnist_split_data_20\mnist_part16.npz (0.8