This notebook is what we use to visualize attention on some test images by using the torchcam library.


In [1]:
from dataset.phishing_dataset import PhishingDataset

import torch
from torchcam.methods import SmoothGradCAMpp
from torchcam.utils import overlay_mask
from torchvision.transforms import ToPILImage
import matplotlib.pyplot as plt
from classifiers.only_cnn_classifier import BasicCNNClassifier
import random
import os

model = BasicCNNClassifier()
model.load_state_dict(torch.load("../models/canonical/cnn_only_phishing_classifier.pt"))
model.eval()

dataset_path = os.path.expanduser("~/transfer/phishing_output.h5")

# Here, we use the last convolutional layer in ResNet18, which is 'layer4'
cam_extractor = SmoothGradCAMpp(model.cnn, target_layer="layer4")

dataset = PhishingDataset(required_data=['image', 'url'], split='test', local_file_path=dataset_path)

def get_random_image(dataset):
    random_index = random.randint(0, len(dataset) - 1)
    
    sample = dataset[random_index]    
    tensor_image = sample['image']  
    
    return tensor_image, random_index

def generate_heatmap(image_tensor, item_idx=None):
    image_tensor.requires_grad = True
    
    logits = model(image_tensor.unsqueeze(0)) 
    
    activation_map = cam_extractor(logits.argmax().item(), logits)
    
    to_pil = ToPILImage()
    
    screenshot_image = dataset.screenshots[item_idx]
    result = overlay_mask(to_pil(screenshot_image), to_pil(activation_map[0]), alpha=0.5)
    
    
    plt.figure(figsize=(341, 226))
    plt.imshow(screenshot_image)
    plt.title("Original Image")
    plt.axis('off')
    plt.savefig('original_image.png', bbox_inches='tight')  # Save as a file
    plt.show()
    
    # Save the heatmap overlay image
    plt.figure(figsize=(341, 226))
    plt.imshow(result)
    plt.title("Heatmap Overlay")
    plt.axis('off')
    plt.savefig('heatmap_overlay.png', bbox_inches='tight')  # Save as a file
    plt.show()

# Example usage with a sample image
tensor_image, class_idx = get_random_image(dataset)
print(dataset.labels[class_idx])
generate_heatmap(tensor_image, class_idx)


  Referenced from: <6DFB383A-E1D9-3EC6-8A60-382AF4E3C226> /opt/homebrew/Caskroom/miniforge/base/envs/phishing-edge/lib/python3.10/site-packages/torchvision/image.so
  warn(
Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x1105f4ca0>>
Traceback (most recent call last):
  File "/opt/homebrew/Caskroom/miniforge/base/envs/phishing-edge/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 


ImportError: cannot import name 'PhishingClassifier' from 'classifiers.only_cnn_classifier' (/Users/imack/Documents/Stanford/CS230/phishing-edge/classifiers/only_cnn_classifier.py)