# ü´Ä Clinical-Grade Multimodal ECG Training Pipeline  
<span style="color:red">by Ridwan Oladipo, MD | Medical AI Specialist</span>  

Production-ready training pipeline for **12-lead ECG classification**, implementing a **ResNet-1D + tabular fusion network** with:  

- **ResNet-1D signal branch** ‚Üí temporal P‚ÄìQRS‚ÄìT wave & rhythm morphology modeling  
- **Clinical metadata branch** ‚Üí HR/HRV + age/sex + device harmonization  
- **Late fusion** ‚Üí integrated ECG + tabular decision space
- **Binary cross-entropy loss** for multilabel setting  
- **Recall-optimized callbacks** ‚Üí early stopping & checkpointing to maximize **myocardial infarction sensitivity**  
- **Reproducible training** with fixed seeds & official PTB-XL stratified folds  (preventing patient leakage)

üöÄ Trains on **~17k+ ECGs** with structured logging & TensorBoard monitoring.  
>‚öïÔ∏è **Clinically-aligned optimization** ‚Äî tuning for **sensitivity and NPV in myocardial infarction detection**, the metrics that matter most in cardiology.

## üß©Environment Setup and Data Loading

In [1]:
# Essential libraries for deep learning and model training
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Conv1D, BatchNormalization, Activation, Add
from tensorflow.keras.layers import MaxPooling1D, GlobalAveragePooling1D, Dropout, concatenate
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, TensorBoard

# For monitoring and evaluation
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import datetime
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import train_test_split

# Load Preprocessed Data
base_dir = "/kaggle/input/ecg-preprocessed"
all_signals = np.load(f"{base_dir}/all_signals.npy", allow_pickle=True)
y_labels = np.load(f"{base_dir}/y_labels.npy", allow_pickle=True)
all_features = pd.read_parquet(f"{base_dir}/all_features.parquet")
model_df_with_labels = pd.read_parquet(f"{base_dir}/model_df_with_labels.parquet")

# Reproducibility
np.random.seed(42)
tf.random.set_seed(42)
print("Random seeds set for reproducibility")

print("=== Training Environment Initialized ===")
print(f"TensorFlow version: {tf.__version__}")
print(f"GPU available: {tf.config.list_physical_devices('GPU')}")

print(f"\n=== Preprocessed Data Verification ===")
print(f"Signals shape: {all_signals.shape}")
print(f"Features shape: {all_features.shape}")
print(f"Labels shape: {y_labels.shape}")
print(f"Classes: {y_labels.shape[1]}")

2025-09-24 20:22:47.360025: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1758745367.382678      93 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1758745367.389419      93 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Random seeds set for reproducibility
=== Training Environment Initialized ===
TensorFlow version: 2.18.0
GPU available: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

=== Preprocessed Data Verification ===
Signals shape: (21837, 1000, 12)
Features shape: (21837, 190)
Labels shape: (21837, 5)
Classes: 5


## üîÄTrain/Test Split

In [2]:
# Using PTB-XL Official strat_fold for Train/Test Split
print("\n=== Using PTB-XL Official strat_fold for Train/Test Split ===")

train_idx = model_df_with_labels['strat_fold'] < 9  # folds 1‚Äì8 = train
test_idx = model_df_with_labels['strat_fold'] >= 9  # folds 9‚Äì10 = test

X_ecg_train, X_ecg_test = all_signals[train_idx], all_signals[test_idx]
X_tab_train, X_tab_test = all_features.loc[train_idx], all_features.loc[test_idx]
y_train, y_test = y_labels[train_idx], y_labels[test_idx]

print(f"‚úì Training set: {len(X_ecg_train):,} samples")
print(f"‚úì Test set: {len(X_ecg_test):,} samples")

# Class Distribution Verification
class_names = ['NORM', 'MI', 'STTC', 'CD', 'HYP']
train_class_dist = y_train.mean(axis=0)
test_class_dist = y_test.mean(axis=0)

print("\n=== Class Distribution Verification ===")
for i, cls in enumerate(class_names):
    diff = abs(train_class_dist[i] - test_class_dist[i])
    print(f"{cls}: Train {train_class_dist[i]:.3f} | Test {test_class_dist[i]:.3f} | Diff {diff:.3f}")


=== Using PTB-XL Official strat_fold for Train/Test Split ===
‚úì Training set: 17,441 samples
‚úì Test set: 4,396 samples

=== Class Distribution Verification ===
NORM: Train 0.436 | Test 0.437 | Diff 0.001
MI: Train 0.252 | Test 0.250 | Diff 0.002
STTC: Train 0.240 | Test 0.240 | Diff 0.000
CD: Train 0.224 | Test 0.226 | Diff 0.002
HYP: Train 0.122 | Test 0.121 | Diff 0.000


## üèóÔ∏è ResNet-1D Architecture for ECG Signal Processing

