In [1]:
from PIL import Image
import torch
import torchvision
from prototypes.deeplearning.models import VitPrototype1Dropout
from prototypes.utility.data import DataLoader
from prototypes.utility.data import ProjectConfiguration

In [2]:
config = ProjectConfiguration("../config.json")

In [3]:
data_loader = DataLoader(data_path=config.get_value("TRAIN_IMAGES_PATH"), metadata_path=config.get_value("TRAIN_METADATA"))

In [4]:
cancer_images = data_loader.get_data(target=1, n_sample=300, width=128, height=128)
non_cancer_images = data_loader.get_data(target=0, n_sample=300, width=128, height=128)

In [5]:
import numpy as np

with torch.no_grad():
    # inputs = torchvision.models.ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1.transforms()(Image.fromarray(cancer_images[3]))
    # outputs = vit_model(inputs.unsqueeze(0))
    inputs = torch.tensor(np.array([torchvision.models.ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1.transforms()(Image.fromarray(image)) for image in cancer_images]))
    # outputs = model(torch.tensor(inputs))

In [41]:
# give image to the model and collect the attention maps
model = VitPrototype1Dropout(n_classes=1)
model.load_state_dict(torch.load("../checkpoint_resnet50_mix_up/0.1.1_vit16Dropout_best.pt", weights_only=True))
model.eval()

# Register hooks to capture the attention weights
attention_maps = []
def get_attention_weights(module, input, output):
    # Store the attention weights
    attention_maps.append(output[0])
    # attention_maps.append(module.self_attention.attn_output_weights.deatch())

for name, module in model.named_modules():
    if isinstance(module, torchvision.models.vision_transformer.EncoderBlock):
        print(name)
        module.self_attention.register_forward_hook(get_attention_weights)

# Pass the image through the model
with torch.no_grad():
    output = model(inputs[0].unsqueeze(0))
    # Convert attention maps to numpy arrays
    attention_maps = [att.cpu().numpy() for att in attention_maps]

In [46]:
model = VitPrototype1Dropout(n_classes=1)
model.load_state_dict(torch.load("../checkpoint_resnet50_mix_up/0.1.1_vit16Dropout_best.pt", weights_only=True))
model.eval()

def get_attention_weights(module, input, output):
    # Store the attention weights
    attention_maps.append(output)

handle = model.model.encoder.layers.encoder_layer_11.ln_1.register_forward_hook(get_attention_weights)
model.eval()

In [47]:
model.eval()

# Pass the image through the model
with torch.no_grad():
    # We get the output
    output = model(inputs[0].unsqueeze(0))
    # print(output.shape)

    # So here we get the weights and biases for the quer, key, and value
    qkv_w = model.model.encoder.layers.encoder_layer_11.self_attention.in_proj_weight
    qkv_b = model.model.encoder.layers.encoder_layer_11.self_attention.in_proj_bias

    print(f"The shape of qkv weight matrix before reshaping is {qkv_w.shape}\n")
    print(f"The shape of qkv bias matrix before reshaping is {qkv_b.shape}\n")
    # print(qkv_w.shape)
    """we have shape of (2304 * 768), we need to understand what is the meaning of the dimensions we have?
    first of all, the 768 represnets the D-hidden dimension through the encoder of the vision transformer which is fiexd across all of the encoder network.
    2304 is a little bit tricky and you need to check the original paper to understand why the shape looks like that.

    We have 3 components (query, keys, and values) for each head, and at the encoder (Architecture dependent) we have 12 heads, then we explore this as first divide 2304 by 12 to get dimensions for each head = 2304/12 = 192, here remember that we have 3 matrices stacked so 192/3 = 64, 
    which is the dimension of the head mentioned in the paper as D_{h} = D/k, and K is the number of heads which is 12 for the vit_b_16()"""

    #shape here is (matrices, d_head *k, d_hidden)
    qkv_w = qkv_w.reshape(3, -1, 768)
    qkv_b = qkv_b.reshape(12, -1, 64)

    print(f"The shape of qkv weight matrix after reshaping is {qkv_w.shape}\n")
    print(f"The shape of qkv bias matrix after reshaping is {qkv_b.shape}\n")

    "Here we get the weights and biases for each component for all of the heads"
    
    #shape here for each weight component is (d_head *k, d_hidden)
    q_w_12_heads = qkv_w[0,:,:]
    k_w_12_heads = qkv_w[1,:,:]
    v_w_12_heads = qkv_w[2,:,:]

    

    q_b_12_heads = qkv_b[:,0,:]
    k_b_12_heads = qkv_b[:,1,:]
    v_b_12_heads = qkv_b[:,2,:]


    print(f"The shape of query weight matrix before reshaping is {q_w_12_heads.shape}, key weight is {k_w_12_heads.shape}, and values weight is {v_w_12_heads.shape}\n")
    print(f"The shape of query bias matrix before reshaping is {q_b_12_heads.shape}, key bias is {k_b_12_heads.shape}, and values bias is {v_b_12_heads.shape}\n")

    # Shape here is (no.head, d_head, d_hidden)
    q_w_12_heads = q_w_12_heads.reshape(12, -1, 768)
    k_w_12_heads = k_w_12_heads.reshape(12, -1, 768)
    v_w_12_heads = v_w_12_heads.reshape(12, -1, 768)

    


    
    print(f"The shape of query weight matrix after reshaping is {q_w_12_heads.shape}, key weight is {k_w_12_heads.shape}, and values weight is {v_w_12_heads.shape}\n")
    # Shape here for each weight component is(d_head, d_hidden)
    q_w_1_head = q_w_12_heads[0,:,:]
    k_w_1_head = k_w_12_heads[0,:,:]
    v_w_1_head = v_w_12_heads[0,:,:]

    q_b_1_head = q_b_12_heads[0,:]
    k_b_1_head = k_b_12_heads[0,:]
    v_b_1_head = v_b_12_heads[0,:]

    print(f"The shape of query weight matrix after reshaping for one head is {q_w_1_head.shape}, key weight is {k_w_1_head .shape}, and values weight is {v_w_1_head .shape}\n")
    print(f"The shape of query bias matrix after reshaping for one head is {q_b_1_head.shape}, key bias is {k_b_1_head .shape}, and values bias is {v_b_1_head .shape}\n")


    out_encoder_10 = attention_maps[0][0]
    out_encoder_10 = out_encoder_10.unsqueeze(0)
    # print(out_encoder_10.shape)


    # place holder to get the attention weights from the heads to use it for later calculations
    att_weights =[]
    satt = []

    # This loop is created to loop over the heads, in order to get all of the attention matrices (qk^{T}) per heads
    for i in range(12):
        q_w = q_w_12_heads[i,:,:]
        k_w = k_w_12_heads[i,:,:]
        v_w = v_w_12_heads[i,:,:]

        q_b = q_b_12_heads[i,:]
        k_b = k_b_12_heads[i,:]
        v_b = v_b_12_heads[i,:]

        

        q = torch.matmul(out_encoder_10, q_w.T) 
        k = torch.matmul(out_encoder_10, k_w.T) 
        v = torch.matmul(out_encoder_10, v_w.T) 

        qk = torch.matmul(q, k.transpose(2, 1))/8
        qk = torch.softmax(qk, dim=(2))
        # print(qk.shape)
        att_weights.append(qk)

