In [1]:

## AFDB preprocessing ; preprocess each patients ECG
"""
convert each AFDB record into fixed length labeled windows for training
Each per record .npz file is saved with:
    X: per window z scored ECG wave transform [num_windows,win_len]
    y: window label (0 = not AF, 1 = AF)
    rr_feat: RR/HRV features per windwo [num_windows:10]
    rr_valid: if the rr freatures are valid for that window (t/f)

"""
import os
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import wfdb  #record I/O
from wfdb import processing  # resampling + signal helpers

# bandpass filtering
from scipy.signal import butter, filtfilt
from scipy.signal import find_peaks



DATA_DIR = Path("../data/raw/afdb")    
OUT_DIR  = Path("../artifacts/afdb_npz") 
OUT_DIR.mkdir(parents=True, exist_ok=True) 

SUMMARY_CSV = OUT_DIR / "summary.csv"   #output summary table path

LEAD_IDX = 0 #the first channel index for AFDB
TARGET_FS = None 
WIN_SEC = 10 #10s windows of ecg
STRIDE_SEC = 5 #each widnow
AF_THRESHOLD = 0.2 #if more than 20% of the window overlaps an AF, label it AF
LOW_HZ = 0.5 #bandpass low cutoff
HIGH_HZ = 40.0 #band pass high cutoff ;; to remove noise
F_ORDER = 4 #filter order

#graphs vars
PLOT_EXAMPLES = False
# PLOT_K = 3
# PLOT_SEED = 0





In [2]:
# functions::



"""
A band pass filter with 0 phase filtering using filtfilt
ECG contains very low freq noise (baseline wander) and high freq noise (powerline noise)
filtfilt applies filters forward & backwards which cancels phase distortion but doesnt shift any waves in time 
otherwise that would ruin the incoming signal
"""

def bandpass_filter(
        #apply bandpass filter with zero phase filtering
#removes low and high freq noise, using filtfilt to avoid phase distortion
    x: np.ndarray,
    fs: float,
    low_hz: float = LOW_HZ,
    high_hz: float = HIGH_HZ,
    order: int = F_ORDER
) -> np.ndarray:
    x = np.asarray(x, dtype=np.float32)

    if low_hz <= 0:
        raise ValueError("low_hz must be > 0")
    if high_hz >= fs / 2:
        raise ValueError("high_hz must be < Nyquist (fs/2)")
    if high_hz <= low_hz:
        raise ValueError("high_hz must be > low_hz")

    nyq = fs / 2.0
    low = low_hz / nyq
    high = high_hz / nyq

    b, a = butter(order, [low, high], btype="bandpass")
    y = filtfilt(b, a, x)
    return y.astype(np.float32)


# afdb records have headers.. so read them by base name
def listafdbrecords(data_dir:Path):
    return sorted({p.stem for p in data_dir.glob("*.hea")})


## rhythm interval parsing

"""
convert rhythm change markers from the WFDB annotations into closed intervals
 where label 0=Normal, 1=AF.

    Uses aux_note strings like:
      '(N'    normal
      '(AFIB' AF
      '(AFL'  AF

"""
# since AFDB labels are typically change points and not per sample labels, we convert to intervals
#which makes window labeling easier
def build_rhythm_intervals_from_aux(ann_samples, aux_notes, N_total_samples: int):
    #Build closed rhythm intervals [start, end)
    changes = []  # (sample, label)

    for s, aux in zip(ann_samples, aux_notes):
        if aux is None:
            continue

        a = aux.strip().upper()
        if a.startswith("(AF") or "AFIB" in a or "AFL" in a:
            changes.append((int(s), 1))
        elif a.startswith("(N") or a in ("N", "NSR", "(NSR"):
            changes.append((int(s), 0))



    changes.sort(key=lambda t: t[0])
    if not changes:
        return []

    #fill beginning if first label starts after 0
    if changes[0][0] > 0:#assyme rhythm is whatever the first label says
        changes = [(0, changes[0][1])] + changes

    intervals = []
    for i in range(len(changes)):
        start = changes[i][0]
        label = changes[i][1]
        end = changes[i + 1][0] if i + 1 < len(changes) else N_total_samples

        if end > start:
            intervals.append((start, end, label))

    return intervals


