In [5]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np

# Load the trained model
model = models.resnet50(pretrained=True)
model.eval()

# Load the input image and preprocess it
image_path = '/Users/leo/Desktop/new_thesis/stock_image.jpg'
image = Image.open(image_path)
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0)

# Create a forward hook to save the feature maps
class CBAMForwardHook:
    def __init__(self):
        self.feature_maps = None
    
    def __call__(self, module, input, output):
        self.feature_maps = output.detach()
        
cbam_hook = CBAMForwardHook()

# Register the forward hook on the last CBAM module
last_cbam_module = None
for module in model.modules():
    if isinstance(module, models.resnet.Bottleneck):
        if hasattr(module, 'cbam'):
            last_cbam_module = module.cbam
if last_cbam_module is not None:
    for name, module in last_cbam_module.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            module.register_forward_hook(cbam_hook)
            break
else:
    print('No CBAM module found in the model')

# Forward the input through the model and obtain the output class probabilities and the feature maps
with torch.no_grad():
    output = model(input_batch)
    class_idx = torch.argmax(output)
    probs = F.softmax(output, dim=1)
    prob = probs[0, class_idx]
    feature_maps = cbam_hook.feature_maps
    if feature_maps is not None:
        feature_maps = feature_maps.squeeze(0)
    else:
        print('No feature maps captured by the forward hook')

# Compute the gradients of the output class probabilities with respect to the feature maps
if feature_maps is not None:
    grads = torch.autograd.grad(outputs=probs[:, class_idx], inputs=feature_maps, grad_outputs=torch.ones_like(probs[:, class_idx]), retain_graph=True, create_graph=True)[0]

    # Compute the class activation map
    cam = torch.sum(grads * feature_maps, dim=(2, 3), keepdim=True)
    cam = F.relu(cam)

    # Upsample the class activation map to the size of the input image
    upsample = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(image.size, interpolation=Image.BILINEAR),
        transforms.ToTensor()
    ])
    cam = upsample(cam.squeeze(0)).numpy()
    cam = cam - np.min(cam)
    cam = cam / np.max(cam)

    # Plot the input image and the class activation map overlaid on top of the input image
    plt.imshow(image)
    plt.imshow(cam, alpha=0.5, cmap='jet')
    plt.axis('off')
    plt.show()
else:
    print('Cannot visualize attention because no feature maps captured')


No CBAM module found in the model
No feature maps captured by the forward hook
Cannot visualize attention because no feature maps captured
