<a href="https://colab.research.google.com/github/hub-ARIYAN/glens_detection_tpu/blob/main/Glens27.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Gravitational Lens Detection with MobileNetV3
# ==============================================
#
# ## Overview
# This notebook trains a MobileNetV3-based binary classifier to detect gravitational lens images.
# - **Class A (1)**: Lensed galaxies  
# - **Class B (0)**: Non-lensed galaxies
#
# ## Quick Start
# 1. Upload your CSV file with columns: `filename`, `label` (A/B)
# 2. Upload/mount your PNG images directory
# 3. Adjust hyperparameters in the "Configuration" section below
# 4. Run all cells
#
# ## Key Features
# - Two-phase training: frozen backbone → fine-tuning
# - Class imbalance handling (weighting/oversampling)
# - Comprehensive evaluation with visualizations
# - Grad-CAM interpretability
# - Ablation experiments
# - Full reproducibility with seeding
#
# ## Hardware Requirements
# - GPU recommended (T4, V100, A100)
# - ~8GB GPU memory for batch_size=32


# CONFIGURATION & SETUP


In [None]:
# Core hyperparameters - adjust these as needed
DATA_DIR = "/content/images"  # Directory containing PNG files
CSV_PATH = "/content/labels.csv"  # CSV with filename,label columns
BATCH_SIZE = 32  # Reduce to 16 if OOM
IMG_SIZE = (224, 224)  # MobileNetV3 input size
SPLIT = (0.8, 0.1, 0.1)  # Train/Val/Test split
ACTIVATION = 'swish'  # Options: 'relu', 'leaky_relu', 'swish'
FREEZE_EPOCHS = 10  # Phase 1: train head only
TOTAL_EPOCHS = 50  # Total training epochs
PATIENCE = 10  # Early stopping patience
SEED = 42  # Reproducibility

# Advanced options
MOBILENET_VERSION = 'Large'  # 'Large' or 'Small'
DROPOUT_RATE = 0.5
HIDDEN_UNITS = 128
USE_BATCH_NORM = True
OVERSAMPLING_THRESHOLD = 10  # Use oversampling if imbalance > 1:10

print("🚀 Gravitational Lens Classifier Setup Complete!")
print(f"Configuration: {MOBILENET_VERSION} MobileNetV3, {ACTIVATION} activation")
print(f"Training strategy: {FREEZE_EPOCHS} frozen + {TOTAL_EPOCHS-FREEZE_EPOCHS} fine-tune epochs")



# IMPORTS & ENVIRONMENT SETUP



In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, roc_curve
from sklearn.metrics import precision_recall_curve, average_precision_score, f1_score
import warnings
warnings.filterwarnings('ignore')

# TensorFlow imports
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Model
from tensorflow.keras.applications import MobileNetV3Large, MobileNetV3Small
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.optimizers import Adam, AdamW
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.metrics import BinaryAccuracy, Precision, Recall, AUC

# Check GPU availability
print("🔧 Environment Setup:")
print(f"TensorFlow version: {tf.__version__}")
print(f"GPU available: {len(tf.config.list_physical_devices('GPU')) > 0}")
if len(tf.config.list_physical_devices('GPU')) > 0:
    print(f"GPU device: {tf.config.list_physical_devices('GPU')[0]}")

# Set random seeds for reproducibility
tf.random.set_seed(SEED)
np.random.seed(SEED)

# Configure memory growth for GPU
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(f"GPU configuration error: {e}")


# DATA LOADING & PREPROCESSING


In [None]:
def load_dataset(csv_path, data_dir, split_ratios=(0.8, 0.1, 0.1), seed=42):
    """
    Load dataset from CSV and create train/val/test splits.

    Args:
        csv_path: Path to CSV with 'filename' and 'label' columns
        data_dir: Directory containing PNG files
        split_ratios: (train, val, test) split ratios
        seed: Random seed for reproducibility

    Returns:
        train_df, val_df, test_df, class_weights
    """
    print("📊 Loading dataset...")

    # Read CSV and validate
    df = pd.read_csv(csv_path)
    print(f"Loaded CSV with {len(df)} rows")
    print(f"Columns: {list(df.columns)}")

    # Ensure required columns exist
    if 'filename' not in df.columns or 'label' not in df.columns:
        raise ValueError("CSV must contain 'filename' and 'label' columns")

    # Map labels A/B to 1/0
    label_map = {'A': 1, 'B': 0}  # A = lensed, B = non-lensed
    df['binary_label'] = df['label'].map(label_map)

    # Check for missing mappings
    if df['binary_label'].isna().any():
        print("⚠️ Warning: Some labels couldn't be mapped. Unique labels:", df['label'].unique())
        df = df.dropna(subset=['binary_label'])

    # Verify files exist
    existing_files = []
    missing_count = 0
    for filename in df['filename']:
        filepath = os.path.join(data_dir, filename)
        if os.path.exists(filepath):
            existing_files.append(True)
        else:
            existing_files.append(False)
            missing_count += 1

    df['file_exists'] = existing_files
    df = df[df['file_exists']].copy()

    if missing_count > 0:
        print(f"⚠️ Warning: {missing_count} files not found, using {len(df)} available files")

    # Class distribution analysis
    class_counts = df['binary_label'].value_counts().sort_index()
    print(f"\n📈 Class Distribution:")
    print(f"Class 0 (Non-lensed): {class_counts.get(0, 0):,} samples")
    print(f"Class 1 (Lensed): {class_counts.get(1, 0):,} samples")

    imbalance_ratio = class_counts.max() / class_counts.min() if class_counts.min() > 0 else float('inf')
    print(f"Imbalance ratio: {imbalance_ratio:.2f}:1")

    # Calculate class weights
    n_samples = len(df)
    n_classes = 2
    class_weights = {}
    for class_id in [0, 1]:
        class_weights[class_id] = n_samples / (n_classes * class_counts.get(class_id, 1))

    print(f"Class weights: {class_weights}")

    # Create stratified splits
    df_shuffled = df.sample(frac=1, random_state=seed).reset_index(drop=True)

    # Stratified split to maintain class distribution
    train_size = int(len(df_shuffled) * split_ratios[0])
    val_size = int(len(df_shuffled) * split_ratios[1])

    # Simple stratified approach
    class_0_df = df_shuffled[df_shuffled['binary_label'] == 0]
    class_1_df = df_shuffled[df_shuffled['binary_label'] == 1]

    train_0 = class_0_df[:int(len(class_0_df) * split_ratios[0])]
    val_0 = class_0_df[int(len(class_0_df) * split_ratios[0]):int(len(class_0_df) * (split_ratios[0] + split_ratios[1]))]
    test_0 = class_0_df[int(len(class_0_df) * (split_ratios[0] + split_ratios[1])):]

    train_1 = class_1_df[:int(len(class_1_df) * split_ratios[0])]
    val_1 = class_1_df[int(len(class_1_df) * split_ratios[0]):int(len(class_1_df) * (split_ratios[0] + split_ratios[1]))]
    test_1 = class_1_df[int(len(class_1_df) * (split_ratios[0] + split_ratios[1])):]

    train_df = pd.concat([train_0, train_1]).sample(frac=1, random_state=seed).reset_index(drop=True)
    val_df = pd.concat([val_0, val_1]).sample(frac=1, random_state=seed).reset_index(drop=True)
    test_df = pd.concat([test_0, test_1]).sample(frac=1, random_state=seed).reset_index(drop=True)

    print(f"\n📊 Split sizes:")
    print(f"Train: {len(train_df):,} ({len(train_df)/len(df)*100:.1f}%)")
    print(f"Val: {len(val_df):,} ({len(val_df)/len(df)*100:.1f}%)")
    print(f"Test: {len(test_df):,} ({len(test_df)/len(df)*100:.1f}%)")

    return train_df, val_df, test_df, class_weights, imbalance_ratio

