In [None]:
!pip install --upgrade tensorflow



In [None]:
!pip install mne



In [None]:
import os
import pandas as pd
import re
from tqdm import tqdm
import numpy as np
import mne
import matplotlib.pyplot as plt
import pandas as pd
import glob
import os

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
root = '/content/drive/MyDrive/chbmit_data/chbmit-1.0.0.physionet.org'

In [None]:
os.listdir(root)

['chb01',
 'ANNOTATORS',
 'RECORDS',
 'RECORDS-WITH-SEIZURES',
 'SUBJECT-INFO',
 'SHA256SUMS.txt',
 'shoeb-icml-2010.pdf',
 'chb02',
 'chb03',
 'chb04',
 'chb05',
 'chb06',
 'chb07',
 'chb08',
 'chb09',
 'chb10',
 'chb11',
 'chb12',
 'chb13',
 'chb14',
 'chb15',
 'chb16',
 'chb17',
 'chb18',
 'chb19',
 'chb20',
 'chb21',
 'chb22',
 'chb23',
 'chb24',
 'chb-mit-scalp-eeg-database-1.0.0.zip']

In [None]:
# Helper: Get all patients
def get_patients(root):
    return sorted([folder for folder in os.listdir(root) if os.path.isdir(os.path.join(root, folder))])

# Helper: Read patient's summary file
def read_patient_summary(patient_folder):
    summary_path = os.path.join(root, patient_folder, f"{patient_folder}-summary.txt")
    with open(summary_path, "r", encoding="utf-8") as f:
        return f.read()

# Helper: Parse seizure times
def parse_seizure_times(summary_text):
    seizure_dict = {}
    current_file = None
    starts = []
    ends = []

    for line in summary_text.splitlines():
        line = line.strip()

        # Detect new file
        if line.startswith("File Name:"):
            if current_file:
                seizure_dict[current_file] = list(zip(starts, ends))
            current_file = line.split(":")[-1].strip()
            starts = []
            ends = []

        # Detect seizure start
        elif line.startswith("Seizure Start Time"):
            try:
                time_in_sec = int(line.split(":")[-1].strip().split()[0])
                starts.append(time_in_sec)
            except:
                pass  # skip broken lines

        # Detect seizure end
        elif line.startswith("Seizure End Time"):
            try:
                time_in_sec = int(line.split(":")[-1].strip().split()[0])
                ends.append(time_in_sec)
            except:
                pass  # skip broken lines

    if current_file:
        seizure_dict[current_file] = list(zip(starts, ends))

    return seizure_dict

# Main: Build metadata DataFrame
def build_metadata(root):
    patients = get_patients(root)
    data = []

    for patient in tqdm(patients, desc="Parsing Patient Data"):
        patient_path = os.path.join(root, patient)
        summary_text = read_patient_summary(patient)
        seizure_times = parse_seizure_times(summary_text)

        edf_files = [file for file in os.listdir(patient_path) if file.endswith(".edf")]

        for edf in edf_files:
            seizure_intervals = seizure_times.get(edf,[])
            has_seizure = len(seizure_intervals) > 0

            seizure_start_times = [start for start, end in seizure_times.get(edf, [])]
            seizure_end_times = [end for start, end in seizure_times.get(edf, [])]

            data.append({
                "patient": patient,
                "edf_file": edf,
                "edf_path": os.path.join(patient_path, edf),
                "has_seizure": has_seizure,
                "seizure_start_times": seizure_start_times,
                "seizure_end_times": seizure_end_times
            })

    df = pd.DataFrame(data)
    return df


In [None]:
patient_df = build_metadata(root)

Parsing Patient Data: 100%|██████████| 24/24 [00:07<00:00,  3.34it/s]


In [None]:
patient_df.head()

Unnamed: 0,patient,edf_file,edf_path,has_seizure,seizure_start_times,seizure_end_times
0,chb01,chb01_03.edf,/content/drive/MyDrive/chbmit_data/chbmit-1.0....,True,[2996],[3036]
1,chb01,chb01_04.edf,/content/drive/MyDrive/chbmit_data/chbmit-1.0....,True,[1467],[1494]
2,chb01,chb01_05.edf,/content/drive/MyDrive/chbmit_data/chbmit-1.0....,False,[],[]
3,chb01,chb01_06.edf,/content/drive/MyDrive/chbmit_data/chbmit-1.0....,False,[],[]
4,chb01,chb01_08.edf,/content/drive/MyDrive/chbmit_data/chbmit-1.0....,False,[],[]


