In [1]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import cv2
import json
import glob
import tqdm

from PIL import Image
import requests
import torchvision.transforms as transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

# for model customization
from functools import partial
from timm.models.vision_transformer import VisionTransformer, _cfg, Attention, Block
from timm.models.layers import trunc_normal_

from torch.utils.data import Dataset, DataLoader

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [1]:
from datasets import load_dataset

# If the dataset is gated/private, make sure you have run huggingface-cli login
dataset = load_dataset("evanarlian/imagenet_1k_resized_256", cache_dir='/home/data/IMNET256')

Resolving data files:   0%|          | 0/52 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/52 [00:00<?, ?it/s]

Downloading data: 100%|██████████| 459M/459M [00:54<00:00, 8.40MB/s] 
Downloading data: 100%|██████████| 526M/526M [03:50<00:00, 2.29MB/s] 
Downloading data: 100%|██████████| 468M/468M [03:08<00:00, 2.49MB/s] 
Downloading data: 100%|██████████| 452M/452M [02:53<00:00, 2.60MB/s] 
Downloading data: 100%|██████████| 469M/469M [03:50<00:00, 2.04MB/s] 
Downloading data: 100%|██████████| 478M/478M [03:33<00:00, 2.23MB/s] 
Downloading data: 100%|██████████| 441M/441M [03:38<00:00, 2.02MB/s] 
Downloading data: 100%|██████████| 420M/420M [03:03<00:00, 2.29MB/s] 
Downloading data: 100%|██████████| 437M/437M [02:56<00:00, 2.48MB/s] 
Downloading data: 100%|██████████| 438M/438M [03:53<00:00, 1.88MB/s] 
Downloading data: 100%|██████████| 440M/440M [03:24<00:00, 2.16MB/s] 
Downloading data: 100%|██████████| 439M/439M [03:32<00:00, 2.06MB/s] 
Downloading data: 100%|██████████| 429M/429M [03:02<00:00, 2.36MB/s] 
Downloading data: 100%|██████████| 480M/480M [03:32<00:00, 2.26MB/s] 
Downloading data: 10

In [2]:
class CustomAttention(Attention):

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        q, k = self.q_norm(q), self.k_norm(k)

        q = q * self.scale
        attn = q @ k.transpose(-2, -1)
        x = self.attn_drop(attn.softmax(dim=-1)) @ v

        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x, attn

class CustomBlock(Block):
    def __init__(
            self,
            **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        for kw in ['mlp_ratio', 'init_values', 'drop_path', 'act_layer', 'mlp_layer']:
            kwargs.pop(kw, None)
        self.attn = CustomAttention(**kwargs)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_, attn = self.attn(self.norm1(x))
        x = x + self.drop_path1(self.ls1(x_))
        x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
        return x, attn

class DistillVisionTransformer(VisionTransformer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
        self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
        num_patches = self.patch_embed.num_patches
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
        self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()

        trunc_normal_(self.dist_token, std=.02)
        trunc_normal_(self.pos_embed, std=.02)
        self.head_dist.apply(self._init_weights)

    def forward_features(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        dist_token = self.dist_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, dist_token, x), dim=1)

        x = x + self.pos_embed
        x = self.pos_drop(x)

        attn = []
        for blk in self.blocks:
            x, a = blk(x)
            attn.append(a)
        attn = torch.stack(attn, dim=1)

        x = self.norm(x)
        return x[:, 0], x[:, 1], attn

    def forward(self, x):
        x, x_dist, attn = self.forward_features(x)
        x = self.head(x)
        x_dist = self.head_dist(x_dist)
        if self.training:
            return x, x_dist, attn
        else:
            # during inference, return the average of both classifier predictions
            return (x + x_dist) / 2

In [3]:
class ImageNetDataset(Dataset):
    def __init__(self):
        super(type(self), self).__init__()
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(256, interpolation=3, antialias=True),
            transforms.CenterCrop(224),
            transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
        ])

        self.image_list = glob.glob('/home/data/imagenet_val/images/*.JPEG')
        with open('/home/data/imagenet_val/imagenet_class_index.json') as fs:
            self.label_dict = json.load(fs)
        
    def __getitem__(self, index):
        img = self.image_list[index]

        label = img.split('_')[-1].split('.')[0]
        for k, v in self.label_dict.items():
            if v[0] != label:
                continue
            label = torch.tensor([int(k)])
            label = nn.functional.one_hot(label, 1000).float().view(-1)
            break

        img = cv2.imread(img)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.image_list)

dataset = ImageNetDataset()
dataset.__getitem__(0)

