<center><h1>Masked Autoencoders</h1> </center>

<center><p><a href="http://arxiv.org/abs/2111.06377">Masked Autoencoders Are Scalable Vision Learners</a></p></center>

<img src="https://miro.medium.com/v2/resize:fit:1120/1*l8zPV1sSDmEPbwHTDh5Rzw.png" width="600"/>

[Code: https://github.com/facebookresearch/mae](https://github.com/facebookresearch/mae/blob/main/models_mae.py)


In [1]:
import numpy as np
import torch
from torch import nn

from functools import partial

import timm.models.vision_transformer
from timm.models.vision_transformer import PatchEmbed, Block

# Vision Transformer

In [2]:
class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
    """Vision Transformer with support for global average pooling"""

    def __init__(self, global_pool=False, **kwargs):
        super(VisionTransformer, self).__init__(**kwargs)

        self.global_pool = global_pool
        if self.global_pool:
            norm_layer = kwargs['norm_layer']
            embed_dim = kwargs['embed_dim']
            self.fc_norm = norm_layer(embed_dim)

            del self.norm  # remove the original norm

    def forward_features(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        if self.global_pool:
            x = x[:, 1:, :].mean(dim=1)  # global pool without cls token
            outcome = self.fc_norm(x)
        else:
            x = self.norm(x)
            outcome = x[:0]

        return outcome

## Models

In [3]:
def vit_base_patch16(**kwargs):
    model = VisionTransformer(patch_size=16, embed_dim=768, depth=12,
                              num_heads=12, mlp_ratio=4, qkv_bias=True,
                              norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model

In [4]:
def vit_large_patch16(**kwargs):
    model = VisionTransformer(patch_size=16, embed_dim=1024, depth=24,
                              num_heads=16, mlp_ratio=4, qkv_bias=True,
                              norm_layer=partial(nn.LayerNorm, eps=1e-6),
                              **kwargs)
    return model

In [5]:
def vit_huge_patch14(**kwargs):
    model = VisionTransformer(patch_size=14, embed_dim=1280, depth=32,
                              num_heads=16, mlp_ratio=4, qkv_bias=True,
                              norm_layer=partial(nn.LayerNorm, eps=1e-6),
                              **kwargs)
    return model

# Masked Autoencoder

In [6]:
class MaskedAutoencoderViT(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3,
                 embed_dim=1024, depth=24, num_heads=16,
                 decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
                 mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
        super().__init__()

        # -----------------------------------------------------------------------------------------
        # MAE encoder specifics
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False)

        self.blocks = nn.ModuleList([
            Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
            for _ in range(depth)])
        self.norm = norm_layer(embed_dim)
        # -----------------------------------------------------------------------------------------

        # -----------------------------------------------------------------------------------------
        # MAE decoder specifics
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)

        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))

        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False)

        self.decoder_blocks = nn.ModuleList([
            Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
            for _ in range(decoder_depth)])
        self.decoder_norm = norm_layer(decoder_embed_dim)
        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans, bias=True)
        # -----------------------------------------------------------------------------------------

        self.norm_pix_loss = norm_pix_loss

        self.initialize_weights()

    def initialize_weights(self):
        # initialize (and freeze) pos_embed by sin-cos embedding
        pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1],
                                            int(self.patch_embed.num_patches ** .5),
                                            cls_token=True)
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

        decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1],
                                                    int(self.patch_embed.num_patches ** .5),
                                                    cls_token=True)
        self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))

        # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
        w = self.patch_embed.proj.weight.data
        torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

        # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
        torch.nn.init.normal_(self.cls_token, std=.02)
        torch.nn.init.normal_(self.mask_token, std=.02)

        # initialize nn.Linear and nn.LayerNorm
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def patchify(self, imgs):
        """
        :param imgs: (N, 3, H, W)
        :return: (N, L, 3 * patch_size**2)
        """
        p = self.patch_embed.patch_size[0]
        assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0

        h = w = imgs.shape[2] // p
        x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
        x = torch.einsum('nchpwq->nhwpqc', x)
        x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3))
        return x

    def unpatchify(self, x):
        """
        :param x: (N, L, 3 * patch_size**2)
        :return: (N, 3, H, W)
        """
        p = self.patch_embed.patch_size[0]
        h = w = int(x.shape[1] ** .5)
        assert h * w == x.shape[1]

        x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
        x = torch.einsum('nhwpqc->nchpwq', x)
        imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
        return imgs

    def random_masking(self, x, mask_ratio):
        """
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.
        :param x: [N, L, D], sequence
        """
        N, L, D = x.shape  # batch, length, dim
        len_keep = int(L * (1 - mask_ratio))

        noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]

        # sort noise for each sample
        ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is to keep, large is to remove
        ids_restore = torch.argsort(ids_shuffle, dim=1) # same as ids_shuffle, used for restoration

        # keep the first subset
        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

        # generate the binary mask: 0 is to keep, 1 is to remove
        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0
        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_masked, mask, ids_restore

    def forward_encoder(self, x, mask_ratio):
        # embed patches
        x = self.patch_embed(x)

        # add pos embed w/o cls token
        x = x + self.pos_embed[:, 1:, :]

        # masking: length -> length * mask_ratio
        x, mask, ids_restore = self.random_masking(x, mask_ratio)

        # append cls token
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # apply Transformer blocks
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)

        return x, mask, ids_restore

    def forward_decoder(self, x, ids_restore):
        # embed tokens
        x = self.decoder_embed(x)

        # append mask tokens to sequence
        mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token
        x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # unshuffle
        x = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token

        # add pos embed
        x = x + self.decoder_pos_embed

        # apply Transformer blocks
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)

        # predictor projection
        x = self.decoder_pred(x)

        # remove cls token
        x = x[:, 1:, :]

        return x

    def forward_loss(self, imgs, pred, mask):
        """
        :param imgs: [N, 3, H, W]
        :param pred: [N, L, p*p*3]
        :param mask: [N, L], 0 is to keep, 1 is to remove
        """
        target = self.patchify(imgs)
        if self.norm_pix_loss:
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)
            target = (target - mean) / (var + 1.e-6) ** .5

        loss = (pred - target) ** 2
        loss = loss.mean(dim=-1)  # (N, L], mean loss per patch

        loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches
        return loss

    def forward(self, imgs, mask_ratio=0.75):
        latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
        pred = self.forward_decoder(latent, ids_restore)  # [N, L, p*p*3]
        loss = self.forward_loss(imgs, pred, mask)
        return loss, pred, mask