In [None]:
# BLOCK 2: Calculate Absolute Times
def calculate_absolute_times(df):
    result_data = []
    patients = df['patient'].unique()

    for patient in tqdm(patients, desc="Calculating Absolute Times"):
        patient_df = df[df['patient'] == patient].sort_values(by="edf_file")  # Sort files correctly
        cumulative_time = 0

        for idx, row in patient_df.iterrows():
            edf_path = row['edf_path']
            try:
                raw = mne.io.read_raw_edf(edf_path, preload=False, verbose='ERROR')
                duration = raw.times[-1]  # Duration in seconds
            except Exception as e:
                print(f"Error reading {edf_path}: {e}")
                duration = 0

            absolute_start = cumulative_time
            absolute_end = cumulative_time + duration

            result_data.append({
                "patient": row['patient'],
                "edf_file": row['edf_file'],
                "edf_path": row['edf_path'],
                "has_seizure": row['has_seizure'],
                "seizure_start_times": row['seizure_start_times'],
                "seizure_end_times": row['seizure_end_times'],
                "absolute_start_time": absolute_start,
                "absolute_end_time": absolute_end
            })

            cumulative_time = absolute_end  # Update for next file

    return pd.DataFrame(result_data)

In [None]:
patient_df_absolute = calculate_absolute_times(patient_df)

Calculating Absolute Times: 100%|██████████| 24/24 [02:54<00:00,  7.28s/it]


In [None]:
patient_df_absolute.head()

Unnamed: 0,patient,edf_file,edf_path,has_seizure,seizure_start_times,seizure_end_times,absolute_start_time,absolute_end_time
0,chb01,chb01_01.edf,/content/drive/MyDrive/chbmit_data/chbmit-1.0....,False,[],[],0.0,3599.996094
1,chb01,chb01_02.edf,/content/drive/MyDrive/chbmit_data/chbmit-1.0....,False,[],[],3599.996094,7199.992188
2,chb01,chb01_03.edf,/content/drive/MyDrive/chbmit_data/chbmit-1.0....,True,[2996],[3036],7199.992188,10799.988281
3,chb01,chb01_04.edf,/content/drive/MyDrive/chbmit_data/chbmit-1.0....,True,[1467],[1494],10799.988281,14399.984375
4,chb01,chb01_05.edf,/content/drive/MyDrive/chbmit_data/chbmit-1.0....,False,[],[],14399.984375,17999.980469


In [None]:
def get_common_channels_from_summaries_correct(root, patient_df):
    patient_common_channels = []

    patients = patient_df['patient'].unique()

    for patient in tqdm(patients, desc="Reading Summary Files"):
        summary_path = os.path.join(root, patient, f"{patient}-summary.txt")

        try:
            with open(summary_path, 'r', encoding='utf-8') as f:
                lines = f.readlines()

            channels = []
            reading_channels = False

            for line in lines:
                line = line.strip()

                if "Channels in EDF Files" in line:
                    reading_channels = True
                    continue

                if reading_channels:
                    if line.startswith("Channel"):
                        # Format: Channel X: NAME
                        parts = line.split(":")
                        if len(parts) == 2:
                            channel_name = parts[1].strip()
                            if '-' in channel_name:  # Only real EEG channels
                                channels.append(channel_name)
                    elif line.startswith("File Name"):
                        break  # Stop when we reach file listings

            if channels:
                patient_channels = set(channels)
                patient_common_channels.append(patient_channels)
            else:
                print(f"⚠ No channels found for {patient}")

        except Exception as e:
            print(f"⚠ Error reading {summary_path}: {e}")

    # Final intersection across all patients
    if patient_common_channels:
        common_channels = set.intersection(*patient_common_channels)
    else:
        common_channels = set()

    return sorted(list(common_channels))

In [None]:
# Call
channels_to_test = get_common_channels_from_summaries_correct(root, patient_df_absolute)

print(f"\nCommon channels across all patients ({len(channels_to_test)} channels):\n")
print(channels_to_test)

Reading Summary Files: 100%|██████████| 24/24 [00:00<00:00, 289.89it/s]


Common channels across all patients (18 channels):

