# End-to-End Deep Learning Framework for Automated ECG Image Diagnosis and Clinical Report Generation

In [None]:
# 1. Setup & Imports

# Install required packages with correct names and dependencies
!pip install timm scikit-learn jinja2 pandas opencv-python albumentations --quiet
!pip install grad-cam --quiet  # Correct package name for pytorch-grad-cam
!apt-get update && apt-get install -y tesseract-ocr --quiet  # Install Tesseract OCR
!pip install pytesseract --quiet  # Now install Python wrapper

import os, random, re, json, gc, time, threading, warnings
from pathlib import Path
from PIL import Image
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms
import timm

import cv2
import pytesseract
from jinja2 import Template
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc
from sklearn.preprocessing import label_binarize
from sklearn.utils.class_weight import compute_class_weight

import albumentations as A
from albumentations.pytorch import ToTensorV2
from concurrent.futures import ThreadPoolExecutor, as_completed

import shutil
from pathlib import Path
from sklearn.model_selection import train_test_split

# Import grad-cam with correct syntax
try:
    from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
    from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
    from pytorch_grad_cam.utils.image import show_cam_on_image
    print("✓ GradCAM imported successfully")
except ImportError as e:
    print(f"⚠️ GradCAM import failed: {e}")
    print("Installing grad-cam...")
    !pip install grad-cam
    from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
    from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
    from pytorch_grad_cam.utils.image import show_cam_on_image

warnings.filterwarnings('ignore')

# Check if seaborn style exists, use alternative if not
try:
    plt.style.use('seaborn-v0_8')
except OSError:
    try:
        plt.style.use('seaborn')
    except OSError:
        plt.style.use('default')
        print("⚠️ Using default matplotlib style")

# Set up device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"✓ Using device: {device}")

# Set random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)
print("✓ Random seeds set for reproducibility")

# Verify critical imports
print("\n=== Import Verification ===")
print(f"✓ PyTorch version: {torch.__version__}")
print(f"✓ Timm version: {timm.__version__}")
print(f"✓ OpenCV version: {cv2.__version__}")
print(f"✓ NumPy version: {np.__version__}")
print(f"✓ Pandas version: {pd.__version__}")

# Test pytesseract
try:
    pytesseract.get_tesseract_version()
    print("✓ Tesseract OCR is working")
except Exception as e:
    print(f"⚠️ Tesseract issue: {e}")

print("=== Setup Complete ===\n")

0% [Working]            Hit:1 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease
Hit:2 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease
Hit:3 http://security.ubuntu.com/ubuntu jammy-security InRelease
Hit:4 http://archive.ubuntu.com/ubuntu jammy InRelease
Hit:5 http://archive.ubuntu.com/ubuntu jammy-updates InRelease
Hit:6 http://archive.ubuntu.com/ubuntu jammy-backports InRelease
Hit:7 https://r2u.stat.illinois.edu/ubuntu jammy InRelease
Hit:8 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease
Hit:9 https://ppa.launchpadcontent.net/graphics-drivers/ppa/ubuntu jammy InRelease
Hit:10 https://ppa.launchpadcontent.net/ubuntugis/ppa/ubuntu jammy InRelease
Reading package lists... Done
W: Skipping acquire of configured file 'main/source/Sources' as repository 'https://r2u.stat.illinois.edu/ubuntu jammy InRelease' does not seem to provide it (sources.list entry misspelt?)
Reading package lists...
Building dependenc

In [None]:
# 2. Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# 3. Preprocessing the dataset

# Define base directories
BASE_DIR = Path('/content/drive/My Drive/Dataset')
REPORTS_DIR = BASE_DIR / 'generated_reports'
MODELS_DIR = BASE_DIR / 'saved_models'
TRAIN_DIR = BASE_DIR / 'train'
VAL_DIR = BASE_DIR / 'val'
TEST_DIR = BASE_DIR / 'test'

# Create necessary directories
REPORTS_DIR.mkdir(exist_ok=True)
MODELS_DIR.mkdir(exist_ok=True)
TRAIN_DIR.mkdir(exist_ok=True)
VAL_DIR.mkdir(exist_ok=True)
TEST_DIR.mkdir(exist_ok=True)

# Define the source directories and their new prefixes
SOURCE_DIRS_MAP = {
    'Abnormal Heartbeat Patients': 'HB',
    'Myocardial Infarction Patients': 'MI',
    'Normal Person': 'Normal',
    'Patient that have History of Myocardial Infraction': 'PMI'
}

print("Starting dataset preparation...")

# --- Step 1: Clean and Collect all image paths ---
all_image_paths = []
class_to_prefix = {} # To map original folder name to desired prefix

