# Enhanced Glaucoma Detection with Custom Medical Imaging CNN (Local Version)
This notebook implements a specialized CNN architecture for glaucoma detection with:
- Custom multi-scale CNN with attention mechanisms
- Advanced data augmentation
- Class imbalance handling
- Grad-CAM visualization
- Comprehensive evaluation metrics

**This version runs locally on your machine (no Google Colab required)**

In [None]:
# Install required packages (run once)
# !pip install tensorflow tensorflow-addons scikit-learn matplotlib seaborn opencv-python

In [None]:
# Import required libraries
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import (
    EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, 
    TensorBoard, CSVLogger
)
import tensorflow_addons as tfa

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    confusion_matrix, classification_report, 
    roc_curve, auc, roc_auc_score,
    precision_recall_curve, average_precision_score
)
from sklearn.utils.class_weight import compute_class_weight
import cv2
import os
from datetime import datetime
import glob

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU Available: {tf.config.list_physical_devices('GPU')}")
print(f"Working directory: {os.getcwd()}")

In [None]:
# Configuration - LOCAL PATHS
BASE_PATH = r'c:\Users\likit\OneDrive\Documents\projects\glucamo'

TRAINING_PATH = os.path.join(BASE_PATH, 'Train-20251107T233046Z-1-001', 'Train')
VALIDATION_PATH = os.path.join(BASE_PATH, 'Validation-20251107T232720Z-1-001', 'Validation')
TEST_PATH = os.path.join(BASE_PATH, 'Test-20251108T015821Z-1-001', 'Test')

# Verify paths exist
print("Checking data paths...")
print(f"Training path exists: {os.path.exists(TRAINING_PATH)}")
print(f"Validation path exists: {os.path.exists(VALIDATION_PATH)}")
print(f"Test path exists: {os.path.exists(TEST_PATH)}")

if not os.path.exists(TRAINING_PATH):
    print(f"\n⚠️ WARNING: Training path not found: {TRAINING_PATH}")
    print("Please verify your data folder structure.")

# Model configuration
IMG_SIZE = (224, 224)
BATCH_SIZE = 32
EPOCHS = 50
NUM_CLASSES = 2
LEARNING_RATE = 0.001

# Model save path (local)
MODEL_NAME = f'glaucoma_custom_cnn_{datetime.now().strftime("%Y%m%d_%H%M%S")}'
CHECKPOINT_PATH = os.path.join(BASE_PATH, 'checkpoints', MODEL_NAME)
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

print(f"\nModels will be saved to: {CHECKPOINT_PATH}")

## Custom CNN Architecture Components

