In [None]:
"""
STEP 1: DATA PREPROCESSING
Run this on CPU runtime (Runtime ‚Üí Change runtime type ‚Üí None)
This script processes images and saves them to Google Drive
"""

from google.colab import drive
drive.mount('/content/drive')

import random
import numpy as np
import shutil
import pandas as pd
import os
import json
from pathlib import Path
from tqdm.notebook import tqdm
from PIL import Image

# Set seeds for reproducibility
random.seed(42)
np.random.seed(42)

Mounted at /content/drive


In [None]:


# ============= CONFIGURATION =============
DATASET_FOLDER_PATH = "/content/drive/MyDrive/dataset"
PROCESSED_DATA_PATH = "/content/drive/MyDrive/dataset/processed_data"
METADATA_PATH = "/content/drive/MyDrive/dataset/processed_data/metadata.json"

IMAGES_PER_CLASS = 20000
IMG_SIZE = 224
TRAIN_RATIO = 0.8
VALID_RATIO = 0.1
TEST_RATIO = 0.1

In [None]:



print("=" * 60)
print("WILDFIRE DETECTION - DATA PREPROCESSING")
print("=" * 60)
print(f"Running on: CPU Runtime")
print(f"Image size: {IMG_SIZE}x{IMG_SIZE}")
print(f"Images per class: {IMAGES_PER_CLASS}")
print(f"Train/Valid/Test split: {TRAIN_RATIO}/{VALID_RATIO}/{TEST_RATIO}")
print("=" * 60)

def load_strict_image(path):
    """Load and verify image integrity"""
    try:
        img = Image.open(path)
        img.verify()
        img = Image.open(path).convert("RGB")
        return img
    except Exception as e:
        print(f"Failed to load {path}: {str(e)}")
        return None

