In [28]:
%load_ext autoreload
%autoreload 2
# %env CUDA_VISIBLE_DEVICES=""

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
env: CUDA_VISIBLE_DEVICES=""


In [29]:
from torchsummary import summary

In [132]:
from vit import PatchEmbeddings,  ViTForClassfication
from data import prepare_data
from torchinfo import summary
import math
import torch

In [113]:
trainloader, testloader, classes = prepare_data(32)

Files already downloaded and verified
Files already downloaded and verified


In [114]:
for X, y in trainloader:
    break
print(X.shape, y.shape)

torch.Size([32, 3, 32, 32]) torch.Size([32])


In [115]:
config = {
    "patch_size": 4,  # Input image size: 32x32 -> 8x8 patches
    "hidden_size": 48,
    "num_hidden_layers": 5,
    "num_attention_heads": 7,
    "intermediate_size": 4 * 48, # 4 * hidden_size
    "hidden_dropout_prob": 0.0,
    "attention_probs_dropout_prob": 0.0,
    "initializer_range": 0.02,
    "image_size": 32,
    "num_classes": 10, # num_classes of CIFAR10
    "num_channels": 3,
    "qkv_bias": True,
    "use_faster_attention": True,
}

In [116]:
from torch import nn
from einops.layers.torch import Rearrange
from torch import Tensor
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels = 3, patch_size = 8, emb_size = 128):
        self.patch_size = patch_size
        super().__init__()
        self.projection = nn.Sequential(
            # break-down the image in s1 x s2 patches and flat them
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size),
            nn.Linear(patch_size * patch_size * in_channels, emb_size)
        )

    def forward(self, x: Tensor) -> Tensor:
        x = self.projection(x)
        return x

In [117]:
pe = PatchEmbedding(in_channels = 3, patch_size = 4, emb_size = 48)

summary(pe, input_data=X)


Layer (type:depth-idx)                   Output Shape              Param #
PatchEmbedding                           [32, 64, 48]              --
├─Sequential: 1-1                        [32, 64, 48]              --
│    └─Rearrange: 2-1                    [32, 64, 48]              --
│    └─Linear: 2-2                       [32, 64, 48]              2,352
Total params: 2,352
Trainable params: 2,352
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0.08
Input size (MB): 0.39
Forward/backward pass size (MB): 0.79
Params size (MB): 0.01
Estimated Total Size (MB): 1.19

In [118]:
pes = PatchEmbeddings(config)
summary(pes, input_data=X)

Layer (type:depth-idx)                   Output Shape              Param #
PatchEmbeddings                          [32, 64, 48]              --
├─Conv2d: 1-1                            [32, 48, 8, 8]            2,352
Total params: 2,352
Trainable params: 2,352
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 4.82
Input size (MB): 0.39
Forward/backward pass size (MB): 0.79
Params size (MB): 0.01
Estimated Total Size (MB): 1.19

In [119]:
pe(X).shape

torch.Size([32, 64, 48])

In [120]:
model = ViTForClassfication(config)
model(X)[0].shape

torch.Size([32, 10])

In [163]:
summary(model, input_data=X)

Layer (type:depth-idx)                                  Output Shape              Param #
ViTForClassfication                                     [32, 10]                  --
├─Embeddings: 1-1                                       [32, 65, 48]              3,168
│    └─PatchEmbeddings: 2-1                             [32, 64, 48]              --
│    │    └─Conv2d: 3-1                                 [32, 48, 8, 8]            2,352
│    └─Dropout: 2-2                                     [32, 65, 48]              --
├─Encoder: 1-2                                          [32, 65, 48]              --
│    └─ModuleList: 2-3                                  --                        --
│    │    └─Block: 3-2                                  [32, 65, 48]              27,102
│    │    └─Block: 3-3                                  [32, 65, 48]              27,102
│    │    └─Block: 3-4                                  [32, 65, 48]              27,102
│    │    └─Block: 3-5                    

In [164]:
model.load_state_dict(torch.load('experiments/vit-with-10-epochs/model_final.pt'))

  model.load_state_dict(torch.load('experiments/vit-with-10-epochs/model_final.pt'))


RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

In [147]:
out, att = model(X, output_attentions=True)
len(att), att[0].shape

(5, torch.Size([32, 7, 65, 65]))

In [148]:
attention_maps = torch.cat(att, dim=1)
attention_maps.shape

torch.Size([32, 35, 65, 65])

In [149]:
attention_maps = attention_maps[:, :, 0, 1:]
attention_maps.shape

torch.Size([32, 35, 64])

In [150]:
attention_maps = attention_maps.mean(dim=1)
attention_maps.shape

torch.Size([32, 64])

In [151]:
num_patches = attention_maps.size(-1)
size = int(math.sqrt(num_patches))
print(num_patches, size)

64 8


In [153]:
attention_maps = attention_maps.view(-1, size, size)
attention_maps.shape

torch.Size([32, 8, 8])

In [162]:
attention_maps[8]

tensor([[0.0154, 0.0153, 0.0153, 0.0154, 0.0153, 0.0154, 0.0154, 0.0154],
        [0.0154, 0.0153, 0.0154, 0.0154, 0.0154, 0.0153, 0.0154, 0.0154],
        [0.0154, 0.0154, 0.0154, 0.0154, 0.0154, 0.0154, 0.0154, 0.0154],
        [0.0154, 0.0155, 0.0154, 0.0154, 0.0155, 0.0154, 0.0153, 0.0154],
        [0.0153, 0.0154, 0.0154, 0.0154, 0.0154, 0.0154, 0.0153, 0.0154],
        [0.0154, 0.0154, 0.0154, 0.0154, 0.0154, 0.0154, 0.0154, 0.0154],
        [0.0154, 0.0154, 0.0154, 0.0154, 0.0153, 0.0154, 0.0154, 0.0154],
        [0.0153, 0.0154, 0.0153, 0.0154, 0.0153, 0.0154, 0.0154, 0.0153]],
       grad_fn=<SelectBackward0>)

In [144]:
attention_maps = attention_maps.unsqueeze(1)
attention_maps.shape


torch.Size([32, 1, 8, 8])

In [146]:
from torch.nn import functional as F
attention_maps = F.interpolate(attention_maps, size=(32, 32), mode='bilinear', align_corners=False)
attention_maps.shape

torch.Size([32, 1, 32, 32])