In [16]:
import pandas as pd
import ast
from sklearn.model_selection import train_test_split

# Load metadata
df = pd.read_csv('C:\\Users\\mstew\\OneDrive\\Curtin Univeristy\\COMP6011\\Task 3\\ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.1\\ptbxl_database.csv')

# Parse scp_codes column to dict
df['scp_codes'] = df['scp_codes'].apply(ast.literal_eval)

# Map all diagnostic labels to SCP codes
priority_map = {
    '1dAVb': ['1AVB'],  # First-degree AV block
    'RBBB': ['IRBBB', 'CRBBB'],  # Right bundle branch block (partial + complete)
    'LBBB': ['ILBBB', 'CLBBB'],  # Left bundle branch block (partial + complete)
    'AFLT': ['AFLT'],  # Atrial flutter
    'AFIB': ['AFIB'],  # Atrial fibrillation
    'NORM': ['NORM'],  # Normal
}

# Function to assign single label based on priority
def assign_priority_label(scp_dict):
    for label, codes in priority_map.items():
        if any(code in scp_dict for code in codes):
            return label
    return 'OTHER'

# Assign single label
df['label'] = df['scp_codes'].apply(assign_priority_label)

# Filter to only those samples with labels in the desired classes
df = df[df['label'].isin(['1dAVb', 'RBBB', 'LBBB', 'AFLT', 'AFIB', 'OTHER', 'NORM'])]

# Split into training and test sets
train_df, test_df = train_test_split(df, test_size=0.2, stratify=df['label'], random_state=42)

# Optional: display class balance
print("Training class distribution:")
print(train_df['label'].value_counts())
print("\nTest class distribution:")
print(test_df['label'].value_counts())


Training class distribution:
label
NORM     7364
OTHER    6792
RBBB     1237
AFIB      977
1dAVb     638
LBBB      408
AFLT       53
Name: count, dtype: int64

Test class distribution:
label
NORM     1841
OTHER    1699
RBBB      309
AFIB      245
1dAVb     159
LBBB      102
AFLT       13
Name: count, dtype: int64


In [16]:
import numpy as np
import pandas as pd
import wfdb
from tqdm import tqdm
from sklearn.preprocessing import LabelEncoder

# Assuming you've already created train_df or test_df with a 'label' column
# and that 'filename_lr' gives paths to 100 Hz signals

def load_and_save_ecg_data(df, data_dir, output_filename):
    signals = []
    labels = []

    for _, row in tqdm(df.iterrows(), total=len(df)):
        path = f"{data_dir}/{row['filename_lr']}"
        try:
            record = wfdb.rdrecord(path)
            sig = record.p_signal  # shape: (n_samples, 12)
            signals.append(sig)
            labels.append(row['label'])
        except Exception as e:
            print(f"Error loading {path}: {e}")

    # Ensure all samples are the same length (optional: pad/truncate if needed)
    min_length = min([s.shape[0] for s in signals])
    signals = [s[:min_length] for s in signals]  # truncate all to min length

    # Convert to arrays
    X = np.array(signals)  # shape: (num_samples, min_length, 12)
    y = np.array(labels)

    

    # Encode labels
    label_encoder = LabelEncoder()
    y_encoded = label_encoder.fit_transform(y)

    # Save to .npz
    np.savez(output_filename, X=X, y=y_encoded, label_names=label_encoder.classes_)

    print(f"Saved to {output_filename} with shape {X.shape}")
    return X, y

# Load and save the ECG data for the test set
X, y = load_and_save_ecg_data(train_df, data_dir='C:\\Users\\mstew\\OneDrive\\Curtin Univeristy\\COMP6011\\Task 3\\ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.1', output_filename='test_data_balanced.npz')

# look at the shape of the data
print(f"Shape of X: {X.shape}")
print(f"Shape of y: {y.shape}")
# Check the first few labels
print("First few labels:", y[:5])

100%|██████████| 17469/17469 [02:41<00:00, 108.48it/s]


