Coded by Lujia Zhong @lujiazho

In [1]:
import time
import numpy as np
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm

class Attention(nn.Module):
    def __init__(self, config):
        super(Attention, self).__init__()
        self.num_attention_heads = config.transformer["num_heads"]
        self.attention_head_size = int(config.embedding_size / self.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = Linear(config.embedding_size, self.all_head_size)
        self.key = Linear(config.embedding_size, self.all_head_size)
        self.value = Linear(config.embedding_size, self.all_head_size)

        self.out = Linear(config.embedding_size, config.embedding_size)
        self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
        self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])

        self.softmax = Softmax(dim=-1)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states):
        # hidden_states: torch.Size([4, 197, 768])
        
        mixed_query_layer = self.query(hidden_states)
        # mixed_query_layer: torch.Size([4, 197, 768])
        mixed_key_layer = self.key(hidden_states)
        # mixed_key_layer: torch.Size([4, 197, 768])
        mixed_value_layer = self.value(hidden_states)
        # mixed_value_layer: torch.Size([4, 197, 768])

        query_layer = self.transpose_for_scores(mixed_query_layer)
        # query_layer: torch.Size([4, 12, 197, 64])
        key_layer = self.transpose_for_scores(mixed_key_layer)
        # key_layer: torch.Size([4, 12, 197, 64])
        value_layer = self.transpose_for_scores(mixed_value_layer)
        # value_layer: torch.Size([4, 12, 197, 64])

        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        # attention_scores: torch.Size([4, 12, 197, 197])
        attention_scores = attention_scores / (self.attention_head_size**0.5)
        # attention_scores: torch.Size([4, 12, 197, 197])
        attention_probs = self.softmax(attention_scores)
        # attention_probs: torch.Size([4, 12, 197, 197])
        weights = attention_probs
        # weights: None
        attention_probs = self.attn_dropout(attention_probs)
        # attention_probs: torch.Size([4, 12, 197, 197])

        context_layer = torch.matmul(attention_probs, value_layer)
        # context_layer: torch.Size([4, 12, 197, 64])
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        # context_layer: torch.Size([4, 197, 12, 64])
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        
        context_layer = context_layer.view(*new_context_layer_shape)
        # context_layer: torch.Size([4, 197, 768])
        attention_output = self.out(context_layer)
        # attention_output: torch.Size([4, 197, 768])
        attention_output = self.proj_dropout(attention_output)
        # attention_output: torch.Size([4, 197, 768])
        
        return attention_output, weights


class MLP(nn.Module):
    def __init__(self, config):
        super(MLP, self).__init__()
        self.fc1 = Linear(config.embedding_size, config.transformer["mlp_dim"])
        self.fc2 = Linear(config.transformer["mlp_dim"], config.embedding_size)
        self.act_fn = torch.nn.functional.gelu
        self.dropout = Dropout(config.transformer["dropout_rate"])

        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.normal_(self.fc1.bias, std=1e-6)
        nn.init.normal_(self.fc2.bias, std=1e-6)

    def forward(self, x):
        # x: torch.Size([4, 197, 768])
        
        x = self.fc1(x)
        # x: torch.Size([4, 197, 3072])
        
        x = self.act_fn(x)
        x = self.dropout(x)
        x = self.fc2(x)
        # x: torch.Size([4, 197, 768])
        
        x = self.dropout(x)
        return x


class Embeddings(nn.Module):
    def __init__(self, config, img_size, in_channels=3):
        super(Embeddings, self).__init__()
        img_size = (img_size, img_size)

        patch_size = config.patches
        n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])

        self.patch_embeddings = Conv2d(in_channels=in_channels,
                                       out_channels=config.embedding_size,
                                       kernel_size=patch_size,
                                       stride=patch_size)
        self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches+1, config.embedding_size))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, config.embedding_size))
        # self.cls_token: torch.Size([1, 1, 768])

        self.dropout = Dropout(config.transformer["dropout_rate"])

    def forward(self, x):
        # x: torch.Size([4, 3, 224, 224])
        
        B = x.shape[0]
        cls_tokens = self.cls_token.expand(B, -1, -1)
        # cls_tokens: torch.Size([4, 1, 768])

        x = self.patch_embeddings(x)
        # x: torch.Size([4, 768, 14, 14])
        x = x.flatten(2)
        # x: torch.Size([4, 768, 196])
        x = x.transpose(-1, -2)
        # x: torch.Size([4, 196, 768])
        x = torch.cat((cls_tokens, x), dim=1)
        # x: torch.Size([4, 197, 768])

        embeddings = x + self.position_embeddings
        # embeddings: torch.Size([4, 197, 768])
        embeddings = self.dropout(embeddings)
        return embeddings


