In [110]:
import math

import torch
from torch import nn
from torch.nn import functional as F
from transformers import BertTokenizer

In [115]:
tokenizer = BertTokenizer.from_pretrained("prajjwal1/bert-tiny")
VOCAB_SIZE = tokenizer.vocab_size
EMB_SIZE = 128
MAX_SEQ_LEN = 512
FEED_FORWARD_EXPANSION = 4


In [2]:
class InputEmbeddings(nn.Module):
    def __init__(self, vocab_size: int, emb_size: int):
        """Initializes the InputEmbeddings module.

        Args:
            vocab_size (int): The size of the vocabulary.
            emb_size (int): The size of each embedding vector.

        Attributes:
            embedding (nn.Embedding): Embedding layer that maps input tokens to their embeddings.
        """
        super().__init__()
        self.vocab_size = vocab_size
        self.emb_size = emb_size
        self.embedding = nn.Embedding(vocab_size, emb_size)

    def forward(self, x):
        return self.embedding(x) / math.sqrt(self.emb_size)


In [36]:
class PositionalEncoding(nn.Module):
    def __init__(self, emb_size, max_len=5000, dropout=0.1):
        """
        Initializes the PositionalEncoding module. Uses the sinusoidal encoding scheme.

        Args:
            emb_size (int): The size of each embedding vector.
            dropout (float): The dropout rate.
            max_len (int, optional): The maximum length of the input sequence.
                Defaults to 5000.
        """
        super().__init__()
        self.emb_size = emb_size
        self.dropout = nn.Dropout(dropout)
        self.max_len = max_len

        # Create a tensor for position indices # Shape: [max_len, 1]
        positions = torch.arange(0, self.max_len).unsqueeze(1)

        # Create a tensor for the even indices
        div_term = 10000 ** (torch.arange(0, self.emb_size, 2).float() / self.emb_size)

        # Apply sine and cosine functions on the entire tensor in one go
        pe = torch.zeros(self.max_len, self.emb_size)

        # Sine for even indices
        pe[:, 0::2] = torch.sin(positions / div_term)

        # Cosine for odd indices
        pe[:, 1::2] = torch.cos(positions / div_term)

        self.pe = pe.unsqueeze(0)  # Shape: [1, max_len, emb_size]

    def forward(self, x):
        x = x + self.pe[:, : x.shape[1]]
        return self.dropout(x)


In [27]:
text = [
    "This is sample sentence one",
    "This is sample sentence two", 
    "This is a very large sample sentence to increase the sequence length", 
    "This is small"
]

tokens = tokenizer(text, return_tensors="pt", padding=True)
B, L = tokens.input_ids.shape
tokens.input_ids.shape, B, L


(torch.Size([4, 14]), 4, 14)

In [40]:
tokens.input_ids

tensor([[ 101, 2023, 2003, 7099, 6251, 2028,  102,    0,    0,    0,    0,    0,
            0,    0],
        [ 101, 2023, 2003, 7099, 6251, 2048,  102,    0,    0,    0,    0,    0,
            0,    0],
        [ 101, 2023, 2003, 1037, 2200, 2312, 7099, 6251, 2000, 3623, 1996, 5537,
         3091,  102],
        [ 101, 2023, 2003, 2235,  102,    0,    0,    0,    0,    0,    0,    0,
            0,    0]])

In [37]:
input_embeddings = InputEmbeddings(VOCAB_SIZE, EMB_SIZE)
positional_encoding = PositionalEncoding(EMB_SIZE, max_len=MAX_SEQ_LEN)

In [38]:
x = input_embeddings(tokens["input_ids"])
print("Input Embeddings:", x.shape)
x = positional_encoding(x)
x.shape

Input Embeddings: torch.Size([4, 14, 128])


torch.Size([4, 14, 128])

In [44]:
positional_encoding.pe[:, :14, :].shape

torch.Size([1, 14, 128])

In [99]:
class selfAttentionHead(nn.Module):
    def __init__(self, emb_size, head_dim):
        super().__init__()
        self.query = nn.Linear(emb_size, head_dim, bias=False)
        self.key = nn.Linear(emb_size, head_dim, bias=False)
        self.value = nn.Linear(emb_size, head_dim, bias=False)
        self.scale = torch.Tensor([emb_size ** -0.5])

    def forward(self, x):
        query = self.query(x)
        key = self.key(x)
        value = self.value(x)
        scores  = torch.bmm(query, key.transpose(1, 2)) * self.scale
        attention_weights = torch.softmax(scores, dim=-1)
        context = torch.bmm(attention_weights, value)
        return context
        


In [100]:
class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size, n_heads):
        super().__init__()
        assert emb_size % n_heads == 0
        self.head_dim = emb_size // n_heads
        self.heads = nn.ModuleList(
            [selfAttentionHead(emb_size, self.head_dim) for _ in range(n_heads)]
        )
        self.linear = nn.Linear(emb_size, emb_size)
        self.dropout = nn.Dropout(0.1)
    
    def forward(self, x):
        x = torch.cat([head(x) for head in self.heads], dim=-1)
        x = self.linear(x)
        return self.dropout(x)

In [105]:
attention_block = selfAttentionHead(EMB_SIZE, 16)

In [106]:
a = attention_block(x)

In [103]:
multihead_attention_bloc = MultiHeadAttention(EMB_SIZE, 8)

In [107]:
a.shape

torch.Size([4, 14, 16])

In [109]:
x = multihead_attention_bloc(x)

In [113]:
class FeedForward(nn.Module):
    def __init__(self, emb_size, expansion_factor):
        super().__init__()
        self.emb_size = emb_size
        self.expansion_factor = expansion_factor
        
        self.linear1 = nn.Linear(emb_size, expansion_factor * emb_size)
        self.linear2 = nn.Linear(expansion_factor * emb_size, emb_size)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        x = self.linear1(x)
        x = F.gelu(x)
        x = self.linear2(x)
        return self.dropout(x)

In [116]:
feed_forward = FeedForward(EMB_SIZE, FEED_FORWARD_EXPANSION)

In [117]:
feed_forward(x).shape

torch.Size([4, 14, 128])