In [1]:
# Required imports
import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset, Dataset
from torchvision import datasets, transforms, models
from torchvision.models.mobilenetv2 import InvertedResidual
from torchvision.models import MobileNet_V2_Weights
from sklearn.metrics import classification_report, confusion_matrix, f1_score, precision_score, recall_score
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from tqdm import tqdm
import matplotlib
matplotlib.use('Agg')  # non-GUI backend

# Import BAM
sys.path.append('../models')
from bam import BAM  # تأكدي أن ملف bam.py موجود ويحتوي على الكلاس الصحيح

# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Transforms
weights = MobileNet_V2_Weights.DEFAULT
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    weights.transforms()
])
val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    weights.transforms()
])

# Load dataset
full_train_dataset = datasets.ImageFolder('../../dataset_split/train', transform=train_transforms)
full_val_dataset = datasets.ImageFolder('../../dataset_split/val', transform=val_transforms)

# Select classes
selected_classes = ['Tomato___Late_blight', 'Potato___Early_blight']
selected_class_indices = [full_train_dataset.class_to_idx[cls] for cls in selected_classes]
index_mapping = {orig_idx: new_idx for new_idx, orig_idx in enumerate(selected_class_indices)}

class ReindexedSubset(Dataset):
    def __init__(self, subset, class_indices_mapping):
        self.subset = subset
        self.mapping = class_indices_mapping

    def __getitem__(self, index):
        x, y = self.subset[index]
        y = self.mapping[y]
        return x, y

    def __len__(self):
        return len(self.subset)

def filter_and_remap(dataset, class_indices, mapping):
    indices = [i for i, (_, label) in enumerate(dataset) if label in class_indices]
    subset = Subset(dataset, indices)
    return ReindexedSubset(subset, mapping)

train_dataset = filter_and_remap(full_train_dataset, selected_class_indices, index_mapping)
val_dataset = filter_and_remap(full_val_dataset, selected_class_indices, index_mapping)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Define model
class MobileNetV2_BAM(nn.Module):
    def __init__(self, base_model, num_classes):
        super(MobileNetV2_BAM, self).__init__()
        self.features = nn.Sequential()
        for name, module in base_model.features._modules.items():
            self.features.add_module(name, module)
            if isinstance(module, InvertedResidual):
                conv_layers = [layer for layer in module.conv if isinstance(layer, nn.Conv2d)]
                out_channels = conv_layers[-1].out_channels if conv_layers else base_model.last_channel
                self.features.add_module(f"bam_{name}", BAM(out_channels))
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(base_model.last_channel, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

base_model = models.mobilenet_v2(weights=weights)
model = MobileNetV2_BAM(base_model, len(selected_classes)).to(device)

# Training setup
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)

# Train
NUM_EPOCHS = 100
for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}"):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        running_loss += loss.item()

    scheduler.step()
    acc = correct / total
    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Loss: {running_loss:.4f}, Accuracy: {acc:.4f}")

# Save model
torch.save(model.state_dict(), "mobilenet_v2_bam.pth")

# Evaluate
model.eval()
all_labels, all_preds = [], []

with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(predicted.cpu().numpy())

# Reports
print(classification_report(all_labels, all_preds, target_names=selected_classes))
cm = confusion_matrix(all_labels, all_preds)
sns.heatmap(cm, annot=True, cmap="Blues", xticklabels=selected_classes, yticklabels=selected_classes)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix")
plt.savefig("confMatrix_bam.png")

f1 = f1_score(all_labels, all_preds, average='weighted')
precision = precision_score(all_labels, all_preds, average='weighted')
recall = recall_score(all_labels, all_preds, average='weighted')
print(f"F1 Score: {f1:.4f} | Precision: {precision:.4f} | Recall: {recall:.4f}")


Epoch 1/100: 100%|██████████| 91/91 [01:59<00:00,  1.32s/it]


Epoch [1/100], Loss: 13.2449, Accuracy: 0.9396