for src_folder_name, prefix in SOURCE_DIRS_MAP.items():
    current_src_dir = BASE_DIR / src_folder_name
    if not current_src_dir.exists():
        print(f"Warning: Source directory '{current_src_dir}' does not exist. Skipping.")
        continue

    print(f"Processing source directory: {current_src_dir}")
    class_dir_path = current_src_dir

    # Get all files in the current class directory
    for img_file in class_dir_path.iterdir():
        if img_file.is_file():
            # Remove "Copy" files
            if "Copy" in img_file.name:
                print(f"Removing duplicate file: {img_file}")
                os.remove(img_file)
                continue
            all_image_paths.append((img_file, src_folder_name)) # Store (path, original_folder_name)
            class_to_prefix[src_folder_name] = prefix

# If no images found, exit
if not all_image_paths:
    print("No images found in the specified source directories. Exiting.")
else:
    # Separate paths by class
    images_by_class = {src_folder_name: [] for src_folder_name in SOURCE_DIRS_MAP.keys()}
    for img_path, src_folder_name in all_image_paths:
        images_by_class[src_folder_name].append(img_path)

    # --- Step 2: Split the data into train, val, test ---
    # Define split ratios
    train_ratio = 0.7
    val_ratio = 0.15
    test_ratio = 0.15

    # Perform stratified split for each class
    for src_folder_name, img_list in images_by_class.items():
        if not img_list:
            print(f"No images found for class: {src_folder_name}. Skipping split.")
            continue

        print(f"Splitting images for class: {src_folder_name} ({len(img_list)} images)")

        # Create subdirectories for the current class in train, val, test
        class_train_dir = TRAIN_DIR / src_folder_name
        class_val_dir = VAL_DIR / src_folder_name
        class_test_dir = TEST_DIR / src_folder_name

        class_train_dir.mkdir(exist_ok=True, parents=True)
        class_val_dir.mkdir(exist_ok=True, parents=True)
        class_test_dir.mkdir(exist_ok=True, parents=True)

        # Split into train and temp (val + test)
        train_images, temp_images = train_test_split(img_list, test_size=(val_ratio + test_ratio), random_state=42)

        # Split temp into val and test
        # Calculate new test_size ratio for temp_images (test_ratio / (val_ratio + test_ratio))
        val_images, test_images = train_test_split(temp_images, test_size=(test_ratio / (val_ratio + test_ratio)), random_state=42)

        print(f"  Train: {len(train_images)} | Val: {len(val_images)} | Test: {len(test_images)}")

        # --- Step 3: Move and Re-index/Rename files ---
        # Function to copy and rename
        def copy_and_rename(image_list, dest_dir, prefix_label):
            for i, img_path in enumerate(image_list):
                new_name = f"{prefix_label}({i+1}){img_path.suffix}"
                shutil.copy(img_path, dest_dir / new_name)

        prefix_label = class_to_prefix[src_folder_name]

        print(f"Copying and renaming images for {src_folder_name}...")
        copy_and_rename(train_images, class_train_dir, prefix_label)
        copy_and_rename(val_images, class_val_dir, prefix_label)
        copy_and_rename(test_images, class_test_dir, prefix_label)

    print("\nDataset preparation complete!")
    print(f"Dataset structure created in: {BASE_DIR}")
    print(f"Train images: {len(list(TRAIN_DIR.rglob('*.*')))} (approx.)")
    print(f"Val images: {len(list(VAL_DIR.rglob('*.*')))} (approx.)")
    print(f"Test images: {len(list(TEST_DIR.rglob('*.*')))} (approx.)")

Starting dataset preparation...
Processing source directory: /content/drive/My Drive/Dataset/Abnormal Heartbeat Patients
Processing source directory: /content/drive/My Drive/Dataset/Myocardial Infarction Patients
Processing source directory: /content/drive/My Drive/Dataset/Normal Person
Processing source directory: /content/drive/My Drive/Dataset/Patient that have History of Myocardial Infraction
Splitting images for class: Abnormal Heartbeat Patients (339 images)
  Train: 237 | Val: 51 | Test: 51
Copying and renaming images for Abnormal Heartbeat Patients...
Splitting images for class: Myocardial Infarction Patients (358 images)
  Train: 250 | Val: 54 | Test: 54
Copying and renaming images for Myocardial Infarction Patients...
Splitting images for class: Normal Person (426 images)
  Train: 298 | Val: 64 | Test: 64
Copying and renaming images for Normal Person...
Splitting images for class: Patient that have History of Myocardial Infraction (258 images)
  Train: 180 | Val: 39 | Test: 3