def prepare_dataset(source_path, destination_path):
    """
    Prepare and save dataset to Google Drive
    RESUMABLE: Can continue from interruption
    """
    print("\nüîÑ Starting dataset preparation...")

    # Create directory structure (doesn't delete existing)
    os.makedirs(destination_path, exist_ok=True)
    for split in ['train', 'valid', 'test']:
        for class_name in ['wildfire', 'no_wildfire']:
            os.makedirs(os.path.join(destination_path, split, class_name), exist_ok=True)

    # Load source images
    wildfire_dir = os.path.join(source_path, 'wildfire')
    no_wildfire_dir = os.path.join(source_path, 'nowildfire')

    print(f"\nüìÇ Loading images from:")
    print(f"   Wildfire: {wildfire_dir}")
    print(f"   No Wildfire: {no_wildfire_dir}")

    wildfire_images = [f for f in os.listdir(wildfire_dir)
                       if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    no_wildfire_images = [f for f in os.listdir(no_wildfire_dir)
                          if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

    print(f"\nüìä Found:")
    print(f"   Wildfire images: {len(wildfire_images)}")
    print(f"   No wildfire images: {len(no_wildfire_images)}")

    def process_images(image_list, source_dir, class_name):
        """Process and split images into train/valid/test - RESUMABLE"""
        # Sample if needed
        original_count = len(image_list)
        if len(image_list) > IMAGES_PER_CLASS:
            image_list = random.sample(image_list, IMAGES_PER_CLASS)
            print(f"   Sampled {IMAGES_PER_CLASS} from {original_count} images")

        # Shuffle
        random.shuffle(image_list)

        # Split
        train_end = int(len(image_list) * TRAIN_RATIO)
        valid_end = train_end + int(len(image_list) * VALID_RATIO)

        train_images = image_list[:train_end]
        valid_images = image_list[train_end:valid_end]
        test_images = image_list[valid_end:]

        stats = {
            'total': len(image_list),
            'train': 0,
            'valid': 0,
            'test': 0,
            'failed': 0
        }

        # Helper function to check if image already processed
        def already_processed(split, class_name, img_name):
            dest_path = os.path.join(destination_path, split, class_name, img_name)
            return os.path.exists(dest_path)

        # Process train set - SKIP EXISTING
        print(f"   Processing training images...")
        processed_count = 0
        skipped_count = 0
        for img_name in tqdm(train_images, desc=f"Train-{class_name}"):
            if already_processed('train', class_name, img_name):
                skipped_count += 1
                stats['train'] += 1
                continue

            img = load_strict_image(os.path.join(source_dir, img_name))
            if img:
                img.resize((IMG_SIZE, IMG_SIZE)).save(
                    os.path.join(destination_path, 'train', class_name, img_name)
                )
                stats['train'] += 1
                processed_count += 1
            else:
                stats['failed'] += 1

        if skipped_count > 0:
            print(f"   ‚ö° Skipped {skipped_count} already processed training images")

        # Process valid set - SKIP EXISTING
        print(f"   Processing validation images...")
        processed_count = 0
        skipped_count = 0
        for img_name in tqdm(valid_images, desc=f"Valid-{class_name}"):
            if already_processed('valid', class_name, img_name):
                skipped_count += 1
                stats['valid'] += 1
                continue

            img = load_strict_image(os.path.join(source_dir, img_name))
            if img:
                img.resize((IMG_SIZE, IMG_SIZE)).save(
                    os.path.join(destination_path, 'valid', class_name, img_name)
                )
                stats['valid'] += 1
                processed_count += 1
            else:
                stats['failed'] += 1

        if skipped_count > 0:
            print(f"   ‚ö° Skipped {skipped_count} already processed validation images")

        # Process test set - SKIP EXISTING
        print(f"   Processing test images...")
        processed_count = 0
        skipped_count = 0
        for img_name in tqdm(test_images, desc=f"Test-{class_name}"):
            if already_processed('test', class_name, img_name):
                skipped_count += 1
                stats['test'] += 1
                continue

            img = load_strict_image(os.path.join(source_dir, img_name))
            if img:
                img.resize((IMG_SIZE, IMG_SIZE)).save(
                    os.path.join(destination_path, 'test', class_name, img_name)
                )
                stats['test'] += 1
                processed_count += 1
            else:
                stats['failed'] += 1

        if skipped_count > 0:
            print(f"   ‚ö° Skipped {skipped_count} already processed test images")

        return stats

    # ============================================
    # THIS IS WHAT'S MISSING IN YOUR CODE!
    # ============================================

    # Process wildfire images
    print(f"\nüî• Processing WILDFIRE class...")
    wildfire_stats = process_images(wildfire_images, wildfire_dir, 'wildfire')

    # Process no_wildfire images
    print(f"\nüå≤ Processing NO_WILDFIRE class...")
    no_wildfire_stats = process_images(no_wildfire_images, no_wildfire_dir, 'no_wildfire')

    # Save metadata
    metadata = {
        'img_size': IMG_SIZE,
        'train_ratio': TRAIN_RATIO,
        'valid_ratio': VALID_RATIO,
        'test_ratio': TEST_RATIO,
        'wildfire': wildfire_stats,
        'no_wildfire': no_wildfire_stats,
        'preprocessing_complete': True
    }

    with open(METADATA_PATH, 'w') as f:
        json.dump(metadata, f, indent=2)

    print(f"\n‚úÖ Preprocessing complete!")
    print(f"\nüìä Final Statistics:")
    print(f"   Wildfire - Train: {wildfire_stats['train']}, Valid: {wildfire_stats['valid']}, Test: {wildfire_stats['test']}")
    print(f"   No Wildfire - Train: {no_wildfire_stats['train']}, Valid: {no_wildfire_stats['valid']}, Test: {no_wildfire_stats['test']}")
    print(f"   Failed images: {wildfire_stats['failed'] + no_wildfire_stats['failed']}")
    print(f"\nüíæ Data saved to: {destination_path}")
    print(f"üìù Metadata saved to: {METADATA_PATH}")

    return metadata


def verify_processed_data(processed_path):
    """Verify that processed data exists and is complete"""
    print("\nüîç Verifying processed data...")

    if not os.path.exists(METADATA_PATH):
        print("‚ùå Metadata not found. Preprocessing required.")
        return False

    with open(METADATA_PATH, 'r') as f:
        metadata = json.load(f)

    if not metadata.get('preprocessing_complete', False):
        print("‚ùå Preprocessing incomplete. Need to reprocess.")
        return False

    # Check directory structure
    for split in ['train', 'valid', 'test']:
        for class_name in ['wildfire', 'no_wildfire']:
            path = os.path.join(processed_path, split, class_name)
            if not os.path.exists(path):
                print(f"‚ùå Missing directory: {path}")
                return False

            num_images = len([f for f in os.listdir(path)
                             if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
            print(f"   {split}/{class_name}: {num_images} images")

    print("‚úÖ All processed data verified!")
    return True

WILDFIRE DETECTION - DATA PREPROCESSING
Running on: CPU Runtime
Image size: 224x224
Images per class: 20000
Train/Valid/Test split: 0.8/0.1/0.1


In [None]:

if __name__ == "__main__":
    # Check if preprocessing already done
    if verify_processed_data(PROCESSED_DATA_PATH):
        print("\n" + "="*60)
        print("‚ö†Ô∏è  PROCESSED DATA ALREADY EXISTS")
        print("="*60)
        response = input("Do you want to:\n1. Resume/Complete preprocessing\n2. Start fresh (delete all)\n3. Skip preprocessing\nEnter (1/2/3): ")

        if response == '2':
            # Delete and start fresh
            print("üóëÔ∏è  Deleting existing processed data...")
            shutil.rmtree(PROCESSED_DATA_PATH)
            if os.path.exists(METADATA_PATH):
                os.remove(METADATA_PATH)
            print("‚úÖ Deleted. Starting fresh...")
            metadata = prepare_dataset(DATASET_FOLDER_PATH, PROCESSED_DATA_PATH)
        elif response == '3':
            print("‚è≠Ô∏è  Skipping preprocessing.")
            exit(0)
        else:  # Default to resume (option 1)
            print("‚ñ∂Ô∏è  Resuming preprocessing (will skip already processed images)...")
            metadata = prepare_dataset(DATASET_FOLDER_PATH, PROCESSED_DATA_PATH)
    else:
        # No existing data, start fresh
        metadata = prepare_dataset(DATASET_FOLDER_PATH, PROCESSED_DATA_PATH)

    print("\n" + "="*60)
    print("‚úÖ PREPROCESSING COMPLETE!")
    print("="*60)
    print("\nüìã NEXT STEPS:")
    print("1. Runtime ‚Üí Change runtime type ‚Üí GPU (T4 or better)")
    print("2. The kernel will restart (this is expected)")
    print("3. Run the TRAINING script")
    print("4. Your processed data is safely stored in Google Drive")
    print("="*60)


üîç Verifying processed data...
‚ùå Metadata not found. Preprocessing required.

üîÑ Starting dataset preparation...

üìÇ Loading images from:
   Wildfire: /content/drive/MyDrive/dataset/wildfire
   No Wildfire: /content/drive/MyDrive/dataset/nowildfire

üìä Found:
   Wildfire images: 22738
   No wildfire images: 20170

üî• Processing WILDFIRE class...
   Sampled 20000 from 22738 images
   Processing training images...


Train-wildfire:   0%|          | 0/16000 [00:00<?, ?it/s]

Failed to load /content/drive/MyDrive/dataset/wildfire/-73.15884,46.38819.jpg: image file is truncated (51 bytes not processed)
   ‚ö° Skipped 336 already processed training images
   Processing validation images...


Valid-wildfire:   0%|          | 0/2000 [00:00<?, ?it/s]

   Processing test images...


Test-wildfire:   0%|          | 0/2000 [00:00<?, ?it/s]


üå≤ Processing NO_WILDFIRE class...
   Sampled 20000 from 20170 images
   Processing training images...


Train-no_wildfire:   0%|          | 0/16000 [00:00<?, ?it/s]

   Processing validation images...


Valid-no_wildfire:   0%|          | 0/2000 [00:00<?, ?it/s]

Failed to load /content/drive/MyDrive/dataset/nowildfire/-114.152378,51.027198.jpg: image file is truncated (16 bytes not processed)
   Processing test images...


Test-no_wildfire:   0%|          | 0/2000 [00:00<?, ?it/s]


‚úÖ Preprocessing complete!

üìä Final Statistics:
   Wildfire - Train: 15999, Valid: 2000, Test: 2000
   No Wildfire - Train: 16000, Valid: 1999, Test: 2000
   Failed images: 2

üíæ Data saved to: /content/drive/MyDrive/dataset/processed_data
üìù Metadata saved to: /content/drive/MyDrive/dataset/processed_data/metadata.json

‚úÖ PREPROCESSING COMPLETE!

üìã NEXT STEPS:
1. Runtime ‚Üí Change runtime type ‚Üí GPU (T4 or better)
2. The kernel will restart (this is expected)
3. Run the TRAINING script
4. Your processed data is safely stored in Google Drive


##Part 2 - Model Bulding and Training

In [None]:
"""
STEP 2: MODEL TRAINING & EVALUATION
Run this on GPU runtime (Runtime ‚Üí Change runtime type ‚Üí T4 GPU or better)
This script loads preprocessed data and trains the model
"""
# from google.colab import drive
# drive.mount('/content/drive')


import random
import numpy as np
import pandas as pd
import os
import json
from pathlib import Path
from tqdm.notebook import tqdm
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc, roc_auc_score
from sklearn.metrics import precision_recall_curve, average_precision_score, f1_score
import seaborn as sns
import matplotlib.pyplot as plt
from datetime import datetime

# Set seeds for reproducibility
random.seed(42)
np.random.seed(42)
tf.random.set_seed(42)

In [None]:
# ============= CONFIGURATION =============
PROCESSED_DATA_PATH = "/content/drive/MyDrive/dataset/processed_data"
METADATA_PATH = "/content/metadata.json"
RESULTS_DIR = "/content/drive/MyDrive/dataset/Results"
MODEL_SAVE_PATH = os.path.join(RESULTS_DIR, "best_model.keras")

BATCH_SIZE = 32
PHASE1_EPOCHS = 5
PHASE1_LR = 5e-5
PHASE2_EPOCHS = 20
PHASE2_LR = 1e-6
UNFREEZE_LAYERS = 50

print("=" * 60)
print("WILDFIRE DETECTION - MODEL TRAINING")
print("=" * 60)

# Check GPU availability
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    print(f"‚úÖ GPU Available: {gpus}")
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
else:
    print("‚ö†Ô∏è  WARNING: No GPU detected! Training will be slow.")
    print("   Please change runtime: Runtime ‚Üí Change runtime type ‚Üí GPU")

print("=" * 60)

# Create results directory
os.makedirs(RESULTS_DIR, exist_ok=True)

# ============= LOAD METADATA =============
import time
def load_metadata():
    """Load preprocessing metadata with retry logic"""
    print("\nAttempting to load metadata...")
    retries = 5
    delay = 2 # seconds
    for i in range(retries):
        if os.path.exists(METADATA_PATH):
            try:
                with open(METADATA_PATH, 'r') as f:
                    metadata = json.load(f)
                if not metadata.get('preprocessing_complete', False):
                    raise ValueError(
                        "‚ùå Preprocessing incomplete!\n"
                        "   Please run the PREPROCESSING script first on CPU runtime!"
                    )
                print(f"‚úÖ Metadata loaded successfully after {i+1} attempt(s).")
                print(f"   Image size: {metadata['img_size']}x{metadata['img_size']}")
                print(f"   Wildfire - Train: {metadata['wildfire']['train']}, "
                      f"Valid: {metadata['wildfire']['valid']}, Test: {metadata['wildfire']['test']}")
                print(f"   No Wildfire - Train: {metadata['no_wildfire']['train']}, "
                      f"Valid: {metadata['no_wildfire']['valid']}, Test: {metadata['no_wildfire']['test']}")
                return metadata
            except json.JSONDecodeError as e:
                print(f"‚ö†Ô∏è  Metadata file found but is corrupted. Retrying in {delay}s... ({e})")
                time.sleep(delay)
            except ValueError as e:
                # If preprocessing incomplete, it's a valid error, not a transient one.
                raise e
        else:
            print(f"‚ùå Metadata not found at {METADATA_PATH}. Retrying in {delay}s... (Attempt {i+1}/{retries})")
            time.sleep(delay)

    raise FileNotFoundError(
        f"‚ùå Metadata not found at {METADATA_PATH} after {retries} attempts.\n"
        f"   Please ensure the PREPROCESSING script was run successfully on CPU runtime!"
    )

metadata = load_metadata()
IMG_SIZE = metadata['img_size']


##Data Generators

In [None]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.efficientnet import preprocess_input
import os
import time # Import time for sleep functionality

def create_data_generators(base_dir):
    print("\nüîÑ Creating data generators...")

    retries = 3
    delay = 5 # seconds

    # Explicitly check for the existence of the split directories with retries
    for split_dir in ['train', 'valid', 'test']:
        full_path = os.path.join(base_dir, split_dir)
        for i in range(retries):
            if os.path.exists(full_path):
                print(f"‚úÖ Found directory: {full_path}")
                break
            else:
                print(f"‚ùå Directory not found: {full_path}. Retrying in {delay}s... (Attempt {i+1}/{retries})")
                time.sleep(delay)
        if not os.path.exists(full_path):
            raise FileNotFoundError(
                f"Failed to find directory after multiple retries: {full_path}.\n"
                f"Please ensure the preprocessing step completed successfully and created this directory in Google Drive."
            )

    train_datagen = ImageDataGenerator(
        preprocessing_function=preprocess_input,

        # Geometric (kept realistic for satellite imagery)
        rotation_range=20,
        width_shift_range=0.1,
        height_shift_range=0.1,
        zoom_range=0.15,
        horizontal_flip=True,

        # Lighting (safe)
        brightness_range=(0.9, 1.1),

        fill_mode="nearest"
    )

    valid_datagen = ImageDataGenerator(
        preprocessing_function=preprocess_input
    )

    test_datagen = ImageDataGenerator(
        preprocessing_function=preprocess_input
    )

    train_gen = train_datagen.flow_from_directory(
        os.path.join(base_dir, "train"),
        target_size=(IMG_SIZE, IMG_SIZE),
        batch_size=BATCH_SIZE,
        class_mode="binary",
        shuffle=True
    )

    valid_gen = valid_datagen.flow_from_directory(
        os.path.join(base_dir, "valid"),
        target_size=(IMG_SIZE, IMG_SIZE),
        batch_size=BATCH_SIZE,
        class_mode="binary",
        shuffle=False
    )

    test_gen = test_datagen.flow_from_directory(
        os.path.join(base_dir, "test"),
        target_size=(IMG_SIZE, IMG_SIZE),
        batch_size=BATCH_SIZE,
        class_mode="binary",
        shuffle=False
    )

    print("‚úÖ Data generators ready")
    return train_gen, valid_gen, test_gen


##Model Build

In [None]:
from tensorflow.keras.applications import EfficientNetB0, EfficientNetB4 # Added EfficientNetB4
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

def build_model(phase=1):
    print(f"\nüèóÔ∏è Building model (Phase {phase})")

    import os # Ensure os import is inside the function for proper scope if cell is run partially
    os.makedirs('/tmp/keras_models', exist_ok=True)

    try:
        # Try loading from cache or download EfficientNetB4
        base_model = EfficientNetB4(
            weights="imagenet",
            include_top=False,
            input_shape=(IMG_SIZE, IMG_SIZE, 3),
            name="efficientnet"
        )
    except Exception as e:
        print(f"‚ö†Ô∏è Download failed: {e}")
        print("Attempting manual download...")
        # Manual download as backup for EfficientNetB4
        !wget -q -O /tmp/keras_models/efficientnetb4_notop.h5 https://storage.googleapis.com/keras-applications/efficientnetb4_notop.h5
        base_model = EfficientNetB4(
            weights='/tmp/keras_models/efficientnetb4_notop.h5',
            include_top=False,
            input_shape=(IMG_SIZE, IMG_SIZE, 3),
            name="efficientnet"
        )

    # Model construction and compilation should be outside the try-except block
    # to ensure it always runs after base_model is defined, regardless of download method.
    if phase == 1:
        base_model.trainable = False
        lr = PHASE1_LR
    else:
        base_model.trainable = True
        for layer in base_model.layers[:-UNFREEZE_LAYERS]:
            layer.trainable = False
        lr = PHASE2_LR

    x = GlobalAveragePooling2D()(base_model.output)
    x = Dense(256, activation="relu")(x)
    x = Dropout(0.3)(x)
    output = Dense(1, activation="sigmoid")(x)

    model = Model(base_model.input, output)

    model.compile(
        optimizer=Adam(lr),
        loss="binary_crossentropy",
        metrics=["accuracy", tf.keras.metrics.AUC(name="auc")]
    )

    print("‚úÖ Model ready")
    return model


##Training

In [None]:
from tensorflow.keras.callbacks import (
    ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
)

def train_model():
    checkpoint = ModelCheckpoint(
        MODEL_SAVE_PATH,
        monitor="val_loss",
        save_best_only=True,
        verbose=1
    )

    early_stop = EarlyStopping(
        monitor="val_loss",
        patience=5,
        restore_best_weights=True,
        verbose=1
    )

    reduce_lr = ReduceLROnPlateau(
        monitor="val_loss",
        factor=0.3,
        patience=2,
        min_lr=1e-6,
        verbose=1
    )

    print("\n====== PHASE 1: FROZEN BASE ======")
    model = build_model(phase=1)
    history1 = model.fit(
        train_gen,
        validation_data=valid_gen,
        epochs=PHASE1_EPOCHS,
        callbacks=[checkpoint, early_stop, reduce_lr],
        verbose=1
    )

    model = load_model(MODEL_SAVE_PATH)

    print("\n====== PHASE 2: FINE TUNING ======")
    base_model = model.get_layer("efficientnet")
    for layer in base_model.layers[-UNFREEZE_LAYERS:]:
        layer.trainable = True

    model.compile(
        optimizer=Adam(PHASE2_LR),
        loss="binary_crossentropy",
        metrics=["accuracy", tf.keras.metrics.AUC(name="auc")]
    )

    history2 = model.fit(
        train_gen,
        validation_data=valid_gen,
        epochs=PHASE2_EPOCHS,
        callbacks=[checkpoint, early_stop, reduce_lr],
        verbose=1
    )

    model = load_model(MODEL_SAVE_PATH)

    return model, {"phase1": history1.history, "phase2": history2.history}


In [None]:
# ============= EVALUATION =============
def evaluate_model(model, test_generator):
    """Comprehensive model evaluation"""
    print("\n" + "="*60)
    print("MODEL EVALUATION")
    print("="*60)

    # Get predictions
    print("\nüîÆ Generating predictions...")
    y_true = test_generator.classes
    y_pred_prob = model.predict(test_generator, verbose=1)
    y_pred = (y_pred_prob > 0.5).astype(int).flatten()

    # Calculate metrics

    test_loss, test_acc, test_auc = model.evaluate(
        test_generator,
        verbose=0
    )

# Calculate precision and recall manually from predictions
    from sklearn.metrics import precision_score, recall_score
    test_precision = precision_score(y_true, y_pred)
    test_recall = recall_score(y_true, y_pred)

    test_f1 = f1_score(y_true, y_pred)

    print(f"\nüìä Test Metrics:")
    print(f"   Loss: {test_loss:.4f}")
    print(f"   Accuracy: {test_acc:.4f}")
    print(f"   AUC-ROC: {test_auc:.4f}")
    print(f"   Precision: {test_precision:.4f}")
    print(f"   Recall: {test_recall:.4f}")
    print(f"   F1-Score: {test_f1:.4f}")

    # Classification report
    print("\nüìã Classification Report:")
    class_names = ['No Wildfire', 'Wildfire']
    print(classification_report(y_true, y_pred, target_names=class_names))

    return {
        'y_true': y_true,
        'y_pred': y_pred,
        'y_pred_prob': y_pred_prob.flatten(),
        'metrics': {
            'loss': float(test_loss),
            'accuracy': float(test_acc),
            'auc': float(test_auc),
            'precision': float(test_precision),
            'recall': float(test_recall),
            'f1_score': float(test_f1)
        }
    }



# ============= VISUALIZATION =============
def plot_training_history(history, save_path):
    """Plot comprehensive training history"""
    print("\nüìà Generating training visualizations...")

    phase1 = history['phase1']
    phase2 = history['phase2']

    # Combine phases
    all_loss = phase1['loss'] + phase2['loss']
    all_val_loss = phase1['val_loss'] + phase2['val_loss']
    all_acc = phase1['accuracy'] + phase2['accuracy']
    all_val_acc = phase1['val_accuracy'] + phase2['val_accuracy']
    all_auc = phase1['auc'] + phase2['auc']
    all_val_auc = phase1['val_auc'] + phase2['val_auc']

    phase1_end = len(phase1['loss'])

    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle('Training History - Two-Phase Transfer Learning', fontsize=16, fontweight='bold')

    # Loss
    ax = axes[0, 0]
    epochs = range(1, len(all_loss) + 1)
    ax.plot(epochs, all_loss, 'b-', label='Training Loss', linewidth=2)
    ax.plot(epochs, all_val_loss, 'r-', label='Validation Loss', linewidth=2)
    ax.axvline(x=phase1_end, color='green', linestyle='--', linewidth=2, label='Phase 2 Start')
    ax.set_xlabel('Epoch', fontsize=12)
    ax.set_ylabel('Loss', fontsize=12)
    ax.set_title('Model Loss', fontsize=14, fontweight='bold')
    ax.legend(fontsize=10)
    ax.grid(True, alpha=0.3)

    # Accuracy
    ax = axes[0, 1]
    ax.plot(epochs, all_acc, 'b-', label='Training Accuracy', linewidth=2)
    ax.plot(epochs, all_val_acc, 'r-', label='Validation Accuracy', linewidth=2)
    ax.axvline(x=phase1_end, color='green', linestyle='--', linewidth=2, label='Phase 2 Start')
    ax.set_xlabel('Epoch', fontsize=12)
    ax.set_ylabel('Accuracy', fontsize=12)
    ax.set_title('Model Accuracy', fontsize=14, fontweight='bold')
    ax.legend(fontsize=10)
    ax.grid(True, alpha=0.3)

    # AUC
    ax = axes[1, 0]
    ax.plot(epochs, all_auc, 'b-', label='Training AUC', linewidth=2)
    ax.plot(epochs, all_val_auc, 'r-', label='Validation AUC', linewidth=2)
    ax.axvline(x=phase1_end, color='green', linestyle='--', linewidth=2, label='Phase 2 Start')
    ax.set_xlabel('Epoch', fontsize=12)
    ax.set_ylabel('AUC', fontsize=12)
    ax.set_title('Model AUC-ROC', fontsize=14, fontweight='bold')
    ax.legend(fontsize=10)
    ax.grid(True, alpha=0.3)

    # Learning rate (if available)
    ax = axes[1, 1]
    if 'lr' in phase1:
        all_lr = phase1['lr'] + phase2['lr']
        ax.plot(epochs, all_lr, 'g-', linewidth=2)
        ax.axvline(x=phase1_end, color='green', linestyle='--', linewidth=2, label='Phase 2 Start')
        ax.set_xlabel('Epoch', fontsize=12)
        ax.set_ylabel('Learning Rate', fontsize=12)
        ax.set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
        ax.set_yscale('log')
        ax.grid(True, alpha=0.3)
    else:
        ax.text(0.5, 0.5, 'Learning Rate\nData Not Available',
                ha='center', va='center', fontsize=14)
        ax.axis('off')

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"   ‚úÖ Saved training history to {save_path}")
    plt.show()

def plot_confusion_matrix(y_true, y_pred, save_path):
    """Plot confusion matrix"""
    cm = confusion_matrix(y_true, y_pred)

    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=['No Wildfire', 'Wildfire'],
                yticklabels=['No Wildfire', 'Wildfire'],
                cbar_kws={'label': 'Count'})

    plt.title('Confusion Matrix', fontsize=16, fontweight='bold', pad=20)
    plt.ylabel('True Label', fontsize=12)
    plt.xlabel('Predicted Label', fontsize=12)

    # Add percentages
    total = cm.sum()
    for i in range(2):
        for j in range(2):
            percentage = (cm[i, j] / total) * 100
            plt.text(j + 0.5, i + 0.7, f'({percentage:.1f}%)',
                    ha='center', va='center', fontsize=10, color='gray')

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"   ‚úÖ Saved confusion matrix to {save_path}")
    plt.show()

def plot_roc_curve(y_true, y_pred_prob, save_path):
    """Plot ROC curve"""
    fpr, tpr, thresholds = roc_curve(y_true, y_pred_prob)
    roc_auc = auc(fpr, tpr)

    plt.figure(figsize=(10, 8))
    plt.plot(fpr, tpr, color='darkorange', lw=3,
             label=f'ROC curve (AUC = {roc_auc:.4f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random Classifier')

    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate', fontsize=12)
    plt.ylabel('True Positive Rate', fontsize=12)
    plt.title('Receiver Operating Characteristic (ROC) Curve',
              fontsize=16, fontweight='bold', pad=20)
    plt.legend(loc="lower right", fontsize=12)
    plt.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"   ‚úÖ Saved ROC curve to {save_path}")
    plt.show()

