In [1]:
# Install required packages
!pip install torch torchvision transformers timm pandas matplotlib seaborn tqdm fpdf grad-cam scikit-plot
!apt-get install -y python3-dev python3-setuptools python3-wheel

import os
import zipfile
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
from PIL import Image
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc, precision_recall_curve
from sklearn.preprocessing import label_binarize
from itertools import cycle
import warnings
warnings.filterwarnings('ignore')

# Import for Grad-CAM
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Extract the dataset
!unzip -o -q mango_leaf_dataset.zip -d mango_leaf_dataset

# Define the dataset class
class MangoLeafDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = ['Anthracnose', 'Bacterial Canker', 'Cutting Weevil',
                       'Die Back', 'Gall Midge', 'Healthy', 'Powdery Mildew', 'Sooty Mould']
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
        self.images = []
        self.labels = []

        # Load images
        for cls in self.classes:
            cls_dir = os.path.join(root_dir, cls)
            if os.path.exists(cls_dir):
                for img_name in os.listdir(cls_dir):
                    if img_name.endswith(('.jpg', '.jpeg', '.png')):
                        self.images.append(os.path.join(cls_dir, img_name))
                        self.labels.append(self.class_to_idx[cls])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label, img_path  # Added img_path for Grad-CAM

# Hyperparameters from the paper
BATCH_SIZE = 32
LEARNING_RATE = 0.001
NUM_EPOCHS = 10
IMAGE_SIZE = 224
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
DROPOUT = 0.5
SCHEDULER_STEP_SIZE = 7
SCHEDULER_GAMMA = 0.1

# Data transformations
train_transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=MEAN, std=STD)
])

test_transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=MEAN, std=STD)
])

# Load dataset
data_dir = "mango_leaf_dataset/DataSet"
dataset = MangoLeafDataset(data_dir, transform=train_transform)

# Split dataset (80% train, 10% validation, 10% test)
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    dataset, [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42)
)

# Apply test transform to validation and test sets
val_dataset.dataset.transform = test_transform
test_dataset.dataset.transform = test_transform

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"Total images: {len(dataset)}")
print(f"Train size: {len(train_dataset)}")
print(f"Validation size: {len(val_dataset)}")
print(f"Test size: {len(test_dataset)}")