In [None]:
# 3. Enhanced Dataset with Advanced Augmentation
class ECGImageDataset(Dataset):
    label_map = {'HB': 0, 'MI': 1, 'Normal': 2, 'PMI': 3}

    def __init__(self, root: Path, transform=None, is_training=False):
        self.samples, self.transform, self.is_training = [], transform, is_training
        self.class_counts = {i: 0 for i in self.label_map.values()}

        for cls, idx in self.label_map.items():
            pattern = rf'{cls}\(\d+\)'
            for p in root.rglob('*.jpg'):
                if re.search(pattern, p.name):
                    self.samples.append((p, idx))
                    self.class_counts[idx] += 1

        random.shuffle(self.samples)
        print(f"Dataset loaded: {len(self.samples)} samples")
        print(f"Class distribution: {self.class_counts}")

    def get_class_weights(self):
        """Calculate class weights for balanced training"""
        total = sum(self.class_counts.values())
        weights = [total / (len(self.class_counts) * count) for count in self.class_counts.values()]
        return torch.FloatTensor(weights)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        try:
            # Load image
            img = Image.open(path).convert('RGB')
            img_array = np.array(img)

            # Apply transforms
            if self.transform:
                if isinstance(self.transform, A.Compose):
                    # Albumentations
                    transformed = self.transform(image=img_array)
                    img_tensor = transformed['image']
                else:
                    # PyTorch transforms
                    img_tensor = self.transform(img)
            else:
                img_tensor = transforms.ToTensor()(img)

            return img_tensor, label, str(path)
        except Exception as e:
            print(f"Error loading {path}: {e}")

            dummy_img = torch.zeros(3, 384, 384)
            return dummy_img, label, str(path)

In [None]:
# 4. Advanced Data Augmentation Pipeline
def get_training_transforms():
    """Comprehensive augmentation for ECG images"""
    return A.Compose([
        # Geometric transforms
        A.Resize(384, 384),
        A.HorizontalFlip(p=0.3),  # ECG can be flipped in some cases
        A.Rotate(limit=3, p=0.3),  # Small rotations for scanning artifacts
        A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.1, rotate_limit=2, p=0.3),

        # Noise and artifacts (common in ECG)
        A.GaussNoise(var_limit=(5.0, 15.0), p=0.2),
        A.ISONoise(color_shift=(0.01, 0.02), intensity=(0.1, 0.3), p=0.2),
        A.MultiplicativeNoise(multiplier=(0.95, 1.05), p=0.2),

        # Brightness/Contrast (for different recording conditions)
        A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.3),
        A.RandomGamma(gamma_limit=(90, 110), p=0.2),

        # Grid artifacts (common in ECG paper)
        A.GridDistortion(num_steps=3, distort_limit=0.05, p=0.15),

        # Normalize and convert to tensor
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ToTensorV2()
    ])

def get_validation_transforms():
    """Simple transforms for validation/test"""
    return A.Compose([
        A.Resize(384, 384),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ToTensorV2()
    ])

In [None]:
# 5. Create datasets with enhanced transforms
train_ds = ECGImageDataset(TRAIN_DIR, transform=get_training_transforms(), is_training=True)
val_ds = ECGImageDataset(VAL_DIR, transform=get_validation_transforms(), is_training=False)
test_paths = sorted(TEST_DIR.glob('*.jpg'))

Dataset loaded: 1145 samples
Class distribution: {0: 237, 1: 430, 2: 298, 3: 180}
Dataset loaded: 247 samples
Class distribution: {0: 51, 1: 93, 2: 64, 3: 39}


In [None]:
# 6. Weighted Sampling for Class Balance
class_weights = train_ds.get_class_weights()
sample_weights = [class_weights[label] for _, label in train_ds.samples]
sampler = WeightedRandomSampler(sample_weights, len(sample_weights), replacement=True)

# Data loaders with optimal settings
train_loader = DataLoader(train_ds, batch_size=32, sampler=sampler, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)

print(f"Class weights: {class_weights}")
print(f"Training batches: {len(train_loader)}, Validation batches: {len(val_loader)}")

Class weights: tensor([1.2078, 0.6657, 0.9606, 1.5903])
Training batches: 36, Validation batches: 8


In [None]:
# 7. Enhanced Model with Better Architecture - Using EfficientNetV2-S
class ECGClassifier(nn.Module):
    def __init__(self, model_name='tf_efficientnetv2_s', num_classes=4, dropout=0.3):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained=True, num_classes=0)
        self.feature_dim = self.backbone.num_features

        # Enhanced classifier head - Reduced size for smaller model
        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(self.feature_dim, 256),  # Reduced from 512 to 256
            nn.ReLU(inplace=True),
            nn.Dropout(dropout/2),
            nn.Linear(256, num_classes)
        )

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        for m in self.classifier:
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        features = self.backbone(x)
        return self.classifier(features)

In [None]:
# 8. Training Setup with Advanced Components
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = ECGClassifier(num_classes=4).to(device)

