In [2]:
import sys
sys.path.append("..")


In [3]:
import os
import torch
import numpy as np
from PIL import Image
from torchvision import transforms, datasets
import torch.nn.functional as F
from matplotlib import cm

from models.model_v2 import PlantDiseaseResNet18  


TEST_IMAGES_PATH = "/home/ajeet/STRM-SEMESTER-PROJECT/RESNET_Methodology/ResNet_Project/dataset/test_restructured"
checkpoint_path = "/home/ajeet/STRM-SEMESTER-PROJECT/RESNET_Methodology/ResNet_Project/collaborative_cnn_team03/models/model_v2.pth"
SAVE_DIR = "/home/ajeet/STRM-SEMESTER-PROJECT/RESNET_Methodology/ResNet_Project/collaborative_cnn_team03/results/gradcam_v2_user2"

os.makedirs(SAVE_DIR, exist_ok=True)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

test_dataset = datasets.ImageFolder(TEST_IMAGES_PATH, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

class_names = test_dataset.classes
num_classes = len(class_names)
print("Classes:", class_names)



model = PlantDiseaseResNet18(num_classes=num_classes).to(device)

checkpoint_state = torch.load(checkpoint_path, map_location=device)
model_state = model.state_dict()

filtered_state = {}
for k, v in checkpoint_state.items():
    if k in model_state and v.shape == model_state[k].shape:
        filtered_state[k] = v

print(f"Loaded {len(filtered_state)} compatible layers.")

model_state.update(filtered_state)
model.load_state_dict(model_state)
model.eval()



def grad_cam(model, img_tensor, target_class, target_layer):

    activations = {}
    gradients = {}

    def forward_hook(module, input, output):
        activations['value'] = output

    def backward_hook(module, grad_in, grad_out):
        gradients['value'] = grad_out[0]

    fwd = target_layer.register_forward_hook(forward_hook)
    bwd = target_layer.register_backward_hook(backward_hook)

    output = model(img_tensor)
    loss = output[0, target_class]

    model.zero_grad()
    loss.backward()

    acts = activations['value'].squeeze(0)  
    grads = gradients['value'].squeeze(0)    

    weights = grads.mean(dim=(1, 2))         

    cam = torch.zeros(acts.shape[1:], dtype=torch.float32, device=acts.device)
    for i, w in enumerate(weights):
        cam += w * acts[i]

    cam = torch.relu(cam)

    cam = cam.unsqueeze(0).unsqueeze(0)
    cam = F.interpolate(cam, size=(224, 224), mode="bilinear", align_corners=False)
    cam = cam.squeeze().detach().cpu().numpy()

    cam = (cam - cam.min()) / (cam.max() + 1e-8)

    fwd.remove()
    bwd.remove()

    return cam, output.detach()



target_layer = model.backbone.layer4[1].conv2 



for i, (img, label) in enumerate(test_loader):

    img = img.to(device)
    label_idx = label.item()
    true_class = class_names[label_idx]

    heatmap, output = grad_cam(model, img, label_idx, target_layer)

    pred_idx = int(output.argmax().item())
    pred_class = class_names[pred_idx]


    heatmap_color = cm.jet(heatmap)[:, :, :3]  
    heatmap_color = (heatmap_color * 255).astype(np.uint8)


    orig = img[0].cpu().numpy().transpose(1, 2, 0)
    orig = (orig * np.array([0.229, 0.224, 0.225]) +
            np.array([0.485, 0.456, 0.406]))
    orig = np.clip(orig * 255.0, 0, 255).astype(np.uint8)


    overlay = (0.7 * orig + 0.3 * heatmap_color).astype(np.uint8)
    overlay_img = Image.fromarray(overlay)


    save_name = f"overlay_{i:04d}_true_{true_class}_pred_{pred_class}.png"
    save_path = os.path.join(SAVE_DIR, save_name)
    overlay_img.save(save_path)

    print(f"Saved Overlay Grad-CAM: {save_path}")


print("\n Colored Grad-CAM Overlay Generation Completed!")
print(f" Saved to: {SAVE_DIR}")


Using device: cuda
Classes: ['AppleCedarRust', 'AppleScab', 'CornCommonRust', 'PotatoEarlyBlight', 'PotatoHealthy', 'TomatoEarlyBlight', 'TomatoHealthy', 'TomatoYellowCurlVirus']
Loaded 131 compatible layers.


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


Saved Overlay Grad-CAM: /home/ajeet/STRM-SEMESTER-PROJECT/RESNET_Methodology/ResNet_Project/collaborative_cnn_team03/results/gradcam_v2_user2/overlay_0000_true_AppleCedarRust_pred_TomatoYellowCurlVirus.png
Saved Overlay Grad-CAM: /home/ajeet/STRM-SEMESTER-PROJECT/RESNET_Methodology/ResNet_Project/collaborative_cnn_team03/results/gradcam_v2_user2/overlay_0001_true_AppleCedarRust_pred_TomatoEarlyBlight.png
Saved Overlay Grad-CAM: /home/ajeet/STRM-SEMESTER-PROJECT/RESNET_Methodology/ResNet_Project/collaborative_cnn_team03/results/gradcam_v2_user2/overlay_0002_true_AppleCedarRust_pred_PotatoEarlyBlight.png
Saved Overlay Grad-CAM: /home/ajeet/STRM-SEMESTER-PROJECT/RESNET_Methodology/ResNet_Project/collaborative_cnn_team03/results/gradcam_v2_user2/overlay_0003_true_AppleCedarRust_pred_PotatoEarlyBlight.png
Saved Overlay Grad-CAM: /home/ajeet/STRM-SEMESTER-PROJECT/RESNET_Methodology/ResNet_Project/collaborative_cnn_team03/results/gradcam_v2_user2/overlay_0004_true_AppleScab_pred_AppleCedarRus