# **Swin Transformer**

The **Swin Transformer** (Shifted Window Transformer, Liu et al., 2021) extends the **Vision Transformer (ViT)** to handle **high-resolution** and **dense prediction tasks** (e.g., detection, segmentation) efficiently — without losing its transformer flexibility.

---

## **1. Motivation**

**ViT** treats an image as a sequence of patches and applies **global self-attention**.
While this works well for classification, it faces key limitations:

* **Quadratic complexity** in the number of patches.
* **No local inductive bias** (poor handling of fine details).
* **Fixed spatial resolution**, unsuitable for dense predictions.

The **Swin Transformer** solves these problems by:

1. Applying **local attention** inside non-overlapping windows (reducing complexity).
2. **Shifting windows** between layers to connect across regions.
3. Introducing **patch merging** to build a **hierarchical (multi-scale)** representation — similar to CNNs.

<img src="images/swin_global_vs_local_attention.png" height="60%" width="60%" />



---

## **2. Architecture Overview**

Swin Transformer follows a **hierarchical pyramid design**, much like ResNet:

| Stage   | Input Resolution | Patch Operation | Output Channels | Description      |
| ------- | ---------------- | --------------- | --------------- | ---------------- |
| Stage 1 | 4×4 patches      | Patch Embedding | 96              | Linear embedding |
| Stage 2 | 1/2 spatial size | Patch Merging   | 192             | Downsampling     |
| Stage 3 | 1/4 spatial size | Patch Merging   | 384             | Downsampling     |
| Stage 4 | 1/8 spatial size | Patch Merging   | 768             | Downsampling     |

Each stage contains several **Swin Transformer Blocks**, each block consisting of:

1. **W-MSA** (Window-based Multi-Head Self-Attention)
2. **SW-MSA** (Shifted-Window Multi-Head Self-Attention)
3. **Feed-forward MLP**
4. **LayerNorm + Residual connections**


<img src="images/swin_architecture.png" height="60%" width="60%" />


<img src="images/two_successive_swin_transformer_blocks.png" height="30%" width="30%" />

---


#### Architecture Variants

- Swin-T: $C = 96$, $ \text{layer numbers} =\{2, 2, 6, 2\}$
- Swin-S: $C = 96$, $ \text{layer numbers} =\{2, 2, 18, 2\}$
- Swin-B: $C = 128$, $ \text{layer numbers}= \{2, 2, 18, 2\}$
- Swin-L: $C = 192$, $ \text{layer numbers} =\{2, 2, 18, 2\}$


## **3. Step-by-Step Pipeline**

Suppose you start with an image of size
$$
H = W = 224, \quad C = 3
$$
so the input tensor is
$$
X \in \mathbb{R}^{B \times 224 \times 224 \times 3}
$$
where (B) is the batch size.

---

#### **3.1. Patch Partitioning**

Swin uses **patch size = 4×4**, meaning each patch covers 4×4 pixels.

Number of patches per dimension:
$$
\frac{224}{4} = 56
$$

So after patch partitioning, we have
$$
N = 56 \times 56 = 3136 \text{ patches.}
$$

Each patch is flattened:
$$
4 \times 4 \times 3 = 48
$$
values per patch.

The shape becomes:
$$
[B, 3136, 48]
$$

---

#### **3.2 Linear Embedding (Projection Layer)**

Each 48-dimensional flattened patch vector is projected into a higher-dimensional embedding space of **C=96** (the feature dimension of Stage 1):

$$
[B, 3136, 48] \xrightarrow{\text{Linear(48→96)}} [B, 3136, 96]
$$

Now each patch token is a **96-dimensional feature vector**.
This is analogous to the **stem convolution** in CNNs.

You can reshape it back to a 2D feature map:
$$
[B, 56, 56, 96]
$$

---
#### **Fusing Patch Partitioning and Linear Embedding**
In most **Swin Transformer implementations**, the **patch partitioning** and **linear embedding** are fused into **a single convolution operation** at the very beginning.

Let’s see why and how.

---

#### **Conceptually**

We said:

* Partition image into non-overlapping 4×4 patches
* Flatten each patch (4×4×3 = 48)
* Apply a linear projection (48 → 96)

That’s two separate steps conceptually.

---

