In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import numpy as np
from PIL import Image

In [2]:
class SelfAttention(nn.Module):
    def __init__(self, embedding_dim=768, key_dim=64):
        super(SelfAttention, self).__init__()

        self.embedding_dim = embedding_dim
        self.key_dim = key_dim

        self.W = nn.Parameter(torch.randn(embedding_dim, 3*key_dim))

    def forward(self, x):
        key_dim = self.key_dim


        qkv = torch.matmul(x, self.W)

        q = qkv[:, :, :key_dim]
        k = qkv[:, :, key_dim:key_dim*2 ]
        v = qkv[:, :, key_dim*2:]


        k_T = torch.transpose(k, -2, -1)
        dot_products = torch.matmul(q, k_T)


        scaled_dot_products = dot_products / np.sqrt(key_dim)

        attention_weights = F.softmax(scaled_dot_products, dim=1)

        weighted_values = torch.matmul(attention_weights, v)

        return weighted_values

In [3]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embedding_dim=768, num_heads=12):
        super(MultiHeadSelfAttention, self).__init__()

        self.num_heads = num_heads
        self.embedding_dim = embedding_dim

        assert embedding_dim % num_heads == 0
        self.key_dim = embedding_dim // n_head


        self.attention_list = [SelfAttention(embedding_dim, self.key_dim) for _ in range(num_heads)]
        self.multi_head_attention = nn.ModuleList(self.attention_list)


        self.W = nn.Parameter(torch.randn(num_heads * self.key_dim, embedding_dim))

    def forward(self, x):

        attention_scores = [attention(x) for attention in self.multi_head_attention]

        Z = torch.cat(attention_scores, -1)

        attention_score = torch.matmul(Z, self.W)

        return attention_score

In [4]:
class MultiLayerPerceptron(nn.Module):
    def __init__(self, embedding_dim=768, hidden_dim=3072):
        super(MultiLayerPerceptron, self).__init__()

        self.mlp = nn.Sequential(
                            nn.Linear(embedding_dim, hidden_dim),
                            nn.GELU(),
                            nn.Linear(hidden_dim, embedding_dim)
                   )

    def forward(self, x):
        x = self.mlp(x)
        return x

In [5]:
class TransformerEncoder(nn.Module):
    def __init__(self, embedding_dim=768, num_heads=12, hidden_dim=3072, dropout_prob=0.1):
        super(TransformerEncoder, self).__init__()

        self.MSA = MultiHeadSelfAttention(embedding_dim, num_heads)
        self.MLP = MultiLayerPerceptron(embedding_dim, hidden_dim)

        self.layer_norm1 = nn.LayerNorm(embedding_dim)
        self.layer_norm2 = nn.LayerNorm(embedding_dim)

        self.dropout1 = nn.Dropout(p=dropout_prob)
        self.dropout2 = nn.Dropout(p=dropout_prob)
        self.dropout3 = nn.Dropout(p=dropout_prob)

    def forward(self, x):

        out_1 = self.dropout1(x)

        out_2 = self.layer_norm1(out_1)

        msa_out = self.MSA(out_2)

        out_3 = self.dropout2(msa_out)

        res_out = x + out_3

        out_4 = self.layer_norm2(res_out)

        mlp_out = self.MLP(out_4)

        out_5 = self.dropout3(mlp_out)

        output = res_out + out_5

        return output

In [6]:
class MLPHead(nn.Module):
    def __init__(self, embedding_dim=768, num_classes=10, fine_tune=False):
        super(MLPHead, self).__init__()
        self.num_classes = num_classes

        if not fine_tune:
            self.mlp_head = nn.Sequential(
                                    nn.Linear(embedding_dim, 3072),
                                    nn.Tanh(),
                                    nn.Linear(3072, num_classes)
                            )
        else:

            self.mlp_head = nn.Linear(embedding_dim, num_classes)

    def forward(self, x):
        x = self.mlp_head(x)
        return x

