# ECG Multi-Label Preprocessing (Self-Contained)
Fully independent notebook: loads raw data, processes, creates 5-class one-hot labels, uploads to HF

In [None]:
!pip install wfdb neurokit2 scikit-learn scipy matplotlib pandas numpy huggingface-hub h5py -q
!apt-get update && apt-get install -y p7zip-full

import os
import json
import numpy as np
import pandas as pd
import wfdb
import h5py
from scipy.signal import butter, filtfilt, welch, resample
import neurokit2 as nk
from datetime import datetime

np.random.seed(42)
print("✓ Dependencies installed")

## Environment Setup

In [None]:
# Set this to True to download from Huggingface else use Google Drive
USE_HF = True

if USE_HF:
  from huggingface_hub import snapshot_download
  local_dir = snapshot_download(
      repo_id="kiril-buga/ECG-database",
      repo_type="dataset",
      local_dir="/content/ECG-database/" # Specify the desired download directory
  )
  print("Downloaded to:", local_dir)

    DATA_PATH = f"{local_dir}/data/"
    ARTIFACT_DIR = f"{local_dir}/artifacts/"

else:
    # Detect environment and mount drive if Colab
    IN_COLAB = False
    try:
        from google.colab import drive
        IN_COLAB = True
        drive.mount('/content/drive/')
    except:
        pass

    # Set paths
    if IN_COLAB:
        DATA_PATH = "/content/drive/MyDrive/DeepLearningECG/data/"
        ARTIFACT_DIR = "/content/drive/MyDrive/DeepLearningECG/artifacts/"
    else:
        DATA_PATH = "../DeepLearningECG/data/"
        ARTIFACT_DIR = "../DeepLearningECG/artifacts/"

ECG_DIR = os.path.join(DATA_PATH, "Child_ecg/")
OUT_DIR = os.path.join(ARTIFACT_DIR, "multilabel_v2")
os.makedirs(OUT_DIR, exist_ok=True)

print(f"Colab: {IN_COLAB}")
print(f"DATA: {DATA_PATH}")
print(f"OUTPUT: {OUT_DIR}")

In [None]:
!cd DATA_PATH && 7z x Child_ecg.zip
print("✓ Extraction complete!")

## Load CSV Metadata

In [None]:
csv_path = os.path.join(DATA_PATH, 'AttributesDictionary.csv')

if os.path.exists(csv_path):
    df_attr = pd.read_csv(csv_path)
    print(f"Loaded CSV: {df_attr.shape}")
else:
    from huggingface_hub import hf_hub_download
    print("Downloading CSV from Hugging Face...")
    csv_file = hf_hub_download(
        repo_id="kiril-buga/ECG-database",
        filename="AttributesDictionary.csv",
        repo_type="dataset"
    )
    df_attr = pd.read_csv(csv_file)
    print(f"Loaded CSV: {df_attr.shape}")

## Signal Processing Functions

In [None]:
def apply_bandpass(x, fs, lowcut=0.5, highcut=40.0):
    if x.ndim == 1:
        x = x[:, None]
    nyq = 0.5 * fs
    b, a = butter(4, [lowcut/nyq, highcut/nyq], btype="band")
    return np.column_stack([filtfilt(b, a, x[:, i]) for i in range(x.shape[1])])

def band_power(f, Pxx, fmin, fmax):
    mask = (f >= fmin) & (f <= fmax)
    return np.trapz(Pxx[mask], f[mask]) if np.any(mask) else 0.0

HAS_NK = True
try:
    import neurokit2
except:
    HAS_NK = False