def create_data_pipeline(df, data_dir, batch_size, img_size, augment=False, cache=True):
    """
    Create optimized tf.data pipeline for image loading and preprocessing.

    Args:
        df: DataFrame with 'filename' and 'binary_label' columns
        data_dir: Directory containing images
        batch_size: Batch size for training
        img_size: Target image size (height, width)
        augment: Whether to apply data augmentation
        cache: Whether to cache the dataset

    Returns:
        tf.data.Dataset
    """
    # Create file paths and labels
    filepaths = [os.path.join(data_dir, fname) for fname in df['filename']]
    labels = df['binary_label'].values.astype(np.float32)

    # Create dataset from file paths and labels
    dataset = tf.data.Dataset.from_tensor_slices((filepaths, labels))

    # Image loading and preprocessing function
    def load_and_preprocess_image(filepath, label):
        # Load image
        image = tf.io.read_file(filepath)
        image = tf.image.decode_png(image, channels=3)

        # Resize and normalize
        image = tf.image.resize(image, img_size)
        image = tf.cast(image, tf.float32) / 255.0

        return image, label

    # Data augmentation function
    def augment_image(image, label):
        # Random horizontal flip
        image = tf.image.random_flip_left_right(image)

        # Random vertical flip
        image = tf.image.random_flip_up_down(image)

        # Random rotation (small angles)
        image = tf.image.rot90(image, k=tf.random.uniform([], 0, 4, dtype=tf.int32))

        # Random brightness and contrast
        image = tf.image.random_brightness(image, 0.1)
        image = tf.image.random_contrast(image, 0.9, 1.1)

        # Random zoom/shift simulation via random crop and resize
        if tf.random.uniform([]) < 0.5:
            crop_size = tf.random.uniform([], 0.85, 1.0)
            h, w = img_size[0], img_size[1]
            new_h = tf.cast(h * crop_size, tf.int32)
            new_w = tf.cast(w * crop_size, tf.int32)
            image = tf.image.random_crop(image, [new_h, new_w, 3])
            image = tf.image.resize(image, img_size)

        # Add Gaussian noise
        noise = tf.random.normal(tf.shape(image), mean=0.0, stddev=0.02)
        image = tf.clip_by_value(image + noise, 0.0, 1.0)

        return image, label

    # Apply loading and preprocessing
    dataset = dataset.map(load_and_preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)

    # Apply augmentation if requested
    if augment:
        dataset = dataset.map(augment_image, num_parallel_calls=tf.data.AUTOTUNE)

    # Shuffle for training datasets
    if augment:  # Assuming augment=True means training data
        dataset = dataset.shuffle(buffer_size=min(1000, len(df)))

    # Cache if requested and dataset is small enough
    if cache:
        dataset = dataset.cache()

    # Batch and prefetch
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)

    return dataset


# MODEL ARCHITECTURE


In [None]:
def get_activation_layer(activation_name):
    """Get activation layer by name."""
    if activation_name.lower() == 'relu':
        return layers.ReLU()
    elif activation_name.lower() == 'leaky_relu':
        return layers.LeakyReLU(alpha=0.1)
    elif activation_name.lower() in ['swish', 'hard_swish']:
        return layers.Activation('swish')
    else:
        raise ValueError(f"Unsupported activation: {activation_name}")

