In [None]:
from dataset.phishing_dataset import PhishingDataset
import torchcam

import torch
from torchcam.methods import GradCAM, SmoothGradCAMpp
from torchcam.utils import overlay_mask
from torchvision.transforms import ToPILImage
from PIL import Image
import matplotlib.pyplot as plt
from classifiers.only_cnn_classifier import PhishingClassifier
import random
import os

model = PhishingClassifier()
model.load_state_dict(torch.load("../models/cnn_only_phishing_classifier.pt"))
model.eval()

dataset_path = os.path.expanduser("~/transfer/phishing_output.h5")

# Here, we use the last convolutional layer in ResNet18, which is 'layer4'
cam_extractor = SmoothGradCAMpp(model.cnn, target_layer="layer4")

dataset = PhishingDataset(required_data=['image', 'url'], split='test', local_file_path=dataset_path)

# Define a function to get a random image from the dataset
def get_random_image(dataset):
    random_index = random.randint(0, len(dataset) - 1)
    
    sample = dataset[random_index]    
    tensor_image = sample['image']  
    
    return tensor_image, random_index

def generate_heatmap(image_tensor, item_idx=None):
    image_tensor.requires_grad = True
    
    logits = model(image_tensor.unsqueeze(0)) 
    
    activation_map = cam_extractor(logits.argmax().item(), logits)
    
    to_pil = ToPILImage()
    
    screenshot_image = dataset.screenshots[item_idx]
    result = overlay_mask(to_pil(screenshot_image), to_pil(activation_map[0]), alpha=0.5)
    
    
    plt.figure(figsize=(341, 226))
    plt.imshow(screenshot_image)
    plt.title("Original Image")
    plt.axis('off')
    plt.savefig('original_image.png', bbox_inches='tight')  # Save as a file
    plt.show()
    
    # Save the heatmap overlay image
    plt.figure(figsize=(341, 226))
    plt.imshow(result)
    plt.title("Heatmap Overlay")
    plt.axis('off')
    plt.savefig('heatmap_overlay.png', bbox_inches='tight')  # Save as a file
    plt.show()

# Example usage with a sample image
tensor_image, class_idx = get_random_image(dataset)
print(dataset.labels[class_idx])
generate_heatmap(tensor_image, class_idx)


