In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# Dummy Data
batch_size = 4
seq_len = 10
d_model = 32

x = torch.rand(batch_size, seq_len, d_model)

x.shape

torch.Size([4, 10, 32])

In [16]:
class SingleHeadSelfAttention(nn.Module):
    def __init__(self, d_model, d_k, d_v, mask=None):
        super().__init__()
        self.d_k = d_k
        self.W_q = nn.Linear(d_model, d_k, bias=False)
        self.W_k = nn.Linear(d_model, d_k, bias=False)
        self.W_v = nn.Linear(d_model, d_v, bias=False)
        self.mask = mask

    def forward(self, x):
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if self.mask is not None:
            scores = scores.masked_fill(self.mask, -1e9)
        attn_weights = F.softmax(scores, dim=-1)
        output = attn_weights @ V
        return output, attn_weights

causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
single_attn = SingleHeadSelfAttention(d_model, d_k = 64, d_v = 64, mask = causal_mask)

outputs, weights = single_attn(x)

print(f"Input: {x.shape}")
print(f"Output: {outputs.shape}")
print(f"Weights: {weights.shape}")


Input: torch.Size([4, 10, 32])
Output: torch.Size([4, 10, 64])
Weights: torch.Size([4, 10, 10])


In [19]:
outputs

tensor([[[-0.1374, -0.0454, -0.3611,  ..., -0.0726, -0.0083,  0.1151],
         [-0.0951, -0.1035, -0.2478,  ..., -0.1652, -0.0087,  0.1415],
         [-0.0890, -0.1329, -0.3132,  ..., -0.1566, -0.0844,  0.2087],
         ...,
         [-0.1074, -0.0594, -0.4337,  ..., -0.2579, -0.0057,  0.2130],
         [-0.0835, -0.0351, -0.4396,  ..., -0.2425, -0.0050,  0.1830],
         [-0.0571, -0.0250, -0.4612,  ..., -0.2392, -0.0007,  0.1635]],

        [[-0.0416, -0.1580, -0.2462,  ..., -0.0277,  0.0939,  0.2789],
         [ 0.1005, -0.1812, -0.4332,  ..., -0.1473,  0.0541,  0.2413],
         [ 0.0637, -0.1874, -0.4412,  ..., -0.3078,  0.1104,  0.2456],
         ...,
         [-0.0608, -0.0794, -0.4670,  ..., -0.3781,  0.0565,  0.2503],
         [-0.0652, -0.0807, -0.4422,  ..., -0.3694,  0.0459,  0.2385],
         [-0.0797, -0.0541, -0.3983,  ..., -0.3372,  0.0205,  0.2331]],

        [[-0.2602,  0.0650, -0.5341,  ..., -0.3594,  0.0830,  0.1337],
         [-0.1267, -0.0390, -0.3174,  ..., -0

In [20]:

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x, mask=None):
        b, s, d = x.shape

        qkv = self.W_qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)
        
        q = q.view(b, s, self.num_heads, self.d_k).transpose(1, 2)
        k = k.view(b, s, self.num_heads, self.d_k).transpose(1, 2)
        v = v.view(b, s, self.num_heads, self.d_k).transpose(1, 2)

        scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
            
        attn_weights = F.softmax(scores, dim=-1)
        
        context = (attn_weights @ v).transpose(1, 2).contiguous().view(b, s, d)
        
        output = self.W_o(context)
        return output
    
class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.mha = MultiHeadAttention(d_model, num_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x):
        # Attention sub-layer
        attn_output = self.mha(x)
        x = self.norm1(x + self.dropout1(attn_output))
        
        # Feed-forward sub-layer
        ffn_output = self.ffn(x)
        x = self.norm2(x + self.dropout2(ffn_output))
        return x

print("\n--- Transformer Block ---")
num_heads = 8
d_ff = 64
transformer_block = TransformerBlock(d_model, num_heads, d_ff)

# Re-using the same dummy data 'x'
print(f"Input to block: {x.shape}")
output_block = transformer_block(x)
print(f"Output from block: {output_block.shape}")


--- Transformer Block ---
Input to block: torch.Size([4, 10, 32])
Output from block: torch.Size([4, 10, 32])


In [10]:
causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
causal_mask

tensor([[False,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False, False, False, False,  True,  True,  True],
        [False, False, False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False, False, False, False,  True],
        [False, False, False, False, False, False, False, False, False, False]])

In [None]:
mha_module = transformer_block.mha
output_masked = mha_module(x, mask=causal_mask)

In [None]:
def get_masked_weights(mha, x, mask):
    b, s, d = x.shape
    qkv = mha.W_qkv(x)
    q, k, v = qkv.chunk(3, dim=-1)
    
    q = q.view(b, s, mha.num_heads, mha.d_k).transpose(1, 2)
    k = k.view(b, s, mha.num_heads, mha.d_k).transpose(1, 2)

    scores = (q @ k.transpose(-2, -1)) / math.sqrt(mha.d_k)
    
    print(f"\nScores shape before masking: {scores.shape}")
    
    if mask is not None:
        scores = scores.masked_fill(mask, -1e9)
        print("Scores after masking (showing first head, first item in batch):")
        print(scores[0, 0])
        
    attn_weights = F.softmax(scores, dim=-1)
    return attn_weights

print("\n--- Verifying Weights ---")
masked_weights = get_masked_weights(mha_module, x, causal_mask)

print(f"\nFinal Attention Weights shape: {masked_weights.shape}")
print("Weights (first head, first item in batch):")
print(masked_weights[0, 0].detach().round(decimals=2))


In [22]:
def generate_square_subsequent_mask(sz):
    return torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1)
generate_square_subsequent_mask(4).bool()

tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])