<a href="https://colab.research.google.com/github/kiril-buga/Neural-Network-Training-Project/blob/main/ECG_MultiLabel_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ECG Multi-Label Model Training (Self-Contained)
Fully independent notebook: loads preprocessed data, trains multi-label model, saves results

In [3]:
!pip install tensorflow scikit-learn matplotlib seaborn huggingface-hub h5py -q


import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.model_selection import train_test_split
from sklearn.metrics import hamming_loss, precision_score, recall_score, f1_score

np.random.seed(42)
tf.random.set_seed(42)

print(f"TensorFlow: {tf.__version__}")
print(f"GPU: {tf.config.list_physical_devices('GPU')}")

TensorFlow: 2.19.0
GPU: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


## Environment & Data Loading

In [None]:
# Load dataset structure from HDF5 and create streaming data generator
h5_path = os.path.join(DATA_DIR, "ecg_data.h5")

print(f"Opening HDF5 file: {h5_path}")
with h5py.File(h5_path, 'r') as h5f:
    n_samples = h5f['X'].shape[0]
    input_shape = h5f['X'].shape[1:]
    DISEASE_CLASSES = list(h5f.attrs['disease_classes'])
    data_format = h5f.attrs.get('data_format', 'unknown')

print(f"✓ Dataset info:")
print(f"  Total windows: {n_samples}")
print(f"  Input shape: {input_shape}")
print(f"  Data format: {data_format}")
print(f"  Classes: {DISEASE_CLASSES}")

# Create train/val/test split indices (without loading data)
all_idx = np.arange(n_samples)
train_idx, test_idx = train_test_split(all_idx, test_size=0.2, random_state=42)
train_idx, val_idx = train_test_split(train_idx, test_size=0.25, random_state=42)

print(f"\nTrain: {len(train_idx)}, Val: {len(val_idx)}, Test: {len(test_idx)}")

# Generator function to stream data from HDF5
def hdf5_generator(h5_path, indices, batch_size=32):
    """
    Generator that streams windows from HDF5 file.
    Yields batches of (X, y) without loading entire dataset into memory.
    """
    with h5py.File(h5_path, 'r') as h5f:
        X_dset = h5f['X']
        y_dset = h5f['y']
        
        # Shuffle indices for training
        shuffled_idx = np.random.permutation(indices)
        
        # Yield batches
        for i in range(0, len(shuffled_idx), batch_size):
            batch_idx = shuffled_idx[i:i+batch_size]
            
            # Load batch from HDF5
            X_batch = X_dset[batch_idx]
            y_batch = y_dset[batch_idx]
            
            # Convert float16 to float32 if needed
            if X_batch.dtype == np.float16:
                X_batch = X_batch.astype(np.float32)
            
            yield X_batch, y_batch

# Create tf.data.Dataset objects
def create_dataset(h5_path, indices, batch_size=32, is_training=True):
    """Create streaming dataset from HDF5."""
    def gen():
        with h5py.File(h5_path, 'r') as h5f:
            X_dset = h5f['X']
            y_dset = h5f['y']
            
            # Shuffle for training, don't shuffle for val/test
            idx = np.random.permutation(indices) if is_training else indices
            
            for i in range(0, len(idx), batch_size):
                batch_idx = idx[i:i+batch_size]
                X_batch = X_dset[batch_idx]
                y_batch = y_dset[batch_idx]
                
                if X_batch.dtype == np.float16:
                    X_batch = X_batch.astype(np.float32)
                
                yield X_batch, y_batch
    
    dataset = tf.data.Dataset.from_generator(
        gen,
        output_signature=(
            tf.TensorSpec(shape=(None, *input_shape), dtype=tf.float32),
            tf.TensorSpec(shape=(None, len(DISEASE_CLASSES)), dtype=tf.int32)
        )
    )
    
    return dataset.prefetch(tf.data.AUTOTUNE)

# Create datasets
print("\nCreating streaming datasets...")
BATCH_SIZE = 32
train_dataset = create_dataset(h5_path, train_idx, batch_size=BATCH_SIZE, is_training=True)
val_dataset = create_dataset(h5_path, val_idx, batch_size=BATCH_SIZE, is_training=False)
test_dataset = create_dataset(h5_path, test_idx, batch_size=BATCH_SIZE, is_training=False)