#num of samples in the intersection of [a0,a1] + [b0,b1)
def interval_overlap(a0,a1,b0,b1):
    return max(0,min(a1,b1)-max(a0,b0))


"""
Label window as AF (1) if >= threshold 
"""
def label_window(start,end, rhythm_intervals, af_threshold=0.2):
    total = end - start
    if total <= 0:
        return 0
    

    af_overlap = 0
    for r0, r1, lab in rhythm_intervals:
        if lab != 1:
            continue
        af_overlap += interval_overlap(start, end, r0, r1)
    return 1 if (af_overlap / total) >= af_threshold else 0


def computeAFfromintervals(rhythm_intervals,N_totalsamples):
    if not rhythm_intervals:
        return np.nan

    af_samples = 0
    total = 0
    for r0, r1, lab in rhythm_intervals:
        r0 = max(0, int(r0))
        r1 = min(N_totalsamples, int(r1))
        if r1 <= r0:
            continue
        seg_len = r1 - r0
        total += seg_len
        if lab == 1:
            af_samples += seg_len

    return (af_samples / total) if total > 0 else np.nan




# def plot_randomwindows(Xplot,y,k=3,seed=0,title_prefix=""):

#     rng = np.random.default_rng(seed)

#     fig, axes = plt.subplots(2, k, figsize=(4 * k, 5), sharex=True, sharey=True)

#     for cls in [0, 1]:
#         idxs = np.where(y == cls)[0]
#         chosen = rng.choice(idxs, size=min(k, len(idxs)), replace=False) if len(idxs) else []

#         for j in range(k):
#             ax = axes[cls, j]
#             if j < len(chosen):
#                 i = chosen[j]
#                 ax.plot(Xplot[i])
#                 ax.set_title(f"{title_prefix}{'Normal' if cls==0 else 'AF'} (win={i})")
#             else:
#                 ax.set_axis_off()

#     plt.tight_layout()
#     plt.show()

In [3]:
## Rpeak detectin & RR features


## this is a fallback if it cant find annotations
def detect_rpeaks_simple(x: np.ndarray, fs: float) -> np.ndarray:
    x = np.asarray(x, dtype=np.float32)

    dx = np.diff(x, prepend=x[0])     # discrete derivative (same length as x)
    energy = (np.abs(dx) ** 2)        # energy-like measure

    # Min distance between R-peaks: 0.25s => max 240 bpm
    min_dist = int(0.25 * fs)

    #thresh is 95th percentile
    thr = np.percentile(energy, 95)

    peaks, _ = find_peaks(energy, distance=min_dist, height=thr)
    return peaks.astype(np.int64)



##this should be used primarily
"""
WFDB annotations are more accurate and the features extracted make more sense if peaks are accurate
"""
def get_rpeaks_from_annotations_or_detect(record_path: str, x: np.ndarray, fs: float, ann=None) -> np.ndarray:

    if ann is None:
        ann = wfdb.rdann(record_path, "atr")  

    r = None

    # remove +'s , theyre non beat markers
    if hasattr(ann, "symbol") and ann.symbol is not None and len(ann.symbol) == len(ann.sample):
        sym = np.asarray(ann.symbol)
        sam = np.asarray(ann.sample, dtype=np.int64)


        cand = sam[sym != "+"]


        cand = cand[(cand >= 0) & (cand < len(x))]

        ##needs to have enough beats to matter
        if len(cand) >= 10:
            r = cand

    if r is None:
        r = detect_rpeaks_simple(x, fs)

    return r.astype(np.int64)