In [7]:
class VisionTransformer(nn.Module):
    def __init__(self, patch_size=16, image_size=224, channel_size=3,
                     num_layers=12, embedding_dim=768, num_heads=12, hidden_dim=3072,
                            dropout_prob=0.1, num_classes=10, pretrain=True):
        super(VisionTransformer, self).__init__()

        self.patch_size = patch_size
        self.channel_size = channel_size
        self.num_layers = num_layers
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads
        self.hidden_dim = hidden_dim
        self.dropout_prob = dropout_prob
        self.num_classes = num_classes

        self.num_patches = int(image_size ** 2 / patch_size ** 2)
        self.patch_embedding = nn.Linear(patch_size * patch_size * channel_size, embedding_dim)
        self.pos_embedding = nn.Parameter(torch.randn(self.num_patches + 1, embedding_dim))
        self.class_token = nn.Parameter(torch.rand(1, embedding_dim))

        transformer_encoder_list = [
            TransformerEncoder(embedding_dim, num_heads, hidden_dim, dropout_prob)
                    for _ in range(num_layers)]
        self.transformer_encoder_layers = nn.Sequential(*transformer_encoder_list)

        self.mlp_head = MLPHead(embedding_dim, num_classes)

    def forward(self, x):
        # get patch size and channel size
        P, C = self.patch_size, self.channel_size

        # split image into patches
        patches = x.unfold(1, C, C).unfold(2, P, P).unfold(3, P, P)
        patches = patches.contiguous().view(patches.size(0), -1, C * P * P).float()

        # linearly embed patches
        patch_embeddings = self.patch_embedding(patches)


        # add class token
        batch_size = patch_embeddings.shape[0]
        patch_embeddings = torch.cat((self.class_token.repeat(batch_size, 1, 1), patch_embeddings), 1)

        # add positional embedding
        patch_embeddings = patch_embeddings + self.pos_embedding

        # feed patch embeddings into a stack of Transformer encoders
        transformer_encoder_output = self.transformer_encoder_layers(patch_embeddings)

        # extract [class] token from encoder output
        output_class_token = transformer_encoder_output[:, 0]

        # pass token through mlp head for classification
        y = self.mlp_head(output_class_token)

        return y

In [8]:
import torchvision
from torchvision import datasets
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision.models as models
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler

In [9]:
image_size = 224
transform = T.Compose([
    T.Resize(image_size),
    T.RandomHorizontalFlip(),
    T.RandomRotation(15),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    T.ToTensor(),
    T.Normalize((0.5,), (0.5,))
])

torchvision.datasets.CIFAR10.url="http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"

trainset = datasets.CIFAR10(root='data', train=True, download=True, transform=transform)
testset = datasets.CIFAR10(root='data', train=False, download=True, transform=transform)


classes = trainset.classes

trainset = torch.utils.data.Subset(trainset, list(range(20000)))
testset = torch.utils.data.Subset(testset, list(range(4000)))

classes

['airplane',
 'automobile',
 'bird',
 'cat',
 'deer',
 'dog',
 'frog',
 'horse',
 'ship',
 'truck']

In [10]:
batch_size = 64

valid_size = 0.2

train_size = len(trainset)
indices = list(range(train_size))
np.random.shuffle(indices)
split = int(np.floor(valid_size * train_size))
train_idx, valid_idx = indices[split:], indices[:split]

train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

train_loader = DataLoader(trainset, batch_size=batch_size, sampler=train_sampler)
valid_loader = DataLoader(trainset, batch_size=batch_size, sampler=valid_sampler)
test_loader = DataLoader(testset, batch_size=batch_size)

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [12]:
image_size = 224; channel_size = 3
n_class = 10
dropout_prob = 0.1
n_layer = 12; embedding_dim = 768; n_head = 12; hidden_dim=3072
patch_size = 16

vision_transformer = VisionTransformer(patch_size, image_size, channel_size,
                            n_layer, embedding_dim, n_head, hidden_dim, dropout_prob, n_class).to(device)

In [13]:
criterion = nn.CrossEntropyLoss()

optimizer = optim.AdamW(vision_transformer.parameters(), lr=1e-4, weight_decay=1e-4)

scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=5e-4,
                                          steps_per_epoch=len(train_loader),
                                          epochs=20, pct_start=0.1)

In [None]:
from tqdm import tqdm

num_epochs = 20

for epoch in range(num_epochs):
    model = vision_transformer
    model.train()
    train_loss = 0
    correct, total = 0, 0

    for images, labels in tqdm(train_loader):
        images, labels = images.to(device), labels.to(device)


        outputs = model(images)
        loss = criterion(outputs, labels)


        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


        train_loss += loss.item()

        _, predicted = outputs.max(1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

    train_accuracy = 100 * correct / total
    avg_train_loss = train_loss / len(train_loader)


    model.eval()
    val_loss, correct, total = 0, 0, 0

    with torch.no_grad():
        for images, labels in valid_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            val_loss += loss.item()
            _, predicted = outputs.max(1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    val_accuracy = 100 * correct / total
    avg_val_loss = val_loss / len(valid_loader)

    print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {avg_train_loss:.4f}, Train Acc: {train_accuracy:.2f}% | Val Loss: {avg_val_loss:.4f}, Val Acc: {val_accuracy:.2f}%")

  0%|          | 0/250 [00:00<?, ?it/s]