In [1]:
from data import imagenet as imgnt
from utils.evaluate import evaluate, evaluate_wrapped, Accuracy

import torch
import torch.nn as nn
import torch.nn.functional as F
import timm

import random
import inspect


from PIL import Image
import requests
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'retina'

import torchvision
import torchvision.transforms as T

from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD


from timm.models.vision_transformer import VisionTransformer

from models.deit import MaskedDeiT as MD

In [2]:
dataset_path = "/scratch_shared/primmere/ILSVRC/Data/CLS-LOC"
imagenet = imgnt.ImageNet(dataset_path, 1)
val = imagenet.get_valid_set()
device = torch.device("cuda")

model = timm.create_model('deit_tiny_patch16_224.fb_in1k', pretrained=True)
model.to(device)
model.eval()
next(model.parameters()).is_cuda 
val

Dataset ImageNetDataset
    Number of datapoints: 50000
    Root location: /scratch_shared/primmere/ILSVRC/Data/CLS-LOC/val
    Compose(
    ToTensor()
    Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=True)
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)

In [3]:
#indices = random.sample(range(50_000),1000)
indices = imgnt.get_sample_indices_for_class(val, list(range(10)), 10_000, device)
val_small = imgnt.ImageNetSubset(val,indices)


In [4]:
val_loader = torch.utils.data.DataLoader(val_small,4, pin_memory=True)
img, label = next(iter(val_loader))
img, label = img.to(device), label.to(device)

In [5]:
label.type()

'torch.cuda.LongTensor'

In [6]:
img.device



device(type='cuda', index=0)

In [7]:
label

tensor([0, 0, 0, 0], device='cuda:0')

In [8]:
torch.argmax(model(img), dim=1)

tensor([  0,   0, 391,   0], device='cuda:0')

In [9]:
acc = Accuracy()
results = evaluate(model, val_loader, acc, device)
print(results['confusion_matrix'])
print(results['total_accuracy'])

Evaluating: 100%|██████████| 125/125 [00:03<00:00, 33.90batch/s]

[[47  0  0 ...  1  0  0]
 [ 0 45  0 ...  0  0  0]
 [ 0  0 31 ...  0  0  0]
 ...
 [ 0  0  0 ...  0  0  0]
 [ 0  0  0 ...  0  0  0]
 [ 0  0  0 ...  0  0  0]]
0.806





In [10]:
L = len(model.blocks)
print(L)
d_l = model.blocks[0].attn.num_heads
print(d_l)
model.blocks[0].attn.head_dim

12
3


64

In [11]:
for p in model.parameters():
    p.requires_grad = False

In [12]:
"""
target = model.blocks[0].attn.qkv
handle = target.register_forward_hook(
    lambda m, inp, out: setattr(m, "saved_input", inp[0].detach())
)

pred = model(img)

x = target.saved_input
"""

In [13]:
model.blocks[0].attn??