"""
Get 10 RR/HRV features from local winow R peaks

#returns:
    feat (10,) float32
    valid: 1.0 if enough peaks to compute, else 0
    #some windows can have too few beats ,, need to verify otherwise model can ignore RR features when invalid
"""
def rr_features_from_peaks(r_peaks_win: np.ndarray, fs: float):

    #Need at least 3 peaks => at least 2 RR intervals
    if r_peaks_win is None or len(r_peaks_win) < 3:
        return np.zeros(10, dtype=np.float32), 0.0

    rr = np.diff(r_peaks_win).astype(np.float32) / float(fs)  #RR in seconds
    if len(rr) < 2:
        return np.zeros(10, dtype=np.float32), 0.0

    drr = np.diff(rr)                  
    abs_drr = np.abs(drr)



    mean_rr = rr.mean()
    sdnn = rr.std()
    rmssd = float(np.sqrt(np.mean(drr ** 2))) if len(drr) else 0.0
    pnn50 = float(np.mean(abs_drr > 0.05)) if len(abs_drr) else 0.0  #50ms
    cv = float(sdnn / (mean_rr + 1e-8))

    med_rr = float(np.median(rr))
    iqr_rr = float(np.percentile(rr, 75) - np.percentile(rr, 25))
    mad_rr = float(np.median(np.abs(rr - med_rr)))

    #Turning point ratio: fraction of local extrema in rr series
    if len(rr) >= 3:
        tp = 0
       # for i in range(1, len(rr) - 1):
        for i in range(1, len(rr) - 1):
            if (rr[i] > rr[i-1] and rr[i] > rr[i+1]) or (rr[i] < rr[i-1] and rr[i] < rr[i+1]):
                tp += 1
        tpr = float(tp / (len(rr) - 2))
    else:
        tpr = 0.0

    rr_range = float(rr.max() - rr.min())

    feat = np.array(
        [mean_rr, sdnn, rmssd, pnn50, cv, med_rr, iqr_rr, mad_rr, tpr, rr_range],
        dtype=np.float32)

    return feat, 1.0


