Let's go step by step through the **Pyramid Vision Transformer (PVT)** — a fundamental architecture that bridges **ViTs and CNNs** for dense visual tasks like **object detection** and **semantic segmentation**.

---

# **1. Motivation**

The **Vision Transformer (ViT)** introduced powerful global attention but has drawbacks:

* Requires **fixed-size inputs** (e.g., 224×224).
* Produces **single-scale features**, unsuitable for dense prediction tasks.
* Has **quadratic complexity** with respect to the number of patches.
* Lacks **local inductive bias** (like translation invariance from CNNs).

The **Pyramid Vision Transformer (PVT)** (Wang et al., *ICCV 2021*) fixes these issues by **building a hierarchical (multi-scale) feature pyramid** — like a CNN backbone (e.g., ResNet).

---

# **2. Key Idea**

PVT = **Hierarchical Vision Transformer Backbone**

It mimics CNNs by generating **multi-resolution feature maps**:

| Stage | Resolution                | Channels      | Purpose                 |
| :---: | :------------------------ | :------------ | :---------------------- |
|   1   | High (e.g., 1/4 of input) | Low           | Capture local features  |
|   2   | Medium                    | More channels | Broader receptive field |
|   3   | Low                       | More channels | Semantic features       |
|   4   | Very low                  | Deep features | Global context          |

These outputs can directly feed **FPN**, **Mask R-CNN**, **U-Net**, etc.

---

# **3. Architecture Overview**

The PVT backbone has 4 stages, each performing:

1. **Patch embedding** (patchify + linear projection)
2. **Transformer encoder blocks**
3. **Spatial reduction attention (SRA)** — reduces tokens before attention
4. **Downsampling between stages** (to form a pyramid)

---

## **3.1. Stage 1: Patch Embedding**

The image is split into small patches (like ViT):

$$
x_0 = \text{PatchEmbed}(I) \in \mathbb{R}^{H_0 \times W_0 \times C_0}
$$

Then flattened into tokens for the first transformer block.

---

## **3.2. Transformer Encoder with SRA**

Standard self-attention has **O(N²)** cost (N = number of patches).
To make it scalable, **PVT** introduces **Spatial Reduction Attention (SRA)**:

Instead of using all keys and values, SRA **downsamples** them:

$$
K' = \text{Downsample}(K), \quad V' = \text{Downsample}(V)
$$

Then attention becomes:

