In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [3]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, n_embd,head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)

        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size, n_embd):
        super().__init__()
        self.n_embd = n_embd
        self.heads = nn.ModuleList([Head(n_embd, head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(self.n_embd, self.n_embd)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(0.2),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size, n_embd)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x


### Pathcify the image

Image after pathify = (N, PxP, HxC/P x WxC/P) -> (N, #Patches, Patch dimensionality)


In [4]:
x = torch.rand(3,1, 28, 28)

In [5]:
class Patchify(nn.Module):
    def __init__(self, patch_size=4):
        super().__init__()
        self.p = patch_size
        self.unfold = torch.nn.Unfold(kernel_size=patch_size, stride=patch_size)

    def forward(self, x):

        bs, c, h, w = x.shape

        x = self.unfold(x).permute(0, 2, 1)

        return x


## MAIN FUNCTION

In [6]:
class ViT(nn.Module):

  def __init__(self, input_shape = (1, 28, 28), patch_size = 4, actual_image_shape = (1, 32, 32), output_dim = 10):
    super(ViT, self).__init__()
    self.input_shape = input_shape
    self.patch_size =  (patch_size, patch_size)
    self.output_dim = output_dim
    self.hidden = 8
    self.n_heads = 2
    self.n_blocks = 4
    self.num_patches = int(input_shape[1] // patch_size)

    self.convBlock = nn.Sequential(
        nn.Conv2d(in_channels= 1, out_channels= 10, kernel_size = 3),
        nn.ReLU(),
        nn.Conv2d(in_channels= 10, out_channels= 10, kernel_size = 3)
    )
    self.convBlock2 = nn.Sequential(
        nn.Conv2d(in_channels= 1, out_channels= 10, kernel_size = 3),
        nn.ReLU(),
        nn.Conv2d(in_channels= 10, out_channels= 10, kernel_size = 3)
    )
    self.Patch = Patchify()
    # linear mapping from pathes of 4x4 to 8 dim vectors
    self.input_dim = input_shape[0]*(patch_size**2)*10
    self.mapping = nn.Linear(self.input_dim, self.hidden)
    # special token for classification
    self.class_token = nn.Parameter(torch.rand(1, self.hidden))
    # postiona embedding table
    self.position_embedding_table = nn.Embedding(self.num_patches**2 + 1, self.hidden)
    # Multihead Attention block
    self.blocks = nn.ModuleList([Block(self.hidden, self.n_heads) for _ in range(self.n_blocks)])
    # output layer construction
    self.outputLayer = nn.Sequential(
        nn.Linear(self.hidden, self.hidden),
        nn.ReLU(),
        #nn.BatchNorm1d(self.hidden),
        nn.Linear(self.hidden, output_dim),
        nn.Softmax(dim=-1)
    )


  def forward(self, x):
    # convolute the image first
    x = self.convBlock(x) + self.convBlock2(x)
    # Patchify
    x = self.Patch(x)
    # linear mapping from 16 -> 8
    x = self.mapping(x)
    # add special token
    tokens = torch.stack([torch.vstack((self.class_token, x[i])) for i in range(len(x))])
    # add positional embedding
   # print(tokens.shape[-2])
    positional_emb = self.position_embedding_table(torch.arange(tokens.shape[-2]).to(device))
    out = tokens + positional_emb
    # passing through Multihead self attention block
    for block in self.blocks:
            out = block(out)
   
    # now passing through MLP layer for final prediction
    out =  self.outputLayer(out[:, 0])
    return out

In [8]:

m = ViT().to(device)

In [9]:
x = torch.rand(1, 1, 32, 32).to(device)

In [10]:
m(x).shape

torch.Size([1, 10])

In [11]:
block = Block(n_embd=8,n_head=2)

In [12]:
block(torch.rand(3, 50, 8)).shape

torch.Size([3, 50, 8])

In [13]:
from pathlib import Path
train_dir = '/kaggle/input/dataset-lfw/lfw-deepfunneled/lfw-deepfunneled'

data_transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize(size=(32, 32)),
    transforms.TrivialAugmentWide(num_magnitude_bins=31),
    transforms.ToTensor()
])

from torchvision import datasets
train_data = datasets.ImageFolder(root=train_dir,
                                  transform=data_transform,
                                  target_transform=None)

In [14]:
num_classes = len(train_data.classes)

In [15]:
device

'cuda'

In [16]:
train_dataloader = DataLoader(dataset=train_data,
                              batch_size=1,
                              num_workers=1,
                              shuffle=True)


In [17]:
next(iter(train_dataloader))[0].shape

torch.Size([1, 1, 32, 32])

In [18]:
def train_step(model: torch.nn.Module,
               dataloader: torch.utils.data.DataLoader,
               loss_fn: torch.nn.Module,
               optimizer: torch.optim.Optimizer):
    # Put model in train mode
    model.train()

    # Setup train loss and train accuracy values
    train_loss, train_acc = 0, 0

    # Loop through data loader data batches
    for batch, (X, y) in enumerate(dataloader):
        # Send data to target device
        X, y = X.to(device), y.to(device)

        # 1. Forward pass
        y_pred = model(X)

        # 2. Calculate  and accumulate loss
        loss = loss_fn(y_pred, y)
        train_loss += loss.item()

        # 3. Optimizer zero grad
        optimizer.zero_grad()

        # 4. Loss backward
        loss.backward()

        # 5. Optimizer step
        optimizer.step()

        # Calculate and accumulate accuracy metric across all batches
        y_pred_class = torch.argmax(torch.softmax(y_pred, dim=1), dim=1)
        train_acc += (y_pred_class == y).sum().item()/len(y_pred)

    train_loss = train_loss / len(dataloader)
    train_acc = train_acc / len(dataloader)
    return train_loss, train_acc

In [19]:
from tqdm.auto import tqdm

# 1. Take in various parameters required for training and test steps
def train(model: torch.nn.Module,
          train_dataloader: torch.utils.data.DataLoader,

          optimizer: torch.optim.Optimizer,
          loss_fn: torch.nn.Module = nn.CrossEntropyLoss(),
          epochs: int = 5, PATH = '/kaggle/working/LFW_ViT.pth'):

    # 2. Create empty results dictionary
    results = {"train_loss": [],
        "train_acc": []
    }

    # 3. Loop through training and testing steps for a number of epochs
    for epoch in tqdm(range(epochs)):
        train_loss, train_acc = train_step(model=model,
                                           dataloader=train_dataloader,
                                           loss_fn=loss_fn,
                                           optimizer=optimizer)

        print(
            f"Epoch: {epoch+1} | "
            f"train_loss: {train_loss:.4f} | "
            f"train_acc: {train_acc:.4f} | "

        )

        # 5. Update results dictionary
        results["train_loss"].append(train_loss)
        results["train_acc"].append(train_acc)


    # 6. Return the filled results at the end of the epochs
    torch.save(model.state_dict(), PATH)
    return results

In [20]:
PATH = '/kaggle/working/LFW_ViT.pth'

In [23]:
LEARNING_RATE = 3.5e-3
model = ViT(output_dim = num_classes).to(device)
model.load_state_dict(torch.load(PATH))
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(params = model.parameters(), lr = LEARNING_RATE)

In [None]:
NUM_EPOCHS = 20
Results = train(model=model,
                train_dataloader=train_dataloader,
                optimizer=optimizer,
                loss_fn=loss_fn,
                epochs=NUM_EPOCHS)

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch: 1 | train_loss: 8.6568 | train_acc: 0.0002 | 
Epoch: 2 | train_loss: 8.6568 | train_acc: 0.0002 | 
Epoch: 3 | train_loss: 8.6567 | train_acc: 0.0002 | 
Epoch: 4 | train_loss: 8.6567 | train_acc: 0.0002 | 
Epoch: 5 | train_loss: 8.6566 | train_acc: 0.0002 | 
Epoch: 6 | train_loss: 8.6565 | train_acc: 0.0002 | 
Epoch: 7 | train_loss: 8.6563 | train_acc: 0.0002 | 
Epoch: 8 | train_loss: 8.6561 | train_acc: 0.0003 | 
Epoch: 9 | train_loss: 8.6555 | train_acc: 0.0007 | 
Epoch: 10 | train_loss: 8.6538 | train_acc: 0.0208 | 
Epoch: 11 | train_loss: 8.6477 | train_acc: 0.0401 | 
Epoch: 12 | train_loss: 8.6336 | train_acc: 0.0401 | 
Epoch: 13 | train_loss: 8.6220 | train_acc: 0.0401 | 
Epoch: 14 | train_loss: 8.6181 | train_acc: 0.0401 | 


In [128]:
Results

{'train_loss': [8.656781060097664], 'train_acc': [0.00015113730824454016]}