In [12]:
# official code (JAX) from several Google transformer papers: https://github.com/google-research/vision_transformer
# code taken from mildlyoverfitted's tutorial: https://www.youtube.com/watch?v=ovB0ddFtzzA&ab_channel=mildlyoverfitted
# PyTorch code used in tutorial with pretrained weights: https://github.com/huggingface/pytorch-image-models

import torch # pytorch 2.0.1 (https://pytorch.org/get-started/pytorch-2.0/)
import torch.nn as nn    

In [13]:
class PatchEmbed(nn.Module):
    # splits each image into linear projections, or patches, so that the transformer can learn from them
    # PARAMS:
        # img_size: size of input image
        # patch_size: size of each patch
        # channels: num of input channels
        # dims: embedding dimension
    
    # ATTRIBUTES:
        # patches: num of patches per image
        # proj (nn.Conv2d): convolutional layer that splits image into patches and embeds
    
    def __init__(self, img_size, patch_size, channels=3, dims=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches = (img_size // patch_size) ** 2
        
        self.proj = nn.Conv2d(
            channels, 
            dims, 
            kernel_size = patch_size, 
            stride = patch_size) # both kernel size and stride are equal to the patch size 
                                 # so that there will never be overlapping patches
    def forward(self, x):
        # run a forward pass. a tensor simply represents a batch of images.
        # PARAMS:
            # x (torch.Tensor): shape (n_samples, channels, img_size, img_size) for a square
        
        # OUTPUT: a 3D tensor representing the set of resulting patches
            # torch.Tensor: shape (n_samples, n_patches, dims)
        
        x = self.proj(
                x # by running the input tensor through the Conv2d layer, we will get a 4D tensor
        ) # shape (n_samples, dims, n_patches ** 0.5, n_patches ** 0.5)
        x = x.flatten(2) # flattens the tensor from the 2nd axis onward
                         # eg., (n_patches ** 0.5) * (n_patches ** 0.5) = n_patches
        x = x.transpose(1, 2) # swap 1st and 2nd axes
        

IndentationError: expected an indented block (3852456927.py, line 27)