In [None]:
# Squeeze-and-Excitation Block for Attention
class SEBlock(layers.Layer):
    def __init__(self, filters, ratio=16, **kwargs):
        super(SEBlock, self).__init__(**kwargs)
        self.filters = filters
        self.ratio = ratio
        
    def build(self, input_shape):
        self.global_pool = layers.GlobalAveragePooling2D()
        self.dense1 = layers.Dense(self.filters // self.ratio, activation='relu')
        self.dense2 = layers.Dense(self.filters, activation='sigmoid')
        self.reshape = layers.Reshape((1, 1, self.filters))
        self.multiply = layers.Multiply()
        
    def call(self, inputs):
        se = self.global_pool(inputs)
        se = self.dense1(se)
        se = self.dense2(se)
        se = self.reshape(se)
        return self.multiply([inputs, se])
    
    def get_config(self):
        config = super().get_config()
        config.update({"filters": self.filters, "ratio": self.ratio})
        return config


# Residual Block with SE attention
def residual_block_with_se(x, filters, kernel_size=3, stride=1, use_se=True):
    """Residual block with optional Squeeze-and-Excitation"""
    shortcut = x
    
    # Main path
    x = layers.Conv2D(filters, kernel_size, strides=stride, padding='same',
                     kernel_initializer='he_normal')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    
    x = layers.Conv2D(filters, kernel_size, strides=1, padding='same',
                     kernel_initializer='he_normal')(x)
    x = layers.BatchNormalization()(x)
    
    # SE block
    if use_se:
        x = SEBlock(filters)(x)
    
    # Shortcut connection
    if stride != 1 or shortcut.shape[-1] != filters:
        shortcut = layers.Conv2D(filters, 1, strides=stride, padding='same')(shortcut)
        shortcut = layers.BatchNormalization()(shortcut)
    
    x = layers.Add()([x, shortcut])
    x = layers.Activation('relu')(x)
    
    return x


# Multi-scale Inception-like block
def inception_block(x, filters):
    """Multi-scale feature extraction"""
    # 1x1 conv
    branch1 = layers.Conv2D(filters, 1, padding='same', activation='relu',
                           kernel_initializer='he_normal')(x)
    
    # 3x3 conv
    branch2 = layers.Conv2D(filters, 1, padding='same', activation='relu',
                           kernel_initializer='he_normal')(x)
    branch2 = layers.Conv2D(filters, 3, padding='same', activation='relu',
                           kernel_initializer='he_normal')(branch2)
    
    # 5x5 conv (using two 3x3 for efficiency)
    branch3 = layers.Conv2D(filters, 1, padding='same', activation='relu',
                           kernel_initializer='he_normal')(x)
    branch3 = layers.Conv2D(filters, 3, padding='same', activation='relu',
                           kernel_initializer='he_normal')(branch3)
    branch3 = layers.Conv2D(filters, 3, padding='same', activation='relu',
                           kernel_initializer='he_normal')(branch3)
    
    # Max pooling branch
    branch4 = layers.MaxPooling2D(3, strides=1, padding='same')(x)
    branch4 = layers.Conv2D(filters, 1, padding='same', activation='relu',
                           kernel_initializer='he_normal')(branch4)
    
    # Concatenate all branches
    output = layers.Concatenate()([branch1, branch2, branch3, branch4])
    output = layers.BatchNormalization()(output)
    
    return output

## Build Custom Medical Imaging CNN

In [None]:
def build_custom_medical_cnn(input_shape=(224, 224, 3), num_classes=2):
    """
    Custom CNN architecture for medical imaging with:
    - Multi-scale feature extraction
    - Attention mechanisms (SE blocks)
    - Residual connections
    - Dropout for regularization
    """
    inputs = layers.Input(shape=input_shape)
    
    # Initial convolution
    x = layers.Conv2D(64, 7, strides=2, padding='same', 
                     kernel_initializer='he_normal')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.MaxPooling2D(3, strides=2, padding='same')(x)
    
    # Stage 1: Residual blocks with SE
    x = residual_block_with_se(x, 64, use_se=True)
    x = residual_block_with_se(x, 64, use_se=True)
    
    # Stage 2: Multi-scale inception block + residual
    x = inception_block(x, 32)  # Output: 128 filters (4 branches * 32)
    x = layers.MaxPooling2D(2, strides=2)(x)
    x = residual_block_with_se(x, 128, use_se=True)
    x = residual_block_with_se(x, 128, use_se=True)
    
    # Stage 3: Deeper features
    x = residual_block_with_se(x, 256, stride=2, use_se=True)
    x = residual_block_with_se(x, 256, use_se=True)
    x = residual_block_with_se(x, 256, use_se=True)
    
    # Stage 4: High-level features
    x = inception_block(x, 64)  # Output: 256 filters
    x = layers.MaxPooling2D(2, strides=2)(x)
    x = residual_block_with_se(x, 512, use_se=True)
    x = residual_block_with_se(x, 512, use_se=True)
    
    # Global pooling and dense layers
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.5)(x)
    
    # Dense layers with batch normalization
    x = layers.Dense(512, kernel_initializer='he_normal')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Dropout(0.4)(x)
    
    x = layers.Dense(256, kernel_initializer='he_normal')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Dropout(0.3)(x)
    
    # Output layer
    outputs = layers.Dense(num_classes, activation='softmax', 
                          kernel_initializer='glorot_uniform')(x)
    
    model = Model(inputs=inputs, outputs=outputs, name='MedicalImageCNN')
    
    return model


# Build the model
print("Building custom medical imaging CNN...")
model = build_custom_medical_cnn(input_shape=(224, 224, 3), num_classes=NUM_CLASSES)
model.summary()

## Advanced Data Augmentation for Medical Imaging

In [None]:
# Medical image augmentation
print("Setting up data generators...")

train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.15,
    zoom_range=0.2,
    horizontal_flip=True,
    vertical_flip=True,
    brightness_range=[0.8, 1.2],
    fill_mode='reflect'
)