# Print model info
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Advanced loss function with class weighting
# Assuming class_weights is defined elsewhere - if not, we'll compute it
try:
    criterion = nn.CrossEntropyLoss(weight=class_weights.to(device), label_smoothing=0.1)
except NameError:
    print("Warning: class_weights not found, using unweighted loss")
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

# Optimizer with better settings - Slightly reduced learning rate for smaller model
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=8e-4,  # Reduced from 1e-3
    weight_decay=0.01,
    betas=(0.9, 0.999)
)

# Advanced learning rate scheduler
# Assuming train_loader is defined - if not, we'll use a placeholder
try:
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=8e-4,  # Reduced from 1e-3
        epochs=40,
        steps_per_epoch=len(train_loader),
        pct_start=0.1,
        anneal_strategy='cos'
    )
except NameError:
    print("Warning: train_loader not found, using step scheduler")
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

# Mixed precision training - Fixed version
use_amp = device.type == 'cuda'
scaler = torch.cuda.amp.GradScaler() if use_amp else None
print(f"Mixed precision training: {'Enabled' if use_amp else 'Disabled'}")

Using device: cuda
Total parameters: 20,506,452
Trainable parameters: 20,506,452
Mixed precision training: Enabled


In [None]:
# 9. Enhanced Training Functions
def train_epoch(loader, epoch):
    model.train()
    total_loss, correct, total = 0, 0, 0

    for batch_idx, (imgs, labels, _) in enumerate(loader):
        imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)

        optimizer.zero_grad()

        # Mixed precision forward pass
        if use_amp:
            with torch.cuda.amp.autocast():
                outputs = model(imgs)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        # Update scheduler if it's OneCycleLR
        if hasattr(scheduler, 'step_update'):
            scheduler.step()
        elif isinstance(scheduler, torch.optim.lr_scheduler.OneCycleLR):
            scheduler.step()

        # Statistics - ensure float32 for loss computation
        total_loss += loss.item() * imgs.size(0)
        predicted = outputs.argmax(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        # Progress logging
        if batch_idx % 50 == 0:
            lr = optimizer.param_groups[0]['lr']
            print(f'Epoch {epoch} [{batch_idx}/{len(loader)}] '
                  f'Loss: {loss.item():.4f} Acc: {100.*correct/total:.2f}% LR: {lr:.6f}')

    return total_loss/total, correct/total

def validate_epoch(loader):
    model.eval()
    total_loss, correct, total = 0, 0, 0
    all_preds, all_labels = [], []

    with torch.no_grad():
        for imgs, labels, _ in loader:
            imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)

            # Proper mixed precision validation
            if use_amp:
                with torch.cuda.amp.autocast():
                    outputs = model(imgs)
                    # Convert to float32 for loss computation
                    loss = criterion(outputs.float(), labels)
            else:
                outputs = model(imgs)
                loss = criterion(outputs, labels)

            total_loss += loss.item() * imgs.size(0)
            predicted = outputs.argmax(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    return total_loss/total, correct/total, all_preds, all_labels

In [None]:
# 10. Enhanced Training Loop with Early Stopping
def train_model(train_loader, val_loader, epochs=40, patience=8):
    best_acc = 0
    no_improve = 0
    train_losses, val_losses = [], []
    train_accs, val_accs = [], []

    print("Starting training...")

    for epoch in range(1, epochs + 1):
        try:
            # Training
            train_loss, train_acc = train_epoch(train_loader, epoch)

            # Validation
            val_loss, val_acc, val_preds, val_labels = validate_epoch(val_loader)

            # Save metrics
            train_losses.append(train_loss)
            val_losses.append(val_loss)
            train_accs.append(train_acc)
            val_accs.append(val_acc)

            print(f"\nEpoch {epoch}/{epochs}")
            print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
            print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
            print("-" * 50)

            # Save best model
            if val_acc > best_acc:
                best_acc = val_acc
                no_improve = 0

                # Create models directory if it doesn't exist
                models_dir = MODELS_DIR
                models_dir.mkdir(exist_ok=True)

                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'best_acc': best_acc,
                    'train_losses': train_losses,
                    'val_losses': val_losses,
                    'train_accs': train_accs,
                    'val_accs': val_accs,
                    'model_config': {
                        'num_classes': 4,
                        'model_name': 'tf_efficientnetv2_s'  # Updated model name
                    }
                }, models_dir / 'best_ecg_model.pth')
                print(f"✅ New best model saved! Validation Accuracy: {best_acc:.4f}")
            else:
                no_improve += 1

            # Update scheduler if it's not OneCycleLR
            if not isinstance(scheduler, torch.optim.lr_scheduler.OneCycleLR):
                scheduler.step()

            # Early stopping
            if no_improve >= patience:
                print(f"Early stopping after {epoch} epochs")
                break

            # Memory cleanup - More frequent for smaller GPU memory
            if epoch % 3 == 0:  # Changed from 5 to 3
                gc.collect()
                if device.type == 'cuda':
                    torch.cuda.empty_cache()

        except Exception as e:
            print(f"Error in epoch {epoch}: {str(e)}")
            print("Continuing to next epoch...")
            continue

    return {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_accs': train_accs,
        'val_accs': val_accs,
        'best_acc': best_acc
    }