## Models

In [7]:
def mae_vit_base_patch16_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(patch_size=16, embed_dim=768, depth=12, num_heads=12,
                                 decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
                                 mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model

In [8]:
def mae_vit_large_patch16_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(patch_size=16, embed_dim=1024, depth=24, num_heads=16,
                                 decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
                                 mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model

In [9]:
def mae_vit_huge_patch14_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(patch_size=14, embed_dim=1280, depth=32, num_heads=16,
                                 decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
                                 mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model

In [10]:
mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b
mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b
mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b

# Utils

## Position Embedding

In [11]:
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
    grid_h = np.arange(grid_size, dtype=np.float32)
    grid_w = np.arange(grid_size, dtype=np.float32)
    grid = np.meshgrid(grid_w, grid_h)
    grid = np.stack(grid, axis=0)

    grid = grid.reshape([2, 1, grid_size, grid_size])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed

In [12]:
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0

    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

    emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D)
    return emb

In [13]:
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    """
    :param embed_dim: output dimension for each position
    :param pos: a list of positions to be encoded: size (M,)
    :return: (M, D)
    """
    assert embed_dim % 2 == 0

    omega = np.arange(embed_dim // 2, dtype=np.float64)
    omega /= embed_dim / 2.
    omega = 1. / 10000 ** omega  # (D/2,)

    pos = pos.reshape(-1)  #(M,)
    out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product

    emb_sin = np.sin(out)  # (M, D/2)
    emb_cos = np.cos(out)  # (M, D/2)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb

# Summary

## Data

In [14]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

in_channel = 3
image_size = 224
batch_size = 2

data = torch.randn((batch_size, in_channel, image_size, image_size)).to(device)

## MAE ViT-B/16

In [15]:
from torchkeras import summary

net = mae_vit_base_patch16().to(device)

summary(net, input_data=data)
del net

--------------------------------------------------------------------------
Layer (type)                            Output Shape              Param #
Conv2d-1                           [-1, 768, 14, 14]              590,592
Identity-2                            [-1, 196, 768]                    0
LayerNorm-3                            [-1, 50, 768]                1,536
Linear-4                              [-1, 50, 2304]            1,771,776
Identity-5                          [-1, 12, 50, 64]                    0
Identity-6                          [-1, 12, 50, 64]                    0
Dropout-7                           [-1, 12, 50, 50]                    0
Linear-8                               [-1, 50, 768]              590,592
Dropout-9                              [-1, 50, 768]                    0
Identity-10                            [-1, 50, 768]                    0
Identity-11                            [-1, 50, 768]                    0
LayerNorm-12                         

## MAE ViT-L/16

In [16]:
net = mae_vit_large_patch16().to(device)

summary(net, input_data=data)
del net

--------------------------------------------------------------------------
Layer (type)                            Output Shape              Param #
Conv2d-1                          [-1, 1024, 14, 14]              787,456
Identity-2                           [-1, 196, 1024]                    0
LayerNorm-3                           [-1, 50, 1024]                2,048
Linear-4                              [-1, 50, 3072]            3,148,800
Identity-5                          [-1, 16, 50, 64]                    0
Identity-6                          [-1, 16, 50, 64]                    0
Dropout-7                           [-1, 16, 50, 50]                    0
Linear-8                              [-1, 50, 1024]            1,049,600
Dropout-9                             [-1, 50, 1024]                    0
Identity-10                           [-1, 50, 1024]                    0
Identity-11                           [-1, 50, 1024]                    0
LayerNorm-12                         

## MAE ViT-H/14

In [17]:
net = mae_vit_huge_patch14().to(device)

summary(net, input_data=data)
del net

--------------------------------------------------------------------------
Layer (type)                            Output Shape              Param #
Conv2d-1                          [-1, 1280, 16, 16]              753,920
Identity-2                           [-1, 256, 1280]                    0
LayerNorm-3                           [-1, 65, 1280]                2,560
Linear-4                              [-1, 65, 3840]            4,919,040
Identity-5                          [-1, 16, 65, 80]                    0
Identity-6                          [-1, 16, 65, 80]                    0
Dropout-7                           [-1, 16, 65, 65]                    0
Linear-8                              [-1, 65, 1280]            1,639,680
Dropout-9                             [-1, 65, 1280]                    0
Identity-10                           [-1, 65, 1280]                    0
Identity-11                           [-1, 65, 1280]                    0
LayerNorm-12                         