In [1]:
import torch
import pytorch_lightning as pl
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchmetrics
import numpy as np
import matplotlib.pyplot as plt
import os
import math

#--------------------------------
# Device configuration
#--------------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#print('Using device: %s'%device)

# Init DataLoader from MNIST Dataset
# Init DataLoader from MNIST Dataset
batch_size = 512

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))])

mnist_train = torchvision.datasets.MNIST(
    '.', train=True, download=True, transform=transform)
mnist_train, mnist_val = torch.utils.data.random_split(mnist_train, [
                                                       50000, 10000])

mnist_test = torchvision.datasets.MNIST(
    os.getcwd(), train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(
    mnist_train, batch_size=64, num_workers=8)
val_loader = torch.utils.data.DataLoader(
    mnist_val, batch_size=64, num_workers=8)
test_loader = torch.utils.data.DataLoader(
    mnist_test, batch_size=64, num_workers=8)


In [2]:
class PositionalEncoding(nn.Module):
    """
    compute sinusoid encoding.
    """

    def __init__(self, model_dim, max_len, device):  # batch_size, seq, dim = x.shape
        """
        constructor of sinusoid encoding class

        :param model_dim: dimension of model
        :param max_len: max sequence length
        :param device: hardware device setting
        """
        super(PositionalEncoding, self).__init__()

        # same size with input matrix (for adding with input matrix)
        self.encoding = torch.zeros(max_len, model_dim, device=device)
        self.encoding.requires_grad = False  # we don't need to compute gradient

        pos = torch.arange(0, max_len, device=device)
        pos = pos.float().unsqueeze(dim=1)
        # 1D => 2D unsqueeze to represent word's position

        _2i = torch.arange(0, model_dim, step=2, device=device).float()
        # 'i' means index of model_dim (e.g. embedding size = 50, 'i' = [0,50])
        # "step=2" means 'i' multiplied with two (same with 2 * i)

        self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / model_dim)))
        self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / model_dim)))
        # compute positional encoding to consider positional information of words

    def forward(self, x):
        # self.encoding
        # [max_len = 512, model_dim = 512]
        batch_size, seq_len, _ = x.size()  # 64, max_len = seq_len = 26, model_dim = 16
        # [batch_size = 128, seq_len = 30]

        return self.encoding[:seq_len, :]  # batch_size, seq, model_dim
        # [seq_len = 30, model_dim = 512]
        # it will add with tok_emb : [128, 30, 512]


def split_up_patches(x, patch_size):
    h = x.shape[-2]
    w = x.shape[-1]
    patches = nn.Unfold(kernel_size=patch_size, stride=patch_size+1)(x)
    # note: index convention is (n_batches, n_tokens, hidden_dim)!
    patches = torch.permute(patches, (0, 2, 1))
    return patches


# test splitting into patches
x_test_split = torch.randn(32, 1, 28, 28)
x_test_split_patches = split_up_patches(x_test_split, 4)
print(x_test_split_patches.shape)

class LinearProjection(nn.Module):
    def __init__(self, patch_size, hidden_dim):
        super().__init__()
        self.patch_size = patch_size
        self.linear_emb = nn.Linear(patch_size * patch_size, hidden_dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
    def split_up_patches(self, x, patch_size):
        h = x.shape[-2]
        w = x.shape[-1]
        patches = nn.Unfold(kernel_size = patch_size, stride = patch_size+1)(x)
        patches = torch.permute(patches, (0, 2, 1)) # note: index convention is (n_batches, n_tokens, hidden_dim)!
        return patches  
    def forward(self, x):
        #  split image to fixed size patch -> flatten -> linear encoding -> concate class token
        x = self.split_up_patches(x, self.patch_size)
        x = self.linear_emb(x) # batch_size, seq, hidden_dim
        x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), 1) # batch_size, seq+1, hidden_dim
        return x
    

torch.Size([32, 25, 16])


In [3]:
class Attention(nn.Module):
    def __init__(self, hidden_dim):
        super(Attention, self).__init__()
        self.hidden_dim = hidden_dim
        # TODO
        self.W_Q = nn.Linear(hidden_dim, hidden_dim)
        self.W_K = nn.Linear(hidden_dim, hidden_dim)
        self.W_V = nn.Linear(hidden_dim, hidden_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        Q = self.W_Q(x)
        K = self.W_K(x)
        V = self.W_V(x)
        # TODO
        score = self.softmax(Q @ K.mT/np.sqrt(self.hidden_dim))
        attention = score @ V

        return attention


class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_dim, nr_heads):
        super(MultiHeadAttention, self).__init__()
        self.hidden_dim = hidden_dim
        self.nr_heads = nr_heads
        self.head_dim = int(hidden_dim/nr_heads)
        self.attention = torch.Tensor([])
        # TODO:
        self.W_concat = nn.Linear(hidden_dim, hidden_dim)
        self.attn = Attention(hidden_dim=self.head_dim)

    def forward(self, x):
        # TODO:
        batch_size, _, _ = x.shape
        self.attention = torch.Tensor([]).to(device)
        x = x.view(batch_size, -1, self.nr_heads,
                   self.head_dim).permute(0, 2, 1, 3)

        for i in range(self.nr_heads):
            self.attention = torch.cat((self.attention, self.attn(x[:, i])), 2)

        return self.W_concat(self.attention)


In [4]:
class LayerNorm(nn.Module):
    def __init__(self, model_dim, eps=1e-12):
        super(LayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(model_dim)).to(device)
        self.beta = nn.Parameter(torch.zeros(model_dim)).to(device)
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)

        out = (x - mean) / (std + self.eps)
        out = self.gamma * out + self.beta
        return out


