In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report, confusion_matrix, precision_recall_curve, roc_curve, auc
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.cuda.amp import GradScaler, autocast
import warnings
from itertools import cycle
from sklearn.preprocessing import label_binarize
import torchvision.models as models
from efficientnet_pytorch import EfficientNet
import time
from PIL import Image as PILImage
from collections import Counter
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
import psutil
import gc
import torch.nn.init as init
import pandas as pd
from datetime import datetime
from pathlib import Path
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as tv_models
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import os
warnings.filterwarnings('ignore')

# Settings

In [None]:
BASE_DIR = Path.cwd()

MODEL_DIR = str(BASE_DIR / "ml_models")
data_dir = str(BASE_DIR / "Processed_Data")
AUGMENTED_IMAGES_DIR = str(BASE_DIR / "aug_images")

if not os.path.exists(MODEL_DIR):
    os.makedirs(MODEL_DIR)

resnet_model_training = True
mobilenet_model_training = True
efficientnet_model_training = True
vgg16_model_training = True
alexnet_model_training = True
hybrid_model_training = True
cnn_model_training = True

common_epochs = 1


# GPU Check

In [None]:
print("CUDA available:", torch.cuda.is_available())
print("Device:", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
if torch.cuda.is_available():
    print("GPU Name:", torch.cuda.get_device_name(0))

# Clean GPU Memory

In [None]:
# Clear GPU cache
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

# Force garbage collection
gc.collect()

print(f"GPU memory allocated: {torch.cuda.memory_allocated(0)/1024**3:.2f} GB")
print(f"GPU memory cached: {torch.cuda.memory_reserved(0)/1024**3:.2f} GB")

# Data Analysis

In [None]:
class ThermalDataExporter:
    def __init__(self, output_dir='thermal_analysis_results'):
        self.output_dir = output_dir
        self.timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        os.makedirs(output_dir, exist_ok=True)
        self.training_data = {}
        self.model_evaluations = []
        self.confusion_matrices = {}
        print(f'ThermalDataExporter initialized. Output directory: {output_dir}')

    def save_training_curves(self, model_name, train_losses, train_accs, val_losses, val_accs):
        """Save training curves data for a specific model"""
        self.training_data[model_name] = {
            'train_losses': train_losses,
            'train_accs': train_accs,
            'val_losses': val_losses,
            'val_accs': val_accs
        }
        print(f'Training curves saved for: {model_name}')

    def save_model_evaluation(self, model_name, accuracy, precision, recall, f1_score,
                            confusion_matrix, classification_report_text):
        """Save comprehensive model evaluation metrics"""
        result = {
            'model_name': model_name,
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1_score': f1_score,
            'timestamp': self.timestamp,
            'classification_report': classification_report_text
        }

        self.model_evaluations.append(result)
        self.confusion_matrices[model_name] = confusion_matrix
        print(f'Model evaluation saved for: {model_name}')
        return result

    def export_training_curves(self, training_data):
        """Export training curves data to CSV"""
        df = pd.DataFrame(training_data)
        filename = f'training_curves_{self.timestamp}.csv'
        filepath = os.path.join(self.output_dir, filename)
        df.to_csv(filepath, index=False)
        print(f'Training curves exported to: {filepath}')
        return filepath

    def export_model_evaluations(self, model_results):
        """Export model evaluation results to CSV"""
        df = pd.DataFrame(model_results)
        filename = f'model_evaluations_{self.timestamp}.csv'
        filepath = os.path.join(self.output_dir, filename)
        df.to_csv(filepath, index=False)
        print(f'Model evaluations exported to: {filepath}')
        return filepath

    def save_all_results_csv(self):
        """Export all collected data to CSV files"""
        # Export training data
        for model_name, data in self.training_data.items():
            df = pd.DataFrame({
                'epoch': range(1, len(data['train_losses']) + 1),
                'train_loss': data['train_losses'],
                'train_accuracy': data['train_accs'],
                'val_loss': data['val_losses'],
                'val_accuracy': data['val_accs']
            })
            filename = f'{model_name}_training_curves_{self.timestamp}.csv'
            filepath = os.path.join(self.output_dir, filename)
            df.to_csv(filepath, index=False)
            print(f'{model_name} training curves exported to: {filepath}')

        # Export model evaluations
        if self.model_evaluations:
            df_eval = pd.DataFrame(self.model_evaluations)
            filename = f'model_evaluations_{self.timestamp}.csv'
            filepath = os.path.join(self.output_dir, filename)
            df_eval.to_csv(filepath, index=False)
            print(f'Model evaluations exported to: {filepath}')

        print('All results exported to CSV files!')

print('ThermalDataExporter class loaded successfully!')

## Data Analysis Visualisation

In [None]:
# Enhanced Data Visualization and CSV Export Functions

class EnhancedDataExporter:
    """Enhanced data exporter with comprehensive CSV and visualization capabilities"""

    def __init__(self, output_dir='thermal_analysis_results'):
        self.output_dir = output_dir
        self.timestamp = pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')
        os.makedirs(output_dir, exist_ok=True)
        print(f'EnhancedDataExporter initialized. Output directory: {output_dir}')

        # Set seaborn style
        sns.set_style('whitegrid')
        sns.set_palette('husl')
        plt.rcParams['figure.dpi'] = 300

    def export_training_data_to_csv(self, all_training_data):
        """Export comprehensive training data for all models to CSV"""
        print('Exporting comprehensive training data to CSV...')

        for model_name, data in all_training_data.items():
            df = pd.DataFrame({
                'epoch': range(1, len(data['train_losses']) + 1),
                'train_loss': data['train_losses'],
                'train_accuracy': data['train_accs'],
                'val_loss': data['val_losses'],
                'val_accuracy': data['val_accs'],
                'model_name': model_name,
                'timestamp': self.timestamp
            })

            filename = f'{model_name.lower().replace(" ", "_")}_training_data_{self.timestamp}.csv'
            filepath = os.path.join(self.output_dir, filename)
            df.to_csv(filepath, index=False)
            print(f'  {model_name} training data: {filepath}')

    def export_model_comparison_csv(self, model_results):
        """Export comprehensive model comparison data"""
        print('Exporting model comparison data to CSV...')

        comparison_data = []
        for result in model_results:
            comparison_data.append({
                'model_name': result['model_name'],
                'accuracy': result.get('accuracy', 0),
                'precision_macro': result.get('precision_macro', 0),
                'precision_weighted': result.get('precision_weighted', 0),
                'recall_macro': result.get('recall_macro', 0),
                'recall_weighted': result.get('recall_weighted', 0),
                'f1_macro': result.get('f1_macro', 0),
                'f1_weighted': result.get('f1_weighted', 0),
                'timestamp': self.timestamp
            })

        df = pd.DataFrame(comparison_data)
        filename = f'model_comparison_{self.timestamp}.csv'
        filepath = os.path.join(self.output_dir, filename)
        df.to_csv(filepath, index=False)
        print(f'  Model comparison data: {filepath}')
        return df

    def export_confusion_matrices_csv(self, model_results, class_names):
        """Export confusion matrices data to CSV"""
        print('Exporting confusion matrices to CSV...')

        for result in model_results:
            if 'confusion_matrix' in result:
                cm = result['confusion_matrix']
                df = pd.DataFrame(cm, columns=class_names, index=class_names)
                df.index.name = 'True_Label'
                df['model_name'] = result['model_name']
                df['timestamp'] = self.timestamp

                filename = f'confusion_matrix_{result["model_name"].lower().replace(" ", "_")}_{self.timestamp}.csv'
                filepath = os.path.join(self.output_dir, filename)
                df.to_csv(filepath)
                print(f'  {result["model_name"]} confusion matrix: {filepath}')

def plot_enhanced_training_curves(all_training_data, save_individual=True):
    """Enhanced training curves with seaborn styling"""
    print('Creating enhanced training curves visualization...')

    if save_individual:
        # Individual plots for each model
        for model_name, data in all_training_data.items():
            fig, axes = plt.subplots(2, 2, figsize=(15, 12))
            epochs = range(1, len(data['train_losses']) + 1)

            # Loss plot
            axes[0, 0].plot(epochs, data['train_losses'], 'o-', label='Training', linewidth=2, markersize=5)
            axes[0, 0].plot(epochs, data['val_losses'], 's-', label='Validation', linewidth=2, markersize=5)
            axes[0, 0].set_title(f'{model_name} - Loss Over Time', fontsize=14, fontweight='bold')
            axes[0, 0].set_xlabel('Epoch')
            axes[0, 0].set_ylabel('Loss')
            axes[0, 0].legend()
            axes[0, 0].grid(True, alpha=0.3)

            # Accuracy plot
            axes[0, 1].plot(epochs, [acc * 100 for acc in data['train_accs']], 'o-', label='Training', linewidth=2, markersize=5)
            axes[0, 1].plot(epochs, [acc * 100 for acc in data['val_accs']], 's-', label='Validation', linewidth=2, markersize=5)
            axes[0, 1].set_title(f'{model_name} - Accuracy Over Time', fontsize=14, fontweight='bold')
            axes[0, 1].set_xlabel('Epoch')
            axes[0, 1].set_ylabel('Accuracy (%)')
            axes[0, 1].legend()
            axes[0, 1].grid(True, alpha=0.3)

            # Loss difference
            loss_diff = np.array(data['val_losses']) - np.array(data['train_losses'])
            axes[1, 0].plot(epochs, loss_diff, 'o-', color='red', linewidth=2, markersize=5)
            axes[1, 0].axhline(y=0, color='black', linestyle='--', alpha=0.5)
            axes[1, 0].set_title(f'{model_name} - Overfitting Analysis (Val - Train Loss)', fontsize=14, fontweight='bold')
            axes[1, 0].set_xlabel('Epoch')
            axes[1, 0].set_ylabel('Loss Difference')
            axes[1, 0].grid(True, alpha=0.3)

            # Learning rate effect (if available)
            axes[1, 1].plot(epochs, data['val_accs'], 'o-', label='Validation Accuracy', linewidth=2, markersize=5)
            axes[1, 1].set_title(f'{model_name} - Validation Accuracy Trend', fontsize=14, fontweight='bold')
            axes[1, 1].set_xlabel('Epoch')
            axes[1, 1].set_ylabel('Validation Accuracy')
            axes[1, 1].legend()
            axes[1, 1].grid(True, alpha=0.3)

            plt.tight_layout()
            filename = f'enhanced_training_curves_{model_name.lower().replace(" ", "_")}.png'
            plt.savefig(os.path.join('thermal_analysis_results', filename), dpi=300, bbox_inches='tight')
            plt.show()
            print(f'  Enhanced training curves for {model_name} saved!')

def plot_comprehensive_model_comparison(model_results):
    """Comprehensive model comparison with multiple visualization types"""
    print('Creating comprehensive model comparison...')

    if len(model_results) < 2:
        print('Need at least 2 models for comparison')
        return

    # Prepare data for different metrics
    metrics_data = []
    for result in model_results:
        metrics_data.append({
            'Model': result['model_name'],
            'Accuracy': result.get('accuracy', 0) * 100,
            'Precision (Macro)': result.get('precision_macro', 0) * 100,
            'Recall (Macro)': result.get('recall_macro', 0) * 100,
            'F1-Score (Macro)': result.get('f1_macro', 0) * 100,
            'Precision (Weighted)': result.get('precision_weighted', 0) * 100,
            'Recall (Weighted)': result.get('recall_weighted', 0) * 100,
            'F1-Score (Weighted)': result.get('f1_weighted', 0) * 100
        })

    df = pd.DataFrame(metrics_data)

    # Create comprehensive comparison figure
    fig, axes = plt.subplots(2, 2, figsize=(20, 15))

    # 1. Bar plot for macro metrics
    macro_cols = ['Accuracy', 'Precision (Macro)', 'Recall (Macro)', 'F1-Score (Macro)']
    df_macro = df[['Model'] + macro_cols].melt(id_vars='Model', var_name='Metric', value_name='Score')
    sns.barplot(data=df_macro, x='Metric', y='Score', hue='Model', ax=axes[0, 0], palette='Set2')
    axes[0, 0].set_title('Model Performance Comparison (Macro Averages)', fontsize=16, fontweight='bold')
    axes[0, 0].set_ylabel('Score (%)')
    axes[0, 0].legend(title='Model', bbox_to_anchor=(1.05, 1), loc='upper left')
    axes[0, 0].tick_params(axis='x', rotation=45)

    # 2. Heatmap for all metrics
    df_heatmap = df.set_index('Model').T
    sns.heatmap(df_heatmap, annot=True, fmt='.2f', cmap='YlOrRd', ax=axes[0, 1], cbar_kws={'label': 'Score (%)'})
    axes[0, 1].set_title('Performance Heatmap (All Metrics)', fontsize=16, fontweight='bold')

    # 3. Radar/Spider plot for comprehensive view
    angles = np.linspace(0, 2 * np.pi, len(macro_cols), endpoint=False).tolist()
    angles += angles[:1]  # Complete the circle

    ax_radar = plt.subplot(2, 2, 3, projection='polar')
    colors = sns.color_palette('Set2', len(model_results))

    for idx, result in enumerate(model_results):
        values = [
            result.get('accuracy', 0) * 100,
            result.get('precision_macro', 0) * 100,
            result.get('recall_macro', 0) * 100,
            result.get('f1_macro', 0) * 100
        ]
        values += values[:1]  # Complete the circle

        ax_radar.plot(angles, values, 'o-', linewidth=2, label=result['model_name'], color=colors[idx])
        ax_radar.fill(angles, values, alpha=0.25, color=colors[idx])

    ax_radar.set_xticks(angles[:-1])
    ax_radar.set_xticklabels(['Accuracy', 'Precision', 'Recall', 'F1-Score'])
    ax_radar.set_ylim(0, 100)
    ax_radar.set_title('Performance Radar Chart', fontsize=16, fontweight='bold', pad=30)
    ax_radar.legend(loc='upper right', bbox_to_anchor=(1.3, 1.0))
    ax_radar.grid(True)

    # 4. Box plot for metric distribution
    metric_values = []
    for result in model_results:
        for metric_name, value in [('Accuracy', result.get('accuracy', 0) * 100),
                                   ('Precision', result.get('precision_macro', 0) * 100),
                                   ('Recall', result.get('recall_macro', 0) * 100),
                                   ('F1-Score', result.get('f1_macro', 0) * 100)]:
            metric_values.append({'Model': result['model_name'], 'Metric': metric_name, 'Value': value})

    df_box = pd.DataFrame(metric_values)
    sns.boxplot(data=df_box, x='Model', y='Value', hue='Metric', ax=axes[1, 1], palette='Set3')
    axes[1, 1].set_title('Metric Distribution by Model', fontsize=16, fontweight='bold')
    axes[1, 1].set_ylabel('Score (%)')
    axes[1, 1].legend(title='Metric')

    plt.tight_layout()
    plt.savefig(os.path.join('thermal_analysis_results', 'comprehensive_model_comparison.png'), dpi=300, bbox_inches='tight')
    plt.show()
    print('  Comprehensive model comparison saved!')

def plot_roc_curves(model_probabilities, y_true, class_names, model_names):
    """Enhanced ROC curves plotting with seaborn styling for multi-class classification"""

    # Set seaborn style
    sns.set_style("whitegrid")
    sns.set_palette("husl")

    n_classes = len(class_names)

    # Binarize the output for multi-class ROC
    y_bin = label_binarize(y_true, classes=range(n_classes))

    fig, axes = plt.subplots(1, len(model_names), figsize=(6*len(model_names), 5))
    if len(model_names) == 1:
        axes = [axes]

    colors = sns.color_palette("husl", n_classes + 1)

    for model_idx, (model_name, y_prob) in enumerate(zip(model_names, model_probabilities)):
        ax = axes[model_idx]

        # Compute ROC curve and ROC area for each class
        fpr = dict()
        tpr = dict()
        roc_auc = dict()

        for i in range(n_classes):
            fpr[i], tpr[i], _ = roc_curve(y_bin[:, i], y_prob[:, i])
            roc_auc[i] = auc(fpr[i], tpr[i])

        # Compute micro-average ROC curve and ROC area
        fpr["micro"], tpr["micro"], _ = roc_curve(y_bin.ravel(), y_prob.ravel())
        roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

        # Plot micro-average ROC curve with enhanced styling
        ax.plot(fpr["micro"], tpr["micro"],
               color='deeppink', linestyle='--', linewidth=3,
               label=f'Micro-avg ROC (AUC = {roc_auc["micro"]:.3f})')

        # Plot ROC curve for each class
        for i, color in zip(range(n_classes), colors):
            ax.plot(fpr[i], tpr[i], color=color, linewidth=2.5,
                   label=f'{class_names[i]} (AUC = {roc_auc[i]:.3f})')

        # Plot random classifier line
        ax.plot([0, 1], [0, 1], 'k--', linewidth=2, alpha=0.6, label='Random Classifier')

        # Styling with seaborn
        ax.set_xlim([0.0, 1.0])
        ax.set_ylim([0.0, 1.05])
        ax.set_xlabel('False Positive Rate', fontsize=12, fontweight='bold')
        ax.set_ylabel('True Positive Rate', fontsize=12, fontweight='bold')
        ax.set_title(f'{model_name} - ROC Curves', fontsize=14, fontweight='bold', pad=20)
        ax.legend(loc="lower right", frameon=True, fancybox=True, shadow=True)
        ax.grid(True, alpha=0.3)

        # Add subtle background color
        ax.set_facecolor('#fafafa')

    plt.tight_layout()

    # Save with high quality
    if not os.path.exists('thermal_analysis_results'):
        os.makedirs('thermal_analysis_results')
    plt.savefig(os.path.join('thermal_analysis_results', 'enhanced_roc_curves_seaborn.png'),
               dpi=300, bbox_inches='tight', facecolor='white', edgecolor='none')
    plt.show()

    # Print summary statistics
    print("\nROC AUC Summary:")
    for model_idx, model_name in enumerate(model_names):
        y_prob = model_probabilities[model_idx]
        fpr_micro, tpr_micro, _ = roc_curve(y_bin.ravel(), y_prob.ravel())
        auc_micro = auc(fpr_micro, tpr_micro)
        print(f"  {model_name}: Micro-avg AUC = {auc_micro:.4f}")
def plot_precision_recall_curves(model_probabilities, y_true, class_names, model_names):
    """Plot Precision-Recall curves for multi-class classification"""
    print('Creating Precision-Recall curves visualization...')

    n_classes = len(class_names)
    y_bin = label_binarize(y_true, classes=range(n_classes))

    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    axes = axes.ravel()

    for model_idx, (model_name, y_prob) in enumerate(zip(model_names, model_probabilities)):
        if model_idx >= 4:
            break

        ax = axes[model_idx]

        colors_local = cycle(['aqua', 'darkorange', 'cornflowerblue', 'red', 'green'])
        for i, color in zip(range(n_classes), colors_local):
            precision, recall, _ = precision_recall_curve(y_bin[:, i], y_prob[:, i])
            avg_precision = auc(recall, precision)

            ax.plot(recall, precision, color=color, lw=2,
                   label=f'{class_names[i]} (AP = {avg_precision:.2f})')

        ax.set_xlim([0.0, 1.0])
        ax.set_ylim([0.0, 1.05])
        ax.set_xlabel('Recall')
        ax.set_ylabel('Precision')
        ax.set_title(f'Precision-Recall Curves - {model_name}', fontweight='bold')
        ax.legend(loc="lower left", fontsize=8)
        ax.grid(True, alpha=0.3)

    # Hide unused subplots
    for idx in range(len(model_names), 4):
        axes[idx].set_visible(False)

    plt.tight_layout()
    plt.savefig(os.path.join('thermal_analysis_results', 'precision_recall_curves.png'), dpi=300, bbox_inches='tight')
    plt.show()
    print('  Precision-Recall curves saved!')

def create_statistical_summary_report(model_results, output_dir='thermal_analysis_results'):
    """Create a comprehensive statistical summary report"""
    print('Creating statistical summary report...')

    # Prepare data for statistical analysis
    stats_data = []
    for result in model_results:
        stats_data.append({
            'Model': result['model_name'],
            'Accuracy': result.get('accuracy', 0),
            'Precision_Macro': result.get('precision_macro', 0),
            'Precision_Weighted': result.get('precision_weighted', 0),
            'Recall_Macro': result.get('recall_macro', 0),
            'Recall_Weighted': result.get('recall_weighted', 0),
            'F1_Macro': result.get('f1_macro', 0),
            'F1_Weighted': result.get('f1_weighted', 0)
        })

    df_stats = pd.DataFrame(stats_data)

    # Calculate summary statistics
    summary_stats = df_stats.describe()

    # Find best performing model for each metric
    best_models = {}
    for col in df_stats.columns[1:]:  # Skip 'Model' column
        best_idx = df_stats[col].idxmax()
        best_models[col] = (df_stats.loc[best_idx, 'Model'], df_stats.loc[best_idx, col])

    # Create and save summary report
    timestamp = pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')
    report_file = os.path.join(output_dir, f'statistical_summary_report_{timestamp}.txt')

    with open(report_file, 'w') as f:
        f.write("THERMAL IMAGING MODEL PERFORMANCE STATISTICAL SUMMARY\n")
        f.write("=" * 60 + "\n\n")

        f.write("SUMMARY STATISTICS:\n")
        f.write("-" * 30 + "\n")
        f.write(summary_stats.to_string())
        f.write("\n\n")

        f.write("BEST PERFORMING MODELS BY METRIC:\n")
        f.write("-" * 40 + "\n")
        for metric, (model, score) in best_models.items():
            f.write(f"{metric.replace('_', ' ').title()}: {model} ({score:.4f})\n")

        f.write("\n" + "=" * 60 + "\n")
        f.write(f"Report generated on: {pd.Timestamp.now()}\n")

    print(f'  Statistical summary report saved: {report_file}')
    return summary_stats, best_models

# Global storage for enhanced data
enhanced_training_data = {}
enhanced_model_results = []

print('Enhanced data visualization and CSV export functions loaded!')
print('Available visualizations:')
print('   • Enhanced training curves with overfitting analysis')
print('   • Comprehensive model comparison (bar, heatmap, radar, box plots)')
print('   • ROC curves for multi-class classification')
print('   • Precision-Recall curves')
print('   • Statistical summary reports')
print('   • CSV export for all data types')

In [None]:
# Enhanced ROC Curve Plotting with Additional Seaborn Features
def plot_enhanced_roc_comparison(model_probabilities, y_true, class_names, model_names):
    """Creates a comprehensive ROC comparison plot with seaborn styling"""

    # Set sophisticated seaborn theme
    sns.set_theme(style="whitegrid", palette="deep")

    # Create figure with subplots
    fig = plt.figure(figsize=(16, 10))
    gs = fig.add_gridspec(2, 2, hspace=0.3, wspace=0.3)

    # Main ROC plot
    ax_main = fig.add_subplot(gs[0, :])

    # Binarize labels
    n_classes = len(class_names)
    y_bin = label_binarize(y_true, classes=range(n_classes))

    # Color palette
    colors = sns.color_palette("Set2", len(model_names))

    # Plot micro-average ROC for each model
    auc_scores = []
    for model_idx, (model_name, y_prob) in enumerate(zip(model_names, model_probabilities)):
        fpr_micro, tpr_micro, _ = roc_curve(y_bin.ravel(), y_prob.ravel())
        roc_auc_micro = auc(fpr_micro, tpr_micro)
        auc_scores.append(roc_auc_micro)

        ax_main.plot(fpr_micro, tpr_micro,
                    color=colors[model_idx],
                    linewidth=3,
                    label=f'{model_name} (AUC = {roc_auc_micro:.3f})')

    # Random classifier line
    ax_main.plot([0, 1], [0, 1], 'k--', linewidth=2, alpha=0.7, label='Random Classifier')

    # Styling main plot
    ax_main.set_xlabel('False Positive Rate', fontsize=14, fontweight='bold')
    ax_main.set_ylabel('True Positive Rate', fontsize=14, fontweight='bold')
    ax_main.set_title('ROC Curves Comparison - All Models (Micro-Average)',
                     fontsize=16, fontweight='bold', pad=20)
    ax_main.legend(loc='lower right', fontsize=12, frameon=True)
    ax_main.grid(True, alpha=0.3)
    ax_main.set_facecolor('#fafafa')

    # AUC comparison bar plot
    ax_bar = fig.add_subplot(gs[1, 0])
    bars = ax_bar.bar(model_names, auc_scores, color=colors, alpha=0.8, edgecolor='black')
    ax_bar.set_ylabel('AUC Score', fontsize=12, fontweight='bold')
    ax_bar.set_title('AUC Comparison', fontsize=14, fontweight='bold')
    ax_bar.set_ylim(0, 1)

    # Add value labels on bars
    for bar, score in zip(bars, auc_scores):
        height = bar.get_height()
        ax_bar.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                   f'{score:.3f}', ha='center', va='bottom', fontweight='bold')

    # Performance metrics heatmap
    ax_heat = fig.add_subplot(gs[1, 1])
    metrics_data = np.array(auc_scores).reshape(1, -1)
    sns.heatmap(metrics_data,
               xticklabels=model_names,
               yticklabels=['AUC'],
               annot=True,
               fmt='.3f',
               cmap='RdYlGn',
               ax=ax_heat,
               cbar_kws={'label': 'AUC Score'})
    ax_heat.set_title('AUC Heatmap', fontsize=14, fontweight='bold')

    plt.suptitle('Enhanced ROC Analysis Dashboard', fontsize=18, fontweight='bold', y=0.98)

    # Save the plot
    if not os.path.exists('thermal_analysis_results'):
        os.makedirs('thermal_analysis_results')
    plt.savefig(os.path.join('thermal_analysis_results', 'enhanced_roc_dashboard_seaborn.png'),
               dpi=300, bbox_inches='tight', facecolor='white')
    plt.show()

    return auc_scores

