This is a re-implementation of the original transformers. This is for my general practice, for more of a tutorial structure. Consider going through my previous implementation.  

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

In [None]:
class MultiHeadAttention(nn.Module):

    def __init__(self, num_heads, d_model):
        super().__init__()
        assert d_model%num_heads == 0

        self.num_heads = num_heads
        self.d_model = d_model
        self.d_k = d_model//num_heads

        self.W_Q = nn.Linear(self.d_model, self.d_model)
        self.W_K = nn.Linear(self.d_model, self.d_model)
        self.W_V = nn.Linear(self.d_model, self.d_model)
        self.W_O = nn.Linear(self.d_model, self.d_model)


    @staticmethod
    def scaled_dot_product_attention(Q,K,V, mask):
        """
        Args:
            query: (batch_size, num_heads, seq_len_q, d_q)
            key: (batch_size, num_heads, seq_len_k, d_k)
            value: (batch_size, num_heads, seq_len_v, d_v)
            mask: Optional mask to prevent attention to certain positions
        """
        assert Q.shape[-1] == K.shape[-1] #query and key dimension should be equal

        attention_score  = (torch.matmul(Q,K.transpose(-2,-1)))/torch.sqrt(torch.tensor(K.shape[-1]))
    
        if mask:
            upper_mask = torch.tril(torch.ones(attention_score.shape[-2], attention_score.shape[-1]))
            upper_mask.masked_fill_(upper_mask==0, float('-inf'))
            attention_score = attention_score + upper_mask

        attention_weights  = F.softmax(attention_score, dim=-1)
        assert attention_weights.shape == (Q.shape[0], Q.shape[1], Q.shape[2], K.shape[2])

        Z = torch.einsum('bhqk,bhkd -> bhqd', attention_weights, V)
        
        return Z
    
    def forward(self, input_matrix):
        batch_size, seq_len = input_matrix.shape[0], input_matrix.shape[1]

        self.Q = self.W_Q(input_matrix).reshape(batch_size, seq_len, self.num_heads, self.d_k).transpose(1,2)
        self.K = self.W_K(input_matrix).reshape(batch_size, seq_len, self.num_heads, self.d_k).transpose(1,2)
        self.V = self.W_V(input_matrix).reshape(batch_size, seq_len, self.num_heads, self.d_k).transpose(1,2)

        attention_score = self.scaled_dot_product_attention(self.Q, self.K, self.V, 0).transpose(1,2).reshape(batch_size, seq_len,self.d_model)
        return self.W_O(attention_score)

In [8]:
def test_attention():
    # Small test case
    batch_size, num_heads, seq_len, d_k = 2, 4, 6, 8
    
    # Create sample tensors
    Q = torch.randn(batch_size, num_heads, seq_len, d_k)
    K = torch.randn(batch_size, num_heads, seq_len, d_k)  # seq_len same for simplicity
    V = torch.randn(batch_size, num_heads, seq_len, d_k)  # d_v = d_k for simplicity
    
    # Test without mask
    output = MultiHeadAttention.scaled_dot_product_attention(Q, K, V, mask=False)
    print(f"Output shape: {output.shape}")
    print(f"Expected: {(batch_size, num_heads, seq_len, d_k)}")
    
    # Test with mask
    output_masked = MultiHeadAttention.scaled_dot_product_attention(Q, K, V, mask=True)
    print(f"Masked output shape: {output_masked.shape}")
    
    print("✅ Basic tests passed!")

# Run the test
test_attention()

Output shape: torch.Size([2, 4, 6, 8])
Expected: (2, 4, 6, 8)
Masked output shape: torch.Size([2, 4, 6, 8])
✅ Basic tests passed!


In [None]:
class Encoder(nn.Module):

    def __init__(self):
        super().__init__()

In [None]:
class Decoder(nn.Module):

    def __init__(self):
        super().__init__()

In [None]:
class PositionalEncoding(nn.Module):

    def __init__(self):
        super().__init__()

In [None]:
class FeedForwardNetwork(nn.Module):

    def __init__(self):
        super().__init__()

In [None]:
class Transformer(nn.Module):

    def __init__(self):
        super().__init__()

In [None]:
def train():
    pass