In [3]:
import torch
import torch.nn as nn
from x_transformers import Encoder
from einops import rearrange, repeat

In [5]:
'''
PatchEmbed class, adapted from https://towardsdatascience.com/implementing-visualttransformer-in-pytorch-184f9f16f632 I think, but I dont have medium premium so idk
- This class is used to convert the image into patches using a convolutional layer
'''
class PatchEmbed(nn.Module):
    """Image to Patch Embedding"""

    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=64):
        super().__init__()
        if isinstance(img_size, int):
            img_size = img_size, img_size
        if isinstance(patch_size, int):
            patch_size = patch_size, patch_size
        #calculate the number of patches
        self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])

        #convolutional layer to convert the image into patches
        self.conv = nn.Conv2d(
            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
        )
        

    def forward(self, x):
        x = self.conv(x)
        #flatten the patches
        x = rearrange(x, 'b e h w -> b (h w) e')
        return x

In [None]:
class Encoder(nn.Module):
    def __init__(self, img_size, patch_size, in_chans, embed_dim, depth, num_heads, post_emb_norm=True, M = 4):
        super().__init__()
        self.M = 4

        #define the patch embedding and positional embedding
        self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        self.num_tokens = self.patch_embed.patch_shape[0] * self.patch_embed.patch_shape[1]
        self.pos_embedding = nn.Parameter(torch.randn(1, self.num_tokens+1, embed_dim))
        
        #define the cls and mask tokens
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.mask_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        nn.init.trunc_normal_(self.cls_token, 0.02)
        nn.init.trunc_normal_(self.mask_token, 0.02)

        #define the encoder and decoder, as well as the layer normalization and dropout
        self.post_emb_norm = nn.LayerNorm(embed_dim) if post_emb_norm else nn.Identity()
        self.norm = nn.LayerNorm(embed_dim)
        self.encoder = Encoder(
            dim=embed_dim,
            heads=num_heads,
            depth=depth, 
        )    
    def forward(self, x):
        #get the patch embeddings
        x = self.patch_embed(x)
        #add the cls and mask tokens
        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=x.shape[0])
        mask_tokens = repeat(self.mask_token, '() n e -> b n e', b=x.shape[0])
        x = torch.cat((cls_tokens, mask_tokens, x), dim=1)
        #add the positional embeddings
        x = x + self.pos_embedding
        #normalize the embeddings
        x = self.post_emb_norm(x)
        #get teacher targets:
        

        return x



In [6]:
x = torch.randn(1, 3, 224, 224)
patch_embed = PatchEmbed()
x = patch_embed(x)
print(x.shape)

torch.Size([1, 196, 64])