['C3-P3', 'C4-P4', 'CZ-PZ', 'F3-C3', 'F4-C4', 'F7-T7', 'F8-T8', 'FP1-F3', 'FP1-F7', 'FP2-F4', 'FP2-F8', 'FZ-CZ', 'P3-O1', 'P4-O2', 'P7-O1', 'P8-O2', 'T7-P7', 'T8-P8']





In [None]:
# --------- SETTINGS ---------
save_root = "/content/drive/MyDrive/EEG_Project/cleaned_segments"
os.makedirs(save_root, exist_ok=True)

preictal_duration = 1800  # 30 minutes
interictal_duration = 1800  # 30 minutes
min_seizure_distance = 7200  # 2 hours
sampling_rate = 256  # Hz

# --------- HELPERS ---------
def merge_close_seizures(seizures, threshold=3600):
    if not seizures:
        return []
    seizures.sort(key=lambda x: x[0])
    merged = [seizures[0]]
    for s in seizures[1:]:
        if s[0] <= merged[-1][1] + threshold:
            merged[-1] = (merged[-1][0], max(merged[-1][1], s[1]))
        else:
            merged.append(s)
    return merged

def find_safe_intervals(total_duration, seizures, min_distance):
    safe = []
    last_end = -min_distance
    for start, end in seizures:
        safe_start = last_end + min_distance
        safe_end = start - min_distance
        if safe_end > safe_start:
            safe.append((safe_start, safe_end))
        last_end = end
    if total_duration > last_end + min_distance:
        safe.append((last_end + min_distance, total_duration))
    return safe

def load_channel_segment(edf_path, abs_start, abs_end, file_abs_start, channel_name):
    try:
        raw = mne.io.read_raw_edf(edf_path, preload=False, verbose='ERROR')
        if channel_name not in raw.ch_names:
            return None
        raw.pick([channel_name])
        sr = int(raw.info['sfreq'])

        start_idx = int((abs_start - file_abs_start) * sr)
        end_idx = int((abs_end - file_abs_start) * sr)

        start_idx = max(0, start_idx)
        end_idx = min(raw.n_times, end_idx)

        if start_idx < end_idx:
            data, _ = raw[:, start_idx:end_idx]
            return data.flatten()
        else:
            return None
    except Exception as e:
        print(f"Error loading {edf_path}: {e}")
        return None

# --------- MAIN FUNCTION ---------
def extract_balanced_segments_per_patient(patient_df, channels):
    for channel in channels:
        print(f"\nExtracting for channel: {channel}")
        all_preictal = []
        all_interictal = []

        for patient in tqdm(patient_df['patient'].unique(), desc=f"Patients for {channel}"):
            patient_records = patient_df[patient_df['patient'] == patient]
            seizure_intervals = []

            for _, row in patient_records.iterrows():
                seizure_intervals.extend([
                    (row['absolute_start_time'] + s_start, row['absolute_start_time'] + s_end)
                    for s_start, s_end in zip(row['seizure_start_times'], row['seizure_end_times'])
                ])
            seizure_intervals = merge_close_seizures(seizure_intervals)

            # PREICTAL Extraction
            preictal_segments = []
            for seizure_start, _ in seizure_intervals:
                preictal_end = seizure_start - 1800  # 30 min before seizure
                preictal_start = preictal_end - 1800

                if preictal_start < 0:
                    continue  # Skip if not enough history

                relevant_files = patient_records[
                    (patient_records['absolute_start_time'] <= preictal_end) &
                    (patient_records['absolute_end_time'] >= preictal_start)
                ]

                segment_pieces = []
                for _, file_row in relevant_files.iterrows():
                    piece = load_channel_segment(file_row['edf_path'], preictal_start, preictal_end,
                                                 file_row['absolute_start_time'], channel)
                    if piece is not None:
                        segment_pieces.append(piece)

                if segment_pieces:
                    full_segment = np.hstack(segment_pieces)
                    if full_segment.shape[0] >= sampling_rate * preictal_duration:
                        preictal_segments.append({'patient': patient, 'segment_type': 'preictal', 'data': full_segment})

            # INTERICTAL Extraction
            interictal_segments = []
            total_duration = patient_records['absolute_end_time'].max()
            safe_intervals = find_safe_intervals(total_duration, seizure_intervals, min_seizure_distance)

            trials = 0
            while len(interictal_segments) < len(preictal_segments) and trials < len(preictal_segments) * 10:
                if not safe_intervals:
                    break
                safe_start, safe_end = safe_intervals[np.random.randint(len(safe_intervals))]
                if safe_end - safe_start >= interictal_duration:
                    random_start = np.random.uniform(safe_start, safe_end - interictal_duration)
                    inter_start = random_start
                    inter_end = random_start + interictal_duration

                    relevant_files = patient_records[
                        (patient_records['absolute_start_time'] <= inter_end) &
                        (patient_records['absolute_end_time'] >= inter_start)
                    ]

                    segment_pieces = []
                    valid = True
                    for _, file_row in relevant_files.iterrows():
                        piece = load_channel_segment(file_row['edf_path'], inter_start, inter_end,
                                                     file_row['absolute_start_time'], channel)
                        if piece is None or piece.size == 0:
                            valid = False
                            break
                        segment_pieces.append(piece)

                    if valid and segment_pieces:
                        full_segment = np.hstack(segment_pieces)
                        if full_segment.shape[0] >= sampling_rate * interictal_duration:
                            interictal_segments.append({'patient': patient, 'segment_type': 'interictal', 'data': full_segment})
                trials += 1

            # FINAL BALANCING PER PATIENT
            n_balanced = min(len(preictal_segments), len(interictal_segments))
            all_preictal.extend(preictal_segments[:n_balanced])
            all_interictal.extend(interictal_segments[:n_balanced])

        # SAVE PER CHANNEL
        preictal_df = pd.DataFrame(all_preictal)
        interictal_df = pd.DataFrame(all_interictal)

        preictal_df.to_pickle(os.path.join(save_root, f"preictal_{channel.replace('-', '_')}_cleaned.pkl"))
        interictal_df.to_pickle(os.path.join(save_root, f"interictal_{channel.replace('-', '_')}_cleaned.pkl"))

        print(f"\n✅ Channel {channel}: Saved {len(preictal_df)} preictal and {len(interictal_df)} interictal segments.")