def plot_precision_recall_curve(y_true, y_pred_prob, save_path):
    """Plot Precision-Recall curve"""
    precision, recall, thresholds = precision_recall_curve(y_true, y_pred_prob)
    avg_precision = average_precision_score(y_true, y_pred_prob)

    plt.figure(figsize=(10, 8))
    plt.plot(recall, precision, color='blue', lw=3,
             label=f'PR curve (AP = {avg_precision:.4f})')

    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('Recall', fontsize=12)
    plt.ylabel('Precision', fontsize=12)
    plt.title('Precision-Recall Curve', fontsize=16, fontweight='bold', pad=20)
    plt.legend(loc="lower left", fontsize=12)
    plt.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"   ‚úÖ Saved PR curve to {save_path}")
    plt.show()

def plot_metrics_summary(metrics, save_path):
    """Plot metrics summary bar chart"""
    metric_names = ['Accuracy', 'Precision', 'Recall', 'F1-Score', 'AUC-ROC']
    metric_values = [
        metrics['accuracy'],
        metrics['precision'],
        metrics['recall'],
        metrics['f1_score'],
        metrics['auc']
    ]

    colors = ['#3498db', '#2ecc71', '#e74c3c', '#f39c12', '#9b59b6']

    plt.figure(figsize=(12, 7))
    bars = plt.bar(metric_names, metric_values, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5)

    # Add value labels on bars
    for bar, value in zip(bars, metric_values):
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height,
                f'{value:.4f}',
                ha='center', va='bottom', fontsize=12, fontweight='bold')

    plt.ylim([0, 1.1])
    plt.ylabel('Score', fontsize=12, fontweight='bold')
    plt.title('Model Performance Metrics Summary', fontsize=16, fontweight='bold', pad=20)
    plt.grid(axis='y', alpha=0.3, linestyle='--')

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"   ‚úÖ Saved metrics summary to {save_path}")
    plt.show()