[0;31mSignature:[0m       [0mmodel[0m[0;34m.[0m[0mblocks[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m.[0m[0mattn[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mType:[0m            Attention
[0;31mString form:[0m    
Attention(
  (qkv): Linear(in_features=192, out_features=576, bias=True)
  (q_norm): Identity()
  (k_norm): Identity()
  (attn_drop): Dropout(p=0.0, inplace=False)
  (proj): Linear(in_features=192, out_features=192, bias=True)
  (proj_drop): Dropout(p=0.0, inplace=False)
)
[0;31mFile:[0m            ~/.conda/envs/dfr2/lib/python3.10/site-packages/timm/models/vision_transformer.py
[0;31mSource:[0m         
[0;32mclass[0m [0mAttention[0m[0;34m([0m[0mnn[0m[0;34m.[0m[0mModule[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0mfused_attn[0m[0;34m:[0m [0mFinal[0m[0;34m[[0m[0mbool[0m[0;34m][0m[0;34m[0m
[0;34m[0m[0;34m[0m
[0;34m[0m    [0;32mdef[0m [

In [16]:
class MaskedAttn(nn.Module):
    """
    Attn module with mask
    """
    def __init__(self, attn: nn.Module, num_classes: int):
        super().__init__()
        
        self.qkv = attn.qkv
        self.proj = attn.proj
        self.attn_drop = attn.attn_drop
        self.proj_drop = attn.proj_drop
        self.scale = attn.scale
        self.num_heads = attn.num_heads
        self.head_dim = attn.head_dim
        self.num_classes = 10
        self.mask_attn = nn.Parameter(torch.ones(self.num_classes, self.num_heads, self.head_dim)) # (C, H, D)
        out_proj, in_proj = self.proj.weight.shape
        self.mask_proj = nn.Parameter(torch.ones(self.num_classes, out_proj, in_proj)) # (C, out_proj, in_proj)

        for p in self.qkv.parameters(): p.requires_grad = False
        for p in self.proj.parameters(): p.requires_grad = False

    def forward(self, x: torch.Tensor, y=None) -> torch.Tensor:
        B, N, C = x.shape # B = batch, N = num tokens (cls+patches), C = embed dim (head dim)
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)  # (B, H, N, D)

        q = q * self.scale
        a = q @ k.transpose(-2, -1)
        a = a.softmax(dim=-1)
        a = self.attn_drop(a)
        o = a @ v

        if y is not None:
            M = self.mask_attn[y] # (B, H, D)
            M = M.unsqueeze(2) # (B, H, 1, D)
            M_proj = self.mask_proj[y]
        else:
            assert 1 == 2, "todo"
            
            #todo: should take avg over classes and multiply

        o = o * M

        x = o.transpose(1, 2).reshape(B, N, C)

        W_proj = self.proj.weight.unsqueeze(0)*M_proj # B, out, in
        x = torch.einsum('bni,boi->bno', x, W_proj) + self.proj.bias.unsqueeze(0).unsqueeze(0) # B, N, C
        # einsum bc we have weights with batch dim, so need batch mat mul

        x = self.proj_drop(x)

        return x



class MaskedMlp(nn.Module):
    """
    MLP block with mask
    """
    def __init__(self, mlp: nn.Module, num_classes: int):
        super().__init__()

        self.fc1 = mlp.fc1
        self.act = mlp.act
        self.drop1 = mlp.drop1
        self.norm = mlp.norm
        self.fc2 = mlp.fc2
        self.drop2 = mlp.drop2

        for p in self.fc1.parameters(): p.requires_grad = False
        for p in self.fc2.parameters(): p.requires_grad = False
        
        self.num_classes = 10

        out1, in1 = self.fc1.weight.shape
        out2, in2 = self.fc2.weight.shape

        self.mask_fc1 = nn.Parameter(torch.ones(self.num_classes, out1, in1))
        self.mask_fc2 = nn.Parameter(torch.ones(self.num_classes, out2, in2))


    def forward(self, x: torch.Tensor, y=None) -> torch.Tensor:
        # x = B,N,C
        # y = B
        if y is not None:
            M_fc1 = self.mask_fc1[y] #B, out, in
            M_fc2 = self.mask_fc2[y] #B, out, in
        else:
            assert 1 == 2, "todo for inference"
            
        W1 = self.fc1.weight.unsqueeze(0) * M_fc1 #B, out1, in1 (unsqeeze to add batch dim to fc1 weights)
        b1 = self.fc1.bias # out1
        x = torch.einsum('bni,boi->bno', x, W1) + b1.unsqueeze(0).unsqueeze(0) #B, N, out1
        # einsum bc we have weights with batch dim, so need batch mat mul
        x = self.act(x)
        x = self.drop1(x)
        x = self.norm(x)
        W2 = self.fc2.weight.unsqueeze(0) * M_fc2 #B, out2, in2 (=out1)
        b2 = self.fc2.bias
        x = torch.einsum('bni,boi->bno', x, W2) + b2.unsqueeze(0).unsqueeze(0)
        x = self.drop2(x)
        return x # B, N, C

class MaskedDeiT(nn.Module):
    """
    deit with masked attn
    """
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.num_classes = model.num_classes

        self.masked_attn = nn.ModuleList()
        self.masked_mlp = nn.ModuleList()
        for blk in self.model.blocks:
            self.masked_attn.append(MaskedAttn(blk.attn, num_classes = self.num_classes))
            self.masked_mlp.append(MaskedMlp(blk.mlp, num_classes = self.num_classes))
            

        # turn off original modules
        for blk in self.model.blocks:
            for p in blk.attn.parameters(): p.requires_grad = False
            for p in blk.mlp.parameters(): p.requires_grad = False


    def forward_features(self, x, y=None):
        B = x.shape[0] # batch size
        x = self.model.patch_embed(x)
        cls_tok = self.model.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tok, x), dim=1) 
        x = x + self.model.pos_embed
        x = self.model.pos_drop(x)

        
        for i, blk in enumerate(self.model.blocks):
            attn_out = self.masked_attn[i](blk.norm1(x), y)
            x = x + blk.drop_path1(blk.ls1(attn_out))
            
            mlp_out = self.masked_mlp[i](blk.norm2(x), y)
            x = x + blk.drop_path2(blk.ls2(mlp_out))
            
        
        x = self.model.norm(x)
        return x[:, 0]

    def forward(self, x, y=None):
        x = self.forward_features(x, y)
        return self.model.head(x)
    

masked_attn = MaskedAttn(model.blocks[0].attn, 1000)
masked_attn
print(masked_attn)



wrapped = MaskedDeiT(model)
wrapped.to(device)
wrapped.eval()



    
print(wrapped(img,label))
print(model(img))

MaskedAttn(
  (qkv): Linear(in_features=192, out_features=576, bias=True)
  (proj): Linear(in_features=192, out_features=192, bias=True)
  (attn_drop): Dropout(p=0.0, inplace=False)
  (proj_drop): Dropout(p=0.0, inplace=False)
)
tensor([[ 7.2938e+00,  1.3162e+00,  2.5841e+00,  ...,  1.7861e+00,
         -5.9208e-01,  8.4107e-02],
        [ 9.4730e+00,  2.1688e+00, -7.6427e-01,  ...,  2.9838e+00,
          1.5964e+00, -1.3424e+00],
        [ 6.0448e+00, -3.7651e-01,  1.1977e+00,  ...,  4.6082e-01,
          6.4149e-01, -1.8328e-04],
        [ 9.1301e+00,  1.2231e+00, -4.9734e-01,  ...,  3.2658e+00,
          9.5074e-01, -9.3547e-01]], device='cuda:0', grad_fn=<AddmmBackward0>)
tensor([[ 7.2938e+00,  1.3162e+00,  2.5840e+00,  ...,  1.7861e+00,
         -5.9208e-01,  8.4104e-02],
        [ 9.4730e+00,  2.1688e+00, -7.6427e-01,  ...,  2.9838e+00,
          1.5964e+00, -1.3424e+00],
        [ 6.0448e+00, -3.7651e-01,  1.1977e+00,  ...,  4.6081e-01,
          6.4148e-01, -1.8060e-04],
      

In [15]:
with torch.no_grad():
    out_masked = masked_attn(x, label)
    out_ref    = model.blocks[0].attn(x)

print(torch.allclose(out_masked, out_ref, rtol=1e-5, atol=1e-6))
print("max abs diff:", (out_masked - out_ref).abs().max().item())

RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)

In [None]:
model??

In [None]:
acc = Accuracy()
results = evaluate(model, val_loader, acc, device)
print(results['confusion_matrix'])
print(results['total_accuracy'])

In [None]:
acc = Accuracy()
results = evaluate_wrapped(wrapped, val_loader, acc, device)
print(results['confusion_matrix'])
print(results['total_accuracy'])




In [None]:
wrapped.to(device)
wrapped(img,y=label)

In [None]:
model(img)

In [None]:
wrapped2 = MD(model)
wrapped2.to(device)
wrapped2.eval()
wrapped2(img,label)

In [None]:
mask_params = [m.mask for m in wrapped2.masked_attn]
