<a href="https://colab.research.google.com/github/kiril-buga/Neural-Network-Training-Project/blob/main/ECG_Preprocessing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [105]:
!pip install neurokit2 wfdb
!apt-get update && apt-get install -y p7zip-full

# Install huggingface_hub if needed
!pip install huggingface-hub -q

import os
import json
import numpy as np
import pandas as pd
import wfdb
from typing import Dict, Any, Tuple, List
from sklearn.model_selection import train_test_split

from scipy.signal import butter, filtfilt, welch, resample
import neurokit2 as nk

# Check if neurokit2 is available
HAS_NEUROKIT = True
try:
    import neurokit2 as nk
except ImportError:
    HAS_NEUROKIT = False

Hit:1 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease
Hit:2 https://cli.github.com/packages stable InRelease
Hit:3 http://archive.ubuntu.com/ubuntu jammy InRelease
Get:4 http://security.ubuntu.com/ubuntu jammy-security InRelease [129 kB]
Get:5 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [128 kB]
Hit:6 https://r2u.stat.illinois.edu/ubuntu jammy InRelease
Hit:7 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease
Hit:8 https://ppa.launchpadcontent.net/graphics-drivers/ppa/ubuntu jammy InRelease
Hit:9 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease
Hit:10 https://ppa.launchpadcontent.net/ubuntugis/ppa/ubuntu jammy InRelease
Hit:11 http://archive.ubuntu.com/ubuntu jammy-backports InRelease
Fetched 257 kB in 3s (94.7 kB/s)
Reading package lists... Done
W: Skipping acquire of configured file 'main/source/Sources' as repository 'https://r2u.stat.illinois.edu/ubuntu jammy InRelease' does not seem to provi

### Step 1: Install Dependencies and Import Libraries
Install required packages and import all necessary libraries:
- **wfdb**: Read ECG files in WFDB format
- **neurokit2**: Advanced ECG signal processing and QC metrics
- **scipy.signal**: Signal filtering and spectral analysis
- **numpy/pandas**: Numerical computing and data manipulation

### Step 2: Configure File Paths
Set up data directories based on execution environment:
- Detects if running in Google Colab or locally
- Mounts Google Drive if needed
- Defines paths for:
  - **ECG_DIR**: Location of WFDB ECG files (`.hea` headers and `.dat` signal data)
  - **DATA_PATH**: Root directory containing CSV metadata files
  - **ARTIFACT_DIR**: Output directory for processed data and QC reports

### Step 3: Load Metadata from CSV
Load the `AttributesDictionary.csv` file containing metadata for all ECG records:
- **Filename**: Path to ECG file
- **Patient_ID**: Unique patient identifier
- **Demographics**: Age, Gender
- **Labels**: ICD-10 codes (primary diagnosis), AHA codes, CHN codes
- **Quality Metrics**: Pre-computed pSQI, basSQI, bSQI (signal quality indices)

### Step 4: Define Signal Processing Utilities
Implement core signal processing functions:
- **butter_bandpass()**: Design a Butterworth bandpass filter (0.5-40 Hz to remove noise and baseline wander)
- **apply_bandpass()**: Apply filter channel-wise to ECG signals
- **band_power()**: Calculate power spectral density in a frequency band (used for QC metrics)

### Step 5: Parse ICD-10 Diagnosis Codes
Extract and simplify diagnosis labels from CSV text format:
- **parse_icd_list()**: Convert string like `'I34.0';'Q21.0';'Q24.9'` into a list `[I34.0, Q21.0, Q24.9]`
- **ICD_primary**: Use the FIRST ICD-10 code as the primary diagnosis label (reduces class imbalance)
- Creates mapping between diagnosis codes and list indices

### Step 6: Inspect Parsed Labels
Display a sample of the parsed ICD-10 codes to verify correct parsing:
- Shows filename, full ICD list, and extracted primary (first) code
- Validates that label parsing is working correctly

### Step 7: Select Primary Label Column
Define which diagnosis column to use as the target label:
- **LABEL_COL = "ICD_primary"**: Uses the primary ICD-10 code for classification
- Can be changed to "AHA_code" or "CHN_code" for alternative classification schemes

### Step 8: Create Label Encodings
Build bidirectional mappings between diagnosis codes and integer indices:
- **label_to_int**: Maps diagnosis code (e.g., "I34.0") → integer (e.g., 0)
- **int_to_label**: Reverse mapping for converting predictions back to diagnosis codes
- Used for one-hot encoding and model training/evaluation

### Step 9: Save Label Mapping
Export the label encoding to JSON for reproducibility:
- Save mapping to `ARTIFACT_DIR/label_mapping.json`
- Allows other scripts to decode model predictions consistently

### Step 10: Parse Signal Quality Indices (SQI)
Convert pre-computed quality metrics from string format to dictionaries:
- **parse_sqi_string()**: Parse strings like `'I':0.288;'II':0.323` into dicts `{'I': 0.288, 'II': 0.323}`
- **pSQI**: Photoplethysmography-like SQI (signal-noise ratio proxy)
- **basSQI**: Baseline wander SQI
- **bSQI**: Baseline SQI
- Creates new columns with parsed dicts for each metric

### Step 11: Compute Mean SQI Per Record
Aggregate per-channel SQI values into a single quality score per record:
- **pSQI_mean**: Average pSQI across all leads
- **basSQI_mean**: Average baseline wander SQI across all leads
- **bSQI_mean**: Average baseline SQI across all leads
- These aggregate metrics are used in QC decision-making

### Step 12: Define Quality Control (QC) Metrics Function
Implement comprehensive QC evaluation for each ECG record:
- **Compute metrics**: Signal duration, amplitude range, NaN fraction, spectral properties, heart rate consistency
- **QC rules**: Check against thresholds for duration, amplitude, baseline wander, powerline noise, heart rate bounds
- **Output**: Boolean `qc_pass` flag and failure reasons if QC fails
- **Purpose**: Filter out poor-quality records before model training