Epoch 2/100: 100%|██████████| 91/91 [01:59<00:00,  1.31s/it]


Epoch [2/100], Loss: 4.7356, Accuracy: 0.9906


Epoch 3/100: 100%|██████████| 91/91 [02:00<00:00,  1.32s/it]


Epoch [3/100], Loss: 4.9882, Accuracy: 0.9920


Epoch 4/100: 100%|██████████| 91/91 [01:59<00:00,  1.31s/it]


Epoch [4/100], Loss: 1.5615, Accuracy: 0.9948


Epoch 5/100: 100%|██████████| 91/91 [02:00<00:00,  1.32s/it]


Epoch [5/100], Loss: 1.6530, Accuracy: 0.9976


Epoch 6/100: 100%|██████████| 91/91 [01:58<00:00,  1.31s/it]


Epoch [6/100], Loss: 1.8081, Accuracy: 0.9962


Epoch 7/100: 100%|██████████| 91/91 [02:00<00:00,  1.32s/it]


Epoch [7/100], Loss: 0.8689, Accuracy: 0.9983


Epoch 8/100: 100%|██████████| 91/91 [01:59<00:00,  1.32s/it]


Epoch [8/100], Loss: 0.5705, Accuracy: 0.9993


Epoch 9/100: 100%|██████████| 91/91 [01:59<00:00,  1.32s/it]


Epoch [9/100], Loss: 0.7667, Accuracy: 0.9976


Epoch 10/100: 100%|██████████| 91/91 [02:01<00:00,  1.34s/it]


Epoch [10/100], Loss: 1.2733, Accuracy: 0.9986


Epoch 11/100: 100%|██████████| 91/91 [02:02<00:00,  1.34s/it]


Epoch [11/100], Loss: 1.9817, Accuracy: 0.9986


Epoch 12/100: 100%|██████████| 91/91 [02:00<00:00,  1.32s/it]


Epoch [12/100], Loss: 1.8965, Accuracy: 0.9972


Epoch 13/100: 100%|██████████| 91/91 [02:02<00:00,  1.34s/it]


Epoch [13/100], Loss: 0.7132, Accuracy: 0.9986


Epoch 14/100: 100%|██████████| 91/91 [02:00<00:00,  1.32s/it]


Epoch [14/100], Loss: 0.9195, Accuracy: 0.9972


Epoch 15/100: 100%|██████████| 91/91 [02:00<00:00,  1.33s/it]


Epoch [15/100], Loss: 1.5582, Accuracy: 0.9962


Epoch 16/100: 100%|██████████| 91/91 [02:00<00:00,  1.32s/it]


Epoch [16/100], Loss: 1.9945, Accuracy: 0.9962


Epoch 17/100: 100%|██████████| 91/91 [02:00<00:00,  1.33s/it]


Epoch [17/100], Loss: 1.3166, Accuracy: 0.9979


Epoch 18/100: 100%|██████████| 91/91 [02:01<00:00,  1.33s/it]


Epoch [18/100], Loss: 0.8326, Accuracy: 0.9976


Epoch 19/100: 100%|██████████| 91/91 [02:01<00:00,  1.33s/it]


Epoch [19/100], Loss: 1.0640, Accuracy: 0.9969


Epoch 20/100: 100%|██████████| 91/91 [02:00<00:00,  1.33s/it]


Epoch [20/100], Loss: 4.4416, Accuracy: 0.9979


Epoch 21/100: 100%|██████████| 91/91 [01:59<00:00,  1.31s/it]


Epoch [21/100], Loss: 0.9809, Accuracy: 0.9986


Epoch 22/100: 100%|██████████| 91/91 [01:58<00:00,  1.31s/it]


Epoch [22/100], Loss: 0.9117, Accuracy: 0.9972


Epoch 23/100: 100%|██████████| 91/91 [02:00<00:00,  1.33s/it]


Epoch [23/100], Loss: 0.7159, Accuracy: 0.9983