# Memory monitoring function for Colab
def check_gpu_memory():
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1024**3  # Convert to GB
        reserved = torch.cuda.memory_reserved() / 1024**3
        print(f"GPU Memory - Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB")
        return allocated, reserved
    return 0, 0

# Example usage:
check_gpu_memory()  # Check initial memory usage
results = train_model(train_loader, val_loader, epochs=40, patience=8)

print("✅ Training setup complete with EfficientNetV2-S!")


GPU Memory - Allocated: 0.28GB, Reserved: 0.29GB
Starting training...
Epoch 1 [0/36] Loss: 1.8856 Acc: 18.75% LR: 0.000032

Epoch 1/40
Train Loss: 1.4155 | Train Acc: 0.3249
Val Loss: 1.0434 | Val Acc: 0.5587
--------------------------------------------------
✅ New best model saved! Validation Accuracy: 0.5587
Epoch 2 [0/36] Loss: 1.0171 Acc: 71.88% LR: 0.000152

Epoch 2/40
Train Loss: 1.0070 | Train Acc: 0.6017
Val Loss: 0.8093 | Val Acc: 0.7652
--------------------------------------------------
✅ New best model saved! Validation Accuracy: 0.7652
Epoch 3 [0/36] Loss: 0.8327 Acc: 71.88% LR: 0.000429

Epoch 3/40
Train Loss: 0.7561 | Train Acc: 0.7581
Val Loss: 0.8326 | Val Acc: 0.6842
--------------------------------------------------
Epoch 4 [0/36] Loss: 0.7461 Acc: 75.00% LR: 0.000698

Epoch 4/40
Train Loss: 0.6566 | Train Acc: 0.8245
Val Loss: 0.6536 | Val Acc: 0.8259
--------------------------------------------------
✅ New best model saved! Validation Accuracy: 0.8259
Epoch 5 [0/36]

In [None]:
# 11. Try Multiple Models for Comparison - EfficientNet-B0, EfficientNetV2-S, ConvNeXt, ResNeXt
model_names = [
    'efficientnet_b0',
    'tf_efficientnetv2_s',
    'resnext50_32x4d',
]

comparison_results = {}

for model_name in model_names:
    print(f"\n🔁 Training with model: {model_name}")

    # Reinitialize model and move to device
    model = ECGClassifier(model_name=model_name, num_classes=4).to(device)

    # Redefine optimizer and scheduler for each model
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=8e-4,
        weight_decay=0.01,
        betas=(0.9, 0.999)
    )

    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=8e-4,
        epochs=40,
        steps_per_epoch=len(train_loader),
        pct_start=0.1,
        anneal_strategy='cos'
    )

    # Redefine scaler for mixed precision
    scaler = torch.cuda.amp.GradScaler() if use_amp else None

    # Train the model
    result = train_model(train_loader, val_loader, epochs=40, patience=8)

    # Store best val acc for comparison
    comparison_results[model_name] = result['best_acc']

    print(f"✅ Finished {model_name} | Best Val Accuracy: {result['best_acc']:.4f}")
    print("-" * 60)

# Summary of all model performances
print("\n📊 Model Performance Summary:")
for name, acc in comparison_results.items():
    print(f"{name}: {acc:.4f}")



🔁 Training with model: tf_efficientnetv2_s
Starting training...
Epoch 1 [0/36] Loss: 1.6513 Acc: 28.12% LR: 0.000032

Epoch 1/40
Train Loss: 1.3873 | Train Acc: 0.3572
Val Loss: 1.1001 | Val Acc: 0.5344
--------------------------------------------------
✅ New best model saved! Validation Accuracy: 0.5344
Epoch 2 [0/36] Loss: 1.0868 Acc: 62.50% LR: 0.000152

Epoch 2/40
Train Loss: 1.0102 | Train Acc: 0.5904
Val Loss: 0.8211 | Val Acc: 0.7490
--------------------------------------------------
✅ New best model saved! Validation Accuracy: 0.7490
Epoch 3 [0/36] Loss: 0.9051 Acc: 62.50% LR: 0.000429