In [49]:
len(att_weights)

In [48]:
att_weights[0].shape

In [50]:
attention_map = att_weights[11][0, 0]

In [51]:
attention_map.shape

In [53]:
attention_map.reshape(24, 24)

In [54]:
np.sqrt(577)

In [None]:
import matplotlib.pyplot as plt

# Choose a layer and a head to visualize
layer_idx = 11  # Index of the layer (ViT-B/16 has 12 layers)
head_idx = 0    # Index of the head (each layer has multiple heads, e.g., 12)

# Get the attention map for the chosen layer and head
attention_map = attention_maps[layer_idx][0, head_idx]  # [batch, head, tokens, tokens]

# Visualize the attention map
cls_attention = attention_map[:, 0]  # Attention to the [CLS] token
cls_attention = cls_attention.reshape(14, 14)  # Reshape to 2D (for 14x14 patches)

plt.imshow(cls_attention, cmap='viridis')
plt.title(f'Layer {layer_idx + 1}, Head {head_idx + 1}')
plt.colorbar()
plt.show()


In [None]:
(outputs > 0.5).sum()

In [None]:
prediction = outputs[0].argmax(-1)
print(f"Predicted class: {vit_model.config.id2label[prediction.item()]} | {prediction.item()}")

In [None]:
# get the attention from the last layer
attentions = outputs.attentions[-1].squeeze(0)
fig, ax = plt.subplots(4, 4, figsize=(15, 15))


for i in range(12):
    attention_map = attentions[i].detach().numpy()
    
    # iterate over the heads and obtain the CLS token
    cls_attention_map = attention_map[0, 1:]
    
    # Now patches are 14 by 14 thus 196 needs to be reshaped first
    cls_attention_map = cls_attention_map.reshape(14, 14)
    
    #Resize the heatmap to overlap over the image
    cls_attention_map = cv2.resize(cls_attention_map, img.size)
    
    ax[i//3, 0].imshow(img)
    ax[i//3, 0].axis('off')
    ax[i//3, 0].set_title(f'Original Image of a: [{vit_model.config.id2label[prediction.item()]}]')
    
    # overlap attention map over image
    ax[i//3, (i%3)+1].imshow(img)
    ax[i//3, (i%3)+1].imshow(cls_attention_map, cmap='jet', alpha=0.4)
    
    ax[i//3, (i%3)+1].axis('off')
    ax[i//3, (i%3)+1].set_title(f'Attention map - [Head: {i + 1}]')

plt.show()