In [1]:
# 📄 src/preprocessing/split_dataset.py

import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
import os


def stream_split_save_lazy(x_path, y_path, save_dir, label_name, test_size=0.2, chunk_size=100, random_state=42):
    os.makedirs(save_dir, exist_ok=True)

    X_memmap = np.load(x_path, mmap_mode='r')
    y_array = np.load(y_path)
    total = X_memmap.shape[0]

    train_X_path = os.path.join(save_dir, f'X_train_{label_name}.npy')
    val_X_path = os.path.join(save_dir, f'X_val_{label_name}.npy')
    train_y_path = os.path.join(save_dir, f'y_train_{label_name}.npy')
    val_y_path = os.path.join(save_dir, f'y_val_{label_name}.npy')

    # 빈 배열 생성 (사이즈 모르므로 리스트로 임시 저장)
    X_train_list, y_train_list, X_val_list, y_val_list = [], [], [], []

    for i in range(0, total, chunk_size):
        end = min(i + chunk_size, total)
        try:
            X_chunk = np.array(X_memmap[i:end], dtype=np.float32)
        except MemoryError:
            print(f"❌ MemoryError at chunk {i}:{end}, skipping...")
            continue

        y_chunk = y_array[i:end]

        X_tr, X_v, y_tr, y_v = train_test_split(
            X_chunk, y_chunk, test_size=test_size, random_state=random_state, stratify=y_chunk
        )

        X_train_list.append(X_tr)
        y_train_list.append(y_tr)
        X_val_list.append(X_v)
        y_val_list.append(y_v)

    # 저장 (최종 병합)
    np.save(train_X_path, np.concatenate(X_train_list, axis=0))
    np.save(train_y_path, np.concatenate(y_train_list, axis=0))
    np.save(val_X_path, np.concatenate(X_val_list, axis=0))
    np.save(val_y_path, np.concatenate(y_val_list, axis=0))

    print(f"✅ 저장 완료: {label_name} → train: {np.concatenate(X_train_list).shape}, val: {np.concatenate(X_val_list).shape}")


if __name__ == "__main__":
    stream_split_save_lazy(
        x_path="../../Database/processed/X_label0.npy",
        y_path="../../Database/processed/y_label0.npy",
        save_dir="../../Database/processed/split",
        label_name="0"
    )
    stream_split_save_lazy(
        x_path="../../Database/processed/X_label1.npy",
        y_path="../../Database/processed/y_label1.npy",
        save_dir="../../Database/processed/split",
        label_name="1"
    )


✅ 저장 완료: 0 → train: (68712, 10, 6000), val: (17178, 10, 6000)
✅ 저장 완료: 1 → train: (69749, 10, 6000), val: (17438, 10, 6000)
