In [1]:
from data import imagenet as imgnt
from utils.evaluate import evaluate, evaluate_wrapped, Accuracy

import torch
import torch.nn as nn
import timm

import random
import inspect


from PIL import Image
import requests
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'retina'

import torchvision
import torchvision.transforms as T

from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD


from timm.models.vision_transformer import VisionTransformer

from models.deit import MaskedDeiT

In [2]:
dataset_path = "/scratch_shared/primmere/ILSVRC/Data/CLS-LOC"
imagenet = imgnt.ImageNet(dataset_path, 1)
val = imagenet.get_valid_set()
train = imagenet.get_train_set()
device = torch.device("cuda")

model = timm.create_model('deit_tiny_patch16_224.fb_in1k', pretrained=True)
model.to(device)
model.eval()

val_indices = imgnt.get_sample_indices_for_class(val, list(range(10)), 10_000, device)
train_indices = imgnt.get_sample_indices_for_class(train, list(range(8,10)), 100, device)
val_small = imgnt.ImageNetSubset(val,val_indices)
train_small = imgnt.ImageNetSubset(train,train_indices)

val_loader = torch.utils.data.DataLoader(val_small,8, pin_memory=True)
train_loader = torch.utils.data.DataLoader(train_small,8, pin_memory=True)

x, y = next(iter(train_loader))
x, y = x.to(device), y.to(device)

In [3]:
torch.argmax(model(x), dim=1)

tensor([38, 38, 38, 38, 38, 38, 38, 38], device='cuda:0')

In [4]:
acc = Accuracy()
results = evaluate(model, val_loader, acc, device)
print(results['confusion_matrix'])
print(results['total_accuracy'])

Evaluating: 100%|██████████| 63/63 [00:04<00:00, 14.82batch/s]

[[47  0  0 ...  1  0  0]
 [ 0 45  0 ...  0  0  0]
 [ 0  0 31 ...  0  0  0]
 ...
 [ 0  0  0 ...  0  0  0]
 [ 0  0  0 ...  0  0  0]
 [ 0  0  0 ...  0  0  0]]
0.806





In [10]:
wrapped = MaskedDeiT(model)
wrapped.to(device)
wrapped.eval()

acc = Accuracy()
results = evaluate_wrapped(wrapped, val_loader, acc, device)
print(results['confusion_matrix'])
print(results['total_accuracy'])

Evaluating: 100%|██████████| 63/63 [00:04<00:00, 13.11batch/s]

[[47  0  0 ...  1  0  0]
 [ 0 45  0 ...  0  0  0]
 [ 0  0 31 ...  0  0  0]
 ...
 [ 0  0  0 ...  0  0  0]
 [ 0  0  0 ...  0  0  0]
 [ 0  0  0 ...  0  0  0]]
0.806





In [12]:
mask_params = [m.mask for m in wrapped.masked_attn]

optimiser = torch.optim.AdamW(mask_params, lr=1e-2, weight_decay=1e-4)
scaler = torch.cuda.amp.GradScaler(enabled=True)

def one_epoch(model, loader, optimiser, device, mask_params):
    model.train()
    running_loss = 0.0
    for imgs, ys in loader:
        imgs, ys = imgs.to(device), ys.to(device)
        optimiser.zero_grad(set_to_none=True)
        
        with torch.cuda.amp.autocast(enabled=True):
            logits = model(imgs, y=ys)
            loss = torch.nn.functional.cross_entropy(logits, ys)
            
        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(mask_params, max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * imgs.size(0)
        
    return running_loss / len(loader.dataset)
    

  scaler = torch.cuda.amp.GradScaler(enabled=True)


In [None]:
num_epochs = 2
for epoch in range(num_epochs):
    loss = one_epoch(wrapped, val_loader, optimiser, device, mask_params)
    results = evaluate_wrapped(wrapped, val_loader, acc, device)
    print(results['total_accuracy'])