## **to get total number of patches in 3x3 image is divide by size of one patch i.e. 1x1**

(3/1)**2 = 9

Yes, exactly! In short, if the **image size** and **patch size** are square and the patch size divides the image size evenly, you can calculate the total number of patches by:

\[
\text{Total patches} = \left( \frac{\text{Image size}}{\text{Patch size}} \right)^2
\]

For your example:

\[
\text{Total patches} = \left( \frac{3}{1} \right)^2 = 3^2 = 9
\]

This shortcut works because the patches are square and cover the image without overlap or remainder.

In a Vision Transformer (ViT), the total number of patches in an image is calculated by dividing the image's dimensions by the dimensions of a single patch.

For your example:

- **Image size**: \( 3 \times 3 \)
- **Patch size**: \( 1 \times 1 \)

To calculate the total number of patches:

1. **Divide the image dimensions by the patch dimensions**:
   \[
   \text{Number of patches in width (W)} = \frac{\text{Image width}}{\text{Patch width}} = \frac{3}{1} = 3
   \]
   \[
   \text{Number of patches in height (H)} = \frac{\text{Image height}}{\text{Patch height}} = \frac{3}{1} = 3
   \]

2. **Multiply the number of patches along each dimension**:
   \[
   \text{Total number of patches} = \text{W} \times \text{H} = 3 \times 3 = 9
   \]

### Explanation
Each \( 1 \times 1 \) patch is a distinct region of the \( 3 \times 3 \) image. Since the image is perfectly divisible by the patch size, there are \( 9 \) patches in total.

In [1]:
import torch
import torch.nn as nn

In [29]:
LEARNING_RATE = 1e-4
NUM_CLASSES = 10
PATCH_SIZE = 4
IMG_SIZE = 28
IN_CHANNELS = 1
NUM_HEADS = 8
DROPOUT = 0.001
HIDDEN_DIM = 768  # mlp head dimension
ADAM_WEIGHT_DECAY = 0
ADAM_BETAS = (0.9, 0.99)
ACTIVATION = "gelu"
NUM_ENCODERS = 4
EMBED_DIM = PATCH_SIZE * PATCH_SIZE * IN_CHANNELS  # 16   #patch's  W*H*CHEANNELs
NUM_PATCHES = (IMG_SIZE // PATCH_SIZE) ** 2  # 49


device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'



![alt text](<vit _01.png>)



cls taken has a positional embedding too


In [32]:
import torch.nn as nn


class PatchEmbedding(nn.Module):
    def __init__(self, embed_dim, patch_size, num_patches, dropout, in_channels):
        super().__init__()

        self.patcher = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=embed_dim,
                kernel_size=patch_size,
                stride=patch_size,
            ),
            nn.Flatten(2),
        )

        self.cls_token = nn.Parameter(
            torch.randn(size=(1, in_channels, embed_dim)), requires_grad=True
        )  # (batch_size, inn_channels, output_channeks)
        self.position_embeddings = nn.Parameter(
            torch.randn(size=(1, num_patches + 1, embed_dim)), requires_grad=True
        )  # +1 because CLS token is also acting as a batch.. each image patch/token has a positional embeddings correspondent
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        cls_token = self.cls_token.expand(
            x.shape[0], -1, -1
        )  # -1 mean keep dimensoins,,dont chnage them
        print(f"{cls_token.shape=}")
        x = self.patcher(x).permute(0, 2, 1)
        # now add left cls token to it
        x = torch.cat([cls_token, x], dim=1)
        x = self.position_embeddings + x
        x = self.dropout(x)
        return x


model = PatchEmbedding(
    embed_dim=EMBED_DIM,
    patch_size=PATCH_SIZE,
    num_patches=NUM_PATCHES,
    dropout=DROPOUT,
    in_channels=IN_CHANNELS,
).to(device)

x = torch.randn(512, 1, 28, 28)  # dummy input # b, c, h, w

model(x).shape  # (batch, NUM_PATCHES:49  + 1 for cls, EMBED_DIM

cls_token.shape=torch.Size([512, 1, 16])


torch.Size([512, 50, 16])

The **[CLS] token** in Vision Transformers (ViT) is a special learnable embedding introduced to aggregate information across all patches for tasks like classification. Here's how it works:

### Purpose of the [CLS] Token
- The **[CLS] token** stands for "classification token."
- It acts as a placeholder for the global representation of the image.
- After the transformer layers process the input patches, the final state of the **[CLS] token** serves as the input to a classification head for downstream tasks (e.g., predicting image labels).

---

### Process in ViT
1. **Input Patches**:
   - The image is divided into fixed-size patches, flattened, and embedded into a sequence of vectors (patch embeddings).
   - Positional embeddings are added to retain spatial information.

2. **[CLS] Token Initialization**:
   - A learnable vector (the [CLS] token) is prepended to the sequence of patch embeddings.
   - The input to the transformer becomes: \([ \text{[CLS]}, P_1, P_2, \dots, P_N ]\), where \(P_i\) are the patch embeddings.