def compute_qc(sig, meta, pSQI_mean, bSQI_mean):
    """Compute QC metrics."""
    qc = {"pSQI_mean": pSQI_mean, "bSQI_mean": bSQI_mean}
    
    fs = meta.get("fs", None)
    if fs is None:
        fs = getattr(meta, "fs", None)
    if fs is None:
        raise ValueError("Missing fs")
    
    if sig.ndim == 1:
        sig = sig[:, None]
    
    n_samples, n_leads = sig.shape
    qc["n_samples"] = int(n_samples)
    qc["n_leads"] = int(n_leads)
    qc["duration_sec"] = n_samples / fs
    
    lead = sig[:, 0]
    n_nans = np.isnan(lead).sum()
    qc["nan_fraction"] = float(n_nans / len(lead))
    
    lead_clean = lead.copy()
    if n_nans > 0:
        not_nan = ~np.isnan(lead_clean)
        if not np.any(not_nan):
            return {**qc, "qc_pass": False, "fail_reason": "all_nan"}
        lead_clean[~not_nan] = np.interp(np.flatnonzero(~not_nan), 
                                          np.flatnonzero(not_nan), lead_clean[not_nan])
    
    amp = lead_clean
    qc["amp_mean"] = float(np.mean(amp))
    qc["amp_std"] = float(np.std(amp))
    q1, q99 = np.percentile(amp, [1, 99])
    qc["amp_robust_range"] = float(q99 - q1)
    
    f, Pxx = welch(amp, fs=fs, nperseg=min(4096, len(amp)))
    qc["baseline_wander_ratio"] = band_power(f, Pxx, 0.0, 0.5) / (band_power(f, Pxx, 0.5, 40.0) + 1e-8)
    qc["powerline_ratio"] = band_power(f, Pxx, 48.0, 52.0) / (band_power(f, Pxx, 40.0, 60.0) + 1e-8)
    
    reasons = []
    if qc["duration_sec"] < 8.0: reasons.append("too_short")
    if qc["nan_fraction"] > 0.01: reasons.append("too_many_nans")
    if not (0.05 < qc["amp_robust_range"] < 10.0): reasons.append("amp_out_of_range")
    if qc["baseline_wander_ratio"] > 0.5: reasons.append("baseline_wander")
    if qc["powerline_ratio"] > 0.5: reasons.append("powerline_noise")
    if pSQI_mean < 0.2: reasons.append("low_pSQI")
    if bSQI_mean < 0.8: reasons.append("low_bSQI")
    
    qc["qc_pass"] = len(reasons) == 0
    qc["fail_reason"] = ";".join(reasons) if reasons else ""
    
    return qc

print("✓ Processing functions defined")

## Preprocessing & Windowing

In [None]:
def preprocess_record(sig, meta, target_fs=500.0):
    fs = meta.get("fs", None) or getattr(meta, "fs", None)
    if sig.ndim == 1:
        sig = sig[:, None]
    
    sig_bp = apply_bandpass(sig, fs=fs)
    if fs == target_fs:
        return sig_bp, fs
    
    n_samples = sig_bp.shape[0]
    n_new = int(round(n_samples / fs * target_fs))
    sig_res = np.column_stack([resample(sig_bp[:, i], n_new) for i in range(sig_bp.shape[1])])
    return sig_res, target_fs

def window_record(sig, fs, window_sec=10.0, step_sec=5.0, target_samples=None):
    """
    Create windows from preprocessed ECG signal.
    Pads/truncates to fixed sample length to ensure consistent shapes.
    
    Parameters:
    - sig: (n_samples, n_leads) signal array
    - fs: sampling frequency (Hz)
    - window_sec: window duration in seconds
    - step_sec: step duration in seconds  
    - target_samples: target number of samples per window (default: window_sec * fs)
    
    Returns:
    - List of windows with shape (target_samples, n_leads)
    """
    if sig.ndim == 1:
        sig = sig[:, None]
    
    n_samples = sig.shape[0]
    n_leads = sig.shape[1]
    win_len = int(window_sec * fs)
    step_len = int(step_sec * fs)
    
    if target_samples is None:
        target_samples = win_len
    
    windows = []
    start = 0
    
    while start + win_len <= n_samples:
        segment = sig[start:start + win_len, :]
        
        # Skip windows with too many NaNs
        if np.isnan(segment).mean() > 0.05:
            start += step_len
            continue
        
        # Normalize each channel
        seg_norm = segment.copy()
        for ch in range(seg_norm.shape[1]):
            x = seg_norm[:, ch]
            m, s = np.nanmean(x), np.nanstd(x)
            seg_norm[:, ch] = (x - m) / (s if s > 1e-6 else 1.0)
        
        # Pad or truncate to target_samples
        if seg_norm.shape[0] < target_samples:
            pad_len = target_samples - seg_norm.shape[0]
            seg_norm = np.pad(seg_norm, ((0, pad_len), (0, 0)), mode='constant', constant_values=0)
        elif seg_norm.shape[0] > target_samples:
            seg_norm = seg_norm[:target_samples, :]
        
        windows.append(seg_norm.astype(np.float16))  # Use float16 for compression
        start += step_len
    
    return windows

print("✓ Preprocessing functions defined")

## ICD Code Parsing & Disease Mapping

In [None]:
ICD_TO_DISEASE = {
    'I40.0': 'Myocarditis', 'I40.9': 'Myocarditis', 'I41.4': 'Myocarditis',
    'I42.0': 'Cardiomyopathy', 'I42.2': 'Cardiomyopathy', 'I42.9': 'Cardiomyopathy', 'Q28.4': 'Cardiomyopathy',
    'M30.3': 'Kawasaki',
    'Q21.1': 'CHD', 'Q21.2': 'CHD', 'Q21.3': 'CHD', 'Q22.1': 'CHD', 'Q25.0': 'CHD', 'Q25.6': 'CHD', 'I27.9': 'CHD',
}

