In [18]:
from data import imagenet as imgnt
from utils.evaluate import evaluate, evaluate_wrapped, Accuracy

import torch
import torch.nn as nn
import timm

import random
import inspect


from PIL import Image
import requests
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'retina'

import torchvision
import torchvision.transforms as T

from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD


from timm.models.vision_transformer import VisionTransformer


In [2]:
dataset_path = "/scratch_shared/primmere/ILSVRC/Data/CLS-LOC"
imagenet = imgnt.ImageNet(dataset_path, 1)
val = imagenet.get_valid_set()
device = torch.device("cuda")

model = timm.create_model('deit_tiny_patch16_224.fb_in1k', pretrained=True)
model.to(device)
model.eval()
next(model.parameters()).is_cuda 
val

Dataset ImageNetDataset
    Number of datapoints: 50000
    Root location: /scratch_shared/primmere/ILSVRC/Data/CLS-LOC/val
    Compose(
    ToTensor()
    Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=True)
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)

In [3]:
#indices = random.sample(range(50_000),1000)
indices = imgnt.get_sample_indices_for_class(val, list(range(10)), 10_000, device)
val_small = imgnt.ImageNetSubset(val,indices)


In [4]:
val_loader = torch.utils.data.DataLoader(val_small,8, pin_memory=True)
img, label = next(iter(val_loader))
img, label = img.to(device), label.to(device)

In [5]:
img.device



device(type='cuda', index=0)

In [7]:
label