def create_mobilenet_classifier(img_size, mobilenet_version='Large', activation='swish',
                               dropout_rate=0.5, hidden_units=128, use_batch_norm=True):
    """
    Create MobileNetV3-based binary classifier.

    Args:
        img_size: Input image size (height, width, channels)
        mobilenet_version: 'Large' or 'Small'
        activation: Activation function name
        dropout_rate: Dropout rate in classifier head
        hidden_units: Hidden layer size
        use_batch_norm: Whether to use batch normalization

    Returns:
        Compiled Keras model
    """
    # Input layer
    inputs = keras.Input(shape=(*img_size, 3))

    # MobileNetV3 backbone
    if mobilenet_version.lower() == 'large':
        backbone = MobileNetV3Large(
            input_shape=(*img_size, 3),
            include_top=False,
            weights='imagenet'
        )
    else:
        backbone = MobileNetV3Small(
            input_shape=(*img_size, 3),
            include_top=False,
            weights='imagenet'
        )

    # Freeze backbone initially
    backbone.trainable = False

    # Feature extraction
    x = backbone(inputs, training=False)

    # Classifier head
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(dropout_rate)(x)

    # Hidden layer
    x = layers.Dense(hidden_units)(x)
    if use_batch_norm:
        x = layers.BatchNormalization()(x)
    x = get_activation_layer(activation)(x)

    # Output layer
    outputs = layers.Dense(1, activation='sigmoid', name='predictions')(x)

    model = Model(inputs, outputs)

    # Compile model
    model.compile(
        optimizer=Adam(learning_rate=1e-3),
        loss=BinaryCrossentropy(),
        metrics=[
            BinaryAccuracy(name='accuracy'),
            Precision(name='precision'),
            Recall(name='recall'),
            AUC(name='auc')
        ]
    )

    print(f"🏗️ Created {mobilenet_version} MobileNetV3 with {activation} activation")
    print(f"Backbone parameters: {backbone.count_params():,}")
    print(f"Total parameters: {model.count_params():,}")

    return model, backbone



# TRAINING UTILITIES


In [None]:
def create_callbacks(patience=10, monitor='val_auc', mode='max'):
    """Create training callbacks."""
    callbacks = [
        EarlyStopping(
            monitor=monitor,
            patience=patience,
            restore_best_weights=True,
            verbose=1,
            mode=mode
        ),
        ReduceLROnPlateau(
            monitor=monitor,
            factor=0.5,
            patience=patience//2,
            min_lr=1e-7,
            verbose=1,
            mode=mode
        ),
        ModelCheckpoint(
            'best_model.h5',
            monitor=monitor,
            save_best_only=True,
            mode=mode,
            verbose=1
        )
    ]
    return callbacks

def handle_class_imbalance(train_df, imbalance_ratio, threshold=10):
    """
    Handle class imbalance using class weights or oversampling.

    Args:
        train_df: Training dataframe
        imbalance_ratio: Ratio of majority to minority class
        threshold: Use oversampling if imbalance > threshold

    Returns:
        Modified dataframe and class weights
    """
    class_counts = train_df['binary_label'].value_counts().sort_index()

    # Calculate class weights
    n_samples = len(train_df)
    class_weights = {
        0: n_samples / (2 * class_counts.get(0, 1)),
        1: n_samples / (2 * class_counts.get(1, 1))
    }

    if imbalance_ratio > threshold:
        print(f"⚖️ High imbalance ({imbalance_ratio:.1f}:1), applying oversampling...")

        # Oversample minority class
        minority_class = class_counts.idxmin()
        majority_class = class_counts.idxmax()

        minority_df = train_df[train_df['binary_label'] == minority_class]
        majority_df = train_df[train_df['binary_label'] == majority_class]

        # Calculate oversampling factor
        target_size = len(majority_df)
        oversample_factor = target_size // len(minority_df)
        remainder = target_size % len(minority_df)

        # Create oversampled dataset
        oversampled_minority = pd.concat([minority_df] * oversample_factor +
                                       [minority_df.sample(remainder, random_state=SEED)])

        train_df = pd.concat([majority_df, oversampled_minority]).sample(
            frac=1, random_state=SEED).reset_index(drop=True)

        # Recalculate class weights
        class_weights = {0: 1.0, 1: 1.0}  # Balanced after oversampling

        print(f"After oversampling: {len(train_df):,} samples")
    else:
        print(f"⚖️ Using class weights (imbalance {imbalance_ratio:.1f}:1 < threshold {threshold})")

    return train_df, class_weights


# EVALUATION & VISUALIZATION

In [None]:
def plot_training_history(history):
    """Plot training curves."""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))

    # Loss
    axes[0,0].plot(history.history['loss'], label='Train Loss', linewidth=2)
    axes[0,0].plot(history.history['val_loss'], label='Val Loss', linewidth=2)
    axes[0,0].set_title('Loss', fontsize=14, fontweight='bold')
    axes[0,0].set_xlabel('Epoch')
    axes[0,0].legend()
    axes[0,0].grid(True, alpha=0.3)

    # Accuracy
    axes[0,1].plot(history.history['accuracy'], label='Train Acc', linewidth=2)
    axes[0,1].plot(history.history['val_accuracy'], label='Val Acc', linewidth=2)
    axes[0,1].set_title('Accuracy', fontsize=14, fontweight='bold')
    axes[0,1].set_xlabel('Epoch')
    axes[0,1].legend()
    axes[0,1].grid(True, alpha=0.3)

    # AUC
    axes[1,0].plot(history.history['auc'], label='Train AUC', linewidth=2)
    axes[1,0].plot(history.history['val_auc'], label='Val AUC', linewidth=2)
    axes[1,0].set_title('AUC', fontsize=14, fontweight='bold')
    axes[1,0].set_xlabel('Epoch')
    axes[1,0].legend()
    axes[1,0].grid(True, alpha=0.3)

    # Learning Rate
    if 'lr' in history.history:
        axes[1,1].plot(history.history['lr'], linewidth=2, color='orange')
        axes[1,1].set_title('Learning Rate', fontsize=14, fontweight='bold')
        axes[1,1].set_xlabel('Epoch')
        axes[1,1].set_yscale('log')
        axes[1,1].grid(True, alpha=0.3)
    else:
        axes[1,1].axis('off')

    plt.tight_layout()
    plt.show()