validation_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)

# Load training data
train_generator = train_datagen.flow_from_directory(
    TRAINING_PATH,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    shuffle=True,
    seed=42
)

# Load validation data
validation_generator = validation_datagen.flow_from_directory(
    VALIDATION_PATH,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    shuffle=False,
    seed=42
)

# Load test data (if available)
if os.path.exists(TEST_PATH):
    test_generator = test_datagen.flow_from_directory(
        TEST_PATH,
        target_size=IMG_SIZE,
        batch_size=BATCH_SIZE,
        class_mode='categorical',
        shuffle=False,
        seed=42
    )
    print(f"Test samples: {test_generator.n}")

print(f"Training samples: {train_generator.n}")
print(f"Validation samples: {validation_generator.n}")
print(f"Class indices: {train_generator.class_indices}")

## Handle Class Imbalance

In [None]:
# Calculate class weights to handle imbalance
class_counts = np.bincount(train_generator.classes)
print(f"Class distribution: {class_counts}")

# Compute class weights
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(train_generator.classes),
    y=train_generator.classes
)
class_weight_dict = dict(enumerate(class_weights))
print(f"Class weights: {class_weight_dict}")

# Visualize class distribution
plt.figure(figsize=(8, 5))
plt.bar(['Class 0 (Normal)', 'Class 1 (Glaucoma)'], class_counts, 
        color=['green', 'red'], alpha=0.7)
plt.title('Training Data Class Distribution')
plt.ylabel('Number of Samples')
plt.xlabel('Classes')
for i, count in enumerate(class_counts):
    plt.text(i, count, str(count), ha='center', va='bottom')
plt.tight_layout()
plt.savefig(os.path.join(CHECKPOINT_PATH, 'class_distribution.png'), dpi=300, bbox_inches='tight')
plt.show()

## Focal Loss for Imbalanced Data

In [None]:
# Focal Loss implementation
class FocalLoss(tf.keras.losses.Loss):
    def __init__(self, gamma=2.0, alpha=0.25, **kwargs):
        super().__init__(**kwargs)
        self.gamma = gamma
        self.alpha = alpha
        
    def call(self, y_true, y_pred):
        epsilon = tf.keras.backend.epsilon()
        y_pred = tf.clip_by_value(y_pred, epsilon, 1.0 - epsilon)
        
        # Calculate focal loss
        cross_entropy = -y_true * tf.math.log(y_pred)
        weight = self.alpha * y_true * tf.math.pow(1 - y_pred, self.gamma)
        
        loss = weight * cross_entropy
        return tf.reduce_sum(loss, axis=-1)
    
    def get_config(self):
        config = super().get_config()
        config.update({"gamma": self.gamma, "alpha": self.alpha})
        return config

## Compile Model with Advanced Optimizer

In [None]:
# Use Focal Loss or standard categorical crossentropy
USE_FOCAL_LOSS = True

if USE_FOCAL_LOSS:
    loss = FocalLoss(gamma=2.0, alpha=0.25)
    print("Using Focal Loss")
else:
    loss = 'categorical_crossentropy'
    print("Using Categorical Crossentropy")

# Use AdamW optimizer with weight decay
optimizer = tfa.optimizers.AdamW(
    learning_rate=LEARNING_RATE,
    weight_decay=1e-4
)

# Compile model
model.compile(
    optimizer=optimizer,
    loss=loss,
    metrics=[
        'accuracy',
        tf.keras.metrics.Precision(name='precision'),
        tf.keras.metrics.Recall(name='recall'),
        tf.keras.metrics.AUC(name='auc'),
        tfa.metrics.F1Score(num_classes=NUM_CLASSES, average='macro', name='f1_score')
    ]
)

print("Model compiled successfully!")

## Setup Callbacks

In [None]:
# Callbacks
callbacks = [
    # Early stopping
    EarlyStopping(
        monitor='val_loss',
        patience=15,
        verbose=1,
        restore_best_weights=True
    ),
    
    # Model checkpoint - save best model
    ModelCheckpoint(
        filepath=os.path.join(CHECKPOINT_PATH, 'best_model.keras'),
        monitor='val_auc',
        mode='max',
        save_best_only=True,
        verbose=1
    ),
    
    # Reduce learning rate on plateau
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=5,
        min_lr=1e-7,
        verbose=1
    ),
    
    # TensorBoard logging
    TensorBoard(
        log_dir=os.path.join(CHECKPOINT_PATH, 'logs'),
        histogram_freq=1,
        write_graph=True
    ),
    
    # CSV logger
    CSVLogger(
        os.path.join(CHECKPOINT_PATH, 'training_log.csv'),
        append=False
    )
]