Epoch 3/40
Train Loss: 0.7684 | Train Acc: 0.7450
Val Loss: 0.7146 | Val Acc: 0.7935
--------------------------------------------------
✅ New best model saved! Validation Accuracy: 0.7935
Epoch 4 [0/36] Loss: 0.6507 Acc: 81.25% LR: 0.000698

Epoch 4/40
Train Loss: 0.6840 | Train Acc: 0.8096
Val Loss: 0.8009 | Val Acc: 0.7814
--------------------------------------------------
Epoch 5 [0/36] Loss

model.safetensors:   0%|          | 0.00/114M [00:00<?, ?B/s]

Starting training...
Epoch 1 [0/36] Loss: 2.5545 Acc: 28.12% LR: 0.000032

Epoch 1/40
Train Loss: 1.8067 | Train Acc: 0.2376
Val Loss: 1.4515 | Val Acc: 0.2065
--------------------------------------------------
✅ New best model saved! Validation Accuracy: 0.2065
Epoch 2 [0/36] Loss: 1.3326 Acc: 28.12% LR: 0.000152

Epoch 2/40
Train Loss: 1.3910 | Train Acc: 0.2786
Val Loss: 1.4830 | Val Acc: 0.1579
--------------------------------------------------
Epoch 3 [0/36] Loss: 1.3722 Acc: 21.88% LR: 0.000429

Epoch 3/40
Train Loss: 1.3779 | Train Acc: 0.2559
Val Loss: 1.4671 | Val Acc: 0.2065
--------------------------------------------------
Epoch 4 [0/36] Loss: 1.4198 Acc: 25.00% LR: 0.000698

Epoch 4/40
Train Loss: 1.3584 | Train Acc: 0.2699
Val Loss: 1.4127 | Val Acc: 0.1579
--------------------------------------------------
Epoch 5 [0/36] Loss: 1.3919 Acc: 15.62% LR: 0.000800

Epoch 5/40
Train Loss: 1.3871 | Train Acc: 0.2192
Val Loss: 1.4313 | Val Acc: 0.1579
----------------------------

model.safetensors:   0%|          | 0.00/100M [00:00<?, ?B/s]

Starting training...
Epoch 1 [0/36] Loss: 1.3745 Acc: 28.12% LR: 0.000032

Epoch 1/40
Train Loss: 1.3614 | Train Acc: 0.2760
Val Loss: 1.3950 | Val Acc: 0.1741
--------------------------------------------------
✅ New best model saved! Validation Accuracy: 0.1741
Epoch 2 [0/36] Loss: 1.3658 Acc: 28.12% LR: 0.000152

Epoch 2/40
Train Loss: 1.2143 | Train Acc: 0.3668
Val Loss: 1.0948 | Val Acc: 0.4899
--------------------------------------------------
✅ New best model saved! Validation Accuracy: 0.4899
Epoch 3 [0/36] Loss: 1.1255 Acc: 50.00% LR: 0.000429

Epoch 3/40
Train Loss: 0.8874 | Train Acc: 0.6568
Val Loss: 0.7849 | Val Acc: 0.7287
--------------------------------------------------
✅ New best model saved! Validation Accuracy: 0.7287
Epoch 4 [0/36] Loss: 0.9752 Acc: 65.62% LR: 0.000698

