In [None]:
# upload your kaggle.json before running
# install dependencies, download and unzip the dataset
!pip install torchinfo
!pip install torchmetrics
!pip install kaggle
!pip install efficientnet_pytorch
!mkdir ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!kaggle datasets download -d a2015003713/militaryaircraftdetectiondataset
!unzip militaryaircraftdetectiondataset.zip

In [None]:
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
# visualise random image from dataset
import random
from PIL import Image
from pathlib import Path

data_path = Path("crop/")

image_path_list = list(data_path.glob("*/*.jpg"))

random_image_path = random.choice(image_path_list)

image_class = random_image_path.parent.stem

img = Image.open(random_image_path)

print(f"Random image path: {random_image_path}")
print(f"Image class: {image_class}")
print(f"Image height: {img.height}")
print(f"Image width: {img.width}")
img

In [None]:
# split the data into train and test
from typing_extensions import TypeVarTuple
import os
import shutil
from sklearn.model_selection import train_test_split

class_names = []

data_path = Path("crop")
test_path = Path("test")
train_path = Path("train")

os.makedirs(train_path, exist_ok=True)
os.makedirs(test_path, exist_ok=True)

for item in os.listdir(data_path):
  class_names.append(item)

for class_name in class_names:
  os.makedirs(train_path/class_name, exist_ok=True)
  os.makedirs(test_path/class_name, exist_ok=True)

  images = [image for image in os.listdir(data_path/class_name) if image.endswith(".jpg")]

  train_images, test_images = train_test_split(images, test_size=0.2)

  for train_image in train_images:
    src_path = data_path/class_name/train_image
    dest_path = train_path/class_name
    shutil.move(src_path, dest_path)

  for test_image in test_images:
    src_path = data_path/class_name/test_image
    dest_path = test_path/class_name
    shutil.move(src_path, dest_path)

In [None]:
from torchvision import transforms

IMSIZE = 224

data_transform = transforms.Compose([
    transforms.Resize(size=(IMSIZE, IMSIZE)),
    transforms.ToTensor()
])

In [None]:
import os
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

train_data = ImageFolder(root=train_path,
                         transform=data_transform,
                         target_transform=None)

test_data = ImageFolder(root=test_path,
                        transform=data_transform,
                        target_transform=None)

BATCH_SIZE = 32

train_dataloader = DataLoader(dataset=train_data,
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              num_workers=os.cpu_count())

test_dataloader = DataLoader(dataset=test_data,
                             batch_size=BATCH_SIZE,
                             shuffle=False,
                             num_workers=os.cpu_count())

class_labels = train_data.classes
num_classes = len(class_labels)

**ViT Equation 1**

$\bf{z_0} = [x_{class}; x^{1}_pE; x^{2}_pE;...;x^{N}_pE] + E_{pos}$

$E \in \mathbb{R}^{(P^{2}\cdot C)\times D}$   

$E_{pos} \in \mathbb{R}^{(N+1)\times D}$

Reshape the image $x \in \mathbb{R}^{H \times W \times C}$ where H is the height of the image, W is the width and C is the number of color channels into a sequence of flattened 2D patches $\bf x_p \in \mathbb{R}^{N \times (P^{2} \cdot C)}$ where C is the number of color channels, P is the resolution of each image patch and $\bf N = \frac{HW}{P^{2}}$ is the resulting number of patches. We flatten the patches and map to D dimensions with a trainable linear projection (In ViT-Base case D = 768 and P = 16).

```x_input = [class_token, patch_1, patch_2, ...] + [class_token_position, patch_1_position, patch_2_position,...]```

**ViT Equation 2 and 3**

Eq. 2

$\bf z'_\iota = MSA(LN(z_{\iota-1})) + z_{\iota-1}$

$\iota = 1...L$

From every layer from 1 to $L$ number of layers there is a $\bf LN$ layer (linear norm layer) wrapped in $\bf MSA$ (Multi-Head Attention) layer

Eq.3

$\bf z_{\iota} = MLP(LN(z'_{\iota}))+z'_{\iota}$

From every layer from 1 to $L$ number of layers there is a $\bf LN$ layer (linear norm layer) wrapped in $\bf MLP$ (Multi-Layer Perceptron) layer


**ViT Equation 4**

$\bf y = LN(z^{0}_L)$

For the last layer $L$ the output $y$ is the zero index token of $\bf z$ wrapped in LayerNorm $\bf(LN)$ layer

In [None]:
H = 224 # height
W = 224 # width
C = 3   # color channels
P = 16  # patch size
D = 768 # hidden units
N = int((H*W)/P**2) # number of patches
N

196

In [None]:
print(f"Input image shape: {H}x{W}x{C}")
print(f"Output shape of flattened 2D patches: {N}x{(P**2)*C}")

Input image shape: 224x224x3
Output shape of flattened 2D patches: 196x768


In [None]:
from torch import nn

class PatchEmbedding(nn.Module):
  def __init__(self,
               in_channels: int=C,
               embedding_dim: int=D,
               patch_size: int=P):
    super().__init__()

    self.patcher = nn.Conv2d(in_channels=in_channels,
                             out_channels=embedding_dim,
                             kernel_size=patch_size,
                             stride=patch_size)

    self.flattener = nn.Flatten(start_dim=2,
                                end_dim=3)

  def forward(self, x):
    x = self.flattener(self.patcher(x))
    return x.permute(0, 2, 1)

In [None]:
class MSABlock(nn.Module):
    def __init__(self,
                 embedding_dim: int=D,
                 num_heads:int=12,
                 dropout:float=0):
        super().__init__()

        self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)

        self.multihead_attn = nn.MultiheadAttention(embed_dim=embedding_dim,
                                                    num_heads=num_heads,
                                                    dropout=dropout,
                                                    batch_first=True)

    def forward(self, x):
        x = self.layer_norm(x)
        attn_output, _ = self.multihead_attn(query=x, key=x, value=x, need_weights=False)
        return attn_output

In [None]:
class MLPBlock(nn.Module):
    def __init__(self,
                 embedding_dim: int=D,
                 mlp_size: int=3072,
                 dropout: float=0.1):
        super().__init__()

        self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)

        self.mlp = nn.Sequential(
            nn.Linear(in_features=embedding_dim, out_features=mlp_size),
            nn.GELU(),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=mlp_size, out_features=embedding_dim),
            nn.Dropout(p=dropout)
        )

    def forward(self, x):
        x = self.layer_norm(x)
        x = self.mlp(x)
        return x

