# **Pyramid Vision Transformer (PVT)**
The **Pyramid Vision Transformer (PVT)** — a fundamental architecture that bridges **ViTs and CNNs** for dense visual tasks like **object detection** and **semantic segmentation**.

---

## **1. Motivation**

The **Vision Transformer (ViT)** introduced powerful global attention but has drawbacks:

* Requires **fixed-size inputs** (e.g., 224×224).
* Produces **single-scale features**, unsuitable for dense prediction tasks.
* Has **quadratic complexity** with respect to the number of patches.
* Lacks **local inductive bias** (like translation invariance from CNNs).

The **Pyramid Vision Transformer (PVT)** (Wang et al., *ICCV 2021*) fixes these issues by **building a hierarchical (multi-scale) feature pyramid** — like a CNN backbone (e.g., ResNet).

---

## **2. Key Idea**

PVT = **Hierarchical Vision Transformer Backbone**

It mimics CNNs by generating **multi-resolution feature maps**:

| Stage | Resolution                | Channels      | Purpose                 |
| :---: | :------------------------ | :------------ | :---------------------- |
|   1   | High (e.g., 1/4 of input) | Low           | Capture local features  |
|   2   | Medium                    | More channels | Broader receptive field |
|   3   | Low                       | More channels | Semantic features       |
|   4   | Very low                  | Deep features | Global context          |

These outputs can directly feed **FPN**, **Mask R-CNN**, **U-Net**, etc.

---

## **3. Architecture Overview**

The PVT backbone has 4 stages, each performing:

1. **Patch embedding** (patchify + linear projection)
2. **Transformer encoder blocks**
3. **Spatial reduction attention (SRA)** — reduces tokens before attention
4. **Downsampling between stages** (to form a pyramid)

---

#### **3.1. Stage 1: Patch Embedding**

The image is split into small patches (like ViT):

$$
x_0 = \text{PatchEmbed}(I) \in \mathbb{R}^{H_0 \times W_0 \times C_0}
$$

Then flattened into tokens for the first transformer block.

---

#### **3.2. Transformer Encoder with SRA**

Standard self-attention has **O(N²)** cost (N = number of patches).
To make it scalable, **PVT** introduces **Spatial Reduction Attention (SRA)**:

Instead of using all keys and values, SRA **downsamples** them:

$$
K' = \text{Downsample}(K), \quad V' = \text{Downsample}(V)
$$

Then attention becomes:

$$
\text{Attention}(Q, K', V') = \text{Softmax}\left( \frac{Q {K'}^T}{\sqrt{d}} \right) V'
$$

This reduces complexity from **O(N²)** → **O(N × N/s²)** where *s* is the reduction ratio.

✅ This keeps **global receptive field** but reduces computation.

---

#### **3.3. Downsampling**


**PVT performs two different kinds of downsampling**:

1. **Downsampling the input tokens between stages** (via Conv2d with stride)
2. **Downsampling the K and V tokens inside attention** (via Spatial Reduction Attention, SRA)

These two are separate mechanisms.

---



## **4. Downsampling of the **input** (Patch Embedding)**

This is the **true spatial downsampling** of feature maps.

At each stage:

* Stage 1: stride 4
* Stage 2: stride 2
* Stage 3: stride 2
* Stage 4: stride 2

This reduces the resolution:

$$
224 \to 56 \to 28 \to 14 \to 7
$$

This is identical to CNN backbones (ResNet, EfficientNet, etc.).

---

#### **4.1 How Conv2d Performs Downsampling**

If your input has spatial size
$$
H_{\text{in}} \times W_{\text{in}},
$$
a convolution with:

* kernel size: $ k $
* stride: $ s $
* padding: $ p $
* dilation: $ d $

produces an output of size

$$
H_{\text{out}} = \left\lfloor \frac{H_{\text{in}} + 2p - d (k-1) - 1}{s} + 1 \right\rfloor,
$$

$$
W_{\text{out}} = \left\lfloor \frac{W_{\text{in}} + 2p - d (k-1) - 1}{s} + 1 \right\rfloor.
$$


$d=1, k=2, s=2$

<img src="../conv/images/no_padding_strides.gif" />



$d=2, k=3, s=1$

<img src="../conv/images/dilation.gif" />




---
#### **4.2 Kernel Size, Stride, and Padding**  

If you set:

* stride = 2
* kernel = 3
* padding = 1

Plugging into the general equation (with d = 1)

General form:

$$
H_{\text{out}} = \left\lfloor \frac{H_{\text{in}} + 2p - d(k-1) - 1}{s} + 1 \right\rfloor
$$

Set
$d = 1,\ p = 1,\ k = 3,\ s = 2$:

$$
H_{\text{out}}
= \left\lfloor
\frac{H_{\text{in}} + 2 - (3 - 1) - 1}{2} + 1
\right\rfloor
$$

Simplify:

$$
H_{\text{out}}
= \left\lfloor
\frac{H_{\text{in}} - 1}{2} + 1
\right\rfloor
$$

Combine:

$$
H_{\text{out}}
= \left\lfloor \frac{H_{\text{in}} + 1}{2} \right\rfloor
$$

For even sizes:

$$
H_{\text{out}} = \frac{H_{\text{in}}}{2}
$$

This is exactly PVT's downsampling behavior.


This is the famous **“same” downsampling pattern** used in ResNet, EfficientNet, etc.

But that requires padding.




```python
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=2, padding=1)
```

This **halves** height and width:

$$
H_{\text{out}} = \frac{H_{\text{in}}}{2},\quad W_{\text{out}} = \frac{W_{\text{in}}}{2}.
$$

---

## **5. Spatial Reduction Attention(SRA) Downsampling**

This is the *core trick* of **Spatial Reduction Attention (SRA)** in **PVT**.

The **downsampling** of $ K $ and $ V $ is **not** done by a linear projection (like `nn.Linear`), but by a **2D convolution with stride = s**, followed by normalization.


#### **5.1 Downsampling of **K** and **V** only (inside attention)**

Inside each Transformer block, PVT introduces **Spatial Reduction Attention (SRA)**:

* **Q** is NOT downsampled
* **K** and **V** are downsampled using a linear projection with reduction ratio (r)

**Reduction ratios:**

| Stage | Reduction (r) |
| ----- | ------------- |
| 1     | 8             |
| 2     | 4             |
| 3     | 2             |
| 4     | 1             |

Thus:

* **K and V tokens become fewer**
* **Q keeps the full resolution**

This reduces the cost of attention.

---

### **5.3. Why downsample only K and V?**

Standard attention complexity is:

$$
\mathcal{O}(N^2)
$$

If you reduce only K and V by a factor $r$:

* Q dimension: $N$
* K, V dimension: $N/r$

Then attention cost becomes:

$$
\mathcal{O}(N \cdot N/r) = \mathcal{O}(N^2/r)
$$

This keeps performance high while reducing runtime and memory.

---

#### **5.4. Visual summary**

**True downsampling (Conv):**

```
Input 224x224
    ↓ stride 4
56x56 tokens
    ↓ stride 2
28x28 tokens
    ↓ stride 2
14x14 tokens
    ↓ stride 2
7x7 tokens
```

**Inside each Transformer block (SRA):**

```
Q: full resolution
K: reduced with factor r
V: reduced with factor r
```

Example in stage 2:

* Q = 28×28 = 784 tokens
* K,V reduced by r = 4

So K,V = 784 / 4 = 196 tokens.

---


## **6. PVT-v1 and PVT-v2 Downsampling Parameters**

Below is the **precise list of kernel sizes (k)** and **strides (s)** used by **PVT-v1** and **PVT-v2** at **every stage**, for **Patch Embedding** and **Spatial-Reduction Attention (SRA)**.


---

### **6.1. PVT-v1 (original PVT)**

#### **6.1.1Patch Embedding Layers (downsampling)**

PVT-v1 downsamples the image using **Conv2d with kernel=7 or 3** and **stride=4 / 2 / 2 / 2**.

| Stage   | Input → Output Resolution | Conv2d kernel $k$ | stride $s$ | padding | Channels |
| ------- | ------------------------- | ----------------- | ---------- | ------- | -------- |
| Stage 1 | $224 \to 56$              | $k = 7$           | $s = 4$    | $p = 3$ | 64       |
| Stage 2 | $56 \to 28$               | $k = 3$           | $s = 2$    | $p = 1$ | 128      |
| Stage 3 | $28 \to 14$               | $k = 3$           | $s = 2$    | $p = 1$ | 320      |
| Stage 4 | $14 \to 7$                | $k = 3$           | $s = 2$    | $p = 1$ | 512      |

**How they downsample:**

Using the standard equation:

$$
H_{\text{out}} = \left\lfloor \frac{H_{\text{in}} + 2p - k}{s} + 1 \right\rfloor
$$

Example: Stage 1
$H_{\text{in}} = 224,\ k = 7,\ s = 4,\ p = 3$

$$
H_{\text{out}} = \frac{224 + 6 - 7}{4} + 1 = 56.
$$

---

#### **6.1.2 SRA: Spatial-Reduction Attention**

Inside each transformer block, the **key/value** tokens are reduced using a **linear projection with stride (r)**, not a convolution.

Reduction ratios (r):

| Stage | Reduction (r) |
| ----- | ------------- |
| 1     | 8             |
| 2     | 4             |
| 3     | 2             |
| 4     | 1             |

This affects **attention**, not spatial downsampling.

---

### **6.2. PVT-v2 (improved version)**

PVT-v2 replaces the large kernel=7 with kernel=3 everywhere, but still keeps the same downsampling ratio.

#### **6.2.1 Patch Embedding Layers**

| Stage   | Input → Output | Conv2d kernel $k$ | stride $s$ | padding | Channels |
| ------- | -------------- | ----------------- | ---------- | ------- | -------- |
| Stage 1 | $224 \to 56$   | $k = 3$           | $s = 4$    | $p = 1$ | 64       |
| Stage 2 | $56 \to 28$    | $k = 3$           | $s = 2$    | $p = 1$ | 128      |
| Stage 3 | $28 \to 14$    | $k = 3$           | $s = 2$    | $p = 1$ | 320      |
| Stage 4 | $14 \to 7$     | $k = 3$           | $s = 2$    | $p = 1$ | 512      |

**Key change:**

* kernel=7 → kernel=3 in Stage 1
* exact same spatial resolutions as PVT-v1

---

#### **6.2.2 SRA Reduction Ratios in PVT-v2**

Same idea, but slightly different values depending on variant (Tiny/Small/Medium/Large):

Typical:

| Stage | Reduction (r) |
| ----- | ------------- |
| 1     | 8             |
| 2     | 4             |
| 3     | 2             |
| 4     | 1             |

---



## **7. Summary Table for PVT-v1 and PVT-v2**

#### **7.1 (k and s only)**

| Stage   | PVT-v1 $k,s$ | PVT-v2 $k,s$ | Output Res   |
| ------- | ------------ | ------------ | ------------ |
| Stage 1 | $k=7, s=4$   | $k=3, s=4$   | $224 \to 56$ |
| Stage 2 | $k=3, s=2$   | $k=3, s=2$   | $56 \to 28$  |
| Stage 3 | $k=3, s=2$   | $k=3, s=2$   | $28 \to 14$  |
| Stage 4 | $k=3, s=2$   | $k=3, s=2$   | $14 \to 7$   |


---

#### **7.2  Why these choices**

* Stage 1 must reduce resolution strongly (224 → 56).
  That is why stride=4 is used.
* Later stages use stride=2, like a CNN backbone (ResNet).
* kernel=3 with padding=1 maintains a stable “same-like” downsampling behavior.

---


## **8. Advantages**

- ✅ **Hierarchical features** → usable as a CNN backbone (e.g., in Mask R-CNN).
- ✅ **Global receptive field** from transformers.
- ✅ **Efficient attention** via spatial reduction.
- ✅ **Variable input resolution** support.
- ✅ **Strong performance on dense tasks** (segmentation, detection).

---

## **9. Comparison with ViT and Swin**

| Model | Attention Type       | Hierarchy | Complexity | Windowed? | Suitable for Detection? |
| :---- | :------------------- | :-------- | :--------- | :-------- | :---------------------- |
| ViT   | Global               | ✖         | O(N²)      | ✖         | ✖                       |
| Swin  | Window-based (local) | ✅         | O(N)       | ✅         | ✅                       |
| PVT   | Global (SRA-reduced) | ✅         | O(N/s²)    | ✖         | ✅                       |

So:

* **PVT keeps global attention** but makes it efficient (SRA).
* **Swin** uses local window attention and shifting to connect neighborhoods.

---

## **10. Equation Summary**

Let’s define for stage *i*:

* Input tokens:
  $$ X_i \in \mathbb{R}^{N_i \times C_i} $$
* Spatial reduction ratio: *r*

Then attention becomes:

$$
Q = X_i W_Q, \quad K = \text{Down}(X_i W_K), \quad V = \text{Down}(X_i W_V)
$$

$$
\text{SRA}(X_i) = \text{Softmax}\left( \frac{Q K^T}{\sqrt{d}} \right) V
$$

where
$$
\text{Down}(\cdot) = \text{Reshape→Conv2d(stride=r)→Flatten}
$$


---

## **11. Typical Usage**

PVT variants:

* **PVT-Tiny**, **PVT-Small**, **PVT-Medium**, **PVT-Large**
  differ in embedding dims and number of blocks.

Used in:

* **PVT + FPN → RetinaNet / Mask R-CNN**
* **PVT + UPerNet → Semantic Segmentation**

---

## **12. Numerical Example**


#### **12.1. Input → Stage 1 (Patch Embedding)**

Start with a $224 \times 224$ RGB image.

PVT-Tiny uses a Conv2d patch embedding with **stride 4**, which gives an **effective patch size of $4 \times 4$**.

Take PVT-v2 style:

* kernel $k = 3$
* stride $s = 4$
* padding $p = 1$
* dilation $d = 1$

Output size per dimension:

$$
\begin{align*}
H_{\text{out}} &= \left\lfloor
\frac{H_{\text{in}} + 2p - d(k-1) - 1}{s} + 1
\right\rfloor \\
&= \left\lfloor
\frac{224 + 2(1) - 1(3-1) - 1}{4} + 1
\right\rfloor \\
&= \left\lfloor
\frac{224 + 2 - 2 - 1}{4} + 1
\right\rfloor \\
&= \left\lfloor
\frac{223}{4} + 1
\right\rfloor \\
&= \lfloor 55.75 + 1 \rfloor = \lfloor 56.75 \rfloor \\
&= 56
\end{align*}
$$

So:

- Resolution: $56 \times 56$

- Tokens:

  $$
  N_1 = 56 \times 56 = 3136
  $$

* Channels: $d_1 = 64$

Now we apply **SRA inside Stage 1**.

#### Stage 1 SRA reduction ratio $s_{\text{SRA}} = 8$

Standard self-attention cost:

$$
\text{Cost}_{\text{standard}}^{(1)} = N_1^2 \times d_1
= 3136^2 \times 64
\approx 9.8 \times 10^8.
$$

SRA downsamples K,V by spatial factor $8$:

* Spatially: $56 \times 56 \to 7 \times 7$
* Tokens for K,V:

  $$
  N_1' = \frac{N_1}{8^2}
  = \frac{3136}{64}
  = 49
  $$

Cost:

$$
\text{Cost}_{\text{SRA}}^{(1)} = N_1 \times N_1' \times d_1
= 3136 \times 49 \times 64
\approx 9.8 \times 10^6.
$$

Shapes (batch (B=1)):

* Input / tokens: $(1, 3136, 64)$
* $Q$: $(1, 3136, 64)$
* $K', V'$: $(1, 49, 64)$
* Attention map: $(3136, 49)$

---

#### 2. Stage 1 → Stage 2 (Conv downsampling)

Between Stage 1 and 2, PVT uses another Conv2d with:

* $k = 3$, $s = 2$, $p = 1$, $d = 1$

Input resolution: $56 \times 56$.

Using

$$
\begin{align*}
H_{\text{out}} &= \left\lfloor
\frac{H_{\text{in}} + 2p - d(k-1) - 1}{s} + 1
\right\rfloor \\
&= \left\lfloor
\frac{56 + 2(1) - 1(3-1) - 1}{2} + 1
\right\rfloor \\
&= \left\lfloor
\frac{56 + 2 - 2 - 1}{2} + 1
\right\rfloor \\
&= \left\lfloor
\frac{55}{2} + 1
\right\rfloor \\
&= \lfloor 27.5 + 1 \rfloor = \lfloor 28.5 \rfloor \\
&= 28
\end{align*}
$$

So:

* Resolution: $28 \times 28$

* Tokens:

  $$
  N_2 = 28 \times 28 = 784
  $$

* Channels: $d_2 = 128$

#### Stage 2 SRA $(s_{\text{SRA}} = 4)$

Standard SA cost:

$$
\text{Cost}_{\text{standard}}^{(2)} = N_2^2 \times d_2
= 784^2 \times 128
\approx 7.9 \times 10^7.
$$

SRA:

* Spatial reduction by (4): $28 \times 28 \to 7 \times 7$
* Tokens:

  $$
  N_2' = \frac{N_2}{4^2}
  = \frac{784}{16}
  = 49
  $$

Cost:

$$
\text{Cost}_{\text{SRA}}^{(2)} = N_2 \times N_2' \times d_2
= 784 \times 49 \times 128
\approx 4.9 \times 10^6.
$$

Shapes:

* Input: $(1, 784, 128)$
* $Q$: $(1, 784, 128)$
* $K', V'$: $(1, 49, 128)$
* Attention map: $(784, 49)$

---

#### 3. Stage 2 → Stage 3 (Conv downsampling)

Same Conv pattern:

* $k = 3$, $s = 2$, $p = 1$, $d = 1$

Input: $28 \times 28$

$$
\begin{align*}
H_{\text{out}} &= \left\lfloor
\frac{H_{\text{in}} + 2p - d(k-1) - 1}{s} + 1
\right\rfloor \\
&= \left\lfloor
\frac{28 + 2(1) - 1(3-1) - 1}{2} + 1
\right\rfloor \\
&= \left\lfloor
\frac{28 + 2 - 2 - 1}{2} + 1
\right\rfloor \\
&= \left\lfloor
\frac{27}{2} + 1
\right\rfloor \\
&= \lfloor 13.5 + 1 \rfloor = \lfloor 14.5 \rfloor \\
&= 14
\end{align*}
$$

So:

* Resolution: $14 \times 14$

* Tokens:

  $$
  N_3 = 14 \times 14 = 196
  $$

* Channels: $d_3 = 320$

#### Stage 3 SRA $(s_{\text{SRA}} = 2)$

Standard SA:

$$
\text{Cost}_{\text{standard}}^{(3)} = N_3^2 \times d_3
= 196^2 \times 320
\approx 1.2 \times 10^7.
$$

SRA:

* Spatial reduction by (2): (14 \times 14 \to 7 \times 7)
* Tokens:

  $$
  N_3' = \frac{N_3}{2^2}
  = \frac{196}{4}
  = 49
  $$

Cost:

$$
\text{Cost}_{\text{SRA}}^{(3)} = N_3 \times N_3' \times d_3
= 196 \times 49 \times 320
\approx 3.1 \times 10^6.
$$

Shapes:

* Input: $(1, 196, 320)$
* $Q$: ((1, 196, 320))
* $K', V'$: $(1, 49, 320)$
* Attention map: $(196, 49)$

---

#### 4. Stage 3 → Stage 4 (Conv downsampling)

Again Conv with $k=3, s=2, p=1, d=1$:

Input: $14 \times 14$

$$
\begin{align*}
H_{\text{out}} &= \left\lfloor
\frac{H_{\text{in}} + 2p - d(k-1) - 1}{s} + 1
\right\rfloor \\
&= \left\lfloor
\frac{14 + 2(1) - 1(3-1) - 1}{2} + 1
\right\rfloor \\
&= \left\lfloor
\frac{14 + 2 - 2 - 1}{2} + 1
\right\rfloor \\
&= \left\lfloor
\frac{13}{2} + 1
\right\rfloor \\
&= \lfloor 6.5 + 1 \rfloor = \lfloor 7.5 \rfloor \\
&= 7
\end{align*}
$$

So:

* Resolution: $7 \times 7$

* Tokens:

  $$
  N_4 = 7 \times 7 = 49
  $$

* Channels: $d_4 = 512$

#### Stage 4 SRA ($s_{\text{SRA}} = 1$)

Now the feature map is already small, so SRA does **no extra reduction**:

* $N_4' = N_4 = 49$

Standard SA and SRA costs are identical:

$$
\text{Cost}_{\text{standard}}^{(4)}
= \text{Cost}_{\text{SRA}}^{(4)}
= N_4^2 \times d_4
= 49^2 \times 512
\approx 1.2 \times 10^6.
$$

Shapes:

* Input: $(1, 49, 512)$
* $Q$: $(1, 49, 512)$
* $K', V'$: $(1, 49, 512)$
* Attention map: $(49, 49)$

---

#### 5. Combined View: Conv Downsampling + SRA

Here is everything together:


| Stage | Conv Downsampling $in → out$      | Output Res | Tokens $N$ | Channels $d$ | SRA ratio $s_{\text{SRA}}$ | $N'$ for K,V |    Q shape    |  K′/V′ shape |  Attn map  |
| :---: | :-------------------------------- | :--------: | :--------: | :----------: | :------------------------: | :----------: | :-----------: | :----------: | :--------: |
|   1   | $224^2 \xrightarrow{k=3,s=4,p=1}$ |    56×56   |    3136    |      64      |              8             |      49      | $1, 3136, 64$ |  $1, 49, 64$ | $3136, 49$ |
|   2   | $56^2 \xrightarrow{k=3,s=2,p=1}$  |    28×28   |     784    |      128     |              4             |      49      | $1, 784, 128$ | $1, 49, 128$ |  $784, 49$ |
|   3   | $28^2 \xrightarrow{k=3,s=2,p=1}$  |    14×14   |     196    |      320     |              2             |      49      | $1, 196, 320$ | $1, 49, 320$ |  $196, 49$ |
|   4   | $14^2 \xrightarrow{k=3,s=2,p=1}$  |     7×7    |     49     |      512     |              1             |      49      |  $1, 49, 512$ | $1, 49, 512$ |  $49, 49$  |


So now you can clearly see:

* **Conv downsampling between stages**:
  $224 \to 56 \to 28 \to 14 \to 7$
* **SRA downsampling inside stages**:
  keeps Q at full resolution in that stage, but reduces K,V to always $7 \times 7 = 49$ tokens in the first three stages.





## **Python Example**


You **can absolutely use the Pyramid Vision Transformer (PVT)** directly via the **[timm](https://github.com/huggingface/pytorch-image-models)** for both **PVT, PVTv2**, and many of their variants.

---

#### **List available PVT models**

You can see all models containing “pvt”:



In [1]:
# fmt: off
# isort: skip_file
# DO NOT reorganize imports - warnings filter must be FIRST!

import torch
import torch.nn as nn
import warnings
import os

warnings.filterwarnings('ignore')
os.environ['PYTHONWARNINGS'] = 'ignore'

import timm 
from timm import create_model

# fmt: on

import timm
models = timm.list_models("*pvt*")
for m in models:
    print(m)


pvt_v2_b0
pvt_v2_b1
pvt_v2_b2
pvt_v2_b2_li
pvt_v2_b3
pvt_v2_b4
pvt_v2_b5
twins_pcpvt_base
twins_pcpvt_large
twins_pcpvt_small


In [2]:
for name, module in backbone.named_children():
    print(f"  {name}: {type(module).__name__}")

NameError: name 'backbone' is not defined

In [None]:
model_name = "pvt_v2_b2"

backbone = timm.create_model(model_name, pretrained=True, features_only=True)

print("-"*60)
name, module = list(backbone.named_children())[0]
print(module)
print("-"*60)

In [None]:
print("-"*60)
name, module = list(backbone.named_children())[1]
print(module)
print("-"*60)

✅ The “v2” series are improved versions with:

* Linear complexity attention
* Improved positional encoding
* Better pretrained weights
---

#### **Classification Example**


In [None]:
import torch
import timm

# Create model
model = timm.create_model('pvt_v2_b2', pretrained=True)
model.eval()

# Random input (B=1, C=3, H=224, W=224)
x = torch.randn(1, 3, 224, 224)
with torch.no_grad():
    y = model(x)

print("Output shape:", y.shape)

✅ This is ImageNet-1k classification output.

---

#### **4. Extract intermediate feature maps (for detection or segmentation)**

If you want to use **PVT as a backbone**, not for classification, you can set:

In [None]:
model = timm.create_model('pvt_v2_b2', pretrained=True, features_only=True)

Now it returns **pyramid feature maps** from multiple stages:

In [None]:
x = torch.randn(1, 3, 224, 224)
with torch.no_grad():
    features = model(x)

for i, f in enumerate(features):
    print(f"Stage {i+1} feature:", f.shape)




✅ Exactly matches the hierarchical pyramid structure we discussed earlier.

---

#### **Using as a backbone for downstream tasks**

You can plug these outputs into:

* **FPN / UPerNet** for segmentation
* **RetinaNet / Mask R-CNN** for detection
* **Custom encoder-decoder architectures (e.g., U-Net)**

Example:

In [None]:
from torch import nn

# PVT backbone
backbone = timm.create_model('pvt_v2_b2', pretrained=True, features_only=True)

# Example decoder
class SimpleSegmentationHead(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, num_classes, kernel_size=1)

    def forward(self, x):
        return self.conv(x[-1])  # use last stage feature

model = SimpleSegmentationHead(512, num_classes=21)

x = torch.randn(1, 3, 224, 224)
features = backbone(x)

for i, f in enumerate(features):
    print(f"Stage {i+1}: {f.shape}")

out = model(features)
print("Segmentation map:", out.shape)


#### `features_only=True`

This changes the forward pass so that the model **returns the feature maps of each stage**, not the classification head.

PVT-v2-B2 has **4 stages**:

| Stage   | Resolution shrink | Output Channels |
| ------- | ----------------- | --------------- |
| Stage 1 | 4× ↓              | 64              |
| Stage 2 | 8× ↓              | 128             |
| Stage 3 | 16× ↓             | 320             |
| Stage 4 | 32× ↓             | 512             |

So for an input image of size
$$H \times W$$

the outputs are:

| Stage | Tensor shape                           |
| ----- | -------------------------------------- |
| p1    | $$B, 64, \frac{H}{4}, \frac{W}{4}$$    |
| p2    | $$B, 128, \frac{H}{8}, \frac{W}{8}$$   |
| p3    | $$B, 320, \frac{H}{16}, \frac{W}{16}$$ |
| p4    | $$B, 512, \frac{H}{32}, \frac{W}{32}$$ |

This is what you get from:

```python
features = backbone(x)
```

**`features` is a Python list:**

```python
features = [p1, p2, p3, p4]
```

---

#### Forward pass inside PVT-v2-B2 (conceptual)

Each stage performs:

#### (1) Patch embedding

For stage 1:

$$X_1 = \text{Conv2d}(X_0)$$
This converts RGB image to 64 channels and downsamples by 4×.

For later stages:

$$X_{i} = \text{PatchMerge}(X_{i-1})$$
Downsamples spatial resolution again.

---

#### (2) Spatial Reduction Attention (SRA)

Each Transformer block inside PVT uses:

$$
Q = XW_Q, \quad K = X_{\text{reduced}}W_K, \quad V = X_{\text{reduced}}W_V
$$

where
$$X_{\text{reduced}} = \text{Downsample}(X, r)$$

with reduction ratios
$$r = {8,4,2,1}$$
for the four stages.

The attention score:

$$
A = \text{Softmax}\left(\frac{QK^{T}}{\sqrt{d_k}}\right)
$$

The output:

$$
Y = A V
$$

Then residuals + MLP, repeated.

Each stage has multiple blocks.

---

#### So the final forward of the backbone returns:

```python
return [stage1_out, stage2_out, stage3_out, stage4_out]
```

---

#### Parameters of PVT-v2-B2

#### Patch embeddings (conv layers)

Each stage has a convolution that changes channels and reduces resolution.

#### Transformer layers

Every layer has:

* $$W_Q, W_K, W_V \in \mathbb{R}^{C \times C}$$
* MLP weights
* LayerNorm parameters
* SRA reduction projections

Total parameters ≈ **35M** (depending on version).

---

#### Your SimpleSegmentationHead

```python
class SimpleSegmentationHead(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, num_classes, kernel_size=1)

    def forward(self, x):
        return self.conv(x[-1])
```

#### Forward logic

You receive a feature list:

```python
[p1, p2, p3, p4]
```

`x[-1]` = `p4`

Shape of `p4`:

$$
B, 512, \frac{H}{32}, \frac{W}{32}
$$

#### The segmentation head does:

$$
\hat{Y} = \text{Conv2d}_{1\times1}(p4)
$$

This is linear projection per pixel:

$$
\hat{Y}_{b,c,i,j} = \sum_{k=1}^{C_{in}} W_{c,k}  p4_{b,k,i,j} + b_c
$$

So you get logits:


$$
\left( B, \text{num\_classes}, \frac{H}{32}, \frac{W}{32} \right)
$$


#### Parameters of the head:

A 1×1 convolution has:

$$
\text{Params} = C_{\text{in}} \cdot C_{\text{out}} + C_{\text{out}}
$$

where

* $C_{\text{in}} = 512$ (from PVT stage 4)
* $C_{\text{out}} = \text{num\_classes}$

Example: for 21 classes (PASCAL VOC):

$$
512 \times 21 + 21 = 10,773
$$

Very small.

---


#### Backbone output

You get:

```python
[
 B, 64,  H/4,  W/4,
 B,128,  H/8,  W/8,
 B,320, H/16, W/16,
 B,512, H/32, W/32
]
```

### Segmentation head

Projects the **last** feature map to class logits using a 1×1 convolution.

#### Result

Low-resolution segmentation map at 1/32 spatial resolution.

---




#### Freeze all layers except the head


In [None]:
for name, param in pvt_model.named_parameters():
    if 'head' not in name:
        param.requires_grad = False
    else:
        param.requires_grad = True

#### Trainable Parameter Count

when you froze everything except the head, since only the final linear layer stays trainable, you end up with roughly `num_features × num_classes + num_classes` parameters (e.g. `256 * 10 + 10 = 2,570`). Total model size is ≈3.4 M, so `2,570 / 3,412,330 ≈ 0.08%` is absolutely expected. That means freezing worked correctly.

If you need more layers to fine-tune, selectively unfreeze later stages (e.g., keep `stage3` and `stage4` trainable) by checking `name.startswith('stages.2')`.

#### Dropout Before the Head

PVT includes a `head_drop` module (Dropout) before the classifier head. When freezing the backbone and only training the head, you usually keep that dropout enabled—especially if data is limited—since it regularizes the linear head weights. Options:

  1. **Keep default** (recommended): leave `pvt_model.head_drop` intact; it only runs during training.
  2. **Reduce or disable** if you’re seeing underfitting: set `pvt_model.head_drop.p = 0` or wrap the head as `nn.Sequential(nn.Dropout(new_p), nn.Linear(...))`.
  3. **Custom head**: replace `pvt_model.head = nn.Sequential(nn.Dropout(p=0.2), nn.Linear(...))` to control dropout explicitly.

Just remember that if you disable dropout, you should monitor validation metrics closely to ensure you’re not overfitting the small trainable portion.

#### input size
You can read the required input size straight from the model’s config instead of guessing. Every timm model exposes a default (or pretrained) configuration dictionary with fields like `input_size`, `mean`, `std`, etc.
Add something like this right after you create the model:

In [None]:
pvt_model = timm.create_model("pvt_v2_b0", pretrained=True)
cfg = pvt_model.pretrained_cfg  # (older timm versions call it default_cfg)
print("Required input size:", cfg["input_size"])  # e.g., (3, 224, 224)


This gives you the exact `(channels, height, width)` expected by the pretrained weights. If you’re on a newer timm release, use the helper instead:


In [None]:
from timm.data import resolve_model_data_config
cfg = resolve_model_data_config(pvt_model)
print(cfg["input_size"])

---

# **Where PVT Works the Best**


Here are the clearest, most practical cases where **PVT (Pyramid Vision Transformer)** is one of the strongest choices — often *better than ViT/DeiT, better than Swin for certain setups, and much better than CNNs for multi-scale tasks*.

---

# **1. Dense Prediction Tasks with Many Scales**

*(Semantic segmentation, instance segmentation, panoptic segmentation)*

PVT was designed to **replace ResNet backbones** inside segmentation/detection heads like FPN, UPerNet, Mask2Former, Detectron2 models.

### **Why PVT shines here**

* It has a **true pyramid** (4 stages: 1/4, 1/8, 1/16, 1/32) just like CNNs.
* Attention uses **spatial-reduction attention (SRA)** so it scales to high resolutions.
* Therefore models like UPerNet, FPN, DeepLab work naturally with it.

### **Typical best-case examples**

* **Medical segmentation** (multi-scale lesions, polyps, tumors)
* **Remote sensing segmentation** (buildings, roads, agricultural fields)
* **Autonomous driving segmentation** (Cityscapes, ADE20K)
* **Indoor scene parsing**
* **Industrial defect segmentation** (scratches, cracks at multiple scales)

### **Real benchmark evidence**

PVT-v2-B2 + UPerNet **outperforms ResNet-50 + UPerNet** on ADE20K with fewer FLOPs.

---

# **2. Object Detection With Many Object Sizes**

*(especially small + large objects in the same scene)*

PVT is extremely strong when **object scale varies**.

### **Why?**

It provides:

* Strong global modeling
* Pyramid features
* Works with RetinaNet, Faster R-CNN, Mask R-CNN directly

### **Best-case domains**

* **Drone detection** (small pedestrians, vehicles from above)
* **Wildlife detection** (animals at multiple scales)
* **COCO-like multi-object scenes**
* **Industrial inspection** (tiny screws + large components)

### **Example**

RetinaNet + PVT-v2 **beats ResNet backbones of same complexity**.

---

# **3. Any Vision Task Where You Want Transformer Power but Need a Pyramid Backbone**

*(i.e., transformer + CNN-like architecture)*

ViT/DeiT are **single-scale**, giving only a 7×7 token map after patch embedding.
PVT gives multi-scale features natively.

### **Excellent fits**

* **Feature extractors for SLAM, SfM, VIO**
  (you want multi-scale features; PVT performs better than DeiT here)
* **Depth estimation** (MDE)
* **Optical flow**
* **Super-resolution** (multi-scale helps)
* **3D reconstruction / NeRF auxiliary encoders**

This is why **MAE-style PVT encoders** have been explored for reconstruction tasks.

---

# **4. When Memory Is Limited but Resolution Is High**

*(e.g. 1024×1024+) ⇒ DeiT/ViT would break)*

PVT’s **spatial reduction attention** reduces keys/values by a factor (like 4×, 8×).

### **Meaning:**

* High-resolution images stay feasible
* Training on large GPUs becomes manageable
* Still benefits from global attention

### **Best-case roles**

* **Microscopy images**
* **Satellite SIGINT datasets**
* **High-resolution industrial inspection**
* **Medical scans** (pathology slides)

In these cases, DeiT/ViT is unusable without strong downsampling.

---

# **5. Applications That Need Both Global and Local Context**

PVT mixes global context (transformers) with local processing (pyramid).
This makes PVT ideal for tasks requiring:

* fine-grained local detail
* long-range interaction

Examples:

* **Agricultural plant disease detection** (field images + leaf closeups)
* **Robotics scene understanding** (room-wide + small objects)
* **Document understanding** (global layout + small text regions)

---

# **Summary Table (Best Use Cases for PVT)**

| Domain / Task                 | Why PVT Excels                              | Competes Against |
| ----------------------------- | ------------------------------------------- | ---------------- |
| Semantic Segmentation         | Multi-scale transformer features            | Swin, ResNet+FPN |
| Panoptic Segmentation         | Strong global + local modeling              | Swin, ConvNeXt   |
| Instance Segmentation         | Pyramid attention backbone                  | ResNet+FPN       |
| Object Detection              | Good at small+large objects in same frame   | Swin, ResNet     |
| Medical Imaging               | Handles large resolutions cheaply           | ViT, DeiT        |
| Remote Sensing                | Long-range transformer context              | ViT, Swin        |
| SLAM / SfM feature extraction | Multi-scale + attention                     | ResNet, DeiT     |
| Industrial Defect Detection   | Requires high-res + local detail            | EfficientNet     |
| Autonomous Driving            | Multi-scale segmentation/detection backbone | ConvNeXt         |

---

# **When NOT to Use PVT**

PVT is **not** ideal when:

* Training data is **very small** (< 1000 samples)
* You need extreme speed on mobile devices (CNNs win)
* A pure global attention model like ViT is enough for classification
* Strong inductive bias is desired (Swin may do better)

---

If you want, I can also provide:

✅ a **comparison table: PVT vs DeiT vs ViT vs Swin**
✅ recommended **PVT variant (B0–B5) for each dataset/task**
✅ minimal code for **PVT-v2 segmentation, detection, or classification**

Just tell me which one you want.