Epoch 4/40
Train Loss: 0.7141 | Train Acc: 0.7799
Val Loss: 0.6964 | Val Acc: 0.7976
--------------------------------------------------
✅ New best model saved! Validation Accuracy: 0.7976
Epoch 5 [0/

RuntimeError: Unknown model (coatnet_0)

In [None]:
# 11. Training Visualization
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix
from pathlib import Path

# Get the training metrics from results
train_losses = results['train_losses']
val_losses = results['val_losses']
train_accs = results['train_accs']
val_accs = results['val_accs']

# Create visualization
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

# Loss curves
epochs = range(1, len(train_losses) + 1)
ax1.plot(epochs, train_losses, label='Training Loss', color='blue', linewidth=2)
ax1.plot(epochs, val_losses, label='Validation Loss', color='red', linewidth=2)
ax1.set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Accuracy curves
ax2.plot(epochs, train_accs, label='Training Accuracy', color='blue', linewidth=2)
ax2.plot(epochs, val_accs, label='Validation Accuracy', color='red', linewidth=2)
ax2.set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.legend()
ax2.grid(True, alpha=0.3)

# Learning rate schedule (approximate for OneCycleLR)
# Since we can't track exact LR, we'll create an approximation
max_lr = 8e-4
total_epochs = len(train_losses)
# OneCycleLR pattern: rise to max, then decay
lr_schedule = []
for i in range(total_epochs):
    if i < total_epochs * 0.1:  # First 10% - warmup
        lr = max_lr * (i / (total_epochs * 0.1))
    else:  # Remaining 90% - cosine decay
        progress = (i - total_epochs * 0.1) / (total_epochs * 0.9)
        lr = max_lr * 0.5 * (1 + np.cos(np.pi * progress))
    lr_schedule.append(lr)

ax3.plot(epochs, lr_schedule, color='green', linewidth=2)
ax3.set_title('Learning Rate Schedule (Approximate)', fontsize=14, fontweight='bold')
ax3.set_xlabel('Epoch')
ax3.set_ylabel('Learning Rate')
ax3.grid(True, alpha=0.3)

# Training summary
ax4.text(0.1, 0.8, f'Best Validation Accuracy: {results["best_acc"]:.4f}',
         fontsize=12, transform=ax4.transAxes, fontweight='bold')
ax4.text(0.1, 0.7, f'Final Training Accuracy: {train_accs[-1]:.4f}',
         fontsize=12, transform=ax4.transAxes)
ax4.text(0.1, 0.6, f'Final Validation Accuracy: {val_accs[-1]:.4f}',
         fontsize=12, transform=ax4.transAxes)
ax4.text(0.1, 0.5, f'Total Epochs: {len(train_losses)}',
         fontsize=12, transform=ax4.transAxes)
ax4.text(0.1, 0.4, f'Early Stopping: {"Yes" if len(train_losses) < 40 else "No"}',
         fontsize=12, transform=ax4.transAxes)
ax4.set_title('Training Summary', fontsize=14, fontweight='bold')
ax4.axis('off')

plt.tight_layout()
plt.show()


In [None]:
# 12. Comprehensive Model Evaluation
print("Loading best model...")

# Model path - use the correct directory structure
models_dir = MODELS_DIR
model_path = models_dir / 'best_ecg_model.pth'

# Check if model file exists
if model_path.exists():
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"✅ Model loaded successfully!")
    print(f"Best validation accuracy: {checkpoint['best_acc']:.4f}")

    # Final evaluation
    print("\nEvaluating model on validation set...")
    val_loss, val_acc, y_pred, y_true = validate_epoch(val_loader)

    # Classification report
    label_names = ['Abnormal HB', 'MI', 'Normal', 'PMI']
    print("\n" + "="*60)
    print("FINAL CLASSIFICATION REPORT")
    print("="*60)
    print(classification_report(y_true, y_pred, target_names=label_names, digits=4))

    # Confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=label_names, yticklabels=label_names,
                cbar_kws={'label': 'Count'})
    plt.title('Confusion Matrix - ECG Classification', fontsize=16, fontweight='bold')
    plt.ylabel('True Label', fontsize=12)
    plt.xlabel('Predicted Label', fontsize=12)
    plt.tight_layout()
    plt.show()

    # Additional metrics
    from sklearn.metrics import precision_recall_fscore_support, accuracy_score

    precision, recall, f1, support = precision_recall_fscore_support(y_true, y_pred, average=None)
    accuracy = accuracy_score(y_true, y_pred)

    print("\n" + "="*60)
    print("DETAILED METRICS BY CLASS")
    print("="*60)
    for i, label in enumerate(label_names):
        print(f"{label:12} | Precision: {precision[i]:.4f} | Recall: {recall[i]:.4f} | F1: {f1[i]:.4f} | Support: {support[i]}")

    print(f"\nOverall Accuracy: {accuracy:.4f}")
    print(f"Macro Average F1: {np.mean(f1):.4f}")
    print(f"Weighted Average F1: {np.average(f1, weights=support):.4f}")

else:
    print(f"❌ Model file not found at: {model_path}")
    print("Available files in models directory:")
    if models_dir.exists():
        for file in models_dir.iterdir():
            print(f"  - {file.name}")
    else:
        print("  Models directory doesn't exist!")

    print("\nUsing current model state for evaluation...")
    val_loss, val_acc, y_pred, y_true = validate_epoch(val_loader)

    # Classification report with current model
    label_names = ['Abnormal HB', 'MI', 'Normal', 'PMI']
    print("\n" + "="*60)
    print("CLASSIFICATION REPORT (Current Model State)")
    print("="*60)
    print(classification_report(y_true, y_pred, target_names=label_names, digits=4))

print("\n🎯 Model evaluation complete!")

In [None]:
# 13. Clinical Report Generation

from datetime import datetime
from tqdm import tqdm
import time

class_map = {
    0: ("Abnormal Heartbeat", "Possible arrhythmia or irregular rhythm. Recommend further cardiological evaluation."),
    1: ("Myocardial Infarction", "ECG pattern consistent with myocardial infarction. Urgent clinical attention advised."),
    2: ("Normal Sinus Rhythm", "Normal ECG. No abnormalities detected."),
    3: ("History of Myocardial Infarction", "Signs of previous infarction. Regular follow-up recommended.")
}

