# ECG Multi-Label Model Training (Self-Contained)
Fully independent notebook: loads preprocessed data, trains multi-label model, saves results

In [None]:
!pip install tensorflow scikit-learn matplotlib seaborn huggingface-hub -q

import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.model_selection import train_test_split
from sklearn.metrics import hamming_loss, precision_score, recall_score, f1_score

np.random.seed(42)
tf.random.set_seed(42)

print(f"TensorFlow: {tf.__version__}")
print(f"GPU: {tf.config.list_physical_devices('GPU')}")

## Environment & Data Loading

In [None]:
# Detect environment
IN_COLAB = False
try:
    from google.colab import drive
    IN_COLAB = True
    drive.mount('/content/drive/')
except:
    pass

# Set paths
if IN_COLAB:
    DATA_DIR = "/content/drive/MyDrive/DeepLearningECG/artifacts/multilabel_v2"
    RESULTS_DIR = "/content/drive/MyDrive/DeepLearningECG/results/"
else:
    DATA_DIR = "../DeepLearningECG/artifacts/multilabel_v2"
    RESULTS_DIR = "../DeepLearningECG/results/"

os.makedirs(RESULTS_DIR, exist_ok=True)

print(f"Colab: {IN_COLAB}")
print(f"DATA_DIR: {DATA_DIR}")
print(f"RESULTS_DIR: {RESULTS_DIR}")

In [None]:
# Load from local or Hugging Face
if os.path.exists(DATA_DIR):
    print(f"Loading from local: {DATA_DIR}")
    X = np.load(os.path.join(DATA_DIR, "X_windows.npy"))
    y = np.load(os.path.join(DATA_DIR, "y_labels_onehot.npy"))
    with open(os.path.join(DATA_DIR, "disease_classes.json")) as f:
        class_info = json.load(f)
else:
    print("Downloading from Hugging Face...")
    from huggingface_hub import hf_hub_download
    
    X_file = hf_hub_download("kiril-buga/ECG-database", "multilabel_v2/X_windows.npy", repo_type="dataset")
    y_file = hf_hub_download("kiril-buga/ECG-database", "multilabel_v2/y_labels_onehot.npy", repo_type="dataset")
    class_file = hf_hub_download("kiril-buga/ECG-database", "multilabel_v2/disease_classes.json", repo_type="dataset")
    
    X = np.load(X_file)
    y = np.load(y_file)
    with open(class_file) as f:
        class_info = json.load(f)

DISEASE_CLASSES = class_info["classes"]

print(f"✓ Loaded data")
print(f"  X: {X.shape}")
print(f"  y: {y.shape}")
print(f"  Classes: {DISEASE_CLASSES}")

## Data Preparation & Splitting

In [None]:
# For patient-level splitting (assumes qc_summary available)
try:
    qc_file = os.path.join(DATA_DIR, "qc_summary.csv") if os.path.exists(DATA_DIR) else None
    if qc_file and os.path.exists(qc_file):
        df_qc = pd.read_csv(qc_file)
        df_qc['Patient_ID'] = df_qc['Filename'].str.split('/').str[1]
        
        # Map windows to patients
        window_patients = []
        for _, row in df_qc.iterrows():
            n_windows = int(row['n_windows']) if pd.notna(row['n_windows']) else 0
            window_patients.extend([row['Patient_ID']] * n_windows)
        window_patients = np.array(window_patients)
        
        # Patient-level split
        unique_patients = np.unique(window_patients)
        train_pat, test_pat = train_test_split(unique_patients, test_size=0.2, random_state=42)
        train_pat, val_pat = train_test_split(train_pat, test_size=0.25, random_state=42)
        
        train_idx = np.where(np.isin(window_patients, train_pat))[0]
        val_idx = np.where(np.isin(window_patients, val_pat))[0]
        test_idx = np.where(np.isin(window_patients, test_pat))[0]
    else:
        raise FileNotFoundError()
except:
    print("No patient mapping available, using random split")
    all_idx = np.arange(len(X))
    train_idx, test_idx = train_test_split(all_idx, test_size=0.2, random_state=42)
    train_idx, val_idx = train_test_split(train_idx, test_size=0.25, random_state=42)

X_train, y_train = X[train_idx], y[train_idx]
X_val, y_val = X[val_idx], y[val_idx]
X_test, y_test = X[test_idx], y[test_idx]

print(f"Train: {X_train.shape}, Val: {X_val.shape}, Test: {X_test.shape}")
print(f"\nClass distribution (test):")
for i, cls in enumerate(DISEASE_CLASSES):
    count = y_test[:, i].sum()
    print(f"  {cls}: {count} ({100*count/len(y_test):.1f}%)")