class FFN(nn.Module):
    def __init__(self, model_dim, hidden_dim, drop_prob=0.1):
        super(FFN, self).__init__()
        # TODO
        self.linear1 = nn.Linear(model_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, model_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(drop_prob)

    def forward(self, x):
        # TODO:
        self.network = nn.Sequential(
            self.linear1,
            self.relu,
            self.dropout,
            self.linear2
        )
        return self.network(x)


In [5]:
class Transformer_module(nn.Module):
    def __init__(self, hidden_dim, ffn_hidden, nr_heads):
        super(Transformer_module, self).__init__()
        # TODO:
        self.hidden_dim = hidden_dim
        self.ffn_hidden = ffn_hidden
        self.nr_heads = nr_heads

        self.msa = MultiHeadAttention(hidden_dim=hidden_dim, nr_heads=nr_heads)
        self.ln1 = LayerNorm(hidden_dim)
        self.ln2 = LayerNorm(hidden_dim)
        self.ffn = FFN(model_dim=hidden_dim, hidden_dim=ffn_hidden)
        
    def forward(self, x):
        # TODO:

        # 1. compute self attention
        attn = self.msa(self.ln1(x))
        # 2. add and norm
        z = attn + x
        # 3. positionwise feed forward network
        ffn = self.ffn(self.ln2(z))
        # 4. add and norm
        out = ffn + z
        return out


In [6]:
class Transformer(pl.LightningModule):
    def __init__(self, hidden_dim = 16, num_class = 10, nr_layers = 3, nr_heads = 3, patch_size = 4):
        super().__init__()
        self.num_class = num_class
        self.nr_layers = nr_layers
        self.nr_heads = nr_heads
        self.learning_rate = 1e-4 #1e-3
        self.patch_size = patch_size
        self.hidden_dim = hidden_dim

        self.train_acc = torchmetrics.classification.MulticlassAccuracy(num_classes = num_class, average='weighted')
        self.val_acc = torchmetrics.classification.MulticlassAccuracy(num_classes = num_class, average='weighted')
        self.test_acc = torchmetrics.classification.MulticlassAccuracy(num_classes = num_class, average='weighted')

        self.transformer_modules = []
        # TODO
        # Patch Embedding
        self.linear_prj = LinearProjection(self.patch_size, self.hidden_dim)
        self.pos_encode = PositionalEncoding(hidden_dim,hidden_dim, device)
        
        # Transformer blocks
        self.hidden_dim = hidden_dim
        #self.transformer = Transformer_module(hidden_dim=hidden_dim, ffn_hidden=512, nr_heads = nr_heads) 
        self.transformer = nn.ModuleList([Transformer_module(hidden_dim=hidden_dim, ffn_hidden=8, nr_heads=nr_heads) for _ in range(nr_layers)])
        # classification head
        self.mlp = nn.Sequential(LayerNorm(hidden_dim), nn.Linear(hidden_dim,num_class))

        #layernorm on final result of transformer
        self.ln = LayerNorm(hidden_dim)


    def forward(self, x):
        # TODO:

        # patch embedding
        x = self.linear_prj(x)
        #positional encoding
        x = x + self.pos_encode(x)
        #transformer
        for layer in self.transformer:
            x = layer(x)
        x = self.ln(x)

        # cls head
        return self.mlp(x[:,0])

    def training_step(self, train_batch, batch_idx):
        images, labels = train_batch
        # Forward pass
        outputs = self(images)
        criterion = torch.nn.CrossEntropyLoss()

        loss = criterion(outputs, labels)
        self.log('train_loss', loss)

        pred_labels = torch.argmax(outputs, 1)
        
        self.train_acc(pred_labels, labels)

        self.log('train_acc', self.train_acc, on_step=False, on_epoch=True)

        return loss

    def validation_step(self, val_batch, batch_idx):
        images, labels = val_batch
        outputs = self(images)
        pred_labels = torch.argmax(outputs, 1)

        self.val_acc(pred_labels, labels)

        self.log('val_acc', self.val_acc, on_step=False, on_epoch=True)

    def test_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        pred_labels = torch.argmax(outputs, 1)

        self.test_acc(pred_labels, labels)

        self.log("test_acc", self.test_acc)

    def configure_optimizers(self):
        opt = torch.optim.Adam(self.parameters(), lr = self.learning_rate)
        return opt

In [7]:
# test splitting into patches
x_test_split = torch.randn(32, 1, 28, 28)
x_test_split_patches = split_up_patches(x_test_split, 4)
print(x_test_split_patches.shape) #32,25,16


torch.Size([32, 25, 16])


In [8]:
trainer = pl.Trainer(accelerator="auto", devices=1, max_epochs=30)
model = Transformer(256, 10, 8, 8, 8)
trainer.fit(model, train_loader, val_loader)
trainer.test(model, test_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name        | Type               | Params
---------------------------------------------------
0 | train_acc   | MulticlassAccuracy | 0     
1 | val_acc     | MulticlassAccuracy | 0     
2 | test_acc    | MulticlassAccuracy | 0     
3 | linear_prj  | LinearProjection   | 16.9 K
4 | pos_encode  | PositionalEncoding | 0     
5 | transformer | ModuleList         | 586 K 
6 | mlp         | Sequential         | 2.6 K 
7 | ln          | LayerNorm          | 0     
---------------------------------------------------
606 K     Trainable params
0         Non-trainable params
606 K     Total params
2.424     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.9720999598503113
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_acc': 0.9720999598503113}]

In [None]:
#tensorboard
#https://tensorboard.dev/experiment/OHkp7e9AQWW4PprZfpYpHw/