print("Enhanced ROC plotting functions loaded successfully!")


In [None]:
# Example usage of enhanced ROC plotting functions
# This cell demonstrates how to use the enhanced ROC plotting capabilities

def demonstrate_enhanced_roc_plotting(model_results_dict, test_probabilities, y_test, class_names):
    """
    Demonstrates the enhanced ROC plotting functions with actual model results

    Parameters:
    - model_results_dict: Dictionary containing model results
    - test_probabilities: List of probability arrays from each model
    - y_test: True test labels
    - class_names: List of class names
    """
    print("\nCreating Enhanced ROC Visualizations...")

    # Extract model names from results
    model_names = list(model_results_dict.keys())

    if len(test_probabilities) > 0 and len(model_names) > 0:
        # Use the original enhanced ROC function
        print("Plotting individual model ROC curves with seaborn styling...")
        plot_roc_curves(test_probabilities, y_test, class_names, model_names)

        # Use the comprehensive ROC dashboard
        print("Creating comprehensive ROC comparison dashboard...")
        auc_scores = plot_enhanced_roc_comparison(test_probabilities, y_test, class_names, model_names)

        # Print summary
        print("\nEnhanced ROC Analysis Complete!")
        print("Generated files:")
        print("   • enhanced_roc_curves_seaborn.png")
        print("   • enhanced_roc_dashboard_seaborn.png")

        return auc_scores
    else:
        print("No model probabilities available for ROC plotting")
        return None

# Global variables to store model probabilities for ROC plotting
MODEL_PROBABILITIES = []  # Will store probability outputs from each model
MODEL_RESULTS_GLOBAL = {}  # Will store all model results
TEST_LABELS_GLOBAL = None  # Will store test labels

def store_model_probabilities(model_name, probabilities, test_labels=None):
    """
    Store model probabilities for later ROC plotting

    Parameters:
    - model_name: Name of the model
    - probabilities: Model probability predictions
    - test_labels: Test labels (only needed once)
    """
    global MODEL_PROBABILITIES, TEST_LABELS_GLOBAL

    MODEL_PROBABILITIES.append(probabilities)

    if test_labels is not None:
        TEST_LABELS_GLOBAL = test_labels

    print(f"Stored probabilities for {model_name} (Total models: {len(MODEL_PROBABILITIES)})")

def plot_all_model_roc_curves():
    """
    Plot ROC curves for all stored models
    """
    global MODEL_PROBABILITIES, MODEL_RESULTS_GLOBAL, TEST_LABELS_GLOBAL

    if len(MODEL_PROBABILITIES) > 0 and TEST_LABELS_GLOBAL is not None:
        model_names = list(MODEL_RESULTS_GLOBAL.keys())
        class_names = ['Class 0', 'Class 1', 'Class 2', 'Class 3', 'Class 4']  # Update based on your classes

        print(f"\nPlotting ROC curves for {len(model_names)} models...")

        # Use both plotting functions
        auc_scores = demonstrate_enhanced_roc_plotting(
            MODEL_RESULTS_GLOBAL,
            MODEL_PROBABILITIES,
            TEST_LABELS_GLOBAL,
            class_names
        )

        return auc_scores
    else:
        print("No model data available for ROC plotting. Train some models first!")
        return None

print("Enhanced ROC plotting utilities loaded!")
print("Usage:")
print("   1. Use store_model_probabilities(model_name, probabilities, test_labels) after each model training")
print("   2. Use plot_all_model_roc_curves() to generate comprehensive ROC analysis")
print("   3. Or use demonstrate_enhanced_roc_plotting() directly with your data")


# Enhanced Optimization System

In [None]:
# Note: LabelSmoothingCrossEntropy, FocalLoss, and other utility classes are now defined in section 3.1

def enhance_model_architecture(model, model_type, num_classes, dropout_rate=0.3):
    """Enhanced model architecture with regularization"""
    if model_type == "resnet":
        model.fc = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(model.fc.in_features, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
            nn.Linear(512, num_classes)
        )
    elif model_type == "mobilenet":
        model.classifier = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(model.last_channel, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
            nn.Linear(512, num_classes)
        )
    elif model_type == "efficientnet":
        model._fc = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(model._fc.in_features, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
            nn.Linear(512, num_classes)
        )
    return model

def train_model_fine_tuned(model, train_loader, val_loader, criterion, epochs, device, model_name):
    """Enhanced training with 8 optimization techniques"""
    print(f"Enhanced Training for {model_name} with 8 Advanced Optimizations")

    # 1. AdamW Optimizer (better weight decay)
    optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01, eps=1e-8)

    # 2. Cosine Annealing LR Scheduler
    scheduler = CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6)

    # 3. Label Smoothing
    criterion_smooth = LabelSmoothingCrossEntropy(smoothing=0.1)

    # 4. Mixed Precision Training
    scaler = GradScaler()

    train_losses, train_accs, val_losses, val_accs = [], [], [], []

    for epoch in range(epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for batch_idx, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Mixed Precision Forward Pass
            with autocast():
                outputs = model(inputs)
                loss = criterion_smooth(outputs, labels)

            # Mixed Precision Backward Pass
            scaler.scale(loss).backward()

            # 6. Gradient Clipping
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            # print("done")

        train_loss = running_loss / total
        train_acc = correct / total

        # Validation phase
        model.eval()
        val_running_loss = 0.0
        val_correct = 0
        val_total = 0
        # print("pt 2")

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)

                with autocast():
                    outputs = model(inputs)
                    loss = criterion_smooth(outputs, labels)

                val_running_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_loss = val_running_loss / val_total
        val_acc = val_correct / val_total

        # Store metrics
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        val_losses.append(val_loss)
        val_accs.append(val_acc)

        # Learning Rate Scheduling
        scheduler.step()

        # Progress reporting
        current_lr = optimizer.param_groups[0]['lr']
        print(f"Epoch [{epoch+1}/{epochs}] - Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f} | LR: {current_lr:.6f}")

    print(f"Enhanced training completed for {model_name}!")
    return train_losses, train_accs, val_losses, val_accs

print("Enhanced optimization system loaded successfully!")
print("Available optimizations:")
print("   1. AdamW optimizer with proper weight decay")
print("   2. Cosine Annealing LR scheduler")
print("   3. Label Smoothing (0.1)")
print("   4. Mixed Precision Training (AMP)")
print("   5. Gradient Clipping (max_norm=1.0)")
print("   6. Enhanced model architectures")
print("   7. Advanced regularization techniques")

In [None]:
"""
Enhanced Training Data CSV Export Integration
============================================

Add this code to your Jupyter notebook to enable automatic CSV saving
of training metrics (training accuracy, validation accuracy, training loss,
validation loss) per epoch for each model.

Usage:
1. Run this cell in your notebook
2. The training function will automatically save CSV files
3. Use the plotting functions to visualize results

"""

def save_training_data_to_csv(model_name, train_losses, train_accs, val_losses, val_accs):
    """Save training metrics to CSV file for each model"""
    results_dir = 'thermal_analysis_results'
    os.makedirs(results_dir, exist_ok=True)

    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    epochs_range = range(1, len(train_losses) + 1)

    df = pd.DataFrame({
        'epoch': epochs_range,
        'training_loss': train_losses,
        'training_accuracy': train_accs,
        'validation_loss': val_losses,
        'validation_accuracy': val_accs,
        'model_name': model_name,
        'timestamp': timestamp
    })

    filename = f'{model_name.lower().replace(" ", "_")}_training_curves_{timestamp}.csv'
    filepath = os.path.join(results_dir, filename)
    df.to_csv(filepath, index=False)

    print(f"Training data saved to: {filepath}")
    return filepath

