<a href="https://colab.research.google.com/github/kiril-buga/Neural-Network-Training-Project/blob/main/ECG_Preprocessing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install neurokit2 wfdb

import os
import json
import numpy as np
import pandas as pd
import wfdb
from typing import Dict, Any, Tuple, List
from sklearn.model_selection import train_test_split

from scipy.signal import butter, filtfilt, welch, resample
import neurokit2 as nk

# Check if neurokit2 is available
HAS_NEUROKIT = True
try:
    import neurokit2 as nk
except ImportError:
    HAS_NEUROKIT = False

In [None]:
# ===== Detect if running in Google Colab and mount Drive =====
IN_COLAB = False
try:
    from google.colab import drive  # type: ignore
    IN_COLAB = True
except Exception:
    drive = None
    IN_COLAB = False

if IN_COLAB:
    drive.mount('/content/drive/')

# ===== Define paths =====
if IN_COLAB:
    # Case 1: You manually placed the dataset in MyDrive
    DATA_PATH = "/content/drive/MyDrive/DeepLearningECG/data/"
    ARTIFACT_DIR = "/content/drive/MyDrive/DeepLearningECG/artifacts/"

else:
    # Case 3: Local fallback (if running outside Colab)
    DATA_PATH = "../DeepLearningECG/data/"
    ARTIFACT_DIR = "../DeepLearningECG/artifacts/"


# Path where the WFDB ECG files (.hea/.dat) live.
ECG_DIR = os.path.join(DATA_PATH, "Child_ecg/")

print("DATA_PATH:", DATA_PATH)
print("ARTIFACT_DIR:", ARTIFACT_DIR)
print("ECG_DIR:", ECG_DIR)
print("Files in DATA_PATH:", os.listdir(DATA_PATH))

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).
DATA_PATH: /content/drive/MyDrive/DeepLearningECG/data/
ARTIFACT_DIR: /content/drive/MyDrive/DeepLearningECG/artifacts/
ECG_DIR: /content/drive/MyDrive/DeepLearningECG/data/Child_ecg/
Files in DATA_PATH: ['ECGCode.csv', 'DiseaseCode.csv', 'ExampleReadingCode.ipynb', 'AttributesDictionary.csv', 'Child_ecg.zip', 'Child_ecg.z01', 'Child_ecg']


In [38]:
# load CSV
df_attr = pd.read_csv(DATA_PATH + 'AttributesDictionary.csv')
df_attr