print("Callbacks configured successfully!")
print(f"\nTo monitor training in TensorBoard, run:")
print(f"tensorboard --logdir=\"{os.path.join(CHECKPOINT_PATH, 'logs')}\"")

## Train the Model

In [None]:
# Train the model
print("\nStarting training...")
print(f"Epochs: {EPOCHS}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Learning rate: {LEARNING_RATE}")
print("="*60)

history = model.fit(
    train_generator,
    epochs=EPOCHS,
    validation_data=validation_generator,
    class_weight=class_weight_dict if not USE_FOCAL_LOSS else None,
    callbacks=callbacks,
    verbose=1
)

print("\nTraining completed!")

## Training Visualization

In [None]:
def plot_training_history(history, save_path):
    """Plot training and validation metrics"""
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    fig.suptitle('Training History', fontsize=16, fontweight='bold')
    
    metrics = [
        ('loss', 'Loss'),
        ('accuracy', 'Accuracy'),
        ('precision', 'Precision'),
        ('recall', 'Recall'),
        ('auc', 'AUC'),
        ('f1_score', 'F1 Score')
    ]
    
    for idx, (metric, title) in enumerate(metrics):
        ax = axes[idx // 3, idx % 3]
        
        if metric in history.history:
            ax.plot(history.history[metric], label=f'Train {title}', linewidth=2)
            ax.plot(history.history[f'val_{metric}'], label=f'Val {title}', linewidth=2)
            ax.set_xlabel('Epoch', fontsize=10)
            ax.set_ylabel(title, fontsize=10)
            ax.set_title(f'{title} Over Epochs', fontweight='bold')
            ax.legend(loc='best')
            ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(save_path, 'training_history.png'), dpi=300, bbox_inches='tight')
    plt.show()

plot_training_history(history, CHECKPOINT_PATH)

## Model Evaluation

In [None]:
# Get predictions on validation set
print("Generating predictions on validation set...")
validation_generator.reset()
y_true = validation_generator.classes
y_pred_probs = model.predict(validation_generator, verbose=1)
y_pred = np.argmax(y_pred_probs, axis=1)

print("Predictions completed!")

In [None]:
# Confusion Matrix
def plot_confusion_matrix(y_true, y_pred, labels, save_path):
    cm = confusion_matrix(y_true, y_pred, normalize='true')
    cm_percentage = cm * 100
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm_percentage, annot=True, fmt='.2f', cmap='Blues',
                xticklabels=labels, yticklabels=labels,
                cbar_kws={'label': 'Percentage'})
    plt.title('Normalized Confusion Matrix (%)', fontsize=14, fontweight='bold')
    plt.ylabel('True Label', fontsize=12)
    plt.xlabel('Predicted Label', fontsize=12)
    plt.tight_layout()
    plt.savefig(os.path.join(save_path, 'confusion_matrix.png'), dpi=300, bbox_inches='tight')
    plt.show()
    
    # Also show counts
    cm_counts = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm_counts, annot=True, fmt='d', cmap='Greens',
                xticklabels=labels, yticklabels=labels,
                cbar_kws={'label': 'Count'})
    plt.title('Confusion Matrix (Counts)', fontsize=14, fontweight='bold')
    plt.ylabel('True Label', fontsize=12)
    plt.xlabel('Predicted Label', fontsize=12)
    plt.tight_layout()
    plt.savefig(os.path.join(save_path, 'confusion_matrix_counts.png'), dpi=300, bbox_inches='tight')
    plt.show()

class_labels = ['Normal', 'Glaucoma']
plot_confusion_matrix(y_true, y_pred, class_labels, CHECKPOINT_PATH)

In [None]:
# Classification Report
print("\n" + "="*60)
print("CLASSIFICATION REPORT")
print("="*60)
print(classification_report(y_true, y_pred, target_names=class_labels, digits=4))

# Save report to file
with open(os.path.join(CHECKPOINT_PATH, 'classification_report.txt'), 'w') as f:
    f.write(classification_report(y_true, y_pred, target_names=class_labels, digits=4))