def generate_text_report(test_id, prediction_idx, save_dir):
    date_str = datetime.now().strftime("%Y‑%m‑%d")
    diagnosis, interpretation = class_map[prediction_idx]

    content = f"""Test ID: {test_id}
Date: {date_str}

Automated Diagnosis
-------------------
{diagnosis}

Clinical Interpretation
-----------------------
{interpretation}
"""

    report_path = save_dir / f"{test_id}.txt"
    with open(report_path, 'w') as f:
        f.write(content)
    return diagnosis


REPORTS_DIR.mkdir(exist_ok=True)
summary_data = []

model.eval()

# Use your test loader, or raw paths like test_paths
test_ds = ECGImageDataset(TEST_DIR, transform=get_validation_transforms())
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False)

print("Generating reports...")

with torch.no_grad():
    for img, _, path_str in tqdm(test_loader):
      img = img.to(device)
      path = Path(path_str[0])
      test_id = path.stem

      start_time = time.time()

      # Prediction
      output = model(img)
      pred_idx = output.argmax(1).item()

      end_time = time.time()
      exec_time = round(end_time - start_time, 4)  # seconds

      # Save report
      diagnosis = generate_text_report(test_id, pred_idx, REPORTS_DIR)

      # Append to summary
      summary_data.append({
          "Test ID": test_id,
          "Prediction Class": pred_idx,
          "Diagnosis": class_map[pred_idx][0],
          "Execution Time (s)": exec_time
      })

summary_df = pd.DataFrame(summary_data)
summary_df.to_csv(REPORTS_DIR / "ecg_summary.csv", index=False)
print("✅ Summary CSV saved at:", REPORTS_DIR / "ecg_summary.csv")


In [None]:
# 14. ROC Curve Analysis with Enhanced Visualization

from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_curve, auc

def plot_roc_curves(model, dataloader, n_classes=4):
    model.eval()
    all_labels, all_probs = [], []

    with torch.no_grad():
        for imgs, labels, _ in dataloader:
            imgs = imgs.to(device)
            outputs = model(imgs)
            probs = F.softmax(outputs, dim=1).cpu().numpy()
            all_probs.extend(probs)
            all_labels.extend(labels.cpu().numpy())

    # Binarize the labels
    y_true = label_binarize(all_labels, classes=list(range(n_classes)))
    y_score = np.array(all_probs)

    fpr, tpr, roc_auc = {}, {}, {}
    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(y_true[:, i], y_score[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    # Plotting
    plt.figure(figsize=(10, 8))
    colors = ['red', 'green', 'blue', 'purple']
    labels = ['Abnormal Heartbeat', 'Myocardial Infarction', 'Normal', 'PMI']

    for i in range(n_classes):
        plt.plot(fpr[i], tpr[i], color=colors[i],
                 label=f'{labels[i]} (AUC = {roc_auc[i]:.2f})')

    plt.plot([0, 1], [0, 1], linestyle='--', color='gray')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Multi-Class ROC Curve')
    plt.legend(loc='lower right')
    plt.grid(True)
    plt.tight_layout()
    plt.show()

# Example
plot_roc_curves(model, val_loader)


In [None]:
# 15. Grad-CAM Visualization Enhancement

from tqdm import tqdm
import torchvision.transforms as T
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image

def batch_generate_gradcams(model, dataset, output_dir, method='GradCAM'):
    model.eval()

    # Choose the correct layer for EfficientNetV2-S
    target_layer = model.backbone.blocks[-1]

    cam_algorithm = eval(method)
    cam = cam_algorithm(model=model, target_layers=[target_layer])

    print(f"Generating Grad-CAMs for {len(dataset)} samples...")

    for idx in tqdm(range(len(dataset))):
        try:
            image_tensor, label, path = dataset[idx]

            input_tensor = image_tensor.unsqueeze(0).to(device)

            # Run CAM
            grayscale_cam = cam(input_tensor=input_tensor,
                                targets=[ClassifierOutputTarget(label)])[0]

            # Convert image for overlay
            img_np = image_tensor.permute(1, 2, 0).cpu().numpy()
            img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())

            cam_image = show_cam_on_image(img_np, grayscale_cam, use_rgb=True)

            # Save CAM image
            file_name = Path(path).stem + f'_class{label}.png'
            save_path = output_dir / file_name
            Image.fromarray(cam_image).save(save_path)

        except Exception as e:
            print(f"❌ Failed for index {idx} ({path}): {e}")

    print("✅ Grad-CAM generation complete.")

# === Usage ===
gradcam_dir = BASE_DIR / 'gradcam_outputs'
gradcam_dir.mkdir(exist_ok=True)
batch_generate_gradcams(model, test_ds, gradcam_dir)