Unnamed: 0,Filename,ECG_ID,Patient_ID,Age,Gender,Acquisition_date,Sampling_point,Lead,AHA_code,CHN_code,ICD-10 code,pSQI,basSQI,bSQI
0,P00/P00001/P00001_E01,P00001_E01,P00001,572d,'Female',2017-11-22 10:46:08,9000,9,'Left ventricular high voltage';'L147','J106';'L123','I34.0';'Q21.0';'Q24.9','I':0.288;'II':0.323;'III':0.346;'aVR':0.312;'...,'I':0.994;'II':0.996;'III':0.991;'aVR':0.997;'...,'I':1.000;'II':1.000;'III':1.000;'aVR':1.000;'...
1,P00/P00002/P00002_E01,P00002_E01,P00002,4327d,'Male',2017-11-28 21:59:47,15000,12,'C21','C13','I51.4';'J18.9','I':0.472;'II':0.446;'III':0.449;'aVR':0.484;'...,'I':0.995;'II':0.980;'III':0.992;'aVR':0.992;'...,'I':1.000;'II':1.000;'III':1.000;'aVR':1.000;'...
2,P00/P00003/P00003_E01,P00003_E01,P00003,1087d,'Female',2017-11-29 16:04:57,10000,12,'C21','C13','Q21.0';'Q24.9','I':0.495;'II':0.347;'III':0.340;'aVR':0.382;'...,'I':0.915;'II':0.895;'III':0.882;'aVR':0.908;'...,'I':1.000;'II':1.000;'III':1.000;'aVR':1.000;'...
3,P00/P00004/P00004_E01,P00004_E01,P00004,2465d,'Male',2017-11-30 15:21:27,13000,9,'C21','C13','Q21.1';'Q24.9','I':0.340;'II':0.405;'III':0.409;'aVR':0.350;'...,'I':0.981;'II':0.988;'III':0.974;'aVR':0.986;'...,'I':1.000;'II':1.000;'III':1.000;'aVR':1.000;'...
4,P00/P00004/P00004_E02,P00004_E02,P00004,2461d,'Male',2017-11-26 19:19:48,15000,9,'A1','A1','Q21.1';'Q24.9','I':0.501;'II':0.494;'III':0.389;'aVR':0.525;'...,'I':0.993;'II':0.993;'III':0.989;'aVR':0.995;'...,'I':1.000;'II':1.000;'III':1.000;'aVR':1.000;'...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
14185,P11/P11639/P11639_E01,P11639_E01,P11639,2646d,'Male',2021-06-24 18:22:31,10000,12,'A1','A1','J35.3','I':0.330;'II':0.422;'III':0.387;'aVR':0.377;'...,'I':0.991;'II':0.991;'III':0.981;'aVR':0.992;'...,'I':1.000;'II':1.000;'III':0.990;'aVR':1.000;'...
14186,P11/P11640/P11640_E01,P11640_E01,P11640,657d,'Male',2021-07-01 09:47:16,10500,12,'C21';'L147','C13';'L123','S02.0';'S06.5';'S06.6';'S06.7';'T14.0','I':0.284;'II':0.362;'III':0.378;'aVR':0.332;'...,'I':0.919;'II':0.934;'III':0.939;'aVR':0.929;'...,'I':0.976;'II':0.993;'III':1.000;'aVR':0.993;'...
14187,P11/P11641/P11641_E01,P11641_E01,P11641,1484d,'Female',2021-07-04 21:58:36,15000,12,'D30+Modifier310','D21+Frequent','I49.1';'R53','I':0.387;'II':0.387;'III':0.411;'aVR':0.384;'...,'I':0.985;'II':0.975;'III':0.960;'aVR':0.980;'...,'I':0.994;'II':0.987;'III':0.982;'aVR':0.994;'...
14188,P11/P11642/P11642_E01,P11642_E01,P11642,5178d,'Male',2021-06-27 20:22:00,15000,12,'C23';'L150','C15';'L128','J31.0';'J34.2';'S02.2','I':0.401;'II':0.409;'III':0.409;'aVR':0.407;'...,'I':0.975;'II':0.973;'III':0.974;'aVR':0.973;'...,'I':0.974;'II':1.000;'III':1.000;'aVR':1.000;'...


In [39]:
def butter_bandpass(lowcut: float, highcut: float, fs: float, order: int = 4):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype="band")
    return b, a


def apply_bandpass(x: np.ndarray, fs: float, lowcut: float = 0.5, highcut: float = 40.0) -> np.ndarray:
    """Apply bandpass filter channel wise."""
    if x.ndim == 1:
        x = x[:, None]
    b, a = butter_bandpass(lowcut, highcut, fs)
    x_filt = np.zeros_like(x)
    for i in range(x.shape[1]):
        x_filt[:, i] = filtfilt(b, a, x[:, i])
    return x_filt.squeeze()


def band_power(f: np.ndarray, Pxx: np.ndarray, fmin: float, fmax: float) -> float:
    """Integrate PSD between fmin and fmax."""
    mask = (f >= fmin) & (f <= fmax)
    if not np.any(mask):
        return 0.0
    return np.trapz(Pxx[mask], f[mask])


In [40]:
def parse_icd_list(s: str):
    """
    Parse ICD-10 string like "'I34.0';'Q21.0';'Q24.9'" into ['I34.0', 'Q21.0', 'Q24.9'].
    Handles NaN or empty strings.
    """
    if pd.isna(s):
        return []
    # remove surrounding quotes
    parts = [p.strip().replace("'", "") for p in s.split(";")]
    parts = [p for p in parts if len(p) > 0]
    return parts


