In [None]:
# Fix working directory
%cd ..

In [None]:
import torch
from datasets import imagenet
from utils.images import get_attention_maps, plot_attention_maps
from utils.system import calculate_num_workers

In [None]:
# Load the model
model_path = (
    "checkpoints/20240602-185420_vit-256emb-04layer-08head-08patch-16register_epoch-1_valacc-9.68.pth"
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load(model_path, map_location=device)

In [None]:
batch_size = 8
num_workers = calculate_num_workers()

# Set the model to evaluation mode
model.eval()

# Define the dataloader
train_dataloader, val_dataloader = imagenet.prepare_dataloaders(
        img_size=model.img_size,
        batch_size=batch_size,
        num_workers=num_workers,
        data_root="data",
    )

In [None]:
# Get a batch of images and labels
images, labels = next(iter(train_dataloader))
images = images.to(device)
labels = labels.to(device)

# Forward pass through the model
with torch.no_grad():
    outputs = model(images)

# Get the attention maps
attention_maps = get_attention_maps(model)

# Plot the attention maps
images = images.cpu()
attention_maps = attention_maps.cpu()
plot_attention_maps(
    images,
    attention_maps,
    num_registers=model.num_registers,
    save_path=None,
    imshow_interpolation=None,
)