## ROC-AUC Curve

In [None]:
def plot_roc_curve(y_true, y_pred_probs, class_labels, save_path):
    """Plot ROC curve for each class"""
    plt.figure(figsize=(10, 8))
    
    # For binary classification, plot for positive class
    fpr, tpr, thresholds = roc_curve(y_true, y_pred_probs[:, 1])
    roc_auc = auc(fpr, tpr)
    
    plt.plot(fpr, tpr, color='darkorange', lw=2, 
             label=f'ROC curve (AUC = {roc_auc:.4f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random Classifier')
    
    # Find optimal threshold (Youden's J statistic)
    J = tpr - fpr
    ix = np.argmax(J)
    best_thresh = thresholds[ix]
    plt.scatter(fpr[ix], tpr[ix], marker='o', color='red', s=200, 
                label=f'Optimal Threshold = {best_thresh:.3f}')
    
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate', fontsize=12)
    plt.ylabel('True Positive Rate (Sensitivity)', fontsize=12)
    plt.title('Receiver Operating Characteristic (ROC) Curve', fontsize=14, fontweight='bold')
    plt.legend(loc="lower right", fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(save_path, 'roc_curve.png'), dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"\nOptimal Threshold: {best_thresh:.4f}")
    print(f"Sensitivity at optimal threshold: {tpr[ix]:.4f}")
    print(f"Specificity at optimal threshold: {1-fpr[ix]:.4f}")
    
    return best_thresh

optimal_threshold = plot_roc_curve(y_true, y_pred_probs, class_labels, CHECKPOINT_PATH)

## Precision-Recall Curve

In [None]:
def plot_precision_recall_curve(y_true, y_pred_probs, save_path):
    """Plot Precision-Recall curve"""
    precision, recall, thresholds = precision_recall_curve(y_true, y_pred_probs[:, 1])
    avg_precision = average_precision_score(y_true, y_pred_probs[:, 1])
    
    plt.figure(figsize=(10, 8))
    plt.plot(recall, precision, color='blue', lw=2,
             label=f'Precision-Recall curve (AP = {avg_precision:.4f})')
    
    # F1 score iso-lines
    f_scores = np.linspace(0.2, 0.9, num=8)
    for f_score in f_scores:
        x = np.linspace(0.01, 1)
        y = f_score * x / (2 * x - f_score)
        plt.plot(x[y >= 0], y[y >= 0], color='gray', alpha=0.3, linestyle='--')
        plt.annotate(f'F1={f_score:.1f}', xy=(0.9, y[45] + 0.02), alpha=0.4, fontsize=8)
    
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('Recall (Sensitivity)', fontsize=12)
    plt.ylabel('Precision', fontsize=12)
    plt.title('Precision-Recall Curve', fontsize=14, fontweight='bold')
    plt.legend(loc="lower left", fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(save_path, 'precision_recall_curve.png'), dpi=300, bbox_inches='tight')
    plt.show()

plot_precision_recall_curve(y_true, y_pred_probs, CHECKPOINT_PATH)

## Grad-CAM Visualization for Interpretability

In [None]:
def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None):
    """Generate Grad-CAM heatmap"""
    grad_model = Model(
        inputs=[model.inputs],
        outputs=[model.get_layer(last_conv_layer_name).output, model.output]
    )
    
    with tf.GradientTape() as tape:
        conv_outputs, predictions = grad_model(img_array)
        if pred_index is None:
            pred_index = tf.argmax(predictions[0])
        class_channel = predictions[:, pred_index]
    
    grads = tape.gradient(class_channel, conv_outputs)
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
    
    conv_outputs = conv_outputs[0]
    heatmap = conv_outputs @ pooled_grads[..., tf.newaxis]
    heatmap = tf.squeeze(heatmap)
    
    heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
    return heatmap.numpy()


