<a href="https://colab.research.google.com/github/facial09/pytorch_basic/blob/main/stablediffData.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
pip install einops

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting einops
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.6.1


In [25]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
import torchsummary

In [26]:
class Stem(nn.Module):

  def __init__(self, img_size : int = 32, patch_size : int = 4, in_channels : int = 3, emb_size : int = 48):
    super().__init__()
    self.proj = nn.Sequential(
        nn.Conv2d(in_channels, emb_size, kernel_size = patch_size, stride = patch_size),
        Rearrange('b e (h) (w) -> b (h w) e')
    )
    # self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
    self.positions = nn.Parameter(torch.randn((img_size // patch_size) ** 2 , emb_size))
    self.linear = nn.Linear(emb_size, emb_size * 2)

  def forward(self, x):
    
    x = self.proj(x)
    # cls_token = repeat(self.cls_token, '() n e -> b n e', b = b)
    # x = torch.cat([cls_token, x], dim = 1)
    x += self.positions
    x = self.linear(x)
    return x

In [27]:
x = torch.randn(128, 3, 32, 32)
stem = Stem()
stem(x).shape

torch.Size([128, 64, 96])

In [31]:
class MultiheadAttention(nn.Module):

  def __init__(self, emb_size , num_heads : int = 8, dropout : float = 0.):
    super().__init__()
    self.emb_size = emb_size
    self.num_heads = num_heads

    self.qkv = nn.Linear(emb_size, emb_size * 3)
    self.att_drop = nn.Dropout(dropout)
    self.projection = nn.Linear(emb_size, emb_size)

  def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:
      # split keys, queries and values in num_heads
      qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
      queries, keys, values = qkv[0], qkv[1], qkv[2]
      # sum up over the last axis
      energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len
      if mask is not None:
          fill_value = torch.finfo(torch.float32).min
          energy.mask_fill(~mask, fill_value)
          
      scaling = self.emb_size ** (1/2)
      att = F.softmax(energy, dim=-1) / scaling
      att = self.att_drop(att)
      # sum up over the third axis
      out = torch.einsum('bhal, bhlv -> bhav ', att, values)
      out = rearrange(out, "b h n d -> b n (h d)")
      out = self.projection(out)
      return out



In [32]:
class Block(nn.Module):

  def __init__(self, dim, expansion_ratio : int = 2, num_heads = 8):
    super().__init__()
    self.token_mixer = nn.Sequential(
        nn.LayerNorm(dim),
        MultiheadAttention(dim, num_heads = num_heads),
    )
    self.channel_mixer = nn. Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, dim * expansion_ratio),
        nn.SiLU(),
        nn.Linear(dim * expansion_ratio, dim),
    )

  def forward(self, x):
    x = x + self.token_mixer(x)
    x = x + self.channel_mixer(x)

    return x

In [33]:
x = torch.randn(128, 3, 32, 32)
stem = Stem()
stem_output = stem(x)
stem_output.shape
block = Block(96)
block(stem_output).shape

torch.Size([128, 64, 96])

In [None]:
class VisionTransformer(nn.Module):
    def __init__(self, image_size, patch_size, num_classes, dim, depth, heads, mlp_ratio):
        super().__init__()
        assert image_size % patch_size == 0, "Image dimensions must be divisible by the patch size."
        num_patches = (image_size // patch_size) ** 2
        patch_dim = 3 * patch_size ** 2

        self.patch_size = patch_size
        self.pos_embedding = nn.Parameter(torch.zeros(1, num_patches + 1, dim))
        self.patch_to_embedding = nn.Linear(patch_dim, dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))
        self.transformer = nn.ModuleList([
            EncoderBlock(dim=dim, num_heads=heads[0], mlp_ratio=mlp_ratio[0]) for _ in range(depth[0])
        ])
        self.transformer.extend([
            EncoderBlock(dim=dim, num_heads=heads[1], mlp_ratio=mlp_ratio[1]) for _ in range(depth[1])
        ])
        self.transformer.extend([
            EncoderBlock(dim=dim, num_heads=heads[2], mlp_ratio=mlp_ratio[2]) for _ in range(depth[2])
        ])
        self.layer_norm = nn.LayerNorm(dim)
        self.fc = nn.Linear(dim, num_classes)

    def forward(self, x):
        x = nn.functional.pad(x, pad=(self.patch_size // 2, self.patch_size // 2,
                                      self.patch_size // 2, self.patch_size // 2), mode='reflect')
        x = Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size)(x)
        x = self.patch_to_embedding(x)
        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embedding
        x = nn.functional.dropout(x, p=0.