# Imports

In [181]:
from torch.nn import MultiheadAttention

# For this notebook to run with updated APIs, we need torch 1.12+ and torchvision 0.13+
try:
    import torch
    import torchvision
    from torch import nn

    assert int(torch.__version__.split(".")[1]) >= 12 or int(
            torch.__version__.split(".")[0]
            ) == 2, "torch version should be 1.12+"
    assert int(torchvision.__version__.split(".")[1]) >= 13, "torchvision version should be 0.13+"
    print(f"torch version: {torch.__version__}")
    print(f"torchvision version: {torchvision.__version__}")
except:
    print(f"[INFO] torch/torchvision versions not as required, installing nightly versions.")
    !pip3 install -U torch torchvision torchaudio --index-url https: // download.pytorch.org/whl/cu118
    import torch
    import torchvision

    print(f"torch version: {torch.__version__}")
    print(f"torchvision version: {torchvision.__version__}")

# Continue with regular imports
import matplotlib.pyplot as plt
import torch
import torchvision

from torch import nn
from torchvision import transforms

# Try to get torchinfo, install it if it doesn't work
try:
    from torchinfo import summary
except:
    print("[INFO] Couldn't find torchinfo... installing it.")
    !pip install -q torchinfo
    from torchinfo import summary

# Try to import the going_modular directory, download it from GitHub if it doesn't work
try:
    from going_modular.going_modular import data_setup, engine
    from helper_functions import download_data, set_seeds, plot_loss_curves
except:
    # Get the going_modular scripts
    print("[INFO] Couldn't find going_modular or helper_functions scripts... downloading them from GitHub.")
    !git clone https: // github.com/mrdbourke/pytorch-deep-learning
    !mv pytorch-deep-learning/going_modular.
    !mv pytorch-deep-learning/helper_functions.py.  # get the helper_functions.py script
    !rm -rf pytorch-deep-learning
    from going_modular.going_modular import data_setup, engine
    from helper_functions import download_data, set_seeds, plot_loss_curves

from pathlib import Path

torch version: 2.1.2
torchvision version: 0.16.2


# Device choser

In [183]:
def device_chooser(prefer_device: str = "cpu") -> str:
    devices = {}
    if torch.cuda.is_available():
        devices["cuda"] = "cuda"
    elif torch.backends.mps.is_available():
        devices["mps"] = "mps"
    else:
        devices["cpu"] = "cpu"

    if prefer_device in devices:
        return devices[prefer_device]
    else:
        return "cpu"


device = device_chooser(prefer_device="mps")
device

'mps'

# Get data

In [184]:
image_path = download_data(
        source="https://github.com/mrdbourke/pytorch-deep-learning/raw/main/images/pizza_steak_sushi.zip",
        destination=Path("pizza_steak_sushi")
        )
image_path

[INFO] data/pizza_steak_sushi directory exists, skipping download.


PosixPath('data/pizza_steak_sushi')

In [185]:
train_path = image_path / "train"
test_path = image_path / "test"

## Prepare dataset | dataloader

In [186]:
RANDOM_SEED = 42
HEIGHT, WIDTH = 224, 224
IMG_SIZE = (HEIGHT, WIDTH)

BATCH_SIZE = 32

manual_transforms = transforms.Compose(
        [
            transforms.Resize(IMG_SIZE),
            transforms.ToTensor(),
            ]
        )

train_dataset = torchvision.datasets.ImageFolder(root=train_path, transform=manual_transforms)
test_dataset = torchvision.datasets.ImageFolder(root=test_path, transform=manual_transforms)

train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE)
test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE)

In [187]:
PATCH_SIZE = 16

conv2d = nn.Conv2d(
        in_channels=3,
        out_channels=768,
        kernel_size=PATCH_SIZE,
        stride=PATCH_SIZE,
        padding=0
        )

flatten = nn.Flatten(start_dim=1, end_dim=2)

# Patch embeding | class token | position embeding 