# Function to visualize samples
def visualize_samples(dataset, num_samples=5):
    fig, axes = plt.subplots(len(dataset.classes), num_samples, figsize=(15, 12))

    for cls_idx, cls_name in enumerate(dataset.classes):
        # Get samples for this class
        class_indices = [i for i, label in enumerate(dataset.labels) if label == cls_idx]
        sample_indices = np.random.choice(class_indices, num_samples, replace=False)

        for i, idx in enumerate(sample_indices):
            img_path = dataset.images[idx]
            image = Image.open(img_path).convert('RGB')

            axes[cls_idx, i].imshow(image)
            axes[cls_idx, i].axis('off')
            if i == 0:
                axes[cls_idx, i].set_ylabel(cls_name, fontsize=12)

    plt.tight_layout()
    plt.savefig('sample_images.png', dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()

# Function to plot class distribution
def plot_class_distribution(dataset):
    class_counts = [dataset.labels.count(i) for i in range(len(dataset.classes))]

    plt.figure(figsize=(10, 6))
    plt.bar(dataset.classes, class_counts)
    plt.title('Class Distribution in Dataset')
    plt.xlabel('Class')
    plt.ylabel('Number of Images')
    plt.xticks(rotation=45)
    plt.savefig('class_distribution.png', dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()

    return class_counts

# Modified model training and evaluation functions to handle Hugging Face models
def train_model(model, model_name, train_loader, val_loader, num_epochs=NUM_EPOCHS):
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=SCHEDULER_STEP_SIZE, gamma=SCHEDULER_GAMMA)

    train_losses = []
    val_losses = []
    train_accs = []
    val_accs = []
    best_acc = 0.0

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for images, labels, _ in tqdm(train_loader, desc=f"{model_name} Epoch {epoch+1}/{num_epochs}"):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()

            # Handle different model output formats
            if model_name in ['ViT', 'Swin Transformer']:
                outputs = model(images)
                logits = outputs.logits
            else:
                logits = model(images)

            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(logits.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        scheduler.step()
        train_loss = running_loss / len(train_loader)
        train_acc = 100 * correct / total
        train_losses.append(train_loss)
        train_accs.append(train_acc)

        # Validation phase
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            for images, labels, _ in val_loader:
                images, labels = images.to(device), labels.to(device)

                # Handle different model output formats
                if model_name in ['ViT', 'Swin Transformer']:
                    outputs = model(images)
                    logits = outputs.logits
                else:
                    logits = model(images)

                loss = criterion(logits, labels)
                val_loss += loss.item()

                _, predicted = torch.max(logits.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_loss = val_loss / len(val_loader)
        val_acc = 100 * correct / total
        val_losses.append(val_loss)
        val_accs.append(val_acc)

        print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')

        # Save best model
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), f'{model_name}_best.pth')

    # Plot training history
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.title('Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(train_accs, label='Train Acc')
    plt.plot(val_accs, label='Val Acc')
    plt.title('Accuracy')
    plt.legend()
    plt.suptitle(f'{model_name} Training History')
    plt.savefig(f'{model_name}_training_history.png', dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()

    return best_acc, {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_accs': train_accs,
        'val_accs': val_accs
    }

def evaluate_model(model, model_name, test_loader):
    model.load_state_dict(torch.load(f'{model_name}_best.pth'))
    model.eval()

    all_preds = []
    all_labels = []
    all_probs = []
    all_image_paths = []

    with torch.no_grad():
        for images, labels, img_paths in test_loader:
            images, labels = images.to(device), labels.to(device)

            # Handle different model output formats
            if model_name in ['ViT', 'Swin Transformer']:
                outputs = model(images)
                logits = outputs.logits
            else:
                logits = model(images)

            probs = torch.nn.functional.softmax(logits, dim=1)
            _, predicted = torch.max(logits.data, 1)

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
            all_image_paths.extend(img_paths)

    # Calculate accuracy
    accuracy = 100 * np.sum(np.array(all_preds) == np.array(all_labels)) / len(all_labels)

    # Confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=dataset.classes, yticklabels=dataset.classes)
    plt.title(f'{model_name} Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.savefig(f'{model_name}_confusion_matrix.png', dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()

    # Classification report
    report = classification_report(all_labels, all_preds, target_names=dataset.classes, output_dict=True)

    # ROC Curve
    fpr = dict()
    tpr = dict()
    roc_auc = dict()

    # Binarize the output for ROC curve
    y_test_bin = label_binarize(all_labels, classes=range(len(dataset.classes)))
    n_classes = y_test_bin.shape[1]

    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(y_test_bin[:, i], np.array(all_probs)[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    # Compute micro-average ROC curve and ROC area
    fpr["micro"], tpr["micro"], _ = roc_curve(y_test_bin.ravel(), np.array(all_probs).ravel())
    roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

    # Plot ROC curve
    plt.figure(figsize=(10, 8))
    colors = cycle(['aqua', 'darkorange', 'cornflowerblue', 'green', 'red', 'purple', 'brown', 'pink'])
    for i, color in zip(range(n_classes), colors):
        plt.plot(fpr[i], tpr[i], color=color, lw=2,
                 label='ROC curve of class {0} (area = {1:0.2f})'
                 ''.format(dataset.classes[i], roc_auc[i]))

    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f'{model_name} ROC Curve')
    plt.legend(loc="lower right")
    plt.savefig(f'{model_name}_roc_curve.png', dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()

    # Precision-Recall curve
    precision = dict()
    recall = dict()
    average_precision = dict()

    for i in range(n_classes):
        precision[i], recall[i], _ = precision_recall_curve(y_test_bin[:, i], np.array(all_probs)[:, i])
        average_precision[i] = auc(recall[i], precision[i])

    # Plot Precision-Recall curve
    plt.figure(figsize=(10, 8))
    colors = cycle(['aqua', 'darkorange', 'cornflowerblue', 'green', 'red', 'purple', 'brown', 'pink'])
    for i, color in zip(range(n_classes), colors):
        plt.plot(recall[i], precision[i], color=color, lw=2,
                 label='Precision-Recall curve of class {0} (area = {1:0.2f})'
                 ''.format(dataset.classes[i], average_precision[i]))

    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title(f'{model_name} Precision-Recall Curve')
    plt.legend(loc="lower left")
    plt.savefig(f'{model_name}_pr_curve.png', dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()

    return accuracy, report, all_probs, all_image_paths, all_preds, all_labels, roc_auc, average_precision

# Model definitions
def create_alexnet():
    model = models.alexnet(pretrained=True)
    model.classifier[6] = nn.Linear(4096, 8)  # 8 classes
    return model

def create_resnet():
    model = models.resnet50(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, 8)  # 8 classes
    return model

def create_vgg16():
    model = models.vgg16(pretrained=True)
    model.classifier[6] = nn.Linear(4096, 8)  # 8 classes
    return model

def create_vit():
    from transformers import ViTForImageClassification, ViTConfig

    config = ViTConfig.from_pretrained('google/vit-base-patch16-224')
    config.num_labels = 8
    model = ViTForImageClassification(config)
    return model

def create_swin():
    from transformers import SwinForImageClassification, SwinConfig

    config = SwinConfig.from_pretrained('microsoft/swin-tiny-patch4-window7-224')
    config.num_labels = 8
    model = SwinForImageClassification(config)
    return model

# Fusion models
class FusionModel(nn.Module):
    def __init__(self, model1, model2, num_classes=8):
        super(FusionModel, self).__init__()
        self.model1 = model1
        self.model2 = model2

        # Remove the final classification layer from both models
        if hasattr(model1, 'classifier'):
            model1.classifier = nn.Sequential(*list(model1.classifier.children())[:-1])
        elif hasattr(model1, 'fc'):
            model1.fc = nn.Identity()

        if hasattr(model2, 'classifier'):
            model2.classifier = nn.Sequential(*list(model2.classifier.children())[:-1])
        elif hasattr(model2, 'fc'):
            model2.fc = nn.Identity()

        # Determine the output dimensions
        with torch.no_grad():
            dummy_input = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE)
            out1 = model1(dummy_input)
            out2 = model2(dummy_input)
            if isinstance(out1, tuple):  # For Hugging Face models
                out1 = out1.logits
            if isinstance(out2, tuple):  # For Hugging Face models
                out2 = out2.logits

            self.feature_dim = out1.shape[1] + out2.shape[1]

        # Final classifier
        self.classifier = nn.Sequential(
            nn.Linear(self.feature_dim, 512),
            nn.ReLU(),
            nn.Dropout(DROPOUT),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        features1 = self.model1(x)
        features2 = self.model2(x)

        if isinstance(features1, tuple):  # For Hugging Face models
            features1 = features1.logits
        if isinstance(features2, tuple):  # For Hugging Face models
            features2 = features2.logits

        combined = torch.cat((features1, features2), dim=1)
        return self.classifier(combined)

def create_fusion_resnet_vgg16():
    model1 = create_resnet()
    model2 = create_vgg16()
    return FusionModel(model1, model2)

def create_fusion_resnet_alexnet():
    model1 = create_resnet()
    model2 = create_alexnet()
    return FusionModel(model1, model2)

def create_fusion_resnet_alexnet_gradcam():
    model1 = create_resnet()
    model2 = create_alexnet()
    return FusionModel(model1, model2)

# Function to apply Grad-CAM
def apply_grad_cam(model, model_name, test_loader, results):
    model.load_state_dict(torch.load(f'{model_name}_best.pth'))
    model.eval()

    # Create Grad-CAM object
    target_layers = []
    if hasattr(model, 'features'):
        target_layers = [model.features[-1]]
    elif hasattr(model, 'layer4'):
        target_layers = [model.layer4[-1]]
    elif hasattr(model, 'encoder'):
        target_layers = [model.encoder.layers[-1].norm1]

    if not target_layers:
        print(f"Could not find target layers for Grad-CAM in {model_name}")
        return

    cam = GradCAM(model=model, target_layers=target_layers)

    # Get a batch of images
    images, labels, img_paths = next(iter(test_loader))
    images = images.to(device)

    # Generate CAM for each image
    grayscale_cams = cam(input_tensor=images)

    # Visualize and save CAM results
    fig, axes = plt.subplots(2, 4, figsize=(15, 8))
    for i, (image, grayscale_cam) in enumerate(zip(images, grayscale_cams)):
        if i >= 8:  # Only show 8 examples
            break

        rgb_img = image.cpu().permute(1, 2, 0).numpy()
        rgb_img = (rgb_img - rgb_img.min()) / (rgb_img.max() - rgb_img.min())

        visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)

        row, col = i // 4, i % 4
        axes[row, col].imshow(visualization)
        axes[row, col].axis('off')
        axes[row, col].set_title(f'True: {dataset.classes[labels[i]]}')

    plt.tight_layout()
    plt.savefig(f'{model_name}_grad_cam.png', dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()

    print(f"Grad-CAM visualization saved as {model_name}_grad_cam.png")

# Function to create a comprehensive PDF report at the end
def create_comprehensive_pdf(results, paper_results, dataset, train_dataset, val_dataset, test_dataset):
    from fpdf import FPDF
    import os

    pdf = FPDF()
    pdf.set_auto_page_break(auto=True, margin=15)

    # Title Page
    pdf.add_page()
    pdf.set_font("Arial", 'B', 18)
    pdf.cell(0, 10, "Mango Leaf Disease Classification", 0, 1, "C")
    pdf.ln(10)

    pdf.set_font("Arial", '', 12)
    pdf.multi_cell(0, 10, "This report consolidates the training, evaluation, and comparison "
                          "of different models (AlexNet, ResNet, VGG-16, ViT, Swin Transformer) "
                          "for Mango Leaf Disease Classification.")
    pdf.ln(10)

    # Dataset Information
    pdf.set_font("Arial", 'B', 14)
    pdf.cell(0, 10, "Dataset Information", 0, 1)
    pdf.set_font("Arial", '', 12)
    pdf.cell(0, 10, f"Total images: {len(dataset)}", 0, 1)
    pdf.cell(0, 10, f"Training set: {len(train_dataset)}", 0, 1)
    pdf.cell(0, 10, f"Validation set: {len(val_dataset)}", 0, 1)
    pdf.cell(0, 10, f"Test set: {len(test_dataset)}", 0, 1)
    pdf.ln(10)

    # Add sample images
    if os.path.exists('sample_images.png'):
        pdf.add_page()
        pdf.set_font("Arial", 'B', 12)
        pdf.cell(0, 10, "Sample Images from Dataset", 0, 1)
        pdf.image('sample_images.png', x=10, y=20, w=180)
        pdf.ln(100)

    # Add class distribution
    if os.path.exists('class_distribution.png'):
        pdf.add_page()
        pdf.set_font("Arial", 'B', 12)
        pdf.cell(0, 10, "Class Distribution", 0, 1)
        pdf.image('class_distribution.png', x=10, y=20, w=180)
        pdf.ln(100)

    # Model Comparison Table
    pdf.add_page()
    pdf.set_font("Arial", 'B', 14)
    pdf.cell(0, 10, "Model Performance Comparison", 0, 1, "C")
    pdf.set_font("Arial", 'B', 12)
    pdf.cell(60, 10, "Model", 1)
    pdf.cell(40, 10, "Our Accuracy (%)", 1)
    pdf.cell(40, 10, "Paper Accuracy (%)", 1)
    pdf.cell(40, 10, "Difference (%)", 1)
    pdf.ln()

    pdf.set_font("Arial", '', 12)
    for model_name in results:
        our_acc = results[model_name]["Test Accuracy"]
        paper_acc = paper_results[model_name]
        diff = our_acc - paper_acc
        pdf.cell(60, 10, model_name, 1)
        pdf.cell(40, 10, f"{our_acc:.2f}", 1)
        pdf.cell(40, 10, f"{paper_acc:.2f}", 1)
        pdf.cell(40, 10, f"{diff:+.2f}", 1)
        pdf.ln()
    pdf.ln(10)

    # Detailed comparison table with more metrics
    pdf.add_page()
    pdf.set_font("Arial", 'B', 14)
    pdf.cell(0, 10, "Detailed Model Comparison", 0, 1, "C")
    pdf.set_font("Arial", 'B', 10)
    pdf.cell(40, 10, "Model", 1)
    pdf.cell(25, 10, "Accuracy", 1)
    pdf.cell(25, 10, "Precision", 1)
    pdf.cell(25, 10, "Recall", 1)
    pdf.cell(25, 10, "F1-Score", 1)
    pdf.cell(25, 10, "ROC AUC", 1)
    pdf.cell(25, 10, "Avg Precision", 1)
    pdf.ln()

    pdf.set_font("Arial", '', 10)
    for model_name in results:
        report = results[model_name]["Classification Report"]
        roc_auc = results[model_name]["ROC AUC"]["micro"]
        avg_precision = np.mean(list(results[model_name]["Average Precision"].values()))

        pdf.cell(40, 10, model_name, 1)
        pdf.cell(25, 10, f"{results[model_name]['Test Accuracy']:.2f}", 1)
        pdf.cell(25, 10, f"{report['weighted avg']['precision']:.2f}", 1)
        pdf.cell(25, 10, f"{report['weighted avg']['recall']:.2f}", 1)
        pdf.cell(25, 10, f"{report['weighted avg']['f1-score']:.2f}", 1)
        pdf.cell(25, 10, f"{roc_auc:.2f}", 1)
        pdf.cell(25, 10, f"{avg_precision:.2f}", 1)
        pdf.ln()
    pdf.ln(10)

    # Add results for each model
    for model_name in results:
        pdf.add_page()
        pdf.set_font("Arial", 'B', 16)
        pdf.cell(0, 10, f"{model_name} - Results", 0, 1, "C")
        pdf.ln(5)

        # Accuracy Info
        acc = results[model_name]["Test Accuracy"]
        paper_acc = paper_results[model_name]
        pdf.set_font("Arial", '', 12)
        pdf.cell(0, 10, f"Test Accuracy: {acc:.2f}%", 0, 1)
        pdf.cell(0, 10, f"Paper Accuracy: {paper_acc:.2f}%", 0, 1)
        pdf.cell(0, 10, f"Difference: {acc - paper_acc:+.2f}%", 0, 1)
        pdf.ln(5)

        # Training history plot
        if os.path.exists(f"{model_name}_training_history.png"):
            pdf.set_font("Arial", 'B', 12)
            pdf.cell(0, 10, "Training History", 0, 1)
            pdf.image(f"{model_name}_training_history.png", x=10, y=30, w=180)
            pdf.ln(85)

        # Confusion Matrix
        if os.path.exists(f"{model_name}_confusion_matrix.png"):
            pdf.add_page()
            pdf.set_font("Arial", 'B', 12)
            pdf.cell(0, 10, "Confusion Matrix", 0, 1)
            pdf.image(f"{model_name}_confusion_matrix.png", x=10, y=20, w=180)
            pdf.ln(85)

        # ROC Curve
        if os.path.exists(f"{model_name}_roc_curve.png"):
            pdf.add_page()
            pdf.set_font("Arial", 'B', 12)
            pdf.cell(0, 10, "ROC Curve", 0, 1)
            pdf.image(f"{model_name}_roc_curve.png", x=10, y=20, w=180)
            pdf.ln(85)

        # Precision-Recall Curve
        if os.path.exists(f"{model_name}_pr_curve.png"):
            pdf.add_page()
            pdf.set_font("Arial", 'B', 12)
            pdf.cell(0, 10, "Precision-Recall Curve", 0, 1)
            pdf.image(f"{model_name}_pr_curve.png", x=10, y=20, w=180)
            pdf.ln(85)

        # Grad-CAM visualization if available
        if 'Grad-CAM' in model_name and os.path.exists(f"{model_name}_grad_cam.png"):
            pdf.add_page()
            pdf.set_font("Arial", 'B', 12)
            pdf.cell(0, 10, "Grad-CAM Visualization", 0, 1)
            pdf.image(f"{model_name}_grad_cam.png", x=10, y=20, w=180)
            pdf.ln(85)

        # Classification Report
        pdf.add_page()
        pdf.set_font("Arial", 'B', 12)
        pdf.cell(0, 10, "Classification Report", 0, 1)

        report = results[model_name]["Classification Report"]

        # Table headers
        pdf.set_font("Arial", 'B', 10)
        pdf.cell(40, 10, "Class", 1)
        pdf.cell(30, 10, "Precision", 1)
        pdf.cell(30, 10, "Recall", 1)
        pdf.cell(30, 10, "F1-Score", 1)
        pdf.cell(30, 10, "Support", 1)
        pdf.ln()

        pdf.set_font("Arial", '', 10)
        for cls in dataset.classes:
            pdf.cell(40, 10, cls, 1)
            pdf.cell(30, 10, f"{report[cls]['precision']:.2f}", 1)
            pdf.cell(30, 10, f"{report[cls]['recall']:.2f}", 1)
            pdf.cell(30, 10, f"{report[cls]['f1-score']:.2f}", 1)
            pdf.cell(30, 10, str(report[cls]['support']), 1)
            pdf.ln()

        # Add macro avg
        pdf.cell(40, 10, "Macro Avg", 1)
        pdf.cell(30, 10, f"{report['macro avg']['precision']:.2f}", 1)
        pdf.cell(30, 10, f"{report['macro avg']['recall']:.2f}", 1)
        pdf.cell(30, 10, f"{report['macro avg']['f1-score']:.2f}", 1)
        pdf.cell(30, 10, str(report['macro avg']['support']), 1)
        pdf.ln()

        # Weighted avg
        pdf.cell(40, 10, "Weighted Avg", 1)
        pdf.cell(30, 10, f"{report['weighted avg']['precision']:.2f}", 1)
        pdf.cell(30, 10, f"{report['weighted avg']['recall']:.2f}", 1)
        pdf.cell(30, 10, f"{report['weighted avg']['f1-score']:.2f}", 1)
        pdf.cell(30, 10, str(report['weighted avg']['support']), 1)
        pdf.ln()

    # Save final PDF
    pdf.output("Comprehensive_Report.pdf")
    print("✅ Comprehensive report saved as Comprehensive_Report.pdf")

# Visualize the dataset
print("Sample images from each class:")
visualize_samples(dataset)

print("Class distribution:")
class_counts = plot_class_distribution(dataset)
for cls, count in zip(dataset.classes, class_counts):
    print(f"{cls}: {count} images")

# Main execution
# Dictionary to store results
results = {}

# Paper results for comparison
paper_results = {
    'AlexNet': 69.08,
    'ResNet': 91.33,
    'VGG-16': 84.92,
    'ViT': 98.50,
    'Swin Transformer': 96.55,
    'Fusion ResNet and VGG-16': 97.17,
    'Fusion ResNet and AlexNet': 97.65,
    'Fusion ResNet and AlexNet with Grad-CAM': 99.97
}

# List of models to train
models_to_train = {
    'AlexNet': create_alexnet(),
    'ResNet': create_resnet(),
    'VGG-16': create_vgg16(),
    'ViT': create_vit(),
    'Swin Transformer': create_swin(),
    'Fusion ResNet and VGG-16': create_fusion_resnet_vgg16(),
    'Fusion ResNet and AlexNet': create_fusion_resnet_alexnet(),
    'Fusion ResNet and AlexNet with Grad-CAM': create_fusion_resnet_alexnet_gradcam()
}

# Train and evaluate each model
for model_name, model in models_to_train.items():
    print(f"\n=== Training {model_name} ===")
    best_val_acc, train_history = train_model(model, model_name, train_loader, val_loader, num_epochs=NUM_EPOCHS)

    print(f"\n=== Evaluating {model_name} ===")
    test_acc, report, probs, img_paths, preds, labels, roc_auc, avg_precision = evaluate_model(model, model_name, test_loader)

    results[model_name] = {
        'Validation Accuracy': best_val_acc,
        'Test Accuracy': test_acc,
        'Train History': train_history,
        'Classification Report': report,
        'Probabilities': probs,
        'Image Paths': img_paths,
        'Predictions': preds,
        'Labels': labels,
        'ROC AUC': roc_auc,
        'Average Precision': avg_precision
    }

    print(f"{model_name} - Validation Accuracy: {best_val_acc:.2f}%, Test Accuracy: {test_acc:.2f}%")

    # Apply Grad-CAM for the fusion model with Grad-CAM
    if 'Grad-CAM' in model_name:
        apply_grad_cam(model, model_name, test_loader, results)

# Create comprehensive PDF report at the end
create_comprehensive_pdf(results, paper_results, dataset, train_dataset, val_dataset, test_dataset)

# Create a results directory and move all reports
!mkdir -p model_reports
!mv *.pdf model_reports/
!mv *.png model_reports/
!mv *.pth model_reports/

print("\nAll reports have been saved in the 'model_reports' directory")

# Print final comparison
print("\n=== Final Comparison with Paper Results ===")
print("Model\t\tOur Result\tPaper Result\tDifference")
print("-" * 55)
for model_name in results:
    our_acc = results[model_name]['Test Accuracy']
    paper_acc = paper_results[model_name]
    diff = our_acc - paper_acc
    print(f"{model_name:25}\t{our_acc:.2f}%\t\t{paper_acc:.2f}%\t\t{diff:+.2f}%")

Output hidden; open in https://colab.research.google.com to view.