DISEASE_CLASSES = ['Myocarditis', 'Cardiomyopathy', 'Kawasaki', 'CHD', 'Healthy']
CLASS_IDX = {c: i for i, c in enumerate(DISEASE_CLASSES)}

def parse_icd(s):
    if pd.isna(s):
        return []
    return [p.strip().replace("'", "") for p in str(s).split(";") if p.strip()]

def clean_icd(code):
    if pd.isna(code):
        return None
    code_str = str(code).strip()
    if ')' in code_str:
        code_str = code_str.split(')')[-1].strip()
    return code_str or None

def parse_sqi(s):
    if pd.isna(s):
        return {}
    out = {}
    for item in str(s).split(";"):
        if ":" in item:
            k, v = item.split(":")
            try:
                out[k.replace("'", "").strip()] = float(v)
            except:
                pass
    return out

# Parse ICD codes
df_attr["ICD_list"] = df_attr["ICD-10 code"].apply(parse_icd)
df_attr["ICD_primary"] = df_attr["ICD_list"].apply(lambda x: x[0] if x else None)
df_attr["ICD_primary_clean"] = df_attr["ICD_primary"].apply(clean_icd)
df_attr["disease"] = df_attr["ICD_primary_clean"].apply(lambda x: ICD_TO_DISEASE.get(x, 'Healthy') if x else 'Healthy')

# Parse SQI
for col in ["pSQI", "basSQI", "bSQI"]:
    df_attr[f"{col}_dict"] = df_attr[col].apply(parse_sqi)
    df_attr[f"{col}_mean"] = df_attr[f"{col}_dict"].apply(lambda d: np.mean(list(d.values())) if d else np.nan)

print("Disease distribution:")
print(df_attr['disease'].value_counts())

## Main Processing Pipeline

In [None]:
def process_all_hdf5(df, ecg_dir, max_records=None, target_samples=5000, target_channels=12, output_file=None):
    """
    Process ECG records and save directly to HDF5 with compression.
    Memory-efficient: streams windows to disk without stacking.
    
    Parameters:
    - output_file: Path to save HDF5 file (default: OUT_DIR/ecg_data.h5)
    
    Compression: gzip level 4 + float16 reduces 15GB → ~1.5-2GB
    """
    if output_file is None:
        output_file = os.path.join(OUT_DIR, "ecg_data.h5")
    
    # First pass: count total windows
    print("Pass 1: Counting total windows...")
    total_windows = 0
    qc_list_prepass = []
    
    iterator = df.iloc[:max_records].iterrows() if max_records else df.iterrows()
    total_records = max_records if max_records else len(df)
    
    for idx, row in iterator:
        if (idx + 1) % 100 == 0:
            print(f"  Scanning [{idx + 1}/{total_records}]...")
        
        fname = row["Filename"]
        path = os.path.join(ecg_dir, fname)
        
        try:
            sig, meta = wfdb.rdsamp(path)
        except:
            continue
        
        meta_dict = meta if isinstance(meta, dict) else meta.__dict__
        sig = np.asarray(sig)
        
        qc = compute_qc(sig, meta_dict, float(row["pSQI_mean"]), float(row["bSQI_mean"]))
        qc_list_prepass.append(qc)
        
        if not qc["qc_pass"]:
            continue
        
        sig_proc, fs = preprocess_record(sig, meta_dict)
        windows = window_record(sig_proc, fs, target_samples=target_samples)
        total_windows += len(windows)
    
    print(f"\n✓ Estimated {total_windows} windows")
    
    # Second pass: write to HDF5 with compression
    print(f"\nPass 2: Writing to HDF5 with compression ({output_file})...")
    
    with h5py.File(output_file, 'w') as h5f:
        # Create datasets with gzip compression (level 4 = good balance speed/compression)
        X_dset = h5f.create_dataset(
            'X', 
            shape=(total_windows, target_samples, target_channels), 
            dtype=np.float16,  # Use float16 instead of float32
            compression='gzip',
            compression_opts=4
        )
        y_dset = h5f.create_dataset(
            'y', 
            shape=(total_windows, len(DISEASE_CLASSES)), 
            dtype=np.int32,
            compression='gzip',
            compression_opts=4
        )
        diseases_dset = h5f.create_dataset(
            'diseases', 
            shape=(total_windows,), 
            dtype=h5py.string_dtype(encoding='utf-8'),
            compression='gzip',
            compression_opts=4
        )
        
        # Store metadata
        h5f.attrs['target_samples'] = target_samples
        h5f.attrs['target_channels'] = target_channels
        h5f.attrs['disease_classes'] = DISEASE_CLASSES
        h5f.attrs['data_format'] = 'float16 + gzip'
        
        qc_list = []
        all_diseases = []
        window_idx = 0
        
        # Iterate through records again
        iterator = df.iloc[:max_records].iterrows() if max_records else df.iterrows()
        
        for idx, row in iterator:
            if (idx + 1) % 10 == 0:
                print(f"  [{idx + 1}/{total_records}] {window_idx}/{total_windows} windows written...")
            
            fname = row["Filename"]
            disease = row["disease"]
            path = os.path.join(ecg_dir, fname)
            
            try:
                sig, meta = wfdb.rdsamp(path)
            except Exception as e:
                qc_list.append({"Filename": fname, "disease": disease, "qc_pass": False, "fail_reason": str(e)})
                continue
            
            meta_dict = meta if isinstance(meta, dict) else meta.__dict__
            sig = np.asarray(sig)
            original_channels = sig.shape[1]
            
            qc = compute_qc(sig, meta_dict, float(row["pSQI_mean"]), float(row["bSQI_mean"]))
            qc["Filename"] = fname
            qc["disease"] = disease
            qc["original_channels"] = original_channels
            
            if not qc["qc_pass"]:
                qc_list.append(qc)
                continue
            
            sig_proc, fs = preprocess_record(sig, meta_dict)
            windows = window_record(sig_proc, fs, target_samples=target_samples)
            
            # Pad and write windows directly to HDF5
            for window in windows:
                if window.shape[1] < target_channels:
                    pad_channels = target_channels - window.shape[1]
                    window = np.pad(window, ((0, 0), (0, pad_channels)), mode='constant', constant_values=0)
                elif window.shape[1] > target_channels:
                    window = window[:, :target_channels]
                
                # Write to HDF5 (automatically compressed)
                X_dset[window_idx] = window
                y_dset[window_idx, CLASS_IDX[disease]] = 1
                diseases_dset[window_idx] = disease
                
                window_idx += 1
                all_diseases.append(disease)
            
            qc["n_windows"] = len(windows)
            qc_list.append(qc)
    
    print(f"\n✓ Saved to {output_file}")
    print(f"  X shape: (windows, samples, channels) = ({window_idx}, {target_samples}, {target_channels})")
    print(f"  Data format: float16 + gzip compression")
    
    return output_file, pd.DataFrame(qc_list)