In [4]:
#process a record at a time
"""
### STEPS:
    load ecg
    resample if needed
    bandpass filt
    load annotations & build rhythm intervals
    window the record
    label each window 
    detect r peaks once on the full patient record
    slice peaks per window and compute RR features
    per window z score,,, normalize window
    save .npz 

"""
## its prolly easier to detect r peaks once on the whole record instead of doing it per window and it could
# cause boundary effects since it could sit right at a boundary and then we have multiple windows if the same peak
## doing full recrod is more stable, just slice after
def process_one_record(record_id: str):
    record_path = str(DATA_DIR / record_id) 

    # ---- Load signal ----
    sig, fields = wfdb.rdsamp(record_path)   # sig shape: (N, n_channels)
    fs = float(fields["fs"])     #sampling frequency from header
    x = sig[:, LEAD_IDX].astype(np.float32)  #pick channel and cast

    # resample if needed!!!
    if TARGET_FS is not None and float(TARGET_FS) != fs:
        #processing.resample_sig returns (y, new_fields)
        x, _ = processing.resample_sig(x, fs, TARGET_FS)
        x = x.astype(np.float32)
        fs = float(TARGET_FS)


    #bandpass filt
    x = bandpass_filter(x, fs, low_hz=LOW_HZ, high_hz=HIGH_HZ, order=F_ORDER)

    N = len(x)

    ann = wfdb.rdann(record_path, "atr")  # loads .atr annotation stream

    # aux_note is where AFDB rhythm labels live
    if not hasattr(ann, "aux_note") or ann.aux_note is None:
        raise RuntimeError("Annotation missing aux_note field (needed for rhythm intervals).")


    # build rhythm intervals
    rhythm_intervals = build_rhythm_intervals_from_aux(ann.sample, ann.aux_note, N_total_samples=N)
    if not rhythm_intervals:
        raise RuntimeError("No rhythm intervals found in aux_note ;; (cannot label windows)")


    win_len = int(WIN_SEC * fs)
    stride  = int(STRIDE_SEC * fs)

    if N < win_len:
        raise RuntimeError(f"Signal too short: {N} samples < {win_len} samples (one window).")

    # Start indices for windows: [0, stride, 2*stride, ...] up to last full window
    starts = np.arange(0, N - win_len + 1, stride, dtype=np.int64)

    #label each window
    y = np.array(
        [label_window(int(s), int(s + win_len), rhythm_intervals, af_threshold=AF_THRESHOLD) for s in starts],
        dtype=np.int64
    )

    #full record peak detection 
    rpeaks_all = get_rpeaks_from_annotations_or_detect(record_path, x, fs, ann=ann)


    #cal RR features per window by slicing rpeaks
    rr_feat_list = []
    rr_valid_list = []

    for s in starts:
        s0 = int(s)
        s1 = int(s + win_len)

        #Peaks that fall inside the window (record indices)
        mask = (rpeaks_all >= s0) & (rpeaks_all < s1)

        #Cnvert to window-relative indices (so features dont depend on absolute time)
        rpk_win = rpeaks_all[mask] - s0

        f, v = rr_features_from_peaks(rpk_win, fs)
        rr_feat_list.append(f)
        rr_valid_list.append(v)

    rr_feat  = np.stack(rr_feat_list).astype(np.float32)        # (num_windows, 10)
    rr_valid = np.asarray(rr_valid_list, dtype=np.float32)      # (num_windows,)


    #create window tensor ;; faster processing
    X = np.stack([x[s:s + win_len] for s in starts]).astype(np.float32)  # (num_windows, win_len)



    ### z score norm per window\
    # removes amplitude scale diffs accross patiens/segments
    # also stabilizes learning for the network
    Xn = (X - X.mean(axis=1, keepdims=True)) / (X.std(axis=1, keepdims=True) + 1e-8)



    #######
    total_windows = int(len(y))
    af_windows = int(y.sum())
    af_pct_windows = float(y.mean()) if total_windows else np.nan
    af_pct_intervals = float(computeAFfromintervals(rhythm_intervals, N))

 
    #save npz
    out_path = OUT_DIR / f"{record_id}_win{WIN_SEC}_stride{STRIDE_SEC}.npz"

    np.savez_compressed(
        out_path,
        # core learning inputs
        X=Xn,                 # normalized windows
        y=y,                  # window labels
        rr_feat=rr_feat,      # RR features
        rr_valid=rr_valid,    # RR validity mask/flag

        #metadata for reproducibility
        fs=fs,
        win_sec=WIN_SEC,
        stride_sec=STRIDE_SEC,
        record=record_id,
        lead_idx=LEAD_IDX,
        af_threshold=AF_THRESHOLD,
        low_hz=LOW_HZ,
        high_hz=HIGH_HZ,
        filter_order=F_ORDER,
    )

    stats = {
        "record": record_id,
        "fs": fs,
        "samples": N,
        "win_len": win_len,
        "stride": stride,
        "windows": total_windows,
        "af_windows": af_windows,
        "normal_windows": total_windows - af_windows,
        "af_pct_windows": af_pct_windows,
        "af_pct_intervals": af_pct_intervals,
        "npz_path": str(out_path),
        "has_both_classes": (len(np.unique(y)) == 2),
    }

    return stats, Xn, y


In [5]:
#main loop::

print("=" * 60)
print("AFDB Preprocessing Pipeline")
print("=" * 60)
print(f"Data directory:   {DATA_DIR}")
print(f"Output directory: {OUT_DIR}")
print(f"Window: {WIN_SEC}s | Stride: {STRIDE_SEC}s | AF threshold: {AF_THRESHOLD}")
print(f"Filter: {LOW_HZ}-{HIGH_HZ} Hz | order={F_ORDER}")
print("=" * 60)

records = listafdbrecords(DATA_DIR)
print(f"\nFound {len(records)} records")
print("First 10:", records[:10])

summary_rows = []
example_plotted = False

