In [None]:
import torch
from torchvision import models, transforms
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os

# Load trained ResNet18 model
model = torch.load('resnet18_fruit_model.pth', map_location=torch.device('cpu'))  # adjust path if needed
model.eval()

# Define transforms (same as during training)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Define hook to capture gradients and activations
activations = {}
gradients = {}

def save_activation(name):
    def hook(model, input, output):
        activations[name] = output.detach()
    return hook

def save_gradient(name):
    def hook(model, grad_input, grad_output):
        gradients[name] = grad_output[0].detach()
    return hook

# Register hook to the last convolutional layer
target_layer = model.layer4[1].conv2  # works for standard ResNet18
target_layer.register_forward_hook(save_activation("features"))
target_layer.register_backward_hook(save_gradient("features"))

# Grad-CAM function
def generate_gradcam(image_tensor, class_idx):
    model.zero_grad()
    output = model(image_tensor.unsqueeze(0))
    pred_class = output.argmax().item()

    score = output[0, class_idx]
    score.backward()

    grads = gradients["features"]
    acts = activations["features"]
    weights = grads.mean(dim=(2, 3), keepdim=True)

    cam = (weights * acts).sum(dim=1).squeeze()
    cam = torch.relu(cam)
    cam = cam - cam.min()
    cam = cam / cam.max()
    cam = cam.numpy()
    cam = cv2.resize(cam, (224, 224))
    return cam, pred_class

# Plotting function
def show_gradcam(image_path, class_idx, true_label):
    img = Image.open(image_path).convert('RGB')
    input_tensor = transform(img)

    cam, pred_class = generate_gradcam(input_tensor, class_idx)

    img_np = np.array(img.resize((224, 224)))
    heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    overlay = heatmap + np.float32(img_np) / 255
    overlay = overlay / overlay.max()

    plt.figure(figsize=(6, 6))
    plt.imshow(overlay)
    plt.title(f"True: {true_label}, Predicted: {pred_class}")
    plt.axis('off')
    plt.show()

# 👇 Example usage: use 3 image paths from poorly classified classes
# Replace with actual paths from your test dataset
examples = [
    ('test/Bellpepper__Rotten/img123.jpg', class_idx_for_bellpepper_rotten, 'Bellpepper__Rotten'),
    ('test/Tomato__Rotten/img456.jpg', class_idx_for_tomato_rotten, 'Tomato__Rotten'),
    ('test/Potato__Rotten/img789.jpg', class_idx_for_pomegranate_rotten, 'Potato__Rotten'),
]

for img_path, class_idx, true_label in examples:
    show_gradcam(img_path, class_idx, true_label)