def evaluate_model(model, test_dataset, test_df):
    """
    Comprehensive model evaluation with metrics and visualizations.

    Args:
        model: Trained model
        test_dataset: Test tf.data.Dataset
        test_df: Test dataframe for additional analysis

    Returns:
        Dictionary of metrics
    """
    print("📊 Evaluating model...")

    # Get predictions
    y_pred_proba = model.predict(test_dataset, verbose=0)
    y_pred = (y_pred_proba > 0.5).astype(int).flatten()
    y_true = test_df['binary_label'].values

    # Calculate metrics
    accuracy = (y_pred == y_true).mean()
    precision = ((y_pred == 1) & (y_true == 1)).sum() / max((y_pred == 1).sum(), 1)
    recall = ((y_pred == 1) & (y_true == 1)).sum() / max((y_true == 1).sum(), 1)
    f1 = 2 * precision * recall / max(precision + recall, 1e-8)
    auc_roc = roc_auc_score(y_true, y_pred_proba)
    auc_pr = average_precision_score(y_true, y_pred_proba)

    metrics = {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1_score': f1,
        'auc_roc': auc_roc,
        'auc_pr': auc_pr
    }

    # Print metrics
    print("🎯 Test Metrics:")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1-Score: {f1:.4f}")
    print(f"ROC-AUC: {auc_roc:.4f}")
    print(f"PR-AUC: {auc_pr:.4f}")

    # Visualizations
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))

    # Confusion Matrix
    cm = confusion_matrix(y_true, y_pred)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[0,0])
    axes[0,0].set_title('Confusion Matrix', fontsize=14, fontweight='bold')
    axes[0,0].set_xlabel('Predicted')
    axes[0,0].set_ylabel('Actual')

    # ROC Curve
    fpr, tpr, _ = roc_curve(y_true, y_pred_proba)
    axes[0,1].plot(fpr, tpr, linewidth=2, label=f'ROC (AUC = {auc_roc:.3f})')
    axes[0,1].plot([0, 1], [0, 1], 'k--', alpha=0.5)
    axes[0,1].set_title('ROC Curve', fontsize=14, fontweight='bold')
    axes[0,1].set_xlabel('False Positive Rate')
    axes[0,1].set_ylabel('True Positive Rate')
    axes[0,1].legend()
    axes[0,1].grid(True, alpha=0.3)

    # Precision-Recall Curve
    precision_curve, recall_curve, _ = precision_recall_curve(y_true, y_pred_proba)
    axes[1,0].plot(recall_curve, precision_curve, linewidth=2, label=f'PR (AUC = {auc_pr:.3f})')
    axes[1,0].set_title('Precision-Recall Curve', fontsize=14, fontweight='bold')
    axes[1,0].set_xlabel('Recall')
    axes[1,0].set_ylabel('Precision')
    axes[1,0].legend()
    axes[1,0].grid(True, alpha=0.3)

    # Prediction Distribution
    axes[1,1].hist(y_pred_proba[y_true == 0], bins=30, alpha=0.7, label='Non-lensed', color='red')
    axes[1,1].hist(y_pred_proba[y_true == 1], bins=30, alpha=0.7, label='Lensed', color='blue')
    axes[1,1].axvline(x=0.5, color='black', linestyle='--', alpha=0.8)
    axes[1,1].set_title('Prediction Distribution', fontsize=14, fontweight='bold')
    axes[1,1].set_xlabel('Predicted Probability')
    axes[1,1].set_ylabel('Count')
    axes[1,1].legend()
    axes[1,1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    return metrics

def create_gradcam_visualization(model, dataset, num_samples=4):
    """
    Create Grad-CAM visualizations for model interpretability.

    Args:
        model: Trained model
        dataset: Dataset to sample from
        num_samples: Number of samples to visualize
    """
    print("🔍 Generating Grad-CAM visualizations...")

    # Get a batch of images
    for images, labels in dataset.take(1):
        break

    # Select samples
    indices = np.random.choice(len(images), min(num_samples, len(images)), replace=False)
    sample_images = tf.gather(images, indices)
    sample_labels = tf.gather(labels, indices)

    # Get predictions
    predictions = model.predict(sample_images, verbose=0)

    # Create Grad-CAM heatmaps
    def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None):
        # First, we create a model that maps the input image to the activations
        # of the last conv layer as well as the output predictions
        grad_model = tf.keras.models.Model(
            [model.inputs], [model.get_layer(last_conv_layer_name).output, model.output]
        )

        # Then, we compute the gradient of the top predicted class for our input image
        # with respect to the activations of the last conv layer
        with tf.GradientTape() as tape:
            last_conv_layer_output, preds = grad_model(img_array)
            if pred_index is None:
                pred_index = tf.argmax(preds[0])
            class_channel = preds[:, pred_index]

        # This is the gradient of the output neuron (top predicted or chosen)
        # with regard to the output feature map of the last conv layer
        grads = tape.gradient(class_channel, last_conv_layer_output)

        # This is a vector where each entry is the mean intensity of the gradient
        # over a specific feature map channel
        pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))

        # We multiply each channel in the feature map array
        # by "how important this channel is" with regard to the top predicted class
        # then sum all the channels to obtain the heatmap class activation
        last_conv_layer_output = last_conv_layer_output[0]
        heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
        heatmap = tf.squeeze(heatmap)

        # For visualization purpose, we will also normalize the heatmap between 0 & 1
        heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
        return heatmap.numpy()

    # Find last convolutional layer
    conv_layers = [layer.name for layer in model.layers if 'conv' in layer.name.lower()]
    if not conv_layers:
        # For MobileNetV3, find the last layer in the backbone
        backbone_layers = [layer.name for layer in model.layers[1].layers if 'conv' in layer.name.lower()]
        if backbone_layers:
            last_conv_layer = f"{model.layers[1].name}/{backbone_layers[-1]}"
        else:
            print("⚠️ Could not find convolutional layers for Grad-CAM")
            return
    else:
        last_conv_layer = conv_layers[-1]

    # Create visualizations
    fig, axes = plt.subplots(2, num_samples, figsize=(4*num_samples, 8))
    if num_samples == 1:
        axes = axes.reshape(-1, 1)

    for i in range(num_samples):
        img = sample_images[i]
        true_label = int(sample_labels[i])
        pred_prob = predictions[i][0]
        pred_label = int(pred_prob > 0.5)

        # Original image
        axes[0, i].imshow(img)
        axes[0, i].set_title(f'True: {true_label}, Pred: {pred_label}\nProb: {pred_prob:.3f}')
        axes[0, i].axis('off')

        try:
            # Generate Grad-CAM heatmap
            heatmap = make_gradcam_heatmap(
                tf.expand_dims(img, 0), model, last_conv_layer
            )

            # Resize heatmap to match image
            heatmap_resized = tf.image.resize(
                tf.expand_dims(heatmap, -1),
                img.shape[:2]
            ).numpy()[:, :, 0]

            # Overlay heatmap on image
            overlay = img.numpy() * 0.6 + plt.cm.jet(heatmap_resized)[:, :, :3] * 0.4
            axes[1, i].imshow(overlay)
            axes[1, i].set_title('Grad-CAM Heatmap')
            axes[1, i].axis('off')

        except Exception as e:
            print(f"⚠️ Grad-CAM failed for sample {i}: {e}")
            axes[1, i].text(0.5, 0.5, 'Grad-CAM\nUnavailable',
                          transform=axes[1, i].transAxes, ha='center', va='center')
            axes[1, i].axis('off')

    plt.tight_layout()
    plt.show()

