In [1]:
from torchvision.models import vit_b_16
import torch.nn as nn
import torch

from surgeon_pytorch import Extract, get_nodes

from PIL import Image
import torchvision.transforms as transforms
import os

import numpy as np

from matplotlib import pyplot as plt

from captum.attr import LayerGradCam, LayerAttribution

In [2]:
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.reshape_transforms import vit_reshape_transform

In [3]:
model = vit_b_16(pretrained=True)
NUM_CLASSES = 2
model.heads.head = nn.Linear(768, NUM_CLASSES)
model.load_state_dict(torch.load("logs/vit_waterbirds.pth",weights_only=False,map_location=torch.device('cpu')))
model.eval()


model2 = vit_b_16(pretrained=True)
model2.heads.head = nn.Linear(768, NUM_CLASSES)
model2.load_state_dict(torch.load("logs/dfr_model.pth", weights_only=False, map_location=torch.device('cpu')))
model2.eval()



VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_a

In [4]:
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def tensorize(img_path):
    image = Image.open(img_path).convert('RGB')
    return preprocess(image).unsqueeze(0)

In [5]:
img = tensorize("notebooks/data/054.Blue_Grosbeak/Blue_Grosbeak_0002_36648.jpg")
img2 = tensorize("notebooks/data/001.Black_footed_Albatross/Black_Footed_Albatross_0007_796138.jpg")
out = model(img)
out2 = model(img2)
_, predicted_class = out.max(dim=1)
_, predicted_class2 = out2.max(dim=1)
print(predicted_class, predicted_class2)

tensor([0]) tensor([1])


In [6]:
def reshape_transform(tensor):
    tensor = tensor[:, 1:, :]
    seq_len = tensor.size(1)
    height = width = int(seq_len ** 0.5)
    result = tensor.reshape(tensor.size(0), height, width, tensor.size(2))
    # Permute to (batch_size, channels, height, width)
    result = result.permute(0, 3, 1, 2)
    return result

In [7]:
target_layers = [model.encoder.layers[-1].ln_1]

In [8]:
targets = [ClassifierOutputTarget(predicted_class.item())]
targets2 = [ClassifierOutputTarget(predicted_class2.item())]

In [9]:
cam = GradCAM(model=model, target_layers=target_layers,reshape_transform=reshape_transform)
cam2 = GradCAM(model=model2, target_layers=target_layers,reshape_transform=reshape_transform)

In [10]:
img2.shape

torch.Size([1, 3, 224, 224])

In [11]:
grayscale_cam = cam(input_tensor=img, targets=targets)

In [13]:
greyscale_cam2 = cam2(input_tensor=img, targets=targets)


AttributeError: 'NoneType' object has no attribute 'shape'

In [None]:
print(img.shape)
img_tensor = img[0,:,:,:]
img_np = img_tensor.permute(1, 2, 0).numpy()
img_np = np.clip(img_np, 0, 1)

# Overlay heatmap
visualization = show_cam_on_image(img_np, grayscale_cam[0, :], use_rgb=True)

# Display the result
plt.imshow(visualization)
plt.axis('off')
plt.show()

In [None]:
mlp_block = model.encoder.layers[-1].mlp

activations = {}
gradients = {}

def save_activation(name):
    def hook(module, input, output):
        activations[name] = output.detach()
    return hook

def save_gradient(name):
    def hook(module, input, output):
        gradients[name] = output[0].detach()
    return hook

mlp_block.register_forward_hook(save_activation("mlp_block"))
mlp_block.register_full_backward_hook(save_gradient("mlp_block"))

In [None]:
attention_module = model.encoder.layers[-1].self_attention

attention_weights = {}

def get_attention_weights(name):
    def hook(module, input, output):
        attn_output, attn_output_weights = output
        attention_weights[name] = attn_output_weights.detach()
    return hook

attention_module.register_forward_hook(get_attention_weights("attn_weights"))

In [None]:
out = model(img)


In [None]:
model.encoder.layers[-1]