for i, rid in enumerate(records, start=1):
    try:
        stats, Xn, y = process_one_record(rid)
        summary_rows.append(stats)

        print(
            f"[{i:02d}/{len(records)}] {rid:6s} | "
            f"windows={stats['windows']:6d} | "
            f"AF={stats['af_windows']:6d} ({100 * stats['af_pct_windows']:6.2f}%) | "
            f"interval_AF={100 * stats['af_pct_intervals']:6.2f}% | saved"
        )

        # Optional plot: first record that has both classes
        if PLOT_EXAMPLES and (not example_plotted) and stats["has_both_classes"]:
            print(f"\nPlotting example windows from {rid}:")
            #plot_random_windows(Xn, y, k=PLOT_K, seed=PLOT_SEED, title_prefix=f"{rid}: ")
            example_plotted = True

    except Exception as e:
        print(f"[{i:02d}/{len(records)}] {rid:6s} | SKIP: {e}")

        summary_rows.append({
            "record": rid,
            "fs": np.nan,
            "samples": np.nan,
            "win_len": np.nan,
            "stride": np.nan,
            "windows": 0,
            "af_windows": 0,
            "normal_windows": 0,
            "af_pct_windows": np.nan,
            "af_pct_intervals": np.nan,
            "npz_path": "",
            "has_both_classes": False,
            "error": str(e),
        })

df = pd.DataFrame(summary_rows)
df.to_csv(SUMMARY_CSV, index=False)

print("\n" + "=" * 60)
print(f"Summary CSV saved: {SUMMARY_CSV}")
print("=" * 60)

ok = df[df["npz_path"] != ""]
usable = ok[ok["has_both_classes"]]

print("\nProcessing Summary:")
print(f"  Total records found:        {len(records)}")
print(f"  Successfully processed:     {len(ok)}")
print(f"  Failed:                     {len(records) - len(ok)}")
print(f"  Records with both classes:  {len(usable)}")

if len(ok) > 0 and ok["windows"].sum() > 0:
    total_w = int(ok["windows"].sum())
    total_af = int(ok["af_windows"].sum())
    print("\nDataset Statistics:")
    print(f"  Total windows:        {total_w:,}")
    print(f"  Total AF windows:     {total_af:,}")
    print(f"  Total Normal windows: {total_w - total_af:,}")
    print(f"  Overall AF%:          {100 * total_af / total_w:.2f}%")

print("\nDone.")


AFDB Preprocessing Pipeline
Data directory:   ../data/raw/afdb
Output directory: ../artifacts/afdb_npz
Window: 10s | Stride: 5s | AF threshold: 0.2
Filter: 0.5-40.0 Hz | order=4

Found 25 records
First 10: ['00735', '03665', '04015', '04043', '04048', '04126', '04746', '04908', '04936', '05091']
[01/25] 00735  | SKIP: sampto must be greater than sampfrom
[02/25] 03665  | SKIP: sampto must be greater than sampfrom
[03/25] 04015  | windows=  7363 | AF=    55 (  0.75%) | interval_AF=  0.64% | saved
[04/25] 04043  | windows=  7363 | AF=  1685 ( 22.88%) | interval_AF= 21.54% | saved
[05/25] 04048  | windows=  7363 | AF=    80 (  1.09%) | interval_AF=  0.98% | saved
[06/25] 04126  | windows=  7363 | AF=   284 (  3.86%) | interval_AF=  3.74% | saved
[07/25] 04746  | windows=  7363 | AF=  3917 ( 53.20%) | interval_AF= 53.10% | saved
[08/25] 04908  | windows=  7363 | AF=   670 (  9.10%) | interval_AF=  9.06% | saved
[09/25] 04936  | windows=  7363 | AF=  6035 ( 81.96%) | interval_AF= 81.34% | s

*RR interval*
    RR interval = time between two consecutive R-peaks in the ECG.
    Measured in seconds.
    If heartbeats are regular → RR intervals are similar.
    If rhythm is irregular (like atrial fibrillation) → RR intervals vary a lot.

HRV (Heart Rate Variability)
    HRV = statistics computed from a sequence of RR intervals.
    Why HRV matters for AF:
    AF is characterized by irregularly irregular ventricular response.
    HRV features capture this irregularity even when waveform shape looks normal.

rr_feat is a 10-dimensional feature vector per window, computed from RR intervals inside that window.

rr_valid is a binary flag per window:
    rr_valid == 1 → RR features are reliable
    rr_valid == 0 → RR features are unreliable / missing

    Without rr_valid:
    Zeroed RR features would look like meaningful low values
    Model might learn wrong correlations
    With rr_valid:
    Model learns:
    “If RR is invalid, ignore those features”