In [None]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self,
                 embedding_dim: int=D,
                 num_heads: int=12,
                 mlp_size: int=3072,
                 mlp_dropout: float=0.1,
                 attn_dropout: float=0):
        super().__init__()

        self.msa_block = MSABlock(embedding_dim=embedding_dim,
                                                     num_heads=num_heads,
                                                     dropout=attn_dropout)

        self.mlp_block =  MLPBlock(embedding_dim=embedding_dim,
                                   mlp_size=mlp_size,
                                   dropout=mlp_dropout)

    def forward(self, x):
        x =  self.msa_block(x) + x
        x = self.mlp_block(x) + x
        return x

In [None]:
class ViT(nn.Module):
  def __init__(self,
               img_size: int=IMSIZE,
               in_channels: int=C,
               patch_size: int=P,
               num_transformer_layers: int=12,
               embedding_dim: int=D,
               mlp_size: int=3072,
               num_heads: int=12,
               attn_dropout: float=0,
               mlp_dropout: float=0.1,
               embedding_dropout: float=0.1,
               num_classes: int=num_classes):
    super().__init__()

    assert img_size % patch_size == 0, f"Image size must be divisable by patch size"

    self.num_patches = (img_size * img_size) // patch_size**2

    self.class_token = nn.Parameter(torch.randn(1, self.num_patches+1, embedding_dim), requires_grad=True)

    self.position_embedding = nn.Parameter(torch.randn(1, 1, embedding_dim), requires_grad=True)

    self.embedding_dropout = nn.Dropout(p=embedding_dropout)

    self.patch_embedding = PatchEmbedding(in_channels=in_channels, patch_size=patch_size, embedding_dim=embedding_dim)

    self.transformer_encoder = nn.Sequential(*[TransformerEncoderBlock(embedding_dim=embedding_dim,
                                                                       num_heads=num_heads,
                                                                       mlp_size=mlp_size,
                                                                       mlp_dropout=mlp_dropout) for _ in range(num_transformer_layers)])

    self.classifier = nn.Sequential(
        nn.LayerNorm(normalized_shape=embedding_dim),
        nn.Linear(in_features=embedding_dim, out_features=num_classes)
    )

  def forward(self, x):

    batch_size = x.shape[0]

    class_token = self.class_token.expand(batch_size, -1, -1)

    x = self.patch_embedding(x)

    x = torch.cat((class_token, x), dim=1)

    x = self.position_embedding + x

    x = self.embedding_dropout(x)

    x = self.transformer_encoder(x)

    x = self.classifier(x[:, 0])

    return  x

In [None]:
vit = ViT(num_classes=len(class_names)).to(device)

In [None]:
from torchinfo import summary

summary(model=vit,
        input_size=(BATCH_SIZE, C, H, W),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=vit.parameters(), lr=3e-3, betas=(0.9, 0.999), weight_decay=0.3)

In [None]:
from tqdm.auto import tqdm
from torchmetrics import Accuracy

accuracy = Accuracy(task="multiclass", num_classes=num_classes)

EPOCHS = 10

for epoch in tqdm(range(EPOCHS)):
  vit.train()

  train_loss, train_acc = 0, 0

  for X, y in train_dataloader:
    X, y = X.to(device), y.to(device)

    pred = vit(X)

    loss = loss_fn(pred, y)
    train_loss += loss.item()
    train_acc += accuracy(pred, y)

    optimizer.zero_grad()

    loss.backward()

    optimizer.step()

  train_loss /= len(train_dataloader)
  train_acc /= len(train_dataloader)


  test_loss = 0
  test_acc = 0

  vit.eval()
  with torch.inference_mode():
    for X, y in test_dataloader:
      X, y = X.to(device), y.to(device)

      test_pred = vit(X)

      loss = loss_fn(test_pred, y)
      test_loss += loss.item()
      test_acc += accuracy(test_pred, y)

  test_loss /= len(test_dataloader)
  test_acc /= len(test_dataloader)

  print(f"Train loss: {train_loss:.4f} | Train acc: {train_acc:.2f} | Test loss: {test_loss:.4f} | Test acc: {test_acc:.2f}")