## 1. Motivation

In convolutional neural networks (CNNs), each convolution can be thought of as:

* Taking **local patches** (receptive fields) from the input,
* Flattening them,
* Multiplying by the filter weights (matrix multiplication),
* Then reshaping the result back to the output feature map.

The operations **`torch.nn.Unfold`** and **`torch.nn.Fold`** allow you to explicitly perform those “patch extraction” and “reconstruction” steps.

✅ In short:

* `unfold` = **extract patches**
* `fold` = **reconstruct image from patches**

---

## 2. `torch.nn.Unfold`

### Concept

`Unfold` takes a 4D tensor
$$x \in \mathbb{R}^{(B, C, H, W)}$$
and returns a 3D tensor of **flattened sliding local blocks**.

### Example

```python
import torch
import torch.nn as nn

x = torch.arange(1, 17, dtype=torch.float32).view(1, 1, 4, 4)
print(x)
```

```
[[[[ 1,  2,  3,  4],
   [ 5,  6,  7,  8],
   [ 9, 10, 11, 12],
   [13, 14, 15, 16]]]]
```

Let’s extract 2×2 patches with stride 2:

```python
unfold = nn.Unfold(kernel_size=(2, 2), stride=2)
patches = unfold(x)
print(patches.shape)
print(patches)
```

Output:

```
torch.Size([1, 4, 4])
tensor([[[ 1.,  3.,  9., 11.],
         [ 2.,  4., 10., 12.],
         [ 5.,  7., 13., 15.],
         [ 6.,  8., 14., 16.]]])
```

Let’s interpret this:

* There are **4 patches** (since 4 windows fit in 4×4 with stride 2).
* Each patch has **4 elements** (2×2 = 4).
* The shape `[B, C×kernel_height×kernel_width, num_patches]` = `[1, 4, 4]`.

If you transpose:

```python
patches_T = patches.squeeze(0).T
print(patches_T)
```

you’ll see each row = flattened patch:

```
[[ 1,  2,  5,  6],
 [ 3,  4,  7,  8],
 [ 9, 10, 13, 14],
 [11, 12, 15, 16]]
```

---

## 3. `torch.nn.Fold`

### Concept

`Fold` is the *inverse* of `Unfold`:
It reconstructs the original image (or feature map) from patches.

```python
fold = nn.Fold(output_size=(4, 4), kernel_size=(2, 2), stride=2)
reconstructed = fold(patches)
print(reconstructed)
```

Result:

```
tensor([[[[ 1.,  2.,  3.,  4.],
          [ 5.,  6.,  7.,  8.],
          [ 9., 10., 11., 12.],
          [13., 14., 15., 16.]]]])
```

✅ Perfect reconstruction, since the patches don’t overlap.

If there **is overlap** (e.g., stride < kernel_size), overlapping pixels will be **summed** during `fold`.

To fix that, you can normalize by the overlap count (explained below).

---

## 4. Handling Overlaps (stride < kernel_size)

When patches overlap, `Fold` adds contributions together.
You can compute a **normalization mask** like this:

```python
ones = torch.ones_like(x)
patches_ones = unfold(ones)
reconstruction_mask = fold(patches_ones)
reconstructed /= reconstruction_mask
```

This ensures that overlapping regions are averaged correctly.

---

## 5. Matrix Multiplication View of Convolution

CNN convolution can be written as matrix multiplication:

$$
Y = W \times X_{\text{unfolded}}
$$

where:

* ( X_{\text{unfolded}} ) = patches extracted by `Unfold` → shape `[B, C×k_h×k_w, L]`
* ( W ) = filter weights → shape `[out_channels, C×k_h×k_w]`
* Result ( Y ) = `[B, out_channels, L]` → reshaped back using `Fold`

So, convolution is basically:

```python
Y = W @ unfold(X)
```

and then reshape.

---

## 6. Typical Use Cases

✅ **ViT or Patch-based Models**
`unfold` is used to create patch tokens:

```python
patches = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
patches = patches.contiguous().view(B, C, -1, patch_size*patch_size)
```

✅ **Image Reconstruction / Super-resolution**
`fold` helps reassemble overlapping patches from the processed output.

✅ **Implementing Custom Convolutions**
You can explicitly compute convolution using `unfold`, matrix multiplication, and `fold`.

---

## 7. Summary Table

| Operation   | Purpose                  | Input Shape       | Output Shape      | Example              |
| ----------- | ------------------------ | ----------------- | ----------------- | -------------------- |
| `nn.Unfold` | Extract local patches    | `[B, C, H, W]`    | `[B, C*kH*kW, L]` | Feature extraction   |
| `nn.Fold`   | Reconstruct from patches | `[B, C*kH*kW, L]` | `[B, C, H, W]`    | Image reconstruction |

---

