In [None]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F
import cv2
import numpy as np

from app.utils.ResNet34Model import ResNet34Model
from app.utils.HistogramEqualization import HistogramEqualization
from app.utils.DatasetStatistics import DatasetStatistics

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
checkpoint = torch.load(
    "./config/resnet34.pth", map_location=device, weights_only=False
)

In [None]:
print(checkpoint.keys())

In [None]:
IMAGE_SIZE = checkpoint["image_size"]
MEAN = checkpoint["mean"]
STD = checkpoint["std"]
DIR = "./new_data/Testing"

In [None]:
transforms = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        HistogramEqualization(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[MEAN] * 3, std=[STD] * 3),
    ]
)

In [None]:
dataset = datasets.ImageFolder(DIR, transform=transforms)
print(dataset.classes)

In [None]:
dl = DataLoader(dataset, shuffle=False, batch_size=1)

In [None]:
model = ResNet34Model()
model.load_state_dict(checkpoint["weights"])
model.to(device)
model.eval()

In [None]:
img, _ = next(iter(dl))
img = img.to(device)

In [None]:
out = model(img)

In [None]:
class_index = out.argmax(dim=1).item()

In [None]:
model.zero_grad()
loss = out[:, class_index].backward()

In [None]:
grads = model.get_activations_gradient()
pool_gradients = torch.mean(grads, dim=[0, 2, 3])

In [None]:
activations = model.get_activations().clone()

for i in range(512):
    activations[:, i, :, :] *= pool_gradients[i]

In [None]:
heatmap = torch.mean(activations, dim=1).squeeze()
heatmap = F.relu(heatmap)  # reLU

In [None]:
heatmap /= heatmap.max()  # normalize

In [None]:
heatmap = heatmap.cpu().numpy()
plt.matshow(heatmap.squeeze())

In [None]:
img_path, _ = dataset.samples[0]
img_original = cv2.imread(img_path)
img_original = cv2.resize(img_original, IMAGE_SIZE)

In [None]:
heatmap = cv2.resize(heatmap, (img_original.shape[1], img_original.shape[0]))
heatmap = (heatmap * 255).astype(np.uint8)

In [None]:
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)

In [None]:
superimposed_img = cv2.addWeighted(img_original, 0.6, heatmap, 0.4, 0)

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(10, 5))

img_original = cv2.cvtColor(img_original, cv2.COLOR_BGR2RGB)
axs[0].imshow(img_original)
axs[0].set_title(f"Original image ({dataset.classes[_]} class)")
axs[0].axis("off")

axs[1].imshow(cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB), cmap="jet")
axs[1].set_title("Grad-CAM")
axs[1].axis("off")

superimposed_img = cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB)
axs[2].imshow(superimposed_img, cmap="jet")
axs[2].set_title(
    f"Grad-CAM over the original image\n(predicted {dataset.classes[class_index]} class)"
)
axs[2].axis("off")

plt.tight_layout()
plt.show()

In [None]:
class_samples = {}
for idx, (path, label) in enumerate(dataset.samples):
    class_name = dataset.classes[label]
    if class_name not in class_samples:
        class_samples[class_name] = idx
    if len(class_samples) == 4:
        break

fig, axes = plt.subplots(4, 3, figsize=(12, 16))

for row, (class_name, idx) in enumerate(class_samples.items()):
    img_tensor = dataset[idx][0].unsqueeze(0).to(device)
    img_path = dataset.samples[idx][0]

    model.zero_grad()
    out = model(img_tensor)
    pred_idx = out.argmax(dim=1).item()
    out[:, pred_idx].backward()

    grads = model.get_activations_gradient()
    acts = model.get_activations().clone()
    weights = grads.mean(dim=(2, 3), keepdim=True)
    cam = (weights * acts).sum(dim=1).squeeze()
    cam = F.relu(cam)
    cam = cam / (cam.max() + 1e-8)
    cam = cam.cpu().detach().numpy()
    cam = cv2.resize(cam, IMAGE_SIZE)

    # Original
    img_orig = cv2.imread(img_path)
    img_orig = cv2.resize(img_orig, IMAGE_SIZE)
    img_orig = cv2.cvtColor(img_orig, cv2.COLOR_BGR2RGB)

    # Heatmap
    heatmap_colored = cv2.applyColorMap((cam * 255).astype(np.uint8), cv2.COLORMAP_JET)
    heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)

    # Overlay
    overlay = cv2.addWeighted(img_orig, 0.6, heatmap_colored, 0.4, 0)

    axes[row, 0].imshow(img_orig)
    axes[row, 0].set_title(f"True: {class_name}")
    axes[row, 0].axis("off")

    axes[row, 1].imshow(cam, cmap="jet")
    axes[row, 1].set_title("Grad-CAM")
    axes[row, 1].axis("off")

    axes[row, 2].imshow(overlay)
    axes[row, 2].set_title(f"Pred: {dataset.classes[pred_idx]}")
    axes[row, 2].axis("off")

plt.tight_layout()
plt.show()