### Step 13: Define Signal Preprocessing and Windowing Functions
Implement the signal processing pipeline:
- **preprocess_record()**: Apply bandpass filter and resample to target sampling rate (500 Hz)
- **window_record()**: Divide long signals into fixed-length overlapping windows
  - Window size: 10 seconds
  - Step size: 5 seconds (50% overlap for data augmentation)
  - Per-window z-score normalization (mean=0, std=1 per channel)
  - Drops windows with >5% NaN values
- Output: List of normalized windows and corresponding labels

### Step 14: Create Alternative Label Mappings (Optional)
Generate label encodings for alternative diagnosis coding schemes:
- **label_to_int_aha**: Maps AHA codes to integers (alternative to ICD-10)
- **int_to_label_aha**: Reverse mapping for AHA codes
- Allows experimentation with different classification taxonomies
- Kept separate to avoid interfering with primary ICD_primary labels

### Step 15: Verify Primary Label Mapping
Sanity check that the ICD_primary label mapping is correctly preserved:
- Prints first 5 keys from both `label_to_int` and `int_to_label`
- Ensures the mapping created in Step 8 is still available for downstream processing
- Helps catch mapping conflicts or overwrites

### Step 16: Define Main Processing Pipeline Function
Implement the complete data processing workflow:
- **Input**: DataFrame with filenames and labels, path to ECG files
- **For each record**:
  1. Load raw ECG signal and metadata from WFDB files
  2. Compute comprehensive QC metrics
  3. Skip records that fail QC checks
  4. Apply bandpass filtering and resampling
  5. Create overlapping time windows
  6. Collect QC statistics and processed windows
- **Output**:
  - **X**: 3D array of preprocessed windows (N_windows, time_samples, channels)
  - **y**: Integer labels for each window
  - **df_qc**: DataFrame with QC metrics and pass/fail reasons
- **Saves**: X, y, and df_qc to ARTIFACT_DIR

### Step 17: Execute Main Processing Pipeline
Run the complete preprocessing pipeline on all ECG records:
- **Inputs**:
  - `df_attr`: Metadata DataFrame from Step 3
  - `ECG_DIR`: Path to WFDB files from Step 2
  - `label_col="ICD_primary"`: Use primary ICD-10 diagnosis code
  - `max_records=None`: Process all available records (set to small number for testing)
- **Execution time**: Scales with number of records (~1-5 min for 100 records on CPU)
- **Expected output**: X shape ~(1000-5000, 5000, 12), y shape ~(1000-5000,)
- **Outputs**: X, y, df_qc ready for train/test split and model training

In [106]:
# Set this to True to download from Huggingface else use Google Drive
USE_HF = True

if USE_HF:
  from huggingface_hub import snapshot_download
  local_dir = snapshot_download(
      repo_id="kiril-buga/ECG-database",
      repo_type="dataset",
      local_dir="/content/ECG-database/" # Specify the desired download directory
  )
  print("Downloaded to:", local_dir)
else:
  from google.colab import drive
  drive.mount('/content/drive')

Fetching 21 files:   0%|          | 0/21 [00:00<?, ?it/s]

Test:   0%|          | 0.00/4.00 [00:00<?, ?B/s]

label_mapping.json: 0.00B [00:00, ?B/s]

README.md:   0%|          | 0.00/374 [00:00<?, ?B/s]

qc_summary.csv: 0.00B [00:00, ?B/s]

Downloaded to: /content/ECG-database


In [107]:

if USE_HF and local_dir:
  # Case 2: You want to download the dataset from Huggingface
    DATA_PATH = f"{local_dir}/data/"
    ARTIFACT_DIR = f"{local_dir}/artifacts/"

else:

  # ===== Detect if running in Google Colab and mount Drive =====
  IN_COLAB = False
  try:
      from google.colab import drive  # type: ignore
      IN_COLAB = True
  except Exception:
      drive = None
      IN_COLAB = False

  if IN_COLAB:
      drive.mount('/content/drive/')

  # ===== Define paths =====
  if IN_COLAB:
      # Case 1: You manually placed the dataset in MyDrive
      DATA_PATH = "/content/drive/MyDrive/DeepLearningECG/data/"
      ARTIFACT_DIR = "/content/drive/MyDrive/DeepLearningECG/artifacts/"

  else:
      # Case 3: Local fallback (if running outside Colab)
      DATA_PATH = "../DeepLearningECG/data/"
      ARTIFACT_DIR = "../DeepLearningECG/artifacts/"


# Path where the WFDB ECG files (.hea/.dat) live.
ECG_DIR = os.path.join(DATA_PATH, "Child_ecg/")

print("DATA_PATH:", DATA_PATH)
print("ARTIFACT_DIR:", ARTIFACT_DIR)
print("ECG_DIR:", ECG_DIR)
print("Files in DATA_PATH:", os.listdir(DATA_PATH))

DATA_PATH: /content/ECG-database/data/
ARTIFACT_DIR: /content/ECG-database/artifacts/
ECG_DIR: /content/ECG-database/data/Child_ecg/
Files in DATA_PATH: ['.ipynb_checkpoints', 'ECGCode.csv', 'ExampleReadingCode.ipynb', 'Child_ecg.zip', 'AttributesDictionary.csv', 'Child_ecg.z01', 'Child_ecg', 'DiseaseCode.csv']


In [108]:
!cd DATA_PATH && 7z x Child_ecg.zip

print("✓ Extraction complete!")

/bin/bash: line 1: cd: DATA_PATH: No such file or directory
✓ Extraction complete!


## Section 2: Data Loading and Exploration
Utility functions to load raw ECG data and diagnostic comments for exploration and debugging:
- Load raw (unprocessed) ECG signals and metadata
- Extract diagnostic comments from WFDB records
- Inspect signal shapes and metadata structure
- Useful for data quality assessment and troubleshooting

In [109]:
# load CSV
df_attr = pd.read_csv(DATA_PATH + 'AttributesDictionary.csv')
df_attr

