# **Pyramid Vision Transformer (PVT)**
The **Pyramid Vision Transformer (PVT)** — a fundamental architecture that bridges **ViTs and CNNs** for dense visual tasks like **object detection** and **semantic segmentation**.

---

## **1. Motivation**

The **Vision Transformer (ViT)** introduced powerful global attention but has drawbacks:

* Requires **fixed-size inputs** (e.g., 224×224).
* Produces **single-scale features**, unsuitable for dense prediction tasks.
* Has **quadratic complexity** with respect to the number of patches.
* Lacks **local inductive bias** (like translation invariance from CNNs).

The **Pyramid Vision Transformer (PVT)** (Wang et al., *ICCV 2021*) fixes these issues by **building a hierarchical (multi-scale) feature pyramid** — like a CNN backbone (e.g., ResNet).

---

## **2. Key Idea**

PVT = **Hierarchical Vision Transformer Backbone**

It mimics CNNs by generating **multi-resolution feature maps**:

| Stage | Resolution                | Channels      | Purpose                 |
| :---: | :------------------------ | :------------ | :---------------------- |
|   1   | High (e.g., 1/4 of input) | Low           | Capture local features  |
|   2   | Medium                    | More channels | Broader receptive field |
|   3   | Low                       | More channels | Semantic features       |
|   4   | Very low                  | Deep features | Global context          |

These outputs can directly feed **FPN**, **Mask R-CNN**, **U-Net**, etc.

---

## **3. Architecture Overview**

The PVT backbone has 4 stages, each performing:

1. **Patch embedding** (patchify + linear projection)
2. **Transformer encoder blocks**
3. **Spatial reduction attention (SRA)** — reduces tokens before attention
4. **Downsampling between stages** (to form a pyramid)

---

#### **3.1. Stage 1: Patch Embedding**

The image is split into small patches (like ViT):

$$
x_0 = \text{PatchEmbed}(I) \in \mathbb{R}^{H_0 \times W_0 \times C_0}
$$

Then flattened into tokens for the first transformer block.

---

#### **3.2. Transformer Encoder with SRA**

Standard self-attention has **O(N²)** cost (N = number of patches).
To make it scalable, **PVT** introduces **Spatial Reduction Attention (SRA)**:

Instead of using all keys and values, SRA **downsamples** them:

$$
K' = \text{Downsample}(K), \quad V' = \text{Downsample}(V)
$$

Then attention becomes:

$$
\text{Attention}(Q, K', V') = \text{Softmax}\left( \frac{Q {K'}^T}{\sqrt{d}} \right) V'
$$

This reduces complexity from **O(N²)** → **O(N × N/s²)** where *s* is the reduction ratio.

✅ This keeps **global receptive field** but reduces computation.

---

#### **3.3. Downsampling**
This is the *core trick* of **Spatial Reduction Attention (SRA)** in **PVT**.

The **downsampling** of $ K $ and ( V ) is **not** done by a linear projection (like `nn.Linear`), but by a **2D convolution with stride = s**, followed by normalization.

---

**Recall the Context**

In normal self-attention:
$$
Q = X W_Q,\quad K = X W_K,\quad V = X W_V
$$

where $$ X \in \mathbb{R}^{B \times N \times C} , N = H \times W .$$

In **PVT**, we *only* reduce the spatial size of $ K $ and $ V $:
before computing attention, we apply a **downsampling operator** to them.

---

**3.3.1 The Spatial Reduction Step**

For each transformer stage with reduction ratio $ s $:

1. **Reshape tokens back to 2D feature map**
   $$ X \in \mathbb{R}^{B \times N \times C} \Rightarrow X_{\text{map}} \in \mathbb{R}^{B \times C \times H \times W} $$

2. **Apply 2D convolution with stride = s**
   $$ X_{\text{reduced}} = \text{Conv2d}(X_{\text{map}},\ \text{stride}=s,\ \text{kernel}=s) $$

   So if $ s=8 $ and $ X_{\text{map}} $ is 56×56, we get 7×7 output.

3. **Flatten back to tokens**
   $$ X_{\text{reduced}} \Rightarrow X' \in \mathbb{R}^{B \times N' \times C} $$
   where $ N' = \frac{H}{s} \times \frac{W}{s} $

4. **Normalize**
   Apply `LayerNorm(C)` before computing $ K', V' $.

5. **Linear projections**
   $$ K' = X' W_K,\quad V' = X' W_V $$

---

**3.3.2 Why Conv2D?**

 **Conv2D with stride s** naturally performs **spatial averaging / subsampling**,
similar to how CNN backbones reduce resolution.

It learns *how* to summarize spatial context (via kernel weights) instead of a fixed pooling rule.

This gives two benefits:

* Learnable spatial reduction (not hardcoded average pooling).
* Maintains **translation equivariance** and **spatial coherence**.

---



**3.3.3 Visual Summary**

```
[Before reduction]
X: (B, 3136, 64) → reshape → (B, 64, 56, 56)

↓ Conv2d(kernel=8, stride=8)

[After reduction]
→ (B, 64, 7, 7) → flatten → (B, 49, 64)
```

So:

* **Conv2d** performs the spatial reduction.
* **Linear layers** still compute Q, K′, V′ after reduction.
* **LayerNorm** ensures stable distribution after spatial subsampling.

---

**3.3.4 Summary**

- ✅ Downsampling of K, V = **Conv2D(stride=s)**, not Linear.
- ✅ This yields $ N' = (H/s)(W/s) $ tokens for K′ and V′.
- ✅ Learnable → model decides *how* to aggregate spatial context.
- ✅ Efficient → reduces attention cost from $ O(N^2) $ to $ O(N·N′) $.
- ✅ Keeps global field of view → Q remains full-size.

---



#### **3.4. Pyramid Hierarchy**

After each stage, feature resolution is halved, and channels increase:

| Stage | Resolution | Channels | Patch size | Reduction ratio |
| :---: | :--------- | :------- | :--------- | :-------------- |
|   1   | 1/4        | 64       | 4×4        | 8               |
|   2   | 1/8        | 128      | 2×2        | 4               |
|   3   | 1/16       | 320      | 2×2        | 2               |
|   4   | 1/32       | 512      | 2×2        | 1               |

Each stage is a **Transformer encoder** operating on the corresponding feature scale.

---

## **4. Advantages**

- ✅ **Hierarchical features** → usable as a CNN backbone (e.g., in Mask R-CNN).
- ✅ **Global receptive field** from transformers.
- ✅ **Efficient attention** via spatial reduction.
- ✅ **Variable input resolution** support.
- ✅ **Strong performance on dense tasks** (segmentation, detection).

---

## **5. Comparison with ViT and Swin**

| Model | Attention Type       | Hierarchy | Complexity | Windowed? | Suitable for Detection? |
| :---- | :------------------- | :-------- | :--------- | :-------- | :---------------------- |
| ViT   | Global               | ✖         | O(N²)      | ✖         | ✖                       |
| Swin  | Window-based (local) | ✅         | O(N)       | ✅         | ✅                       |
| PVT   | Global (SRA-reduced) | ✅         | O(N/s²)    | ✖         | ✅                       |

So:

* **PVT keeps global attention** but makes it efficient (SRA).
* **Swin** uses local window attention and shifting to connect neighborhoods.

---

## **6. Equation Summary**

Let’s define for stage *i*:

* Input tokens:
  $$ X_i \in \mathbb{R}^{N_i \times C_i} $$
* Spatial reduction ratio: *r*

Then attention becomes:

$$
Q = X_i W_Q, \quad K = \text{Down}(X_i W_K), \quad V = \text{Down}(X_i W_V)
$$

$$
\text{SRA}(X_i) = \text{Softmax}\left( \frac{Q K^T}{\sqrt{d}} \right) V
$$

where
$$
\text{Down}(\cdot) = \text{Reshape→Conv2d(stride=r)→Flatten}
$$

---

## **7. Visual Summary**

```
Input Image
   ↓
[Stage 1] Patch Embed → Transformer (SRA) → Downsample
   ↓
[Stage 2] Transformer (SRA) → Downsample
   ↓
[Stage 3] Transformer (SRA) → Downsample
   ↓
[Stage 4] Transformer (SRA)
   ↓
Feature Pyramid Outputs → Detection / Segmentation Head
```

---

## **8. Typical Usage**

PVT variants:

* **PVT-Tiny**, **PVT-Small**, **PVT-Medium**, **PVT-Large**
  differ in embedding dims and number of blocks.

Used in:

* **PVT + FPN → RetinaNet / Mask R-CNN**
* **PVT + UPerNet → Semantic Segmentation**

---

## **Numerical Example**
Let’s go through a **numerical example** showing how **Spatial Reduction Attention (SRA)** in **Pyramid Vision Transformer (PVT)** drastically reduces computational cost compared to standard self-attention.

---

#### **1. Reminder: Complexity of Self-Attention**

For each layer, self-attention requires:

$$
\text{Cost} \sim O(N^2 \cdot d)
$$

where

* $ N $ = number of tokens (patches)
* $ d $ = embedding dimension

In ViT, $ N = (H/P) \times (W/P) $.

Example: for 224×224 input and 16×16 patch size:
$$
N = (224/16)^2 = 14^2 = 196
$$

---




### **Recap of PVT-Tiny Parameters**

Let’s include **Stage 1**, so we can see how the **token counts**, **spatial reduction ratio**, and **computational cost** evolve across the full PVT pyramid.

We’ll continue with the **PVT-Tiny** configuration (most common example).



| Stage | Output Resolution | Channels | Patch Size | SRA Reduction Ratio (s) | #Blocks |
| :---: | :---------------: | :------: | :--------: | :---------------------: | :-----: |
|   1   |       56×56       |    64    |     4×4    |            8            |    2    |
|   2   |       28×28       |    128   |     2×2    |            4            |    2    |
|   3   |       14×14       |    320   |     2×2    |            2            |    2    |
|   4   |        7×7        |    512   |     2×2    |            1            |    2    |

Input is 224×224 RGB image → patchify step by step.

---

### **2. Stage 1: 56×56 Feature Map**

#### Token count

$$
N_1 = 56 \times 56 = 3136
$$
Embedding dim: $ d_1 = 64 $

#### Standard self-attention cost

$$
\text{Cost}_{\text{standard}}^{(1)} = N_1^2 \times d_1 = 3136^2 \times 64 = 9.8 \times 10^8
$$

#### Spatial Reduction Attention (s = 8)

Downsample keys/values by factor 8:

$$
N_1' = \frac{N_1}{8^2} = \frac{3136}{64} = 49
$$

So

$$
\text{Cost}_{\text{SRA}}^{(1)} = N_1 \times N_1' \times d_1 = 3136 \times 49 \times 64 = 9.8 \times 10^6
$$

Assume the mini-batch size is **B = 1** for simplicity.

---


| Symbol        | Meaning                             | Shape         |
| :------------ | :---------------------------------- | :------------ |
| Input         | Flattened tokens                    | (1, 3136, 64) |
| Q             | Query projection of all tokens      | (1, 3136, 64) |
| K′            | Keys after spatial reduction by 8   | (1, 49, 64)   |
| V′            | Values after spatial reduction by 8 | (1, 49, 64)   |
| Attention map | Q × K′ᵀ                             | (3136, 49)    |




✅ **≈ 100× reduction** already at the first stage.

---

### **3. Stage 2: 28×28 Feature Map**

From Stage 1 → Stage 2 we downsample 2×.

$$
N_2 = 28 \times 28 = 784, \quad d_2 = 128, \quad s = 4
$$

Standard:

$$
\text{Cost}_{\text{standard}}^{(2)} = 784^2 \times 128 = 7.9 \times 10^7
$$

Reduced:

$$
N_2' = \frac{784}{4^2} = 49
$$

$$
\text{Cost}_{\text{SRA}}^{(2)} = 784 \times 49 \times 128 = 4.9 \times 10^6
$$

| Symbol        | Meaning                              | Shape |
| :------------ | :----------------------------------- | :---- |
| Input         | (1, 784, 128)                        |       |
| Q             | (1, 784, 128)                        |       |
| K′            | (1, 49, 128) (because 784 / 4² = 49) |       |
| V′            | (1, 49, 128)                         |       |
| Attention map | (784, 49)                            |       |

✅ ~16× less compute.

---

#### **4. Stage 3: 14×14 Feature Map**

$$
N_3 = 196, \quad d_3 = 320, \quad s = 2
$$

Standard:

$$
\text{Cost}_{\text{standard}}^{(3)} = 196^2 \times 320 = 1.2 \times 10^7
$$

Reduced:

$$
N_3' = \frac{196}{2^2} = 49
$$

$$
\text{Cost}_{\text{SRA}}^{(3)} = 196 \times 49 \times 320 = 3.1 \times 10^6
$$


| Symbol        | Meaning                              | Shape |
| :------------ | :----------------------------------- | :---- |
| Input         | (1, 196, 320)                        |       |
| Q             | (1, 196, 320)                        |       |
| K′            | (1, 49, 320) (because 196 / 2² = 49) |       |
| V′            | (1, 49, 320)                         |       |
| Attention map | (196, 49)                            |       |

✅ ~4× less compute.

---

### **5. Stage 4: 7×7 Feature Map**

$$
N_4 = 49, \quad d_4 = 512, \quad s = 1
$$

No reduction (since it’s already small).

$$
\text{Cost}_{\text{standard}}^{(4)} = N_4^2 \times d_4 = 49^2 \times 512 = 1.2 \times 10^6
$$

SRA = same (s = 1).


| Symbol        | Meaning                            | Shape |
| :------------ | :--------------------------------- | :---- |
| Input         | (1, 49, 512)                       |       |
| Q             | (1, 49, 512)                       |       |
| K′            | (1, 49, 512) (no reduction, s = 1) |       |
| V′            | (1, 49, 512)                       |       |
| Attention map | (49, 49)                           |       |

---

### **6. Full Pyramid Comparison**

| Stage | Resolution |   N  |  d  |  s  | Standard SA |    SRA    | Reduction |
| :---: | :--------- | :--: | :-: | :-: | :---------: | :-------: | :-------: |
|   1   | 56×56      | 3136 |  64 |  8  |  9.8 × 10⁸  | 9.8 × 10⁶ |  **100×** |
|   2   | 28×28      |  784 | 128 |  4  |  7.9 × 10⁷  | 4.9 × 10⁶ |  **16×**  |
|   3   | 14×14      |  196 | 320 |  2  |  1.2 × 10⁷  | 3.1 × 10⁶ |   **4×**  |
|   4   | 7×7        |  49  | 512 |  1  |  1.2 × 10⁶  | 1.2 × 10⁶ |     —     |


| Stage | Resolution | Tokens (N) | Channels (C) |  s  |    Q Shape    |  K′/V′ Shape | Attention Map |
| :---: | :--------- | :--------: | :----------: | :-: | :-----------: | :----------: | :-----------: |
|   1   | 56×56      |    3136    |      64      |  8  | (1, 3136, 64) |  (1, 49, 64) |   (3136, 49)  |
|   2   | 28×28      |     784    |      128     |  4  | (1, 784, 128) | (1, 49, 128) |   (784, 49)   |
|   3   | 14×14      |     196    |      320     |  2  | (1, 196, 320) | (1, 49, 320) |   (196, 49)   |
|   4   | 7×7        |     49     |      512     |  1  |  (1, 49, 512) | (1, 49, 512) |    (49, 49)   |

---

### **7. Observations**

* The largest savings occur in **early high-resolution stages**.
* Later stages are already small, so reduction is less critical.
* Total FLOPs ≈ **10× less** overall vs. pure global attention backbone.
* **Q** keeps full spatial resolution — each token still “looks” globally.
* **K′, V′** are spatially reduced to a **coarse grid of 7×7 = 49 tokens**, fixed across stages 1–3.
* Thus, the **attention map size** shrinks from millions (3136²) to only tens of thousands (3136×49).

The model preserves **global context** but cuts the computational cost by roughly **100×** in early stages.


This design choice is why **PVT can scale to high-res images** and still serve as a **drop-in backbone** for detection or segmentation.

---




## **Python Code**

```python
class SpatialReductionAttention(nn.Module):
    def __init__(self, dim, num_heads, sr_ratio):
        super().__init__()

        # Number of attention heads (same as standard multi-head attention)
        self.num_heads = num_heads

        # Scaling factor used in the attention softmax
        # scale = 1 / sqrt(head_dim)
        self.scale = (dim // num_heads) ** -0.5

        # Linear layers to create Q and [K,V]
        # Standard ViT-style projection layers
        self.q = nn.Linear(dim, dim)
        self.kv = nn.Linear(dim, dim * 2)

        # Spatial reduction ratio (how much to downsample K,V)
        # e.g. 8, 4, 2, 1 for PVT-Tiny stages
        self.sr_ratio = sr_ratio

        # Only create Conv2D if we actually reduce spatial resolution
        if sr_ratio > 1:
            # Conv2d performs learnable downsampling
            # (stride = sr_ratio) → spatially reduces tokens
            self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)

            # Normalization layer after reduction
            self.norm = nn.LayerNorm(dim)
```

---

#### Scaling factor
Scaling factor is one of those subtle but *essential* details in every attention mechanism.

**1. Where it comes from**

In all Transformer-style attention mechanisms, the core operation is:

$$
\text{Attention}(Q, K, V) = \text{Softmax}\left( \frac{Q K^{T}}{\sqrt{d_k}} \right)V
$$

where

* $ Q \in \mathbb{R}^{N \times d_k} $: queries
* $ K \in \mathbb{R}^{N' \times d_k} $: keys
* $ V \in \mathbb{R}^{N' \times d_v} $: values
* $ d_k $: dimension of each head (after splitting channels into heads).

That denominator **$ \sqrt{d_k} $** is precisely the **scaling factor** implemented as:

```python
self.scale = (dim // num_heads) ** -0.5
```

So:

$$
\text{scale} = \frac{1}{\sqrt{d_k}} = (d_k)^{-1/2}
$$

---

**2. Why do we need scaling?**

Without scaling, the dot-product $ QK^{T} $ can have very large values when $ d_k $ is large.
This leads to **unstable gradients** and **softmax saturation**.

Let’s see why.

Each entry of $ QK^T $ is the dot product between two vectors of length $ d_k $:

$$
(QK^T)_{ij} = \sum_{t=1}^{d_k} Q_{it} K_{jt}
$$

If the components of Q and K are zero-mean with variance 1, then the variance of the dot product grows linearly with ( d_k ):

$$
\text{Var}(QK^T) \propto d_k
$$

So as $ d_k $ increases (e.g., 64, 128, 256), the logits become large,
and **$softmax(QK^T)$** becomes extremely peaked → gradients vanish for most entries.

---

**3. The fix: scale by $ 1/\sqrt{d_k} $**

Dividing by $ \sqrt{d_k} $ normalizes the variance of the dot products, so the values stay roughly in a stable range (e.g., around -1 to 1), which keeps the softmax smooth and gradients well-behaved.

So we compute:

$$
\text{Attention weights} = \text{Softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right)
$$

This ensures:

* Stable softmax range
* Balanced gradient flow
* Faster and smoother convergence during training

---

**4. In the PVT Code**

In the code:

```python
self.scale = (dim // num_heads) ** -0.5
```

* `dim` = total embedding dimension $ C $
* `num_heads` = number of attention heads $ h $
* `dim // num_heads` = per-head dimension $ d_k $
* `** -0.5` = take inverse square root

So if $\text{dim} = 64$ and $$\text{num\_heads}=4:$$

$$
d_k = 64 / 4 = 16 \Rightarrow \text{scale} = 1 / \sqrt{16} = 0.25
$$

and inside `forward()`:

```python
attn = (q @ k.transpose(-2, -1)) * self.scale
```

means we are computing:

$$
A = \frac{Q K^{T}}{\sqrt{d_k}}
$$

---

**5. Numerical example**

Suppose
$ Q, K \in \mathbb{R}^{3 \times 4} $,
so $ d_k = 4 $.

If we skip scaling, the dot products might look like:

$$
QK^T =
\begin{bmatrix}
10 & 12 & 9 \
7 & 11 & 13 \
12 & 8 & 9
\end{bmatrix}
$$

Softmax of these large values → almost one-hot vectors (e.g., [0.99, 0.005, 0.005]),
and gradients vanish for most entries.

After scaling by $ 1/\sqrt{4} = 0.5 $:

$$
(QK^T)/\sqrt{d_k} =
\begin{bmatrix}
5 & 6 & 4.5 \
3.5 & 5.5 & 6.5 \
6 & 4 & 4.5
\end{bmatrix}
$$

Softmax becomes more balanced, e.g., [0.65, 0.20, 0.15],
leading to healthier gradient updates.

---

**6. Analogy**

Think of the scale factor as a **temperature control** for the softmax:

| Effect          | Formula                    | Behavior                             |
| :-------------- | :------------------------- | :----------------------------------- |
| No scaling      | Softmax(QKᵀ)               | Overconfident, sharp, poor gradients |
| Proper scaling  | Softmax(QKᵀ / √dₖ)         | Smooth, stable gradients             |
| Too small scale | Softmax(QKᵀ / large value) | Too flat, underconfident             |

---

**7. Summary**

- ✅ `self.scale = (dim // num_heads) ** -0.5` implements $ 1 / \sqrt{d_k} $
- ✅ Prevents softmax saturation and stabilizes training
- ✅ Keeps attention values in a numerically safe range
- ✅ Essential for all transformer-based attention (ViT, Swin, PVT, etc.)
- ✅ Even though PVT adds *Spatial Reduction Attention*, the attention core remains the same — and scaling is just as necessary.

---






In **standard PyTorch Transformers** (and frameworks built on top of them, like `torch.nn.MultiheadAttention`, `timm`, or `transformers`), the scaling factor
$$ \frac{1}{\sqrt{d_k}} $$
is **already built-in** — so we *don’t* manually multiply by `self.scale`.

However, when implementing **custom attention modules** (like in PVT, Swin, ViT-from-scratch), we need to include it explicitly.

Let’s go through this in detail.

---

## **1. PyTorch’s Built-in Attention Layer**

PyTorch provides this module:

```python
torch.nn.MultiheadAttention(embed_dim, num_heads)
```

When you call it:

```python
attn_output, attn_weights = self_attention(x, x, x)
```

internally it performs:

$$
\text{Attention}(Q,K,V) = \text{Softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right)V
$$

The scaling factor is *automatically applied* inside the source code.

If you look at the PyTorch implementation (simplified):

```python
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
attn_output_weights = attn_output_weights / math.sqrt(head_dim)
attn_output_weights = F.softmax(attn_output_weights, dim=-1)
```

So yes — it’s **already handled for you**.

That’s why, when you use something like `nn.TransformerEncoder` or `nn.MultiheadAttention`,
you don’t explicitly see or need to define `self.scale`.

---

## **2. Why do we define `self.scale` manually in custom architectures?**

In research code (ViT, Swin, PVT), we often **re-implement attention** manually, because we:

* Customize **token layout** (windows, shifted windows, reduced tokens)
* Add **spatial convolutions**, **relative position biases**, etc.
* Split **Q, K, V** differently (like PVT does with Conv2D for K,V)

Thus we can’t rely on PyTorch’s built-in layer — and we need to handle every step explicitly:

1. Project Q, K, V
2. Split heads
3. Compute attention weights
4. Apply scaling
5. Softmax
6. Aggregate V

That’s why lines like this appear:

```python
self.scale = (dim // num_heads) ** -0.5
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
```

They replicate what PyTorch’s internal module does,
but now we have the freedom to modify the attention logic.

---

## **3. Example Comparison**

### (A) Using **PyTorch built-in**

```python
import torch.nn as nn

attn = nn.MultiheadAttention(embed_dim=64, num_heads=4)
out, weights = attn(x, x, x)
```

✅ PyTorch handles:

* Linear projection to Q,K,V
* Scaling
* Softmax
* Dropout
* Output projection

---

### (B) Custom attention (ViT/Swin/PVT style)

```python
q = self.q(x)
k = self.k(x)
v = self.v(x)

attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
out = (attn @ v)
```

✅ You explicitly:

* Create Q,K,V
* Define `self.scale`
* Apply it manually
* Can add spatial reduction, window masking, etc.

---

## **4. Why this distinction matters**

| Case                                | Scaling done automatically?         | Typical Usage                    |
| :---------------------------------- | :---------------------------------- | :------------------------------- |
| `nn.MultiheadAttention`             | ✅ Yes (inside PyTorch)              | Standard NLP / small ViT         |
| Custom Attention (PVT, Swin, ViT)   | ❌ No, must do manually              | Research or custom vision models |
| HuggingFace models (`transformers`) | ✅ Yes (inside their implementation) | High-level APIs                  |

---

## **5. Summary**

✅ The scaling factor ( 1/\sqrt{d_k} ) is *always used* — it’s fundamental to all attention.
✅ In **PyTorch’s built-in layers**, it’s already applied internally.
✅ In **custom attention implementations** (ViT, Swin, PVT, etc.), we must explicitly define and apply it ourselves (`self.scale`).
✅ The underlying math is **identical** in all cases.

---

Would you like me to show the actual internal snippet from `torch.nn.MultiheadAttention` (the part where PyTorch applies the scaling)?


### **Now the forward()**

```python
    def forward(self, x, H, W):
        # Input:
        # x: [B, N, C] where N = H*W tokens
        B, N, C = x.shape
```

At this point,

* `x` = flattened spatial tokens (from a transformer block input)
* Each token has `C` channels (the embedding dimension).

---

```python
        # Compute Q as usual (no reduction)
        q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads)
```

✅ **Q** keeps *full spatial resolution* — every token has its own query vector.
Shape:
$$
Q \in \mathbb{R}^{B \times N \times h \times (C/h)}
$$

---

### **Spatial Reduction for K, V**

```python
        if self.sr_ratio > 1:
            # Reshape flattened tokens back into 2D feature maps
            # from [B, N, C] → [B, C, H, W]
            x_ = x.transpose(1, 2).reshape(B, C, H, W)

            # Apply learnable Conv2D with stride = sr_ratio
            # Reduces H, W → H/sr_ratio, W/sr_ratio
            x_ = self.sr(x_)  # [B, C, H/s, W/s]

            # Flatten spatial map back into token sequence
            # [B, C, H/s, W/s] → [B, N', C]
            x_ = x_.reshape(B, C, -1).transpose(1, 2)

            # Normalize the reduced tokens before projecting K, V
            x_ = self.norm(x_)
        else:
            # No reduction for small feature maps (e.g., stage 4)
            x_ = x
```

✅ This is where **Spatial Reduction Attention** happens.

* The **Conv2D layer** (`self.sr`) is **learnable**, acting like a *trainable pooling* operator.
* `sr_ratio` controls how much to downsample:

  * stage 1 → stride 8 (56×56 → 7×7)
  * stage 2 → stride 4 (28×28 → 7×7)
  * stage 3 → stride 2 (14×14 → 7×7)
  * stage 4 → stride 1 (7×7 → 7×7)

So K′, V′ are *always* roughly 7×7 spatial tokens = 49 tokens per head.

---

### **Create K′ and V′**

```python
        kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads)
        k, v = kv[:, :, 0], kv[:, :, 1]
```

✅ Here we generate **K′** and **V′** from the reduced sequence `x_`.

Shapes:

* K′, V′ ∈ [B, N′, h, C/h], with N′ ≈ 49

That matches what we discussed:

| Stage |   N  | N′ (after reduction) |
| :---: | :--: | :------------------: |
|   1   | 3136 |          49          |
|   2   |  784 |          49          |
|   3   |  196 |          49          |
|   4   |  49  |          49          |

---

### **Compute attention and output**

```python
        # Compute scaled dot-product attention
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        # Weighted sum of V′
        out = (attn @ v).reshape(B, N, C)
        return out
```

✅ Standard transformer attention, but with **reduced K′, V′**.

* The attention matrix has shape [N, N′] instead of [N, N].
* This makes the computation **O(N × N′)** instead of **O(N²)**.

---

## **3. Shape Flow Example (Stage 1, 56×56)**

| Tensor    | Operation          | Shape         |
| :-------- | :----------------- | :------------ |
| Input `x` | —                  | [1, 3136, 64] |
| Q         | Linear             | [1, 3136, 64] |
| x_        | Conv2D(stride = 8) | [1, 49, 64]   |
| K′, V′    | Linear             | [1, 49, 64]   |
| Attention | Q × K′ᵀ            | [3136, 49]    |
| Output    | Weighted sum       | [1, 3136, 64] |

---

## **4. Relation to What We Discussed**

| Concept                           | Implementation in Code        | Explanation                                |
| :-------------------------------- | :---------------------------- | :----------------------------------------- |
| **Learnable spatial reduction**   | `self.sr = nn.Conv2d(...)`    | Learns how to downsample spatial features. |
| **Q keeps full resolution**       | `q = self.q(x)`               | Every patch still queries globally.        |
| **Reduced K, V**                  | `x_ = self.sr(x_map)`         | Keys/values summarize local context.       |
| **Normalization after reduction** | `self.norm(x_)`               | Keeps feature statistics stable.           |
| **Efficiency**                    | `attn = (q @ kᵀ)` → O(N × N′) | Computation cost drastically reduced.      |

---

## **5. Key Takeaways**

✅ The **Conv2D layer is learnable**, not fixed pooling.
✅ It performs **spatial reduction** → fewer K, V tokens.
✅ Q remains full → still **global attention**.
✅ Complexity:
$$
O(N^2) \to O(N \cdot N') = O\left(N \cdot \frac{N}{s^2}\right)
$$
✅ SRA gives PVT its **pyramid hierarchy** and **efficiency**.

---


## self.kv
Excellent — this is one of the key *implementation tricks* used in Transformers (including PVT) for **efficiency** and **cleaner code**.
Let’s go step by step through why we write

```python
self.kv = nn.Linear(dim, dim * 2)
```

instead of two separate projections (`self.k` and `self.v`).

---

## **1. Recall the purpose of Q, K, V**

In self-attention, we project the input tokens ( X ) into three spaces:

$$
Q = X W_Q, \quad K = X W_K, \quad V = X W_V
$$

where each ( W_Q, W_K, W_V \in \mathbb{R}^{d_{\text{model}} \times d_{\text{head}}} ).

Each of these is just a **linear transformation** — conceptually independent.

So one could write in code:

```python
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
```

and later compute:

```python
q = self.q(x)
k = self.k(x)
v = self.v(x)
```

But in practice, this is **inefficient** (three separate matrix multiplications and memory reads).
So we combine **K** and **V** into a single linear layer.

---

## **2. Combined linear layer:**

```python
self.kv = nn.Linear(dim, dim * 2)
```

This means one linear layer outputs **twice as many channels** — the first half for K, second half for V.

When we run:

```python
kv = self.kv(x_)
```

the result has shape `[B, N', 2*dim]`.
Then we split it into **two parts**:

```python
kv = kv.reshape(B, -1, 2, self.num_heads, C // self.num_heads)
k, v = kv[:, :, 0], kv[:, :, 1]
```

So effectively:

$$
[K ; V] = X W_{KV}, \quad \text{where } W_{KV} = [W_K, W_V]
$$

and we just slice it into two pieces.

---

## **3. Why this is equivalent**

A linear layer does:

$$
Y = X W + b
$$

If ( W = [W_K, W_V] ), then:

$$
Y = X [W_K, W_V] + [b_K, b_V] = [XW_K + b_K, ; XW_V + b_V]
$$

So the output naturally splits into two independent parts — exactly what we want for K and V.

✅ Mathematically equivalent
✅ Computationally more efficient
✅ Simpler code

---

## **4. Why not combine Q as well?**

We *could* combine all three (Q, K, V) into a single `nn.Linear(dim, dim * 3)` —
and in many Transformer implementations (e.g., PyTorch’s `nn.MultiheadAttention` or ViT), that’s exactly what’s done.

However, in **PVT** (and other hierarchical Transformers), they **separate Q** because:

* ( Q ) always uses the **full-resolution tokens** (shape ( N )),
* while ( K ) and ( V ) use **spatially reduced tokens** (shape ( N' )).

So ( Q ) and ( [K, V] ) are computed from **different feature maps** (`x` vs `x_`).

Thus:

* Q: `x` (no reduction)
* K, V: `x_` (after Conv2d downsampling)

That’s why we can’t combine Q, K, and V into a single linear layer in PVT.

---

## **5. Summary Table**

| Projection | Computed from              | Linear layer              | Output shape   | Purpose                    |
| :--------- | :------------------------- | :------------------------ | :------------- | :------------------------- |
| **Q**      | Full-resolution tokens `x` | `nn.Linear(dim, dim)`     | [B, N, dim]    | Queries (global view)      |
| **K′, V′** | Downsampled tokens `x_`    | `nn.Linear(dim, dim * 2)` | [B, N′, 2*dim] | Keys & values (summarized) |

---

## **6. Summary Explanation**

✅ `self.kv = nn.Linear(dim, dim * 2)` combines **K and V projections** into a single operation.
✅ After passing through it, the tensor is split into K and V halves.
✅ Saves computation and memory vs. two separate layers.
✅ Q is kept separate because in **Spatial Reduction Attention**, it uses the **original** resolution while K and V use the **downsampled** one.

---



Perfect — let’s make this **very concrete** with a **toy PyTorch example** that shows:

1. How `self.kv = nn.Linear(dim, dim * 2)` works internally.
2. How the tensor shapes look **before and after** the projection.
3. How we split the result into **K** and **V**.

We’ll use small, human-readable numbers.

---

## **1. Setup**

We’ll simulate:

* Batch size ( B = 1 )
* Number of tokens ( N = 4 )
* Embedding dimension ( \text{dim} = 8 )
* 2 attention heads (( h = 2 ))

---

```python
import torch
import torch.nn as nn

# Small example
B, N, dim = 1, 4, 8
num_heads = 2

# Input tensor (4 tokens, each of 8 features)
x = torch.arange(B * N * dim, dtype=torch.float32).reshape(B, N, dim)
print("Input x shape:", x.shape)
print(x)
```

**Output:**

```
Input x shape: torch.Size([1, 4, 8])
tensor([[[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11., 12., 13., 14., 15.],
         [16., 17., 18., 19., 20., 21., 22., 23.],
         [24., 25., 26., 27., 28., 29., 30., 31.]]])
```

---

## **2. Define the `kv` Linear Layer**

```python
kv_layer = nn.Linear(dim, dim * 2, bias=False)
print("Weight shape:", kv_layer.weight.shape)
```

**Output:**

```
Weight shape: torch.Size([16, 8])
```

Explanation:

* Input: 8
* Output: 16 (= 2 × 8)
* So this layer outputs both K and V concatenated along the last dimension.

---

## **3. Forward Pass**

```python
kv = kv_layer(x)
print("After kv projection:", kv.shape)
```

**Output:**

```
After kv projection: torch.Size([1, 4, 16])
```

So for each token (length 8), the linear layer produces **16 outputs** —
the first 8 correspond to **K**, and the next 8 correspond to **V**.

---

## **4. Split into K and V**

```python
k, v = kv.chunk(2, dim=-1)
print("K shape:", k.shape)
print("V shape:", v.shape)
```

**Output:**

```
K shape: torch.Size([1, 4, 8])
V shape: torch.Size([1, 4, 8])
```

✅ Now you see:

* One linear layer → produces both K and V.
* We just split the last dimension in half.

---

## **5. With multiple heads (optional)**

Let’s reshape into multi-head format:

```python
head_dim = dim // num_heads
k = k.reshape(B, N, num_heads, head_dim)
v = v.reshape(B, N, num_heads, head_dim)
print("K per head:", k.shape)
print("V per head:", v.shape)
```

**Output:**

```
K per head: torch.Size([1, 4, 2, 4])
V per head: torch.Size([1, 4, 2, 4])
```

✅ Each token now has 2 heads, each of 4 dimensions.

---

## **6. Visual Summary**

| Step               | Operation                                | Shape          | Comment                  |
| :----------------- | :--------------------------------------- | :------------- | :----------------------- |
| Input              | `x`                                      | [1, 4, 8]      | 4 tokens, 8 channels     |
| Linear projection  | `self.kv(x)`                             | [1, 4, 16]     | outputs K+V concatenated |
| Split              | `chunk(2, dim=-1)`                       | [1, 4, 8] each | separates K and V        |
| Multi-head reshape | reshape to `[B, N, num_heads, head_dim]` | [1, 4, 2, 4]   | per-head views           |

---

## **7. Key Takeaways**

✅ `self.kv = nn.Linear(dim, dim * 2)` projects input tokens into a **combined (K,V)** space.
✅ It’s equivalent to having two independent linear layers (`W_K`, `W_V`) concatenated.
✅ The first half of the output corresponds to **K**, the second half to **V**.
✅ It saves both computation and memory bandwidth — one matrix multiply instead of two.
✅ Later, we reshape for multi-head attention.

---



## **PVT outputs**


## **1. PVT produces a feature hierarchy (multi-scale outputs)**

The **Pyramid Vision Transformer (PVT)** is built to behave like a CNN backbone (e.g. ResNet).
Instead of giving only one global feature, it produces **4 feature maps at different resolutions**:

| Stage | Symbol | Resolution (for 224×224 input) | Channels | Type of Information     |
| :---: | :----- | :----------------------------: | :------: | :---------------------- |
|   1   | **C1** |              56×56             |    64    | Local textures, edges   |
|   2   | **C2** |              28×28             |    128   | Small object parts      |
|   3   | **C3** |              14×14             |    320   | Large object regions    |
|   4   | **C4** |               7×7              |    512   | Global semantic context |

Each `Cᵢ` is the output of a stage containing several **Transformer blocks with Spatial Reduction Attention (SRA)**.

So when you call in PyTorch:

```python
features = backbone(x)   # e.g., PVT from timm with features_only=True
```

you get:

```python
C1 = features[0]  # [B, 64, 56, 56]
C2 = features[1]  # [B,128, 28, 28]
C3 = features[2]  # [B,320, 14, 14]
C4 = features[3]  # [B,512,  7,  7]
```

These are **multi-resolution, multi-semantic** features — perfect inputs for FPN.

---




Excellent question — and a very practical one. ✅

Yes, you **can absolutely use the Pyramid Vision Transformer (PVT)** directly via the **[timm](https://github.com/huggingface/pytorch-image-models)** library.
The **timm** package (by Ross Wightman) includes official and community-backed implementations of **PVT, PVTv2**, and many of their variants.

Let’s go through:

1. **Installation**
2. **Available PVT models in timm**
3. **Example usage (feature extraction, forward pass, and visualization)**
4. **Integration notes for detection/segmentation**

---

## **1. Install timm**

```bash
pip install timm
```

To check your version:

```bash
python -c "import timm; print(timm.__version__)"
```

You should ideally have **timm ≥ 0.9.10** (contains PVTv2 and updated models).

---

## **2. List available PVT models**

You can see all models containing “pvt”:

```python
import timm
models = timm.list_models("*pvt*")
for m in models:
    print(m)
```

Typical output includes:

```
pvt_tiny
pvt_small
pvt_medium
pvt_large
pvt_v2_b0
pvt_v2_b1
pvt_v2_b2
pvt_v2_b3
pvt_v2_b4
pvt_v2_b5
```

✅ The “v2” series are improved versions with:

* Linear complexity attention
* Improved positional encoding
* Better pretrained weights

---

## **3. Load a PVT model**

### **Classification example**

```python
import torch
import timm

# Create model
model = timm.create_model('pvt_v2_b2', pretrained=True)
model.eval()

# Random input (B=1, C=3, H=224, W=224)
x = torch.randn(1, 3, 224, 224)
with torch.no_grad():
    y = model(x)

print("Output shape:", y.shape)
```

**Output:**

```
Output shape: torch.Size([1, 1000])
```

✅ This is ImageNet-1k classification output.

---

### **4. Extract intermediate feature maps (for detection or segmentation)**

If you want to use **PVT as a backbone**, not for classification, you can set:

```python
model = timm.create_model('pvt_v2_b2', pretrained=True, features_only=True)
```

Now it returns **pyramid feature maps** from multiple stages:

```python
x = torch.randn(1, 3, 224, 224)
with torch.no_grad():
    features = model(x)

for i, f in enumerate(features):
    print(f"Stage {i+1} feature:", f.shape)
```

**Typical output:**

```
Stage 1 feature: torch.Size([1, 64, 56, 56])
Stage 2 feature: torch.Size([1, 128, 28, 28])
Stage 3 feature: torch.Size([1, 320, 14, 14])
Stage 4 feature: torch.Size([1, 512, 7, 7])
```

✅ Exactly matches the hierarchical pyramid structure we discussed earlier.

---

## **5. Visualize the hierarchy**

You can see how the channels and resolutions evolve:

| Stage | Output Shape     | Channels   | Description        |
| :---: | :--------------- | :--------- | :----------------- |
|   1   | [B, 64, 56, 56]  | Low-level  | local texture      |
|   2   | [B, 128, 28, 28] | mid-level  | edge/shape         |
|   3   | [B, 320, 14, 14] | high-level | semantics          |
|   4   | [B, 512, 7, 7]   | global     | full-image context |

---

## **6. Using as a backbone for downstream tasks**

You can plug these outputs into:

* **FPN / UPerNet** for segmentation
* **RetinaNet / Mask R-CNN** for detection
* **Custom encoder-decoder architectures (e.g., U-Net)**

Example:

```python
from timm import create_model
from torch import nn

# PVT backbone
backbone = create_model('pvt_v2_b2', pretrained=True, features_only=True)

# Example decoder
class SimpleSegmentationHead(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, num_classes, kernel_size=1)

    def forward(self, x):
        return self.conv(x[-1])  # use last stage feature

model = SimpleSegmentationHead(512, num_classes=21)
```

---

## **7. Example end-to-end flow**

```python
x = torch.randn(1, 3, 224, 224)
features = backbone(x)

for i, f in enumerate(features):
    print(f"Stage {i+1}: {f.shape}")

out = model(features)
print("Segmentation map:", out.shape)
```

Output:

```
Stage 1: torch.Size([1, 64, 56, 56])
Stage 2: torch.Size([1, 128, 28, 28])
Stage 3: torch.Size([1, 320, 14, 14])
Stage 4: torch.Size([1, 512, 7, 7])
Segmentation map: torch.Size([1, 21, 7, 7])
```

---

## **8. Summary**

✅ **timm** provides ready-to-use **PVT and PVTv2** models.
✅ You can use `pretrained=True` for ImageNet weights.
✅ Use `features_only=True` to get multi-scale backbone outputs.
✅ Perfect for **detection**, **segmentation**, **pose estimation**, etc.
✅ All models include correct scaling, attention, and spatial reduction internally.

---

Would you like me to show a **complete minimal segmentation example** using `pvt_v2_b2` from timm + a small decoder (like a U-Net or FPN-style head)?
