In [4]:
# # This Python 3 environment comes with many helpful analytics libraries installed
# # It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# # For example, here's several helpful packages to load

# import numpy as np # linear algebra
# import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# # Input data files are available in the read-only "../input/" directory
# # For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

# import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# # You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# # You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [5]:
!pip install wfdb neurokit2 numpy scipy scikit-learn xgboost torch tqdm

Collecting wfdb
  Downloading wfdb-4.3.0-py3-none-any.whl.metadata (3.8 kB)
Collecting neurokit2
  Downloading neurokit2-0.2.12-py2.py3-none-any.whl.metadata (37 kB)
Collecting pandas>=2.2.3 (from wfdb)
  Downloading pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (91 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m91.2/91.2 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
Downloading wfdb-4.3.0-py3-none-any.whl (163 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m163.8/163.8 kB[0m [31m11.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading neurokit2-0.2.12-py2.py3-none-any.whl (708 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m708.4/708.4 kB[0m [31m27.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (12.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.4/12.4 MB[0m [31m117.1 MB/s[0m eta [36m0:00:00

In [None]:
# ============================================================
# Kaggle Notebook (FULL SOLVED CODE - Robust WFDB Loader)
# RR Imputation GAN (LSTM residual + WGAN-GP) with:
# ✅ Robust WFDB RR loader (auto-detect annotation extension)
# ✅ Per-record normalization (fixes trending outliers)
# ✅ Missing-only L1 + dRR consistency + WGAN-GP
# ============================================================

import os, random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import wfdb

# -------------------------
# CONFIG
# -------------------------
DATASET_ROOT = "/kaggle/input/normal-sinus-dataset/normal-sinus-rhythm-rr-interval-database-1.0.0"

WINDOW_LEN   = 50
STRIDE       = 5
K_MISSING    = 5

BATCH_SIZE   = 128
EPOCHS       = 30
N_CRITIC     = 3

LR_G         = 2e-4
LR_D         = 2e-4

LAMBDA_MISS   = 200.0
LAMBDA_DRR    = 5.0
LAMBDA_SMOOTH = 0.02

GP_LAMBDA     = 10.0
FILL_MODE     = "mean"   # "mean" or "ffill"
SEED          = 42

# -------------------------
# Seed
# -------------------------
def seed_all(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed_all(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# ============================================================
# 1) Robust RR loader (auto-detect annotation extension)
# ============================================================
def list_record_names(dataset_root):
    hea_files = [f for f in os.listdir(dataset_root) if f.endswith(".hea")]
    return sorted(set(os.path.splitext(f)[0] for f in hea_files))

def find_ann_ext_for_record(dataset_root, rec):
    """
    Find possible annotation extensions for a record by scanning files:
      rec.<ext> where ext not in {hea, dat}
    Return list of candidate extensions.
    """
    candidates = []
    prefix = rec + "."
    for fn in os.listdir(dataset_root):
        if fn.startswith(prefix):
            ext = fn.split(".")[-1]
            if ext.lower() not in ["hea", "dat"]:
                candidates.append(ext)
    # common ones first if present
    priority = ["atr", "ecg", "qrs", "ann"]
    candidates_sorted = []
    for p in priority:
        if p in candidates:
            candidates_sorted.append(p)
    for c in candidates:
        if c not in candidates_sorted:
            candidates_sorted.append(c)
    return candidates_sorted

def rr_from_annotation(dataset_root, rec, ext):
    rec_path = os.path.join(dataset_root, rec)
    header = wfdb.rdheader(rec_path)
    fs = float(header.fs)

    ann = wfdb.rdann(rec_path, ext)
    r_samples = ann.sample.astype(np.float32)

    # RR in ms
    rr_ms = np.diff(r_samples) / fs * 1000.0

    # physiological filter
    rr_ms = rr_ms[(rr_ms > 300) & (rr_ms < 2000)]
    return rr_ms.astype(np.float32)

def load_rr_records(dataset_root):
    record_names = list_record_names(dataset_root)
    if len(record_names) == 0:
        raise RuntimeError("No .hea files found. Check DATASET_ROOT.")

    print("Found .hea records:", len(record_names))

    rr_records = []
    loaded = 0
    failed = 0

    for rec in record_names:
        exts = find_ann_ext_for_record(dataset_root, rec)
        got = None
        for ext in exts:
            try:
                rr = rr_from_annotation(dataset_root, rec, ext)
                if len(rr) >= 60:
                    got = rr
                    break
            except:
                continue

        if got is None:
            failed += 1
        else:
            rr_records.append(got)
            loaded += 1

    if loaded == 0:
        # helpful debug: show one record's available files
        sample = record_names[0]
        sample_files = [f for f in os.listdir(dataset_root) if f.startswith(sample + ".")]
        raise RuntimeError(
            "No RR records loaded. WFDB couldn't read any annotation.\n"
            f"Example record '{sample}' files: {sample_files}\n"
            "Your dataset might store peaks differently."
        )

    print(f"Loaded RR records: {loaded} | failed: {failed}")
    return rr_records

rr_records = load_rr_records(DATASET_ROOT)
print("RR lengths (min/mean/max):",
      min(len(r) for r in rr_records),
      int(np.mean([len(r) for r in rr_records])),
      max(len(r) for r in rr_records))

# ============================================================
# 2) Dataset (per-record normalization + windowing)
# ============================================================
class RRWindowDataset(Dataset):
    def __init__(self, rr_records, window_len=50, stride=5):
        self.samples = []
        self.stats = []  # (mu, std) per window

        for rec in rr_records:
            rec = np.asarray(rec, dtype=np.float32)
            if len(rec) < window_len:
                continue

            mu = float(rec.mean())
            std = float(rec.std() + 1e-8)
            rec_norm = (rec - mu) / std

            for s in range(0, len(rec_norm) - window_len + 1, stride):
                self.samples.append(rec_norm[s:s+window_len].astype(np.float32))
                self.stats.append((mu, std))

        if len(self.samples) == 0:
            raise ValueError("No windows created. Reduce WINDOW_LEN/STRIDE or check lengths.")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        x = torch.tensor(self.samples[idx], dtype=torch.float32)
        mu, std = self.stats[idx]
        return x, torch.tensor(mu, dtype=torch.float32), torch.tensor(std, dtype=torch.float32)

# ============================================================
# 3) Masking
# ============================================================
def make_mask_and_corrupt(x, K=5, fill_mode="mean"):
    B, T = x.shape
    m = torch.ones((B, T), device=x.device)
    idx_missing_list = []

    for b in range(B):
        idx = torch.randperm(T, device=x.device)[:K]
        m[b, idx] = 0.0
        idx_missing_list.append(idx)

    if fill_mode == "mean":
        obs_sum = (x * m).sum(dim=1, keepdim=True)
        obs_cnt = m.sum(dim=1, keepdim=True).clamp(min=1.0)
        fill = obs_sum / obs_cnt
        x_obs = x * m + fill * (1.0 - m)
    elif fill_mode == "ffill":
        x_obs = x.clone()
        for b in range(B):
            missing = (m[b] == 0)
            obs_idx = torch.where(m[b] == 1)[0]
            if len(obs_idx) == 0:
                continue
            prev = x_obs[b, obs_idx[0]].item()
            for t in range(T):
                if missing[t]:
                    x_obs[b, t] = prev
                else:
                    prev = x_obs[b, t].item()
    else:
        raise ValueError("fill_mode must be 'mean' or 'ffill'")

    return x_obs, m, idx_missing_list

# ============================================================
# 4) Models
# ============================================================
class LSTMGeneratorResidual(nn.Module):
    def __init__(self, hidden=64, num_layers=2, dropout=0.1):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=4,
            hidden_size=hidden,
            num_layers=num_layers,
            dropout=dropout if num_layers > 1 else 0.0,
            batch_first=True,
            bidirectional=True
        )
        self.head = nn.Sequential(
            nn.Linear(hidden * 2, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1)
        )

    def forward(self, x_base, m, noise):
        dx = x_base[:, 1:] - x_base[:, :-1]
        dx = torch.cat([torch.zeros((x_base.size(0), 1), device=x_base.device), dx], dim=1)
        inp = torch.stack([x_base, m, dx, noise], dim=-1)
        h, _ = self.lstm(inp)
        delta = self.head(h).squeeze(-1)
        return delta

class CNNWganCritic(nn.Module):
    def __init__(self, channels=64, kernel=5):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(1, channels, kernel, padding=kernel//2),
            nn.LeakyReLU(0.2),
            nn.Conv1d(channels, channels, kernel, padding=kernel//2),
            nn.LeakyReLU(0.2),
            nn.Conv1d(channels, channels, kernel, padding=kernel//2),
            nn.LeakyReLU(0.2),
        )
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(channels, 1)

    def forward(self, x):
        x = x.unsqueeze(1)
        h = self.net(x)
        h = self.pool(h).squeeze(-1)
        return self.fc(h).squeeze(-1)

# ============================================================
# 5) WGAN-GP
# ============================================================
def gradient_penalty(critic, real, fake, gp_lambda=10.0):
    B = real.size(0)
    eps = torch.rand(B, 1, device=real.device).expand_as(real)
    x_hat = eps * real + (1 - eps) * fake
    x_hat.requires_grad_(True)

    scores = critic(x_hat)
    grads = torch.autograd.grad(
        outputs=scores,
        inputs=x_hat,
        grad_outputs=torch.ones_like(scores),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]

    grads = grads.view(B, -1)
    gp = ((grads.norm(2, dim=1) - 1.0) ** 2).mean() * gp_lambda
    return gp

# ============================================================
# 6) Train
# ============================================================
def train_wgan(rr_records):
    ds = RRWindowDataset(rr_records, window_len=WINDOW_LEN, stride=STRIDE)
    dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
    print("Total windows:", len(ds))

    G = LSTMGeneratorResidual(hidden=64, num_layers=2, dropout=0.1).to(device)
    D = CNNWganCritic(channels=64, kernel=5).to(device)

    opt_g = optim.Adam(G.parameters(), lr=LR_G, betas=(0.5, 0.9))
    opt_d = optim.Adam(D.parameters(), lr=LR_D, betas=(0.5, 0.9))

    lo, hi = -6.0, 6.0  # normalized clamp

    for ep in range(1, EPOCHS + 1):
        G.train(); D.train()
        d_sum = g_sum = miss_sum = drr_sum = 0.0

        for x, mu, std in dl:
            x = x.to(device)

            # ----- Critic -----
            for _ in range(N_CRITIC):
                with torch.no_grad():
                    x_base, m, _ = make_mask_and_corrupt(x, K=K_MISSING, fill_mode=FILL_MODE)
                    noise = torch.rand_like(x) * (1.0 - m)
                    delta = G(x_base, m, noise)
                    x_hat = m * x + (1.0 - m) * (x_base + delta)
                    x_hat = torch.clamp(x_hat, lo, hi)

                real_score = D(x)
                fake_score = D(x_hat)
                gp = gradient_penalty(D, x, x_hat, gp_lambda=GP_LAMBDA)
                d_loss = (fake_score.mean() - real_score.mean()) + gp

                opt_d.zero_grad(set_to_none=True)
                d_loss.backward()
                opt_d.step()

            # ----- Generator -----
            x_base, m, _ = make_mask_and_corrupt(x, K=K_MISSING, fill_mode=FILL_MODE)
            noise = torch.rand_like(x) * (1.0 - m)
            delta = G(x_base, m, noise)
            x_hat = m * x + (1.0 - m) * (x_base + delta)
            x_hat = torch.clamp(x_hat, lo, hi)

            adv = -D(x_hat).mean()

            denom = (1.0 - m).sum().clamp(min=1.0)
            miss_mae = (torch.abs(x_hat - x) * (1.0 - m)).sum() / denom

            m_pair = m[:, 1:] * m[:, :-1]
            miss_drr_mask = 1.0 - m_pair
            drr_hat = x_hat[:, 1:] - x_hat[:, :-1]
            drr_true = x[:, 1:] - x[:, :-1]
            denom_drr = miss_drr_mask.sum().clamp(min=1.0)
            drr_loss = (torch.abs(drr_hat - drr_true) * miss_drr_mask).sum() / denom_drr

            smooth = torch.mean(torch.abs(drr_hat))

            g_loss = adv + LAMBDA_MISS * miss_mae + LAMBDA_DRR * drr_loss + LAMBDA_SMOOTH * smooth

            opt_g.zero_grad(set_to_none=True)
            g_loss.backward()
            opt_g.step()

            d_sum += d_loss.item()
            g_sum += g_loss.item()
            miss_sum += miss_mae.item()
            drr_sum += drr_loss.item()

        print(
            f"Epoch {ep:03d} | D={d_sum/len(dl):.4f} | G={g_sum/len(dl):.4f} | "
            f"MissMAE={miss_sum/len(dl):.4f} | dRRloss={drr_sum/len(dl):.4f}"
        )

    return G, D

# ============================================================
# 7) Inference + Plot (de-normalize)
# ============================================================
@torch.no_grad()
def impute_one_window_from_record(G, rr_record, start_idx=0, K=5):
    rr_record = np.asarray(rr_record, dtype=np.float32)
    window = rr_record[start_idx:start_idx+WINDOW_LEN]
    if len(window) < WINDOW_LEN:
        raise ValueError("Not enough length for this window.")

    mu = float(window.mean())
    std = float(window.std() + 1e-8)
    x = (window - mu) / std

    x = torch.tensor(x, dtype=torch.float32, device=device).unsqueeze(0)
    x_base, m, idx_list = make_mask_and_corrupt(x, K=K, fill_mode=FILL_MODE)
    idx_missing = idx_list[0].detach().cpu().numpy()

    noise = torch.rand_like(x) * (1.0 - m)
    delta = G(x_base, m, noise)
    x_hat = m * x + (1.0 - m) * (x_base + delta)

    x_true = x.squeeze(0).cpu().numpy() * std + mu
    x_base_np = x_base.squeeze(0).cpu().numpy() * std + mu
    x_hat_np  = x_hat.squeeze(0).cpu().numpy() * std + mu

    err = x_hat_np[idx_missing] - x_true[idx_missing]
    mae = float(np.mean(np.abs(err)))
    rmse = float(np.sqrt(np.mean(err**2)))
    return x_true, x_base_np, x_hat_np, idx_missing, mae, rmse

def plot_imputation(x_true, x_base, x_imp, idx_missing, title):
    plt.figure(figsize=(12,4))
    plt.plot(x_true, label="True RR")
    plt.plot(x_base, label="Masked/Filled RR")
    plt.plot(x_imp, label="Imputed RR")
    plt.scatter(idx_missing, x_true[idx_missing], marker="x", s=80, label="Missing true")
    plt.scatter(idx_missing, x_imp[idx_missing], marker="o", s=80, label="Missing pred")
    plt.title(title)
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

# ============================================================
# 8) Run
# ============================================================
G, D = train_wgan(rr_records)

# Test plots on outlier-like records (by index in rr_records list)
for rid in [1, 39]:
    x_true, x_base, x_imp, idx_m, mae, rmse = impute_one_window_from_record(G, rr_records[rid], start_idx=0, K=K_MISSING)
    plot_imputation(x_true, x_base, x_imp, idx_m, title=f"Record {rid} | MAE={mae:.4f} RMSE={rmse:.4f}")

# Save generator
MODEL_PATH = "/kaggle/working/rr_imputer_lstm_wgangp_residual_autoload.pt"
torch.save({"G_state": G.state_dict(), "window_len": WINDOW_LEN, "k_missing": K_MISSING}, MODEL_PATH)
print("Saved:", MODEL_PATH)


Device: cuda
Found .hea records: 54
Loaded RR records: 54 | failed: 0
RR lengths (min/mean/max): 76760 107178 136481
Total windows: 1157023


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Epoch 001 | D=-0.0303 | G=31.6638 | MissMAE=0.1229 | dRRloss=0.1218
Epoch 002 | D=-0.0439 | G=35.5945 | MissMAE=0.1093 | dRRloss=0.1089
Epoch 003 | D=-0.0433 | G=35.9308 | MissMAE=0.1062 | dRRloss=0.1059
Epoch 004 | D=-0.0439 | G=33.3240 | MissMAE=0.1044 | dRRloss=0.1042
Epoch 005 | D=-0.0432 | G=31.3432 | MissMAE=0.1033 | dRRloss=0.1031
Epoch 006 | D=-0.0424 | G=30.6946 | MissMAE=0.1029 | dRRloss=0.1027
Epoch 007 | D=-0.0425 | G=29.5446 | MissMAE=0.1023 | dRRloss=0.1021
Epoch 008 | D=-0.0426 | G=28.8314 | MissMAE=0.1019 | dRRloss=0.1017
Epoch 009 | D=-0.0423 | G=26.4312 | MissMAE=0.1016 | dRRloss=0.1014
Epoch 010 | D=-0.0425 | G=27.1099 | MissMAE=0.1010 | dRRloss=0.1009
Epoch 011 | D=-0.0422 | G=25.9321 | MissMAE=0.1009 | dRRloss=0.1008