Unnamed: 0,Filename,ECG_ID,Patient_ID,Age,Gender,Acquisition_date,Sampling_point,Lead,AHA_code,CHN_code,ICD-10 code,pSQI,basSQI,bSQI
0,P00/P00001/P00001_E01,P00001_E01,P00001,572d,'Female',2017-11-22 10:46:08,9000,9,'Left ventricular high voltage';'L147','J106';'L123','I34.0';'Q21.0';'Q24.9','I':0.288;'II':0.323;'III':0.346;'aVR':0.312;'...,'I':0.994;'II':0.996;'III':0.991;'aVR':0.997;'...,'I':1.000;'II':1.000;'III':1.000;'aVR':1.000;'...
1,P00/P00002/P00002_E01,P00002_E01,P00002,4327d,'Male',2017-11-28 21:59:47,15000,12,'C21','C13','I51.4';'J18.9','I':0.472;'II':0.446;'III':0.449;'aVR':0.484;'...,'I':0.995;'II':0.980;'III':0.992;'aVR':0.992;'...,'I':1.000;'II':1.000;'III':1.000;'aVR':1.000;'...
2,P00/P00003/P00003_E01,P00003_E01,P00003,1087d,'Female',2017-11-29 16:04:57,10000,12,'C21','C13','Q21.0';'Q24.9','I':0.495;'II':0.347;'III':0.340;'aVR':0.382;'...,'I':0.915;'II':0.895;'III':0.882;'aVR':0.908;'...,'I':1.000;'II':1.000;'III':1.000;'aVR':1.000;'...
3,P00/P00004/P00004_E01,P00004_E01,P00004,2465d,'Male',2017-11-30 15:21:27,13000,9,'C21','C13','Q21.1';'Q24.9','I':0.340;'II':0.405;'III':0.409;'aVR':0.350;'...,'I':0.981;'II':0.988;'III':0.974;'aVR':0.986;'...,'I':1.000;'II':1.000;'III':1.000;'aVR':1.000;'...
4,P00/P00004/P00004_E02,P00004_E02,P00004,2461d,'Male',2017-11-26 19:19:48,15000,9,'A1','A1','Q21.1';'Q24.9','I':0.501;'II':0.494;'III':0.389;'aVR':0.525;'...,'I':0.993;'II':0.993;'III':0.989;'aVR':0.995;'...,'I':1.000;'II':1.000;'III':1.000;'aVR':1.000;'...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
14185,P11/P11639/P11639_E01,P11639_E01,P11639,2646d,'Male',2021-06-24 18:22:31,10000,12,'A1','A1','J35.3','I':0.330;'II':0.422;'III':0.387;'aVR':0.377;'...,'I':0.991;'II':0.991;'III':0.981;'aVR':0.992;'...,'I':1.000;'II':1.000;'III':0.990;'aVR':1.000;'...
14186,P11/P11640/P11640_E01,P11640_E01,P11640,657d,'Male',2021-07-01 09:47:16,10500,12,'C21';'L147','C13';'L123','S02.0';'S06.5';'S06.6';'S06.7';'T14.0','I':0.284;'II':0.362;'III':0.378;'aVR':0.332;'...,'I':0.919;'II':0.934;'III':0.939;'aVR':0.929;'...,'I':0.976;'II':0.993;'III':1.000;'aVR':0.993;'...
14187,P11/P11641/P11641_E01,P11641_E01,P11641,1484d,'Female',2021-07-04 21:58:36,15000,12,'D30+Modifier310','D21+Frequent','I49.1';'R53','I':0.387;'II':0.387;'III':0.411;'aVR':0.384;'...,'I':0.985;'II':0.975;'III':0.960;'aVR':0.980;'...,'I':0.994;'II':0.987;'III':0.982;'aVR':0.994;'...
14188,P11/P11642/P11642_E01,P11642_E01,P11642,5178d,'Male',2021-06-27 20:22:00,15000,12,'C23';'L150','C15';'L128','J31.0';'J34.2';'S02.2','I':0.401;'II':0.409;'III':0.409;'aVR':0.407;'...,'I':0.975;'II':0.973;'III':0.974;'aVR':0.973;'...,'I':0.974;'II':1.000;'III':1.000;'aVR':1.000;'...


### Step 19: Load and Explore Raw ECG Data
Implement utility functions and load a small sample of raw data for inspection:
- **load_raw_data()**: Read raw ECG signals and metadata from WFDB files
- **load_Diag()**: Extract diagnostic comments from WFDB record headers
- **Display**: Signal shapes, metadata structure, and diagnostic comments for first N records
- **Purpose**: Verify data integrity and understand signal properties before full processing
- **Sample size**: N_SAMPLES=5 to avoid memory overload (adjust as needed)

In [110]:
def butter_bandpass(lowcut: float, highcut: float, fs: float, order: int = 4):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype="band")
    return b, a


def apply_bandpass(x: np.ndarray, fs: float, lowcut: float = 0.5, highcut: float = 40.0) -> np.ndarray:
    """Apply bandpass filter channel wise."""
    if x.ndim == 1:
        x = x[:, None]
    b, a = butter_bandpass(lowcut, highcut, fs)
    x_filt = np.zeros_like(x)
    for i in range(x.shape[1]):
        x_filt[:, i] = filtfilt(b, a, x[:, i])
    return x_filt.squeeze()


def band_power(f: np.ndarray, Pxx: np.ndarray, fmin: float, fmax: float) -> float:
    """Integrate PSD between fmin and fmax."""
    mask = (f >= fmin) & (f <= fmax)
    if not np.any(mask):
        return 0.0
    return np.trapezoid(Pxx[mask], f[mask])


In [111]:
def parse_icd_list(s: str):
    """
    Parse ICD-10 string like "'I34.0';'Q21.0';'Q24.9'" into ['I34.0', 'Q21.0', 'Q24.9'].
    Handles NaN or empty strings.
    """
    if pd.isna(s):
        return []
    # remove surrounding quotes
    parts = [p.strip().replace("'", "") for p in s.split(";")]
    parts = [p for p in parts if len(p) > 0]
    return parts