In [3]:
def ResNet1D_block(X, filters, kernel_size=7, stride=1):
    """
    ResNet 1D block with skip connection for ECG signal processing

    This architecture is designed for temporal sequence modeling
    in ECG signals, capturing both local and global cardiac patterns.

    Args:
        X: Input tensor
        filters: Number of convolutional filters
        kernel_size: Convolution kernel size (default 7 for ECG)
        stride: Convolution stride

    Returns:
        Output tensor with residual connection
    """
    X_shortcut = X

    # First convolutional component
    X = Conv1D(filters=filters, kernel_size=kernel_size, strides=stride, padding='same')(X)
    X = BatchNormalization()(X)
    X = Activation('relu')(X)

    # Second convolutional component
    X = Conv1D(filters=filters, kernel_size=kernel_size, strides=1, padding='same')(X)
    X = BatchNormalization()(X)

    # Skip connection handling for dimension matching
    if stride > 1 or X_shortcut.shape[-1] != filters:
        X_shortcut = Conv1D(filters=filters, kernel_size=1, strides=stride, padding='same')(X_shortcut)
        X_shortcut = BatchNormalization()(X_shortcut)

    # Add residual connection and activate
    X = Add()([X, X_shortcut])
    X = Activation('relu')(X)

    return X


def create_resnet1d_multimodal_model(ecg_shape, tab_shape, n_classes=5):
    """
    Build a ResNet-1D based multimodal model for ECG + tabular data classification

    This architecture follows cardiology-informed design principles:
    - Multi-scale temporal feature extraction for rhythm analysis
    - Hierarchical feature learning for morphological pattern recognition
    - Clinical metadata integration for comprehensive assessment
    - Multi-label output for clinical scenarios

    Args:
        ecg_shape: Shape of ECG signal input (timesteps, leads)
        tab_shape: Number of tabular features
        n_classes: Number of output diagnostic classes

    Returns:
        Compiled Keras model ready for training
    """

    # ECG SIGNAL BRANCH
    ecg_input = Input(shape=ecg_shape, name='ecg_input')

    # Initial convolutional layer - captures basic ECG waveforms
    X = Conv1D(filters=64, kernel_size=7, strides=2, padding='same')(ecg_input)
    X = BatchNormalization()(X)
    X = Activation('relu')(X)
    X = MaxPooling1D(pool_size=3, strides=2, padding='same')(X)

    # ResNet block stack 1: Fine-grained pattern detection (P-QRS-T waves)
    X = ResNet1D_block(X, filters=64, kernel_size=5)
    X = ResNet1D_block(X, filters=64, kernel_size=5)

    # ResNet block stack 2: Intermediate pattern recognition (ST segments, intervals)
    X = ResNet1D_block(X, filters=128, kernel_size=5, stride=2)
    X = ResNet1D_block(X, filters=128, kernel_size=5)

    # ResNet block stack 3: High-level rhythm and morphology patterns
    X = ResNet1D_block(X, filters=256, kernel_size=3, stride=2)
    X = ResNet1D_block(X, filters=256, kernel_size=3)

    # Global feature aggregation
    X = GlobalAveragePooling1D()(X)
    X = Dense(128, activation='relu')(X)
    X = Dropout(0.5)(X)  # Prevent overfitting on ECG patterns
    ecg_output = Dense(64, activation='relu')(X)

    # TABULAR FEATURE BRANCH
    tab_input = Input(shape=(tab_shape,), name='tab_input')

    # Clinical metadata processing network
    Y = Dense(128, activation='relu')(tab_input)
    Y = BatchNormalization()(Y)
    Y = Dropout(0.3)(Y)

    Y = Dense(64, activation='relu')(Y)
    Y = BatchNormalization()(Y)
    Y = Dropout(0.3)(Y)

    tab_output = Dense(32, activation='relu')(Y)

    # MULTIMODAL FUSION
    # Combine ECG signal features with clinical metadata
    combined = concatenate([ecg_output, tab_output], name='multimodal_fusion')

    # Final classification layers
    Z = Dense(64, activation='relu')(combined)
    Z = BatchNormalization()(Z)
    Z = Dropout(0.3)(Z)

    Z = Dense(32, activation='relu')(Z)

    # Multi-label output with sigmoid activation for independent class probabilities
    main_output = Dense(n_classes, activation='sigmoid', name='main_output')(Z)

    # MODEL COMPILATION
    model = Model(inputs=[ecg_input, tab_input], outputs=main_output)

    # Compile with clinically-appropriate metrics
    model.compile(
        optimizer=Adam(learning_rate=0.001),
        loss='binary_crossentropy',  # Appropriate for multi-label classification
        metrics=[
            'accuracy',
            tf.keras.metrics.AUC(multi_label=True, name='auc'),
            tf.keras.metrics.Recall(name='recall'),
            tf.keras.metrics.Precision(name='precision')
        ]
    )

    return model


print("=== ResNet-1D Multimodal Architecture Defined ===")
print("‚úì ECG signal processing branch: 1D ResNet with temporal modeling")
print("‚úì Tabular feature branch: Dense network for clinical metadata")
print("‚úì Multimodal fusion: Late fusion of ECG and clinical features")
print("‚úì Multi-label output: Sigmoid activation for independent predictions")

=== ResNet-1D Multimodal Architecture Defined ===
‚úì ECG signal processing branch: 1D ResNet with temporal modeling
‚úì Tabular feature branch: Dense network for clinical metadata
‚úì Multimodal fusion: Late fusion of ECG and clinical features
‚úì Multi-label output: Sigmoid activation for independent predictions