Would you like me to show a **numerical example of `unfold` + weight multiplication + `fold` to simulate a 2D convolution manually** (so you see exactly how it matches PyTorch’s `nn.Conv2d`)?


## 9. Convolution as Matrix Multiplication


Instead of sliding a kernel over the image manually, we can:

1. Use `Unfold` to extract all sliding **patches** from the input.
2. **Flatten the kernel** into a row vector.
3. Multiply: **(kernel matrix) × (unfolded input patches)**.

This is much faster, and how many deep learning libraries implement convolutions behind the scenes.

---


####  Input Image $ \mathbf{X} \in \mathbb{R}^{3 \times 3} $

$
\mathbf{X} = \begin{bmatrix}
1 & 2 & 3 \\
4 & 5 & 6 \\
7 & 8 & 9
\end{bmatrix}
$

####  Kernel $ \mathbf{K} \in \mathbb{R}^{2 \times 2} $

$
\mathbf{K} = \begin{bmatrix}
1 & 0 \\
0 & -1
\end{bmatrix}
$

We'll do a **valid convolution** (no padding), **stride 1**, and treat it as **cross-correlation** (no flipping of kernel).


In [1]:
import torch

# Input image: shape (1, 1, 3, 3) (batch, channels, height, width)
img = torch.tensor([[[[1, 2, 3],
                      [4, 5, 6],
                      [7, 8, 9]]]], dtype=torch.float32)

print("Image shape:", img.shape)  # (N=1, C=1, H=3, W=3)

# Kernel: shape (2, 2), no channel yet
kernel = torch.tensor([[1, 0],
                       [0, -1]], dtype=torch.float32)

Image shape: torch.Size([1, 1, 3, 3])



---

####  Step 1: Unfold (im2col)

We extract all possible $ 2 \times 2 $ patches from $ \mathbf{X} $ and flatten each into a column vector:

####  Sliding Patches:

1. Patch at top-left:
$
\begin{bmatrix}
1 & 2 \\
4 & 5
\end{bmatrix}
\Rightarrow \begin{bmatrix} 1 \\ 2 \\ 4 \\ 5 \end{bmatrix}
$

2. Patch at top-middle:
$
\begin{bmatrix}
2 & 3 \\
5 & 6
\end{bmatrix}
\Rightarrow \begin{bmatrix} 2 \\ 3 \\ 5 \\ 6 \end{bmatrix}
$

3. Patch at middle-left:
$
\begin{bmatrix}
4 & 5 \\
7 & 8
\end{bmatrix}
\Rightarrow \begin{bmatrix} 4 \\ 5 \\ 7 \\ 8 \end{bmatrix}
$

4. Patch at middle-middle:
$
\begin{bmatrix}
5 & 6 \\
8 & 9
\end{bmatrix}
\Rightarrow \begin{bmatrix} 5 \\ 6 \\ 8 \\ 9 \end{bmatrix}
$

---

####  Step 2: Build Unfolded Matrix $ \mathbf{X}_{\text{unfold}} \in \mathbb{R}^{4 \times 4} $

Each column is a flattened patch:

$
\mathbf{X}_{\text{unfold}} =
\begin{bmatrix}
1 & 2 & 4 & 5 \\
2 & 3 & 5 & 6 \\
4 & 5 & 7 & 8 \\
5 & 6 & 8 & 9 \\
\end{bmatrix}^\top =
\begin{bmatrix}
1 & 2 & 4 & 5 \\
2 & 3 & 5 & 6 \\
4 & 5 & 7 & 8 \\
5 & 6 & 8 & 9 \\
\end{bmatrix}
$

---

In [2]:
unfold = torch.nn.Unfold(kernel_size=(2, 2))  # no padding, stride=1
patches = unfold(img)  # shape: (N, C*kH*kW, L), here (1, 4, 4)
print("Unfolded patches:\n", patches)

Unfolded patches:
 tensor([[[1., 2., 4., 5.],
         [2., 3., 5., 6.],
         [4., 5., 7., 8.],
         [5., 6., 8., 9.]]])


####  Step 3: Flatten Kernel into Row Vector

$
\mathbf{k}_{\text{flat}} = \begin{bmatrix} 1 & 0 & 0 & -1 \end{bmatrix}
$

---

In [3]:
kernel_flat = kernel.view(1, -1)  # shape: (1, 4)
print(kernel_flat)

tensor([[ 1.,  0.,  0., -1.]])


#### Step 4: Matrix Multiplication

We compute:

$\mathbf{y} = \mathbf{k}_{\text{flat}} \cdot \mathbf{X}_{\text{unfold}}^\top$