# MAIN TRAINING PIPELINE

In [None]:
def train_gravitational_lens_classifier():
    """Main training pipeline."""
    print("🌌 Starting Gravitational Lens Detection Training Pipeline")
    print("=" * 60)

    # Check if running on CPU and adjust settings
    if len(tf.config.list_physical_devices('GPU')) == 0:
        print("⚡ Running on CPU - optimizing for CPU performance")
        global BATCH_SIZE, FREEZE_EPOCHS, TOTAL_EPOCHS
        # Reduce batch size and epochs for faster CPU training
        BATCH_SIZE = min(BATCH_SIZE, 16)
        FREEZE_EPOCHS = min(FREEZE_EPOCHS, 5)
        TOTAL_EPOCHS = min(TOTAL_EPOCHS, 20)
        print(f"Adjusted: batch_size={BATCH_SIZE}, freeze_epochs={FREEZE_EPOCHS}, total_epochs={TOTAL_EPOCHS}")

    # Step 1: Load and prepare data
    print("\n" + "="*60)
    print("STEP 1: DATA LOADING & PREPARATION")
    print("="*60)

    train_df, val_df, test_df, initial_class_weights, imbalance_ratio = load_dataset(
        CSV_PATH, DATA_DIR, SPLIT, SEED
    )

    # Handle class imbalance
    train_df, class_weights = handle_class_imbalance(
        train_df, imbalance_ratio, OVERSAMPLING_THRESHOLD
    )

    # Step 2: Create data pipelines
    print("\n" + "="*60)
    print("STEP 2: DATA PIPELINE CREATION")
    print("="*60)

    print("🔄 Creating data pipelines...")
    train_dataset = create_data_pipeline(
        train_df, DATA_DIR, BATCH_SIZE, IMG_SIZE, augment=True, cache=False
    )
    val_dataset = create_data_pipeline(
        val_df, DATA_DIR, BATCH_SIZE, IMG_SIZE, augment=False, cache=True
    )
    test_dataset = create_data_pipeline(
        test_df, DATA_DIR, BATCH_SIZE, IMG_SIZE, augment=False, cache=True
    )

    steps_per_epoch = len(train_df) // BATCH_SIZE
    validation_steps = len(val_df) // BATCH_SIZE

    print(f"Training steps per epoch: {steps_per_epoch}")
    print(f"Validation steps: {validation_steps}")

    # Step 3: Create model
    print("\n" + "="*60)
    print("STEP 3: MODEL CREATION")
    print("="*60)

    model, backbone = create_mobilenet_classifier(
        IMG_SIZE, MOBILENET_VERSION, ACTIVATION, DROPOUT_RATE, HIDDEN_UNITS, USE_BATCH_NORM
    )

    print(f"\n📋 Model Summary:")
    model.summary()

    # Step 4: Phase 1 Training (Frozen Backbone)
    print("\n" + "="*60)
    print("STEP 4: PHASE 1 TRAINING (FROZEN BACKBONE)")
    print("="*60)

    callbacks = create_callbacks(patience=PATIENCE//2, monitor='val_auc')

    print(f"🔒 Training classifier head for {FREEZE_EPOCHS} epochs...")
    history_phase1 = model.fit(
        train_dataset,
        epochs=FREEZE_EPOCHS,
        validation_data=val_dataset,
        steps_per_epoch=steps_per_epoch,
        validation_steps=validation_steps,
        callbacks=callbacks,
        class_weight=class_weights,
        verbose=1
    )

    # Step 5: Phase 2 Training (Fine-tuning)
    print("\n" + "="*60)
    print("STEP 5: PHASE 2 TRAINING (FINE-TUNING)")
    print("="*60)

    # Unfreeze backbone
    backbone.trainable = True

    # Use lower learning rate for fine-tuning
    model.compile(
        optimizer=Adam(learning_rate=1e-4),  # 10x lower
        loss=BinaryCrossentropy(),
        metrics=[
            BinaryAccuracy(name='accuracy'),
            Precision(name='precision'),
            Recall(name='recall'),
            AUC(name='auc')
        ]
    )

    print(f"🔓 Fine-tuning full model for {TOTAL_EPOCHS - FREEZE_EPOCHS} more epochs...")
    print(f"Unfrozen backbone with {backbone.count_params():,} trainable parameters")

    callbacks = create_callbacks(patience=PATIENCE, monitor='val_auc')

    history_phase2 = model.fit(
        train_dataset,
        epochs=TOTAL_EPOCHS - FREEZE_EPOCHS,
        validation_data=val_dataset,
        steps_per_epoch=steps_per_epoch,
        validation_steps=validation_steps,
        callbacks=callbacks,
        class_weight=class_weights,
        verbose=1
    )

    # Combine training histories
    combined_history = keras.utils.get_custom_objects()
    combined_history = type('', (), {})()
    combined_history.history = {}

    for key in history_phase1.history.keys():
        combined_history.history[key] = (
            history_phase1.history[key] + history_phase2.history[key]
        )

    # Step 6: Evaluation
    print("\n" + "="*60)
    print("STEP 6: MODEL EVALUATION")
    print("="*60)

    # Plot training curves
    plot_training_history(combined_history)

    # Evaluate on test set
    metrics = evaluate_model(model, test_dataset, test_df)

    # Generate Grad-CAM visualizations
    create_gradcam_visualization(model, test_dataset, num_samples=4)

    return model, metrics, combined_history

# ABLATION EXPERIMENTS

In [None]:
def run_ablation_experiments():
    """
    Run quick ablation experiments to compare different configurations.
    """
    print("\n" + "="*60)
    print("ABLATION EXPERIMENTS")
    print("="*60)

    # Load data once for all experiments
    train_df, val_df, test_df, _, imbalance_ratio = load_dataset(
        CSV_PATH, DATA_DIR, SPLIT, SEED
    )
    train_df, class_weights = handle_class_imbalance(
        train_df, imbalance_ratio, OVERSAMPLING_THRESHOLD
    )

    # Create datasets
    train_dataset = create_data_pipeline(
        train_df, DATA_DIR, BATCH_SIZE, IMG_SIZE, augment=True, cache=False
    )
    val_dataset = create_data_pipeline(
        val_df, DATA_DIR, BATCH_SIZE, IMG_SIZE, augment=False, cache=True
    )

    steps_per_epoch = len(train_df) // BATCH_SIZE
    validation_steps = len(val_df) // BATCH_SIZE

    # Experiment configurations
    experiments = [
        {'activation': 'relu', 'name': 'ReLU'},
        {'activation': 'leaky_relu', 'name': 'LeakyReLU'},
        {'activation': 'swish', 'name': 'Swish'},
    ]

    results = []

    # Reduce epochs for ablation (faster experiments)
    ablation_epochs = min(15, TOTAL_EPOCHS)
    ablation_freeze = min(5, FREEZE_EPOCHS)

    for i, exp in enumerate(experiments):
        print(f"\n🧪 Experiment {i+1}/{len(experiments)}: {exp['name']}")
        print("-" * 40)

        # Create model with current configuration
        model, backbone = create_mobilenet_classifier(
            IMG_SIZE, MOBILENET_VERSION, exp['activation'],
            DROPOUT_RATE, HIDDEN_UNITS, USE_BATCH_NORM
        )

        # Phase 1: Frozen training
        callbacks = [EarlyStopping(monitor='val_auc', patience=3, restore_best_weights=True)]

        history1 = model.fit(
            train_dataset,
            epochs=ablation_freeze,
            validation_data=val_dataset,
            steps_per_epoch=steps_per_epoch,
            validation_steps=validation_steps,
            callbacks=callbacks,
            class_weight=class_weights,
            verbose=0
        )

        # Phase 2: Fine-tuning
        backbone.trainable = True
        model.compile(
            optimizer=Adam(learning_rate=1e-4),
            loss=BinaryCrossentropy(),
            metrics=[BinaryAccuracy(name='accuracy'), AUC(name='auc')]
        )

        history2 = model.fit(
            train_dataset,
            epochs=ablation_epochs - ablation_freeze,
            validation_data=val_dataset,
            steps_per_epoch=steps_per_epoch,
            validation_steps=validation_steps,
            callbacks=callbacks,
            class_weight=class_weights,
            verbose=0
        )

        # Evaluate
        val_loss, val_acc, val_auc = model.evaluate(val_dataset, verbose=0)
        model_size = model.count_params()

        results.append({
            'Configuration': exp['name'],
            'Val Accuracy': f"{val_acc:.4f}",
            'Val AUC': f"{val_auc:.4f}",
            'Model Size': f"{model_size:,}",
            'Best Epoch': len(history1.history['loss']) + len(history2.history['loss'])
        })

        print(f"✅ {exp['name']}: Val AUC = {val_auc:.4f}, Val Acc = {val_acc:.4f}")

    # Display results table
    results_df = pd.DataFrame(results)
    print(f"\n📊 ABLATION RESULTS SUMMARY")
    print("=" * 60)
    print(results_df.to_string(index=False))

    return results_df


# UTILITY FUNCTIONS

In [None]:
def visualize_data_samples(dataset, num_samples=8):
    """Visualize sample images from the dataset."""
    print("🖼️ Sample Images from Dataset:")

    # Get a batch
    for images, labels in dataset.take(1):
        break

    # Select samples
    indices = np.random.choice(len(images), min(num_samples, len(images)), replace=False)

    # Plot
    cols = 4
    rows = (num_samples + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(15, 4*rows))
    if rows == 1:
        axes = axes.reshape(1, -1)

    for i, idx in enumerate(indices):
        row, col = i // cols, i % cols

        img = images[idx]
        label = int(labels[idx])
        label_text = "Lensed" if label == 1 else "Non-lensed"

        axes[row, col].imshow(img)
        axes[row, col].set_title(f'{label_text} (Class {label})', fontsize=12, fontweight='bold')
        axes[row, col].axis('off')

    # Hide extra subplots
    for i in range(num_samples, rows * cols):
        row, col = i // cols, i % cols
        axes[row, col].axis('off')

    plt.tight_layout()
    plt.show()

def analyze_predictions(model, test_dataset, test_df, threshold=0.5):
    """Analyze model predictions in detail."""
    print("🔍 Detailed Prediction Analysis")

    # Get predictions
    y_pred_proba = model.predict(test_dataset, verbose=0).flatten()
    y_pred = (y_pred_proba > threshold).astype(int)
    y_true = test_df['binary_label'].values

    # Create analysis dataframe
    analysis_df = test_df.copy()
    analysis_df['predicted_prob'] = y_pred_proba
    analysis_df['predicted_label'] = y_pred
    analysis_df['correct'] = (y_pred == y_true)

    # Confidence analysis
    high_conf_correct = ((y_pred_proba > 0.8) | (y_pred_proba < 0.2)) & analysis_df['correct']
    low_conf_errors = ((y_pred_proba > 0.3) & (y_pred_proba < 0.7)) & ~analysis_df['correct']

    print(f"High confidence correct predictions: {high_conf_correct.sum()}/{len(analysis_df)} ({high_conf_correct.mean()*100:.1f}%)")
    print(f"Low confidence errors: {low_conf_errors.sum()}/{len(analysis_df)} ({low_conf_errors.mean()*100:.1f}%)")

    # Show most confident correct and incorrect predictions
    if high_conf_correct.any():
        print("\n✅ Most confident correct predictions:")
        confident_correct = analysis_df[high_conf_correct].nlargest(3, 'predicted_prob')
        for _, row in confident_correct.iterrows():
            prob = row['predicted_prob']
            print(f"  {row['filename']}: True={row['binary_label']}, Pred={row['predicted_label']} (prob={prob:.3f})")

    if low_conf_errors.any():
        print("\n❌ Low confidence errors (review these):")
        low_conf_wrong = analysis_df[low_conf_errors].head(3)
        for _, row in low_conf_wrong.iterrows():
            prob = row['predicted_prob']
            print(f"  {row['filename']}: True={row['binary_label']}, Pred={row['predicted_label']} (prob={prob:.3f})")

    return analysis_df

# ENHANCED TRAINING WITH TENSORBOARD

In [None]:
def setup_tensorboard():
    """Setup TensorBoard logging."""
    import datetime
    log_dir = f"logs/fit/{datetime.datetime.now().strftime('%Y%m%d-%H%M%S')}"
    tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=log_dir,
        histogram_freq=1,
        write_graph=True,
        write_images=True,
        update_freq='epoch'
    )
    return tensorboard_callback, log_dir