Epoch 24/100: 100%|██████████| 91/91 [02:01<00:00,  1.33s/it]


Epoch [24/100], Loss: 0.6904, Accuracy: 0.9986


Epoch 25/100: 100%|██████████| 91/91 [02:00<00:00,  1.33s/it]


Epoch [25/100], Loss: 0.9589, Accuracy: 0.9962


Epoch 26/100: 100%|██████████| 91/91 [02:00<00:00,  1.32s/it]


Epoch [26/100], Loss: 3.0260, Accuracy: 0.9979


Epoch 27/100: 100%|██████████| 91/91 [02:00<00:00,  1.32s/it]


Epoch [27/100], Loss: 1.1737, Accuracy: 0.9969


Epoch 28/100: 100%|██████████| 91/91 [02:00<00:00,  1.32s/it]


Epoch [28/100], Loss: 1.0019, Accuracy: 0.9969


Epoch 29/100: 100%|██████████| 91/91 [02:00<00:00,  1.32s/it]


Epoch [29/100], Loss: 2.5321, Accuracy: 0.9962


Epoch 30/100: 100%|██████████| 91/91 [02:00<00:00,  1.32s/it]


Epoch [30/100], Loss: 0.9016, Accuracy: 0.9962


Epoch 31/100: 100%|██████████| 91/91 [02:00<00:00,  1.32s/it]


Epoch [31/100], Loss: 2.8945, Accuracy: 0.9972


Epoch 32/100: 100%|██████████| 91/91 [01:59<00:00,  1.31s/it]


Epoch [32/100], Loss: 1.0486, Accuracy: 0.9976


Epoch 33/100: 100%|██████████| 91/91 [02:00<00:00,  1.33s/it]


Epoch [33/100], Loss: 1.9486, Accuracy: 0.9986


Epoch 34/100: 100%|██████████| 91/91 [01:59<00:00,  1.32s/it]


Epoch [34/100], Loss: 2.9498, Accuracy: 0.9969


Epoch 35/100: 100%|██████████| 91/91 [01:59<00:00,  1.32s/it]


Epoch [35/100], Loss: 3.1994, Accuracy: 0.9979


Epoch 36/100: 100%|██████████| 91/91 [02:00<00:00,  1.32s/it]


Epoch [36/100], Loss: 0.9607, Accuracy: 0.9976


Epoch 37/100: 100%|██████████| 91/91 [02:01<00:00,  1.33s/it]


Epoch [37/100], Loss: 0.8041, Accuracy: 0.9979


Epoch 38/100: 100%|██████████| 91/91 [02:00<00:00,  1.33s/it]


Epoch [38/100], Loss: 1.5943, Accuracy: 0.9969


Epoch 39/100: 100%|██████████| 91/91 [02:01<00:00,  1.33s/it]


Epoch [39/100], Loss: 0.9006, Accuracy: 0.9969


Epoch 40/100: 100%|██████████| 91/91 [02:01<00:00,  1.33s/it]


Epoch [40/100], Loss: 0.7361, Accuracy: 0.9983


Epoch 41/100: 100%|██████████| 91/91 [02:00<00:00,  1.33s/it]


Epoch [41/100], Loss: 0.6773, Accuracy: 0.9979


Epoch 42/100: 100%|██████████| 91/91 [01:59<00:00,  1.31s/it]


Epoch [42/100], Loss: 0.5400, Accuracy: 0.9986


Epoch 43/100: 100%|██████████| 91/91 [02:00<00:00,  1.32s/it]


Epoch [43/100], Loss: 0.7451, Accuracy: 0.9986


Epoch 44/100: 100%|██████████| 91/91 [02:01<00:00,  1.34s/it]


Epoch [44/100], Loss: 0.8705, Accuracy: 0.9965


Epoch 45/100: 100%|██████████| 91/91 [02:00<00:00,  1.33s/it]


