Coded by Lujia Zhong @lujiazho

In [1]:
import math
import time
import numpy as np
import torch
from torch import nn
from torch.nn import MSELoss, Dropout, Softmax, Linear, Conv2d, LayerNorm

class Attention(nn.Module):
    def __init__(self, config):
        super(Attention, self).__init__()
        self.num_attention_heads = config.transformer["num_heads"]
        self.attention_head_size = int(config.embedding_size / self.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = Linear(config.embedding_size, self.all_head_size)
        self.key = Linear(config.embedding_size, self.all_head_size)
        self.value = Linear(config.embedding_size, self.all_head_size)

        self.out = Linear(config.embedding_size, config.embedding_size)
        self.attn_dropout = Dropout(config.transformer["dropout_rate"])
        self.proj_dropout = Dropout(config.transformer["dropout_rate"])

        self.softmax = Softmax(dim=-1)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states):        
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / (self.attention_head_size**0.5)
        attention_probs = self.softmax(attention_scores)
        weights = attention_probs
        attention_probs = self.attn_dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        
        context_layer = context_layer.view(*new_context_layer_shape)
        attention_output = self.out(context_layer)
        attention_output = self.proj_dropout(attention_output)
        
        return attention_output, weights


class MLP(nn.Module):
    def __init__(self, config):
        super(MLP, self).__init__()
        self.fc1 = Linear(config.embedding_size, config.transformer["mlp_dim"])
        self.fc2 = Linear(config.transformer["mlp_dim"], config.embedding_size)
        self.act_fn = torch.nn.functional.gelu
        self.dropout = Dropout(config.transformer["dropout_rate"])

        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.normal_(self.fc1.bias, std=1e-6)
        nn.init.normal_(self.fc2.bias, std=1e-6)

    def forward(self, x):
        x = self.fc1(x)
        
        x = self.act_fn(x)
        x = self.dropout(x)
        x = self.fc2(x)
        
        x = self.dropout(x)
        return x

class Block(nn.Module):
    def __init__(self, config):
        super(Block, self).__init__()
        self.embedding_size = config.embedding_size
        self.attention_norm = LayerNorm(config.embedding_size, eps=1e-6)
        self.ffn_norm = LayerNorm(config.embedding_size, eps=1e-6)
        self.ffn = MLP(config)
        self.attn = Attention(config)

    def forward(self, x):
        h = x
        x = self.attention_norm(x)
        x, weights = self.attn(x)
        
        x = x + h

        h = x
        x = self.ffn_norm(x)
        x = self.ffn(x)
        
        x = x + h
        return x, weights

class Coder(nn.Module):
    def __init__(self, config):
        super(Coder, self).__init__()
        self.layer = nn.ModuleList()
        self.coder_norm = LayerNorm(config.embedding_size, eps=1e-6)
        for _ in range(config.transformer["num_layers"]):
            layer = Block(config)
            self.layer.append(layer)

    def forward(self, hidden_states):
        # hidden_states: torch.Size([4, 65, 768])
        
        attn_weights = []
        for layer_block in self.layer:
            hidden_states, weights = layer_block(hidden_states)
            attn_weights.append(weights)
        # hidden_states: torch.Size([4, 65, 768])
        
        encoded = self.coder_norm(hidden_states)
        # encoded: torch.Size([4, 65, 768])
        
        return encoded, attn_weights

class Embeddings(nn.Module):
    def __init__(self, config, position_embeddings, img_size, in_channels=3):
        super(Embeddings, self).__init__()
        img_size = (img_size, img_size)

        patch_size = (config.patch_size, config.patch_size)
        embedding_size = config.enc_config.embedding_size

        self.patch_embeddings = Conv2d(in_channels=in_channels,
                                       out_channels=embedding_size,
                                       kernel_size=patch_size,
                                       stride=patch_size)
        self.position_embeddings = position_embeddings[:,1:,:]
        # self.position_embeddings: torch.Size([1, 256, 768])
        
    def forward(self, x):
        # x: torch.Size([4, 3, 256, 256])
        
        x = self.patch_embeddings(x)
        # x: torch.Size([4, 768, 16, 16])
        x = x.flatten(2)
        # x: torch.Size([4, 768, 256])
        x = x.transpose(-1, -2)
        # x: torch.Size([4, 256, 768])
        
        embeddings = x + self.position_embeddings
        # embeddings: torch.Size([4, 256, 768])
        return embeddings