# Create a new column with parsed ICD codes
df_attr["ICD_list"] = df_attr["ICD-10 code"].apply(parse_icd_list)

# Use FIRST ICD code as label (Option A)
df_attr["ICD_primary"] = df_attr["ICD_list"].apply(lambda lst: lst[0] if len(lst) > 0 else None)


In [112]:
df_attr[["Filename", "ICD_list", "ICD_primary"]].head()


Unnamed: 0,Filename,ICD_list,ICD_primary
0,P00/P00001/P00001_E01,"[I34.0, Q21.0, Q24.9]",I34.0
1,P00/P00002/P00002_E01,"[I51.4, J18.9]",I51.4
2,P00/P00003/P00003_E01,"[Q21.0, Q24.9]",Q21.0
3,P00/P00004/P00004_E01,"[Q21.1, Q24.9]",Q21.1
4,P00/P00004/P00004_E02,"[Q21.1, Q24.9]",Q21.1


In [113]:
LABEL_COL = "ICD_primary"


In [114]:
label_values = sorted(df_attr[LABEL_COL].dropna().unique().tolist())
label_to_int = {lab: i for i, lab in enumerate(label_values)}
int_to_label = {i: lab for lab, i in label_to_int.items()}

print("Label mapping:", label_to_int)


Label mapping: {'(F) I40.0': 0, '(FO) Q21.1': 1, '(OSD) Q21.1': 2, '(V) I40.0': 3, 'A02.1': 4, 'A02.9': 5, 'A05.2': 6, 'A08.0': 7, 'A09.0': 8, 'A09.9': 9, 'A15.3': 10, 'A16.2': 11, 'A16.5': 12, 'A16.9': 13, 'A17.0': 14, 'A17.8': 15, 'A18.0': 16, 'A18.3': 17, 'A18.8': 18, 'A23.9': 19, 'A37.9': 20, 'A41.0': 21, 'A41.5': 22, 'A41.9': 23, 'A46': 24, 'A48.3': 25, 'A49.0': 26, 'A49.1': 27, 'A49.3': 28, 'A49.8': 29, 'A49.9': 30, 'A71.9': 31, 'A81.1': 32, 'A86': 33, 'A87.9': 34, 'B00.8': 35, 'B00.9': 36, 'B01.9': 37, 'B02.8': 38, 'B02.9': 39, 'B07': 40, 'B08.5': 41, 'B09': 42, 'B16.9': 43, 'B18.1': 44, 'B18.2': 45, 'B25.1': 46, 'B25.9': 47, 'B27.9': 48, 'B30.9': 49, 'B33.2': 50, 'B34.0': 51, 'B34.1': 52, 'B34.8': 53, 'B34.9': 54, 'B35.0': 55, 'B35.2': 56, 'B35.6': 57, 'B36.0': 58, 'B37.0': 59, 'B37.9': 60, 'B45.1': 61, 'B49': 62, 'B55.0': 63, 'B59': 64, 'B77.8': 65, 'B82.9': 66, 'B83.0': 67, 'B86': 68, 'B94.1': 69, 'B99': 70, 'C02.9': 71, 'C06.9': 72, 'C07': 73, 'C11.9': 74, 'C22.2': 75, 'C22.

In [115]:
with open(os.path.join(ARTIFACT_DIR, "label_mapping.json"), "w") as f:
    json.dump({"label_to_int": label_to_int, "int_to_label": int_to_label}, f)


In [116]:
def parse_sqi_string(s: str):
    """
    Convert "'I':0.288;'II':0.323" into dict {'I':0.288, 'II':0.323}
    """
    if pd.isna(s):
        return {}
    items = s.split(";")
    out = {}
    for it in items:
        it = it.strip()
        if ":" not in it:
            continue
        k, v = it.split(":")
        k = k.replace("'", "").strip()
        try:
            v = float(v)
        except:
            continue
        out[k] = v
    return out


df_attr["pSQI_dict"] = df_attr["pSQI"].apply(parse_sqi_string)
df_attr["basSQI_dict"] = df_attr["basSQI"].apply(parse_sqi_string)
df_attr["bSQI_dict"] = df_attr["bSQI"].apply(parse_sqi_string)


In [117]:
df_attr["pSQI_mean"] = df_attr["pSQI_dict"].apply(lambda d: np.mean(list(d.values())) if len(d)>0 else np.nan)
df_attr["basSQI_mean"] = df_attr["basSQI_dict"].apply(lambda d: np.mean(list(d.values())) if len(d)>0 else np.nan)
df_attr["bSQI_mean"] = df_attr["bSQI_dict"].apply(lambda d: np.mean(list(d.values())) if len(d)>0 else np.nan)


In [118]:
def compute_qc_metrics(sig, meta, pSQI_mean, bSQI_mean):
    """
    Compute QC metrics for one ECG record.

    Returns a dict with metrics and a 'qc_pass' boolean.
    """
    qc: Dict[str, Any] = {}
    qc["pSQI_mean"] = pSQI_mean
    qc["bSQI_mean"] = bSQI_mean

    fs = meta.get("fs", None)
    if fs is None:
        raise ValueError("Sampling frequency 'fs' missing in meta.")

    # Ensure 2D: (time, leads)
    if sig.ndim == 1:
        sig = sig[:, None]

    n_samples, n_leads = sig.shape
    qc["n_samples"] = int(n_samples)
    qc["n_leads"] = int(n_leads)
    qc["duration_sec"] = n_samples / fs

    # Use first lead for QC indices
    lead = sig[:, 0]

    # Handle NaNs
    n_total = lead.size
    n_nans = np.isnan(lead).sum()
    qc["nan_fraction"] = float(n_nans / n_total)

    lead_clean = lead.copy()
    if n_nans > 0:
        # simple interpolation for NaNs
        not_nan = ~np.isnan(lead_clean)
        if not np.any(not_nan):
            # Entire lead is NaN: fail immediately
            qc["qc_pass"] = False
            qc["fail_reason"] = "all_nan"
            return qc
        lead_clean[~not_nan] = np.interp(
            np.flatnonzero(~not_nan),
            np.flatnonzero(not_nan),
            lead_clean[not_nan],
        )

    # Basic amplitude stats
    amp = lead_clean
    qc["amp_mean"] = float(np.mean(amp))
    qc["amp_std"] = float(np.std(amp))
    # Robust range
    q1, q99 = np.percentile(amp, [1, 99])
    qc["amp_robust_range"] = float(q99 - q1)

    # Spectral measures using Welch
    f, Pxx = welch(amp, fs=fs, nperseg=min(4096, len(amp)))

    total_power = band_power(f, Pxx, 0.5, 40.0)
    low_power = band_power(f, Pxx, 0.0, 0.5)
    qc["baseline_wander_ratio"] = float(low_power / (total_power + 1e-8))

    # Powerline noise index (assuming 50 Hz, adapt to 60 Hz if needed)
    pl_power = band_power(f, Pxx, 48.0, 52.0)
    band_40_60 = band_power(f, Pxx, 40.0, 60.0)
    qc["powerline_ratio"] = float(pl_power / (band_40_60 + 1e-8))

    # Heart rate consistency via neurokit2 (if available)
    if HAS_NEUROKIT:
        try:
            cleaned = nk.ecg_clean(amp, sampling_rate=fs)
            _, rpeaks = nk.ecg_peaks(cleaned, sampling_rate=fs)
            r_locs = rpeaks["ECG_R_Peaks"]
            if len(r_locs) > 1:
                hr = nk.ecg_rate(r_locs, sampling_rate=fs)
                qc["hr_mean"] = float(np.mean(hr))
                qc["hr_std"] = float(np.std(hr))
                qc["hr_n_beats"] = int(len(r_locs))
            else:
                qc["hr_mean"] = np.nan
                qc["hr_std"] = np.nan
                qc["hr_n_beats"] = int(len(r_locs))
        except Exception as e:
            qc["hr_mean"] = np.nan
            qc["hr_std"] = np.nan
            qc["hr_n_beats"] = 0
            qc["hr_error"] = str(e)
    else:
        qc["hr_mean"] = np.nan
        qc["hr_std"] = np.nan
        qc["hr_n_beats"] = -1

    # Simple QC rules, tune thresholds as needed
    MIN_DURATION = 8.0       # seconds
    MAX_NAN_FRAC = 0.01
    MIN_AMP_RANGE = 0.05     # depends on units
    MAX_AMP_RANGE = 10.0
    MAX_BASELINE_RATIO = 0.5
    MAX_POWERLINE_RATIO = 0.5

    reasons = []

    if qc["duration_sec"] < MIN_DURATION:
        reasons.append("too_short")
    if qc["nan_fraction"] > MAX_NAN_FRAC:
        reasons.append("too_many_nans")
    if not (MIN_AMP_RANGE < qc["amp_robust_range"] < MAX_AMP_RANGE):
        reasons.append("amp_out_of_range")
    if qc["baseline_wander_ratio"] > MAX_BASELINE_RATIO:
        reasons.append("baseline_wander")
    if qc["powerline_ratio"] > MAX_POWERLINE_RATIO:
        reasons.append("powerline_noise")

    # HR based rules only if HR was computed
    if not np.isnan(qc["hr_mean"]):
        if not (40.0 <= qc["hr_mean"] <= 220.0):
            reasons.append("hr_out_of_range")
        if qc["hr_n_beats"] < 5:
            reasons.append("too_few_beats")

    # Override QC using SQI thresholds from df_attr
    # (You will pass pSQI_mean / basSQI_mean / bSQI_mean as parameters)
    if qc["pSQI_mean"] < 0.2:    # based on literature
        reasons.append("low_pSQI")
    if qc["bSQI_mean"] < 0.8:
        reasons.append("low_bSQI")

    qc["qc_pass"] = len(reasons) == 0
    qc["fail_reason"] = ";".join(reasons) if reasons else ""

    return qc

In [119]:
TARGET_FS = 500.0
WINDOW_SEC = 10.0
STEP_SEC = 5.0  # 50 percent overlap


def preprocess_record(sig: np.ndarray, meta: Dict[str, Any], target_fs: float = TARGET_FS) -> Tuple[np.ndarray, float]:
    """
    Bandpass filter and resample entire record.

    Returns:
        sig_proc: (time, leads) at target_fs
        fs_new: sampling rate after resampling
    """
    fs = meta.get("fs", None)
    if fs is None:
        raise ValueError("Sampling frequency 'fs' missing in meta.")

    if sig.ndim == 1:
        sig = sig[:, None]

    # Bandpass
    sig_bp = apply_bandpass(sig, fs=fs)

    if fs == target_fs:
        return sig_bp, fs

    # Resample time dimension
    n_samples = sig_bp.shape[0]
    duration = n_samples / fs
    n_new = int(round(duration * target_fs))

    sig_res = np.zeros((n_new, sig_bp.shape[1]))
    for i in range(sig_bp.shape[1]):
        sig_res[:, i] = resample(sig_bp[:, i], n_new)

    return sig_res, target_fs


def window_record(
    sig: np.ndarray,
    fs: float,
    label_int: int,
    window_sec: float = WINDOW_SEC,
    step_sec: float = STEP_SEC,
    lead_indices: List[int] = None
) -> Tuple[List[np.ndarray], List[int]]:
    """
    Slice a preprocessed record into overlapping windows.

    Returns lists of windows (time, channels) and labels.
    """
    if lead_indices is None:
        # default: use first lead only
        lead_indices = [0]

    if sig.ndim == 1:
        sig = sig[:, None]

    sig = sig[:, lead_indices]
    n_samples = sig.shape[0]

    win_len = int(window_sec * fs)
    step_len = int(step_sec * fs)

    windows = []
    labels = []

    start = 0
    while start + win_len <= n_samples:
        segment = sig[start:start + win_len, :]

        # Drop window if all NaN or very low variance
        if np.isnan(segment).mean() > 0.05:
            start += step_len
            continue

        # Normalize per window (z score per channel)
        seg_norm = segment.copy()
        for ch in range(seg_norm.shape[1]):
            x = seg_norm[:, ch]
            m = np.nanmean(x)
            s = np.nanstd(x)
            if s < 1e-6:
                s = 1.0
            seg_norm[:, ch] = (x - m) / s

        windows.append(seg_norm.astype(np.float32))
        labels.append(int(label_int))

        start += step_len

    return windows, labels


In [120]:
# Create alternative label mapping for AHA_code (if needed later)
assert "AHA_code" in df_attr.columns, "AHA_code not in df_attr columns"

# Unique label values for AHA_code
aha_label_values = sorted(df_attr["AHA_code"].dropna().unique().tolist())
label_to_int_aha = {lab: idx for idx, lab in enumerate(aha_label_values)}
int_to_label_aha = {idx: lab for lab, idx in label_to_int_aha.items()}

print("AHA_code label mapping:", label_to_int_aha)

AHA_code label mapping: {"'A1'": 0, "'A1';'C22';'C23'": 1, "'A1';'C23'": 2, "'A1';'L147'": 3, "'A1';'L150'": 4, "'A1';'Prolonged QTc interval'": 5, "'A1';'Sinoatrial node to atrial internal migratory rhythm'": 6, "'A2'": 7, "'A2';'C22';'C23'": 8, "'A2';'C23'": 9, "'A2';'C23';'I101'": 10, "'A2';'J124'": 11, "'A2';'J125'": 12, "'A2';'L150'": 13, "'A2';'Suggests213'": 14, "'Abnormal Q wave'": 15, "'Abnormal Q wave';'Consider230'": 16, "'Abnormal precordial R-wave progression'": 17, "'Abnormal precordial R-wave progression';'J125';'K140';'K141';'K142';'L145+Modifier363'": 18, "'Abnormal precordial R-wave progression';'J125';'K140';'K141';'K143';'L145+Modifier362';'L145+Modifier363'": 19, "'Abnormal precordial R-wave progression';'J125';'K142';'L146'": 20, "'Abnormal precordial R-wave progression';'J125';'K143';'Abnormal Q wave'": 21, "'Abnormal precordial R-wave progression';'J125';'L145+Modifier362';'L145+Modifier363'": 22, "'Abnormal precordial R-wave progression';'K140';'K141';'K142';'L

In [121]:
# Note: ICD_primary label mapping was created in cell-8
# We already have label_to_int and int_to_label from cell-8
# These are used in the build_qc_and_windows function below

# Verify the mappings are still available
print("Current label_to_int keys (first 5):", list(label_to_int.keys())[:5])
print("Current int_to_label keys (first 5):", list(int_to_label.keys())[:5])

Current label_to_int keys (first 5): ['(F) I40.0', '(FO) Q21.1', '(OSD) Q21.1', '(V) I40.0', 'A02.1']
Current int_to_label keys (first 5): [0, 1, 2, 3, 4]


In [122]:
def build_qc_and_windows(
    df_attr: pd.DataFrame,
    ecg_dir: str,
    label_col: str = LABEL_COL,
    max_records: int = None,
) -> Tuple[np.ndarray, np.ndarray, pd.DataFrame]:
    """
    Run QC and preprocessing over all records.

    Returns:
        X: (N_windows, T, C)
        y: (N_windows,) integer labels
        df_qc: QC metrics per record
    """

    # ===== BLOCK 1: Initialize data collection lists =====
    # These lists will accumulate results from all records that pass QC
    qc_rows = []      # List of dicts with QC metrics for each record
    all_windows = []  # List of preprocessed signal windows (each is a numpy array)
    all_labels = []   # List of integer labels corresponding to each window

    # ===== BLOCK 2: Set up iteration over records =====
    # Create an iterator that optionally limits the number of records processed
    iterator = df_attr.iterrows()
    if max_records is not None:
        iterator = df_attr.iloc[:max_records].iterrows()

    # ===== BLOCK 3: Main loop - Process each ECG record =====
    for idx, row in iterator:

        # ===== BLOCK 3A: Extract filename and label =====
        # Get the path to the ECG file and the diagnostic label
        fname = row["Filename"]
        label_raw = row[label_col]

        # ===== BLOCK 3B: Skip records without labels =====
        # If the record has no diagnosis code, skip it
        if pd.isna(label_raw):
            continue

        # ===== BLOCK 3C: Convert label to integer index =====
        # Look up the integer encoding for this diagnosis code
        label_int = label_to_int[label_raw]

        # ===== BLOCK 3D: Load raw ECG signal from disk =====
        # Read the WFDB format ECG file (includes signal and metadata)
        record_path = os.path.join(ECG_DIR, fname)
        try:
            sig, meta = wfdb.rdsamp(record_path)
        except Exception as e:
            # Log failure and continue if file cannot be read
            print(f"Error reading {record_path}: {e}")
            qc_rows.append({
                "Filename": fname,
                "qc_pass": False,
                "fail_reason": f"read_error:{e}",
            })
            continue

        # ===== BLOCK 3E: Handle metadata compatibility =====
        # Different wfdb versions return meta as dict or object
        # Normalize to dict format for consistent handling
        meta_dict = meta if isinstance(meta, dict) else meta.__dict__

        # ===== BLOCK 3F: Compute QC metrics on raw signal =====
        # Calculate signal quality indices (amplitude, spectral, HR consistency, etc.)
        # and determine if the record passes quality thresholds
        qc = compute_qc_metrics(
              np.asarray(sig),
              meta_dict,
              pSQI_mean=float(row["pSQI_mean"]),
              bSQI_mean=float(row["bSQI_mean"]),
          )

        # ===== BLOCK 3G: Attach metadata to QC result =====
        # Add filename and label information to QC metrics dict
        qc["Filename"] = fname
        qc["label_raw"] = label_raw
        qc["label_int"] = int(label_int)

        # ===== BLOCK 3H: Skip records that fail QC =====
        # Records with poor signal quality are logged but not processed further
        if not qc["qc_pass"]:
            qc_rows.append(qc)
            continue

        # ===== BLOCK 3I: Preprocess signal (bandpass + resample) =====
        # Apply bandpass filter (0.5-40 Hz) and resample to TARGET_FS (500 Hz)
        sig_proc, fs_new = preprocess_record(np.asarray(sig), meta_dict, target_fs=TARGET_FS)

        # ===== BLOCK 3J: Divide signal into overlapping windows =====
        # Create fixed-length windows (10 sec) with 50% overlap (5 sec step)
        # Each window is normalized independently (z-score per channel)
        windows, labels = window_record(sig_proc, fs=fs_new, label_int=label_int)

        # ===== BLOCK 3K: Record number of windows created =====
        # Store the count of windows extracted from this record
        qc["n_windows"] = len(windows)

        # ===== BLOCK 3L: Accumulate results =====
        # Add QC metrics to summary and extend window/label lists
        qc_rows.append(qc)
        all_windows.extend(windows)
        all_labels.extend(labels)

    # ===== BLOCK 4: Check if any windows were created =====
    # Abort if no records passed QC (likely indicates misconfigured thresholds)
    if len(all_windows) == 0:
        raise RuntimeError("No windows created. Check QC thresholds and label column.")

    # ===== BLOCK 5: Stack windows into single array =====
    # Convert list of windows into 3D numpy array: (N_windows, time_samples, channels)
    X = np.stack(all_windows, axis=0)  # (N, T, C)
    y = np.array(all_labels, dtype=np.int64)

    # ===== BLOCK 6: Create QC summary DataFrame =====
    # Convert QC metrics list into pandas DataFrame for easy analysis
    df_qc = pd.DataFrame(qc_rows)

    # ===== BLOCK 7: Save all results to disk =====
    # Write outputs to artifact directory for later loading and model training
    os.makedirs(ARTIFACT_DIR, exist_ok=True)
    np.save(os.path.join(ARTIFACT_DIR, "X_windows.npy"), X)
    np.save(os.path.join(ARTIFACT_DIR, "y_labels.npy"), y)
    df_qc.to_csv(os.path.join(ARTIFACT_DIR, "qc_summary.csv"), index=False)

    # ===== BLOCK 8: Print summary statistics =====
    # Display shapes and counts to verify processing completed successfully
    print("Saved:")
    print("  X_windows.npy shape:", X.shape)
    print("  y_labels.npy shape:", y.shape)
    print("  qc_summary.csv rows:", len(df_qc))

    # ===== BLOCK 9: Return processed data =====
    # Return arrays ready for model training and QC report for analysis
    return X, y, df_qc

In [123]:
X, y, df_qc = build_qc_and_windows(
    df_attr=df_attr,
    ecg_dir=ECG_DIR,
    label_col="ICD_primary",
    max_records=None
)


  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(


Saved:
  X_windows.npy shape: (59868, 5000, 1)
  y_labels.npy shape: (59868,)
  qc_summary.csv rows: 14190


In [124]:

from huggingface_hub import HfApi, login

# Login to Hugging Face (you'll be prompted for token)
login()

# Initialize API
api = HfApi()

# File paths and metadata
files_to_upload = [
    "/content/ECG-database/artifacts/X_windows.npy",
    "/content/ECG-database/artifacts/label_mapping.json",
    "/content/ECG-database/artifacts/qc_summary.csv",
    "/content/ECG-database/artifacts/y_labels.npy"
]

repo_id = "kiril-buga/ECG-database"
repo_type = "dataset"
commit_message = "Upload preprocessed ECG data and artifacts"

# Upload each file
for file_path in files_to_upload:
    try:
        api.upload_file(
            path_or_fileobj=file_path,
            path_in_repo=f"artifacts/{file_path.split('/')[-1]}",
            repo_id=repo_id,
            repo_type=repo_type,
            commit_message=f"Add {file_path.split('/')[-1]}"
        )
        print(f"✓ Uploaded: {file_path.split('/')[-1]}")
    except Exception as e:
        print(f"✗ Failed to upload {file_path}: {e}")

print("\n✓ All files uploaded!")

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

  ...e/artifacts/X_windows.npy:   0%|          |  683kB / 1.20GB            

✓ Uploaded: X_windows.npy


No files have been modified since last commit. Skipping to prevent empty commit.


✓ Uploaded: label_mapping.json
✓ Uploaded: qc_summary.csv


Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

  ...se/artifacts/y_labels.npy: 100%|##########|  479kB /  479kB            

✓ Uploaded: y_labels.npy

✓ All files uploaded!


In [125]:
# Make sure cell-17 (build_qc_and_windows) is executed first before running this cell

# Split data into train, val and test
# Get unique Patient_ID values
patient_ids = df_attr['Patient_ID'].unique()

# Split Patient_IDs into training, validation, and test sets
train_ids, test_ids = train_test_split(patient_ids, test_size=0.2, random_state=42)
train_ids, val_ids = train_test_split(train_ids, test_size=0.25, random_state=42)  # 0.25 * 0.8 = 0.2

# Get indices for training, validation, and test sets based on Patient_ID
# Note: This assumes X and y come from build_qc_and_windows which uses df_attr rows
# We need to match the windows back to Patient_IDs (this logic needs adjustment based on qc_df)

print("Train patient IDs:", len(train_ids))
print("Val patient IDs:", len(val_ids))
print("Test patient IDs:", len(test_ids))
print("\nNote: Proper train/test split requires matching windows back to patient IDs.")
print("Consider using df_qc returned from build_qc_and_windows for more accurate splitting.")

Train patient IDs: 6985
Val patient IDs: 2329
Test patient IDs: 2329

Note: Proper train/test split requires matching windows back to patient IDs.
Consider using df_qc returned from build_qc_and_windows for more accurate splitting.


### Test Loader

In [126]:
# ===== Load ECG data =====


# Load the raw ECG signal data
def load_raw_data(df, path, n_samples=None):
    """Return list of ECG arrays and list of metadata dicts."""
    filenames = df["Filename"].tolist()
    if n_samples is not None:
        filenames = filenames[:n_samples]

    signals = []
    metas = []
    for fname in filenames:
        sig, meta = wfdb.rdsamp(os.path.join(path, fname))
        signals.append(sig)
        metas.append(meta)
    return signals, metas

# Load diagnostic comments from WFDB metadata
def load_Diag(df, path, n_samples=None):
    """Return disease and ECG diagnostic comments from WFDB records."""
    filenames = df["Filename"].tolist()
    if n_samples is not None:
        filenames = filenames[:n_samples]

    disease_diag = []
    ecg_diag = []
    for fname in filenames:
        record = wfdb.rdrecord(os.path.join(path, fname))
        comments = record.comments
        disease_diag.append(comments[1] if len(comments) > 1 else None)
        ecg_diag.append(comments[2] if len(comments) > 2 else None)
    return disease_diag, ecg_diag

# ===== Load a few ECG records and check shapes =====

# Load only a few records so Colab RAM is safe
N_SAMPLES = 5
signals, metas = load_raw_data(df_attr, ECG_DIR, n_samples=N_SAMPLES)

print(f"Loaded {len(signals)} ECG signals.")
print("Shape of first signal (time, leads):", signals[0].shape)
print("Meta of first signal:")
print(metas[0])

# Attach comments for the same subset
disease_diag, ecg_diag = load_Diag(df_attr, ECG_DIR, n_samples=N_SAMPLES)

df_subset = df_attr.iloc[:N_SAMPLES].copy()
df_subset["Disease_diag_comment"] = disease_diag
df_subset["ECG_diag_comment"] = ecg_diag
display(df_subset)


Loaded 5 ECG signals.
Shape of first signal (time, leads): (9000, 9)
Meta of first signal:
{'fs': 500, 'sig_len': 9000, 'n_sig': 9, 'base_date': None, 'base_time': None, 'units': ['mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV'], 'sig_name': ['I', 'II', 'III', 'AVR', 'AVL', 'AVF', 'V1', 'V3', 'V5'], 'comments': ['Female 572', 'Ventricular septal defect', 'Left ventricular high voltage;T-wave abnormality']}


Unnamed: 0,Filename,ECG_ID,Patient_ID,Age,Gender,Acquisition_date,Sampling_point,Lead,AHA_code,CHN_code,...,ICD_list,ICD_primary,pSQI_dict,basSQI_dict,bSQI_dict,pSQI_mean,basSQI_mean,bSQI_mean,Disease_diag_comment,ECG_diag_comment
0,P00/P00001/P00001_E01,P00001_E01,P00001,572d,'Female',2017-11-22 10:46:08,9000,9,'Left ventricular high voltage';'L147','J106';'L123',...,"[I34.0, Q21.0, Q24.9]",I34.0,"{'I': 0.288, 'II': 0.323, 'III': 0.346, 'aVR':...","{'I': 0.994, 'II': 0.996, 'III': 0.991, 'aVR':...","{'I': 1.0, 'II': 1.0, 'III': 1.0, 'aVR': 1.0, ...",0.317667,0.977444,1.0,Ventricular septal defect,Left ventricular high voltage;T-wave abnormality
1,P00/P00002/P00002_E01,P00002_E01,P00002,4327d,'Male',2017-11-28 21:59:47,15000,12,'C21','C13',...,"[I51.4, J18.9]",I51.4,"{'I': 0.472, 'II': 0.446, 'III': 0.449, 'aVR':...","{'I': 0.995, 'II': 0.98, 'III': 0.992, 'aVR': ...","{'I': 1.0, 'II': 1.0, 'III': 1.0, 'aVR': 1.0, ...",0.449333,0.9895,0.996583,Myocarditis,Sinus tachycardia
2,P00/P00003/P00003_E01,P00003_E01,P00003,1087d,'Female',2017-11-29 16:04:57,10000,12,'C21','C13',...,"[Q21.0, Q24.9]",Q21.0,"{'I': 0.495, 'II': 0.347, 'III': 0.34, 'aVR': ...","{'I': 0.915, 'II': 0.895, 'III': 0.882, 'aVR':...","{'I': 1.0, 'II': 1.0, 'III': 1.0, 'aVR': 1.0, ...",0.40225,0.923667,1.0,Ventricular septal defect,Sinus tachycardia
3,P00/P00004/P00004_E01,P00004_E01,P00004,2465d,'Male',2017-11-30 15:21:27,13000,9,'C21','C13',...,"[Q21.1, Q24.9]",Q21.1,"{'I': 0.34, 'II': 0.405, 'III': 0.409, 'aVR': ...","{'I': 0.981, 'II': 0.988, 'III': 0.974, 'aVR':...","{'I': 1.0, 'II': 1.0, 'III': 1.0, 'aVR': 1.0, ...",0.382222,0.985222,1.0,Atrial septal defect,Sinus tachycardia
4,P00/P00004/P00004_E02,P00004_E02,P00004,2461d,'Male',2017-11-26 19:19:48,15000,9,'A1','A1',...,"[Q21.1, Q24.9]",Q21.1,"{'I': 0.501, 'II': 0.494, 'III': 0.389, 'aVR':...","{'I': 0.993, 'II': 0.993, 'III': 0.989, 'aVR':...","{'I': 1.0, 'II': 1.0, 'III': 1.0, 'aVR': 1.0, ...",0.502667,0.972778,1.0,Atrial septal defect,Normal ECG