def display_gradcam(img_path, model, last_conv_layer_name, alpha=0.4):
    """Display Grad-CAM visualization"""
    img = tf.keras.preprocessing.image.load_img(img_path, target_size=IMG_SIZE)
    img_array = tf.keras.preprocessing.image.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0) / 255.0
    
    preds = model.predict(img_array, verbose=0)
    pred_class = np.argmax(preds[0])
    pred_prob = preds[0][pred_class]
    
    heatmap = make_gradcam_heatmap(img_array, model, last_conv_layer_name)
    
    heatmap = cv2.resize(heatmap, (img.size[0], img.size[1]))
    heatmap = np.uint8(255 * heatmap)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    
    img_array_uint8 = np.uint8(255 * img_array[0])
    superimposed_img = cv2.addWeighted(img_array_uint8, 1-alpha, heatmap, alpha, 0)
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    axes[0].imshow(img)
    axes[0].set_title('Original Image', fontweight='bold')
    axes[0].axis('off')
    
    axes[1].imshow(heatmap)
    axes[1].set_title('Grad-CAM Heatmap', fontweight='bold')
    axes[1].axis('off')
    
    axes[2].imshow(superimposed_img)
    axes[2].set_title(f'Superimposed\nPrediction: {class_labels[pred_class]} ({pred_prob:.2%})', 
                     fontweight='bold')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    return pred_class, pred_prob


# Find the last convolutional layer
last_conv_layer = None
for layer in reversed(model.layers):
    if isinstance(layer, layers.Conv2D):
        last_conv_layer = layer.name
        break

print(f"Last convolutional layer: {last_conv_layer}")
print("\nReady to visualize Grad-CAM!")
print("Use: display_gradcam('path/to/image.png', model, last_conv_layer)")

In [None]:
# Example: Visualize Grad-CAM for sample images
print("Generating Grad-CAM visualizations...")

# Get sample images from validation set
normal_images = glob.glob(os.path.join(VALIDATION_PATH, '0', '*.png'))[:5]
glaucoma_images = glob.glob(os.path.join(VALIDATION_PATH, '1', '*.png'))[:5]

if len(normal_images) > 0:
    print("\nGrad-CAM Visualization for Normal Cases:")
    print("="*60)
    for img_path in normal_images[:2]:  # Show 2 examples
        print(f"\nImage: {os.path.basename(img_path)}")
        display_gradcam(img_path, model, last_conv_layer)

if len(glaucoma_images) > 0:
    print("\n" + "="*60)
    print("Grad-CAM Visualization for Glaucoma Cases:")
    print("="*60)
    for img_path in glaucoma_images[:2]:  # Show 2 examples
        print(f"\nImage: {os.path.basename(img_path)}")
        display_gradcam(img_path, model, last_conv_layer)

## Save Final Model

In [None]:
# Save in multiple formats
print("\nSaving models...")

# 1. Native Keras format (recommended)
model.save(os.path.join(CHECKPOINT_PATH, 'final_model.keras'))
print("✓ Model saved in Keras format (.keras)")

# 2. H5 format for compatibility
model.save(os.path.join(CHECKPOINT_PATH, 'final_model.h5'))
print("✓ Model saved in H5 format (.h5)")

# 3. TensorFlow Lite for mobile deployment
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

with open(os.path.join(CHECKPOINT_PATH, 'final_model.tflite'), 'wb') as f:
    f.write(tflite_model)
print("✓ Model saved in TFLite format (.tflite)")

# 4. SavedModel format for TensorFlow Serving
model.save(os.path.join(CHECKPOINT_PATH, 'saved_model'), save_format='tf')
print("✓ Model saved in SavedModel format")

print(f"\n✓ All models saved to: {CHECKPOINT_PATH}")

## Summary Statistics

In [None]:
# Print final summary
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

print("\n" + "="*80)
print("TRAINING SUMMARY")
print("="*80)

print(f"\nModel Architecture: Custom Medical Imaging CNN")
print(f"Total Parameters: {model.count_params():,}")
print(f"Training Samples: {train_generator.n}")
print(f"Validation Samples: {validation_generator.n}")
print(f"Epochs Trained: {len(history.history['loss'])}")

print("\n" + "-"*80)
print("BEST VALIDATION METRICS")
print("-"*80)

metrics_to_show = ['val_accuracy', 'val_precision', 'val_recall', 'val_auc', 'val_f1_score']
for metric in metrics_to_show:
    if metric in history.history:
        best_value = max(history.history[metric])
        best_epoch = np.argmax(history.history[metric]) + 1
        print(f"{metric.replace('val_', '').upper():.<30} {best_value:.4f} (Epoch {best_epoch})")

