In [1]:
from data import imagenet as imgnt
from utils.evaluate import evaluate, evaluate_wrapped, Accuracy
from x_pruner.x_loss import XPrunerLoss


import torch
import torch.nn as nn
import timm

import random
import inspect


from PIL import Image
import requests
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'retina'

import torchvision
import torchvision.transforms as T

from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD


from timm.models.vision_transformer import VisionTransformer

from models.deit import MaskedDeiT

In [2]:
dataset_path = "/scratch_shared/primmere/ILSVRC/Data/CLS-LOC"
imagenet = imgnt.ImageNet(dataset_path, 1)
val = imagenet.get_valid_set()
train = imagenet.get_train_set()
device = torch.device("cuda")

model = timm.create_model('deit_tiny_patch16_224.fb_in1k', pretrained=True)
model.to(device)
model.eval()

n_classes = 100

val_indices = imgnt.get_sample_indices_for_class(val, list(range(n_classes)), 10_000, device)
train_indices = imgnt.get_sample_indices_for_class(train, list(range(8,10)), 100, device)
val_small = imgnt.ImageNetSubset(val,val_indices)
train_small = imgnt.ImageNetSubset(train,train_indices)

val_loader = torch.utils.data.DataLoader(val_small,8, pin_memory=True)
train_loader = torch.utils.data.DataLoader(train_small,8, pin_memory=True)

x, y = next(iter(train_loader))
x, y = x.to(device), y.to(device)

In [3]:
torch.argmax(model(x), dim=1)

tensor([38, 38, 38, 38, 38, 38, 38, 38], device='cuda:0')

In [4]:
wrapped = MaskedDeiT(model)
wrapped.to(device)
wrapped.eval()

acc = Accuracy()
results = evaluate_wrapped(wrapped, val_loader, acc, device)
print(results['confusion_matrix'])
print(results['total_accuracy'])

Evaluating: 100%|██████████| 625/625 [00:40<00:00, 15.29batch/s]

[[47  0  0 ...  0  0  0]
 [ 0 45  0 ...  0  0  0]
 [ 0  0 31 ...  0  0  0]
 ...
 [ 0  0  0 ...  0  0  0]
 [ 0  0  0 ...  0  0  0]
 [ 0  0  0 ...  0  0  0]]
0.7612





In [5]:
mask_params = [m.mask for m in wrapped.masked_attn]
for p in wrapped.model.parameters():
    p.requires_grad = False

optimiser = torch.optim.AdamW(mask_params, lr=1e-2, weight_decay=1e-4)
scaler = torch.cuda.amp.GradScaler(enabled=True)
criterion = XPrunerLoss()
def one_epoch(model, loader, optimiser, device, mask_params,  criterion):
    model.train()
    running_loss = 0.0
    for imgs, ys in loader:
        imgs, ys = imgs.to(device), ys.to(device)
        optimiser.zero_grad(set_to_none=True)
        
        with torch.cuda.amp.autocast(enabled=True):
            logits = model(imgs, y=ys)
            loss, losses_d = criterion(logits, ys, mask_params)
            
        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(mask_params, max_norm=1.0)
        scaler.step(optimiser)
        scaler.update()

        running_loss += loss.item() * imgs.size(0)
        
    return running_loss / len(loader.dataset)


num_epochs = 10
for epoch in range(num_epochs):
    loss = one_epoch(wrapped, val_loader, optimiser, device, mask_params, criterion)
    print(epoch, loss)
    if (epoch%5==0 and epoch != 0) or epoch==num_epochs-1:
        results = evaluate_wrapped(wrapped, val_loader, acc, device)
        print("accuracy", results['total_accuracy'])

print("done")
    

  scaler = torch.cuda.amp.GradScaler(enabled=True)
  with torch.cuda.amp.autocast(enabled=True):


0 2.909858106613159


Evaluating: 100%|██████████| 625/625 [00:36<00:00, 17.31batch/s]
  with torch.cuda.amp.autocast(enabled=True):


accuracy 0.9828
1 0.8089711599588394
2 0.14656420004963874
3 0.08704657775461674
4 0.20367838396430016
5 0.06729550706297159


Evaluating: 100%|██████████| 625/625 [00:36<00:00, 17.33batch/s]
  with torch.cuda.amp.autocast(enabled=True):


accuracy 0.9914
6 0.06830279025733471
7 0.042223666368424895
8 0.04626942673325539
9 0.07798429832756519


Evaluating: 100%|██████████| 625/625 [00:37<00:00, 16.46batch/s]

accuracy 0.995
done





In [8]:
wrapped.train()
trainables = [n for n,p in wrapped.named_parameters() if p.requires_grad]
print("Trainables:", trainables)

Trainables: ['masked_attn.0.mask', 'masked_attn.1.mask', 'masked_attn.2.mask', 'masked_attn.3.mask', 'masked_attn.4.mask', 'masked_attn.5.mask', 'masked_attn.6.mask', 'masked_attn.7.mask', 'masked_attn.8.mask', 'masked_attn.9.mask', 'masked_attn.10.mask', 'masked_attn.11.mask']


In [7]:
wrapped.masked_attn[0].mask[0,2,:]

tensor([0.0908, 0.0793, 0.1752, 0.1471, 0.4191, 0.1060, 0.2359, 0.1470, 0.0501,
        0.1569, 0.1021, 0.1820, 0.0720, 0.1379, 0.1273, 0.3824, 0.3415, 0.1866,
        0.3188, 0.1079, 0.0080, 0.1638, 0.2152, 0.0860, 0.0944, 0.2406, 0.3069,
        0.1688, 0.4972, 0.1870, 0.2843, 0.3279, 0.3047, 0.1171, 0.1442, 0.3523,
        0.1132, 0.2627, 0.1523, 0.2019, 0.2042, 0.6033, 0.0898, 0.1042, 0.1890,
        0.1392, 0.2131, 0.6359, 0.1650, 0.0735, 0.3048, 0.1388, 0.1624, 0.1475,
        0.0743, 0.2087, 0.2120, 0.5012, 0.1749, 0.2162, 0.1307, 0.1466, 0.4093,
        0.3833], device='cuda:0', grad_fn=<SliceBackward0>)