# Create a new column with parsed ICD codes
df_attr["ICD_list"] = df_attr["ICD-10 code"].apply(parse_icd_list)

# Use FIRST ICD code as label (Option A)
df_attr["ICD_primary"] = df_attr["ICD_list"].apply(lambda lst: lst[0] if len(lst) > 0 else None)


In [41]:
df_attr[["Filename", "ICD_list", "ICD_primary"]].head()


Unnamed: 0,Filename,ICD_list,ICD_primary
0,P00/P00001/P00001_E01,"[I34.0, Q21.0, Q24.9]",I34.0
1,P00/P00002/P00002_E01,"[I51.4, J18.9]",I51.4
2,P00/P00003/P00003_E01,"[Q21.0, Q24.9]",Q21.0
3,P00/P00004/P00004_E01,"[Q21.1, Q24.9]",Q21.1
4,P00/P00004/P00004_E02,"[Q21.1, Q24.9]",Q21.1


In [42]:
LABEL_COL = "ICD_primary"


In [43]:
label_values = sorted(df_attr[LABEL_COL].dropna().unique().tolist())
label_to_int = {lab: i for i, lab in enumerate(label_values)}
int_to_label = {i: lab for lab, i in label_to_int.items()}

print("Label mapping:", label_to_int)


Label mapping: {'(F) I40.0': 0, '(FO) Q21.1': 1, '(OSD) Q21.1': 2, '(V) I40.0': 3, 'A02.1': 4, 'A02.9': 5, 'A05.2': 6, 'A08.0': 7, 'A09.0': 8, 'A09.9': 9, 'A15.3': 10, 'A16.2': 11, 'A16.5': 12, 'A16.9': 13, 'A17.0': 14, 'A17.8': 15, 'A18.0': 16, 'A18.3': 17, 'A18.8': 18, 'A23.9': 19, 'A37.9': 20, 'A41.0': 21, 'A41.5': 22, 'A41.9': 23, 'A46': 24, 'A48.3': 25, 'A49.0': 26, 'A49.1': 27, 'A49.3': 28, 'A49.8': 29, 'A49.9': 30, 'A71.9': 31, 'A81.1': 32, 'A86': 33, 'A87.9': 34, 'B00.8': 35, 'B00.9': 36, 'B01.9': 37, 'B02.8': 38, 'B02.9': 39, 'B07': 40, 'B08.5': 41, 'B09': 42, 'B16.9': 43, 'B18.1': 44, 'B18.2': 45, 'B25.1': 46, 'B25.9': 47, 'B27.9': 48, 'B30.9': 49, 'B33.2': 50, 'B34.0': 51, 'B34.1': 52, 'B34.8': 53, 'B34.9': 54, 'B35.0': 55, 'B35.2': 56, 'B35.6': 57, 'B36.0': 58, 'B37.0': 59, 'B37.9': 60, 'B45.1': 61, 'B49': 62, 'B55.0': 63, 'B59': 64, 'B77.8': 65, 'B82.9': 66, 'B83.0': 67, 'B86': 68, 'B94.1': 69, 'B99': 70, 'C02.9': 71, 'C06.9': 72, 'C07': 73, 'C11.9': 74, 'C22.2': 75, 'C22.

In [44]:
with open(os.path.join(ARTIFACT_DIR, "label_mapping.json"), "w") as f:
    json.dump({"label_to_int": label_to_int, "int_to_label": int_to_label}, f)


In [45]:
def parse_sqi_string(s: str):
    """
    Convert "'I':0.288;'II':0.323" into dict {'I':0.288, 'II':0.323}
    """
    if pd.isna(s):
        return {}
    items = s.split(";")
    out = {}
    for it in items:
        it = it.strip()
        if ":" not in it:
            continue
        k, v = it.split(":")
        k = k.replace("'", "").strip()
        try:
            v = float(v)
        except:
            continue
        out[k] = v
    return out


df_attr["pSQI_dict"] = df_attr["pSQI"].apply(parse_sqi_string)
df_attr["basSQI_dict"] = df_attr["basSQI"].apply(parse_sqi_string)
df_attr["bSQI_dict"] = df_attr["bSQI"].apply(parse_sqi_string)


In [46]:
df_attr["pSQI_mean"] = df_attr["pSQI_dict"].apply(lambda d: np.mean(list(d.values())) if len(d)>0 else np.nan)
df_attr["basSQI_mean"] = df_attr["basSQI_dict"].apply(lambda d: np.mean(list(d.values())) if len(d)>0 else np.nan)
df_attr["bSQI_mean"] = df_attr["bSQI_dict"].apply(lambda d: np.mean(list(d.values())) if len(d)>0 else np.nan)


In [None]:
def compute_qc_metrics(sig, meta, pSQI_mean, bSQI_mean):
    """
    Compute QC metrics for one ECG record.

    Returns a dict with metrics and a 'qc_pass' boolean.
    """
    qc: Dict[str, Any] = {}
    qc["pSQI_mean"] = pSQI_mean
    qc["bSQI_mean"] = bSQI_mean
    
    fs = meta.get("fs", None)
    if fs is None:
        raise ValueError("Sampling frequency 'fs' missing in meta.")

    # Ensure 2D: (time, leads)
    if sig.ndim == 1:
        sig = sig[:, None]

    n_samples, n_leads = sig.shape
    qc["n_samples"] = int(n_samples)
    qc["n_leads"] = int(n_leads)
    qc["duration_sec"] = n_samples / fs

    # Use first lead for QC indices
    lead = sig[:, 0]

    # Handle NaNs
    n_total = lead.size
    n_nans = np.isnan(lead).sum()
    qc["nan_fraction"] = float(n_nans / n_total)

    lead_clean = lead.copy()
    if n_nans > 0:
        # simple interpolation for NaNs
        not_nan = ~np.isnan(lead_clean)
        if not np.any(not_nan):
            # Entire lead is NaN: fail immediately
            qc["qc_pass"] = False
            qc["fail_reason"] = "all_nan"
            return qc
        lead_clean[~not_nan] = np.interp(
            np.flatnonzero(~not_nan),
            np.flatnonzero(not_nan),
            lead_clean[not_nan],
        )

    # Basic amplitude stats
    amp = lead_clean
    qc["amp_mean"] = float(np.mean(amp))
    qc["amp_std"] = float(np.std(amp))
    # Robust range
    q1, q99 = np.percentile(amp, [1, 99])
    qc["amp_robust_range"] = float(q99 - q1)

    # Spectral measures using Welch
    f, Pxx = welch(amp, fs=fs, nperseg=min(4096, len(amp)))

    total_power = band_power(f, Pxx, 0.5, 40.0)
    low_power = band_power(f, Pxx, 0.0, 0.5)
    qc["baseline_wander_ratio"] = float(low_power / (total_power + 1e-8))

    # Powerline noise index (assuming 50 Hz, adapt to 60 Hz if needed)
    pl_power = band_power(f, Pxx, 48.0, 52.0)
    band_40_60 = band_power(f, Pxx, 40.0, 60.0)
    qc["powerline_ratio"] = float(pl_power / (band_40_60 + 1e-8))

    # Heart rate consistency via neurokit2 (if available)
    if HAS_NEUROKIT:
        try:
            cleaned = nk.ecg_clean(amp, sampling_rate=fs)
            _, rpeaks = nk.ecg_peaks(cleaned, sampling_rate=fs)
            r_locs = rpeaks["ECG_R_Peaks"]
            if len(r_locs) > 1:
                hr = nk.ecg_rate(r_locs, sampling_rate=fs)
                qc["hr_mean"] = float(np.mean(hr))
                qc["hr_std"] = float(np.std(hr))
                qc["hr_n_beats"] = int(len(r_locs))
            else:
                qc["hr_mean"] = np.nan
                qc["hr_std"] = np.nan
                qc["hr_n_beats"] = int(len(r_locs))
        except Exception as e:
            qc["hr_mean"] = np.nan
            qc["hr_std"] = np.nan
            qc["hr_n_beats"] = 0
            qc["hr_error"] = str(e)
    else:
        qc["hr_mean"] = np.nan
        qc["hr_std"] = np.nan
        qc["hr_n_beats"] = -1

    # Simple QC rules, tune thresholds as needed
    MIN_DURATION = 8.0       # seconds
    MAX_NAN_FRAC = 0.01
    MIN_AMP_RANGE = 0.05     # depends on units
    MAX_AMP_RANGE = 10.0
    MAX_BASELINE_RATIO = 0.5
    MAX_POWERLINE_RATIO = 0.5

    reasons = []

    if qc["duration_sec"] < MIN_DURATION:
        reasons.append("too_short")
    if qc["nan_fraction"] > MAX_NAN_FRAC:
        reasons.append("too_many_nans")
    if not (MIN_AMP_RANGE < qc["amp_robust_range"] < MAX_AMP_RANGE):
        reasons.append("amp_out_of_range")
    if qc["baseline_wander_ratio"] > MAX_BASELINE_RATIO:
        reasons.append("baseline_wander")
    if qc["powerline_ratio"] > MAX_POWERLINE_RATIO:
        reasons.append("powerline_noise")

    # HR based rules only if HR was computed
    if not np.isnan(qc["hr_mean"]):
        if not (40.0 <= qc["hr_mean"] <= 220.0):
            reasons.append("hr_out_of_range")
        if qc["hr_n_beats"] < 5:
            reasons.append("too_few_beats")

    # Override QC using SQI thresholds from df_attr
    # (You will pass pSQI_mean / basSQI_mean / bSQI_mean as parameters)
    if qc["pSQI_mean"] < 0.2:    # based on literature
        reasons.append("low_pSQI")
    if qc["bSQI_mean"] < 0.8:
        reasons.append("low_bSQI")

    qc["qc_pass"] = len(reasons) == 0
    qc["fail_reason"] = ";".join(reasons) if reasons else ""

    return qc

In [48]:
TARGET_FS = 500.0
WINDOW_SEC = 10.0
STEP_SEC = 5.0  # 50 percent overlap


def preprocess_record(sig: np.ndarray, meta: Dict[str, Any], target_fs: float = TARGET_FS) -> Tuple[np.ndarray, float]:
    """
    Bandpass filter and resample entire record.

    Returns:
        sig_proc: (time, leads) at target_fs
        fs_new: sampling rate after resampling
    """
    fs = meta.get("fs", None)
    if fs is None:
        raise ValueError("Sampling frequency 'fs' missing in meta.")

    if sig.ndim == 1:
        sig = sig[:, None]

    # Bandpass
    sig_bp = apply_bandpass(sig, fs=fs)

    if fs == target_fs:
        return sig_bp, fs

    # Resample time dimension
    n_samples = sig_bp.shape[0]
    duration = n_samples / fs
    n_new = int(round(duration * target_fs))

    sig_res = np.zeros((n_new, sig_bp.shape[1]))
    for i in range(sig_bp.shape[1]):
        sig_res[:, i] = resample(sig_bp[:, i], n_new)

    return sig_res, target_fs


def window_record(
    sig: np.ndarray,
    fs: float,
    label_int: int,
    window_sec: float = WINDOW_SEC,
    step_sec: float = STEP_SEC,
    lead_indices: List[int] = None
) -> Tuple[List[np.ndarray], List[int]]:
    """
    Slice a preprocessed record into overlapping windows.

    Returns lists of windows (time, channels) and labels.
    """
    if lead_indices is None:
        # default: use first lead only
        lead_indices = [0]

    if sig.ndim == 1:
        sig = sig[:, None]

    sig = sig[:, lead_indices]
    n_samples = sig.shape[0]

    win_len = int(window_sec * fs)
    step_len = int(step_sec * fs)

    windows = []
    labels = []

    start = 0
    while start + win_len <= n_samples:
        segment = sig[start:start + win_len, :]

        # Drop window if all NaN or very low variance
        if np.isnan(segment).mean() > 0.05:
            start += step_len
            continue

        # Normalize per window (z score per channel)
        seg_norm = segment.copy()
        for ch in range(seg_norm.shape[1]):
            x = seg_norm[:, ch]
            m = np.nanmean(x)
            s = np.nanstd(x)
            if s < 1e-6:
                s = 1.0
            seg_norm[:, ch] = (x - m) / s

        windows.append(seg_norm.astype(np.float32))
        labels.append(int(label_int))

        start += step_len

    return windows, labels


In [None]:
# Create alternative label mapping for AHA_code (if needed later)
assert "AHA_code" in df_attr.columns, "AHA_code not in df_attr columns"

# Unique label values for AHA_code
aha_label_values = sorted(df_attr["AHA_code"].dropna().unique().tolist())
label_to_int_aha = {lab: idx for idx, lab in enumerate(aha_label_values)}
int_to_label_aha = {idx: lab for lab, idx in label_to_int_aha.items()}

print("AHA_code label mapping:", label_to_int_aha)

In [None]:
# Note: ICD_primary label mapping was created in cell-8
# We already have label_to_int and int_to_label from cell-8
# These are used in the build_qc_and_windows function below

# Verify the mappings are still available
print("Current label_to_int keys (first 5):", list(label_to_int.keys())[:5])
print("Current int_to_label keys (first 5):", list(int_to_label.keys())[:5])

In [53]:
def build_qc_and_windows(
    df_attr: pd.DataFrame,
    ecg_dir: str,
    label_col: str = LABEL_COL,
    max_records: int = None,
) -> Tuple[np.ndarray, np.ndarray, pd.DataFrame]:
    """
    Run QC and preprocessing over all records.

    Returns:
        X: (N_windows, T, C)
        y: (N_windows,) integer labels
        df_qc: QC metrics per record
    """
    qc_rows = []
    all_windows = []
    all_labels = []

    # Iterate over attribute rows
    iterator = df_attr.iterrows()
    if max_records is not None:
        iterator = df_attr.iloc[:max_records].iterrows()

    for idx, row in iterator:
        fname = row["Filename"]
        label_raw = row[label_col]

        if pd.isna(label_raw):
            # Skip unlabeled records
            continue

        label_int = label_to_int[label_raw]

        record_path = os.path.join(ECG_DIR, fname)
        try:
            sig, meta = wfdb.rdsamp(record_path)
        except Exception as e:
            print(f"Error reading {record_path}: {e}")
            qc_rows.append({
                "Filename": fname,
                "qc_pass": False,
                "fail_reason": f"read_error:{e}",
            })
            continue

        # Compute QC on raw signal
        qc = compute_qc_metrics(
              np.asarray(sig),
              meta.__dict__,
              pSQI_mean=float(row["pSQI_mean"]),
              bSQI_mean=float(row["bSQI_mean"]),
          )
        qc["Filename"] = fname
        qc["label_raw"] = label_raw
        qc["label_int"] = int(label_int)

        if not qc["qc_pass"]:
            qc_rows.append(qc)
            continue

        # Preprocess and window
        sig_proc, fs_new = preprocess_record(np.asarray(sig), meta.__dict__, target_fs=TARGET_FS)
        windows, labels = window_record(sig_proc, fs=fs_new, label_int=label_int)

        qc["n_windows"] = len(windows)

        qc_rows.append(qc)

        all_windows.extend(windows)
        all_labels.extend(labels)

    # Stack
    if len(all_windows) == 0:
        raise RuntimeError("No windows created. Check QC thresholds and label column.")

    X = np.stack(all_windows, axis=0)  # (N, T, C)
    y = np.array(all_labels, dtype=np.int64)

    df_qc = pd.DataFrame(qc_rows)

    # Save artifacts
    os.makedirs(ARTIFACT_DIR, exist_ok=True)
    np.save(os.path.join(ARTIFACT_DIR, "X_windows.npy"), X)
    np.save(os.path.join(ARTIFACT_DIR, "y_labels.npy"), y)
    df_qc.to_csv(os.path.join(ARTIFACT_DIR, "qc_summary.csv"), index=False)

    print("Saved:")
    print("  X_windows.npy shape:", X.shape)
    print("  y_labels.npy shape:", y.shape)
    print("  qc_summary.csv rows:", len(df_qc))

    return X, y, df_qc


In [54]:
X, y, df_qc = build_qc_and_windows(
    df_attr=df_attr,
    ecg_dir=ECG_DIR,
    label_col="ICD_primary",
    max_records=None
)


KeyError: 'I34.0'

In [None]:
# Note: This cell depends on X and y being created from build_qc_and_windows
# Make sure cell-17 (build_qc_and_windows) is executed first before running this cell

# Split data into train, val and test
# Get unique Patient_ID values
patient_ids = df_attr['Patient_ID'].unique()

# Split Patient_IDs into training, validation, and test sets
train_ids, test_ids = train_test_split(patient_ids, test_size=0.2, random_state=42)
train_ids, val_ids = train_test_split(train_ids, test_size=0.25, random_state=42)  # 0.25 * 0.8 = 0.2

# Get indices for training, validation, and test sets based on Patient_ID
# Note: This assumes X and y come from build_qc_and_windows which uses df_attr rows
# We need to match the windows back to Patient_IDs (this logic needs adjustment based on qc_df)

print("Train patient IDs:", len(train_ids))
print("Val patient IDs:", len(val_ids))
print("Test patient IDs:", len(test_ids))
print("\nNote: Proper train/test split requires matching windows back to patient IDs.")
print("Consider using df_qc returned from build_qc_and_windows for more accurate splitting.")

### Test Loader

In [None]:
# ===== Load ECG data =====


# Load the raw ECG signal data
def load_raw_data(df, path, n_samples=None):
    """Return list of ECG arrays and list of metadata dicts."""
    filenames = df["Filename"].tolist()
    if n_samples is not None:
        filenames = filenames[:n_samples]

    signals = []
    metas = []
    for fname in filenames:
        sig, meta = wfdb.rdsamp(os.path.join(path, fname))
        signals.append(sig)
        metas.append(meta)
    return signals, metas

# Load diagnostic comments from WFDB metadata
def load_Diag(df, path, n_samples=None):
    """Return disease and ECG diagnostic comments from WFDB records."""
    filenames = df["Filename"].tolist()
    if n_samples is not None:
        filenames = filenames[:n_samples]

    disease_diag = []
    ecg_diag = []
    for fname in filenames:
        record = wfdb.rdrecord(os.path.join(path, fname))
        comments = record.comments
        disease_diag.append(comments[1] if len(comments) > 1 else None)
        ecg_diag.append(comments[2] if len(comments) > 2 else None)
    return disease_diag, ecg_diag

# ===== Load a few ECG records and check shapes =====

# Load only a few records so Colab RAM is safe
N_SAMPLES = 5
signals, metas = load_raw_data(df_attr, ECG_DIR, n_samples=N_SAMPLES)

print(f"Loaded {len(signals)} ECG signals.")
print("Shape of first signal (time, leads):", signals[0].shape)
print("Meta of first signal:")
print(metas[0])

# Attach comments for the same subset
disease_diag, ecg_diag = load_Diag(df_attr, ECG_DIR, n_samples=N_SAMPLES)

df_subset = df_attr.iloc[:N_SAMPLES].copy()
df_subset["Disease_diag_comment"] = disease_diag
df_subset["ECG_diag_comment"] = ecg_diag
display(df_subset)
