In [None]:
import os
import pandas as pd
import mne
import numpy as np
from sklearn.model_selection import train_test_split

# === CONFIG ===
DATA_DIR = "/tf/rahman/NCH_Sleep_Data_Bank/data"
ECG_CHANNEL = "ECG EKG2-EKG"     # Primary ECG channel name
ECG_CHANNEL2 = "ECG LA-RA"       # Backup ECG channel name
RANDOM_SEED = 42                 # For reproducibility

# === COLLECT RECORDS ===
records = []
for file in os.listdir(DATA_DIR):
    if file.endswith(".edf"):
        base_name = os.path.splitext(file)[0]
        edf_path = os.path.join(DATA_DIR, base_name + ".edf")
        annot_path = os.path.join(DATA_DIR, base_name + ".annot")
        if os.path.exists(annot_path):
            records.append(base_name)
        else:
            print(f"⚠️ Annotation missing for {base_name}, skipping.")

# === SPLIT RECORDS ===
train_records, temp_records = train_test_split(records, test_size=0.2, random_state=RANDOM_SEED)
val_records, test_records = train_test_split(temp_records, test_size=0.5, random_state=RANDOM_SEED)

splits = {
    'train': train_records,
    'val': val_records,
    'test': test_records
}

# === OUTPUT STRUCTURE ===
data = {
    'train': {'segments': [], 'labels': []},
    'val': {'segments': [], 'labels': []},
    'test': {'segments': [], 'labels': []}
}

# === PROCESS RECORDS BY SPLIT ===
for split_name, record_list in splits.items():
    print(f"\n🔄 Processing {split_name.upper()} set with {len(record_list)} records...")
    for base_name in record_list:
        edf_path = os.path.join(DATA_DIR, base_name + ".edf")
        annot_path = os.path.join(DATA_DIR, base_name + ".annot")
        
        try:
            # Load EDF
            raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)

            # Check available channels and select the appropriate ECG channel
            available_channels = raw.ch_names
            if ECG_CHANNEL in available_channels:
                raw_ecg = raw.copy().pick([ECG_CHANNEL])
            elif ECG_CHANNEL2 in available_channels:
                raw_ecg = raw.copy().pick([ECG_CHANNEL2])
            else:
                raise ValueError(f"No valid ECG channel found in {base_name}. Channels: {available_channels}")
            
            sfreq = raw_ecg.info['sfreq']

            # Load annotations
            annot_df = pd.read_csv(annot_path, sep="\t", header=None, names=["description", "onset", "duration"])
            sleep_stages = annot_df[annot_df['description'].str.contains("Sleep stage")].reset_index(drop=True)

            for _, row in sleep_stages.iterrows():
                start_sample = int(row['onset'] * sfreq)
                end_sample = int((row['onset'] + row['duration']) * sfreq)

                ecg_segment, _ = raw_ecg[:, start_sample:end_sample]
                if ecg_segment.shape[1] > 0:
                    data[split_name]['segments'].append(ecg_segment[0])
                    data[split_name]['labels'].append(row['description'])

        except Exception as e:
            print(f"❌ Error processing {base_name}: {e}")

# === FINAL REPORT ===
for split in ['train', 'val', 'test']:
    print(f"\n📦 {split.upper()} SET:")
    print(f"   ➤ Records: {len(splits[split])}")
    print(f"   ➤ Segments: {len(data[split]['segments'])}")
    print(f"   ➤ Unique Stages: {set(data[split]['labels'])}")

# === OPTIONAL: Save to disk ===
# np.savez("ecg_segments_split_by_record.npz",
#          X_train=data['train']['segments'], y_train=data['train']['labels'],
#          X_val=data['val']['segments'], y_val=data['val']['labels'],
#          X_test=data['test']['segments'], y_test=data['test']['labels'])