$$
\text{Attention}(Q, K', V') = \text{Softmax}\left( \frac{Q {K'}^T}{\sqrt{d}} \right) V'
$$

This reduces complexity from **O(N²)** → **O(N × N/s²)** where *s* is the reduction ratio.

✅ This keeps **global receptive field** but reduces computation.

---

## **3.3. Pyramid Hierarchy**

After each stage, feature resolution is halved, and channels increase:

| Stage | Resolution | Channels | Patch size | Reduction ratio |
| :---: | :--------- | :------- | :--------- | :-------------- |
|   1   | 1/4        | 64       | 4×4        | 8               |
|   2   | 1/8        | 128      | 2×2        | 4               |
|   3   | 1/16       | 320      | 2×2        | 2               |
|   4   | 1/32       | 512      | 2×2        | 1               |

Each stage is a **Transformer encoder** operating on the corresponding feature scale.

---

# **4. Advantages**

✅ **Hierarchical features** → usable as a CNN backbone (e.g., in Mask R-CNN).
✅ **Global receptive field** from transformers.
✅ **Efficient attention** via spatial reduction.
✅ **Variable input resolution** support.
✅ **Strong performance on dense tasks** (segmentation, detection).

---

# **5. Comparison with ViT and Swin**

| Model | Attention Type       | Hierarchy | Complexity | Windowed? | Suitable for Detection? |
| :---- | :------------------- | :-------- | :--------- | :-------- | :---------------------- |
| ViT   | Global               | ✖         | O(N²)      | ✖         | ✖                       |
| Swin  | Window-based (local) | ✅         | O(N)       | ✅         | ✅                       |
| PVT   | Global (SRA-reduced) | ✅         | O(N/s²)    | ✖         | ✅                       |

So:

* **PVT keeps global attention** but makes it efficient (SRA).
* **Swin** uses local window attention and shifting to connect neighborhoods.

---

# **6. Equation Summary**

Let’s define for stage *i*:

* Input tokens:
  $$ X_i \in \mathbb{R}^{N_i \times C_i} $$
* Spatial reduction ratio: *r*

Then attention becomes:

$$
Q = X_i W_Q, \quad K = \text{Down}(X_i W_K), \quad V = \text{Down}(X_i W_V)
$$

$$
\text{SRA}(X_i) = \text{Softmax}\left( \frac{Q K^T}{\sqrt{d}} \right) V
$$

where
$$
\text{Down}(\cdot) = \text{Reshape→Conv2d(stride=r)→Flatten}
$$

---

# **7. Visual Summary**

```
Input Image
   ↓
[Stage 1] Patch Embed → Transformer (SRA) → Downsample
   ↓
[Stage 2] Transformer (SRA) → Downsample
   ↓
[Stage 3] Transformer (SRA) → Downsample
   ↓
[Stage 4] Transformer (SRA)
   ↓
Feature Pyramid Outputs → Detection / Segmentation Head
```

---

# **8. Typical Usage**

PVT variants:

* **PVT-Tiny**, **PVT-Small**, **PVT-Medium**, **PVT-Large**
  differ in embedding dims and number of blocks.

Used in:

* **PVT + FPN → RetinaNet / Mask R-CNN**
* **PVT + UPerNet → Semantic Segmentation**

---

# **9. Example PyTorch Skeleton**

```python
class SpatialReductionAttention(nn.Module):
    def __init__(self, dim, num_heads, sr_ratio):
        super().__init__()
        self.num_heads = num_heads
        self.scale = (dim // num_heads) ** -0.5
        self.q = nn.Linear(dim, dim)
        self.kv = nn.Linear(dim, dim * 2)
        self.sr_ratio = sr_ratio
        if sr_ratio > 1:
            self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
            self.norm = nn.LayerNorm(dim)
        
    def forward(self, x, H, W):
        B, N, C = x.shape
        q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads)
        
        if self.sr_ratio > 1:
            x_ = x.transpose(1, 2).reshape(B, C, H, W)
            x_ = self.sr(x_).reshape(B, C, -1).transpose(1, 2)
            x_ = self.norm(x_)
        else:
            x_ = x
        
        kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads)
        k, v = kv[:, :, 0], kv[:, :, 1]
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        out = (attn @ v).reshape(B, N, C)
        return out
```

---

# **10. Summary**

✅ **Concept:** Build a **multi-scale feature pyramid** with transformers.
✅ **Trick:** **Spatial Reduction Attention (SRA)** → efficient global attention.
✅ **Benefit:** Works as a **drop-in CNN backbone** for detection/segmentation.
✅ **Key papers:**

* *Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions* (Wang et al., ICCV 2021)

---




## **numerical example**
Perfect — let’s go through a **numerical example** showing how **Spatial Reduction Attention (SRA)** in **Pyramid Vision Transformer (PVT)** drastically reduces computational cost compared to standard self-attention.

---

# **1. Reminder: Complexity of Self-Attention**

For each layer, self-attention requires:

$$
\text{Cost} \sim O(N^2 \cdot d)
$$

where

* ( N ) = number of tokens (patches)
* ( d ) = embedding dimension

In ViT, ( N = (H/P) \times (W/P) ).

Example: for 224×224 input and 16×16 patch size:
$$
N = (224/16)^2 = 14^2 = 196
$$

---

# **2. Example Setup**

Let’s consider a **PVT stage** operating on feature maps of size 56×56 (like Stage 2 in PVT-Tiny).

| Parameter               | Symbol                     | Value |
| ----------------------- | -------------------------- | ----- |
| Feature map size        | ( H \times W )             | 56×56 |
| Tokens                  | ( N = 3136 )               |       |
| Embedding dim           | ( d = 64 )                 |       |
| Heads                   | ( h = 1 ) (for simplicity) |       |
| Spatial reduction ratio | ( s = 8 )                  |       |

---

# **3. Cost of Standard Self-Attention**

In vanilla ViT-style attention:

$$
Q, K, V \in \mathbb{R}^{N \times d}
$$

Attention matrix has size ( N \times N ).

Total multiply-adds per head:

$$
\text{Cost}_{\text{standard}} = N^2 \cdot d
$$

Plug in the numbers:

$$
\text{Cost}_{\text{standard}} = (3136)^2 \times 64 = 9.8 \times 10^8
$$

Almost **1 billion operations per head per layer** — extremely heavy!

---

# **4. Cost of SRA (Spatial Reduction Attention)**

Now, in SRA we **downsample K and V** by ratio ( s ).
This means:

$$
N' = \frac{N}{s^2}
$$

So for ( s = 8 ):

$$
N' = \frac{3136}{8^2} = \frac{3136}{64} = 49
$$

Now, K and V each have 49 tokens instead of 3136.

The attention matrix now has size ( N \times N' = 3136 \times 49 ).

Cost becomes:

$$
\text{Cost}_{\text{SRA}} = N \times N' \times d = 3136 \times 49 \times 64 = 9.8 \times 10^6
$$

---

# **5. Comparison**

| Method      | Tokens in K/V | Attention Matrix | Operations | Reduction       |
| :---------- | :------------ | :--------------- | :--------- | :-------------- |
| Standard SA | 3136          | 3136×3136        | 9.8×10⁸    | —               |
| SRA (s=8)   | 49            | 3136×49          | 9.8×10⁶    | **100× faster** |

✅ The SRA reduces computational cost by roughly **100×** at this stage.

---

# **6. Why It Still Works**

* **Queries (Q)** remain full-resolution (one per token),
  so each patch still "attends" globally.
* **Keys/Values** are **downsampled**, so we summarize context efficiently.
* The global context is preserved, but at a **lower spatial resolution**.

This is similar to how **FPN** or **feature maps in CNNs** summarize spatial information at different scales.

---

# **7. Visual Intuition**

```
Standard Attention
------------------
Q : 3136 tokens (56×56)
K : 3136 tokens
V : 3136 tokens
→ Attention matrix: 3136 × 3136
→ Full resolution, costly!

SRA (s = 8)
------------
Q : 3136 tokens (56×56)
K : 49 tokens (7×7 after downsampling)
V : 49 tokens
→ Attention matrix: 3136 × 49
→ Much smaller, same global reach
```

---

# **8. Formula Recap**

Standard attention:

$$
\text{SA}(Q,K,V) = \text{Softmax}\left( \frac{QK^T}{\sqrt{d}} \right)V
$$

Spatial Reduction Attention:

$$
\text{SRA}(Q,K',V') = \text{Softmax}\left( \frac{Q{K'}^T}{\sqrt{d}} \right)V'
$$

where

$$
K',V' = \text{Conv2D_Downsample}(K,V, \text{stride}=s)
$$

---

# **9. Summary**

✅ **Spatial Reduction Attention (SRA)** dramatically lowers attention cost
✅ **Global context preserved**, since Q still covers all spatial tokens
✅ **Multi-scale hierarchy** gives CNN-like feature pyramids
✅ **Ideal backbone** for detection and segmentation

---

Would you like me to extend this with a **small PyTorch example** showing the actual tensor shapes before and after applying SRA (e.g., 56×56 → 7×7 reduction)?