Epoch [45/100], Loss: 4.6041, Accuracy: 0.9972


Epoch 46/100: 100%|██████████| 91/91 [02:01<00:00,  1.34s/it]


Epoch [46/100], Loss: 0.7263, Accuracy: 0.9986


Epoch 47/100: 100%|██████████| 91/91 [01:59<00:00,  1.32s/it]


Epoch [47/100], Loss: 0.8099, Accuracy: 0.9976


Epoch 48/100: 100%|██████████| 91/91 [02:01<00:00,  1.34s/it]


Epoch [48/100], Loss: 0.5718, Accuracy: 0.9986


Epoch 49/100: 100%|██████████| 91/91 [02:00<00:00,  1.32s/it]


Epoch [49/100], Loss: 1.9110, Accuracy: 0.9969


Epoch 50/100: 100%|██████████| 91/91 [02:00<00:00,  1.33s/it]


Epoch [50/100], Loss: 0.8192, Accuracy: 0.9979


Epoch 51/100: 100%|██████████| 91/91 [01:59<00:00,  1.32s/it]


Epoch [51/100], Loss: 0.7873, Accuracy: 0.9976


Epoch 52/100: 100%|██████████| 91/91 [02:00<00:00,  1.33s/it]


Epoch [52/100], Loss: 0.7920, Accuracy: 0.9979


Epoch 53/100: 100%|██████████| 91/91 [02:00<00:00,  1.33s/it]


Epoch [53/100], Loss: 0.7490, Accuracy: 0.9972


Epoch 54/100: 100%|██████████| 91/91 [02:01<00:00,  1.33s/it]


Epoch [54/100], Loss: 2.9324, Accuracy: 0.9969


Epoch 55/100: 100%|██████████| 91/91 [01:59<00:00,  1.31s/it]


Epoch [55/100], Loss: 1.4573, Accuracy: 0.9941


Epoch 56/100: 100%|██████████| 91/91 [02:01<00:00,  1.33s/it]


Epoch [56/100], Loss: 1.2258, Accuracy: 0.9958


Epoch 57/100: 100%|██████████| 91/91 [02:01<00:00,  1.34s/it]


Epoch [57/100], Loss: 1.6282, Accuracy: 0.9986


Epoch 58/100: 100%|██████████| 91/91 [02:02<00:00,  1.34s/it]


Epoch [58/100], Loss: 0.9577, Accuracy: 0.9976


Epoch 59/100: 100%|██████████| 91/91 [02:02<00:00,  1.34s/it]


Epoch [59/100], Loss: 0.8913, Accuracy: 0.9972


Epoch 60/100: 100%|██████████| 91/91 [01:58<00:00,  1.31s/it]


Epoch [60/100], Loss: 1.1348, Accuracy: 0.9955


Epoch 61/100: 100%|██████████| 91/91 [02:01<00:00,  1.34s/it]


Epoch [61/100], Loss: 0.8244, Accuracy: 0.9986


Epoch 62/100: 100%|██████████| 91/91 [02:00<00:00,  1.33s/it]


Epoch [62/100], Loss: 5.1451, Accuracy: 0.9965


Epoch 63/100: 100%|██████████| 91/91 [02:02<00:00,  1.35s/it]


Epoch [63/100], Loss: 1.0182, Accuracy: 0.9969


Epoch 64/100: 100%|██████████| 91/91 [02:01<00:00,  1.34s/it]


Epoch [64/100], Loss: 2.0008, Accuracy: 0.9983


Epoch 65/100: 100%|██████████| 91/91 [02:02<00:00,  1.35s/it]


Epoch [65/100], Loss: 0.8601, Accuracy: 0.9972


Epoch 66/100: 100%|██████████| 91/91 [02:00<00:00,  1.33s/it]


Epoch [66/100], Loss: 2.3092, Accuracy: 0.9969


Epoch 67/100: 100%|██████████| 91/91 [02:00<00:00,  1.32s/it]


