## Converting Images to Patches

### Singular image

In [22]:
import torch
from torch import nn

In [8]:
img_1=torch.randn((32,32,3)) # H,W,C
H,W,C=img_1.shape
patch_size=8
N=int((H*W)/(patch_size**2))
print("N:",N)
# H,W,C -> N,P^2,C
# 32,32,3 -> 16,8^2,3 -> 16 patches, each of size 8 x 8 x 3
img1_patch=img_1.view(N,patch_size,patch_size,C)
print(img1_patch.shape)

N: 16
torch.Size([16, 8, 8, 3])


#### Why using "view" is incorrect here?

**What you want to do**
- You have an image of 32,32,3 (HWC) and wish to split it into **non-overlapping square patches** of size 8x8 where each patch *preserves a spatial neighborhood*
- So, after patching you want 16 patches of shape 8x8x3  
  
**What’s wrong with view(N, patch_size, patch_size, C):**
- `img1_patch = img_1.view(N, patch_size, patch_size, C)`  
- This **does not extract actual patches from the image**
- Instead, view() just reshapes the tensor in memory, without **caring about spatial structure**
- This means,
    - It takes a whole image as a long 1D stream of numbers, and just chunks it *blindly* into blocks of shape 8,8,3
    - This ignores where each pixel is actually located in the image
    - So your "patches"  are made of **random pixel groupings** that are **not spatially connected** like the top left 8x8 region, top-right etc
- It's like cutting a photo into rectangles **without caring where the cuts go**
  
`view` = **reshape blindly
- “Hey, take this long row of numbers in memory and chunk it up into a new shape, as long as the total number of elements stays the same.”

#### use `torch.nn.Unfold`
- Sliding window extractor
- Slices tensor into **overlapping or non-overlapping** chunks based on stride and stacks them into a new dim
- It keeps **spatial groups** intact. 
- Unlike .view() or .reshape(), which blindly reshuffle numbers, unfold preserves the structure of each patch, perfect for feeding into ViT’s linear projection. 


In [11]:
img_1=torch.randn((3,32,32)) # H,W,C
C,H,W=img_1.shape
patch_size=8
patch_1=img_1.unfold(1,patch_size,patch_size).unfold(2,patch_size,patch_size)
patch_1.shape # (C, num_patches_H, num_patches_W, patch_size, patch_size)


torch.Size([3, 4, 4, 8, 8])

`img_1.unfold(1,patch_size,patch_size).unfold(2,patch_size,patch_size)`
- Imagine you're cutting brownies in a tray
- First, you make a **horizontal cut** (height)
- Then you make a **vertical cut** (width)
- Now you've got perfect little squares (patches)

In [None]:
# Next
# (C, num_patches_H, num_patches_W, patch_size, patch_size) -> (num_patches, C, patch_size,patch_size)
# Can i use view(-1,192)?
patch_1=patch_1.permute()

**permute vs view**  

Let's say we have a book that has:
- 2 chapters
- Each chapter has 3 pages
- Each chapter has 4 lines   
`x = torch.randn(2, 3, 4)  # (chapters, pages, lines)`  
`permute`
- rotate the dimensions around
- No reshaping of data - just reorder how you see the tensor in terms of axis
- permute(1,0,2)
    - *I wanna reorder this book so i see all the pages first, grouped by chapter*
    - havent changed the number of pages or files
    - Just **how you organize** your view of the book
  
`view`
- Changes the **shape** without changing the order
- "hey pytorch, treat the same memory as a new shape - dont shuffle anything"
- view(6,4)
    - *Forget chapters and pages, just give me a list of 6 flat sections, each with 4 lines.*
    - **You’re flattening or reshaping, but not caring about what those lines originally belonged to.**



> In the context of ViT, we **deeply care about the structure and order** of how patches are extracted and processed
- view would just smash the memory together without respecting spatial grouping of each patch
-permute ensures each row ends up representing one patch, fully and correctly ordered.
-  Use permute because
    - 	It reorganizes the tensor to make each patch’s data contiguous and aligned
    - So that each patch stays intact when you finally flatten

In [12]:
img_1=torch.randn((3,32,32)) # H,W,C
C,H,W=img_1.shape
patch_size=8
patch_1=img_1.unfold(1,patch_size,patch_size).unfold(2,patch_size,patch_size)
patch_1.shape # (C, num_patches_H, num_patches_W, patch_size, patch_size)

patch_1=patch_1.permute(1,2,0,3,4).reshape(-1,patch_size*patch_size*C)
print(patch_1.shape)

torch.Size([16, 192])


### Multiple Images

In [20]:
imgs=torch.randn((10,3,32,32)) # (batch_size,C,H,W)
B,C,H,W=imgs.shape
patch_size=8
patches=imgs.unfold(2,patch_size,patch_size).unfold(3,patch_size,patch_size) 
print(patches.shape) #(batch_size,C,num_patches_H,num_patches_W,patch_size,patch_size)
patches=patches.permute(0,2,3,1,4,5).reshape(B,-1,patch_size*patch_size*C)
print(patches.shape)

torch.Size([10, 3, 4, 4, 8, 8])
torch.Size([10, 16, 192])


## CLS token

### Singular image

In [30]:
img_1=torch.randn((16,192)) 
D=int(192*4)
linear=nn.Linear(in_features=img_1.shape[-1],out_features=D)
img_1=linear(img_1)
img_1.shape #16,768
cls=nn.Parameter(torch.randn((1,D))) #1,768
print(cls.shape)
input_1=torch.cat((cls,img_1),dim=0)
input_1.shape

torch.Size([1, 768])


torch.Size([17, 768])

### Multiple Images

In [36]:
imgs=torch.randn((10,16,192))
D=int(192*4)

linear=nn.Linear(in_features=imgs.shape[-1],out_features=D)
imgs=linear(imgs)
imgs.shape #10,16,768
cls=nn.Parameter(torch.randn((1,D))) #1,768
#one cls token per image
cls=cls.expand(imgs.shape[0],1,D)
inputs=torch.cat((cls,imgs),dim=1)
inputs.shape

torch.Size([10, 17, 768])

## Extracting the CLS token