In [None]:
# ============= PRE-DOWNLOAD MODEL WEIGHTS =============
print("\n" + "="*60)
print("DOWNLOADING EFFICIENTNETB4 WEIGHTS")
print("="*60)

import socket
socket.setdefaulttimeout(300)  # 5 minute timeout

try:
    from tensorflow.keras.applications import EfficientNetB4
    print("üì• Downloading EfficientNetB4 weights (this may take a minute)...")
    _ = EfficientNetB4(weights='imagenet', include_top=False)
    print("‚úÖ Weights downloaded successfully!")
except Exception as e:
    print(f"‚ö†Ô∏è Auto-download failed: {e}")
    print("üì• Attempting manual download with wget...")
    !mkdir -p /root/.keras/models/
    !wget -O /root/.keras/models/efficientnetb4_notop.h5 https://storage.googleapis.com/keras-applications/efficientnetb4_notop.h5
    print("‚úÖ Manual download complete!")

print("="*60 + "\n")


DOWNLOADING EFFICIENTNETB4 WEIGHTS
üì• Downloading EfficientNetB4 weights (this may take a minute)...
Downloading data from https://storage.googleapis.com/keras-applications/efficientnetb4_notop.h5
[1m71686520/71686520[0m [32m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m[37m[0m [1m1s[0m 0us/step
‚úÖ Weights downloaded successfully!



In [None]:
from tensorflow.keras.models import load_model

train_gen, valid_gen, test_gen = create_data_generators(PROCESSED_DATA_PATH)
from datetime import datetime

print("\nüöÄ Training started...")
start_time = datetime.now()

model, history = train_model()

end_time = datetime.now()
print(f"\n‚è±Ô∏è Training time: {(end_time - start_time).total_seconds()/60:.2f} minutes")

# Load the best model in case training was interrupted or to ensure best weights are used
model = load_model(MODEL_SAVE_PATH)

results = evaluate_model(model, test_gen)
os.makedirs(RESULTS_DIR, exist_ok=True)

# NOTE: plot_training_history is skipped as 'history' might be incomplete due to interruption
# plot_training_history(history, os.path.join(RESULTS_DIR, "training_history.png"))

plot_confusion_matrix(
    results["y_true"],
    results["y_pred"],
    os.path.join(RESULTS_DIR, "confusion_matrix.png")
)

plot_roc_curve(
    results["y_true"],
    results["y_pred_prob"],
    os.path.join(RESULTS_DIR, "roc_curve.png")
)

plot_precision_recall_curve(
    results["y_true"],
    results["y_pred_prob"],
    os.path.join(RESULTS_DIR, "precision_recall_curve.png")
)

plot_metrics_summary(
    results["metrics"],
    os.path.join(RESULTS_DIR, "metrics_summary.png")
)



In [None]:
from tensorflow.keras.models import load_model
train_gen, valid_gen, test_gen = create_data_generators(PROCESSED_DATA_PATH)
print(f"\nüîç Loading model for evaluation...")
loaded_model = load_model("/content/drive/MyDrive/dataset/Results/best_model.keras")
print("‚úÖ Model loaded successfully!")

# Assuming train_gen is already defined from create_data_generators
# If not, you might need to run the `create_data_generators` cell first:
# train_gen, _, _ = create_data_generators(PROCESSED_DATA_PATH)

print("\nüìä Evaluating model on training data...")
train_loss, train_acc, train_auc = loaded_model.evaluate(train_gen, verbose=1)

print(f"\n‚ú® Training Data Evaluation Results:")
print(f"   Loss: {train_loss:.4f}")
print(f"   Accuracy: {train_acc:.4f}")
print(f"   AUC-ROC: {train_auc:.4f}")



üîÑ Creating data generators...
‚úÖ Found directory: /content/drive/MyDrive/dataset/processed_data/train
‚úÖ Found directory: /content/drive/MyDrive/dataset/processed_data/valid
‚úÖ Found directory: /content/drive/MyDrive/dataset/processed_data/test
Found 32152 images belonging to 2 classes.
Found 3999 images belonging to 2 classes.
Found 4000 images belonging to 2 classes.
‚úÖ Data generators ready

üîç Loading model for evaluation...
‚úÖ Model loaded successfully!

üìä Evaluating model on training data...


  self._warn_if_super_not_called()


[1m1005/1005[0m [32m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m[37m[0m [1m9955s[0m 10s/step - accuracy: 0.9522 - auc: 0.9893 - loss: 0.1263

‚ú® Training Data Evaluation Results:
   Loss: 0.1287
   Accuracy: 0.9522
   AUC-ROC: 0.9890