def plot_training_curves_from_csv(csv_file_path):
    """Plot training curves from saved CSV data"""
    df = pd.read_csv(csv_file_path)
    model_name = df['model_name'].iloc[0]

    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))

    # Training and Validation Loss
    ax1.plot(df['epoch'], df['training_loss'], 'o-', label='Training Loss', linewidth=2)
    ax1.plot(df['epoch'], df['validation_loss'], 's-', label='Validation Loss', linewidth=2)
    ax1.set_title(f'{model_name} - Training & Validation Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # Training and Validation Accuracy
    ax2.plot(df['epoch'], df['training_accuracy'], 'o-', label='Training Accuracy', linewidth=2)
    ax2.plot(df['epoch'], df['validation_accuracy'], 's-', label='Validation Accuracy', linewidth=2)
    ax2.set_title(f'{model_name} - Training & Validation Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    # Overfitting Analysis
    loss_diff = df['validation_loss'] - df['training_loss']
    ax3.plot(df['epoch'], loss_diff, 'o-', color='red', linewidth=2)
    ax3.axhline(y=0, color='black', linestyle='--', alpha=0.5)
    ax3.set_title(f'{model_name} - Overfitting Analysis (Val - Train Loss)')
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Loss Difference')
    ax3.grid(True, alpha=0.3)

    # Accuracy Gap
    acc_diff = df['training_accuracy'] - df['validation_accuracy']
    ax4.plot(df['epoch'], acc_diff, 'o-', color='orange', linewidth=2)
    ax4.axhline(y=0, color='black', linestyle='--', alpha=0.5)
    ax4.set_title(f'{model_name} - Accuracy Gap (Train - Val)')
    ax4.set_xlabel('Epoch')
    ax4.set_ylabel('Accuracy Difference')
    ax4.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()
    return fig

def plot_all_models_comparison():
    """Plot comparison of all models from CSV files"""
    results_dir = 'thermal_analysis_results'
    csv_files = glob.glob(os.path.join(results_dir, '*_training_curves_*.csv'))

    if not csv_files:
        print("No training curve CSV files found")
        return

    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(18, 14))
    colors = plt.cm.Set1(range(len(csv_files)))

    for i, csv_file in enumerate(csv_files):
        df = pd.read_csv(csv_file)
        model_name = df['model_name'].iloc[0]
        color = colors[i]

        # Plot all metrics
        ax1.plot(df['epoch'], df['training_loss'], 'o-', label=model_name,
                color=color, linewidth=2, markersize=4)
        ax2.plot(df['epoch'], df['validation_loss'], 's-', label=model_name,
                color=color, linewidth=2, markersize=4)
        ax3.plot(df['epoch'], df['training_accuracy'], 'o-', label=model_name,
                color=color, linewidth=2, markersize=4)
        ax4.plot(df['epoch'], df['validation_accuracy'], 's-', label=model_name,
                color=color, linewidth=2, markersize=4)

    # Configure plots
    ax1.set_title('Training Loss Comparison', fontsize=14, fontweight='bold')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Training Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    ax2.set_title('Validation Loss Comparison', fontsize=14, fontweight='bold')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Validation Loss')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    ax3.set_title('Training Accuracy Comparison', fontsize=14, fontweight='bold')
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Training Accuracy')
    ax3.legend()
    ax3.grid(True, alpha=0.3)

    ax4.set_title('Validation Accuracy Comparison', fontsize=14, fontweight='bold')
    ax4.set_xlabel('Epoch')
    ax4.set_ylabel('Validation Accuracy')
    ax4.legend()
    ax4.grid(True, alpha=0.3)

    plt.suptitle('Model Training Curves Comparison', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()
    return fig

# Enhanced training function with automatic CSV saving
def train_model_with_csv_export(model, train_loader, val_loader, criterion, epochs, device, model_name):
    """
    Enhanced training function that automatically saves training data to CSV

    This function wraps the existing train_model_fine_tuned function and adds
    automatic CSV export functionality.
    """
    # Call the original training function
    train_losses, train_accs, val_losses, val_accs = train_model_fine_tuned(
        model, train_loader, val_loader, criterion, epochs, device, model_name
    )

    # Save data to CSV
    save_training_data_to_csv(model_name, train_losses, train_accs, val_losses, val_accs)

    # Store in enhanced training data for other visualizations
    if 'enhanced_training_data' in globals():
        enhanced_training_data[model_name] = {
            'train_losses': train_losses,
            'train_accs': train_accs,
            'val_losses': val_losses,
            'val_accs': val_accs
        }

    return train_losses, train_accs, val_losses, val_accs

# Function to create training summary
def create_training_summary():
    """Create a summary table of all training results"""
    results_dir = 'thermal_analysis_results'
    csv_files = glob.glob(os.path.join(results_dir, '*_training_curves_*.csv'))

    if not csv_files:
        print("No CSV files found")
        return None

    summary_data = []
    for csv_file in csv_files:
        df = pd.read_csv(csv_file)
        model_name = df['model_name'].iloc[0]

        summary_data.append({
            'Model': model_name,
            'Epochs': len(df),
            'Final Train Acc': f"{df['training_accuracy'].iloc[-1]:.4f}",
            'Final Val Acc': f"{df['validation_accuracy'].iloc[-1]:.4f}",
            'Best Val Acc': f"{df['validation_accuracy'].max():.4f}",
            'Final Train Loss': f"{df['training_loss'].iloc[-1]:.4f}",
            'Final Val Loss': f"{df['validation_loss'].iloc[-1]:.4f}",
            'Min Val Loss': f"{df['validation_loss'].min():.4f}"
        })

    summary_df = pd.DataFrame(summary_data)
    print("Training Results Summary:")
    print(summary_df.to_string(index=False))

    # Save summary
    summary_file = os.path.join(results_dir, f'training_summary_{datetime.now().strftime("%Y%m%d_%H%M%S")}.csv')
    summary_df.to_csv(summary_file, index=False)
    print(f"\nSummary saved to: {summary_file}")

    return summary_df


print("1. Train your models (CSV files will be automatically saved)")
print("2. Use plot_training_curves_from_csv('path/to/file.csv') for individual plots")
print("3. Use plot_all_models_comparison() to compare all models")
print("4. Use create_training_summary() to generate summary table")
print("\nAll files will be saved in 'thermal_analysis_results' directory")

# 3. Models

All model architectures are defined in this section for better organization and reusability.

## 3.1 Utility Classes and Functions

In [None]:
class LabelSmoothingCrossEntropy(nn.Module):
    """Label Smoothing Cross Entropy Loss for better generalization"""
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing

    def forward(self, pred, target):
        n_class = pred.size(1)
        one_hot = torch.zeros_like(pred).scatter(1, target.view(-1, 1), 1)
        one_hot = one_hot * (1 - self.smoothing) + (1 - one_hot) * self.smoothing / (n_class - 1)
        log_prob = F.log_softmax(pred, dim=1)
        return torch.mean(torch.sum(-one_hot * log_prob, dim=1))


class FocalLoss(nn.Module):
    """Focal Loss for handling class imbalance"""
    def __init__(self, alpha=1, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
        return focal_loss.mean()


class CombinedLoss(nn.Module):
    """Combined loss function using both CrossEntropy and Focal Loss"""
    def __init__(self, alpha=0.6, beta=0.4):
        super(CombinedLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.ce_loss = nn.CrossEntropyLoss()
        self.focal_loss = FocalLoss(alpha=1, gamma=2)

    def forward(self, predictions, targets):
        ce = self.ce_loss(predictions, targets)
        focal = self.focal_loss(predictions, targets)
        return self.alpha * ce + self.beta * focal

print("Utility loss functions loaded!")

## 3.2 Attention Mechanisms

In [None]:
class SpatialAttention(nn.Module):
    """Spatial attention mechanism for feature enhancement"""
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x_cat = torch.cat([avg_out, max_out], dim=1)
        attention = self.sigmoid(self.conv(x_cat))
        return x * attention


class ChannelAttention(nn.Module):
    """Channel attention mechanism for feature recalibration"""
    def __init__(self, in_channels, reduction=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(in_channels // reduction, in_channels, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, _, _ = x.size()
        y1 = self.avg_pool(x).view(b, c)
        y2 = self.max_pool(x).view(b, c)

        y1 = self.fc(y1)
        y2 = self.fc(y2)

        attention = self.sigmoid(y1 + y2).view(b, c, 1, 1)
        return x * attention.expand_as(x)


class CBAM(nn.Module):
    """Convolutional Block Attention Module"""
    def __init__(self, in_channels, reduction=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttention(in_channels, reduction)
        self.spatial_attention = SpatialAttention(kernel_size)

    def forward(self, x):
        x = self.channel_attention(x)
        x = self.spatial_attention(x)
        return x

print("Attention mechanisms loaded!")

## 3.3 VGG16 Thermal Model

In [None]:
class VGG16Thermal(nn.Module):
    """VGG16 model adapted for thermal image classification"""
    def __init__(self, num_classes=5):
        super(VGG16Thermal, self).__init__()
        self.backbone = tv_models.vgg16(pretrained=True)
        self.backbone.classifier = nn.Sequential(
            nn.Linear(25088, 4096),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(4096, 2048),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(2048, 1024),
            nn.ReLU(True),
            nn.Dropout(0.3),
            nn.Linear(1024, num_classes)
        )

    def forward(self, x):
        return self.backbone(x)

print("VGG16Thermal model loaded!")

## 3.4 AlexNet Thermal Model

In [None]:
class AlexNetThermal(nn.Module):
    """Enhanced AlexNet for thermal imaging with improved architecture"""
    def __init__(self, num_classes=5):
        super(AlexNetThermal, self).__init__()
        
        # Load pretrained AlexNet
        self.alexnet = tv_models.alexnet(pretrained=True)
        
        # Modify the first convolution layer to accept thermal images
        # Keep the pretrained weights for RGB channels
        original_conv1 = self.alexnet.features[0]
        self.alexnet.features[0] = nn.Conv2d(
            3, 64, kernel_size=11, stride=4, padding=2
        )
        
        # Copy pretrained weights
        with torch.no_grad():
            self.alexnet.features[0].weight.copy_(original_conv1.weight)
            if self.alexnet.features[0].bias is not None:
                self.alexnet.features[0].bias.copy_(original_conv1.bias)
        
        # Replace classifier with improved architecture
        self.alexnet.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(256 * 6 * 6, 4096),
            nn.BatchNorm1d(4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(4096, 2048),
            nn.BatchNorm1d(2048),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(2048, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, num_classes)
        )
        
        # Initialize weights
        self._initialize_weights()

    def _initialize_weights(self):
        """Initialize weights for the custom classifier"""
        for m in self.alexnet.classifier.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        # Handle both 1-channel thermal and 3-channel RGB inputs
        if x.shape[1] == 1:
            x = x.repeat(1, 3, 1, 1)
        elif x.shape[1] != 3:
            raise ValueError(f"Expected 1 or 3 input channels, got {x.shape[1]}")
        return self.alexnet(x)

print("AlexNetThermal model loaded!")

## 3.5 Enhanced Hybrid Model

In [None]:
class EnhancedHybridVGGAlexNet(nn.Module):
    """Enhanced Hybrid model combining VGG16 and AlexNet with attention mechanisms"""
    
    def __init__(self, num_classes=5, dropout_rate=0.5):
        super(EnhancedHybridVGGAlexNet, self).__init__()
        
        # Load pre-trained backbones
        vgg16 = tv_models.vgg16(pretrained=True)
        alexnet = tv_models.alexnet(pretrained=True)
        
        # VGG16 feature extraction
        self.vgg_features = nn.Sequential(*list(vgg16.features[:-1]))
        
        # AlexNet feature extraction
        self.alex_features = nn.Sequential(*list(alexnet.features))
        
        # Attention mechanisms
        self.vgg_attention = CBAM(512)
        self.alex_attention = CBAM(256)
        
        # Feature adaptation layers
        self.vgg_adapt = nn.Sequential(
            nn.Conv2d(512, 256, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.1)
        )
        
        self.alex_adapt = nn.Sequential(
            nn.Conv2d(256, 256, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.1)
        )
        
        # Enhanced fusion layers
        self.fusion_conv = nn.Sequential(
            nn.Conv2d(512, 1024, 3, padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.2),
            
            nn.Conv2d(1024, 512, 3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1)
        )
        
        # Enhanced classifier
        self.classifier = nn.Sequential(
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
            
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate * 0.7),
            
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate * 0.5),
            
            nn.Linear(256, num_classes)
        )
        
        # Initialize weights
        self._initialize_weights()
    
    def _initialize_weights(self):
        """Initialize custom layer weights"""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
                nn.init.constant_(m.weight, 1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        # VGG16 pathway
        vgg_features = self.vgg_features(x)
        vgg_features = self.vgg_attention(vgg_features)
        vgg_features = self.vgg_adapt(vgg_features)
        
        # AlexNet pathway
        alex_features = self.alex_features(x)
        alex_features = self.alex_attention(alex_features)
        alex_features = self.alex_adapt(alex_features)
        
        # Align spatial dimensions
        alex_features = F.interpolate(alex_features, size=vgg_features.shape[-2:],
                                    mode='bilinear', align_corners=False)
        
        # Feature fusion
        fused_features = torch.cat([vgg_features, alex_features], dim=1)
        fused_features = self.fusion_conv(fused_features)
        
        # Classification
        features = fused_features.flatten(1)
        output = self.classifier(features)
        
        return output

print("EnhancedHybridVGGAlexNet model loaded!")

## 3.6 Custom CNN with Inception Modules

In [None]:
class InceptionModule(nn.Module):
    """Inception-style module for multi-scale feature extraction"""
    
    def __init__(self, in_channels, filters, name_prefix="inception"):
        super(InceptionModule, self).__init__()
        
        # Branch 1: 1x1 convolution
        self.branch1 = nn.Conv2d(in_channels, filters//4, kernel_size=1, padding=0)
        self.bn1 = nn.BatchNorm2d(filters//4)
        
        # Branch 2: 1x1 -> 3x3 convolution
        self.branch2_reduce = nn.Conv2d(in_channels, filters//8, kernel_size=1, padding=0)
        self.bn2_reduce = nn.BatchNorm2d(filters//8)
        self.branch2_conv = nn.Conv2d(filters//8, filters//4, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(filters//4)
        
        # Branch 3: 1x1 -> 3x3 -> 3x3 (simulating 5x5)
        self.branch3_reduce = nn.Conv2d(in_channels, filters//16, kernel_size=1, padding=0)
        self.bn3_reduce = nn.BatchNorm2d(filters//16)
        self.branch3_conv1 = nn.Conv2d(filters//16, filters//8, kernel_size=3, padding=1)
        self.bn3_conv1 = nn.BatchNorm2d(filters//8)
        self.branch3_conv2 = nn.Conv2d(filters//8, filters//4, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(filters//4)
        
        # Branch 4: MaxPooling -> 1x1
        self.branch4_pool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        self.branch4_conv = nn.Conv2d(in_channels, filters//4, kernel_size=1, padding=0)
        self.bn4 = nn.BatchNorm2d(filters//4)
    
    def forward(self, x):
        # Branch 1
        branch1 = F.relu(self.bn1(self.branch1(x)))
        
        # Branch 2
        branch2 = F.relu(self.bn2_reduce(self.branch2_reduce(x)))
        branch2 = F.relu(self.bn2(self.branch2_conv(branch2)))
        
        # Branch 3
        branch3 = F.relu(self.bn3_reduce(self.branch3_reduce(x)))
        branch3 = F.relu(self.bn3_conv1(self.branch3_conv1(branch3)))
        branch3 = F.relu(self.bn3(self.branch3_conv2(branch3)))
        
        # Branch 4
        branch4 = self.branch4_pool(x)
        branch4 = F.relu(self.bn4(self.branch4_conv(branch4)))
        
        # Concatenate all branches
        outputs = torch.cat([branch1, branch2, branch3, branch4], dim=1)
        return outputs


class ProposedCNN(nn.Module):
    """Custom CNN architecture with Inception modules for thermal imaging"""
    
    def __init__(self, input_shape=(3, 200, 200), num_classes=5, dropout_rate=0.3):
        super(ProposedCNN, self).__init__()
        
        # Initial convolution layer
        self.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        # Second convolution block
        self.conv2a = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn2a = nn.BatchNorm2d(128)
        self.conv2b = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn2b = nn.BatchNorm2d(128)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Inception modules
        self.inception3a = InceptionModule(128, 256, "inception3a")
        self.inception3b = InceptionModule(256, 256, "inception3b")
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.inception4a = InceptionModule(256, 512, "inception4a")
        self.inception4b = InceptionModule(512, 512, "inception4b")
        self.inception4c = InceptionModule(512, 512, "inception4c")
        self.inception4d = InceptionModule(512, 512, "inception4d")
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.inception5a = InceptionModule(512, 1024, "inception5a")
        self.inception5b = InceptionModule(1024, 1024, "inception5b")
        
        # Dropout layer
        self.dropout4 = nn.Dropout2d(dropout_rate)
        
        # Deep convolution layers
        self.conv_deep1 = nn.Conv2d(1024, 1024, kernel_size=3, padding=1)
        self.bn_deep1 = nn.BatchNorm2d(1024)
        self.conv_deep2 = nn.Conv2d(1024, 1024, kernel_size=3, padding=1)
        self.bn_deep2 = nn.BatchNorm2d(1024)
        
        # Global Average Pooling
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        
        # Dense layers
        self.fc1 = nn.Linear(1024, 512)
        self.dropout6 = nn.Dropout(dropout_rate)
        
        self.fc2 = nn.Linear(512, 256)
        self.dropout7 = nn.Dropout(dropout_rate)
        
        # Output layer
        self.predictions = nn.Linear(256, num_classes)
        
        # Initialize weights
        self._initialize_weights()
    
    def _initialize_weights(self):
        """Initialize model weights"""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                init.constant_(m.bias, 0)
    
    def forward(self, x):
        # Initial convolution
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool1(x)
        
        # Second convolution block
        x = F.relu(self.bn2a(self.conv2a(x)))
        x = F.relu(self.bn2b(self.conv2b(x)))
        x = self.pool2(x)
        
        # Inception modules
        x = self.inception3a(x)
        x = self.inception3b(x)
        x = self.pool3(x)
        
        x = self.inception4a(x)
        x = self.inception4b(x)
        x = self.inception4c(x)
        x = self.inception4d(x)
        x = self.pool4(x)
        
        x = self.inception5a(x)
        x = self.inception5b(x)
        
        # Dropout
        x = self.dropout4(x)
        
        # Deep convolution layers
        x = F.relu(self.bn_deep1(self.conv_deep1(x)))
        x = F.relu(self.bn_deep2(self.conv_deep2(x)))
        
        # Global Average Pooling
        x = self.global_avg_pool(x)
        x = x.view(x.size(0), -1)  # Flatten
        
        # Dense layers
        x = F.relu(self.fc1(x))
        x = self.dropout6(x)
        
        x = F.relu(self.fc2(x))
        x = self.dropout7(x)
        
        # Output layer
        x = self.predictions(x)
        
        return x

print("ProposedCNN and InceptionModule models loaded!")

# Data Augmentation & Loading

In [None]:
def save_augmented_images(X, y, save_dir, categories):
    """Save augmented images to disk organized by class"""
    print(f"Saving augmented images to {save_dir}...")

    # Create main directory
    os.makedirs(save_dir, exist_ok=True)

    # Create subdirectories for each class
    for category in categories:
        class_dir = os.path.join(save_dir, category)
        os.makedirs(class_dir, exist_ok=True)

    # Save images by class
    class_counts = {}
    for i, category in enumerate(categories):
        class_counts[i] = 0

    for idx, (img, label) in enumerate(zip(X, y)):
        category_name = categories[label]
        class_counts[label] += 1

        # Convert numpy array to PIL Image and save
        img_uint8 = (img * 255).astype(np.uint8)
        pil_img = PILImage.fromarray(img_uint8, mode='RGB')

        filename = f"{class_counts[label]:05d}.png"
        filepath = os.path.join(save_dir, category_name, filename)
        pil_img.save(filepath)

        if (idx + 1) % 1000 == 0:
            print(f"   Saved {idx + 1}/{len(X)} images...")

    print(f"Successfully saved {len(X)} images")
    for i, category in enumerate(categories):
        print(f"   {category}: {class_counts[i]} images")


def load_augmented_images(load_dir, categories):
    """Load augmented images from disk"""
    print(f" Loading augmented images from {load_dir}...")

    images, labels = [], []
    label_map = {cat: i for i, cat in enumerate(categories)}

    for category in categories:
        class_dir = os.path.join(load_dir, category)
        if not os.path.exists(class_dir):
            print(f"    Directory not found: {class_dir}")
            continue

        image_files = sorted([f for f in os.listdir(class_dir) if f.endswith('.png')])
        for img_file in image_files:
            img_path = os.path.join(class_dir, img_file)
            try:
                with PILImage.open(img_path) as im:
                    img = im.convert("RGB")
                    arr = np.asarray(img, dtype=np.float32) / 255.0
                    images.append(arr)
                labels.append(label_map[category])
            except Exception as e:
                print(f"   Error loading {img_path}: {e}")

        if len(image_files) > 0:
            print(f"   {category}: {len(image_files)} images loaded")

    if len(images) > 0:
        print(f"Successfully loaded {len(images)} images")
        return np.array(images, dtype=np.float32), np.array(labels, dtype=np.int64)
    else:
        return None, None


def check_augmented_images_exist(load_dir, categories, expected_count_per_class):
    """Check if augmented images already exist with expected counts"""
    if not os.path.exists(load_dir):
        return False

    for category in categories:
        class_dir = os.path.join(load_dir, category)
        if not os.path.exists(class_dir):
            return False

        image_files = [f for f in os.listdir(class_dir) if f.endswith('.png')]
        if len(image_files) < expected_count_per_class:
            return False

    return True


def augment_with_target_count(X, y, target_count_per_class=1000, batch_size=50):
    """Enhanced data augmentation with target count control"""
    print(f"Target count per class: {target_count_per_class}")
    print(f"Original class distribution: {Counter(y)}")

    H = W = 200
    pad_px = 24
    augment_tfms = transforms.Compose([
        transforms.Pad(pad_px, padding_mode='reflect'),
        transforms.RandomAffine(
            degrees=35,
            translate=(0.25, 0.25),
            scale=(0.85, 1.15),
            shear=25,
            interpolation=InterpolationMode.BILINEAR,
            fill=0
        ),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ColorJitter(
            brightness=(0.9, 1.1),
            contrast=(0.95, 1.05),
            saturation=(0.95, 1.05)
        ),
        transforms.CenterCrop((H, W)),
    ])

    X_augmented, y_augmented = [], []
    unique_classes = np.unique(y)

    for class_idx in unique_classes:
        class_mask = y == class_idx
        class_data = X[class_mask]
        current_count = len(class_data)
        print(f"Processing class {class_idx}: {current_count} samples")

        # Add original data first
        X_augmented.extend(class_data)
        y_augmented.extend([class_idx] * current_count)

        if current_count < target_count_per_class:
            needed = target_count_per_class - current_count
            print(f"    Augmenting {needed} additional samples")
            generated = 0
            while generated < needed:
                batch_needed = min(batch_size, needed - generated)
                batch_X = []

                for _ in range(batch_needed):
                    idx = np.random.randint(0, len(class_data))
                    img = class_data[idx]
                    img_tensor = torch.from_numpy(img).permute(2, 0, 1).float()
                    aug_tensor = augment_tfms(img_tensor)
                    aug_img = aug_tensor.permute(1, 2, 0).numpy()
                    aug_img = np.clip(aug_img, 0, 1)
                    batch_X.append(aug_img)

                X_augmented.extend(batch_X)
                y_augmented.extend([class_idx] * batch_needed)
                generated += batch_needed

                if generated % 100 == 0 or generated == needed:
                    print(f"      Generated {generated}/{needed} samples")

                del batch_X
                gc.collect()
        else:
            print(f"   Class has sufficient samples ({current_count} >= {target_count_per_class})")

    print(f"Augmentation complete!")
    print(f"Final dataset size: {len(X_augmented)} samples")

    return np.array(X_augmented), np.array(y_augmented)


# Configuration
TARGET_COUNT = 4000

# Load original data to get categories
print(" Loading thermal imaging data...")
images, labels = [], []
categories = sorted([d for d in os.listdir(data_dir)
                     if os.path.isdir(os.path.join(data_dir, d))])

# Check if augmented images already exist
if check_augmented_images_exist(AUGMENTED_IMAGES_DIR, categories, TARGET_COUNT):
    print("\nFound existing augmented images! Loading from disk...")
    X_augmented, y_augmented = load_augmented_images(AUGMENTED_IMAGES_DIR, categories)

    if X_augmented is not None and y_augmented is not None:
        print(f"\nLoaded Dataset Statistics:")
        print(f"Total samples: {len(X_augmented)}")
        for i, category in enumerate(categories):
            count = np.sum(y_augmented == i)
            percentage = (count / len(y_augmented)) * 100
            print(f"  {category}: {count} samples ({percentage:.1f}%)")
    else:
        print(" Failed to load augmented images. Will generate new ones.")
        X_augmented = None
else:
    print("\n Augmented images not found or incomplete. Will generate new ones...")
    X_augmented = None

# Generate and save augmented images if not loaded
if X_augmented is None:
    for category in categories:
        category_path = os.path.join(data_dir, category)
        for image_name in os.listdir(category_path):
            if image_name.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')):
                image_path = os.path.join(category_path, image_name)
                try:
                    with PILImage.open(image_path) as im:
                        img = im.convert("RGB").resize((200, 200))
                        arr = np.asarray(img, dtype=np.float32) / 255.0
                        images.append(arr)
                    labels.append(category)
                except Exception as e:
                    print(f"Error loading image {image_path}: {e}")

    images = np.array(images, dtype=np.float32)
    labels = np.array(labels)
    label_map = {cat: i for i, cat in enumerate(categories)}
    numerical_labels = np.array([label_map[label] for label in labels], dtype=np.int64)

    X, y = images, numerical_labels

    print(f"\nOriginal Dataset Statistics:")
    print(f"Total samples: {len(X)}")
    for i, category in enumerate(categories):
        count = np.sum(y == i)
        percentage = (count / len(y)) * 100
        print(f"  {category}: {count} samples ({percentage:.1f}%)")

    print("\nApplying enhanced data augmentation...")
    X_augmented, y_augmented = augment_with_target_count(
        X, y, target_count_per_class=TARGET_COUNT, batch_size=50
    )

    print("\nSaving augmented images for future use...")
    save_augmented_images(X_augmented, y_augmented, AUGMENTED_IMAGES_DIR, categories)

# Enhanced train/validation/test split (70/15/15)
X_train, X_temp, y_train, y_temp = train_test_split(
    X_augmented, y_augmented, test_size=0.3, random_state=42, stratify=y_augmented
)
X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp, test_size=0.5, random_state=42, stratify=y_temp
)

print(f"\nData split completed:")
print(f"Training samples: {len(X_train)} (70%)")
print(f"Validation samples: {len(X_val)} (15%)")
print(f"Testing samples: {len(X_test)} (15%)")


def prepare_enhanced_dataloaders(X_train, y_train, X_val, y_val, X_test, y_test, batch_size=32):
    """Prepare enhanced DataLoaders with train/validation/test splits"""

    def prepare_tensors(X, y):
        if X.max() > 1.0:
            X = X.astype(np.float32) / 255.0
        else:
            X = X.astype(np.float32)
        X_tensor = torch.from_numpy(X).permute(0, 3, 1, 2)
        y_tensor = torch.from_numpy(y.astype(np.int64))
        return X_tensor, y_tensor

    X_train_tensor, y_train_tensor = prepare_tensors(X_train, y_train)
    X_val_tensor, y_val_tensor = prepare_tensors(X_val, y_val)
    X_test_tensor, y_test_tensor = prepare_tensors(X_test, y_test)

    train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
    val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
    test_dataset = TensorDataset(X_test_tensor, y_test_tensor)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                              num_workers=2, pin_memory=torch.cuda.is_available())
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
                            num_workers=2, pin_memory=torch.cuda.is_available())
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
                             num_workers=2, pin_memory=torch.cuda.is_available())

    return train_loader, val_loader, test_loader


train_loader_enhanced, val_loader_enhanced, test_loader_enhanced = prepare_enhanced_dataloaders(
    X_train, y_train, X_val, y_val, X_test, y_test, batch_size=32
)

print(f"\n Enhanced DataLoaders created successfully!")
print(f"Train batches: {len(train_loader_enhanced)}")
print(f"Validation batches: {len(val_loader_enhanced)}")
print(f"Test batches: {len(test_loader_enhanced)}")

# Global storage for model results
model_results = []


# Evaluation Functions

In [None]:
def evaluate_model_comprehensive(model, test_loader, class_names, model_name, device):
    """Comprehensive evaluation with all metrics"""
    model.eval()
    all_predictions = []
    all_labels = []

    print(f"Evaluating {model_name}...")

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)

            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_predictions)
    precision_macro = precision_score(all_labels, all_predictions, average='macro', zero_division=0)
    precision_weighted = precision_score(all_labels, all_predictions, average='weighted', zero_division=0)
    recall_macro = recall_score(all_labels, all_predictions, average='macro', zero_division=0)
    recall_weighted = recall_score(all_labels, all_predictions, average='weighted', zero_division=0)
    f1_macro = f1_score(all_labels, all_predictions, average='macro', zero_division=0)
    f1_weighted = f1_score(all_labels, all_predictions, average='weighted', zero_division=0)

    # Calculate micro-averaged metrics
    precision_micro = precision_score(all_labels, all_predictions, average='micro', zero_division=0)
    recall_micro = recall_score(all_labels, all_predictions, average='micro', zero_division=0)
    f1_micro = f1_score(all_labels, all_predictions, average='micro', zero_division=0)


    # Print results
    print(f"{model_name} - Test Results:")
    print(f"  Accuracy: {accuracy:.4f}")
    print(f"  Precision (Macro): {precision_macro:.4f}")
    print(f"  Precision (Weighted): {precision_weighted:.4f}")
    print(f"  Recall (Macro): {recall_macro:.4f}")
    print(f"  Recall (Weighted): {recall_weighted:.4f}")
    print(f"  F1-Score (Macro): {f1_macro:.4f}")
    print(f"  F1-Score (Weighted): {f1_weighted:.4f}")

    # Detailed classification report
    print(f"{model_name} - Detailed Classification Report:")
    print(classification_report(all_labels, all_predictions,
                              target_names=class_names, zero_division=0))

    # Confusion matrix
    cm = confusion_matrix(all_labels, all_predictions)

    plt.figure(figsize=(8, 7), dpi=300)
    sns.heatmap(
        cm, annot=True, fmt='d', cmap='Blues',
        xticklabels=class_names, yticklabels=class_names,
        linewidths=0.5, linecolor='white',
        cbar=True, cbar_kws={'label': None},
        annot_kws={"size": 19}  # <-- Bigger, bold numbers
    )

    # Axis labels
    plt.xlabel('Predicted', fontsize=14, fontweight='bold')
    plt.ylabel('Actual', fontsize=14, fontweight='bold')

    # Ticks formatting
    plt.xticks(fontsize=14, rotation=0)     # Horizontal x-axis ticks
    plt.yticks(fontsize=14, rotation=90)    # Vertical y-axis ticks

    # Layout and save
    plt.tight_layout()
    plt.show()

    return {
        'model_name': model_name,
        'accuracy': accuracy,
        'precision_macro': precision_macro,
        'precision_weighted': precision_weighted,
        'recall_macro': recall_macro,
        'recall_weighted': recall_weighted,
        'f1_macro': f1_macro,
        'f1_weighted': f1_weighted,
        'precision_micro': precision_micro,
        'recall_micro': recall_micro,
        'f1_micro': f1_micro,
        'confusion_matrix': cm
    }

def plot_training_curves(train_losses, train_accs, val_losses, val_accs, model_name):
    """Plot training and validation curves"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

    # Loss plot
    ax1.plot(train_losses, label='Training Loss', color='blue')
    ax1.plot(val_losses, label='Validation Loss', color='red')
    ax1.set_title(f'{model_name} - Training & Validation Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # Accuracy plot
    ax2.plot(train_accs, label='Training Accuracy', color='blue')
    ax2.plot(val_accs, label='Validation Accuracy', color='red')
    ax2.set_title(f'{model_name} - Training & Validation Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

print("Enhanced evaluation functions loaded!")

# Model Training

## ResNet Training

In [None]:

# Enhanced ResNet Training and Evaluation
print("" + "="*80)
print(" ENHANCED RESNET18 TRAINING WITH VALIDATION")
print("="*80)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

num_classes = len(categories)
resnet_enhanced = models.resnet18(pretrained=False)

# Enhance architecture with regularization
resnet_enhanced = enhance_model_architecture(resnet_enhanced, "resnet", num_classes, dropout_rate=0.3)
# CUDA memory management and error handling
try:
    # Clear CUDA cache before model transfer
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    resnet_enhanced = resnet_enhanced.to(device)
    print(f"ResNet model successfully moved to {device}")
except RuntimeError as e:
    if "CUDA" in str(e) or "out of memory" in str(e):
        print(f"CUDA error encountered: {e}")
        print("Falling back to CPU...")
        device = torch.device('cpu')
        resnet_enhanced = resnet_enhanced.to(device)
    else:
        raise e

# Define loss and optimizer
criterion = nn.CrossEntropyLoss()

# Train with validation
if resnet_model_training:
  epochs = common_epochs
  train_losses, train_accs, val_losses, val_accs = train_model_fine_tuned(
      resnet_enhanced, train_loader_enhanced, val_loader_enhanced,
      criterion, epochs, device, "ResNet18")

  # Plot training curves
  plot_training_curves(train_losses, train_accs, val_losses, val_accs, "ResNet18")
  print ("Training is completed")

### Save Model

In [None]:
# Comprehensive model saving for ResNet18
model_name = "resnet18_epoch_100"
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

# Create comprehensive checkpoint
checkpoint = {
    # Model architecture and weights
    'model_state_dict': resnet_enhanced.state_dict(),
    'model_architecture': 'ResNet18',
    'model_class': 'resnet18',
    
    # Training configuration
    'num_classes': len(categories),
    'class_names': categories,
    'epochs_trained': common_epochs,
    'dropout_rate': 0.3,
    
    # Training history
    'train_losses': train_losses,
    'train_accs': train_accs,
    'val_losses': val_losses,
    'val_accs': val_accs,
    'best_val_acc': max(val_accs),
    'best_val_loss': min(val_losses),
    
    # Optimizer configuration
    'optimizer_name': 'AdamW',
    'learning_rate': 0.001,
    'weight_decay': 0.01,
    'scheduler_name': 'CosineAnnealingLR',
    
    # Metadata
    'timestamp': timestamp,
    'device': str(device),
    'pytorch_version': torch.__version__,
    'input_size': (200, 200),
    'batch_size': 32,
}

# Save full model (for easy loading)
full_model_path = os.path.join(MODEL_DIR, f"{model_name}_full.pth")
torch.save(resnet_enhanced, full_model_path)

# Save comprehensive checkpoint (recommended for production)
checkpoint_path = os.path.join(MODEL_DIR, f"{model_name}_checkpoint.pth")
torch.save(checkpoint, checkpoint_path)

print(f"ResNet18 Model saved successfully:")
print(f"   Full model: {full_model_path}")
print(f"   Checkpoint: {checkpoint_path}")
print(f"   Best validation accuracy: {max(val_accs):.4f}")
print(f"   Best validation loss: {min(val_losses):.4f}")

 ### Evaluation

In [None]:

model_name = "resnet18_epoch_100_full"
model_path = os.path.join(MODEL_DIR, f"{model_name}.pth")

resnet_enhanced = torch.load(model_path, map_location=device, weights_only=False)
resnet_enhanced.eval()
print("Full Resnet model loaded successfully")


# Comprehensive evaluation
resnet_results = evaluate_model_comprehensive(resnet_enhanced, test_loader_enhanced, categories, "ResNet18", device)
save_training_data_to_csv("ResNet18", train_losses, train_accs, val_losses, val_accs)
# Store model probabilities for ROC plotting
with torch.no_grad():
    resnet_enhanced.eval()
    test_probs = []
    test_labels = []

    for batch_x, batch_y in test_loader_enhanced:
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)

        outputs = resnet_enhanced(batch_x)
        probs = torch.softmax(outputs, dim=1)

        test_probs.append(probs.cpu().numpy())
        test_labels.append(batch_y.cpu().numpy())

    test_probs = np.vstack(test_probs)
    test_labels_array = np.hstack(test_labels)

# Store for ROC plotting
store_model_probabilities('ResNet18', test_probs, test_labels_array)
MODEL_RESULTS_GLOBAL['ResNet18'] = resnet_results


# Store results
model_results.append(resnet_results)
print(f"ResNet18 enhanced training and evaluation completed!")

## MobileNet Training

In [None]:
# Enhanced MobileNet Training and Evaluation
print("\n" + "="*80)
print(" ENHANCED MOBILENETV2 TRAINING WITH VALIDATION")
print("="*80)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

num_classes = len(categories)
mobilenet_enhanced = models.mobilenet_v2(pretrained=False)

# Enhance architecture with regularization
mobilenet_enhanced = enhance_model_architecture(mobilenet_enhanced, "mobilenet", num_classes, dropout_rate=0.3)

# Define loss and optimizer
criterion = nn.CrossEntropyLoss()

# Train with validation
#epochs = 5
if mobilenet_model_training:
  epochs = common_epochs
  train_losses, train_accs, val_losses, val_accs = train_model_fine_tuned(
      mobilenet_enhanced, train_loader_enhanced, val_loader_enhanced,
      criterion, epochs, device, "MobileNetV2")

  # Plot training curves
  plot_training_curves(train_losses, train_accs, val_losses, val_accs, "MobileNetV2")



### Save model

In [None]:
# Comprehensive model saving for MobileNetV2
model_name = "mobilenet_epoch_100"
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

# Create comprehensive checkpoint
checkpoint = {
    # Model architecture and weights
    'model_state_dict': mobilenet_enhanced.state_dict(),
    'model_architecture': 'MobileNetV2',
    'model_class': 'mobilenet_v2',
    
    # Training configuration
    'num_classes': len(categories),
    'class_names': categories,
    'epochs_trained': common_epochs,
    'dropout_rate': 0.3,
    
    # Training history
    'train_losses': train_losses,
    'train_accs': train_accs,
    'val_losses': val_losses,
    'val_accs': val_accs,
    'best_val_acc': max(val_accs),
    'best_val_loss': min(val_losses),
    
    # Optimizer configuration
    'optimizer_name': 'AdamW',
    'learning_rate': 0.001,
    'weight_decay': 0.01,
    'scheduler_name': 'CosineAnnealingLR',
    
    # Metadata
    'timestamp': timestamp,
    'device': str(device),
    'pytorch_version': torch.__version__,
    'input_size': (200, 200),
    'batch_size': 32,
}

# Save full model (for easy loading)
full_model_path = os.path.join(MODEL_DIR, f"{model_name}_full.pth")
torch.save(mobilenet_enhanced, full_model_path)

# Save comprehensive checkpoint (recommended for production)
checkpoint_path = os.path.join(MODEL_DIR, f"{model_name}_checkpoint.pth")
torch.save(checkpoint, checkpoint_path)

print(f"MobileNetV2 Model saved successfully:")
print(f"   Full model: {full_model_path}")
print(f"   Checkpoint: {checkpoint_path}")
print(f"   Best validation accuracy: {max(val_accs):.4f}")
print(f"   Best validation loss: {min(val_losses):.4f}")

### Evaluation

In [None]:
model_name = "mobilenet_epoch_100_full"
model_path = os.path.join(MODEL_DIR, f"{model_name}.pth")

mobilenet_enhanced = torch.load(model_path, map_location=device, weights_only=False)
mobilenet_enhanced.eval()
print("Full mobilenet  model loaded successfully")


# Comprehensive evaluation
mobilenet_results = evaluate_model_comprehensive(mobilenet_enhanced, test_loader_enhanced, categories, "MobileNetV2", device)

# Store model probabilities for ROC plotting
with torch.no_grad():
    mobilenet_enhanced.eval()
    test_probs = []
    test_labels = []
    for batch_x, batch_y in test_loader_enhanced:
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        outputs = mobilenet_enhanced(batch_x)
        probs = torch.softmax(outputs, dim=1)
        test_probs.append(probs.cpu().numpy())
        test_labels.append(batch_y.cpu().numpy())

    test_probs = np.vstack(test_probs)
    test_labels_array = np.hstack(test_labels)

# Store for ROC plotting
store_model_probabilities('MobileNetV2', test_probs, test_labels_array)
MODEL_RESULTS_GLOBAL['MobileNetV2'] = mobilenet_results
save_training_data_to_csv("MobileNetV2", train_losses, train_accs, val_losses, val_accs)

# Store results
model_results.append(mobilenet_results)
print(f"\nMobileNetV2 enhanced training and evaluation completed!")

## EfficientNet Training

In [None]:
# Enhanced EfficientNet Training and Evaluation
print("\n" + "="*80)
print(" ENHANCED EFFICIENTNET TRAINING WITH VALIDATION")
print("="*80)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Try to  import EfficientNet, skip if not available
try:

    num_classes = len(categories)
    efficientnet_enhanced = EfficientNet.from_name("efficientnet-b0")

    # Enhance architecture with regularization
    efficientnet_enhanced = enhance_model_architecture(efficientnet_enhanced, "efficientnet", num_classes, dropout_rate=0.3)

    # Define loss function
    criterion = nn.CrossEntropyLoss()

    # Train with validation
    if efficientnet_model_training:
      epochs = common_epochs
      print(f"Training EfficientNet with full enhanced optimization suite...")

      train_losses, train_accs, val_losses, val_accs = train_model_fine_tuned(
          efficientnet_enhanced, train_loader_enhanced, val_loader_enhanced,
          criterion, epochs, device, "EfficientNet-B0")

      # Plot training curves
      plot_training_curves(train_losses, train_accs, val_losses, val_accs, "EfficientNet-B0")
      print(f"EfficientNet-B0 enhanced training completed!")

except ImportError as e:
    print(f" EfficientNet not available: {e}")
    print(f"   Install with: pip install efficientnet-pytorch")
    print(f"   Continuing with ResNet and MobileNet results...")

### Save model

In [None]:
# Comprehensive model saving for EfficientNet-B0
model_name = "efficientnet_epoch_100"
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

# Create comprehensive checkpoint
checkpoint = {
    # Model architecture and weights
    'model_state_dict': efficientnet_enhanced.state_dict(),
    'model_architecture': 'EfficientNet-B0',
    'model_class': 'efficientnet-b0',
    
    # Training configuration
    'num_classes': len(categories),
    'class_names': categories,
    'epochs_trained': common_epochs,
    'dropout_rate': 0.3,
    
    # Training history
    'train_losses': train_losses,
    'train_accs': train_accs,
    'val_losses': val_losses,
    'val_accs': val_accs,
    'best_val_acc': max(val_accs),
    'best_val_loss': min(val_losses),
    
    # Optimizer configuration
    'optimizer_name': 'AdamW',
    'learning_rate': 0.001,
    'weight_decay': 0.01,
    'scheduler_name': 'CosineAnnealingLR',
    
    # Metadata
    'timestamp': timestamp,
    'device': str(device),
    'pytorch_version': torch.__version__,
    'input_size': (200, 200),
    'batch_size': 32,
}

# Save full model (for easy loading)
full_model_path = os.path.join(MODEL_DIR, f"{model_name}_full.pth")
torch.save(efficientnet_enhanced, full_model_path)

# Save comprehensive checkpoint (recommended for production)
checkpoint_path = os.path.join(MODEL_DIR, f"{model_name}_checkpoint.pth")
torch.save(checkpoint, checkpoint_path)

print(f"EfficientNet-B0 Model saved successfully:")
print(f"   Full model: {full_model_path}")
print(f"   Checkpoint: {checkpoint_path}")
print(f"   Best validation accuracy: {max(val_accs):.4f}")
print(f"   Best validation loss: {min(val_losses):.4f}")

### Evaluation

In [None]:
model_name = "efficientnet_epoch_100_full"
model_path = os.path.join(MODEL_DIR, f"{model_name}.pth")

efficientnet_enhanced = torch.load(model_path, map_location=device, weights_only=False)
efficientnet_enhanced.eval()
print("Full efficientnet  model loaded successfully")

try:
    # Comprehensive evaluation
    efficientnet_results = evaluate_model_comprehensive(efficientnet_enhanced, test_loader_enhanced, categories, "EfficientNet-B0", device)


    # Store model probabilities for ROC plotting
    with torch.no_grad():
        efficientnet_enhanced.eval()
        test_probs = []
        test_labels = []
        for batch_x, batch_y in test_loader_enhanced:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            outputs = efficientnet_enhanced(batch_x)
            probs = torch.softmax(outputs, dim=1)
            test_probs.append(probs.cpu().numpy())
            test_labels.append(batch_y.cpu().numpy())

        test_probs = np.vstack(test_probs)
        test_labels_array = np.hstack(test_labels)

    # Store for ROC plotting
    store_model_probabilities('EfficientNet', test_probs, test_labels_array)
    MODEL_RESULTS_GLOBAL['EfficientNet'] = efficientnet_results
    save_training_data_to_csv("EfficientNet", train_losses, train_accs, val_losses, val_accs)


    # Store results
    model_results.append(efficientnet_results)
    print(f"EfficientNet-B0 enhanced and evaluation completed!")

except ImportError as e:
    print(f" EfficientNet not available: {e}")
    print(f"   Install with: pip install efficientnet-pytorch")
    print(f"   Continuing with ResNet and MobileNet results...")

except Exception as e:
    print(f" EfficientNet training skipped due to error: {e}")
    print(f"   ResNet and MobileNet results are available for comparison")

## VGG16

In [None]:
print("VGG16 ENHANCED FINE-TUNING WITH ADVANCED OPTIMIZATION")
print("=" * 70)

vgg16_model = VGG16Thermal(num_classes=5).to(device)
vgg16_params = sum(p.numel() for p in vgg16_model.parameters())


# Advanced optimizer configuration for VGG16
vgg16_optimizer = optim.AdamW(
    vgg16_model.parameters(),
    lr=0.0005,  # Lower initial learning rate for fine-tuning
    weight_decay=0.001,  # Reduced weight decay
    betas=(0.9, 0.999),
    eps=1e-8
)

# Multi-step learning rate scheduler for better convergence
vgg16_scheduler = optim.lr_scheduler.MultiStepLR(
    vgg16_optimizer,
    milestones=[10, 20, 30],
    gamma=0.5
)

# Enhanced criterion with label smoothing
class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing

    def forward(self, pred, target):
        log_prob = F.log_softmax(pred, dim=-1)
        weight = pred.new_ones(pred.size()) * self.smoothing / (pred.size(-1) - 1.)
        weight.scatter_(-1, target.unsqueeze(-1), (1. - self.smoothing))
        loss = (-weight * log_prob).sum(dim=-1).mean()
        return loss

enhanced_criterion = LabelSmoothingCrossEntropy(smoothing=0.1)

# Enhanced training with custom loop for better control
def train_vgg16_enhanced(model, train_loader, val_loader, criterion, optimizer, scheduler, epochs=35):
    model.train()
    best_val_acc = 0.0
    train_losses, val_losses = [], []
    train_accs, val_accs = [], []

    for epoch in range(epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for batch_idx, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Gradient clipping for stability
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        train_loss = running_loss / len(train_loader)
        train_acc = 100.0 * correct / total

        # Validation phase
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                val_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

        val_loss /= len(val_loader)
        val_acc = 100.0 * correct / total

        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accs.append(train_acc)
        val_accs.append(val_acc)

        print(f'Epoch {epoch+1:2d}/{epochs} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | '
              f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% | LR: {optimizer.param_groups[0]["lr"]:.6f}')

        # Learning rate scheduling
        scheduler.step()

        # Save best model state
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'vgg16_best_enhanced.pth')

    # Load best model
    model.load_state_dict(torch.load('vgg16_best_enhanced.pth'))
    print(f'Enhanced VGG16 training completed! Best validation accuracy: {best_val_acc:.2f}%')
    return train_losses, train_accs, val_losses, val_accs


# Execute enhanced training
if vgg16_model_training:
  train_losses, train_accs, val_losses, val_accs = train_vgg16_enhanced(
      vgg16_model, train_loader_enhanced, val_loader_enhanced,
      enhanced_criterion, vgg16_optimizer, vgg16_scheduler, epochs=common_epochs
  )



### Save model

In [None]:
# Comprehensive model saving for VGG16
model_name = "vgg16_epoch_100"
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

# Create comprehensive checkpoint
checkpoint = {
    # Model architecture and weights
    'model_state_dict': vgg16_model.state_dict(),
    'model_architecture': 'VGG16',
    'model_class': 'vgg16',
    
    # Training configuration
    'num_classes': len(categories),
    'class_names': categories,
    'epochs_trained': common_epochs,
    'dropout_rate': 0.5,
    
    # Training history
    'train_losses': train_losses,
    'train_accs': train_accs,
    'val_losses': val_losses,
    'val_accs': val_accs,
    'best_val_acc': max(val_accs),
    'best_val_loss': min(val_losses),
    
    # Optimizer configuration
    'optimizer_name': 'AdamW',
    'learning_rate': 0.001,
    'weight_decay': 0.01,
    'scheduler_name': 'MultiStepLR',
    'scheduler_milestones': [30, 60, 90],
    
    # Metadata
    'timestamp': timestamp,
    'device': str(device),
    'pytorch_version': torch.__version__,
    'input_size': (200, 200),
    'batch_size': 32,
}

# Save full model (for easy loading)
full_model_path = os.path.join(MODEL_DIR, f"{model_name}_full.pth")
torch.save(vgg16_model, full_model_path)

# Save comprehensive checkpoint (recommended for production)
checkpoint_path = os.path.join(MODEL_DIR, f"{model_name}_checkpoint.pth")
torch.save(checkpoint, checkpoint_path)

print(f"VGG16 Model saved successfully:")
print(f"   Full model: {full_model_path}")
print(f"   Checkpoint: {checkpoint_path}")
print(f"   Best validation accuracy: {max(val_accs):.4f}")
print(f"   Best validation loss: {min(val_losses):.4f}")

### Evaluation

In [None]:

model_name = "vgg16_epoch_100_full"
model_path = os.path.join(MODEL_DIR, f"{model_name}.pth")

vgg16_model = torch.load(model_path, map_location=device, weights_only=False)
vgg16_model.eval()
print("Full vgg16_epoch_100  model loaded successfully")

# Enhanced evaluation
vgg16_results = evaluate_model_comprehensive(vgg16_model, test_loader_enhanced, categories, "VGG16-Enhanced", device)


# Plot training curves
plot_training_curves(train_losses, train_accs, val_losses, val_accs, "VGG16")
save_training_data_to_csv("VGG16", train_losses, train_accs, val_losses, val_accs)

# Store model probabilities for ROC plotting
with torch.no_grad():
    vgg16_model.eval()
    test_probs = []
    test_labels = []
    for batch_x, batch_y in test_loader_enhanced:
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        outputs = vgg16_model(batch_x)
        probs = torch.softmax(outputs, dim=1)
        test_probs.append(probs.cpu().numpy())
        test_labels.append(batch_y.cpu().numpy())

    test_probs = np.vstack(test_probs)
    test_labels_array = np.hstack(test_labels)

# Store for ROC plotting
store_model_probabilities('VGG16', test_probs, test_labels_array)
MODEL_RESULTS_GLOBAL['VGG16'] = vgg16_results



model_results.append(vgg16_results)

print("VGG16 Enhanced Fine-tuning Results:")
print(f"  Final Training Accuracy: {train_accs[-1]:.2f}%")
print(f"  Final Validation Accuracy: {val_accs[-1]:.2f}%")
print("VGG16 enhanced fine-tuning completed!")


## AlexNET

In [None]:
print("ALEXNET ENHANCED FINE-TUNING WITH ADVANCED OPTIMIZATION")
print("=" * 70)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
alexnet_model = AlexNetThermal(num_classes=5).to(device)
print(f"AlexNet model initialized with {sum(p.numel() for p in alexnet_model.parameters()):,} parameters")


alexnet_optimizer = optim.AdamW(
    alexnet_model.parameters(),
    lr=0.001,          # Starting learning rate
    weight_decay=0.01, # Increased from 0.0005 for better regularization
    betas=(0.9, 0.999),
    eps=1e-8
)

# Improved learning rate schedule
alexnet_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
    alexnet_optimizer,
    T_0=30,      # Changed from 10
    T_mult=2,    # Restart period multiplier
    eta_min=1e-5 # Changed from 1e-6
)


def train_alexnet_enhanced(model, train_loader, val_loader, criterion, optimizer, scheduler, epochs=30):
    best_val_acc = 0.0
    train_losses, val_losses, train_accs, val_accs = [], [], [], []

    for epoch in range(epochs):
        model.train()
        running_loss, correct, total = 0.0, 0, 0

        for batch_idx, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        train_loss = running_loss / len(train_loader)
        train_acc = 100.0 * correct / total

        # Validation phase
        model.eval()
        val_loss, correct, total = 0.0, 0, 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

        val_loss /= len(val_loader)
        val_acc = 100.0 * correct / total

        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accs.append(train_acc / 100)
        val_accs.append(val_acc / 100)

        print(
            f"Epoch {epoch + 1:3d}/{epochs} | "
            f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | "
            f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% | "
            f"LR: {optimizer.param_groups[0]['lr']:.6f}"
        )

        scheduler.step()

        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), os.path.join(MODEL_DIR, 'alexnet_best_enhanced.pth'))

    # Load best model
    model.load_state_dict(torch.load(os.path.join(MODEL_DIR, 'alexnet_best_enhanced.pth')))
    print(f"\nEnhanced AlexNet training completed! Best validation accuracy: {best_val_acc:.2f}%\n")

    return train_losses, train_accs, val_losses, val_accs


if alexnet_model_training:
    criterion = LabelSmoothingCrossEntropy(smoothing=0.1)
    train_losses, train_accs, val_losses, val_accs = train_alexnet_enhanced(
        alexnet_model,
        train_loader_enhanced,
        val_loader_enhanced,
        criterion,
        alexnet_optimizer,
        alexnet_scheduler,
        epochs=common_epochs
    )

    # Save training data to CSV
    save_training_data_to_csv('AlexNet Enhanced', train_losses, train_accs, val_losses, val_accs)


### Save model


In [None]:
# Comprehensive model saving for AlexNet
model_name = "alexnet_epoch_100"
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

# Create comprehensive checkpoint
checkpoint = {
    # Model architecture and weights
    'model_state_dict': alexnet_model.state_dict(),
    'model_architecture': 'AlexNet',
    'model_class': 'alexnet',
    
    # Training configuration
    'num_classes': len(categories),
    'class_names': categories,
    'epochs_trained': common_epochs,
    'dropout_rate': 0.5,
    
    # Training history
    'train_losses': train_losses,
    'train_accs': train_accs,
    'val_losses': val_losses,
    'val_accs': val_accs,
    'best_val_acc': max(val_accs),
    'best_val_loss': min(val_losses),
    
    # Optimizer configuration
    'optimizer_name': 'AdamW',
    'learning_rate': 0.001,
    'weight_decay': 0.01,
    'scheduler_name': 'CosineAnnealingWarmRestarts',
    'scheduler_t0': 20,
    
    # Metadata
    'timestamp': timestamp,
    'device': str(device),
    'pytorch_version': torch.__version__,
    'input_size': (200, 200),
    'batch_size': 32,
}

# Save full model (for easy loading)
full_model_path = os.path.join(MODEL_DIR, f"{model_name}_full.pth")
torch.save(alexnet_model, full_model_path)

# Save comprehensive checkpoint (recommended for production)
checkpoint_path = os.path.join(MODEL_DIR, f"{model_name}_checkpoint.pth")
torch.save(checkpoint, checkpoint_path)

print(f"AlexNet Model saved successfully:")
print(f"   Full model: {full_model_path}")
print(f"   Checkpoint: {checkpoint_path}")
print(f"   Best validation accuracy: {max(val_accs):.4f}")
print(f"   Best validation loss: {min(val_losses):.4f}")

### Evaluation

In [None]:
model_name = "alexnet_epoch_100_full"
model_path = os.path.join(MODEL_DIR, f"{model_name}.pth")

alexnet_model = torch.load(model_path, map_location=device, weights_only=False)
alexnet_model.eval()
print("Full alexnet_enhanced  model loaded successfully")


# Enhanced evaluation with correct class_names
alexnet_results = evaluate_model_comprehensive(alexnet_model, test_loader_enhanced, categories, "AlexNet-Enhanced", device)

# Plot training curves
plot_training_curves(train_losses, train_accs, val_losses, val_accs, "AlexNet")

# Store model probabilities for ROC plotting
with torch.no_grad():
    alexnet_model.eval()
    test_probs = []
    test_labels = []
    for batch_x, batch_y in test_loader_enhanced:
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        outputs = alexnet_model(batch_x)
        probs = torch.softmax(outputs, dim=1)
        test_probs.append(probs.cpu().numpy())
        test_labels.append(batch_y.cpu().numpy())

    test_probs = np.vstack(test_probs)
    test_labels_array = np.hstack(test_labels)

# Store for ROC plotting
store_model_probabilities('AlexNet', test_probs, test_labels_array)
MODEL_RESULTS_GLOBAL['AlexNet'] = alexnet_results
save_training_data_to_csv("AlexNet", train_losses, train_accs, val_losses, val_accs)



model_results.append(alexnet_results)

print("AlexNet Enhanced Fine-tuning Results:")
print(f"  Final Training Accuracy: {train_accs[-1]:.2f}%")
print(f"  Final Validation Accuracy: {val_accs[-1]:.2f}%")
print("AlexNet enhanced fine-tuning completed!")



## Hybrid Model Training

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f" Using device: {device}")

# Fixed Enhanced Training Function
def train_enhanced_hybrid_fixed(model, train_loader, val_loader, epochs=30):
    """Enhanced training with proper phase handling and metrics tracking"""

    print(f"Starting Enhanced Training for {epochs} epochs")

    # Validate data loaders
    if len(train_loader) == 0 or len(val_loader) == 0:
        raise ValueError("Data loaders cannot be empty!")

    print(f"Data loaders validated: train={len(train_loader)}, val={len(val_loader)}")

    # Training setup
    criterion = CombinedLoss(alpha=0.6, beta=0.4)

    # Phase 1: Freeze backbone (30% of epochs)
    freeze_epochs = max(3, int(0.3 * epochs))
    finetune_epochs = epochs - freeze_epochs

    print(f" Phase 1: Freezing backbone for {freeze_epochs} epochs")
    print(f" Phase 2: Fine-tuning all layers for {finetune_epochs} epochs")

    # Initialize tracking
    train_losses, train_accs = [], []
    val_losses, val_accs = [], []
    best_val_acc = 0.0

    # Phase 1: Frozen backbone training
    for param in model.vgg_features.parameters():
        param.requires_grad = False
    for param in model.alex_features.parameters():
        param.requires_grad = False

    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                           lr=2e-3, weight_decay=0.01)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=freeze_epochs)

    for epoch in range(freeze_epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, target)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

        # Calculate training metrics
        train_loss = running_loss / len(train_loader)
        train_acc = 100.0 * correct / total

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                outputs = model(data)
                loss = criterion(outputs, target)

                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                val_total += target.size(0)
                val_correct += (predicted == target).sum().item()

        val_loss = val_loss / len(val_loader)
        val_acc = 100.0 * val_correct / val_total

        # Store metrics
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        val_losses.append(val_loss)
        val_accs.append(val_acc)

        scheduler.step()

        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'enhanced_hybrid_best.pth')

        # Progress report
        print(f'Phase1 Epoch [{epoch+1}/{freeze_epochs}] - '
                  f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | '
                  f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')

    # Phase 2: Full model fine-tuning
    print(f" Phase 2: Fine-tuning all layers")

    # Unfreeze all parameters
    for param in model.parameters():
        param.requires_grad = True

    # New optimizer with lower learning rate for fine-tuning
    optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=0.01)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=finetune_epochs)

    for epoch in range(finetune_epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, target)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

        # Calculate training metrics
        train_loss = running_loss / len(train_loader)
        train_acc = 100.0 * correct / total

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                outputs = model(data)
                loss = criterion(outputs, target)

                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                val_total += target.size(0)
                val_correct += (predicted == target).sum().item()

        val_loss = val_loss / len(val_loader)
        val_acc = 100.0 * val_correct / val_total

        # Store metrics
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        val_losses.append(val_loss)
        val_accs.append(val_acc)

        scheduler.step()

        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'enhanced_hybrid_best.pth')

        # Progress report
        print(f'Phase2 Epoch [{epoch+1}/{finetune_epochs}] - '
                  f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | '
                  f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')

    # Load best model
    model.load_state_dict(torch.load('enhanced_hybrid_best.pth'))
    print(f'Enhanced Training completed! Best validation accuracy: {best_val_acc:.2f}%')

    return model, train_losses, train_accs, val_losses, val_accs

# Create and train enhanced hybrid model
print("CREATING ENHANCED HYBRID VGG-ALEXNET MODEL")
enhanced_hybrid_model = EnhancedHybridVGGAlexNet(num_classes=5, dropout_rate=0.5).to(device)
print(f"Enhanced Hybrid VGG-AlexNet model successfully created and moved to {device}")

# Count parameters
total_params = sum(p.numel() for p in enhanced_hybrid_model.parameters())
trainable_params = sum(p.numel() for p in enhanced_hybrid_model.parameters() if p.requires_grad)
print(f"Enhanced Model Parameters: {total_params:,} total, {trainable_params:,} trainable")

if hybrid_model_training:
  print(" Starting Enhanced Hybrid VGG-AlexNet Training...")
  enhanced_hybrid_model, train_losses, train_accs, val_losses, val_accs = train_enhanced_hybrid_fixed(
      enhanced_hybrid_model, train_loader_enhanced, val_loader_enhanced,  epochs= common_epochs
  )

  # Plot training curves (following ResNet pattern)
  plot_training_curves(train_losses, train_accs, val_losses, val_accs, "Enhanced-Hybrid-VGG-AlexNet")


### Save model

In [None]:
# Comprehensive model saving for Enhanced-Hybrid-VGG-AlexNet
model_name = "hybridmodel_epoch_100"
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

# Create comprehensive checkpoint
checkpoint = {
    # Model architecture and weights
    'model_state_dict': enhanced_hybrid_model.state_dict(),
    'model_architecture': 'Enhanced-Hybrid-VGG-AlexNet',
    'model_class': 'EnhancedHybridVGGAlexNet',
    
    # Training configuration
    'num_classes': len(categories),
    'class_names': categories,
    'epochs_trained': common_epochs,
    'dropout_rate': 0.5,
    'use_cbam': True,
    'multi_scale_fusion': True,
    
    # Training history
    'train_losses': train_losses,
    'train_accs': train_accs,
    'val_losses': val_losses,
    'val_accs': val_accs,
    'best_val_acc': max(val_accs),
    'best_val_loss': min(val_losses),
    
    # Optimizer configuration
    'optimizer_name': 'AdamW',
    'learning_rate': 0.0005,
    'weight_decay': 0.01,
    'scheduler_name': 'CosineAnnealingLR',
    
    # Metadata
    'timestamp': timestamp,
    'device': str(device),
    'pytorch_version': torch.__version__,
    'input_size': (200, 200),
    'batch_size': 32,
}

# Save full model (for easy loading)
full_model_path = os.path.join(MODEL_DIR, f"{model_name}_full.pth")
torch.save(enhanced_hybrid_model, full_model_path)

# Save comprehensive checkpoint (recommended for production)
checkpoint_path = os.path.join(MODEL_DIR, f"{model_name}_checkpoint.pth")
torch.save(checkpoint, checkpoint_path)

print(f"Enhanced-Hybrid-VGG-AlexNet Model saved successfully:")
print(f"   Full model: {full_model_path}")
print(f"   Checkpoint: {checkpoint_path}")
print(f"   Best validation accuracy: {max(val_accs):.4f}")
print(f"   Best validation loss: {min(val_losses):.4f}")

### Evaluation model

In [None]:

model_name = "hybridmodel_epoch_100_full"
model_path = os.path.join(MODEL_DIR, f"{model_name}.pth")

enhanced_hybrid_model = torch.load(model_path, map_location=device, weights_only=False)
enhanced_hybrid_model.eval()
print("Full hybridmodel_enhanced  model loaded successfully")


print("Evaluating Enhanced Hybrid VGG-AlexNet Model...")
enhanced_hybrid_results = evaluate_model_comprehensive(enhanced_hybrid_model, test_loader_enhanced, categories, "Enhanced-Hybrid-VGG-AlexNet", device)

# Get model probabilities and store them (following ResNet pattern)
# Get model probabilities using direct implementation (following ResNet pattern)
with torch.no_grad():
    test_labels = []
    test_probs = []
    for data, target in test_loader_enhanced:
        data, target = data.to(device), target.to(device)
        outputs = enhanced_hybrid_model(data)
        probs = F.softmax(outputs, dim=1)
        test_labels.append(target.cpu().numpy())
        test_probs.append(probs.cpu().numpy())

    test_probs = np.vstack(test_probs)
    test_labels_array = np.hstack(test_labels)
store_model_probabilities('Enhanced-Hybrid-VGG-AlexNet', test_probs, test_labels_array)

# Store results globally (following ResNet pattern)
MODEL_RESULTS_GLOBAL['Enhanced-Hybrid-VGG-AlexNet'] = enhanced_hybrid_results
save_training_data_to_csv("Enhanced-Hybrid-VGG-AlexNet", train_losses, train_accs, val_losses, val_accs)

# Additional enhanced visualization
plt.figure(figsize=(18, 6))

# Training curves
plt.subplot(1, 3, 1)
epochs_range = range(1, len(train_losses) + 1)
plt.plot(epochs_range, train_losses, 'b-', label='Training Loss', alpha=0.8)
plt.plot(epochs_range, val_losses, 'r-', label='Validation Loss', alpha=0.8)
plt.axvline(x=len(train_losses)*0.3, color='gray', linestyle='--', alpha=0.7, label='Phase 1→2')
plt.title('Enhanced Hybrid Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 3, 2)
plt.plot(epochs_range, train_accs, 'b-', label='Training Accuracy', alpha=0.8)
plt.plot(epochs_range, val_accs, 'r-', label='Validation Accuracy', alpha=0.8)
plt.axvline(x=len(train_accs)*0.3, color='gray', linestyle='--', alpha=0.7, label='Phase 1→2')
plt.title('Enhanced Hybrid Training Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True, alpha=0.3)

# Confusion matrix
plt.subplot(1, 3, 3)

sns.heatmap(enhanced_hybrid_results['confusion_matrix'],
           annot=True, fmt='d', cmap='Blues',
           xticklabels=categories, yticklabels=categories)
plt.title('Enhanced Hybrid Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('Actual')

plt.tight_layout()
plt.show()

print("Enhanced Hybrid VGG-AlexNet Training Results:")
for key, value in enhanced_hybrid_results.items():
    if key not in ['confusion_matrix', 'classification_report']:
        print(f"   {key}: {value:}")

print("Enhanced Hybrid VGG-AlexNet model training and evaluation completed!")
print(f"Model weights saved as: enhanced_hybrid_best.pth")
print(f"Final Test Accuracy: {enhanced_hybrid_results['accuracy']:.2f}%")

# Performance comparison summary
if 'Enhanced-Hybrid-VGG-AlexNet' in MODEL_RESULTS_GLOBAL:
    original_acc = MODEL_RESULTS_GLOBAL['Enhanced-Hybrid-VGG-AlexNet']['accuracy']
    enhanced_acc = enhanced_hybrid_results['accuracy']
    improvement = enhanced_acc - original_acc
    print(f"PERFORMANCE IMPROVEMENT:")
    print(f"   Original Hybrid: {original_acc:.2f}%")
    print(f"   Enhanced Hybrid: {enhanced_acc:.2f}%")
    print(f"   Improvement: {improvement:+.2f}%")
else:
    print(f"Enhanced Hybrid Model Performance:")
    print(f"   Test Accuracy: {enhanced_hybrid_results['accuracy']:.2f}%")
    print(f"   F1-Score: {enhanced_hybrid_results.get('f1_macro', 0):.2f}%")


## CNN Architecture

In [None]:
# Enhanced CNN Architecture Implementation
print("\n" + "="*80)
print(" ENHANCED CNN ARCHITECTURE TRAINING")
print("="*80)
print(f"CNN model successfully moved to {device}")

# Create CNN model
num_classes = len(categories)
cnn_enhanced = ProposedCNN(input_shape=(3, 200, 200), num_classes=num_classes, dropout_rate=0.3)

# Model summary
total_params = sum(p.numel() for p in cnn_enhanced.parameters())
trainable_params = sum(p.numel() for p in cnn_enhanced.parameters() if p.requires_grad)

print(f"\n CNN Architecture Summary:")
print(f"Total Parameters: {total_params:,}")
print(f"Trainable Parameters: {trainable_params:,}")

# Define loss and optimizer
criterion = nn.CrossEntropyLoss()

# Train with validation
#epochs = 5
if cnn_model_training:
  epochs = common_epochs
  train_losses, train_accs, val_losses, val_accs = train_model_fine_tuned(
      cnn_enhanced, train_loader_enhanced, val_loader_enhanced,
      criterion, epochs, device, "ProposedCNN")

  # Plot training curves
  plot_training_curves(train_losses, train_accs, val_losses, val_accs, "ProposedCNN")



### Save model

In [None]:
# Comprehensive model saving for ProposedCNN
model_name = "cnn_model_epoch_100"
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

# Create comprehensive checkpoint
checkpoint = {
    # Model architecture and weights
    'model_state_dict': cnn_enhanced.state_dict(),
    'model_architecture': 'ProposedCNN',
    'model_class': 'ProposedCNN',
    
    # Training configuration
    'num_classes': len(categories),
    'class_names': categories,
    'epochs_trained': common_epochs,
    'dropout_rate': 0.5,
    'use_inception_modules': True,
    'multi_branch_architecture': True,
    
    # Training history
    'train_losses': train_losses,
    'train_accs': train_accs,
    'val_losses': val_losses,
    'val_accs': val_accs,
    'best_val_acc': max(val_accs),
    'best_val_loss': min(val_losses),
    
    # Optimizer configuration
    'optimizer_name': 'AdamW',
    'learning_rate': 0.001,
    'weight_decay': 0.01,
    'scheduler_name': 'CosineAnnealingLR',
    
    # Metadata
    'timestamp': timestamp,
    'device': str(device),
    'pytorch_version': torch.__version__,
    'input_size': (200, 200),
    'batch_size': 32,
}

# Save full model (for easy loading)
full_model_path = os.path.join(MODEL_DIR, f"{model_name}_full.pth")
torch.save(cnn_enhanced, full_model_path)

# Save comprehensive checkpoint (recommended for production)
checkpoint_path = os.path.join(MODEL_DIR, f"{model_name}_checkpoint.pth")
torch.save(checkpoint, checkpoint_path)

print(f"ProposedCNN Model saved successfully:")
print(f"   Full model: {full_model_path}")
print(f"   Checkpoint: {checkpoint_path}")
print(f"   Best validation accuracy: {max(val_accs):.4f}")
print(f"   Best validation loss: {min(val_losses):.4f}")

### Evaluation

In [None]:

model_name = "cnn_model_epoch_100_full"
model_path = os.path.join(MODEL_DIR, f"{model_name}.pth")

cnnmodel_enhanced = torch.load(model_path, map_location=device, weights_only=False)
cnnmodel_enhanced.eval()
print("Full hybridmodel_enhanced  model loaded successfully")

# Comprehensive evaluation
cnn_results = evaluate_model_comprehensive(cnn_enhanced, test_loader_enhanced, categories, "ProposedCNN", device)

# Store model probabilities for ROC plotting
with torch.no_grad():
    cnn_enhanced.eval()
    test_probs = []
    test_labels = []
    for batch_x, batch_y in test_loader_enhanced:
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        outputs = cnn_enhanced(batch_x)
        probs = torch.softmax(outputs, dim=1)
        test_probs.append(probs.cpu().numpy())
        test_labels.append(batch_y.cpu().numpy())

    test_probs = np.vstack(test_probs)
    test_labels_array = np.hstack(test_labels)

# Store for ROC plotting
store_model_probabilities('ProposedCNN', test_probs, test_labels_array)
MODEL_RESULTS_GLOBAL['ProposedCNN'] = cnn_results
save_training_data_to_csv("ProposedCNN", train_losses, train_accs, val_losses, val_accs)



# Store results
model_results.append(cnn_results)
print(f"\nProposed CNN training and evaluation completed!")

# Data Analysis Execution

## Comprehensive Visualisation

In [None]:
# Comprehensive Model Comparison and Final Results
print("\n" + "="*100)
print(" COMPREHENSIVE MODEL COMPARISON - FINAL RESULTS")
print("="*100)

# Create comparison table

if model_results:
    # Create DataFrame for easy comparison
    comparison_df = pd.DataFrame(model_results)

    # Display results table
    print("\nModel Performance Comparison:")
    print("-" * 100)
    print(f"{'Model':<20} {'Accuracy':<10} {'Precision':<12} {'Recall':<10} {'F1-Score':<10} {'F1-Weighted':<12}")
    print("-" * 100)

    for _, row in comparison_df.iterrows():
        print(f"{row['model_name']:<20} {row['accuracy']:<10.4f} {row['precision_macro']:<12.4f} "
              f"{row['recall_macro']:<10.4f} {row['f1_macro']:<10.4f} {row['f1_weighted']:<12.4f}")

    print("-" * 100)

    # Find best performing model for each metric
    best_accuracy = comparison_df.loc[comparison_df['accuracy'].idxmax()]
    best_precision = comparison_df.loc[comparison_df['precision_macro'].idxmax()]
    best_recall = comparison_df.loc[comparison_df['recall_macro'].idxmax()]
    best_f1 = comparison_df.loc[comparison_df['f1_macro'].idxmax()]

    print(f"\n Best Performing Models:")
    print(f"  Highest Accuracy: {best_accuracy['model_name']} ({best_accuracy['accuracy']:.4f})")
    print(f"  Highest Precision: {best_precision['model_name']} ({best_precision['precision_macro']:.4f})")
    print(f"  Highest Recall: {best_recall['model_name']} ({best_recall['recall_macro']:.4f})")
    print(f"  Highest F1-Score: {best_f1['model_name']} ({best_f1['f1_macro']:.4f})")

    # Create comparison plots
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

    models = comparison_df['model_name']
    colors = ['skyblue', 'lightcoral', 'lightgreen', 'orange', 'purple'][:len(models)]

    # Accuracy comparison
    accuracies = comparison_df['accuracy']
    bars1 = ax1.bar(models, accuracies, color=colors)
    ax1.set_title('Model Accuracy Comparison', fontweight='bold')
    ax1.set_ylabel('Accuracy')
    ax1.set_ylim(0, 1)
    ax1.tick_params(axis='x', rotation=45)
    for bar, acc in zip(bars1, accuracies):
        ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                f'{acc:.3f}', ha='center', va='bottom', fontweight='bold')

    # Precision comparison
    precisions = comparison_df['precision_macro']
    bars2 = ax2.bar(models, precisions, color=colors)
    ax2.set_title('Model Precision Comparison', fontweight='bold')
    ax2.set_ylabel('Precision (Macro)')
    ax2.set_ylim(0, 1)
    ax2.tick_params(axis='x', rotation=45)
    for bar, prec in zip(bars2, precisions):
        ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                f'{prec:.3f}', ha='center', va='bottom', fontweight='bold')

    # Recall comparison
    recalls = comparison_df['recall_macro']
    bars3 = ax3.bar(models, recalls, color=colors)
    ax3.set_title('Model Recall Comparison', fontweight='bold')
    ax3.set_ylabel('Recall (Macro)')
    ax3.set_ylim(0, 1)
    ax3.tick_params(axis='x', rotation=45)
    for bar, rec in zip(bars3, recalls):
        ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                f'{rec:.3f}', ha='center', va='bottom', fontweight='bold')

    # F1-Score comparison
    f1_scores = comparison_df['f1_macro']
    bars4 = ax4.bar(models, f1_scores, color=colors)
    ax4.set_title('Model F1-Score Comparison', fontweight='bold')
    ax4.set_ylabel('F1-Score (Macro)')
    ax4.set_ylim(0, 1)
    ax4.tick_params(axis='x', rotation=45)
    for bar, f1 in zip(bars4, f1_scores):
        ax4.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                f'{f1:.3f}', ha='center', va='bottom', fontweight='bold')

    plt.tight_layout()
    plt.suptitle(' Comprehensive Model Performance Comparison', fontsize=16, fontweight='bold', y=1.02)
    plt.show()

    # Determine overall best model
    comparison_df['overall_score'] = (comparison_df['accuracy'] + comparison_df['precision_macro'] +
                                     comparison_df['recall_macro'] + comparison_df['f1_macro']) / 4
    best_overall = comparison_df.loc[comparison_df['overall_score'].idxmax()]

    print(f"\n OVERALL BEST MODEL: {best_overall['model_name']}")
    print(f"   Overall Score: {best_overall['overall_score']:.4f}")
    print(f"   Accuracy: {best_overall['accuracy']:.4f}")
    print(f"   Precision: {best_overall['precision_macro']:.4f}")
    print(f"   Recall: {best_overall['recall_macro']:.4f}")
    print(f"   F1-Score: {best_overall['f1_macro']:.4f}")

    # Summary statistics
    print(f"\nFinal Summary:")
    print(f"  Dataset: {len(X_train) + len(X_val) + len(X_test)} total samples (augmented)")
    print(f"   Models trained: {len(model_results)}")
    print(f"  Optimizations used: 8 advanced techniques")
    print(f"   Hardware: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
    print(f"  Best overall accuracy: {comparison_df['accuracy'].max():.4f}")

else:
    print("No model results found. Please run the model training sections first.")

print("\n" + "="*100)
print("ENHANCED THERMAL IMAGING CLASSIFICATION COMPLETED SUCCESSFULLY!")
print(" All models trained with comprehensive optimization suite")
print("Detailed evaluation metrics calculated and compared")
print(" Best performing models identified")
print("="*100)

## CSV Exporter

In [None]:
# Enhanced Data Export and CSV Generation for Graph Creation
# =========================================================
# This cell provides comprehensive CSV export functionality for all thermal imaging analysis data
# Add this cell to your notebook after the "Enhanced Data Export and Comprehensive Visualization Execution" section

class ThermalDataCSVExporter:
    """Comprehensive CSV exporter for thermal imaging analysis data"""

    def __init__(self, output_dir='thermal_csv_exports'):
        """Initialize the CSV exporter with output directory"""
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(exist_ok=True)
        self.timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        print(f"CSV export directory: {self.output_dir.absolute()}")

    def export_training_data_csv(self, training_data_dict):
        """Export training curves data to CSV files for graph creation"""
        print("Exporting training data to CSV...")

        # Combined training data for all models
        combined_data = []

        for model_name, data in training_data_dict.items():
            epochs = range(1, len(data['train_losses']) + 1)

            # Create individual model DataFrame
            df_individual = pd.DataFrame({
                'epoch': epochs,
                'train_loss': data['train_losses'],
                'train_accuracy': data['train_accs'],
                'val_loss': data['val_losses'],
                'val_accuracy': data['val_accs'],
                'model': model_name,
                'loss_difference': np.array(data['val_losses']) - np.array(data['train_losses']),
                'accuracy_difference': np.array(data['train_accs']) - np.array(data['val_accs'])
            })

            # Save individual model file
            filename = f'{model_name.lower().replace(" ", "_")}_training_curves_{self.timestamp}.csv'
            filepath = self.output_dir / filename
            df_individual.to_csv(filepath, index=False)
            print(f"  Saved: {filename}")

            # Add to combined data
            combined_data.append(df_individual)

        # Save combined training data
        if combined_data:
            df_combined = pd.concat(combined_data, ignore_index=True)
            combined_filepath = self.output_dir / f'all_models_training_curves_{self.timestamp}.csv'
            df_combined.to_csv(combined_filepath, index=False)
            print(f"  Saved combined training data: all_models_training_curves_{self.timestamp}.csv")

        return len(training_data_dict)

    def export_model_comparison_csv(self, model_results_list):
        """Export model comparison metrics to CSV for comparison graphs"""
        print("Exporting model comparison data to CSV...")

        if not model_results_list:
            print("  No model results to export")
            return 0

        # Create comparison DataFrame
        comparison_data = []
        for result in model_results_list:
            comparison_data.append({
                'model_name': result.get('model_name', 'Unknown'),
                'accuracy': result.get('accuracy', 0) * 100,  # Convert to percentage
                'precision_macro': result.get('precision_macro', 0) * 100,
                'recall_macro': result.get('recall_macro', 0) * 100,
                'f1_macro': result.get('f1_macro', 0) * 100,
                'precision_micro': result.get('precision_micro', 0) * 100,
                'recall_micro': result.get('recall_micro', 0) * 100,
                'f1_micro': result.get('f1_micro', 0) * 100,
                'precision_weighted': result.get('precision_weighted', 0) * 100,
                'recall_weighted': result.get('recall_weighted', 0) * 100,
                'f1_weighted': result.get('f1_weighted', 0) * 100
            })

        df_comparison = pd.DataFrame(comparison_data)
        filepath = self.output_dir / f'model_comparison_{self.timestamp}.csv'
        df_comparison.to_csv(filepath, index=False)
        print(f"  Saved: model_comparison_{self.timestamp}.csv")

        return len(model_results_list)

# Execute the CSV export using existing data from the notebook
# =============================================================

# Initialize the CSV exporter
csv_exporter = ThermalDataCSVExporter()

print("Starting comprehensive CSV export for graph creation...")
print("=" * 60)

# Export training data if available
if 'enhanced_training_data' in globals() and enhanced_training_data:
    print("Exporting training curves data...")
    csv_exporter.export_training_data_csv(enhanced_training_data)
    print()

# Export model comparison data if available
if 'enhanced_model_results' in globals() and enhanced_model_results:
    print("Exporting model comparison data...")
    csv_exporter.export_model_comparison_csv(enhanced_model_results)
    print()

# Also export any other model results variables that might exist
alternative_vars = ['model_results', 'all_model_results', 'final_results']
for var_name in alternative_vars:
    if var_name in globals() and globals()[var_name]:
        print(f"Found additional results in '{var_name}', exporting...")
        csv_exporter.export_model_comparison_csv(globals()[var_name])
        print()

print("=" * 60)
print("CSV EXPORT COMPLETED!")
print("=" * 60)
print(f"All CSV files saved to: {csv_exporter.output_dir.absolute()}")
print()
print("Available CSV files for creating graphs:")
print("  • all_models_training_curves_*.csv - For training/validation loss and accuracy plots")
print("  • model_comparison_*.csv - For model performance comparison charts")
print("  • Individual model files - For detailed per-model analysis")
print()
print(" You can now use these CSV files with any plotting library (matplotlib, seaborn, plotly)")
print("   or external tools (Excel, Tableau, Power BI) to create custom visualizations!")

# Display sample of the exported data
print("\n" + "=" * 60)
print("SAMPLE OF EXPORTED DATA")
print("=" * 60)

# Show sample of training data if available
training_files = list(csv_exporter.output_dir.glob("all_models_training_curves_*.csv"))
if training_files:
    sample_df = pd.read_csv(training_files[0])
    print("Training Curves Data Sample:")
    print(sample_df.head())
    print(f"Shape: {sample_df.shape}")
    print()

# Show sample of comparison data if available
comparison_files = list(csv_exporter.output_dir.glob("model_comparison_*.csv"))
if comparison_files:
    sample_df = pd.read_csv(comparison_files[0])
    print("Model Comparison Data Sample:")
    print(sample_df)
    print(f"Shape: {sample_df.shape}")

print("\n" + "=" * 60)

## Data Visualiser

In [None]:
# Set style for better-looking plots
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

class ThermalDataVisualizer:
    """Create comprehensive visualizations from thermal imaging CSV data"""

    def __init__(self, csv_directory='thermal_csv_exports'):
        """Initialize with the directory containing CSV files"""
        self.csv_dir = Path(csv_directory)
        if not self.csv_dir.exists():
            raise FileNotFoundError(f"CSV directory not found: {csv_directory}")

        self.output_dir = self.csv_dir / 'visualizations'
        self.output_dir.mkdir(exist_ok=True)

        print(f"CSV Data Directory: {self.csv_dir.absolute()}")
        print(f"Visualization Output: {self.output_dir.absolute()}")

        # Load available CSV files
        self.load_csv_files()

    def load_csv_files(self):
        """Load all available CSV files"""
        print("\n Loading CSV files...")

        # Find training curves data
        training_files = list(self.csv_dir.glob("all_models_training_curves_*.csv"))
        self.training_data = None
        if training_files:
            self.training_data = pd.read_csv(training_files[0])
            print(f"  Training data loaded: {training_files[0].name}")

        # Find model comparison data
        comparison_files = list(self.csv_dir.glob("model_comparison_*.csv"))
        self.comparison_data = None
        if comparison_files:
            self.comparison_data = pd.read_csv(comparison_files[0])
            print(f"  Comparison data loaded: {comparison_files[0].name}")

        # Find individual model files
        individual_files = list(self.csv_dir.glob("*_training_curves_*.csv"))
        individual_files = [f for f in individual_files if "all_models" not in f.name]
        self.individual_data = {}
        for file in individual_files:
            model_name = file.name.split('_training_curves_')[0]
            self.individual_data[model_name] = pd.read_csv(file)
            print(f"  Individual model data loaded: {model_name}")

    def plot_training_curves_comparison(self, save=True, figsize=(15, 12)):
        """Create training curves comparison plots"""
        if self.training_data is None:
            print("No training data available for plotting")
            return

        print("Creating training curves comparison plots...")

        fig, axes = plt.subplots(2, 2, figsize=figsize)
        fig.suptitle('Training Curves Comparison Across Models', fontsize=16, fontweight='bold')

        models = self.training_data['model'].unique()
        colors = plt.cm.Set1(np.linspace(0, 1, len(models)))

        for i, model in enumerate(models):
            model_data = self.training_data[self.training_data['model'] == model]
            color = colors[i]

            # Training Loss
            axes[0, 0].plot(model_data['epoch'], model_data['train_loss'],
                           label=f'{model}', color=color, marker='o', linewidth=2)

            # Validation Loss
            axes[0, 1].plot(model_data['epoch'], model_data['val_loss'],
                           label=f'{model}', color=color, marker='s', linewidth=2)

            # Training Accuracy
            axes[1, 0].plot(model_data['epoch'], model_data['train_accuracy'],
                           label=f'{model}', color=color, marker='^', linewidth=2)

            # Validation Accuracy
            axes[1, 1].plot(model_data['epoch'], model_data['val_accuracy'],
                           label=f'{model}', color=color, marker='D', linewidth=2)

        # Customize subplots
        axes[0, 0].set_title('Training Loss', fontweight='bold')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)

        axes[0, 1].set_title('Validation Loss', fontweight='bold')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Loss')
        axes[0, 1].legend()
        axes[0, 1].grid(True, alpha=0.3)

        axes[1, 0].set_title('Training Accuracy', fontweight='bold')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('Accuracy')
        axes[1, 0].legend()
        axes[1, 0].grid(True, alpha=0.3)

        axes[1, 1].set_title('Validation Accuracy', fontweight='bold')
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('Accuracy')
        axes[1, 1].legend()
        axes[1, 1].grid(True, alpha=0.3)

        plt.tight_layout()

        if save:
            filepath = self.output_dir / 'training_curves_comparison.png'
            plt.savefig(filepath, dpi=300, bbox_inches='tight')
            print(f"  Saved: {filepath.name}")

        plt.show()

    def plot_overfitting_analysis(self, save=True, figsize=(15, 6)):
        """Create overfitting analysis plots"""
        if self.training_data is None:
            print("No training data available for overfitting analysis")
            return

        print(" Creating overfitting analysis plots...")

        fig, axes = plt.subplots(1, 2, figsize=figsize)
        fig.suptitle('Overfitting Analysis', fontsize=16, fontweight='bold')

        models = self.training_data['model'].unique()
        colors = plt.cm.Set1(np.linspace(0, 1, len(models)))

        for i, model in enumerate(models):
            model_data = self.training_data[self.training_data['model'] == model]
            color = colors[i]

            # Loss difference (Val - Train)
            axes[0].plot(model_data['epoch'], model_data['loss_difference'],
                        label=f'{model}', color=color, marker='o', linewidth=2)

            # Accuracy difference (Train - Val)
            axes[1].plot(model_data['epoch'], model_data['accuracy_difference'],
                        label=f'{model}', color=color, marker='s', linewidth=2)

        # Customize subplots
        axes[0].set_title('Loss Difference (Validation - Training)', fontweight='bold')
        axes[0].set_xlabel('Epoch')
        axes[0].set_ylabel('Loss Difference')
        axes[0].axhline(y=0, color='red', linestyle='--', alpha=0.5)
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)

        axes[1].set_title('Accuracy Difference (Training - Validation)', fontweight='bold')
        axes[1].set_xlabel('Epoch')
        axes[1].set_ylabel('Accuracy Difference')
        axes[1].axhline(y=0, color='red', linestyle='--', alpha=0.5)
        axes[1].legend()
        axes[1].grid(True, alpha=0.3)

        plt.tight_layout()

        if save:
            filepath = self.output_dir / 'overfitting_analysis.png'
            plt.savefig(filepath, dpi=300, bbox_inches='tight')
            print(f"  Saved: {filepath.name}")

        plt.show()

    def plot_model_comparison(self, save=True, figsize=(15, 12)):
        """Create model performance comparison plots"""
        if self.comparison_data is None:
            print("No comparison data available for plotting")
            return

        print("Creating model comparison plots...")

        fig, axes = plt.subplots(2, 2, figsize=figsize)
        fig.suptitle('Model Performance Comparison', fontsize=16, fontweight='bold')

        # Main metrics for bar plot
        main_metrics = ['accuracy', 'precision_macro', 'recall_macro', 'f1_macro']

        # Bar plot for main metrics
        x = np.arange(len(self.comparison_data))
        width = 0.2

        for i, metric in enumerate(main_metrics):
            if metric in self.comparison_data.columns:
                axes[0, 0].bar(x + i*width, self.comparison_data[metric],
                              width, label=metric.replace('_', ' ').title())

        axes[0, 0].set_title('Main Performance Metrics', fontweight='bold')
        axes[0, 0].set_xlabel('Models')
        axes[0, 0].set_ylabel('Score (%)')
        axes[0, 0].set_xticks(x + width * 1.5)
        axes[0, 0].set_xticklabels(self.comparison_data['model_name'], rotation=45)
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)

        # Heatmap for all metrics
        metric_cols = [col for col in self.comparison_data.columns if col != 'model_name']
        heatmap_data = self.comparison_data[metric_cols].T
        heatmap_data.columns = self.comparison_data['model_name']

        sns.heatmap(heatmap_data, annot=True, fmt='.2f', cmap='YlOrRd',
                   ax=axes[0, 1], cbar_kws={'label': 'Score (%)'})
        axes[0, 1].set_title('Performance Heatmap', fontweight='bold')

        # Radar chart for main metrics
        if len(main_metrics) >= 3:
            angles = np.linspace(0, 2*np.pi, len(main_metrics), endpoint=False).tolist()
            angles += angles[:1]  # Complete the circle

            ax_radar = plt.subplot(2, 2, 3, projection='polar')

            for i, model in enumerate(self.comparison_data['model_name']):
                values = []
                for metric in main_metrics:
                    if metric in self.comparison_data.columns:
                        values.append(self.comparison_data.iloc[i][metric])
                values += values[:1]  # Complete the circle

                ax_radar.plot(angles, values, 'o-', linewidth=2, label=model)
                ax_radar.fill(angles, values, alpha=0.25)

            ax_radar.set_xticks(angles[:-1])
            ax_radar.set_xticklabels([m.replace('_', ' ').title() for m in main_metrics])
            ax_radar.set_title('Performance Radar Chart', fontweight='bold', pad=20)
            ax_radar.legend(loc='upper right', bbox_to_anchor=(1.3, 1.0))

        # Accuracy comparison bar plot
        if 'accuracy' in self.comparison_data.columns:
            bars = axes[1, 1].bar(self.comparison_data['model_name'],
                                 self.comparison_data['accuracy'],
                                 color=plt.cm.viridis(np.linspace(0, 1, len(self.comparison_data))))

            axes[1, 1].set_title('Accuracy Comparison', fontweight='bold')
            axes[1, 1].set_xlabel('Models')
            axes[1, 1].set_ylabel('Accuracy (%)')
            axes[1, 1].set_xticklabels(self.comparison_data['model_name'], rotation=45)

            # Add value labels on bars
            for bar in bars:
                height = bar.get_height()
                axes[1, 1].annotate(f'{height:.2f}%',
                                   xy=(bar.get_x() + bar.get_width() / 2, height),
                                   xytext=(0, 3),  # 3 points vertical offset
                                   textcoords="offset points",
                                   ha='center', va='bottom')

            axes[1, 1].grid(True, alpha=0.3)

        plt.tight_layout()

        if save:
            filepath = self.output_dir / 'model_comparison.png'
            plt.savefig(filepath, dpi=300, bbox_inches='tight')
            print(f"  Saved: {filepath.name}")

        plt.show()

    def create_summary_report(self, save=True, figsize=(12, 10)):
        """Create a summary report with key metrics"""
        print("Creating summary report...")

        fig, axes = plt.subplots(2, 1, figsize=figsize)
        fig.suptitle('Thermal Imaging Analysis Summary Report', fontsize=16, fontweight='bold')

        # Training summary
        if self.training_data is not None:
            summary_text = "TRAINING SUMMARY\n" + "="*50 + "\n"

            for model in self.training_data['model'].unique():
                model_data = self.training_data[self.training_data['model'] == model]
                final_train_acc = model_data['train_accuracy'].iloc[-1]
                final_val_acc = model_data['val_accuracy'].iloc[-1]
                final_train_loss = model_data['train_loss'].iloc[-1]
                final_val_loss = model_data['val_loss'].iloc[-1]

                summary_text += f"\n{model}:\n"
                summary_text += f"  Final Training Accuracy: {final_train_acc:.4f}\n"
                summary_text += f"  Final Validation Accuracy: {final_val_acc:.4f}\n"
                summary_text += f"  Final Training Loss: {final_train_loss:.4f}\n"
                summary_text += f"  Final Validation Loss: {final_val_loss:.4f}\n"

            axes[0].text(0.05, 0.95, summary_text, transform=axes[0].transAxes,
                        fontsize=10, verticalalignment='top', fontfamily='monospace')
            axes[0].set_title('Training Results Summary', fontweight='bold')
            axes[0].axis('off')

        # Model comparison summary
        if self.comparison_data is not None:
            comparison_text = "MODEL COMPARISON\n" + "="*50 + "\n"

            best_accuracy = self.comparison_data.loc[self.comparison_data['accuracy'].idxmax()]
            comparison_text += f"\nBest Overall Model: {best_accuracy['model_name']}\n"
            comparison_text += f"  Accuracy: {best_accuracy['accuracy']:.2f}%\n"

            if 'precision_macro' in self.comparison_data.columns:
                comparison_text += f"  Precision (Macro): {best_accuracy['precision_macro']:.2f}%\n"
            if 'recall_macro' in self.comparison_data.columns:
                comparison_text += f"  Recall (Macro): {best_accuracy['recall_macro']:.2f}%\n"
            if 'f1_macro' in self.comparison_data.columns:
                comparison_text += f"  F1-Score (Macro): {best_accuracy['f1_macro']:.2f}%\n"

            comparison_text += f"\nAll Models Performance:\n"
            for _, row in self.comparison_data.iterrows():
                comparison_text += f"  {row['model_name']}: {row['accuracy']:.2f}% accuracy\n"

            axes[1].text(0.05, 0.95, comparison_text, transform=axes[1].transAxes,
                        fontsize=10, verticalalignment='top', fontfamily='monospace')
            axes[1].set_title('Performance Comparison Summary', fontweight='bold')
            axes[1].axis('off')

        plt.tight_layout()

        if save:
            filepath = self.output_dir / 'summary_report.png'
            plt.savefig(filepath, dpi=300, bbox_inches='tight')
            print(f"  Saved: {filepath.name}")

        plt.show()

    def generate_all_plots(self):
        """Generate all available plots"""
        print(f"\n{'='*60}")
        print("GENERATING ALL VISUALIZATION PLOTS")
        print(f"{'='*60}")

        # Generate all plots
        if self.training_data is not None:
            self.plot_training_curves_comparison()
            print()
            self.plot_overfitting_analysis()
            print()

        if self.comparison_data is not None:
            self.plot_model_comparison()
            print()

        self.create_summary_report()
        print()

        print(f"{'='*60}")
        print("ALL VISUALIZATIONS COMPLETED!")
        print(f"{'='*60}")
        print(f"All plots saved to: {self.output_dir.absolute()}")
        print("\nGenerated files:")
        for plot_file in self.output_dir.glob("*.png"):
            print(f"  • {plot_file.name}")

# Convenience functions for quick usage
def create_all_visualizations(csv_directory='thermal_csv_exports'):
    """One-line function to create all visualizations"""
    try:
        visualizer = ThermalDataVisualizer(csv_directory)
        visualizer.generate_all_plots()
        return visualizer
    except Exception as e:
        print(f" Error: {e}")
        return None

def get_data_summary(csv_directory='thermal_csv_exports'):
    """Get a summary of available data"""
    try:
        visualizer = ThermalDataVisualizer(csv_directory)

        print("DATA SUMMARY")
        print("=" * 40)
        print(f"CSV Directory: {visualizer.csv_dir.absolute()}")
        print(f" Training Data Available: {'' if visualizer.training_data is not None else ''}")
        print(f"Comparison Data Available: {'' if visualizer.comparison_data is not None else ''}")

        if visualizer.training_data is not None:
            models = list(visualizer.training_data['model'].unique())
            print(f" Models Found: {', '.join(models)}")
            print(f" Training Data Shape: {visualizer.training_data.shape}")

        if visualizer.comparison_data is not None:
            print(f" Comparison Data Shape: {visualizer.comparison_data.shape}")

        return visualizer

    except Exception as e:
        print(f" Error: {e}")
        return None

print("Visualization functions loaded successfully!")
print(" Use create_all_visualizations() to generate all plots, or get_data_summary() to check available data")


## Image Visualisation

In [None]:
print("CREATING THERMAL IMAGING VISUALIZATIONS")
print("=" * 60)

# Check if CSV data exists and create visualizations
try:
    # Get data summary first
    print(" Checking available data...")
    visualizer = get_data_summary()

    if visualizer and (visualizer.training_data is not None or visualizer.comparison_data is not None):
        print("\nCreating all visualizations...")
        visualizer.generate_all_plots()

        print("\nVISUALIZATION COMPLETE!")
        print("Check the 'thermal_csv_exports/visualizations' folder for saved plots")
        print(" All plots are displayed above and saved as high-quality PNG files")

    else:
        print("\n No data available for visualization!")
        print(" Make sure you've run the CSV export cell first")

except Exception as e:
    print(f"\n Error creating visualizations: {e}")
    print(" Make sure the CSV export completed successfully")

## ROC Curve

In [None]:
# Generate comprehensive ROC analysis for all trained models
print("\nGenerating Comprehensive ROC Analysis for All Models...")
print("=" * 60)

# Plot ROC curves for all models
final_auc_scores = plot_all_model_roc_curves()

if final_auc_scores:
    print("\nFinal AUC Summary:")
    model_names = list(MODEL_RESULTS_GLOBAL.keys())
    for model_name, auc_score in zip(model_names, final_auc_scores):
        print(f"   {model_name}: {auc_score:.4f}")

    # Find best model
    best_idx = np.argmax(final_auc_scores)
    best_model = model_names[best_idx]
    best_auc = final_auc_scores[best_idx]

    print(f"\n Best Model: {best_model} (AUC: {best_auc:.4f})")

print("\nEnhanced ROC Analysis Complete!")
print("Check the 'thermal_analysis_results' folder for generated plots:")
print("   • enhanced_roc_curves_seaborn.png")
print("   • enhanced_roc_dashboard_seaborn.png")


In [None]:
# Next Steps: Comprehensive Model Analysis and Visualization
# Add this cell after all your model training and CSV export is complete

print("Starting Comprehensive Model Analysis...")

# Step 1: Plot individual training curves for each model
print("\nStep 1: Plotting individual training curves...")

results_dir = 'thermal_analysis_results'
csv_files = glob.glob(os.path.join(results_dir, '*_training_curves_*.csv'))

print(f"Found {len(csv_files)} training curve files:")
for csv_file in csv_files:
    filename = os.path.basename(csv_file)
    print(f"   {filename}")

# Plot each model's training curves individually
for csv_file in csv_files:
    try:
        df = pd.read_csv(csv_file)
        model_name = df['model_name'].iloc[0]
        print(f"\nPlotting training curves for: {model_name}")
        plot_training_curves_from_csv(csv_file)
    except Exception as e:
        print(f"Error plotting {csv_file}: {e}")

# Step 2: Create comparison of all models
print("\nStep 2: Creating comparison of all models...")
plot_all_models_comparison()

# Step 3: Generate training summary
print("\nStep 3: Generating training summary...")
training_summary = create_training_summary()

# Step 4: Plot comprehensive model evaluation metrics
print("\nStep 4: Creating comprehensive evaluation metrics visualization...")

# Check if we have model evaluation results
if 'model_results' in globals() and model_results:
    # Prepare evaluation metrics data
    evaluation_data = []
    for result in model_results:
        evaluation_data.append({
            'Model': result.get('model_name', 'Unknown'),
            'Accuracy': result.get('accuracy', 0) * 100,
            'Precision (Macro)': result.get('precision_macro', 0) * 100,
            'Recall (Macro)': result.get('recall_macro', 0) * 100,
            'F1-Score (Macro)': result.get('f1_macro', 0) * 100,
            'Precision (Weighted)': result.get('precision_weighted', 0) * 100,
            'Recall (Weighted)': result.get('recall_weighted', 0) * 100,
            'F1-Score (Weighted)': result.get('f1_weighted', 0) * 100
        })

    eval_df = pd.DataFrame(evaluation_data)

    # Create comprehensive evaluation metrics plot
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))

    models = eval_df['Model']
    x_pos = np.arange(len(models))

    # Plot 1: Accuracy vs F1-Scores
    width = 0.25
    axes[0,0].bar(x_pos - width, eval_df['Accuracy'], width, label='Accuracy', alpha=0.8)
    axes[0,0].bar(x_pos, eval_df['F1-Score (Macro)'], width, label='F1-Macro', alpha=0.8)
    axes[0,0].bar(x_pos + width, eval_df['F1-Score (Weighted)'], width, label='F1-Weighted', alpha=0.8)
    axes[0,0].set_title('Model Performance Comparison', fontweight='bold')
    axes[0,0].set_xlabel('Models')
    axes[0,0].set_ylabel('Performance (%)')
    axes[0,0].set_xticks(x_pos)
    axes[0,0].set_xticklabels(models, rotation=45, ha='right')
    axes[0,0].legend()
    axes[0,0].grid(True, alpha=0.3)

    # Add value labels
    for i, model in enumerate(models):
        axes[0,0].text(i - width, eval_df.iloc[i]['Accuracy'] + 1, f"{eval_df.iloc[i]['Accuracy']:.1f}",
                      ha='center', va='bottom', fontweight='bold', fontsize=9)
        axes[0,0].text(i, eval_df.iloc[i]['F1-Score (Macro)'] + 1, f"{eval_df.iloc[i]['F1-Score (Macro)']:.1f}",
                      ha='center', va='bottom', fontweight='bold', fontsize=9)
        axes[0,0].text(i + width, eval_df.iloc[i]['F1-Score (Weighted)'] + 1, f"{eval_df.iloc[i]['F1-Score (Weighted)']:.1f}",
                      ha='center', va='bottom', fontweight='bold', fontsize=9)

    # Plot 2: Precision Comparison
    axes[0,1].bar(x_pos - width/2, eval_df['Precision (Macro)'], width, label='Macro', alpha=0.8)
    axes[0,1].bar(x_pos + width/2, eval_df['Precision (Weighted)'], width, label='Weighted', alpha=0.8)
    axes[0,1].set_title('Precision Comparison', fontweight='bold')
    axes[0,1].set_xlabel('Models')
    axes[0,1].set_ylabel('Precision (%)')
    axes[0,1].set_xticks(x_pos)
    axes[0,1].set_xticklabels(models, rotation=45, ha='right')
    axes[0,1].legend()
    axes[0,1].grid(True, alpha=0.3)

    # Plot 3: Recall Comparison
    axes[1,0].bar(x_pos - width/2, eval_df['Recall (Macro)'], width, label='Macro', alpha=0.8)
    axes[1,0].bar(x_pos + width/2, eval_df['Recall (Weighted)'], width, label='Weighted', alpha=0.8)
    axes[1,0].set_title('Recall Comparison', fontweight='bold')
    axes[1,0].set_xlabel('Models')
    axes[1,0].set_ylabel('Recall (%)')
    axes[1,0].set_xticks(x_pos)
    axes[1,0].set_xticklabels(models, rotation=45, ha='right')
    axes[1,0].legend()
    axes[1,0].grid(True, alpha=0.3)

    # Plot 4: Performance Heatmap
    metrics_for_heatmap = eval_df.set_index('Model')[['Accuracy', 'Precision (Macro)', 'Recall (Macro)',
                                                      'F1-Score (Macro)', 'Precision (Weighted)',
                                                      'Recall (Weighted)', 'F1-Score (Weighted)']]
    sns.heatmap(metrics_for_heatmap.T, annot=True, fmt='.1f', cmap='YlOrRd', ax=axes[1,1],
                cbar_kws={'label': 'Performance (%)'})
    axes[1,1].set_title('Performance Metrics Heatmap', fontweight='bold')
    axes[1,1].set_xlabel('Models')
    axes[1,1].set_ylabel('Metrics')

    plt.suptitle('Comprehensive Model Evaluation Analysis', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()

    # Save evaluation metrics to CSV
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    eval_metrics_file = os.path.join(results_dir, f'model_evaluation_metrics_{timestamp}.csv')
    eval_df.to_csv(eval_metrics_file, index=False)
    print(f"Model evaluation metrics saved to: {eval_metrics_file}")

    # Find and display best performing models
    best_accuracy = eval_df.loc[eval_df['Accuracy'].idxmax()]
    best_f1_macro = eval_df.loc[eval_df['F1-Score (Macro)'].idxmax()]
    best_f1_weighted = eval_df.loc[eval_df['F1-Score (Weighted)'].idxmax()]

    print(f"\n BEST PERFORMING MODELS:")
    print(f"   Best Accuracy: {best_accuracy['Model']} ({best_accuracy['Accuracy']:.2f}%)")
    print(f"   Best F1-Macro: {best_f1_macro['Model']} ({best_f1_macro['F1-Score (Macro)']:.2f}%)")
    print(f"   Best F1-Weighted: {best_f1_weighted['Model']} ({best_f1_weighted['F1-Score (Weighted)']:.2f}%)")

else:
    print("No model evaluation results found. Make sure you've evaluated your models and stored results in 'model_results'.")

print(f"\nAll analysis results saved in '{results_dir}' directory")
print("Comprehensive model analysis completed!")

# Inference

In [None]:
print("=" * 80)
print(" LOADING MODELS FROM DISK")
print("=" * 80)

# IMPORTANT: Define model classes EXACTLY as they were saved
# These match the training cell definitions (cells 42, 48, 54, 60)

# Check device availability
print(f"CUDA Available: {torch.cuda.is_available()}")

# Use CPU for all models to ensure consistent benchmarking
primary_device = torch.device("cuda")
fallback_device = torch.device("cpu")
print(f"Primary Device: {primary_device}")


# Model directory
MODEL_DIR = Path(MODEL_DIR)
print(f"Model directory: {MODEL_DIR}")

# Map model files to their names
model_mapping = {
    'resnet18_epoch_100_full.pth': 'ResNet18',
    'mobilenet_epoch_100_full.pth': 'MobileNetV2',
    'efficientnet_epoch_100_full.pth': 'EfficientNet',
    'vgg16_epoch_100_full.pth': 'VGG16',
    'alexnet_epoch_100_full.pth': 'AlexNet',
    'hybridmodel_epoch_100_full.pth': 'Enhanced-Hybrid-VGG-AlexNet',
    'cnn_model_epoch_100_full.pth': 'ProposedCNN'
}

# Load all available models
loaded_models = {}
model_devices = {}  # Track which device each model is on

print(f"\nLoading models from: {MODEL_DIR}\n")

for model_file, model_name in model_mapping.items():
    model_path = MODEL_DIR / model_file
    if model_path.exists():
        try:
            print(f"  Loading {model_name}...", end=" ")
            model = torch.load(model_path, map_location=primary_device, weights_only=False)
            model = model.to(primary_device)
            model.eval()
            loaded_models[model_name] = model
            model_devices[model_name] = primary_device
            print(f"")
        except Exception as e:
            print(f" Error: {e}")
    else:
        print(f"   {model_file} not found")

print(f"\nSuccessfully loaded {len(loaded_models)} models")
print("=" * 80)

# Run the benchmark
if loaded_models:
    def measure_inference_performance(model, test_loader, model_name, device, num_runs=100, batch_sizes=[1, 8, 16, 32]):
        """
        Comprehensive inference performance measurement for real-time deployment assessment.
        """
        results = {
            'model_name': model_name,
            'device': str(device),
            'batch_performance': {},
            'memory_usage': {}
        }

        model.eval()

        # Get model memory footprint (parameters only)
        param_size_mb = sum(p.nelement() * p.element_size() for p in model.parameters()) / (1024 * 1024)
        results['model_size_mb'] = param_size_mb

        # Get a sample batch from test data
        sample_batch, _ = next(iter(test_loader))
        sample_batch = sample_batch.to(device)

        # Test different batch sizes
        for batch_size in batch_sizes:
            if batch_size > len(sample_batch):
                continue

            X_batch = sample_batch[:batch_size]

            # Memory before inference
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                torch.cuda.synchronize()
            elif torch.backends.mps.is_available() and str(device) == 'mps':
                torch.mps.empty_cache()
                torch.mps.synchronize()

            process = psutil.Process()
            mem_before = process.memory_info().rss / (1024 * 1024)

            # Warmup
            with torch.no_grad():
                _ = model(X_batch)

            if torch.cuda.is_available() and str(device) == 'cuda':
                torch.cuda.synchronize()
            elif torch.backends.mps.is_available() and str(device) == 'mps':
                torch.mps.synchronize()

            # Timed runs
            inference_times = []
            for _ in range(num_runs):
                if torch.cuda.is_available() and str(device) == 'cuda':
                    torch.cuda.synchronize()
                elif torch.backends.mps.is_available() and str(device) == 'mps':
                    torch.mps.synchronize()

                start_time = time.perf_counter()
                with torch.no_grad():
                    _ = model(X_batch)

                if torch.cuda.is_available() and str(device) == 'cuda':
                    torch.cuda.synchronize()
                elif torch.backends.mps.is_available() and str(device) == 'mps':
                    torch.mps.synchronize()

                end_time = time.perf_counter()
                inference_times.append((end_time - start_time) * 1000)

            mem_after = process.memory_info().rss / (1024 * 1024)
            mem_delta = mem_after - mem_before

            # Statistics
            avg_time = np.mean(inference_times)
            std_time = np.std(inference_times)
            min_time = np.min(inference_times)
            max_time = np.max(inference_times)
            p95_time = np.percentile(inference_times, 95)
            p99_time = np.percentile(inference_times, 99)

            avg_time_per_sample = avg_time / batch_size
            throughput = (batch_size * 1000) / avg_time

            results['batch_performance'][batch_size] = {
                'avg_ms': avg_time,
                'std_ms': std_time,
                'min_ms': min_time,
                'max_ms': max_time,
                'p95_ms': p95_time,
                'p99_ms': p99_time,
                'per_sample_ms': avg_time_per_sample,
                'throughput_samples_per_sec': throughput,
                'memory_delta_mb': mem_delta
            }

        return results

    print("\n" + "=" * 80)
    print("INFERENCE PERFORMANCE BENCHMARK FOR REAL-TIME DEPLOYMENT")
    print("=" * 80)
    print("\nMeasuring inference time and memory usage across different batch sizes...")
    print("This assessment is critical for industrial deployment scenarios.\n")

    benchmark_results = {}
    batch_sizes = [1, 8, 16, 32]

    for model_name, model in loaded_models.items():
        device = model_devices[model_name]
        print(f"\n{''*80}")
        print(f"Benchmarking: {model_name}")
        print(f"{''*80}")

        try:
            results = measure_inference_performance(
                model=model,
                test_loader=test_loader_enhanced,
                model_name=model_name,
                device=device,
                num_runs=100,
                batch_sizes=batch_sizes
            )

            benchmark_results[model_name] = results

            print(f"\n  Model Size: {results['model_size_mb']:.2f} MB (parameters only)")
            print(f"\n  Batch Performance:")
            print(f"  {'Batch':>6} | {'Avg (ms)':>10} | {'Std (ms)':>10} | {'P95 (ms)':>10} | {'Per Sample':>12} | {'Throughput':>15} | {'Mem Δ':>10}")
            print(f"  {'-'*6}-+-{'-'*10}-+-{'-'*10}-+-{'-'*10}-+-{'-'*12}-+-{'-'*15}-+-{'-'*10}")

            for batch_size, perf in results['batch_performance'].items():
                print(f"  {batch_size:>6} | {perf['avg_ms']:>10.3f} | {perf['std_ms']:>10.3f} | "
                      f"{perf['p95_ms']:>10.3f} | {perf['per_sample_ms']:>10.3f} ms | "
                      f"{perf['throughput_samples_per_sec']:>12.1f} sps | {perf['memory_delta_mb']:>8.2f} MB")
        except Exception as e:
            print(f"   Error: {e}")

    # Comparative Analysis
    if benchmark_results:
        print("\n" + "=" * 80)
        print("COMPARATIVE ANALYSIS FOR DEPLOYMENT")
        print("=" * 80)

        print("\nReal-Time Performance (Batch Size = 1):")
        print(f"  {'Model':.<35} {'Latency (ms)':>15} {'Throughput (sps)':>20} {'Memory':>12}")
        print(f"  {'-'*35} {'-'*15} {'-'*20} {'-'*12}")

        realtime_perf = []
        for model_name, results in benchmark_results.items():
            if 1 in results['batch_performance']:
                perf = results['batch_performance'][1]
                realtime_perf.append({
                    'name': model_name,
                    'latency': perf['avg_ms'],
                    'throughput': perf['throughput_samples_per_sec'],
                    'memory': results['model_size_mb']
                })
                print(f"  {model_name:.<35} {perf['avg_ms']:>12.3f} ms {perf['throughput_samples_per_sec']:>17.1f} sps "
                      f"{results['model_size_mb']:>9.2f} MB")

        if realtime_perf:
            fastest = min(realtime_perf, key=lambda x: x['latency'])
            highest_throughput = max(realtime_perf, key=lambda x: x['throughput'])
            smallest = min(realtime_perf, key=lambda x: x['memory'])

            print("\n Recommendations for Real-Time Deployment:")
            print(f"  • Lowest Latency:     {fastest['name']} ({fastest['latency']:.3f} ms)")
            print(f"  • Highest Throughput: {highest_throughput['name']} ({highest_throughput['throughput']:.1f} samples/sec)")
            print(f"  • Smallest Memory:    {smallest['name']} ({smallest['memory']:.2f} MB)")

            print("\n Industrial Deployment Guidelines:")
            print("  ")
            print("   Real-time (< 10ms):      Edge devices, immediate response       ")
            print("   Near real-time (< 50ms): Quality control, inline inspection     ")
            print("   Batch (< 100ms):         Offline analysis, periodic monitoring  ")
            print("  ")

            print("\n  Model Classification for Deployment:")
            for item in realtime_perf:
                latency = item['latency']
                if latency < 10:
                    category = "Real-time capable"
                elif latency < 50:
                    category = " Near real-time"
                elif latency < 100:
                    category = " Batch processing"
                else:
                    category = " May need optimization"
                print(f"    {item['name']:.<35} {category}")

        print("\n Batch Processing Performance (Batch Size = 32):")
        if any(32 in results['batch_performance'] for results in benchmark_results.values()):
            print(f"  {'Model':.<35} {'Total (ms)':>12} {'Per Sample (ms)':>17} {'Throughput (sps)':>20}")
            print(f"  {'-'*35} {'-'*12} {'-'*17} {'-'*20}")

            for model_name, results in benchmark_results.items():
                if 32 in results['batch_performance']:
                    perf = results['batch_performance'][32]
                    print(f"  {model_name:.<35} {perf['avg_ms']:>12.3f} {perf['per_sample_ms']:>17.3f} "
                          f"{perf['throughput_samples_per_sec']:>17.1f} sps")

        print("\n" + "=" * 80)
        print("Inference Benchmarking Complete!")
        print("=" * 80)
        print(f"\nBenchmarked {len(benchmark_results)} models successfully")
        print("   Results saved in 'benchmark_results' variable for further analysis")
    else:
        print("\n No models successfully benchmarked.")
else:
    print("\n No models could be loaded from disk")