$\mathbf{y} = 
\begin{bmatrix} 
1 & 0 & 0 & -1 
\end{bmatrix}
\cdot
\begin{bmatrix}
1 & 2 & 4 & 5 \\
2 & 3 & 5 & 6 \\
4 & 5 & 7 & 8 \\
5 & 6 & 8 & 9 \\
\end{bmatrix}^\top
= \begin{bmatrix}
1 - 5, & 2 - 6, & 4 - 8, & 5 - 9
\end{bmatrix}
= \begin{bmatrix}
-4 & -4 & -4 & -4
\end{bmatrix}$

In [4]:
out = kernel_flat @ patches  # shape: (1, 4)
print(out)

tensor([[[-4., -4., -4., -4.]]])


####  Step 5: Reshape Output

Reshape to $ 2 \times 2 $ (since original image is $ 3 \times 3 $, kernel is $ 2 \times 2 $, and we used stride 1):

$
\mathbf{Y} = \begin{bmatrix}
-4 & -4 \\
-4 & -4
\end{bmatrix}
$

---

####  Final Output

$\boxed{
\mathbf{Y} = \begin{bmatrix}
-4 & -4 \\
-4 & -4
\end{bmatrix}}$

---


In [5]:
out_image = out.view(1, 1, 2, 2)
print("Output image:\n", out_image)

Output image:
 tensor([[[[-4., -4.],
          [-4., -4.]]]])


| Step            | Shape                         |
|----------------|-------------------------------|
| `x`            | (1, 1, 3, 3)                   |
| `x_unfold`     | (1, 4, 4) → 4 values per patch |
| `kernel_flat`  | (1, 4)                         |
| `out`          | (1, 4)                         |
| `out_image`    | (1, 1, 2, 2)                   |

---


# **`torch.nn.Unfold`**

`Unfold` (also called **im2col**) turns a 2D input image into a set of flattened sliding patches (columns), which makes convolution a **matrix multiplication**.


In PyTorch, **you don't *need* to use `unfold`** when doing convolutions because PyTorch already provides highly optimized convolution operations via `torch.nn.Conv2d`, `F.conv2d`, etc. However, you might want to use `unfold` when you want to:

Sometimes you want to do something **patch-wise**, like:
- Applying attention over local windows (like in Swin Transformers).
- Doing non-standard convolutions (e.g., dynamic or deformable convolutions).
- Manually implementing **convolutional backpropagation** or other gradient tricks.

---

Assume:
- Input: `x` with shape **[N, C_in=3, H, W]**
- Kernel: `w` with shape **[C_out, C_in=3, kH=3, kW=3]**
- Stride: `s=1`
- Padding: `p=0`
- Dilation: `d=1`
---

```python
conv = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=3, stride=1, padding=0)
out = conv(x)
```


**Shape math:**
- Input: `[N, 3, H, W]`
- Kernel: `[8, 3, 3, 3]` (8 filters)
- Output: `[N, 8, H_out, W_out]` where:

$
H_{out} = \left\lfloor \frac{H + 2p - d(kH - 1) - 1}{s} + 1 \right\rfloor = H - 2
$
$
W_{out} = W - 2
$

---


**Simulating Conv2d with `Unfold` and `Fold`**


```python
unfold = nn.Unfold(kernel_size=(3, 3), stride=1, padding=0)
patches = unfold(x)  # [N, C_in * kH * kW, L], where L = H_out * W_out
```

**Resulting shape:**
- `patches`: `[N, 3*3*3 = 27, H_out * W_out]`

So you get all sliding `3x3` patches in flattened form.

---

**Multiply with weights**

Let's say weights `w` have shape `[8, 3, 3, 3]`. We flatten each kernel:

```python
w_flat = w.view(8, -1)  # [8, 27]
```

We now do matrix multiplication for each image in the batch:

```python
# x_unfold: [N, 27, L]
# w_flat.T: [27, 8]
out_unfold = torch.einsum('nkl,ok->nol', patches, w_flat)  # [N, 8, L]
```

Or simply:
```python
out_unfold = w_flat @ patches  # [8, L] per image, batched
```

---

**Reshape back to image using `Fold`**

You can reshape it back if needed:
```python
fold = nn.Fold(output_size=(H_out, W_out), kernel_size=(1,1))  # Since already computed
out = out_unfold.view(N, 8, H_out, W_out)
```

However, usually `Fold` is used when you've done an operation **in patch space** and want to re-aggregate.

---


| Operation | Shape | Equivalent |
|----------|-------------------------|-----------|
| `nn.Conv2d` | `[N, C_out, H_out, W_out]` | Implicit |
| `Unfold(x)` | `[N, C_in * kH * kW, L]` | All patches as columns |
| Weight flattening | `[C_out, C_in * kH * kW]` | Filters as rows |
| Multiply | `[N, C_out, L]` | Output in flattened spatial domain |
| Reshape | `[N, C_out, H_out, W_out]` | Final output |

---