Epoch [67/100], Loss: 2.3904, Accuracy: 0.9979


Epoch 68/100: 100%|██████████| 91/91 [02:01<00:00,  1.33s/it]


Epoch [68/100], Loss: 1.3079, Accuracy: 0.9969


Epoch 69/100: 100%|██████████| 91/91 [02:01<00:00,  1.34s/it]


Epoch [69/100], Loss: 2.0648, Accuracy: 0.9986


Epoch 70/100: 100%|██████████| 91/91 [02:01<00:00,  1.33s/it]


Epoch [70/100], Loss: 0.8881, Accuracy: 0.9972


Epoch 71/100: 100%|██████████| 91/91 [02:01<00:00,  1.33s/it]


Epoch [71/100], Loss: 4.9998, Accuracy: 0.9965


Epoch 72/100: 100%|██████████| 91/91 [02:01<00:00,  1.33s/it]


Epoch [72/100], Loss: 0.8174, Accuracy: 0.9979


Epoch 73/100: 100%|██████████| 91/91 [02:01<00:00,  1.33s/it]


Epoch [73/100], Loss: 0.7439, Accuracy: 0.9979


Epoch 74/100: 100%|██████████| 91/91 [02:01<00:00,  1.34s/it]


Epoch [74/100], Loss: 0.7864, Accuracy: 0.9979


Epoch 75/100: 100%|██████████| 91/91 [02:02<00:00,  1.34s/it]


Epoch [75/100], Loss: 0.6334, Accuracy: 0.9990


Epoch 76/100: 100%|██████████| 91/91 [02:00<00:00,  1.33s/it]


Epoch [76/100], Loss: 0.8850, Accuracy: 0.9965


Epoch 77/100: 100%|██████████| 91/91 [02:02<00:00,  1.35s/it]


Epoch [77/100], Loss: 1.4116, Accuracy: 0.9945


Epoch 78/100: 100%|██████████| 91/91 [02:01<00:00,  1.33s/it]


Epoch [78/100], Loss: 1.0151, Accuracy: 0.9990


Epoch 79/100: 100%|██████████| 91/91 [02:01<00:00,  1.34s/it]


Epoch [79/100], Loss: 3.6287, Accuracy: 0.9955


Epoch 80/100: 100%|██████████| 91/91 [02:00<00:00,  1.32s/it]


Epoch [80/100], Loss: 0.7050, Accuracy: 0.9983


Epoch 81/100: 100%|██████████| 91/91 [02:01<00:00,  1.33s/it]


Epoch [81/100], Loss: 4.2650, Accuracy: 0.9969


Epoch 82/100: 100%|██████████| 91/91 [02:01<00:00,  1.33s/it]


Epoch [82/100], Loss: 0.9192, Accuracy: 0.9979


Epoch 83/100: 100%|██████████| 91/91 [02:01<00:00,  1.33s/it]


Epoch [83/100], Loss: 0.8246, Accuracy: 0.9972


Epoch 84/100: 100%|██████████| 91/91 [02:00<00:00,  1.33s/it]


Epoch [84/100], Loss: 0.9584, Accuracy: 0.9979


Epoch 85/100: 100%|██████████| 91/91 [02:00<00:00,  1.32s/it]


Epoch [85/100], Loss: 2.0194, Accuracy: 0.9962


Epoch 86/100: 100%|██████████| 91/91 [02:00<00:00,  1.33s/it]


Epoch [86/100], Loss: 1.0871, Accuracy: 0.9979


Epoch 87/100: 100%|██████████| 91/91 [02:01<00:00,  1.33s/it]


Epoch [87/100], Loss: 1.4806, Accuracy: 0.9979


Epoch 88/100: 100%|██████████| 91/91 [02:00<00:00,  1.32s/it]


Epoch [88/100], Loss: 1.4215, Accuracy: 0.9972


Epoch 89/100: 100%|██████████| 91/91 [02:00<00:00,  1.33s/it]