class MAE(nn.Module):
    def __init__(self, config, img_size=256, in_channels=3):
        super().__init__()
        # common
        assert 0 <= config.masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
        self.in_c = in_channels
        self.masking_ratio = config.masking_ratio
        self.patch_size = config.patch_size
        self.n_patches = (img_size // self.patch_size) * (img_size // self.patch_size)
        
        encoder_dim = config.enc_config.embedding_size
        decoder_dim = config.dec_config.embedding_size
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, encoder_dim))
        # self.cls_token: torch.Size([1, 1, 768])
        self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches+1, encoder_dim))

        self.embeddings = Embeddings(config, self.position_embeddings, img_size=img_size)
        self.encoder = Coder(config.enc_config)

        # encoder 2 decoder
        self.enc_to_dec = nn.Linear(encoder_dim, decoder_dim) if encoder_dim != decoder_dim else nn.Identity()
        
        self.mask_token = nn.Parameter(torch.randn(decoder_dim))
        self.decoder = Coder(config.dec_config)

        self.decoder_pos_emb = nn.Parameter(torch.zeros(1, self.n_patches+1, decoder_dim))
        self.to_pixels = nn.Linear(decoder_dim, self.patch_size*self.patch_size*self.in_c)
        # loss
        self.criterion = MSELoss()

    def patchify(self, imgs):
        # imgs: [4, 3, 256, 256]
        assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % self.patch_size == 0

        p_side = imgs.shape[2] // self.patch_size
        x = imgs.reshape(shape=(imgs.shape[0], self.in_c, p_side, self.patch_size, p_side, self.patch_size))
        x = x.permute(0, 2, 4, 3, 5, 1)
        x = x.reshape(shape=(imgs.shape[0], p_side * p_side, self.patch_size**2 * self.in_c))
        # x: [4, 256, 768]
        
        return x
    
    def unpatchify(self, x):
        # x: [4, 256, 768]
        p_side = int(x.shape[1]**.5)
        assert p_side**2 == x.shape[1]
        
        x = x.reshape(shape=(x.shape[0], p_side, p_side, self.patch_size, self.patch_size, self.in_c))
        x = x.permute(0, 5, 1, 3, 2, 4)
        imgs = x.reshape(shape=(x.shape[0], self.in_c, p_side * self.patch_size, p_side * self.patch_size))
        # imgs: [4, 3, 256, 256]
        
        return imgs
    
    def forward(self, imgs):
        # img: torch.Size([4, 3, 256, 256])
        batch = imgs.shape[0]
        
        # encoder embeddings
        embedding_output = self.embeddings(imgs)
        # tokens: torch.Size([4, 256, 768])

        # shuffle
        num_masked = int(self.masking_ratio * self.n_patches)            # num_masked: 192
        rand_indices = torch.rand(batch, self.n_patches).argsort(dim=-1) # torch.Size([4, 256])
        masked_indices, unmasked_indices = rand_indices[:, :num_masked], rand_indices[:, num_masked:]
        # masked_indices: torch.Size([4, 192])
        # unmasked_indices: torch.Size([4, 64])
        batch_range = torch.arange(batch)[:, None]                      # tensor([[0], [1], [2], [3]])
        
        # tokens part
        tokens = embedding_output[batch_range, unmasked_indices]        # torch.Size([4, 64, 768])
        
        # add cls token
        cls_tokens = self.cls_token.expand(batch, -1, -1)               # torch.Size([4, 1, 768])
        cls_tokens = cls_tokens + self.position_embeddings[:,:1,:]      # torch.Size([4, 1, 768])
        tokens = torch.cat((cls_tokens, tokens), dim=1)                 # torch.Size([4, 65, 768])
        
        # encoder
        encoded_tokens, enc_attns = self.encoder(tokens)                # torch.Size([4, 65, 768])
        
        # decoder embeddings
        decoder_tokens = self.enc_to_dec(encoded_tokens)                # torch.Size([4, 65, 512])
        cls_tokens = decoder_tokens[:,:1,:]                             # torch.Size([4, 1, 512])
        decoder_tokens = decoder_tokens[:,1:,:]                         # torch.Size([4, 64, 512])
        
        # masked learnable tokens
        mask_tokens = self.mask_token.repeat(batch, num_masked, 1)           # torch.Size([4, 192, 512])
        decoder_tokens = torch.cat((mask_tokens, decoder_tokens), dim=1)     # torch.Size([4, 256, 512])
        # unshuffle
        decoder_tokens = decoder_tokens[batch_range, rand_indices.sort()[1]] # torch.Size([4, 256, 512])
        # add cls token
        decoder_tokens = torch.cat((cls_tokens, decoder_tokens), dim=1)      # torch.Size([4, 257, 512])
        mask_tokens = decoder_tokens + self.decoder_pos_emb                  # torch.Size([4, 257, 512])

        # decoder
        decoded_tokens, dec_attns = self.decoder(decoder_tokens)        # torch.Size([4, 256, 512])

        # only take the masked part as pred
        mask_tokens = decoded_tokens[batch_range, masked_indices]       # torch.Size([4, 192, 512])
        pred_pixel_values = self.to_pixels(mask_tokens)                 # torch.Size([4, 192, 768])

        # loss
        pateched_imgs = self.patchify(imgs)
        loss = self.criterion(pred_pixel_values, pateched_imgs[batch_range, masked_indices])
        
        # predicted img
        imgs = torch.cat((pred_pixel_values, 
                          pateched_imgs[batch_range, unmasked_indices]), dim=1)    # torch.Size([4, 256, 768])
        imgs = imgs[batch_range, rand_indices.sort()[1]]                           # torch.Size([4, 256, 768])
        pred_imgs = self.unpatchify(imgs)
        return loss, pred_imgs, enc_attns, dec_attns
    