def create_enhanced_callbacks(patience=10, monitor='val_auc', mode='max'):
    """Create enhanced training callbacks with TensorBoard."""
    tensorboard_callback, log_dir = setup_tensorboard()

    callbacks = [
        tensorboard_callback,
        EarlyStopping(
            monitor=monitor,
            patience=patience,
            restore_best_weights=True,
            verbose=1,
            mode=mode
        ),
        ReduceLROnPlateau(
            monitor=monitor,
            factor=0.5,
            patience=patience//2,
            min_lr=1e-7,
            verbose=1,
            mode=mode
        ),
        ModelCheckpoint(
            'best_gravitational_lens_model.h5',
            monitor=monitor,
            save_best_only=True,
            mode=mode,
            verbose=1,
            save_weights_only=False
        )
    ]

    print(f"📊 TensorBoard logs will be saved to: {log_dir}")
    print("💡 To view TensorBoard: %tensorboard --logdir logs/fit")

    return callbacks.csv", index=False)

    # Save configuration
    config = {
        'DATA_DIR': DATA_DIR,
        'BATCH_SIZE': BATCH_SIZE,
        'IMG_SIZE': IMG_SIZE,
        'SPLIT': SPLIT,
        'ACTIVATION': ACTIVATION,
        'FREEZE_EPOCHS': FREEZE_EPOCHS,
        'TOTAL_EPOCHS': TOTAL_EPOCHS,
        'MOBILENET_VERSION': MOBILENET_VERSION,
        'DROPOUT_RATE': DROPOUT_RATE,
        'HIDDEN_UNITS': HIDDEN_UNITS,
        'SEED': SEED
    }

    config_df = pd.DataFrame([config])
    config_df.to_csv(f"{save_dir}/config.csv", index=False)

    print(f"📁 All artifacts saved to {save_dir}/")

