In [16]:
import torch
import torch.nn as nn

In [17]:
inputs = torch.tensor([
    [0.43, 0.15, 0.89], # Your      (x^1)
    [0.55, 0.87, 0.66], # journey   (x^2)
    [0.57, 0.85, 0.64], # starts    (x^3)
    [0.22, 0.58, 0.33], # with      (x^4)
    [0.77, 0.25, 0.10], # one       (x^5)
    [0.05, 0.80, 0.55]  # step      (x^6)
])

batch = torch.stack((inputs, inputs), dim=0)

In [18]:
print("Batch:\n\n", batch, end="\n\n")
print("Batch Shape =>", batch.shape)

Batch:

 tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])

Batch Shape => torch.Size([2, 6, 3])


In [19]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        
        self.Wq = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.Wk = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.Wv = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )
    
    def forward(self, x):
        b, num_tokens, d_in = x.shape
        
        queries = self.Wq(x)
        keys = self.Wk(x)
        values = self.Wv(x)

        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = attn_weights @ values
        return context_vec

In [20]:
d_in = inputs.shape[1]
d_out = 2
context_length = batch.shape[1]

In [21]:
torch.manual_seed(666)

causal_attn = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = causal_attn(batch)

In [22]:
print("Context Vectors:\n\n", context_vecs, end="\n\n")
print("Context Vectors Shape =>", context_vecs.shape)

Context Vectors:

 tensor([[[0.1888, 0.2392],
         [0.2342, 0.3205],
         [0.2489, 0.3449],
         [0.2147, 0.3002],
         [0.2409, 0.3300],
         [0.2113, 0.2948]],

        [[0.1888, 0.2392],
         [0.2342, 0.3205],
         [0.2489, 0.3449],
         [0.2147, 0.3002],
         [0.2409, 0.3300],
         [0.2113, 0.2948]]], grad_fn=<UnsafeViewBackward0>)

Context Vectors Shape => torch.Size([2, 6, 2])
