In [None]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F
import cv2
import numpy as np

In [None]:
%run ResNet34Model.ipynb
%run DatasetStatistics.ipynb
%run HistogramEqualization.ipynb

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
checkpoint = torch.load("./config/resnet34.pth", map_location=device)

In [None]:
IMAGE_SIZE = checkpoint["image_size"]
MEAN = checkpoint["mean"] / 255.0
STD = checkpoint["std"] / 255.0
DIR = "./new_data/Testing"

In [None]:
transforms = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        HistogramEqualization(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[MEAN] * 3, std=[STD] * 3),
    ]
)

In [None]:
dataset = datasets.ImageFolder(DIR, transform=transforms)

In [None]:
dl = DataLoader(dataset, shuffle=False, batch_size=1)

In [None]:
model = ResNet34Model()
model.load_state_dict(checkpoint["weights"])
model.to(device)
model.eval()

In [None]:
img, _ = next(iter(dl))
img = img.to(device)

In [None]:
out = model(img)

In [None]:
class_index = out.argmax(dim=1).item()

In [None]:
model.zero_grad()
loss = out[:, class_index].backward()

In [None]:
grads = model.get_activations_gradient()
pool_gradients = torch.mean(grads, dim=[0, 2, 3])

In [None]:
activations = model.get_activations()

for i in range(512):
    activations[:, i, :, :] *= pool_gradients[i]

In [None]:
heatmap = torch.mean(activations, dim=1).squeeze()
heatmap = F.relu(heatmap)  # reLU

In [None]:
heatmap /= heatmap.max()  # normalize

In [None]:
heatmap = heatmap.cpu().numpy()
plt.matshow(heatmap.squeeze())

In [None]:
img_path, _ = dataset.samples[0]
img_original = cv2.imread(img_path)
img_original = cv2.resize(img_original, IMAGE_SIZE)

In [None]:
heatmap = cv2.resize(heatmap, (img_original.shape[1], img_original.shape[0]))
heatmap = (heatmap * 255).astype(np.uint8)

In [None]:
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)

In [None]:
superimposed_img = cv2.addWeighted(img_original, 0.6, heatmap, 0.4, 0)

In [None]:
# TODO: add visualization