# FINAL EXECUTION SECTION

In [None]:
def main():
    """Main execution function - run the complete training pipeline."""
    print("🎬 Starting Gravitational Lens Classification Pipeline...")
    print("🔬 Enhanced version with TensorBoard, detailed analysis, and ablations")

    try:
        # Run main enhanced training pipeline
        print("\n🚀 MAIN TRAINING PIPELINE")
        model, metrics, history, analysis = train_gravitational_lens_classifier_enhanced()

        # Print final results summary
        print("\n" + "="*70)
        print("🎊 TRAINING COMPLETED SUCCESSFULLY!")
        print("="*70)
        print(f"🏆 Final Test Results:")
        print(f"   Accuracy: {metrics['accuracy']:.4f}")
        print(f"   ROC-AUC:  {metrics['auc_roc']:.4f}")
        print(f"   PR-AUC:   {metrics['auc_pr']:.4f}")
        print(f"   F1-Score: {metrics['f1_score']:.4f}")
        print(f"   Precision: {metrics['precision']:.4f}")
        print(f"   Recall:   {metrics['recall']:.4f}")

        # Run ablation study (optional - comment out to skip)
        print("\n🔬 RUNNING ABLATION STUDY...")
        print("This compares different activation functions and model sizes.")
        print("💡 Comment out the next line to skip ablation experiments and save time.")

        ablation_results = run_quick_ablation_study()

        print("\n✨ PIPELINE COMPLETE!")
        print("📁 Check the training_artifacts/ folder for saved models and logs")
        print("📊 Use %tensorboard --logdir logs/fit to view training in TensorBoard")

        return model, metrics, ablation_results

    except FileNotFoundError as e:
        print(f"❌ File not found: {e}")
        print("💡 Please check that your DATA_DIR and CSV_PATH are correct")
        print("💡 Make sure your CSV has 'filename' and 'label' columns")

    except Exception as e:
        print(f"❌ Error in training pipeline: {e}")
        import traceback
        traceback.print_exc()
        print("\n💡 Common issues:")
        print("   - Check your data paths (DATA_DIR, CSV_PATH)")
        print("   - Verify CSV format (columns: filename, label)")
        print("   - Ensure sufficient GPU memory (reduce BATCH_SIZE if needed)")
        print("   - Check image file formats (should be PNG)")