In [None]:
channels_to_extract = channels_to_test
preictal_df, interictal_df = extract_balanced_segments_per_patient(patient_df_absolute, channels_to_extract)


Extracting for channel: C3-P3


Patients for C3-P3: 100%|██████████| 24/24 [00:39<00:00,  1.66s/it]



✅ Channel C3-P3: Saved 18 preictal and 18 interictal segments.

Extracting for channel: C4-P4


Patients for C4-P4: 100%|██████████| 24/24 [00:09<00:00,  2.53it/s]



✅ Channel C4-P4: Saved 18 preictal and 18 interictal segments.

Extracting for channel: CZ-PZ


Patients for CZ-PZ: 100%|██████████| 24/24 [00:14<00:00,  1.69it/s]



✅ Channel CZ-PZ: Saved 18 preictal and 18 interictal segments.

Extracting for channel: F3-C3


Patients for F3-C3: 100%|██████████| 24/24 [00:09<00:00,  2.56it/s]



✅ Channel F3-C3: Saved 18 preictal and 18 interictal segments.

Extracting for channel: F4-C4


Patients for F4-C4: 100%|██████████| 24/24 [00:14<00:00,  1.62it/s]



✅ Channel F4-C4: Saved 18 preictal and 18 interictal segments.

Extracting for channel: F7-T7


Patients for F7-T7: 100%|██████████| 24/24 [00:05<00:00,  4.30it/s]



✅ Channel F7-T7: Saved 18 preictal and 18 interictal segments.

Extracting for channel: F8-T8


Patients for F8-T8: 100%|██████████| 24/24 [00:13<00:00,  1.80it/s]



✅ Channel F8-T8: Saved 18 preictal and 18 interictal segments.

Extracting for channel: FP1-F3


Patients for FP1-F3: 100%|██████████| 24/24 [00:11<00:00,  2.07it/s]



✅ Channel FP1-F3: Saved 18 preictal and 18 interictal segments.

Extracting for channel: FP1-F7


Patients for FP1-F7: 100%|██████████| 24/24 [00:05<00:00,  4.21it/s]



✅ Channel FP1-F7: Saved 18 preictal and 18 interictal segments.

Extracting for channel: FP2-F4


Patients for FP2-F4: 100%|██████████| 24/24 [00:15<00:00,  1.59it/s]



✅ Channel FP2-F4: Saved 18 preictal and 18 interictal segments.

Extracting for channel: FP2-F8


Patients for FP2-F8: 100%|██████████| 24/24 [00:06<00:00,  3.65it/s]