In [188]:
im_batch: torch.Tensor = next(iter(train_dataloader))[0]
im: torch.Tensor = im_batch[0]

print(im.shape)
image_embeddings: torch.Tensor = conv2d(im)
flattened_image_embeddings: torch.Tensor = flatten(image_embeddings)
print(image_embeddings.shape, flattened_image_embeddings.shape, sep="\n")

torch.Size([3, 224, 224])
torch.Size([768, 14, 14])
torch.Size([768, 196])


In [189]:
im_batch, _ = next(iter(train_dataloader))
im: torch.Tensor = im_batch[0].unsqueeze(0)
im.shape

torch.Size([1, 3, 224, 224])

In [190]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int, patch_size: int, embed_dim: int):
        super().__init__()

        self.patch_size = patch_size
        self.patcher = nn.Sequential(
                nn.Conv2d(
                        in_channels=in_channels,
                        out_channels=embed_dim,
                        stride=patch_size,
                        kernel_size=patch_size,
                        padding=0
                        ),
                nn.Flatten(start_dim=-2, end_dim=-1)
                )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        image_resolution = x.shape[-1]
        assert image_resolution % self.patch_size == 0, f"[ERROR] - Image resolution is not divisible by patch size. Image size {x.shape} | Patch size {self.patch_size}"

        return self.patcher(x).permute(0, 2, 1)


set_seeds()
patchify = PatchEmbedding(in_channels=3, patch_size=PATCH_SIZE, embed_dim=768)
patch_embedded_image = patchify(im)

print(
        f"{im.shape}",
        f"{patchify(im).shape}",
        sep="\n"
        )

torch.Size([1, 3, 224, 224])
torch.Size([1, 196, 768])


In [191]:
embedding_dimension = 768
class_token = nn.Parameter(
        torch.randn(1, 1, embedding_dimension), requires_grad=True
        )

print(class_token.shape, patch_embedded_image.shape, sep="\n")

patch_embedded_image_with_class_embedding = torch.cat(
        (class_token, patch_embedded_image), dim=1
        )
patch_position_embeddings = nn.Parameter(
        torch.randn(1, 197, embedding_dimension),
        requires_grad=True
        )

position_and_class_embeddings = (
        patch_embedded_image_with_class_embedding + patch_position_embeddings)

patch_embedded_image_with_class_embedding.shape, position_and_class_embeddings.shape

torch.Size([1, 1, 768])
torch.Size([1, 196, 768])


(torch.Size([1, 197, 768]), torch.Size([1, 197, 768]))

# Multi-head attention

In [212]:
class MultiHeadSelfAttentionBlock(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, attn_dropout: float | int):
        super().__init__()

        self.layer_norm = nn.LayerNorm(normalized_shape=embed_dim)
        self.multihead_attn = nn.MultiheadAttention(
                embed_dim=embed_dim,
                num_heads=num_heads,
                dropout=attn_dropout,
                batch_first=True
                )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x: torch.Tensor = self.layer_norm(x)
        return self.multihead_attn(query=x, key=x, value=x, need_weights=False)[0]

In [215]:
msa_block = MultiHeadSelfAttentionBlock(embed_dim=embedding_dimension, num_heads=12, attn_dropout=0)

patch_image_after_msa = msa_block(position_and_class_embeddings)
patch_image_after_msa.shape

torch.Size([1, 197, 768])

# Multilayer Perceptron (MLP)¶

In [216]:
class MLPBlock(nn.Module):
    def __init__(self, embedding_dim: int = embedding_dimension, mlp_dim: int = 3072, dropout: float = 0.1):
        super().__init__()

        self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)
        self.mlp = nn.Sequential(
                nn.Linear(in_features=embedding_dim, out_features=mlp_dim),
                nn.GELU(),
                nn.Dropout(p=dropout),
                nn.Linear(in_features=mlp_dim, out_features=embedding_dim),
                nn.Dropout(p=dropout)
                )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x: torch.Tensor = self.layer_norm(x)
        return self.mlp(x)

