In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
from torchvision.transforms import v2

In [2]:
import os
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import cv2
from PIL import Image
from tqdm import tqdm
from random import randint, random

In [3]:
from src.FFTConv import *
from src.ImageHandler import *

In [4]:
IMG_SIZE = 128
device = 'cuda' if torch.cuda.is_available() else 'cpu'
targets = ['bacterial', 'normal', 'viral']
REBUILD_DATA = False

In [5]:
model = FFTAlex(apply_fft=True, device=device, IMG_SIZE=IMG_SIZE)

Total Layers replaced:  1


In [6]:
# model = FFTGoogle(apply_fft=True, device=device)

In [7]:
model.load_model_dict(os.path.join('models', 'alex', 'fft_alex_model.pth'))
model.eval()
model

  state_dict = torch.load(path, map_location=self.device)


FFTAlex(
  (model): AlexNet(
    (features): Sequential(
      (0): FFTConvNet(
        (conv_layer): Conv2d(3, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
      )
      (1): ReLU()
      (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (4): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2))
      (5): ReLU()
      (6): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (7): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (9): ReLU()
      (10): Conv2d(384, 256, kernel_size=(1, 1), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), padding=(1, 1))
      (13): ReLU()
      (14): MaxPool2d(kernel_size=3, stride=2, pa

In [8]:
layer = model.model.features[10]

In [9]:
# Hook to store the feature maps
feature_maps = None

def hook_fn(module, input, output):
    global feature_maps
    feature_maps = output

# Register the hook to the last convolutional layer
layer.register_forward_hook(hook_fn)

<torch.utils.hooks.RemovableHandle at 0x1a2147297f0>

In [10]:
lmdb_path = os.path.join('lmdb')

if REBUILD_DATA:
    image_path = os.path.join('data', 'test_set')
    test_data = ImageDataset(image_path=image_path, device=device, lmdb_path=lmdb_path, save_lmdb=True, mode="test")

    REBUILD_DATA = False
else:
    test_data = ImageDataset(image_path=None, device=device, lmdb_path=lmdb_path, save_lmdb=False, mode="test")

Loaded test dataset


In [13]:
main_directory = os.path.join("data", "test_set")
IMG_SIZE = 128
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def preprocess_image(img_path):
    img = Image.open(img_path).convert('RGB')
    preprocess = v2.Compose([
        v2.ToImage(),
        v2.Resize((IMG_SIZE, IMG_SIZE)),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])
    img = preprocess(img).unsqueeze(0)
    return img.to(device)

def get_label(img_path):
    return os.path.basename(os.path.dirname(img_path))

all_images = []
no_pic = 20
disp_heatmaps = 0

for label in targets:
    label_directory = os.path.join(main_directory, label)
    if os.path.isdir(label_directory):
        for img_name in os.listdir(label_directory):
            if img_name.lower().endswith(('.jpg', '.jpeg')):
                all_images.append(os.path.join(label_directory, img_name))

while disp_heatmaps < no_pic:
    image_num = randint(0, len(all_images) - 1)
    img_path = all_images[image_num]
    true_label = get_label(img_path)
    input_img = preprocess_image(img_path)

    output = model(input_img)
    _, predicted_class = torch.max(output, 1)
    predicted_label = targets[predicted_class.item()]

    if predicted_label == true_label:
        disp_heatmaps += 1
        model.zero_grad()

        gradients = torch.autograd.grad(output[:, predicted_class], feature_maps, retain_graph=True)[0]
        pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])

        for i in range(feature_maps.shape[1]):
            feature_maps[:, i, :, :] *= pooled_gradients[i]

        # Average the feature maps along the channel dimension
        heatmap = torch.mean(feature_maps, dim=1).squeeze()

        # Apply ReLU to the heatmap
        heatmap = np.maximum(heatmap.detach().cpu().numpy(), 0)
        max_val = np.max(heatmap)
        if max_val > 0:
            heatmap /= max_val
        heatmap = np.uint8(255 * heatmap)

        # Read the original image for overlay
        img = cv2.imread(img_path)
        img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Fix color issues

        # Ensure heatmap and img have the same size
        heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))

        # Create figure and display
        plt.figure(figsize=(10, 10))
        plt.imshow(img)
        heatmap_img = plt.imshow(heatmap, cmap='jet', alpha=0.5)
        plt.colorbar(heatmap_img, shrink=0.5, aspect=10)
        plt.title(f"Index: {image_num}, Prediction: {targets[predicted_class]}")
        plt.axis('off')

        # Save the heatmap image
        plt.savefig(f"heatmaps/image_{image_num}.png", dpi=300, bbox_inches="tight")
        plt.close()


In [12]:
print(len(all_images))

624