print(f"✓ Datasets created (batch_size={BATCH_SIZE})")

In [None]:
# Load preprocessed data from HDF5
h5_path = os.path.join(DATA_DIR, "ecg_data.h5")

print(f"Loading data from: {h5_path}")
with h5py.File(h5_path, 'r') as h5f:
    X = h5f['X'][:]  # Load all data
    y = h5f['y'][:]
    DISEASE_CLASSES = list(h5f.attrs['disease_classes'])
    data_format = h5f.attrs.get('data_format', 'unknown')

# Convert float16 to float32 for TensorFlow if needed
if X.dtype == np.float16:
    print(f"Converting float16 → float32 for TensorFlow...")
    X = X.astype(np.float32)

print(f"✓ Data loaded successfully (format: {data_format})")
print(f"  X: {X.shape} (dtype: {X.dtype})")
print(f"  y: {y.shape}")
print(f"  Classes: {DISEASE_CLASSES}")

Loading data from: /content/ECG-database/multilabel_v2/ecg_data.h5


# Data preparation complete - splitting already done in cell-4
# Verify dataset contents
print("Verifying datasets...")
for batch_X, batch_y in train_dataset.take(1):
    print(f"  Sample batch - X shape: {batch_X.shape}, y shape: {batch_y.shape}")
    for i, cls in enumerate(DISEASE_CLASSES):
        count = batch_y[:, i].sum().numpy()
        print(f"    {cls}: {count} in batch")

In [None]:
# For patient-level splitting (assumes qc_summary available)
try:
    qc_file = os.path.join(DATA_DIR, "qc_summary.csv") if os.path.exists(DATA_DIR) else None
    if qc_file and os.path.exists(qc_file):
        df_qc = pd.read_csv(qc_file)
        df_qc['Patient_ID'] = df_qc['Filename'].str.split('/').str[1]

        # Map windows to patients
        window_patients = []
        for _, row in df_qc.iterrows():
            n_windows = int(row['n_windows']) if pd.notna(row['n_windows']) else 0
            window_patients.extend([row['Patient_ID']] * n_windows)
        window_patients = np.array(window_patients)

        # Patient-level split
        unique_patients = np.unique(window_patients)
        train_pat, test_pat = train_test_split(unique_patients, test_size=0.2, random_state=42)
        train_pat, val_pat = train_test_split(train_pat, test_size=0.25, random_state=42)

        train_idx = np.where(np.isin(window_patients, train_pat))[0]
        val_idx = np.where(np.isin(window_patients, val_pat))[0]
        test_idx = np.where(np.isin(window_patients, test_pat))[0]
    else:
        raise FileNotFoundError()
except:
    print("No patient mapping available, using random split")
    all_idx = np.arange(len(X))
    train_idx, test_idx = train_test_split(all_idx, test_size=0.2, random_state=42)
    train_idx, val_idx = train_test_split(train_idx, test_size=0.25, random_state=42)

X_train, y_train = X[train_idx], y[train_idx]
X_val, y_val = X[val_idx], y[val_idx]
X_test, y_test = X[test_idx], y[test_idx]

print(f"Train: {X_train.shape}, Val: {X_val.shape}, Test: {X_test.shape}")
print(f"\nClass distribution (test):")
for i, cls in enumerate(DISEASE_CLASSES):
    count = y_test[:, i].sum()
    print(f"  {cls}: {count} ({100*count/len(y_test):.1f}%)")

## Build & Train Model

In [None]:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_path = os.path.join(RESULTS_DIR, f"model_{timestamp}.keras")

callbacks = [
    keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True),
    keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-6),
    keras.callbacks.ModelCheckpoint(model_path, monitor='val_binary_accuracy', save_best_only=True)
]

print("Training with streaming data...")
history = model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=100,
    callbacks=callbacks,
    verbose=1
)
print(f"✓ Model saved: {model_path}")

In [None]:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_path = os.path.join(RESULTS_DIR, f"model_{timestamp}.keras")

callbacks = [
    keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True),
    keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-6),
    keras.callbacks.ModelCheckpoint(model_path, monitor='val_binary_accuracy', save_best_only=True)
]