In [217]:
mlp_block = MLPBlock()

patch_image_through_mlp_block: torch.Tensor = mlp_block(patch_image_after_msa)
patch_image_through_mlp_block.shape

torch.Size([1, 197, 768])

# Creating a Transformer Encoder by combining our custom made layers

In [223]:
# 1. Create a class that inherits from nn.Module
class TransformerEncoderBlock(nn.Module):
    """Creates a Transformer Encoder block."""

    # 2. Initialize the class with hyperparameters from Table 1 and Table 3
    def __init__(
            self,
            embedding_dim: int = 768,  # Hidden size D from Table 1 for ViT-Base
            num_heads: int = 12,  # Heads from Table 1 for ViT-Base
            mlp_size: int = 3072,  # MLP size from Table 1 for ViT-Base
            mlp_dropout: float = 0.1,  # Amount of dropout for dense layers from Table 3 for ViT-Base
            attn_dropout: float = 0
            ):  # Amount of dropout for attention layers
        super().__init__()

        # 3. Create MSA block (equation 2)
        self.msa_block = MultiHeadSelfAttentionBlock(
                embed_dim=embedding_dim,
                num_heads=num_heads,
                attn_dropout=attn_dropout
                )

        # 4. Create MLP block (equation 3)
        self.mlp_block = MLPBlock(
                embedding_dim=embedding_dim,
                mlp_dim=mlp_size,
                dropout=mlp_dropout
                )

    # 5. Create a forward() method
    def forward(self, x):
        # 6. Create residual connection for MSA block (add the input to the output)
        x = self.msa_block(x) + x

        # 7. Create residual connection for MLP block (add the input to the output)
        x = self.mlp_block(x) + x

        return x

In [225]:
transformer_encoder_block = TransformerEncoderBlock()

summary(transformer_encoder_block, (1, 197, 768))

Layer (type:depth-idx)                   Output Shape              Param #
TransformerEncoderBlock                  [1, 197, 768]             --
├─MultiHeadSelfAttentionBlock: 1-1       [1, 197, 768]             --
│    └─LayerNorm: 2-1                    [1, 197, 768]             1,536
│    └─MultiheadAttention: 2-2           [1, 197, 768]             2,362,368
├─MLPBlock: 1-2                          [1, 197, 768]             --
│    └─LayerNorm: 2-3                    [1, 197, 768]             1,536
│    └─Sequential: 2-4                   [1, 197, 768]             --
│    │    └─Linear: 3-1                  [1, 197, 3072]            2,362,368
│    │    └─GELU: 3-2                    [1, 197, 3072]            --
│    │    └─Dropout: 3-3                 [1, 197, 3072]            --
│    │    └─Linear: 3-4                  [1, 197, 768]             2,360,064
│    │    └─Dropout: 3-5                 [1, 197, 768]             --
Total params: 7,087,872
Trainable params: 7,087,872
Non-tr

In [233]:
torch_transformer_encoder_layer = nn.TransformerEncoderLayer(
        d_model=embedding_dimension,
        nhead=12,
        dim_feedforward=3072,
        dropout=0.1,
        activation=nn.GELU(),
        batch_first=True,
        norm_first=True
        )

summary(model=torch_transformer_encoder_layer, input_size=(1, 197, 768), col_names=["input_size", "output_size", "num_params"], col_width=20, row_settings=["var_names"])

Layer (type (var_name))                            Input Shape          Output Shape         Param #
TransformerEncoderLayer (TransformerEncoderLayer)  [1, 197, 768]        [1, 197, 768]        7,087,872
Total params: 7,087,872
Trainable params: 7,087,872
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0
Input size (MB): 0.61
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.61

In [234]:
torch_transformer_encoder_layer