print("\n" + "-"*80)
print("VALIDATION SET PERFORMANCE")
print("-"*80)

accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred, average='binary')
recall = recall_score(y_true, y_pred, average='binary')
f1 = f1_score(y_true, y_pred, average='binary')
roc_auc = roc_auc_score(y_true, y_pred_probs[:, 1])

print(f"{'Accuracy':.<30} {accuracy:.4f} ({accuracy*100:.2f}%)")
print(f"{'Precision':.<30} {precision:.4f}")
print(f"{'Recall (Sensitivity)':.<30} {recall:.4f}")
print(f"{'F1 Score':.<30} {f1:.4f}")
print(f"{'ROC-AUC':.<30} {roc_auc:.4f}")

# Calculate specificity
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
specificity = tn / (tn + fp)
print(f"{'Specificity':.<30} {specificity:.4f}")

print("\n" + "="*80)
print(f"Models saved to: {CHECKPOINT_PATH}")
print("="*80)

# Save summary to text file
with open(os.path.join(CHECKPOINT_PATH, 'training_summary.txt'), 'w') as f:
    f.write("TRAINING SUMMARY\n")
    f.write("="*80 + "\n\n")
    f.write(f"Model: Custom Medical Imaging CNN\n")
    f.write(f"Parameters: {model.count_params():,}\n")
    f.write(f"Training Samples: {train_generator.n}\n")
    f.write(f"Validation Samples: {validation_generator.n}\n")
    f.write(f"Epochs: {len(history.history['loss'])}\n\n")
    f.write(f"Accuracy: {accuracy:.4f}\n")
    f.write(f"Precision: {precision:.4f}\n")
    f.write(f"Recall: {recall:.4f}\n")
    f.write(f"F1 Score: {f1:.4f}\n")
    f.write(f"ROC-AUC: {roc_auc:.4f}\n")
    f.write(f"Specificity: {specificity:.4f}\n")

## Single Image Prediction Function

In [None]:
def predict_single_image(image_path, model, show_gradcam=True):
    """
    Predict a single image and optionally show Grad-CAM
    """
    img = tf.keras.preprocessing.image.load_img(image_path, target_size=IMG_SIZE)
    img_array = tf.keras.preprocessing.image.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0) / 255.0
    
    prediction = model.predict(img_array, verbose=0)
    pred_class = np.argmax(prediction[0])
    pred_prob = prediction[0][pred_class]
    
    print("\n" + "="*60)
    print("PREDICTION RESULT")
    print("="*60)
    print(f"Image: {os.path.basename(image_path)}")
    print(f"Predicted Class: {class_labels[pred_class]}")
    print(f"Confidence: {pred_prob:.2%}")
    print(f"\nClass Probabilities:")
    for i, label in enumerate(class_labels):
        print(f"  {label}: {prediction[0][i]:.2%}")
    print("="*60)
    
    if show_gradcam:
        display_gradcam(image_path, model, last_conv_layer)
    else:
        plt.figure(figsize=(6, 6))
        plt.imshow(img)
        plt.title(f'{class_labels[pred_class]} ({pred_prob:.2%})', fontweight='bold')
        plt.axis('off')
        plt.show()
    
    return pred_class, pred_prob


print("\nReady to predict!")
print("\nExample usage:")
print("predict_single_image(r'c:\\path\\to\\your\\image.png', model, show_gradcam=True)")

## Open Results Folder

In [None]:
# Open the checkpoint folder in Windows Explorer
import subprocess

print(f"\nOpening results folder: {CHECKPOINT_PATH}")
subprocess.Popen(f'explorer "{CHECKPOINT_PATH}"')

print("\n" + "="*80)
print("TRAINING COMPLETE!")
print("="*80)
print(f"\nAll results saved to: {CHECKPOINT_PATH}")
print("\nGenerated files:")
print("  - best_model.keras (best model during training)")
print("  - final_model.keras (final model)")
print("  - final_model.h5 (H5 format)")
print("  - final_model.tflite (mobile deployment)")
print("  - saved_model/ (TensorFlow Serving)")
print("  - training_log.csv (detailed metrics)")
print("  - training_history.png")
print("  - confusion_matrix.png")
print("  - roc_curve.png")
print("  - precision_recall_curve.png")
print("  - classification_report.txt")
print("  - training_summary.txt")
print("="*80)