Epoch [89/100], Loss: 1.2840, Accuracy: 0.9979


Epoch 90/100: 100%|██████████| 91/91 [02:00<00:00,  1.32s/it]


Epoch [90/100], Loss: 1.0909, Accuracy: 0.9972


Epoch 91/100: 100%|██████████| 91/91 [02:00<00:00,  1.32s/it]


Epoch [91/100], Loss: 1.0524, Accuracy: 0.9972


Epoch 92/100: 100%|██████████| 91/91 [02:00<00:00,  1.33s/it]


Epoch [92/100], Loss: 6.8494, Accuracy: 0.9979


Epoch 93/100: 100%|██████████| 91/91 [02:00<00:00,  1.33s/it]


Epoch [93/100], Loss: 2.8177, Accuracy: 0.9983


Epoch 94/100: 100%|██████████| 91/91 [02:01<00:00,  1.33s/it]


Epoch [94/100], Loss: 0.9780, Accuracy: 0.9979


Epoch 95/100: 100%|██████████| 91/91 [02:01<00:00,  1.34s/it]


Epoch [95/100], Loss: 0.9948, Accuracy: 0.9969


Epoch 96/100: 100%|██████████| 91/91 [01:59<00:00,  1.31s/it]


Epoch [96/100], Loss: 0.7668, Accuracy: 0.9979


Epoch 97/100: 100%|██████████| 91/91 [01:59<00:00,  1.32s/it]


Epoch [97/100], Loss: 0.9934, Accuracy: 0.9979


Epoch 98/100: 100%|██████████| 91/91 [02:00<00:00,  1.33s/it]


Epoch [98/100], Loss: 2.3020, Accuracy: 0.9979


Epoch 99/100: 100%|██████████| 91/91 [02:00<00:00,  1.33s/it]


Epoch [99/100], Loss: 0.6489, Accuracy: 0.9993


Epoch 100/100: 100%|██████████| 91/91 [02:00<00:00,  1.32s/it]


Epoch [100/100], Loss: 1.0808, Accuracy: 0.9972
                       precision    recall  f1-score   support

 Tomato___Late_blight       1.00      1.00      1.00       940
Potato___Early_blight       1.00      1.00      1.00       494

             accuracy                           1.00      1434
            macro avg       1.00      1.00      1.00      1434
         weighted avg       1.00      1.00      1.00      1434

F1 Score: 0.9993 | Precision: 0.9993 | Recall: 0.9993


In [5]:
from sklearn.manifold import TSNE
# === t-SNE Visualization ===

def extract_features(model, dataloader):
    model.eval()
    features = []
    labels = []
    with torch.no_grad():
        for imgs, lbls in tqdm(dataloader, desc="Extracting features"):
            imgs = imgs.to(device)
            output = model.features(imgs)  # Use the feature extractor
            output = model.avgpool(output)  # Apply global average pooling
            output = torch.flatten(output, 1)
            features.append(output.cpu())
            labels.extend(lbls.cpu().numpy())
    features = torch.cat(features, dim=0).numpy()
    return features, np.array(labels)

features, labels = extract_features(model, val_loader)

# Apply t-SNE
tsne = TSNE(n_components=2, random_state=42, perplexity=30, n_iter=300)
tsne_results = tsne.fit_transform(features)

# Plotting
plt.figure(figsize=(8, 6))
palette = sns.color_palette("hsv", len(selected_classes))
for i, cls_name in enumerate(selected_classes):
    idx = labels == i
    plt.scatter(tsne_results[idx, 0], tsne_results[idx, 1], label=cls_name, alpha=0.6, s=20, color=palette[i])
plt.legend()
plt.title("t-SNE of Validation Features (MobileNetV2 + BAM)")
plt.xlabel("t-SNE Dimension 1")
plt.ylabel("t-SNE Dimension 2")
plt.tight_layout()
plt.savefig("tsne_bam.png")
plt.close()


