
# ðŸ«€ ECG Beat Classification & Visualization (20s Samples at 125Hz)

This notebook loads ECG `.csv` files from a folder, preprocesses them, resamples to 500Hz, applies a pretrained deep learning model to classify rhythms (e.g., PVC, VT), and visualizes the results using interactive widgets.


In [1]:

from pathlib import Path
import pandas as pd
import numpy as np
import torch
from torch_ecg.models import ECG_CRNN_CINC2021
from torch_ecg.model_configs import ECG_CRNN_CINC2021_CONFIG
from torch_ecg.utils.utils_signal import resample_poly
from ipywidgets import interact, widgets
import matplotlib.pyplot as plt

# Set data directory (adjust if needed)
data_dir = Path("./Data/Sleep_on_20240809_230342_by_Etienne_5010176_94DEB87F66D7_RAWECG/")

# Load all ECG files
ecg_files = sorted(data_dir.glob("ecg_*.csv"))
print(f"Found {len(ecg_files)} ECG files.")


ModuleNotFoundError: No module named 'torch'

In [None]:

# Load pretrained model
config = ECG_CRNN_CINC2021_CONFIG
model = ECG_CRNN_CINC2021(config)
model.eval()

# If you have pretrained weights, you can load them here
# model.load_state_dict(torch.load("path_to_weights.pt"))


In [None]:

def prepare_input(ecg_data, original_fs=125):
    ecg_resampled = resample_poly(ecg_data, 500, original_fs)
    ecg_tensor = torch.tensor(ecg_resampled).float().unsqueeze(0).unsqueeze(0)
    return ecg_tensor

def classify_ecg(model, ecg_tensor):
    with torch.no_grad():
        pred = model(ecg_tensor)
    return torch.sigmoid(pred).numpy().flatten()


In [None]:

labels = config.classes

def view_file(file_idx):
    file = ecg_files[file_idx]
    ecg = pd.read_csv(file, header=None).squeeze().values
    ecg_tensor = prepare_input(ecg)
    preds = classify_ecg(model, ecg_tensor)

    # Plot ECG
    plt.figure(figsize=(12, 4))
    plt.plot(np.arange(len(ecg)) / 125, ecg)
    plt.title(f"{file.name} - Prediction")
    plt.xlabel("Time (s)")
    plt.ylabel("Amplitude")
    plt.grid(True)
    plt.show()

    # Show predictions with probability > 0.3
    print("Predicted rhythms:")
    for lbl, prob in zip(labels, preds):
        if prob > 0.3:
            print(f"  {lbl}: {prob:.2f}")

interact(view_file, file_idx=widgets.IntSlider(min=0, max=len(ecg_files)-1, step=1, value=0));
