In [4]:
import torch

torch.manual_seed(123)
if torch.backends.mps.is_available():
    device = torch.device("mps")   # Apple Silicon GPU (Metal)
elif torch.cuda.is_available():
    device = torch.device("cuda")  # NVIDIA GPU
else:
    device = torch.device("cpu")   # CPU fallback

print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")

batch_size = 8
context_len = 1024
embed_dim = 768
embeddings = torch.randn((batch_size, context_len, embed_dim), device=device)


Using device: cuda
PyTorch version: 2.9.0+cu130


### 1. Casual MHA Attention Wrapper

In [5]:
import torch.nn as nn

class CausalAttention(nn.Module):

    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)  # New
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))  # New

    def forward(self, x):
        b, num_tokens, d_in = x.shape  # New batch dimension b
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2)  # Changed transpose
        attn_scores.masked_fill_(  # New, _ ops are in-place
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)  # New

        context_vec = attn_weights @ values
        return context_vec


class Ch03_MHA_Wrapper(nn.Module):

    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias)
             for _ in range(num_heads)]
        )
        self.out_proj = nn.Linear(d_out*num_heads, d_out*num_heads)

    def forward(self, x):
        context_vec = torch.cat([head(x) for head in self.heads], dim=-1)
        return self.out_proj(context_vec)


mha_ch03_wrapper = Ch03_MHA_Wrapper(
    d_in=embed_dim,
    d_out=embed_dim//12,
    context_length=context_len,
    dropout=0.0,
    num_heads=12,
    qkv_bias=False
).to(device)

out = mha_ch03_wrapper(embeddings)
print(out.shape)

torch.Size([8, 1024, 768])


### 2. Final Course MHA Implementation

In [33]:
class Ch03_MHA(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads  # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x)  # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2)

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)  # optional projection

        return context_vec

torch.manual_seed(123)

mha_ch03 = Ch03_MHA(
    d_in=embed_dim,
    d_out=embed_dim,
    context_length=context_len,
    dropout=0.0,
    num_heads=12,
    qkv_bias=False
).to(device)



out = mha_ch03(embeddings)
print(out.shape)

torch.Size([8, 1024, 768])


In [34]:
out

tensor([[[-0.3466,  0.1000, -0.1979,  ...,  0.2440,  0.0190,  0.4200],
         [-0.2110,  0.0193, -0.3820,  ...,  0.0796,  0.0849,  0.1457],
         [ 0.1017, -0.0742, -0.2684,  ...,  0.1443,  0.1109,  0.1662],
         ...,
         [ 0.0252, -0.0407, -0.0086,  ...,  0.0308, -0.0259,  0.0387],
         [ 0.0216, -0.0290, -0.0182,  ...,  0.0261, -0.0294,  0.0390],
         [ 0.0297, -0.0350, -0.0167,  ...,  0.0406, -0.0295,  0.0408]],

        [[ 0.2520,  0.4761, -0.4031,  ...,  0.2052,  0.3203,  0.2161],
         [ 0.1707,  0.2257,  0.0142,  ...,  0.0013,  0.2052, -0.0628],
         [ 0.0692,  0.2422, -0.1263,  ..., -0.0167, -0.0215, -0.1381],
         ...,
         [ 0.0351, -0.0340, -0.0125,  ...,  0.0210, -0.0496,  0.0555],
         [ 0.0469, -0.0355, -0.0117,  ...,  0.0182, -0.0533,  0.0542],
         [ 0.0301, -0.0378, -0.0109,  ...,  0.0235, -0.0519,  0.0561]],

        [[ 0.0508, -0.1723, -0.2799,  ...,  0.2172, -0.2496, -0.2822],
         [ 0.0254, -0.4323, -0.1117,  ..., -0

MHA with Cpmbined Weights

In [35]:
class MHAv_2_0(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads  # Reduce the projection dim to match desired output dim

        self.WQKV = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        qkv = self.WQKV(x)  # Shape: (b, num_tokens, 3 * d_out)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, 3 * d_out) -> (b, num_tokens, 3, num_heads, head_dim)
        qkv = qkv.view(b, num_tokens, 3, self.num_heads, self.head_dim)

        #Unbinding qvk into queries, keys and values:
        #(b, num_heads, num_tokens, 3 * head_dim) ->(3, b, num_heads, num_tokens, head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        queries, keys, values = qkv.unbind(dim = 0)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2)

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)  # optional projection

        return context_vec


torch.manual_seed(123)

mha_v_2_0 = MHAv_2_0(
    d_in=embed_dim,
    d_out=embed_dim,
    context_length=context_len,
    dropout=0.0,
    num_heads=12,
    qkv_bias=False
).to(device)

out_mha_v_2_0 = mha_v_2_0(embeddings)
print(out.shape)

torch.Size([8, 1024, 768])


In [36]:
out_mha_v_2_0

tensor([[[-0.3466,  0.1000, -0.1979,  ...,  0.2440,  0.0190,  0.4200],
         [-0.2110,  0.0193, -0.3820,  ...,  0.0796,  0.0849,  0.1457],
         [ 0.1017, -0.0742, -0.2684,  ...,  0.1443,  0.1109,  0.1662],
         ...,
         [ 0.0252, -0.0407, -0.0086,  ...,  0.0308, -0.0259,  0.0387],
         [ 0.0216, -0.0290, -0.0182,  ...,  0.0261, -0.0294,  0.0390],
         [ 0.0297, -0.0350, -0.0167,  ...,  0.0406, -0.0295,  0.0408]],

        [[ 0.2520,  0.4761, -0.4031,  ...,  0.2052,  0.3203,  0.2161],
         [ 0.1707,  0.2257,  0.0142,  ...,  0.0013,  0.2052, -0.0628],
         [ 0.0692,  0.2422, -0.1263,  ..., -0.0167, -0.0215, -0.1381],
         ...,
         [ 0.0351, -0.0340, -0.0125,  ...,  0.0210, -0.0496,  0.0555],
         [ 0.0469, -0.0355, -0.0117,  ...,  0.0182, -0.0533,  0.0542],
         [ 0.0301, -0.0378, -0.0109,  ...,  0.0235, -0.0519,  0.0561]],

        [[ 0.0508, -0.1723, -0.2799,  ...,  0.2172, -0.2496, -0.2822],
         [ 0.0254, -0.4323, -0.1117,  ..., -0