✅ Channel FP2-F8: Saved 18 preictal and 18 interictal segments.

Extracting for channel: FZ-CZ


Patients for FZ-CZ: 100%|██████████| 24/24 [00:12<00:00,  1.88it/s]



✅ Channel FZ-CZ: Saved 18 preictal and 18 interictal segments.

Extracting for channel: P3-O1


Patients for P3-O1: 100%|██████████| 24/24 [00:07<00:00,  3.14it/s]



✅ Channel P3-O1: Saved 18 preictal and 18 interictal segments.

Extracting for channel: P4-O2


Patients for P4-O2: 100%|██████████| 24/24 [00:13<00:00,  1.81it/s]



✅ Channel P4-O2: Saved 18 preictal and 18 interictal segments.

Extracting for channel: P7-O1


Patients for P7-O1: 100%|██████████| 24/24 [00:05<00:00,  4.05it/s]



✅ Channel P7-O1: Saved 18 preictal and 18 interictal segments.

Extracting for channel: P8-O2


Patients for P8-O2: 100%|██████████| 24/24 [00:12<00:00,  1.93it/s]



✅ Channel P8-O2: Saved 18 preictal and 18 interictal segments.

Extracting for channel: T7-P7


Patients for T7-P7: 100%|██████████| 24/24 [00:11<00:00,  2.14it/s]



✅ Channel T7-P7: Saved 18 preictal and 18 interictal segments.

Extracting for channel: T8-P8


Patients for T8-P8: 100%|██████████| 24/24 [00:00<00:00, 42.42it/s]


✅ Channel T8-P8: Saved 0 preictal and 0 interictal segments.





TypeError: cannot unpack non-iterable NoneType object

In [None]:
# Set your cleaned segments path
cleaned_folder = "/content/drive/MyDrive/EEG_Project/cleaned_segments"

# Find all preictal and interictal files
preictal_files = sorted(glob.glob(os.path.join(cleaned_folder, "preictal_*.pkl")))
interictal_files = sorted(glob.glob(os.path.join(cleaned_folder, "interictal_*.pkl")))

summary = []

for pre_file, inter_file in zip(preictal_files, interictal_files):
    channel = os.path.basename(pre_file).replace("preictal_", "").replace("cleaned.pkl", "").replace("", "-")

    pre_df = pd.read_pickle(pre_file)
    inter_df = pd.read_pickle(inter_file)

    # Add durations
    pre_df['duration_sec'] = pre_df['data'].apply(lambda x: len(x) / 256)
    inter_df['duration_sec'] = inter_df['data'].apply(lambda x: len(x) / 256)

    for patient in pre_df['patient'].unique():
        preictal_count = pre_df[pre_df['patient'] == patient].shape[0]
        interictal_count = inter_df[inter_df['patient'] == patient].shape[0]

        preictal_duration = pre_df[pre_df['patient'] == patient]['duration_sec'].sum()
        interictal_duration = inter_df[inter_df['patient'] == patient]['duration_sec'].sum()

        summary.append({
            'Patient': patient,
            'Channel': channel,
            'Preictal Segments': preictal_count,
            'Interictal Segments': interictal_count,
            'Preictal Duration (s)': preictal_duration,
            'Interictal Duration (s)': interictal_duration
        })

# Create the summary DataFrame
summary_df = pd.DataFrame(summary)
summary_df = summary_df.sort_values(by=['Patient', 'Channel']).reset_index(drop=True)

# Save summary to CSV
summary_df.to_csv(os.path.join(cleaned_folder, "summary_cleaned_segments.csv"), index=False)

# Display
summary_df

Unnamed: 0,Patient,Channel,Preictal Segments,Interictal Segments,Preictal Duration (s),Interictal Duration (s)
0,chb01,-C-3-_-P-3-_-,5,5,9000.011719,9000.003906
1,chb01,-C-4-_-P-4-_-,5,5,9000.011719,9000.007812
2,chb01,-C-Z-_-P-Z-_-,5,5,9000.011719,9000.000000
3,chb01,-F-3-_-C-3-_-,5,5,9000.011719,9000.019531
4,chb01,-F-4-_-C-4-_-,5,5,9000.011719,9000.000000
...,...,...,...,...,...,...
80,chb05,-P-3-_-O-1-_-,5,5,9000.011719,9000.007812
81,chb05,-P-4-_-O-2-_-,5,5,9000.011719,9000.011719
82,chb05,-P-7-_-O-1-_-,5,5,9000.011719,9000.003906
83,chb05,-P-8-_-O-2-_-,5,5,9000.011719,9000.007812


