## 1. Motivation

In convolutional neural networks (CNNs), each convolution can be thought of as:

* Taking **local patches** (receptive fields) from the input,
* Flattening them,
* Multiplying by the filter weights (matrix multiplication),
* Then reshaping the result back to the output feature map.

The operations **`torch.nn.Unfold`** and **`torch.nn.Fold`** allow you to explicitly perform those “patch extraction” and “reconstruction” steps.

✅ In short:

* `unfold` = **extract patches**
* `fold` = **reconstruct image from patches**

---

## 2. `torch.nn.Unfold`

### Concept

`Unfold` takes a 4D tensor
$$x \in \mathbb{R}^{(B, C, H, W)}$$
and returns a 3D tensor of **flattened sliding local blocks**.

### Example

```python
import torch
import torch.nn as nn

x = torch.arange(1, 17, dtype=torch.float32).view(1, 1, 4, 4)
print(x)
```

```
[[[[ 1,  2,  3,  4],
   [ 5,  6,  7,  8],
   [ 9, 10, 11, 12],
   [13, 14, 15, 16]]]]
```

Let’s extract 2×2 patches with stride 2:

```python
unfold = nn.Unfold(kernel_size=(2, 2), stride=2)
patches = unfold(x)
print(patches.shape)
print(patches)
```

Output:

```
torch.Size([1, 4, 4])
tensor([[[ 1.,  3.,  9., 11.],
         [ 2.,  4., 10., 12.],
         [ 5.,  7., 13., 15.],
         [ 6.,  8., 14., 16.]]])
```

Let’s interpret this:

* There are **4 patches** (since 4 windows fit in 4×4 with stride 2).
* Each patch has **4 elements** (2×2 = 4).
* The shape `[B, C×kernel_height×kernel_width, num_patches]` = `[1, 4, 4]`.

If you transpose:

```python
patches_T = patches.squeeze(0).T
print(patches_T)
```

you’ll see each row = flattened patch:

```
[[ 1,  2,  5,  6],
 [ 3,  4,  7,  8],
 [ 9, 10, 13, 14],
 [11, 12, 15, 16]]
```

---

## 3. `torch.nn.Fold`

### Concept

`Fold` is the *inverse* of `Unfold`:
It reconstructs the original image (or feature map) from patches.

```python
fold = nn.Fold(output_size=(4, 4), kernel_size=(2, 2), stride=2)
reconstructed = fold(patches)
print(reconstructed)
```

Result:

```
tensor([[[[ 1.,  2.,  3.,  4.],
          [ 5.,  6.,  7.,  8.],
          [ 9., 10., 11., 12.],
          [13., 14., 15., 16.]]]])
```

✅ Perfect reconstruction, since the patches don’t overlap.

If there **is overlap** (e.g., stride < kernel_size), overlapping pixels will be **summed** during `fold`.

To fix that, you can normalize by the overlap count (explained below).

---

## 4. Handling Overlaps (stride < kernel_size)

When patches overlap, `Fold` adds contributions together.
You can compute a **normalization mask** like this:

```python
ones = torch.ones_like(x)
patches_ones = unfold(ones)
reconstruction_mask = fold(patches_ones)
reconstructed /= reconstruction_mask
```

This ensures that overlapping regions are averaged correctly.

---

## 5. Matrix Multiplication View of Convolution

CNN convolution can be written as matrix multiplication:

$$
Y = W \times X_{\text{unfolded}}
$$

where:

* ( X_{\text{unfolded}} ) = patches extracted by `Unfold` → shape `[B, C×k_h×k_w, L]`
* ( W ) = filter weights → shape `[out_channels, C×k_h×k_w]`
* Result ( Y ) = `[B, out_channels, L]` → reshaped back using `Fold`

So, convolution is basically:

```python
Y = W @ unfold(X)
```

and then reshape.

---

## 6. Typical Use Cases

✅ **ViT or Patch-based Models**
`unfold` is used to create patch tokens:

```python
patches = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
patches = patches.contiguous().view(B, C, -1, patch_size*patch_size)
```

✅ **Image Reconstruction / Super-resolution**
`fold` helps reassemble overlapping patches from the processed output.

✅ **Implementing Custom Convolutions**
You can explicitly compute convolution using `unfold`, matrix multiplication, and `fold`.

---

## 7. Summary Table

| Operation   | Purpose                  | Input Shape       | Output Shape      | Example              |
| ----------- | ------------------------ | ----------------- | ----------------- | -------------------- |
| `nn.Unfold` | Extract local patches    | `[B, C, H, W]`    | `[B, C*kH*kW, L]` | Feature extraction   |
| `nn.Fold`   | Reconstruct from patches | `[B, C*kH*kW, L]` | `[B, C, H, W]`    | Image reconstruction |

---

Would you like me to show a **numerical example of `unfold` + weight multiplication + `fold` to simulate a 2D convolution manually** (so you see exactly how it matches PyTorch’s `nn.Conv2d`)?
