# 1. VOLO

This notebook is written to better understand volo architecture.

# 1. Introduction

# 2. Method

Our model can be regarded as an architecture with two seperate stages. The first stage consists of a stack of Outlookers that generates fine-level token representations. The second stage deploys a sequence of transformer blocks to aggregate global information. At the beginning of each stage, a patch embedding module is used to map the input to token representations with designed shapes.

## 2.1. Outlooker

Given a sequence of input C-dim token representations `X.shape = (H, W, C)`, Outlooker can be written as follows: 
- `X^hat = OutlookAtt(LN(X)) + X`
- `Z=MLP(LN(X^hat)) + X^hat`

### 2.1.1 Outlook Attention

Outlook attention is simple, efficient, and easy to implement. The main insights behind it are:
1. the feature at each spatial location is representative enough to generate attention weights for locally aggregating its neighboring features.
2. The dense and local spatial aggregation can encode fine-level information efficiently.

![volo_211](./imgs/volo_211.jpg)
![volo_211_2](./imgs/volo_211_2.jpg)

### Todo
1. OutLookAttn 기존 모델 코드와 비교
2. 전체적으로 들어가는 Dropout / Attention Dropout 추가
3. 코드 로직 / 빼먹은 부분은 없는지 확인
4. 깔끔하게 정리

In [49]:
import math
import torch
from torch import nn
import torch.nn.functional as F


class OutLookAttn(nn.Module):
    """OutLookAttention (need to compare with original code)"""
    def __init__(self, dim, head, H, W, K=3, padding=1, qkv_bias=False):
        super().__init__()
        self.v_pj = nn.Linear(dim, dim, bias=qkv_bias)
        self.attn = nn.Linear(dim, head * K ** 4)
        self.proj = nn.Linear(dim, dim)
        self.unfold = nn.Unfold(K, padding=padding)
        self.fold = nn.Fold((H, W), K, padding=padding)
        self.K = K
        self.head = head
        self.dim = dim
        self.H = H
        self.W = W

    def forward(self, x): # input(x): (B, H * W, dim)
        # value(v): (B, H * W, dim) -> (B, dim, H, W) -> (B, H * W, head, K**2, dim / head)
        v = self.v_pj(x).permute(0, 2, 1).reshape(-1, self.dim, self.H, self.W)
        v = self.unfold(v).reshape(-1, self.head, self.dim // self.head, self.K ** 2, self.H * self.W)
        v = v.permute(0, 4, 1, 3, 2) 

        # attention(a): (B, H * W, dim) -> (B, H * W, head, K ** 2, K ** 2)
        a = self.attn(x).reshape(-1, self.H * self.W, self.head, self.K ** 2, self.K ** 2)
        a = F.softmax(a, dim=-1) 

        x = (a @ v).permute(0, 2, 4, 3, 1).reshape(-1, self.dim * (self.K ** 2), self.H * self.W)
        x = self.fold(x).permute(0, 2, 3, 1).reshape(-1, self.H * self.W, self.dim)

        return x


class MLP(nn.Module):
    def __init__(self, dim, mlp_ratio):
        super().__init__()
        self.fc1 = nn.Linear(dim, dim * mlp_ratio)
        self.fc2 = nn.Linear(dim * mlp_ratio, dim)
        self.gelu = nn.GELU()

    def forward(self, x):
        return self.fc2(self.gelu(self.fc1(x)))


class Outlooker(nn.Module):
    def __init__(self, dim, head, mlp_ratio, H, W, K=3, padding=1):
        super().__init__()
        self.outlook_attn = OutLookAttn(dim, head, H, W, K, padding)
        self.mlp = MLP(dim, mlp_ratio)
        self.ln1 = nn.LayerNorm(dim, eps=1e-6)
        self.ln2 = nn.LayerNorm(dim, eps=1e-6)

    def forward(self, x):
        x_hat = self.outlook_attn(self.ln1(x)) + x
        z = self.mlp(self.ln2(x_hat)) + x_hat

        return z


class ConvNormAct(nn.Sequential):
    def __init__(self, in_dim, out_dim, kernel_size, padding, stride):
        super().__init__(nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias=False),
                        nn.BatchNorm2d(out_dim), nn.ReLU(inplace=True))