# Run processing
print("Processing ECG records...")
h5_file, df_qc = process_all_hdf5(df_attr, ECG_DIR, max_records=None)

# Verify file
with h5py.File(h5_file, 'r') as h5f:
    print(f"\nDataset shapes:")
    print(f"  X: {h5f['X'].shape} (dtype: {h5f['X'].dtype})")
    print(f"  y: {h5f['y'].shape}")
    print(f"  diseases: {h5f['diseases'].shape}")
    
    print(f"\nDisease distribution:")
    for i, cls in enumerate(DISEASE_CLASSES):
        count = h5f['y'][:, i].sum()
        print(f"  {cls}: {count}")

## Save Results

In [None]:
# Save QC summary and metadata
df_qc[['Filename', 'disease', 'qc_pass', 'n_windows', 'original_channels']].to_csv(
    os.path.join(OUT_DIR, "qc_summary.csv"), index=False
)

with open(os.path.join(OUT_DIR, "disease_classes.json"), "w") as f:
    json.dump({
        "classes": DISEASE_CLASSES, 
        "class_idx": CLASS_IDX, 
        "icd_map": ICD_TO_DISEASE,
        "data_format": "hdf5",
        "hdf5_file": "ecg_data.h5"
    }, f, indent=2)

print(f"✓ Saved to {OUT_DIR}")
print(f"  - ecg_data.h5 (HDF5 format)")
print(f"  - qc_summary.csv")
print(f"  - disease_classes.json")

## Upload to Hugging Face (Optional)

In [None]:
UPLOAD_TO_HF = False  # Set to True to upload

if UPLOAD_TO_HF:
    from huggingface_hub import HfApi, login
    
    print("Logging into Hugging Face...")
    login()
    
    api = HfApi()
    print("Uploading to HF...")
    api.upload_folder(
        folder_path=OUT_DIR,
        repo_id="kiril-buga/ECG-database",
        repo_type="dataset",
        path_in_repo="multilabel_v2",
        commit_message="Multi-label preprocessed data"
    )
    print("✓ Uploaded to HF")
else:
    print("To upload: set UPLOAD_TO_HF=True and have HF write token")