TransformerEncoderLayer(
  (self_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
  )
  (linear1): Linear(in_features=768, out_features=3072, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (linear2): Linear(in_features=3072, out_features=768, bias=True)
  (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (dropout1): Dropout(p=0.1, inplace=False)
  (dropout2): Dropout(p=0.1, inplace=False)
  (activation): GELU(approximate='none')
)

# Vit

In [235]:
# 1. Create a ViT class that inherits from nn.Module
class ViT(nn.Module):
    """Creates a Vision Transformer architecture with ViT-Base hyperparameters by default."""
    # 2. Initialize the class with hyperparameters from Table 1 and Table 3
    def __init__(self,
                 img_size:int=224, # Training resolution from Table 3 in ViT paper
                 in_channels:int=3, # Number of channels in input image
                 patch_size:int=16, # Patch size
                 num_transformer_layers:int=12, # Layers from Table 1 for ViT-Base
                 embedding_dim:int=768, # Hidden size D from Table 1 for ViT-Base
                 mlp_size:int=3072, # MLP size from Table 1 for ViT-Base
                 num_heads:int=12, # Heads from Table 1 for ViT-Base
                 attn_dropout:float=0, # Dropout for attention projection
                 mlp_dropout:float=0.1, # Dropout for dense/MLP layers
                 embedding_dropout:float=0.1, # Dropout for patch and position embeddings
                 num_classes:int=1000): # Default for ImageNet but can customize this
        super().__init__() # don't forget the super().__init__()!

        # 3. Make the image size is divisble by the patch size
        assert img_size % patch_size == 0, f"Image size must be divisible by patch size, image size: {img_size}, patch size: {patch_size}."

        # 4. Calculate number of patches (height * width/patch^2)
        self.num_patches = (img_size * img_size) // patch_size**2

        # 5. Create learnable class embedding (needs to go at front of sequence of patch embeddings)
        self.class_embedding = nn.Parameter(data=torch.randn(1, 1, embedding_dim),
                                            requires_grad=True)

        # 6. Create learnable position embedding
        self.position_embedding = nn.Parameter(data=torch.randn(1, self.num_patches+1, embedding_dim),
                                               requires_grad=True)

        # 7. Create embedding dropout value
        self.embedding_dropout = nn.Dropout(p=embedding_dropout)

        # 8. Create patch embedding layer
        self.patch_embedding = PatchEmbedding(in_channels=in_channels,
                                              patch_size=patch_size,
                                              embedding_dim=embedding_dim)

        # 9. Create Transformer Encoder blocks (we can stack Transformer Encoder blocks using nn.Sequential())
        # Note: The "*" means "all"
        self.transformer_encoder = nn.Sequential(*[TransformerEncoderBlock(embedding_dim=embedding_dim,
                                                                            num_heads=num_heads,
                                                                            mlp_size=mlp_size,
                                                                            mlp_dropout=mlp_dropout) for _ in range(num_transformer_layers)])

        # 10. Create classifier head
        self.classifier = nn.Sequential(
            nn.LayerNorm(normalized_shape=embedding_dim),
            nn.Linear(in_features=embedding_dim,
                      out_features=num_classes)
        )

    # 11. Create a forward() method
    def forward(self, x):

        # 12. Get batch size
        batch_size = x.shape[0]

        # 13. Create class token embedding and expand it to match the batch size (equation 1)
        class_token = self.class_embedding.expand(batch_size, -1, -1) # "-1" means to infer the dimension (try this line on its own)

        # 14. Create patch embedding (equation 1)
        x = self.patch_embedding(x)

        # 15. Concat class embedding and patch embedding (equation 1)
        x = torch.cat((class_token, x), dim=1)

        # 16. Add position embedding to patch embedding (equation 1)
        x = self.position_embedding + x

        # 17. Run embedding dropout (Appendix B.1)
        x = self.embedding_dropout(x)

        # 18. Pass patch, position and class embedding through transformer encoder layers (equations 2 & 3)
        x = self.transformer_encoder(x)

        # 19. Put 0 index logit through classifier (equation 4)
        x = self.classifier(x[:, 0]) # run on each sample in a batch at 0 index

        return x