tensor([0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')

In [8]:
torch.argmax(model(img), dim=1)

tensor([  0,   0, 391,   0,   0,   0,   0,   0], device='cuda:0')

In [9]:
acc = Accuracy()
results = evaluate(model, val_loader, acc, device)
print(results['confusion_matrix'])
print(results['total_accuracy'])

Evaluating: 100%|██████████| 63/63 [00:04<00:00, 15.66batch/s]

[[47  0  0 ...  1  0  0]
 [ 0 45  0 ...  0  0  0]
 [ 0  0 31 ...  0  0  0]
 ...
 [ 0  0  0 ...  0  0  0]
 [ 0  0  0 ...  0  0  0]
 [ 0  0  0 ...  0  0  0]]
0.806





In [10]:
L = len(model.blocks)
print(L)
d_l = model.blocks[0].attn.num_heads
print(d_l)
model.blocks[0].attn.head_dim

12
3


64

In [11]:
for p in model.parameters():
    p.requires_grad = False

In [12]:
model.blocks[0].attn??

[0;31mSignature:[0m       [0mmodel[0m[0;34m.[0m[0mblocks[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m.[0m[0mattn[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mType:[0m            Attention
[0;31mString form:[0m    
Attention(
  (qkv): Linear(in_features=192, out_features=576, bias=True)
  (q_norm): Identity()
  (k_norm): Identity()
  (attn_drop): Dropout(p=0.0, inplace=False)
  (proj): Linear(in_features=192, out_features=192, bias=True)
  (proj_drop): Dropout(p=0.0, inplace=False)
)
[0;31mFile:[0m            ~/.conda/envs/dfr2/lib/python3.10/site-packages/timm/models/vision_transformer.py
[0;31mSource:[0m         
[0;32mclass[0m [0mAttention[0m[0;34m([0m[0mnn[0m[0;34m.[0m[0mModule[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0mfused_attn[0m[0;34m:[0m [0mFinal[0m[0;34m[[0m[0mbool[0m[0;34m][0m[0;34m[0m
[0;34m[0m[0;34m[0m
[0;34m[0m    [0;32mdef[0m [

In [26]:
class MaskedAttn(nn.Module):
    """
    Attn module with mask
    """
    def __init__(self, attn: nn.Module, num_classes: int):
        super().__init__()
        
        self.qkv = attn.qkv
        self.proj = attn.proj
        self.attn_drop = attn.attn_drop
        self.proj_drop = attn.proj_drop
        self.scale = attn.scale
        self.num_heads = attn.num_heads
        self.head_dim = attn.head_dim
        self.num_classes = num_classes
        self.mask = nn.Parameter(torch.ones(self.num_classes, self.num_heads, self.head_dim)) # (C, H, D)
    
        for p in self.qkv.parameters(): p.requires_grad = False
        for p in self.proj.parameters(): p.requires_grad = False

    def forward(self, x: torch.Tensor, y=None) -> torch.Tensor:
        B, N, C = x.shape # B = batch, N = num tokens (cls+patches), C = embed dim (head dim)
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)  # (B, H, N, D)

        q = q * self.scale
        a = q @ k.transpose(-2, -1)
        a = a.softmax(dim=-1)
        a = self.attn_drop(a)
        o = a @ v

        if y is not None:
            M = self.mask[y] # (B, H, D)
            M = M.unsqueeze(2) # (B, H, 1, D)
        else:
            assert 1 == 2, "todo"
            
            #todo: should take avg over classes and multiply

        o = o * M

        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)

        return x
        

masked_attn = MaskedAttn(model.blocks[0].attn, 1000)
masked_attn
print(masked_attn)

class MaskedDeiT(nn.Module):
    """
    deit with masked attn
    """
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.num_classes = model.num_classes

        self.masked_attn = nn.ModuleList()
        for blk in self.model.blocks:
            self.masked_attn.append(MaskedAttn(blk.attn, num_classes = self.num_classes))

        # turn off the original attn modules (we won’t call them)
        for blk in self.model.blocks:
            for p in blk.attn.parameters(): p.requires_grad = False


    def forward_features(self, x, y=None):
        B = x.shape[0] # batch size
        x = self.model.patch_embed(x)
        cls_tok = self.model.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tok, x), dim=1) 
        x = x + self.model.pos_embed
        x = self.model.pos_drop(x)

        
        for i, blk in enumerate(self.model.blocks):
            attn_out = self.masked_attn[i](blk.norm1(x), y)
            x = x + blk.drop_path1(blk.ls1(attn_out))
            
            mlp_out = blk.mlp(blk.norm2(x))
            x = x + blk.drop_path2(blk.ls2(mlp_out))
            
        
        x = self.model.norm(x)
        return x[:, 0]

    def forward(self, x, y=None):
        x = self.forward_features(x, y)
        return self.model.head(x)


wrapped = MaskedDeiT(model)
wrapped.to(device)
wrapped.eval()


MaskedAttn(
  (qkv): Linear(in_features=192, out_features=576, bias=True)
  (proj): Linear(in_features=192, out_features=192, bias=True)
  (attn_drop): Dropout(p=0.0, inplace=False)
  (proj_drop): Dropout(p=0.0, inplace=False)
)


MaskedDeiT(
  (model): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (patch_drop): Identity()
    (norm_pre): Identity()
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=192, out_features=576, bias=True)
          (q_norm): Identity()
          (k_norm): Identity()
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=192, out_features=192, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=192, out_features=768, bias=True)
          (act): GELU(approximate='none')
   

In [14]:

model.blocks[0].drop_path1??


[0;31mSignature:[0m      [0mmodel[0m[0;34m.[0m[0mblocks[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m.[0m[0mdrop_path1[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mType:[0m           Identity
[0;31mString form:[0m    Identity()
[0;31mFile:[0m           ~/.conda/envs/dfr2/lib/python3.10/site-packages/torch/nn/modules/linear.py
[0;31mSource:[0m        
[0;32mclass[0m [0mIdentity[0m[0;34m([0m[0mModule[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0;34mr"""A placeholder identity operator that is argument-insensitive.[0m
[0;34m[0m
[0;34m    Args:[0m
[0;34m        args: any argument (unused)[0m
[0;34m        kwargs: any keyword argument (unused)[0m
[0;34m[0m
[0;34m    Shape:[0m
[0;34m        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.[0m
[0;34m        - Output: :math:`(*)`, same shape as the input.[0m
[0;34m[0m
[0;34m    Examples::[0

In [20]:
acc = Accuracy()
results = evaluate(model, val_loader, acc, device)
print(results['confusion_matrix'])
print(results['total_accuracy'])

Evaluating: 100%|██████████| 63/63 [00:04<00:00, 14.79batch/s]

[[47  0  0 ...  1  0  0]
 [ 0 45  0 ...  0  0  0]
 [ 0  0 31 ...  0  0  0]
 ...
 [ 0  0  0 ...  0  0  0]
 [ 0  0  0 ...  0  0  0]
 [ 0  0  0 ...  0  0  0]]
0.806





In [25]:
acc = Accuracy()
results = evaluate_wrapped(wrapped, val_loader, acc, device)
print(results['confusion_matrix'])
print(results['total_accuracy'])




Evaluating: 100%|██████████| 63/63 [00:03<00:00, 15.79batch/s]

[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 1 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]
0.0





In [27]:
wrapped.to(device)
wrapped(img,y=label)

tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        ...,

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]]], device='cuda:0',
       grad_fn=<IndexBackward0>)
tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.

tensor([[-2.7773e+00, -1.6326e+00, -1.3924e+00,  ..., -1.5742e+00,
          6.4230e-01,  1.4097e+00],
        [-5.6951e-01, -6.7556e-01, -1.8383e-01,  ...,  4.0156e-01,
         -4.9376e-01,  8.7062e-04],
        [ 7.3232e-01,  3.3133e-01,  5.8941e-01,  ..., -9.1473e-01,
          6.0892e-01,  1.8116e+00],
        ...,
        [-9.1319e-01, -9.5612e-01, -1.1950e+00,  ..., -1.3085e+00,
          1.0125e+00,  6.9102e-02],
        [-1.4980e+00, -5.9719e-01, -2.6913e+00,  ...,  1.5051e-01,
          2.6262e-01,  9.4503e-01],
        [-4.7712e-01, -8.7983e-01,  5.8697e-01,  ..., -8.5545e-01,
          2.8106e+00,  6.7254e-01]], device='cuda:0')

In [None]:
model(img).shape