3. **Transformer Processing**:
   - The sequence, including the [CLS] token, is passed through multiple transformer layers.
   - Each layer updates the representation of the [CLS] token based on interactions with the patch embeddings.

4. **Output**:
   - After the final transformer layer, the [CLS] token contains a global representation of the image.
   - This representation is passed to a **classification head** (typically an MLP) for the final prediction.

---

### Why Use the [CLS] Token?
- **Global Aggregation**: The [CLS] token aggregates information from all patches, acting as a summary of the image.
- **Simplicity**: Using a single token for classification avoids the need for pooling operations like average or max pooling.
- **Flexibility**: The same mechanism can be adapted for tasks other than classification by modifying the downstream head.

---

### Example Workflow in ViT
1. **Input Image**: \( 224 \times 224 \) image divided into \( 16 \times 16 \) patches → \( 14 \times 14 = 196 \) patches.
2. **Sequence**: [CLS] + 196 patch embeddings → \( 197 \) tokens.
3. **Output**: Final state of [CLS] → passed to the classification head.

The [CLS] token is central to the ViT architecture, enabling efficient global understanding of the input image.

Yes, in Vision Transformers (**ViT**), the **prediction is based on the [CLS] token**. Here's how it works in detail:

### Why Prediction is Based on [CLS] Token
The **[CLS] token** is designed to serve as a global representation of the input image. During the forward pass through the transformer layers, the [CLS] token interacts with all the patch embeddings and gathers information from the entire image.

After the final transformer layer:
- The [CLS] token contains a summarized feature representation of the image.
- This feature vector is then passed through a **classification head** (e.g., a fully connected layer or MLP) to make the final prediction.

---

### Step-by-Step Process for Prediction

1. **Input Sequence**:
   - The input sequence is \([ \text{[CLS]}, P_1, P_2, \dots, P_N ]\), where \(P_i\) are the patch embeddings, and [CLS] is the classification token.

2. **Transformer Layers**:
   - The sequence is processed by the transformer, and the [CLS] token is updated at each layer through attention with the patch embeddings.

3. **Final [CLS] Token**:
   - After the last transformer layer, the [CLS] token contains a high-level feature vector summarizing the entire image.

4. **Classification Head**:
   - The [CLS] token's final representation is passed through a classification head (e.g., a linear layer with softmax) to predict the image class.

---

### Example
For a classification task:
1. **Input Image**: \(224 \times 224\) image divided into \(16 \times 16\) patches → \(14 \times 14 = 196\) patches.
2. **Input Sequence**: [CLS] + 196 patch embeddings → \(197\) tokens.
3. **Output of Transformer**: Final [CLS] token → \(D\)-dimensional vector (e.g., \(D = 768\)).
4. **Classification**:
   \[
   \text{Prediction} = \text{Softmax}(\text{Linear}(\text{[CLS]}))
   \]

---

### Key Benefits of Using [CLS] Token for Prediction
- **Global Context**: It gathers information from all patches through attention.
- **Simplicity**: Eliminates the need for additional pooling layers like average or max pooling.
- **Flexibility**: Can be adapted for tasks like segmentation or object detection by adding task-specific heads.

Thus, in ViT, the final prediction for classification is **entirely based on the [CLS] token's representation**.

In [33]:
class ViT(nn.Module):
    def __init__(
        self,
        num_patches,
        img_size,
        num_classes,
        patch_size,
        embed_dim,
        num_encoders,
        num_head,
        hidden_dim,
        dropout,
        activation,
        in_channels,
    ):
        super().__init__()
        self.embeddings_block = PatchEmbedding(
            embed_dim=embed_dim,
            patch_size=patch_size,
            num_patches=num_patches,
            dropout=dropout,
            in_channels=in_channels,
        )
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_head,
            dropout=dropout,
            activation=activation,
            batch_first=True,  # batch will come first i.e. (B, C,H,W)
            norm_first=True,  # norm before attentiona and mlp lauer
        )
        self.encoder_blocks = nn.TransformerEncoder(
            encoder_layer, num_layers=num_encoders
        )
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(normalized_shape=embed_dim),
            nn.Linear(in_features=embed_dim, out_features=num_classes),
        )

    def forward(self, x):
        x = self.embeddings_block(x)
        x = self.encoder_blocks(x)
        x = self.mlp_head(x[:, 0, :])  # only take 0th token: cls token
        # we dont classify on whole embedding,,instead we only classify on cls token and we only feed cls token to mlp_head
        return x


model = ViT(
    NUM_PATCHES,
    IMG_SIZE,
    NUM_CLASSES,
    PATCH_SIZE,
    EMBED_DIM,
    NUM_ENCODERS,
    NUM_HEADS,
    HIDDEN_DIM,
    DROPOUT,
    ACTIVATION,
    IN_CHANNELS,
).to(device)
y = torch.randn(512, 1, 28, 28)
model(y).shape



cls_token.shape=torch.Size([512, 1, 16])


torch.Size([512, 10])