class PatchEmbedding(nn.Module):
    """Non-overlapping embedding function

    difference from paper (from author repo)
    - use 4 conv layer in first patch embedding (img -> (pe) -> stage1)
    - add positional embedding in second patch embedding (stage1 -> (pe) -> stage2)
    
    """
    def __init__(self, in_dim, out_dim, H, W, patch_size, use_stem=False, hidden_dim=64, add_pe=False):
        super().__init__()
        if use_stem:
            self.conv = nn.Sequential(
                ConvNormAct(in_dim, hidden_dim, kernel_size=7, padding=3, stride=2),
                ConvNormAct(hidden_dim, hidden_dim, kernel_size=3, padding=1, stride=1),
                ConvNormAct(hidden_dim, hidden_dim, kernel_size=3, padding=1, stride=1),
                nn.Conv2d(hidden_dim, out_dim, patch_size // 2, patch_size // 2)
            )
        else:
            self.conv = nn.Conv2d(in_dim, out_dim, patch_size, patch_size)

        self.patch_len = H * W // (patch_size * patch_size)
        self.H = H
        self.W = W
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.add_pe = add_pe
        if self.add_pe:
            self.pe = nn.Parameter(torch.zeros([1, self.patch_len, self.out_dim]))

    def forward(self, x):
        if x.ndim != 4:
            x = x.permute(0, 2, 1).reshape(-1, self.in_dim, self.H, self.W)
        x = self.conv(x).permute(0, 2, 3, 1).reshape(-1, self.patch_len, self.out_dim)

        if self.add_pe:
            x = x + self.pe.expand(x.size(0), -1, -1)

        return x


class MHSA(nn.Module):
    def __init__(self, dim, head, qkv_bias=False):
        super().__init__()
        self.k = dim // head
        self.div = math.sqrt(self.k)
        self.head = head
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        B, N, D = x.shape
        q, k, v = [out.reshape(B, N, self.head, self.k).permute(0, 2, 1, 3) for out in self.qkv(x).tensor_split(3, dim=-1)]

        attn = q @ k.transpose(-1, -2) / self.div
        attn_prob = F.softmax(attn, dim=-1)

        out = attn_prob @ v
        out = out.permute(0, 2, 1, 3).reshape(B, N, D)
        out = self.proj(out)

        return out


class SelfAttention(nn.Module):
    def __init__(self, dim, mlp_ratio, head):
        super().__init__()
        self.attn = MHSA(dim, head)
        self.mlp = MLP(dim, mlp_ratio)
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)

    def forward(self, x):
        x = self.attn(self.norm1(x)) + x
        x = self.mlp(self.norm2(x)) + x

        return x


class MHCA(nn.Module):
    def __init__(self, dim, head, qkv_bias=False):
        super().__init__()
        self.k = dim // head
        self.div = math.sqrt(self.k)
        self.head = head
        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        B, N, D = x.shape
        q = self.q(x[:, :1]).reshape(B, 1, self.head, self.k).permute(0, 2, 1, 3)
        k, v = [out.reshape(B, N-1, self.head, self.k).permute(0, 2, 1, 3) for out in self.kv(x[:, 1:]).tensor_split(2, dim=-1)]

        attn = q @ k.transpose(-1, -2) / self.div
        attn_prob = F.softmax(attn, dim=-1)

        out = attn_prob @ v
        out = out.permute(0, 2, 1, 3).reshape(B, 1, D)
        out = self.proj(out)

        return out


class ClassAttention(nn.Module):
    def __init__(self, dim, mlp_ratio, head):
        super().__init__()
        self.attn = MHCA(dim, head)
        self.mlp = MLP(dim, mlp_ratio)
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)

    def forward(self, cls_x):
        cls, x = cls_x
        z = torch.concat([cls, x], dim=1)
        cls = self.attn(self.norm1(z)) + cls # Todo: check whether norm1() is applied to z or cls
        cls = self.mlp(self.norm2(cls)) + cls

        return cls, x