# EXECUTION BLOCK

In [None]:
if __name__ == "__main__":
    # Run the complete pipeline
    main()

# COLAB-SPECIFIC INSTRUCTIONS AND UTILITIES

In [None]:
print("""
🎯 GOOGLE COLAB QUICK SETUP GUIDE:
==================================

1️⃣ ENABLE GPU (STRONGLY RECOMMENDED):
   Runtime → Change runtime type → Hardware accelerator → GPU → Save

2️⃣ UPLOAD YOUR DATA:

   Option A - Direct Upload:
   ```python
   from google.colab import files
   uploaded = files.upload()  # Upload your CSV and images
   ```

   Option B - Google Drive:
   ```python
   from google.colab import drive
   drive.mount('/content/drive')
   DATA_DIR = "/content/drive/MyDrive/gravitational_lens_images/"
   CSV_PATH = "/content/drive/MyDrive/gravitational_lens_labels.csv"
   ```

3️⃣ ADJUST CONFIGURATION (edit the variables at the top):
   - DATA_DIR: Path to your image directory
   - CSV_PATH: Path to your labels CSV file
   - BATCH_SIZE: Start with 32, reduce to 16 if out of memory
   - TOTAL_EPOCHS: 50 is good for full training, 20 for quick tests

4️⃣ RUN THE NOTEBOOK:
   Runtime → Run all (or run cell by cell)

📊 CSV FORMAT REQUIRED:
   Your CSV must have exactly these columns:
   - filename: image filename (e.g., "image_001.png")
   - label: "A" for lensed galaxies, "B" for non-lensed

📂 DIRECTORY STRUCTURE:
   Your images can be in subdirectories - the code will find them automatically!

🔧 EXPECTED PERFORMANCE (with good data):
   - Training time: 2-4 hours on T4 GPU
   - Memory usage: 6-10GB GPU RAM
   - Target AUC: 0.90+ for quality datasets

💡 TROUBLESHOOTING:
   - OOM Error → Reduce BATCH_SIZE to 16 or 8
   - Slow training → Make sure GPU is enabled
   - File not found → Check DATA_DIR and CSV_PATH
   - Poor performance → Check data quality and class balance

🎉 WHAT YOU GET:
   - Trained model (.h5 file)
   - Training history plots
   - Confusion matrix, ROC curves
   - Grad-CAM visualizations
   - Performance metrics
   - TensorBoard logs
   - Ablation study results

Ready to detect gravitational lenses! 🌌🔍
""")

# Quick verification function
def verify_setup():
    """Quick verification of setup and data paths."""
    print("🔍 Verifying Setup...")

    print(f"✓ TensorFlow version: {tf.__version__}")
    print(f"✓ GPU available: {len(tf.config.list_physical_devices('GPU')) > 0}")

    # Check if paths exist
    if os.path.exists(DATA_DIR):
        print(f"✓ Data directory found: {DATA_DIR}")
        png_files = []
        for root, dirs, files in os.walk(DATA_DIR):
            png_files.extend([f for f in files if f.lower().endswith('.png')])
        print(f"✓ Found {len(png_files)} PNG files")
    else:
        print(f"❌ Data directory not found: {DATA_DIR}")

    if os.path.exists(CSV_PATH):
        print(f"✓ CSV file found: {CSV_PATH}")
        df = pd.read_csv(CSV_PATH)
        print(f"✓ CSV has {len(df)} rows")
        print(f"✓ CSV columns: {list(df.columns)}")

        if 'filename' in df.columns and 'label' in df.columns:
            print("✓ Required columns found")
            label_counts = df['label'].value_counts()
            print(f"✓ Label distribution: {dict(label_counts)}")
        else:
            print("❌ Missing required columns: 'filename' and/or 'label'")
    else:
        print(f"❌ CSV file not found: {CSV_PATH}")

    print("\nConfiguration:")
    print(f"- Batch size: {BATCH_SIZE}")
    print(f"- Image size: {IMG_SIZE}")
    print(f"- Activation: {ACTIVATION}")
    print(f"- Training epochs: {TOTAL_EPOCHS} (freeze: {FREEZE_EPOCHS})")

# Uncomment to run verification
# verify_setup()


🎯 GOOGLE COLAB QUICK SETUP GUIDE:

1️⃣ ENABLE GPU (STRONGLY RECOMMENDED):
   Runtime → Change runtime type → Hardware accelerator → GPU → Save

2️⃣ UPLOAD YOUR DATA:
   
   Option A - Direct Upload:
   ```python
   from google.colab import files
   uploaded = files.upload()  # Upload your CSV and images
   ```
   
   Option B - Google Drive:
   ```python
   from google.colab import drive
   drive.mount('/content/drive')
   DATA_DIR = "/content/drive/MyDrive/gravitational_lens_images/"
   CSV_PATH = "/content/drive/MyDrive/gravitational_lens_labels.csv"
   ```

3️⃣ ADJUST CONFIGURATION (edit the variables at the top):
   - DATA_DIR: Path to your image directory
   - CSV_PATH: Path to your labels CSV file
   - BATCH_SIZE: Start with 32, reduce to 16 if out of memory
   - TOTAL_EPOCHS: 50 is good for full training, 20 for quick tests

4️⃣ RUN THE NOTEBOOK:
   Runtime → Run all (or run cell by cell)

📊 CSV FORMAT REQUIRED:
   Your CSV must have exactly these columns:
   - filename: image 