print("Training...")
history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=100,
    batch_size=32,
    callbacks=callbacks,
    verbose=1
)
print(f"✓ Model saved: {model_path}")

# Evaluate on test dataset
print("Evaluating on test set...")
y_test_list = []
y_pred_list = []
y_probs_list = []

for X_batch, y_batch in test_dataset:
    y_test_list.append(y_batch.numpy())
    y_probs = model.predict(X_batch, verbose=0)
    y_probs_list.append(y_probs)
    y_pred = (y_probs >= 0.5).astype(int)
    y_pred_list.append(y_pred)

# Concatenate all batches
y_test = np.concatenate(y_test_list, axis=0)
y_pred = np.concatenate(y_pred_list, axis=0)
y_pred_probs = np.concatenate(y_probs_list, axis=0)

print("\n" + "="*70)
print("TEST SET METRICS")
print("="*70)

print(f"\nHamming Loss: {hamming_loss(y_test, y_pred):.4f}")
print(f"Exact Match Accuracy: {np.mean(np.all(y_test == y_pred, axis=1)):.4f}")

print(f"\nPer-Class Metrics:")
print(f"{'Class':<20} {'Precision':>12} {'Recall':>12} {'F1-Score':>12}")
print("-" * 70)

for i, cls in enumerate(DISEASE_CLASSES):
    p = precision_score(y_test[:, i], y_pred[:, i], zero_division=0)
    r = recall_score(y_test[:, i], y_pred[:, i], zero_division=0)
    f = f1_score(y_test[:, i], y_pred[:, i], zero_division=0)
    print(f"{cls:<20} {p:>12.4f} {r:>12.4f} {f:>12.4f}")

In [None]:
y_pred_probs = model.predict(X_test, verbose=0)
y_pred = (y_pred_probs >= 0.5).astype(int)

print("\n" + "="*70)
print("TEST SET METRICS")
print("="*70)

print(f"\nHamming Loss: {hamming_loss(y_test, y_pred):.4f}")
print(f"Exact Match Accuracy: {np.mean(np.all(y_test == y_pred, axis=1)):.4f}")

print(f"\nPer-Class Metrics:")
print(f"{'Class':<20} {'Precision':>12} {'Recall':>12} {'F1-Score':>12}")
print("-" * 70)

for i, cls in enumerate(DISEASE_CLASSES):
    p = precision_score(y_test[:, i], y_pred[:, i], zero_division=0)
    r = recall_score(y_test[:, i], y_pred[:, i], zero_division=0)
    f = f1_score(y_test[:, i], y_pred[:, i], zero_division=0)
    print(f"{cls:<20} {p:>12.4f} {r:>12.4f} {f:>12.4f}")

## Visualization

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].plot(history.history['loss'], label='Train', linewidth=2)
axes[0].plot(history.history['val_loss'], label='Val', linewidth=2)
axes[0].set_xlabel('Epoch'); axes[0].set_ylabel('Loss')
axes[0].set_title('Loss'); axes[0].legend(); axes[0].grid(True, alpha=0.3)

axes[1].plot(history.history['binary_accuracy'], label='Train', linewidth=2)
axes[1].plot(history.history['val_binary_accuracy'], label='Val', linewidth=2)
axes[1].set_xlabel('Epoch'); axes[1].set_ylabel('Accuracy')
axes[1].set_title('Accuracy'); axes[1].legend(); axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(RESULTS_DIR, f'training_{timestamp}.png'), dpi=150, bbox_inches='tight')
print(f"✓ Plot saved")
plt.show()

## Save Results

In [None]:
# Save predictions
np.savez(
    os.path.join(RESULTS_DIR, f"predictions_{timestamp}.npz"),
    y_true=y_test,
    y_pred=y_pred,
    y_probs=y_pred_probs
)

# Save results JSON
results = {
    'timestamp': timestamp,
    'disease_classes': DISEASE_CLASSES,
    'epochs_trained': len(history.history['loss']),
    'train_samples': X_train.shape[0],
    'val_samples': X_val.shape[0],
    'test_samples': X_test.shape[0],
}

with open(os.path.join(RESULTS_DIR, f"results_{timestamp}.json"), "w") as f:
    json.dump(results, f, indent=2)

print(f"✓ Results saved to {RESULTS_DIR}")