In [None]:
import sys
sys.path.append("../../")

import matplotlib.pyplot as plt

import torch
import torch.nn as nn

from torchvision import models
from utils.data import get_example_image

In [None]:
tensor, img, label = get_example_image(650)

plt.imshow(img)
plt.axis("off")
plt.title(f"Label: {label}")
plt.show()

In [None]:
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)

num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)  # 10 classes

model.load_state_dict(torch.load("../../assets/models/finetuned_resnet18.pth"))
model.eval()

In [None]:
features = {}

def save_features(name):
    def hook(module, input, output):
        features[name] = output.detach()
    return hook

layers = ["layer1", "layer2", "layer3", "layer4"]
for name in layers:
    getattr(model, name).register_forward_hook(save_features(name))

with torch.no_grad():
    _ = model(tensor.unsqueeze(0))  # add batch dimension

channel_idx = 0 # channel to visualize

plt.figure(figsize=(14, 3))
for i, name in enumerate(layers):
    fmap = features[name][0, channel_idx]

    plt.subplot(1, len(layers), i + 1)
    plt.imshow(fmap.cpu(), cmap="viridis")
    plt.title(name)
    plt.axis("off")

plt.show()