In [1]:
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import numpy as np
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
import imageio, warnings, os

from interpretation_methods import *
from utils.imagenet_seg_loader import ImagenetSegLoader
from utils.model_loaders import vit_base_patch16_224_dino, vit_base_patch16_224
from utils.input_arguments import get_arg_parser
from utils.saver import Saver
from utils.sideplot import side_plot
from utils.image_denorm import image_vizformat

warnings.filterwarnings("ignore")
plt.switch_backend("agg")

In [2]:
data_path = "dataset/gtsegs_ijcv.mat"
data_length = 100
batch_size = 1
num_workers = 7

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
label_transform = transforms.Compose([transforms.Resize((224, 224), Image.NEAREST), ])

In [4]:
dataset = ImagenetSegLoader(data_path, data_length, transform=image_transform, target_transform=label_transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=False)
dataloader = tqdm(dataloader)  # Would help tracking loop iteration along with setting some verbose text.

model = vit_base_patch16_224_dino(pretrained=True).to(device)

  0%|          | 0/100 [00:00<?, ?it/s]

In [8]:
def model_patching(x):
    blocks = model.blocks

    x = model.patch_embed(x)
    x = model.pos_drop(x)
    for blk in blocks:
        x = blk(x)
    x = model.norm(x)

    return x

In [7]:
def image_vizformat(img):
    inr = transforms.Normalize(mean=[-0.5/.5, -0.5/.5, -0.5/.5], std=[1/0.5, 1/0.5, 1/0.5])
    img = inr(img[0])
    img = torch.permute(img, (1, 2, 0))
    return img.detach().cpu().numpy()

In [6]:
imgs = []
masks = []

for ix, d in tqdm(enumerate(dataloader)):
    imgs.append(d[0])
    masks.append(d[1])



100%|██████████| 100/100 [00:45<00:00,  2.22it/s]
100it [00:40,  2.50it/s]


In [85]:
output = model(imgs[3])
torch.argmax(output)

tensor(166)

In [86]:
embeddings = model_patching(imgs[19])

Shape: torch.Size([1, 196, 768])


In [88]:
block0 = model.blocks[0]

In [89]:
block0

Block(
  (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
  (attn): Attention(
    (qkv): Linear(in_features=768, out_features=2304, bias=True)
    (attn_drop): Dropout(p=0.0, inplace=False)
    (proj): Linear(in_features=768, out_features=768, bias=True)
    (proj_drop): Dropout(p=0.0, inplace=False)
  )
  (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
  (mlp): Mlp(
    (fc1): Linear(in_features=768, out_features=3072, bias=True)
    (act): GELU(approximate='none')
    (fc2): Linear(in_features=3072, out_features=768, bias=True)
    (drop): Dropout(p=0.0, inplace=False)
  )
)

In [97]:
b0w0 = block0.norm1.weight
b0a = block0.attn.get_attention_map()
b0w1 = block0.norm2.weight
b0w2 = block0.mlp.fc1.weight
b0w3 = block0.mlp.act(block0.mlp.fc2.weight)

In [100]:
print(b0w3.shape)
print(b0w2.shape)
print(b0w1.shape)
print(b0a.shape)
print(b0w0.shape)

torch.Size([768, 3072])
torch.Size([3072, 768])
torch.Size([768])
torch.Size([1, 12, 196, 196])
torch.Size([768])


In [106]:
coeff = torch.matmul(torch.mm(b0w3, b0w2), b0w1)

In [15]:
x = model.patch_embed(imgs[0])
x = model.pos_drop(x)
x = model.blocks[0].norm1(x)
x = model.blocks[0].attn(x)
y = model.blocks[0].norm2(x)

Block(
  (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
  (attn): Attention(
    (qkv): Linear(in_features=768, out_features=2304, bias=True)
    (attn_drop): Dropout(p=0.0, inplace=False)
    (proj): Linear(in_features=768, out_features=768, bias=True)
    (proj_drop): Dropout(p=0.0, inplace=False)
  )
  (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
  (mlp): Mlp(
    (fc1): Linear(in_features=768, out_features=3072, bias=True)
    (act): GELU(approximate='none')
    (fc2): Linear(in_features=3072, out_features=768, bias=True)
    (drop): Dropout(p=0.0, inplace=False)
  )
)

In [17]:
x.shape

torch.Size([1, 196, 768])

In [18]:
y.shape

torch.Size([1, 196, 768])