In [None]:
# Fix working directory
%cd ..

In [None]:
import torch
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from utils.images import get_attention_maps, plot_attention_maps

In [None]:
# Load the entire model
model_path = "checkpoints/20240531_205302.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load(model_path, map_location=device)

In [None]:
# Define constants
batch_size = 2
img_size = 224
num_channels = 3

# Set the model to evaluation mode
model.eval()

# Define transformation pipeline for the dataset
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load the Imagenette dataset
imagenette_dataset = datasets.Imagenette(root="data", split="train", download=False, transform=transform)

# Create a dataloader
imagenette_dataloader = DataLoader(imagenette_dataset, batch_size=batch_size, shuffle=True)

In [None]:
# Get a batch of images and labels
images, labels = next(iter(imagenette_dataloader))
images = images.to(device)
labels = labels.to(device)

# Forward pass through the model
with torch.no_grad():
    outputs = model(images)

# Get the attention maps
attention_maps = get_attention_maps(model)

# Plot the attention maps
images = images.cpu()
attention_maps = attention_maps.cpu()
plot_attention_maps(images, attention_maps, save_path = None, imshow_interpolation = None)