In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
from torchvision.transforms import v2

In [2]:
import os
import matplotlib.pyplot as plt
import numpy as np
import cv2
from PIL import Image
from tqdm import tqdm

In [3]:
from src.FFTConv import *
from src.ImageHandler import *

In [4]:
IMG_SIZE = 129
device = 'cuda' if torch.cuda.is_available() else 'cpu'
targets = ['bacterial', 'normal', 'viral']
REBUILD_DATA = False

In [5]:
model = FFTAlex(apply_fft=True, device=device, IMG_SIZE=IMG_SIZE)

Total Layers replaced:  1


In [None]:
# model = FFTGoogle(apply_fft=True, device=device)

In [7]:
model.load_model_dict(os.path.join('models', 'alex', 'fft_alex_model.pth'))
model.eval()
# layer = model.model.inception4e.branch4

FFTAlex(
  (model): AlexNet(
    (features): Sequential(
      (0): FFTConvNet(
        (conv_layer): Conv2d(3, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
      )
      (1): ReLU()
      (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (4): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2))
      (5): ReLU()
      (6): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (7): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (9): ReLU()
      (10): Conv2d(384, 256, kernel_size=(1, 1), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), padding=(1, 1))
      (13): ReLU()
      (14): MaxPool2d(kernel_size=3, stride=2, pa

In [None]:
# Hook to store the feature maps
feature_maps = None

def hook_fn(module, input, output):
    global feature_maps
    feature_maps = output

# Register the hook to the last convolutional layer
layer.register_forward_hook(hook_fn)

In [None]:
# lmdb_path = os.path.join('lmdb')

# if REBUILD_DATA:
#     image_path = os.path.join('data', 'test_set')
#     test_data = ImageDataset(image_path=image_path, device=device, lmdb_path=lmdb_path, save_lmdb=True, mode="test")

#     REBUILD_DATA = False
# else:
#     test_data = ImageDataset(image_path=None, device=device, lmdb_path=lmdb_path, save_lmdb=False, mode="test")

In [None]:
# batch_size = 32
# test_dl = DataLoader(test_data, batch_size=batch_size, shuffle=False, pin_memory=True)

In [None]:
# def test(model, test_dl, device):
#     model.eval()
#     all_preds = []
#     all_labels = []
#     all_grads = []

#     for images, labels in tqdm(test_dl):
#         images = images.to(device)
#         labels = labels.to(device)

#         images.requires_grad = True  # Enable gradient tracking for images
#         outputs = model(images)
#         _, predicted_class = torch.max(outputs, dim=1)

#         # Compute scalar loss to enable autograd
#         loss = outputs.gather(1, predicted_class.view(-1, 1)).sum()
#         model.zero_grad()
#         loss.backward(retain_graph=True)

#         # Store results
#         all_preds.extend(predicted_class.cpu().numpy())
#         all_labels.extend(labels.cpu().numpy())
#         all_grads.extend(images.grad.cpu().numpy())  # Save gradient maps

#     return np.array(all_labels), np.array(all_preds), np.array(all_grads)


In [None]:
# labels, preds, grads = test(model, test_dl, device)

# for i in range(len(grads)):  
#     gradients = grads[i]  # Select gradients for the i-th sample

#     # Follow the same heatmap generation pipeline
#     pooled_gradients = torch.mean(torch.Tensor(gradients), dim=[0, 2, 3])  # Pool over spatial dimensions
#     pooled_gradients = pooled_gradients.detach().cpu().numpy()
#     feature_maps = feature_maps.cpu()
#     for j in range(feature_maps.shape[1]):
#         feature_maps[:, j, :, :] *= pooled_gradients[j]


#     heatmap = torch.mean(feature_maps, dim=1).squeeze()
#     heatmap = np.maximum(heatmap.detach().cpu().numpy(), 0)
#     heatmap /= np.max(heatmap)

In [None]:
# Load and preprocess the input image
def preprocess_image(img_path):
    img = cv2.imread(img_path, )
    img = Image.open(img_path).convert('L')
    preprocess = v2.Compose([
        v2.ToImage(),
        v2.Resize((IMG_SIZE, IMG_SIZE)),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(mean=[0.5,], std=[0.5,]),
        v2.Grayscale(num_output_channels=3)
    ])
    img = preprocess(img)
    img = img.unsqueeze(0)
    return img

img_path = os.path.join("image.jpeg")
input_image = preprocess_image(img_path).to(device)

# Forward pass
output = model(input_image)

# Get the predicted class
_, predicted_class = torch.max(output, 1)

In [None]:
# Zero the gradients
model.model.zero_grad()

# Get the gradients of the output with respect to the feature maps
gradients = torch.autograd.grad(output[:, predicted_class], feature_maps, retain_graph=True)[0]

In [None]:
# Pool the gradients across the channels
pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])

# Weight the feature maps by the pooled gradients
for i in range(feature_maps.shape[1]):
    feature_maps[:, i, :, :] *= pooled_gradients[i]

# Average the feature maps along the channel dimension
heatmap = torch.mean(feature_maps, dim=1).squeeze()

# Apply ReLU to the heatmap
heatmap = np.maximum(heatmap.detach().cpu().numpy(), 0)

# Normalize the heatmap
heatmap /= np.max(heatmap)

In [None]:
# for i in range(len(dataset)):  
#     image_tensor, label = dataset[i]  # Retrieve preprocessed image and label
#     img = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)  # Convert tensor to NumPy image
    
#     # Get heatmap for the corresponding image
#     heatmap = grads[i]  # Extract stored gradients
#     heatmap = np.mean(heatmap, axis=0)  # Pool gradients across channels
#     heatmap = np.maximum(heatmap, 0)  # Apply ReLU
#     heatmap /= np.max(heatmap)  # Normalize

#     # Resize heatmap to match the original image size
#     heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))

#     # Convert heatmap to RGB
#     heatmap = np.uint8(255 * heatmap)
#     heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)

#     # Superimpose the heatmap on the original image
#     superimposed_img = heatmap * 0.4 + img

#     # Display the image
#     plt.imshow(superimposed_img / 255)
#     plt.title(f"Prediction: {targets[preds[i]]}")  # Use model predictions
#     plt.axis('off')
#     plt.show()


In [None]:
# Load the original image
img = cv2.imread(img_path)
img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))

# Resize the heatmap to match the image size
heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))

# Convert the heatmap to RGB
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)

# Superimpose the heatmap on the original image
superimposed_img = heatmap * 0.4 + img

# Display the image
plt.imshow(superimposed_img / 255)
plt.title(f"Prediction: {targets[predicted_class]}")
plt.axis('off')
plt.show()