(tensor([[[ 2.2489,  2.2489,  2.2489,  ...,  0.4627,  0.4338,  0.4080],
          [ 2.2489,  2.2489,  2.2489,  ...,  0.5029,  0.5693,  0.5380],
          [ 2.2489,  2.2489,  2.2489,  ...,  0.4870,  0.5354,  0.5178],
          ...,
          [-0.2145, -0.3301, -0.2220,  ...,  0.3099,  0.2612,  0.2527],
          [-0.3632, -0.3537, -0.3418,  ...,  0.3785,  0.3223,  0.2573],
          [-0.4371, -0.4600, -0.4058,  ...,  0.3983,  0.3611,  0.3098]],
 
         [[ 2.4286,  2.4286,  2.4286,  ...,  0.9522,  0.9140,  0.8805],
          [ 2.4286,  2.4286,  2.4286,  ...,  1.0059,  1.0721,  1.0332],
          [ 2.4286,  2.4286,  2.4286,  ...,  1.0157,  1.0479,  1.0408],
          ...,
          [-0.0699, -0.1919, -0.0840,  ...,  0.3000,  0.2713,  0.2705],
          [-0.1953, -0.2136, -0.2055,  ...,  0.3539,  0.3318,  0.2793],
          [-0.2881, -0.3253, -0.2683,  ...,  0.3613,  0.3356,  0.3166]],
 
         [[ 2.6400,  2.6400,  2.6400,  ...,  0.1170,  0.1071,  0.1191],
          [ 2.6400,  2.6400,

In [4]:
history = {0:[], 1:[]}

In [51]:
#model = torch.hub.load('facebookresearch/deit:main', 'deit_tiny_patch16_224', pretrained=True)
deit_small_distilled_patch16_224 = DistillVisionTransformer(
    block_fn=CustomBlock, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6))
deit_small_distilled_patch16_224.default_cfg = _cfg()

dataset = ImageNetDataset()
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, drop_last=True)
optimizer = torch.optim.Adam(deit_small_distilled_patch16_224.parameters(), lr=0.0001)

In [9]:
deit_small_distilled_patch16_224.to(device)
deit_small_distilled_patch16_224.train()

tq = tqdm.tqdm(dataloader)
for i, (img, label) in enumerate(tq):
    img = img.to(device)
    label = label.to(device)

    out_c, out_s, attn = deit_small_distilled_patch16_224(img)
    loss_cls = nn.functional.cross_entropy(out_c, label)
    history[0].append(loss_cls.item())
    
    optimizer.zero_grad()
    loss_cls.backward()
    optimizer.step()

    tq.set_postfix(loss=loss_cls.item())
    if (i+1) == 500:
        break
torch.cuda.empty_cache()

  2%|▏         | 56/3125 [00:17<16:21,  3.13it/s, loss=7.2] 


KeyboardInterrupt: 

In [6]:
#model = torch.hub.load('facebookresearch/deit:main', 'deit_tiny_patch16_224', pretrained=True)
deit_small_patch16_224 = VisionTransformer(
    patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
    norm_layer=partial(nn.LayerNorm, eps=1e-6))
deit_small_patch16_224.default_cfg = _cfg()
checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
            map_location="cpu", check_hash=True
        )
deit_small_patch16_224.load_state_dict(checkpoint["model"])

#model = torch.hub.load('facebookresearch/deit:main', 'deit_tiny_patch16_224', pretrained=True)
deit_small_distilled_patch16_224 = DistillVisionTransformer(
    block_fn=CustomBlock, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6))
deit_small_distilled_patch16_224.default_cfg = _cfg()

dataset = ImageNetDataset()
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, drop_last=True)
optimizer = torch.optim.Adam(deit_small_distilled_patch16_224.parameters(), lr=0.0001)

In [10]:
deit_small_patch16_224.to(device)
deit_small_patch16_224.eval()

deit_small_distilled_patch16_224.to(device)
deit_small_distilled_patch16_224.train()

alpha = 0.9

tq = tqdm.tqdm(dataloader)
for i, (img, label) in enumerate(tq):
    img = img.to(device)
    label = label.to(device)

    out_c, out_s, attn = deit_small_distilled_patch16_224(img)
    out_t = deit_small_patch16_224(img)
    loss_cls = nn.functional.cross_entropy(out_c, label)
    loss_dist = nn.functional.kl_div(out_c, out_t, reduction='batchmean', log_target=True)
    loss = alpha * loss_cls + (1 - alpha) * loss_dist
    history[1].append(loss_cls.item())
    break
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    tq.set_postfix(loss=loss_cls.item())
    if (i+1) == 100:
        break
torch.cuda.empty_cache()

  0%|          | 0/3125 [00:00<?, ?it/s]


OutOfMemoryError: HIP out of memory. Tried to allocate 346.00 MiB. GPU 0 has a total capacty of 11.98 GiB of which 124.00 MiB is free. Of the allocated memory 10.93 GiB is allocated by PyTorch, and 518.55 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_HIP_ALLOC_CONF

In [146]:
def stochastic_mask_selection(attn, t, k):
    '''
    attn: shape (N,N) for N patches
    t: soften temperature
    k: keep number
    '''
    attn = torch.softmax(attn.sum(dim=1)/t, dim=0).cumsum(dim=0)
    keep = attn.multinomial(num_samples=k, replacement=False)
    return keep


tensor([0, 1, 3])