# 1. VOLO

This notebook is written to better understand volo architecture.

# 1. Introduction

# 2. Method

Our model can be regarded as an architecture with two seperate stages. The first stage consists of a stack of Outlookers that generates fine-level token representations. The second stage deploys a sequence of transformer blocks to aggregate global information. At the beginning of each stage, a patch embedding module is used to map the input to token representations with designed shapes.

## 2.1. Outlooker

Given a sequence of input C-dim token representations `X.shape = (H, W, C)`, Outlooker can be written as follows: 
- `X^hat = OutlookAtt(LN(X)) + X`
- `Z=MLP(LN(X^hat)) + X^hat`

### 2.1.1 Outlook Attention

Outlook attention is simple, efficient, and easy to implement. The main insights behind it are:
1. the feature at each spatial location is representative enough to generate attention weights for locally aggregating its neighboring features.
2. The dense and local spatial aggregation can encode fine-level information efficiently.

![volo_211](./imgs/volo_211.jpg)
![volo_211_2](./imgs/volo_211_2.jpg)

In [32]:
import torch
from torch import nn
import torch.nn.functional as F


class OutLookAttn(nn.Module):
    def __init__(self, dim, head, H, W, K=3, padding=1):
        super().__init__()
        self.v_pj = nn.Linear(dim, dim)
        self.attn = nn.Linear(dim, head * K ** 4)
        self.unfold = nn.Unfold(K, padding=padding)
        self.fold = nn.Fold((H, W), K, padding=padding)
        self.K = K
        self.head = head
        self.dim = dim
        self.H = H
        self.W = W

    def forward(self, x): # input(x): (B, H * W, dim)
        # value(v): (B, H * W, dim) -> (B, dim, H, W) -> (B, H * W, head, K**2, dim / head)
        v = self.v_pj(x).permute(0, 2, 1).reshape(-1, self.dim, self.H, self.W)
        v = self.unfold(v).reshape(-1, self.head, self.dim // self.head, self.K ** 2, self.H * self.W)
        v = v.permute(0, 4, 1, 3, 2) 

        # attention(a): (B, H * W, dim) -> (B, H * W, head, K ** 2, K ** 2)
        a = self.attn(x).reshape(-1, self.H * self.W, self.head, self.K ** 2, self.K ** 2)
        a = F.softmax(a, dim=-1) 

        x = (a @ v).permute(0, 2, 4, 3, 1).reshape(-1, self.dim * (self.K ** 2), self.H * self.W)
        x = self.fold(x).permute(0, 2, 3, 1).reshape(-1, self.H * self.W, self.dim)

        return x


class MLP(nn.Module):
    def __init__(self, dim, mlp_ratio):
        super().__init__()
        self.fc1 = nn.Linear(dim, dim * mlp_ratio)
        self.fc2 = nn.Linear(dim * mlp_ratio, dim)
        self.gelu = nn.GELU()

    def forward(self, x):
        return self.fc2(self.gelu(self.fc1(x)))


class Outlooker(nn.Module):
    def __init__(self, dim, head, mlp_ratio, H, W, K=3, padding=1):
        super().__init__()
        self.outlook_attn = OutLookAttn(dim, head, H, W, K, padding)
        self.mlp = MLP(dim, mlp_ratio)
        self.ln1 = nn.LayerNorm(dim, eps=1e-6)
        self.ln2 = nn.LayerNorm(dim, eps=1e-6)

    def forward(self, x):
        x_hat = self.outlook_attn(self.ln1(x)) + x
        z = self.mlp(self.ln2(x_hat)) + x_hat

        return z


class PatchEmbedding(nn.Module):
    def __init__(self, in_dim, out_dim, H, W, patch_size):
        super().__init__()
        self.conv = nn.Conv2d(in_dim, out_dim, patch_size, patch_size)
        self.patch_len = H * W // (patch_size * patch_size)
        self.H = H
        self.W = W
        self.in_dim = in_dim
        self.out_dim = out_dim

    def forward(self, x):
        if x.ndim != 4:
            x = x.permute(0, 2, 1).reshape(-1, self.in_dim, self.H, self.W)
        x = self.conv(x).permute(0, 2, 3, 1).reshape(-1, self.patch_len, self.out_dim)

        return x


class SelfAttention(nn.Module):
    def __init__(self, dim, mlp_ratio, head):
        super().__init__()

    def forward(self, x):
        return x


class ClassAttention(nn.Module):
    def __init__(self, dim, mlp_ratio, head):
        super().__init__()

    def forward(self, x):
        return x


class VOLO(nn.Module):
    def __init__(self, s1_num, s1_dim, s1_head, s1_mlp_ratio, s2_num, s2_dim, s2_head, s2_mlp_ratio,
                H, W, K=3, padding=1):
        super().__init__()
        self.patch_embedding1 = PatchEmbedding(3, s1_dim, H, W, 8)
        self.stage1 = nn.Sequential(*[Outlooker(s1_dim, s1_head, s1_mlp_ratio, H // 8, W // 8, K, padding) for _ in range(s1_num)])
        self.patch_embedding2 = PatchEmbedding(s1_dim, s2_dim, H // 8, W // 8, 2)
        self.stage2 = nn.Sequential(*[SelfAttention(s2_dim, s2_mlp_ratio, s2_head) for _ in range(s2_num)])
        self.cls = ClassAttention(s2_dim, s2_mlp_ratio, s2_head)

    def forward(self, x):
        x = self.stage1(self.patch_embedding1(x))
        x = self.stage2(self.patch_embedding2(x))
        x = self.cls(x)

        return x

outlook_attn = Outlooker(192, 6, 3, 14, 14)
x = torch.rand([2, 14 * 14, 192])
y = outlook_attn(x)
print(f"# params: {sum([p.numel() for p in outlook_attn.parameters() if p.requires_grad])/(1024*1024):.2f}M")
print(y.shape)

model = VOLO(H=224, W=224, s1_num=4, s1_dim=192, s1_head=6, s1_mlp_ratio=3, 
                            s2_num=14, s2_dim=384, s2_head=12, s2_mlp_ratio=3)
x = torch.rand([2, 3, 224, 224])
y = model(x)
print(f"# params: {sum([p.numel() for p in model.parameters() if p.requires_grad])/(1024*1024):.2f}M")
print(y.shape)

# params: 0.34M
torch.Size([2, 196, 192])
# params: 1.67M
torch.Size([2, 196, 384])