class ModelConfig:
    patch_size = 16
    masking_ratio = 0.75
    
    # encoder
    class enc_config:
        embedding_size = 768
        transformer = {
            'mlp_dim': 3072,
            'num_heads': 12,
            'num_layers': 12,
            'dropout_rate': 0.0
        }
    
    # decoder
    class dec_config:
        embedding_size = 512
        transformer = {
            'mlp_dim': 2048,
            'num_heads': 8,
            'num_layers': 8,
            'dropout_rate': 0.0
        }

img_size = 256
model = MAE(ModelConfig(), img_size=img_size)

In [2]:
optimizer = torch.optim.SGD(model.parameters(),
                                lr=3e-2,
                                momentum=0.9,
                                weight_decay=0)
iterarions = 5
begin = time.time()
# Training
for iterarion in range(iterarions):
    x = torch.Tensor(np.random.randn(4, 3, img_size, img_size))
    
    optimizer.zero_grad()
    loss, *_ = model(x)
    
    if iterarion % 1 == 0:
        print('Iterarion:', '%2d,' % (iterarion + 1), 'loss =', '{:.4f}'.format(loss))

    loss.backward()
    optimizer.step()
print(f"{(time.time() - begin)/iterarions:.4f}s / iterarion")

Iterarion:  1, loss = 1.3392
Iterarion:  2, loss = 1.2834
Iterarion:  3, loss = 1.2110
Iterarion:  4, loss = 1.1590
Iterarion:  5, loss = 1.1276
2.5700s / iterarion


In [3]:
# predicting
x = torch.randn(4, 3, img_size, img_size)
loss, pred, *_ = model(x)