Saved to test_data_balanced.npz with shape (17469, 1000, 12)
Shape of X: (17469, 1000, 12)
Shape of y: (17469,)
First few labels: ['RBBB' 'AFIB' 'AFIB' 'OTHER' 'OTHER']


In [26]:
import numpy as np
from scipy.interpolate import interp1d

def jitter(signal, sigma=0.01):
    return signal + np.random.normal(0, sigma, size=signal.shape)

def scaling(signal, sigma=0.1):
    factor = np.random.normal(loc=1.0, scale=sigma)
    return signal * factor

def time_stretch(signal, stretch_factor=0.1):
    orig_len = signal.shape[1]
    factor = 1 + np.random.uniform(-stretch_factor, stretch_factor)
    new_len = int(orig_len * factor)
    x = np.linspace(0, 1, orig_len)
    f = interp1d(x, signal, kind='linear', axis=1, fill_value='extrapolate')
    x_new = np.linspace(0, 1, new_len)
    stretched = f(x_new)
    if stretched.shape[1] > orig_len:
        return stretched[:, :orig_len]
    else:
        # pad if needed
        pad_width = orig_len - stretched.shape[1]
        return np.pad(stretched, ((0, 0), (0, pad_width)), mode='edge')


In [36]:
from collections import Counter
from sklearn.utils import shuffle