Extracting features: 100%|██████████| 45/45 [00:16<00:00,  2.65it/s]


In [17]:
import cv2
import numpy as np
import torch
import matplotlib.pyplot as plt

def generate_gradcam(model, image_tensor, target_class=None):
    model.eval()
    image_tensor = image_tensor.unsqueeze(0).to(device)

    # Hook to capture gradients and activations
    gradients = []
    activations = []

    def backward_hook(module, grad_input, grad_output):
        gradients.append(grad_output[0])

    def forward_hook(module, input, output):
        activations.append(output)

    # Find the last BAM layer
    last_bam_layer = None
    for name, module in reversed(model.features._modules.items()):
        if isinstance(module, BAM):
            last_bam_layer = module
            break

    if last_bam_layer is None:
        print("No BAM layer found!")
        return None, None

    # Register hooks
    handle_fw = last_bam_layer.register_forward_hook(forward_hook)
    handle_bw = last_bam_layer.register_backward_hook(backward_hook)

    try:
        # Forward pass
        output = model(image_tensor)
        pred_class = output.argmax(dim=1).item() if target_class is None else target_class
        score = output[:, pred_class]

        # Backward pass
        model.zero_grad()
        score.backward()
    finally:
        # Safely remove hooks
        handle_fw.remove()
        handle_bw.remove()

    # Compute Grad-CAM
    grad = gradients[0][0].cpu().numpy()
    act = activations[0][0].detach().cpu().numpy()
    weights = np.mean(grad, axis=(1, 2))
    cam = np.zeros(act.shape[1:], dtype=np.float32)

    for i, w in enumerate(weights):
        cam += w * act[i]

    cam = np.maximum(cam, 0)
    cam = cv2.resize(cam, (224, 224))
    cam = cam - cam.min()
    cam = cam / cam.max()

    # Unnormalize the input image (assuming ImageNet normalization)
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img = image_tensor.squeeze().detach().cpu().permute(1, 2, 0).numpy()
    img = (img * std + mean)
    img = np.clip(img, 0, 1)

    # Overlay Grad-CAM
    heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    overlayed = heatmap + img
    overlayed = overlayed / overlayed.max()

    return overlayed, pred_class

import os

plt.figure(figsize=(12, 6))
NUM_SAMPLES = 5
selected_classes = ['Tomato___Late_blight', 'Potato___Early_blight']

output_dir = "gradcam_results"
os.makedirs(output_dir, exist_ok=True)

for i, (img, label) in enumerate(zip(samples[:NUM_SAMPLES], labels[:NUM_SAMPLES])):
    result, pred = generate_gradcam(model, img)
    if result is None:
        print(f"Skipping image {i} - No BAM layer found.")
        continue

    # Original image display
    plt.subplot(2, NUM_SAMPLES, i + 1)
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    img_disp = img * std + mean
    img_disp = img_disp.permute(1, 2, 0).clamp(0, 1).cpu().numpy()
    plt.imshow(img_disp)
    plt.title(f"True: {selected_classes[label]}")
    plt.axis('off')

    # Grad-CAM overlay display
    plt.subplot(2, NUM_SAMPLES, NUM_SAMPLES + i + 1)
    plt.imshow(result)
    plt.title(f"Pred: {selected_classes[pred]}")
    plt.axis('off')

    # Save Grad-CAM overlay image to file
    save_path = os.path.join(output_dir, f"gradcam_overlay_{i}.png")
    plt.imsave(save_path, result)
    print(f"Saved Grad-CAM overlay image {i} at {save_path}")

plt.tight_layout()
plt.show()


Saved Grad-CAM overlay image 0 at gradcam_results/gradcam_overlay_0.png
Saved Grad-CAM overlay image 1 at gradcam_results/gradcam_overlay_1.png
Saved Grad-CAM overlay image 2 at gradcam_results/gradcam_overlay_2.png
Saved Grad-CAM overlay image 3 at gradcam_results/gradcam_overlay_3.png