In [None]:
import os
import glob
import pandas as pd
import numpy as np
from tqdm import tqdm

# ========== SETTINGS ==========
cleaned_folder = "/content/drive/MyDrive/EEG_Project/cleaned_segments"  # Your cleaned segments
sampling_rate = 256  # Hz
window_size_seconds = 10  # 10-second windows
window_size_samples = sampling_rate * window_size_seconds

# ========== COLLECT FILES ==========
preictal_files = sorted(glob.glob(os.path.join(cleaned_folder, "preictal_*_cleaned.pkl")))
interictal_files = sorted(glob.glob(os.path.join(cleaned_folder, "interictal_*_cleaned.pkl")))

print(f"Found {len(preictal_files)} preictal and {len(interictal_files)} interictal cleaned files.")

# ========== MAIN WINDOWING ==========
X = []  # List of windows
y = []  # List of labels (1=preictal, 0=interictal)

for pre_file, inter_file in tqdm(zip(preictal_files, interictal_files), total=len(preictal_files), desc="Windowing"):
    preictal_df = pd.read_pickle(pre_file)
    interictal_df = pd.read_pickle(inter_file)

    # --- Preictal Windows ---
    for data in preictal_df['data']:
        num_windows = len(data) // window_size_samples
        for i in range(num_windows):
            window = data[i * window_size_samples : (i + 1) * window_size_samples]
            if len(window) == window_size_samples:
                X.append(window)
                y.append(1)  # Label: preictal

    # --- Interictal Windows ---
    for data in interictal_df['data']:
        num_windows = len(data) // window_size_samples
        for i in range(num_windows):
            window = data[i * window_size_samples : (i + 1) * window_size_samples]
            if len(window) == window_size_samples:
                X.append(window)
                y.append(0)  # Label: interictal

# ========== CONVERT TO ARRAYS ==========
X = np.array(X)
y = np.array(y)

print(f"✅ Data ready!")
print(f"Shape of X: {X.shape}")  # (num_samples, 2560)
print(f"Shape of y: {y.shape}")  # (num_samples,)

Found 17 preictal and 17 interictal cleaned files.


Windowing: 100%|██████████| 17/17 [00:11<00:00,  1.44it/s]


✅ Data ready!
Shape of X: (110160, 2560)
Shape of y: (110160,)


In [None]:
import numpy as np

unique, counts = np.unique(y, return_counts=True)
class_distribution = dict(zip(unique, counts))
print("Class distribution:", class_distribution)

Class distribution: {np.int64(0): np.int64(55080), np.int64(1): np.int64(55080)}


In [None]:
import numpy as np
from sklearn.model_selection import train_test_split
import os

# Assuming X and y are already loaded (your 110k samples)

# --------- SETTINGS ---------
save_folder = "/content/drive/MyDrive/EEG_Project/windowed_dataset"
os.makedirs(save_folder, exist_ok=True)

# --------- SPLITTING ---------
# First shuffle and split Train (70%) vs Temp (30%)
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.30, random_state=42, stratify=y)

# Then split Temp into Validation (15%) and Test (15%)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.50, random_state=42, stratify=y_temp)

# Print shape
print(f"X_train shape: {X_train.shape}, y_train shape: {y_train.shape}")
print(f"X_val shape: {X_val.shape}, y_val shape: {y_val.shape}")
print(f"X_test shape: {X_test.shape}, y_test shape: {y_test.shape}")

# --------- SAVING ---------
np.save(os.path.join(save_folder, "X_train.npy"), X_train)
np.save(os.path.join(save_folder, "y_train.npy"), y_train)

np.save(os.path.join(save_folder, "X_val.npy"), X_val)
np.save(os.path.join(save_folder, "y_val.npy"), y_val)

np.save(os.path.join(save_folder, "X_test.npy"), X_test)
np.save(os.path.join(save_folder, "y_test.npy"), y_test)

print("✅ Dataset splits saved!")

X_train shape: (77112, 2560), y_train shape: (77112,)
X_val shape: (16524, 2560), y_val shape: (16524,)
X_test shape: (16524, 2560), y_test shape: (16524,)
✅ Dataset splits saved!