def generate_augmented_samples(X_class, num_needed, label):
    num_synth = int(num_needed // 3)
    num_aug = num_needed - num_synth
    
    # Synthetic oversampling
    synth_samples = X_class[np.random.randint(0, len(X_class), num_synth)]

    # Augmentation: jitter, scaling, stretch (round robin)
    aug_samples = []
    funcs = [jitter, scaling, time_stretch]
    for i in range(num_aug):
        idx = np.random.randint(0, len(X_class))
        func = funcs[i % len(funcs)]
        aug = func(X_class[idx])
        aug_samples.append(aug)

    aug_samples = np.array(aug_samples)
    X_boosted = np.concatenate([synth_samples, aug_samples], axis=0)
    y_boosted = np.full(len(X_boosted), label)
    return X_boosted, y_boosted


In [37]:

# Split the data into features and labels
label_col = 'label'
X_train = train_df.drop(columns=[label_col])
y_train = train_df[label_col]

print(y_train.shape)

target_min = 2000
X_train_boosted = []
y_train_boosted = []

class_counts = Counter(y_train)

import wfdb

data_dir = 'C:\\Users\\mstew\\OneDrive\\Curtin Univeristy\\COMP6011\\Task 3\\ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.1'

# For each class, load signals and augment if needed
for cls, count in class_counts.items():
    print(count)
    X_cls_df = X_train[y_train == cls]
    y_cls = y_train[y_train == cls]
    # Load signals for this class
    signals = []
    for fname in X_cls_df['filename_lr']:
        path = f"{data_dir}/{fname}"
        try:
            record = wfdb.rdrecord(path)
            sig = record.p_signal.T  # shape: (12, n_samples)
            signals.append(sig)
        except Exception as e:
            print(f"Error loading {path}: {e}")
    # Make all signals the same length
    min_length = min([s.shape[1] for s in signals])
    signals = [s[:, :min_length] for s in signals]
    X_cls_signals = np.stack(signals, axis=0)  # shape: (num_samples, 12, min_length)
    y_cls_arr = np.array([cls] * len(X_cls_signals))
    X_train_boosted.append(X_cls_signals)
    y_train_boosted.append(y_cls_arr)
    if count < target_min:
        needed = target_min - count
        X_new, y_new = generate_augmented_samples(X_cls_signals, needed, cls)
        X_train_boosted.append(X_new)
        y_train_boosted.append(y_new)

X_train_balanced = np.concatenate(X_train_boosted, axis=0)
y_train_balanced = np.concatenate(y_train_boosted, axis=0)

X_train_balanced, y_train_balanced = shuffle(X_train_balanced, y_train_balanced, random_state=42)

print("Training class distribution after balancing:")
print(dict(Counter(y_train_balanced)))


(17469,)
1237
977
6792
7364
638
408
53
Training class distribution after balancing:
{'AFIB': 2000, 'NORM': 7364, 'OTHER': 6792, 'LBBB': 2000, 'AFLT': 2000, 'RBBB': 2000, '1dAVb': 2000}


In [15]:
# Duplicate underrepresented test samples by loading ECG signals and duplicating as needed
label_col = 'label'
target_test_min = 400
data_dir = 'C:\\Users\\mstew\\OneDrive\\Curtin Univeristy\\COMP6011\\Task 3\\ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.1'

X_test_boosted = []
y_test_boosted = []

test_counts = Counter(test_df[label_col])

for cls, count in test_counts.items():
    X_cls_df = test_df[test_df[label_col] == cls]
    signals = []
    for fname in X_cls_df['filename_lr']:
        path = f"{data_dir}/{fname}"
        try:
            record = wfdb.rdrecord(path)
            sig = record.p_signal.T  # shape: (12, n_samples)
            signals.append(sig)
        except Exception as e:
            print(f"Error loading {path}: {e}")
    if len(signals) == 0:
        continue
    min_length = min([s.shape[1] for s in signals])
    signals = [s[:, :min_length] for s in signals]
    X_cls_signals = np.stack(signals, axis=0)  # (num_samples, 12, min_length)
    y_cls_arr = np.array([cls] * len(X_cls_signals))
    X_test_boosted.append(X_cls_signals)
    y_test_boosted.append(y_cls_arr)
    if count < target_test_min:
        needed = target_test_min - count
        reps = needed // len(X_cls_signals) + 1
        X_dup = np.tile(X_cls_signals, (reps, 1, 1))[:needed]
        y_dup = np.full(len(X_dup), cls)
        X_test_boosted.append(X_dup)
        y_test_boosted.append(y_dup)

X_test_balanced = np.concatenate(X_test_boosted, axis=0)
y_test_balanced = np.concatenate(y_test_boosted, axis=0)

X_test_balanced, y_test_balanced = shuffle(X_test_balanced, y_test_balanced, random_state=42)

print("Test class distribution after balancing:")
print(dict(Counter(y_test_balanced)))

NameError: name 'test_df' is not defined

In [39]:
np.savez('X_train_balanced.npz', X=X_train_balanced)
np.savez('y_train_balanced.npz', y=y_train_balanced)
np.savez('X_test_balanced.npz', X=X_test_balanced)
np.savez('y_test_balanced.npz', y=y_test_balanced)

In [21]:
# Use the already prepared balanced arrays
X = X_train_balanced
y = y_train_balanced

# Encode labels
label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(y)

# Save to .npz
output_filename = 'train_data_balanced.npz'
np.savez(output_filename, X=X, y=y_encoded, label_names=label_encoder.classes_)

print(f"Saved to {output_filename} with shape {X.shape}")
print(f"Shape of X: {X.shape}")
print(f"Shape of y: {y.shape}")
print("First few labels:", y[:5])

Saved to train_data_balanced.npz with shape (24156, 12, 1000)
Shape of X: (24156, 12, 1000)
Shape of y: (24156,)
First few labels: ['AFIB' 'NORM' 'OTHER' 'OTHER' 'OTHER']


In [10]:
output_filename = 'train_data.npz'
# Encode labels
label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(y)

# Save to .npz
np.savez(output_filename, X=X, y=y_encoded, label_names=label_encoder.classes_)

print(f"Saved to {output_filename} with shape {X.shape}")



Saved to train_data.npz with shape (17469, 1000, 12)


The below section is for prepping the MIT-BIH arrhythmia database for use with our model

In [1]:
import wfdb
import numpy as np
import os
from scipy.signal import resample

def load_mitbih_data(path, target_length=1000, target_fs=100):
    """
    Load and prepare ECG data from MIT-BIH Arrhythmia Database.

    Args:
        path (str): Directory containing .dat, .atr, and .hea files.
        target_length (int): Number of samples expected by model (e.g., 1000).
        target_fs (int): Target sampling rate for model (e.g., 100 Hz).

    Returns:
        list of np.ndarray: ECGs with shape (2, 1000, 1) from MLII and V1 leads.
        list of str: Corresponding rhythm annotations per sample.
    """
    ecg_segments = []
    annotations = []

    files = [f[:-4] for f in os.listdir(path) if f.endswith(".hea")]
    print(f"Found {len(files)} records.")

    for record_name in files:
        try:
            # Read signal and annotations
            record = wfdb.rdrecord(os.path.join(path, record_name))
            ann = wfdb.rdann(os.path.join(path, record_name), 'atr')
            signal = record.p_signal  # shape (N, 2) usually: MLII, V1
            original_fs = record.fs

            # Leads present
            leads = record.sig_name
            lead_indices = {name: i for i, name in enumerate(leads)}
            if 'MLII' not in lead_indices or 'V1' not in lead_indices:
                continue

            # Extract and resample the leads
            ml_ii = signal[:, lead_indices['MLII']]
            v1 = signal[:, lead_indices['V1']]

            # Resample to 100 Hz
            new_len = int(len(ml_ii) * target_fs / original_fs)
            ml_ii_resampled = resample(ml_ii, new_len)
            v1_resampled = resample(v1, new_len)

            # Sliding window extraction
            step = target_length
            for i in range(0, len(ml_ii_resampled) - target_length, step):
                segment = np.zeros((12, target_length))  # Fill all 12 leads with zero
                segment[0] = ml_ii_resampled[i:i+target_length]  # Simulate as lead I
                segment[6] = v1_resampled[i:i+target_length]     # Simulate as V1

                ecg_segments.append(segment[..., np.newaxis])  # shape: (12, 1000, 1)
                annotations.append("UNKNOWN")  # Replace later with label if needed

        except Exception as e:
            print(f"Failed to process {record_name}: {e}")

    print(f"Prepared {len(ecg_segments)} samples.")
    return ecg_segments, annotations


In [2]:
def get_mitbih_classes(path):
    from collections import Counter
    all_classes = Counter()

    files = [f[:-4] for f in os.listdir(path) if f.endswith(".hea")]
    for record_name in files:
        try:
            ann = wfdb.rdann(os.path.join(path, record_name), 'atr')
            symbols = ann.symbol
            all_classes.update(symbols)
        except Exception as e:
            print(f"Could not read annotations for {record_name}: {e}")

    print("Annotation symbols and counts:")
    for symbol, count in all_classes.items():
        print(f"{symbol}: {count}")


In [3]:
# Replace with your MIT-BIH dataset directory
mitbih_path = "C:\\Users\\mstew\\OneDrive\\Curtin Univeristy\\COMP6011\\Task 3\\mit-bih-arrhythmia-database-1.0.0"

# Print class labels
get_mitbih_classes(mitbih_path)

# # Load data for model
# X_test_mitbih, y_labels = load_mitbih_data(mitbih_path)

# # Convert to numpy array
# X_test_mitbih = np.array(X_test_mitbih)  # shape: (N, 12, 1000, 1)


Annotation symbols and counts:
+: 1291
N: 75052
A: 2546
V: 7130
~: 616
|: 132
Q: 33
/: 7028
f: 982
x: 193
F: 803
j: 229
L: 8075
a: 150
J: 83
R: 7259
[: 6
!: 472
]: 6
E: 106
S: 2
": 437
e: 16


In [10]:
import wfdb
import numpy as np
from collections import defaultdict
from sklearn.preprocessing import LabelEncoder
import os

# Fixed label order for model alignment
FIXED_LABEL_ORDER = ['1dAVb', 'AFIB', 'AFLT', 'LBBB', 'NORM', 'OTHER', 'RBBB']

# Mapping annotation keywords to class names
MITBIH_TO_CUSTOM = {
    '1dAVb': ['1AV'],
    'RBBB': ['R'],
    'LBBB': ['L'],
    'AFLT': ['AFL'],
    'AFIB': ['AFIB'],
    'NORM': ['N']
}

def map_annotation(comment):
    for cls, terms in MITBIH_TO_CUSTOM.items():
        for term in terms:
            if term in comment:
                return cls
    return "OTHER"

def prepare_mitbih_dataset(data_dir, record_list, max_per_class=1500):
    class_counts = defaultdict(int)
    X_data = []
    y_data = []

    for rec in record_list:
        record_path = os.path.join(data_dir, rec)
        try:
            record = wfdb.rdrecord(record_path)
            annotation = wfdb.rdann(record_path, 'atr')
        except Exception as e:
            print(f"Skipping {rec} due to error: {e}")
            continue

        signals = record.p_signal.T  # Shape: (n_leads, n_samples)
        fs = record.fs
        if signals.shape[0] != 2:
            print(f"Skipping {rec}: expected 2 leads, found {signals.shape[0]}")
            continue

        for i, (sample, sym) in enumerate(zip(annotation.sample, annotation.aux_note)):
            label = map_annotation(sym)

            if class_counts[label] >= max_per_class:
                continue

            # 1000-sample segment centered on beat
            start = sample - 500
            end = sample + 500

            if start < 0 or end > signals.shape[1]:
                continue

            segment = signals[:, start:end]  # shape: (2, 1000)

            if segment.shape[1] != 1000:
                continue

            # Expand to (12, 1000) by repeating leads if necessary
            padded_segment = np.zeros((12, 1000))
            padded_segment[:2, :] = segment

            X_data.append(padded_segment[..., np.newaxis])  # shape: (12, 1000, 1)
            y_data.append(label)
            class_counts[label] += 1


    # Encode y_data with fixed label order
    label_to_index = {label: idx for idx, label in enumerate(FIXED_LABEL_ORDER)}
    y_encoded = np.array([label_to_index[label] for label in y_data])

    print("Class distribution:", dict(class_counts))
    print("Label mapping:", label_to_index)

    X = np.array(X_data)
    y = y_encoded

    return X, y, label_to_index


In [36]:
# records = ['100', '101', '102', '103', '104'] 
records = [f"{i:03d}" for i in range(100, 234)]  # Adjust as needed
X, y, label_map = prepare_mitbih_dataset(mitbih_path, records)
print("Shape of X:", X.shape)  # Expected: (N, 12, 1000, 1)
print("Label mapping:", label_map)


Skipping 110 due to error: [Errno 2] No such file or directory: 'C:/Users/mstew/OneDrive/Curtin Univeristy/COMP6011/Task 3/mit-bih-arrhythmia-database-1.0.0/110.hea'
Skipping 120 due to error: [Errno 2] No such file or directory: 'C:/Users/mstew/OneDrive/Curtin Univeristy/COMP6011/Task 3/mit-bih-arrhythmia-database-1.0.0/120.hea'
Skipping 125 due to error: [Errno 2] No such file or directory: 'C:/Users/mstew/OneDrive/Curtin Univeristy/COMP6011/Task 3/mit-bih-arrhythmia-database-1.0.0/125.hea'
Skipping 126 due to error: [Errno 2] No such file or directory: 'C:/Users/mstew/OneDrive/Curtin Univeristy/COMP6011/Task 3/mit-bih-arrhythmia-database-1.0.0/126.hea'
Skipping 127 due to error: [Errno 2] No such file or directory: 'C:/Users/mstew/OneDrive/Curtin Univeristy/COMP6011/Task 3/mit-bih-arrhythmia-database-1.0.0/127.hea'
Skipping 128 due to error: [Errno 2] No such file or directory: 'C:/Users/mstew/OneDrive/Curtin Univeristy/COMP6011/Task 3/mit-bih-arrhythmia-database-1.0.0/128.hea'
Skip

In [37]:
from collections import Counter
import wfdb
import os

def check_1dAVb_presence(data_dir, record_list):
    counter = Counter()
    for rec in record_list:
        try:
            annotation = wfdb.rdann(os.path.join(data_dir, rec), 'atr')
            for note in annotation.aux_note:
                if '1AV' in note:
                    counter[rec] += 1
        except Exception as e:
            print(f"Skipping {rec} due to error: {e}")
    return counter

# Example usage
records = ['100', '101', '102', '103', '104', '105']  # or all 48
avb_cases = check_1dAVb_presence('path/to/mitbih', records)
print("1dAVb cases found:", avb_cases)

Skipping 100 due to error: [Errno 2] No such file or directory: 'c:/Users/mstew/OneDrive/Curtin Univeristy/COMP6011/Task 3/ECG-classification-main/path/to/mitbih/100.atr'
Skipping 101 due to error: [Errno 2] No such file or directory: 'c:/Users/mstew/OneDrive/Curtin Univeristy/COMP6011/Task 3/ECG-classification-main/path/to/mitbih/101.atr'
Skipping 102 due to error: [Errno 2] No such file or directory: 'c:/Users/mstew/OneDrive/Curtin Univeristy/COMP6011/Task 3/ECG-classification-main/path/to/mitbih/102.atr'
Skipping 103 due to error: [Errno 2] No such file or directory: 'c:/Users/mstew/OneDrive/Curtin Univeristy/COMP6011/Task 3/ECG-classification-main/path/to/mitbih/103.atr'
Skipping 104 due to error: [Errno 2] No such file or directory: 'c:/Users/mstew/OneDrive/Curtin Univeristy/COMP6011/Task 3/ECG-classification-main/path/to/mitbih/104.atr'
Skipping 105 due to error: [Errno 2] No such file or directory: 'c:/Users/mstew/OneDrive/Curtin Univeristy/COMP6011/Task 3/ECG-classification-mai

In [38]:
import numpy as np

# Load the PTB-XL balanced training set
train_data = np.load('X_train_balanced.npz', allow_pickle=True)
y_data = np.load('y_train_balanced.npz', allow_pickle=True)
X_ptbxl = train_data['X']
y_ptbxl = y_data['y']

# Your label order: ['1dAVb', 'AFIB', 'AFLT', 'LBBB', 'NORM', 'OTHER', 'RBBB']
label_order = ['1dAVb', 'AFIB', 'AFLT', 'LBBB', 'NORM', 'OTHER', 'RBBB']

print("y_ptbxl shape:", y_ptbxl.shape)
print("First few entries in y_ptbxl:", y_ptbxl[:5])

# Find an index for each of the missing classes
index_1davb = np.where(y_ptbxl == '1dAVb')[0][0]
index_aflt = np.where(y_ptbxl == 'AFLT')[0][0]

int_label_1davb = label_order.index('1dAVb') # This will be 0
int_label_aflt = label_order.index('AFLT')   # This will be 2
# Convert the string labels for the added examples to their integer mapping
y_1davb_add = np.array([label_order.index('1dAVb')]) # This will be [0]
y_aflt_add = np.array([label_order.index('AFLT')])   # This will be [2]

x_1davb = X_ptbxl[index_1davb]
x_aflt = X_ptbxl[index_aflt]

# Ensure the new samples have shape (1, 12, 1000, 1)
x_1davb = x_1davb.reshape(12, 1000, 1)  # Reshape to (12, 1000, 1)
x_aflt = x_aflt.reshape(12, 1000, 1)
# Expand dimensions to add batch size
x_1davb = np.expand_dims(x_1davb, axis=0)  # Add batch dimension
x_aflt = np.expand_dims(x_aflt, axis=0)

# Double-check all shapes are 4D
print("x_1davb shape:", x_1davb.shape)
print("x_aflt shape:", x_aflt.shape)
print("Original X shape:", X.shape)

print(y[:5])  # Check first few labels
print("y_1davb_add:", y_1davb_add)
print(y.shape)

X = np.concatenate([X, x_1davb, x_aflt], axis=0)
y = np.concatenate([y, y_1davb_add, y_aflt_add], axis=0)

print("✅ New test set shape:", X.shape, y.shape)

y_ptbxl shape: (24156,)
First few entries in y_ptbxl: ['AFIB' 'NORM' 'OTHER' 'OTHER' 'OTHER']
x_1davb shape: (1, 12, 1000, 1)
x_aflt shape: (1, 12, 1000, 1)
Original X shape: (2289, 12, 1000, 1)
[5 5 5 5 5]
y_1davb_add: [0]
(2289,)
✅ New test set shape: (2291, 12, 1000, 1) (2291,)


In [39]:
#save the X and y arrays to .npz files
output_filename = 'mitbih_data.npz'
np.savez(output_filename, X=X, y=y, label_map=label_map)

In [61]:
import os
import numpy as np
import wfdb
from sklearn.preprocessing import MultiLabelBinarizer
from scipy.signal import resample
from collections import defaultdict

def extract_diagnosis_from_hea(file_path):
    """
    Extract diagnosis codes from a .hea file.
    
    Args:
        file_path (str): Path to the .hea file.
    
    Returns:
        list of str: List of diagnosis codes (as strings).
    """
    with open(file_path, 'r') as f:
        for line in f:
            if line.startswith('#Dx:'):
                # Get the part after "#Dx:"
                dx_line = line.strip().split(':', 1)[1]
                # Diagnosis codes can be comma-separated
                dx_codes = [code.strip() for code in dx_line.split(',')]
                return dx_codes
    return []



def prepare_ecg_arrhythmia_dataset_debug(data_dir, label_mapping, desired_labels, max_records=None):
    import os
    import wfdb
    import numpy as np
    from scipy.signal import resample
    from sklearn.preprocessing import MultiLabelBinarizer
    from collections import defaultdict
    from glob import glob

    # use glob to find all .hea files in the directory and its subdirectories
    if not os.path.exists(data_dir):
        print(f"Directory {data_dir} does not exist.")
        return None, None
    record_files = glob(data_dir + '/**/*.hea', recursive=True)
    
    # record_files = Path(data_dir).rglob('*.hea')
        


    # record_files = [f for f in os.listdir(data_dir) if f.endswith('.hea')]
    print(f"Found {len(record_files)} records in {data_dir}")
    if max_records:
        record_files = record_files[:max_records]

    X = []
    y = []
    class_counts = defaultdict(int)

    total_checked = 0

    for file in record_files:
        record_name = file[:-4]
        total_checked += 1

        try:
            record = wfdb.rdrecord(os.path.join(data_dir, record_name))
            signal = record.p_signal.T  # shape (12, time)

            if signal.shape[1] < 5000:
                print(f"Skipping {record_name} (too short: {signal.shape[1]})")
                continue

            resampled = resample(signal, 1000, axis=1)

            # Attempt to get annotations
            # annotations are found in the .hea files
            # look for '#Dx' line in the .hea file and parse it
            annotation = extract_diagnosis_from_hea(os.path.join(data_dir, file))
            # print(f"Annotations for {record_name}: {annotation}")
            
            labels = set()
            
            if annotation:
                for note in annotation:
                    
                    mapped = label_mapping.get(note)
                    if mapped:
                        labels.add(mapped)
            else:
                print(f"Missing annotation for {record_name}")

            # Fallback to OTHER if nothing found
            if not labels:
                labels = {'OTHER'}
            else:
                labels = {lbl if lbl in desired_labels else 'OTHER' for lbl in labels}
                if 'OTHER' in labels and len(labels) > 1:
                    labels.discard('OTHER')

            label = list(labels)[0]
            class_counts[label] += 1

            X.append(resampled)
            y.append(label)

        except Exception as e:
            print(f"Failed on {record_name}: {e}")

    print(f"\n🔍 Checked {total_checked} records, loaded {len(X)}")
    for k, v in class_counts.items():
        print(f"  {k}: {v}")

    if len(X) == 0:
        print("⚠️ No data loaded. Double-check annotation file format and mapping keys.")
        return None, None

    X = np.array(X)[..., np.newaxis]
    mlb = MultiLabelBinarizer(classes=desired_labels)
    y = mlb.fit_transform([[lbl] for lbl in y])

    return X, y


In [62]:
label_mapping = {
    '270492004': '1dAVb',
    '164889003': 'AFIB',
    '164890007': 'AFLT',
    '164909002': 'LBBB',
    '426783006': 'NORM',
    '59118001': 'RBBB',
    # Unmapped labels will be classified as OTHER
}

desired_labels = ['1dAVb', 'AFIB', 'AFLT', 'LBBB', 'NORM', 'OTHER', 'RBBB']

X_ecg, y_ecg = prepare_ecg_arrhythmia_dataset_debug(
    data_dir='C:\\Users\\mstew\\OneDrive\\Curtin Univeristy\\COMP6011\\Task 3\\a-large-scale-12-lead-electrocardiogram-database-for-arrhythmia-study-1.0.0\WFDBRecords\\',
    label_mapping=label_mapping,
    desired_labels=desired_labels,
    max_records=5000
)


Found 45152 records in C:\Users\mstew\OneDrive\Curtin Univeristy\COMP6011\Task 3\a-large-scale-12-lead-electrocardiogram-database-for-arrhythmia-study-1.0.0\WFDBRecords\
Failed on C:\Users\mstew\OneDrive\Curtin Univeristy\COMP6011\Task 3\a-large-scale-12-lead-electrocardiogram-database-for-arrhythmia-study-1.0.0\WFDBRecords\01\019\JS01052: time data '/' does not match format '%d/%m/%Y'

🔍 Checked 5000 records, loaded 4999
  RBBB: 177
  OTHER: 2650
  AFLT: 239
  AFIB: 802
  NORM: 944
  1dAVb: 107
  LBBB: 80


In [66]:
print(X_ecg.shape)
print(y_ecg.shape)

# remove 1500 samples from OTHER class
if X_ecg is not None and y_ecg is not None:
    print("Shapes before removing OTHER:", X_ecg.shape, y_ecg.shape)
    
    # Find indices of OTHER class
    other_indices = np.where(y_ecg[:, 5] == 1)[0]  # Assuming 'OTHER' is at index 5
    if len(other_indices) > 1500:
        indices_to_remove = np.random.choice(other_indices, size=1500, replace=False)
        mask = np.ones(len(y_ecg), dtype=bool)
        mask[indices_to_remove] = False
        X_ecg = X_ecg[mask]
        y_ecg = y_ecg[mask]
    
    print("Shapes after removing OTHER:", X_ecg.shape, y_ecg.shape)

# change the 1 hot encoding of y_ecg to a single label
y_ecg_single = np.argmax(y_ecg, axis=1)  # Convert to single label per sample

print("y_ecg_single shape:", y_ecg_single.shape)
print("First few labels:", y_ecg_single[:5])

# save the X and y arrays to .npz files
output_filename = 'ecg_arrhythmia_data.npz'
np.savez(output_filename, X=X_ecg, y=y_ecg_single, label_map=desired_labels)

(3499, 12, 1000, 1)
(3499, 7)
Shapes before removing OTHER: (3499, 12, 1000, 1) (3499, 7)
Shapes after removing OTHER: (3499, 12, 1000, 1) (3499, 7)
y_ecg_single shape: (3499,)
First few labels: [6 5 2 1 4]