## Build & Train Model

In [None]:
def build_model(input_shape, num_classes):
    return keras.Sequential([
        layers.Input(shape=input_shape),
        layers.Conv1D(32, 3, padding='same', activation='relu'),
        layers.Conv1D(32, 3, padding='same', activation='relu'),
        layers.MaxPooling1D(2),
        layers.Dropout(0.2),
        layers.Conv1D(64, 3, padding='same', activation='relu'),
        layers.Conv1D(64, 3, padding='same', activation='relu'),
        layers.MaxPooling1D(2),
        layers.Dropout(0.2),
        layers.Conv1D(128, 3, padding='same', activation='relu'),
        layers.Conv1D(128, 3, padding='same', activation='relu'),
        layers.MaxPooling1D(2),
        layers.Dropout(0.2),
        layers.Conv1D(256, 3, padding='same', activation='relu'),
        layers.Conv1D(256, 3, padding='same', activation='relu'),
        layers.MaxPooling1D(2),
        layers.Dropout(0.2),
        layers.GlobalAveragePooling1D(),
        layers.Dense(256, activation='relu'),
        layers.Dropout(0.3),
        layers.Dense(128, activation='relu'),
        layers.Dropout(0.3),
        layers.Dense(num_classes, activation='sigmoid')
    ])

model = build_model((X_train.shape[1], X_train.shape[2]), len(DISEASE_CLASSES))
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    loss='binary_crossentropy',
    metrics=['binary_accuracy']
)

print("Model:")
model.summary()

In [None]:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_path = os.path.join(RESULTS_DIR, f"model_{timestamp}.keras")

callbacks = [
    keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True),
    keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-6),
    keras.callbacks.ModelCheckpoint(model_path, monitor='val_binary_accuracy', save_best_only=True)
]

print("Training...")
history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=100,
    batch_size=32,
    callbacks=callbacks,
    verbose=1
)
print(f"✓ Model saved: {model_path}")

## Evaluation

In [None]:
y_pred_probs = model.predict(X_test, verbose=0)
y_pred = (y_pred_probs >= 0.5).astype(int)

print("\n" + "="*70)
print("TEST SET METRICS")
print("="*70)

print(f"\nHamming Loss: {hamming_loss(y_test, y_pred):.4f}")
print(f"Exact Match Accuracy: {np.mean(np.all(y_test == y_pred, axis=1)):.4f}")

print(f"\nPer-Class Metrics:")
print(f"{'Class':<20} {'Precision':>12} {'Recall':>12} {'F1-Score':>12}")
print("-" * 70)

for i, cls in enumerate(DISEASE_CLASSES):
    p = precision_score(y_test[:, i], y_pred[:, i], zero_division=0)
    r = recall_score(y_test[:, i], y_pred[:, i], zero_division=0)
    f = f1_score(y_test[:, i], y_pred[:, i], zero_division=0)
    print(f"{cls:<20} {p:>12.4f} {r:>12.4f} {f:>12.4f}")

## Visualization

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].plot(history.history['loss'], label='Train', linewidth=2)
axes[0].plot(history.history['val_loss'], label='Val', linewidth=2)
axes[0].set_xlabel('Epoch'); axes[0].set_ylabel('Loss')
axes[0].set_title('Loss'); axes[0].legend(); axes[0].grid(True, alpha=0.3)

axes[1].plot(history.history['binary_accuracy'], label='Train', linewidth=2)
axes[1].plot(history.history['val_binary_accuracy'], label='Val', linewidth=2)
axes[1].set_xlabel('Epoch'); axes[1].set_ylabel('Accuracy')
axes[1].set_title('Accuracy'); axes[1].legend(); axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(RESULTS_DIR, f'training_{timestamp}.png'), dpi=150, bbox_inches='tight')
print(f"✓ Plot saved")
plt.show()

## Save Results

In [None]:
# Save predictions
np.savez(
    os.path.join(RESULTS_DIR, f"predictions_{timestamp}.npz"),
    y_true=y_test,
    y_pred=y_pred,
    y_probs=y_pred_probs
)

# Save results JSON
results = {
    'timestamp': timestamp,
    'disease_classes': DISEASE_CLASSES,
    'epochs_trained': len(history.history['loss']),
    'train_samples': X_train.shape[0],
    'val_samples': X_val.shape[0],
    'test_samples': X_test.shape[0],
}

with open(os.path.join(RESULTS_DIR, f"results_{timestamp}.json"), "w") as f:
    json.dump(results, f, indent=2)

print(f"✓ Results saved to {RESULTS_DIR}")