class Block(nn.Module):
    def __init__(self, config):
        super(Block, self).__init__()
        self.embedding_size = config.embedding_size
        self.attention_norm = LayerNorm(config.embedding_size, eps=1e-6)
        self.ffn_norm = LayerNorm(config.embedding_size, eps=1e-6)
        self.ffn = MLP(config)
        self.attn = Attention(config)

    def forward(self, x):
        # x: torch.Size([4, 197, 768])
        
        h = x
        x = self.attention_norm(x)
        x, weights = self.attn(x)
        # x: torch.Size([4, 197, 768])
        
        x = x + h

        h = x
        x = self.ffn_norm(x)
        x = self.ffn(x)
        # x: torch.Size([4, 197, 768])
        
        x = x + h
        return x, weights

class Encoder(nn.Module):
    def __init__(self, config):
        super(Encoder, self).__init__()
        self.layer = nn.ModuleList()
        self.encoder_norm = LayerNorm(config.embedding_size, eps=1e-6)
        for _ in range(config.transformer["num_layers"]):
            layer = Block(config)
            self.layer.append(layer)

    def forward(self, hidden_states):
        # hidden_states: torch.Size([4, 197, 768])
        
        attn_weights = []
        for layer_block in self.layer:
            hidden_states, weights = layer_block(hidden_states)
            attn_weights.append(weights)
        # hidden_states: torch.Size([4, 197, 768])
        
        encoded = self.encoder_norm(hidden_states)
        # encoded: torch.Size([4, 197, 768])
        
        return encoded, attn_weights

class VisionTransformer(nn.Module):
    def __init__(self, config, img_size=224, num_classes=21843):
        super(VisionTransformer, self).__init__()
        self.num_classes = num_classes
        
        self.embeddings = Embeddings(config, img_size=img_size)
        self.encoder = Encoder(config)
        self.head = Linear(config.embedding_size, num_classes)

    def forward(self, x):
        # x: torch.Size([4, 3, 224, 224])
        
        embedding_output = self.embeddings(x)
        # embedding_output: torch.Size([4, 197, 768])
        
        encoded, attn_weights = self.encoder(embedding_output)
        # encoded: torch.Size([4, 197, 768])
        # attn_weights: [12, torch.Size([4, 12, 197, 197])]
        
        logits = self.head(encoded[:, 0]) # this is cls, or use np.mean(x, axis=1)
        # logits: torch.Size([4, 10])

        return logits, attn_weights

class ModelConfig:
    patches = (16, 16)
    embedding_size = 768
    transformer = {
        'mlp_dim': 3072,
        'num_heads': 12,
        'num_layers': 12,
        'attention_dropout_rate': 0.0,
        'dropout_rate': 0.1
    }

img_size = 224
model = VisionTransformer(ModelConfig(), img_size=img_size, num_classes=10)

In [2]:
optimizer = torch.optim.SGD(model.parameters(),
                                lr=3e-2,
                                momentum=0.9,
                                weight_decay=0)
criterion = CrossEntropyLoss()

iterarions = 2
begin = time.time()
# Training
for iterarion in range(iterarions):
    x = torch.Tensor(np.random.randn(4, 3, img_size, img_size))
    y = torch.LongTensor([0,1,2,3])
    
    optimizer.zero_grad()
    pred, atten_weights = model(x)
    loss = criterion(pred, y)

    if iterarion % 1 == 0:
        print('Iterarion:', '%2d,' % (iterarion + 1), 'loss =', '{:.4f}'.format(loss))

    loss.backward()
    optimizer.step()
print(f"{(time.time() - begin)/iterarions:.4f}s / iterarion")

Iterarion:  1, loss = 2.4564
Iterarion:  2, loss = 1.7643
3.7500s / iterarion


In [3]:
# predicting
pred, atten = model(torch.randn(4, 3, img_size, img_size))
pred.data.max(1, keepdim=True)[1].squeeze()

tensor([2, 2, 2, 2])