In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PatchEmbeddings(nn.Module):

    def __init__(self, img_size=96, patch_size = 16, hidden_dim = 512):
        super().__init__()

        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2

        self.conv = nn.Conv2d(in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, X: torch.Tensor):

        X = self.conv(X) # (B, hidden_dim, sqrt(num_patches), sqrt(num_patches))

        X = X.flatten(2) # (B, hidden_dim, num_patches)

        X = X.transpose(1,2) # (B, num_patches, hidden_dim)

        return X

x = torch.rand(5,3,96,96)
patch_emb = PatchEmbeddings()
y = patch_emb(x)

In [8]:
class Head(nn.Module):

    def __init__(self, n_embd, head_size, dropout=0.1, is_decoder = False) -> None:
        super().__init__()

        self.key = nn.Linear(n_embd, head_size, bias = False)

        self.query = nn.Linear(n_embd, head_size, bias = False)

        self.value = nn.Linear(n_embd, head_size, bias = False)

        self.dropout = nn.Dropout(dropout)

        self.is_decoder = is_decoder

    def forward(self, x):

        # Batch dim (B), Sequence length (T), Embdedding dim (C)
        B,T,C = x.shape

        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        wei = q @ k.transpose(-2,-1) * (C ** -0.5) # (B,T,T)

        if self.is_decoder:
            # If head used in decoder appy causal mask to attention scores to prevent attending to future
            tril = torch.tril(torch.ones(T,T, dtype = torch.bool, device = x.device))
            wei = wei.masked_fill(tril == 0, float('inf'))

        # Apply softmax to attention scores to get attention probs
        wei = F.softmax(wei, dim=-1)

        # Apply dropout for regularization
        wei = self.dropout(wei)

        out = wei @ v # (B,T,T) @ (B,T, head_size) -> (B, T, head_size)

        return out

In [9]:
# Testing
x = torch.rand(5,36,512)
head = Head(512, 4)
out = head(x)
out.shape

torch.Size([5, 36, 4])

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_embd, n_heads, dropout = 0.1, is_decoder = False) -> None:
        super().__init__(*args, **kwargs)