#### **Implementation Trick**

In code (e.g., in `timm`, `swin_transformer.py`), this is implemented as a **Conv2d layer**:

```python
self.patch_embed = nn.Conv2d(
    in_channels=3,
    out_channels=96,
    kernel_size=4,
    stride=4
)
```

This single convolution:

* Takes a **4×4 receptive field** (kernel)
* Moves by **stride 4** (non-overlapping patches)
* Outputs **96 channels**

---

#### **Resulting Dimensions**

This convolution directly produces:

$$
[B, 96, H/4, W/4]
$$

For a 224×224 input:
$$
[B, 96, 56, 56]
$$

If you then flatten the spatial dimensions:
$$
[B, 3136, 96]
$$

— which is exactly what you’d get from explicit patch partitioning + linear embedding.

---

#### **Advantages of the Convolutional Implementation**

* **Efficiency:** No explicit loops or reshaping of patches.
* **GPU-optimized:** Convolution is highly optimized for performance.
* **Equivalent to linear projection:** Each 4×4 patch is flattened and multiplied by a weight matrix of shape [96, 48], which is exactly what convolution with kernel 4×4 does.

---

✅ **Summary**

| Step                               | Conceptual View                    | Implementation                     |
| ---------------------------------- | ---------------------------------- | ---------------------------------- |
| Patch Partition + Linear Embedding | Flatten 4×4 patches, Linear(48→96) | `Conv2d(3→96, kernel=4, stride=4)` |
| Output Shape                       | [B, 56×56, 96]                     | [B, 96, 56, 56] (then flattened)   |



---
#### **3.3. Stage 1 — Swin Transformer Block(s)**

Stage 1 applies **two Swin Transformer Blocks**, each consisting of:

1. **W-MSA** (Window Multi-head Self-Attention)
2. **SW-MSA** (Shifted-Window MSA)


#### **Window-based Multi-Head Self-Attention (W-MSA)**


Each feature map is divided into **non-overlapping windows** (e.g., 7×7 patches).
Attention is computed **independently** within each window.

This reduces computational cost from global $ O((HW)^2) $ to
$$ O(M^2HW) $$
where $ M $ is the window size (e.g., 7).

In this setup,

> **All queries within a window share the same key set.**

That is, every patch inside a window can only attend to other patches inside **that same window**, not across windows.

Formally, for window $ w $:
$$
Q^{(w)} = X^{(w)}W^Q,\quad
K^{(w)} = X^{(w)}W^K,\quad
V^{(w)} = X^{(w)}W^V
$$
and
$$
\text{Attention}^{(w)} = \text{Softmax}\left(\frac{Q^{(w)}{K^{(w)}}^T}{\sqrt{d}}\right)V^{(w)}
$$

This local attention structure introduces **spatial locality** and scales efficiently with image size.




#### **Shifted-Window Multi-Head Self-Attention (SW-MSA)**


In the **next block**, windows are **shifted by half the window size** (e.g., 3 pixels if ( M=7 )).
This shift allows patches that were previously in separate windows to now fall in the same window.

Thus, alternating between **W-MSA** and **SW-MSA** layers enables:

* Local attention within windows.
* Cross-window communication.
* Gradual expansion of the receptive field — achieving a global view over multiple layers.



Both operate **within local windows**, not globally.

#### Window setup:

Each window covers **7×7 patches**, so each window has
$$
49 \text{ tokens.}
$$

Since the feature map is (56×56):
$$
\frac{56}{7} = 8 \text{ windows per side} \Rightarrow 8×8 = 64 \text{ windows total.}
$$

So within Stage 1:

* Input: [B, 56, 56, 96]
* Divide into windows of [7, 7, 96]
* Self-attention is computed **inside each 7×7 window**
* The output tokens are reassembled back to [B, 56, 56, 96]
* The second block uses **shifted windows** (by 3 patches = 7//2) to mix information across boundaries

After these two blocks, Stage 1 outputs the same resolution:
$$
[B, 56, 56, 96]
$$

---

#### **3.4. Patch Merging (between Stage 1 → Stage 2)**

Before entering Stage 2, the resolution is halved and channel depth doubles:

* Merge each **2×2 patch group**
* So new spatial size:
  $$
  56/2 = 28 \Rightarrow [B, 28, 28, 192]
  $$
* (Each merge concatenates 4 neighboring patch vectors → 4×96 = 384 → linear → 192)



To reduce spatial resolution and increase semantic richness, Swin introduces **Patch Merging**, analogous to CNN downsampling.

Given:
$$ X \in \mathbb{R}^{H \times W \times C} $$

1. **Group 2×2 neighboring patches:**
   Each group of four patches is concatenated:
   $$
   [x_{00}, x_{01}, x_{10}, x_{11}] \in \mathbb{R}^{4C}
   $$


<img src="images/patch_merging1.png" height="40%" width="40%" />


2. **Linear Projection:**
   Reduce dimensionality from $ 4C $ → $ 2C $:
   $$
   X' = \text{Linear}(\text{Concat}_{2\times2}(X)) \in \mathbb{R}^{\frac{H}{2} \times \frac{W}{2} \times 2C}
   $$

<img src="images/patch_merging2.png" height="40%" width="40%" />


Thus:

* Spatial size halves.
* Channel dimension doubles.

After several patch-merging steps, the model forms a **feature pyramid** where deeper stages capture more abstract semantics.




---





## 4. Inside a Swin Transformer Block

Given an input $ X \in \mathbb{R}^{H \times W \times C} $:

1. **LayerNorm:**
   $$
   \hat{X} = \text{LN}(X)
   $$
2. **(Shifted) Window Attention:**
   $$
   X' = X + \text{WindowAttention}(\hat{X})
   $$
3. **Feed-forward (MLP) with residual:**
   $$
   X'' = X' + \text{MLP}(\text{LN}(X'))
   $$
4. **MLP structure:**
   $$
   \text{MLP}(x) = \text{Linear}_2(\text{GELU}(\text{Linear}_1(x)))
   $$
   where hidden dimension = $ 4C $.

---

## 5. Multi-Head Self-Attention Inside a Window

Within each window:

$$
\text{Attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^T}{\sqrt{d}} + B\right)V
$$

Here:

* $ B $ — learnable **relative position bias** for spatial awareness.
* $ Q, K, V $ — derived from the same window.
* Computation is efficient since windows are small.

---

## 6. Hierarchical Output Example

For a 224×224 input image, Swin Transformer produces multi-scale features:

| Stage | Resolution | Channels |
| ----- | ---------- | -------- |
| 1     | 56×56      | 96       |
| 2     | 28×28      | 192      |
| 3     | 14×14      | 384      |
| 4     | 7×7        | 768      |

These outputs form a **feature pyramid** — ideal for:

* **Classification:** via global average pooling.
* **Detection/Segmentation:** as inputs to FPNs (e.g., Mask R-CNN).

---

## 7. Intuitive Summary

* **ViT**: global attention, single-scale, quadratic complexity.
* **Swin**: local attention, hierarchical, linear complexity.
* **Shifted windows**: enable cross-region communication.
* **Patch merging**: provides multi-scale features like CNNs.

 **In one sentence:**

> Swin Transformer limits self-attention to local windows (shared key sets), shifts them between layers for global context, and builds a hierarchical multi-scale representation through patch merging — combining the strengths of CNNs and Transformers.

---
## **Summary Table**

| Step                      | Operation          | Output Shape     | Notes                             |
| ------------------------- | ------------------ | ---------------- | --------------------------------- |
| Input                     | RGB image          | [B, 224, 224, 3] | Raw pixels                        |
| Patch Partition           | 4×4 blocks         | [B, 56×56, 48]   | Flatten each 4×4×3 patch          |
| Linear Embedding          | Linear(48→96)      | [B, 3136, 96]    | Project to feature dim 96         |
| Reshape                   | —                  | [B, 56, 56, 96]  | Spatial 2D form                   |
| Stage 1 (2 × Swin Blocks) | W-MSA + SW-MSA     | [B, 56, 56, 96]  | Local attention windows 7×7       |
| Patch Merging             | 2×2 merge + Linear | [B, 28, 28, 192] | Halve resolution, double channels |

---

Refs: [1](https://www.youtube.com/watch?v=qUSPbHE3OeU), [2](https://www.youtube.com/watch?v=z_8lajPxGQo)