class VOLO(nn.Module):
    def __init__(self, num_classes, s1_num, s1_dim, s1_head, s1_mlp_ratio, s2_num, s2_dim, s2_head, s2_mlp_ratio,
                H, W, K=3, padding=1, stem_hidden_dim=64):
        super().__init__()
        # Forward: pe1 (0.04M) -> stage1 (1.35M) -> pe2 (0.28M) -> stage2 (19.75M) -> cls (2.82M) -> norm -> fc (0.37M)
        self.patch_embedding1 = PatchEmbedding(3, s1_dim, H, W, 8, use_stem=True, hidden_dim=stem_hidden_dim)
        self.patch_embedding2 = PatchEmbedding(s1_dim, s2_dim, H // 8, W // 8, 2, add_pe=True)
        self.stage1 = nn.Sequential(*[Outlooker(s1_dim, s1_head, s1_mlp_ratio, H // 8, W // 8, K, padding) for _ in range(s1_num)])
        self.stage2 = nn.Sequential(*[SelfAttention(s2_dim, s2_mlp_ratio, s2_head) for _ in range(s2_num)])
        self.cls = nn.Sequential(*[ClassAttention(s2_dim, s2_mlp_ratio, s2_head) for _ in range(2)])
        self.cls_token = nn.Parameter(torch.zeros(1, 1, s2_dim))
        self.norm = nn.LayerNorm(s2_dim)
        self.classifier = nn.Linear(s2_dim, num_classes)
        self.aux_head = nn.Linear(s2_dim, num_classes)

    def forward(self, x):
        x = self.stage1(self.patch_embedding1(x))
        x = self.stage2(self.patch_embedding2(x))
        cls_token, x = self.cls((self.cls_token.expand(x.size(0), -1, -1), x))
        cls_token = self.norm(cls_token)
        out = self.classifier(cls_token)

        return out


def volo_d1_224():
    return VOLO(num_classes=1000, H=224, W=224, 
            s1_num=4, s1_dim=192, s1_head=6, s1_mlp_ratio=3, 
            s2_num=14, s2_dim=384, s2_head=12, s2_mlp_ratio=3)

def volo_d2_224(): # change: increase layer & dim & head
    return VOLO(num_classes=1000, H=224, W=224, 
            s1_num=6, s1_dim=256, s1_head=8, s1_mlp_ratio=3, 
            s2_num=18, s2_dim=512, s2_head=16, s2_mlp_ratio=3)

def volo_d3_224(): # change: increase layer
    return VOLO(num_classes=1000, H=224, W=224, 
            s1_num=8, s1_dim=256, s1_head=8, s1_mlp_ratio=3, 
            s2_num=28, s2_dim=512, s2_head=16, s2_mlp_ratio=3)

def volo_d4_224(): # change: increase dim & head
    return VOLO(num_classes=1000, H=224, W=224, 
            s1_num=8, s1_dim=384, s1_head=12, s1_mlp_ratio=3, 
            s2_num=28, s2_dim=768, s2_head=16, s2_mlp_ratio=3)

def volo_d5_224(): # change: increase layer & mlp ratio 
    """volo d5 @ 224
    
    modification from paper (from author code)
    - stem_hidden_dim = 128
    """
    return VOLO(num_classes=1000, H=224, W=224, stem_hidden_dim=128,
            s1_num=12, s1_dim=384, s1_head=12, s1_mlp_ratio=4, 
            s2_num=36, s2_dim=768, s2_head=16, s2_mlp_ratio=4)


outlook_attn = Outlooker(192, 6, 3, 14, 14)
x = torch.rand([2, 14 * 14, 192])
y = outlook_attn(x)
print(f"# params: {sum([p.numel() for p in outlook_attn.parameters() if p.requires_grad])/(1024*1024):.2f}M")
print(y.shape)

for model_fn in [volo_d1_224, volo_d2_224, volo_d3_224, volo_d4_224, volo_d5_224]:
    model = model_fn()
    x = torch.rand([2, 3, 224, 224])
    y = model(x)
    print(f"{model_fn.__name__} # params: {sum([p.numel() for p in model.parameters() if p.requires_grad])/(1024*1024):.2f}M")
    print(y.shape)

# params: 0.37M
torch.Size([2, 196, 192])
volo_d1_224 # params: 25.40M
torch.Size([2, 1, 1000])
volo_d2_224 # params: 55.96M
torch.Size([2, 1, 1000])
volo_d3_224 # params: 82.33M
torch.Size([2, 1, 1000])
volo_d4_224 # params: 184.02M
torch.Size([2, 1, 1000])
volo_d5_224 # params: 281.77M
torch.Size([2, 1, 1000])
