### casual attention

In [1]:
import torch
import torch.nn as nn


In [18]:
class CasualAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, num_heads, head_dim, dropout, kvq_bias=False):
        super().__init__()
        
        assert (d_out % num_heads == 0), "d_out must be dividable to num_head." 
        
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = self.d_out // self.num_heads 
        self.W_queries = nn.Linear(d_in, d_out, bias=kvq_bias)
        self.W_keys = nn.Linear(d_in, d_out, bias=kvq_bias)
        self.W_values = nn.Linear(d_in, d_out, bias=kvq_bias)   
        self.out_proj = nn.Linear(d_out, d_out) 
        self.dropout = dropout
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
        
        
    def forward(self, x):
        b, num_tokens, d_in = x.shape
        
        queries = self.W_queries(x) 
        keys = self.W_keys(x)
        values = self.W_values(x)
        
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        
        queries = queries.transpose(1, 2) # b, num_heads, num_token, head_dim
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        
        attention_scores = queries @ keys.transpose(2, 3)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attention_scores.masked_fill_(mask_bool, -torch.inf)
        
        attention_weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim = -1) 
        
        context_vec = (attention_weights @ values).transpose(1, 2)
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)
        
        return context_vec
        

In [19]:
torch.manual_seed(123)

inputs = torch.rand(6, 3)
inputs = torch.stack((inputs, inputs), dim = 0)
d_in = inputs.shape[2]
d_out = 64 
context_length = 6
num_heads = 4
head_dim = 16
dropout = 0.0
ca = CasualAttention(d_in, d_out, context_length, num_heads, head_dim, dropout)
ca(inputs)

tensor([[[ 0.1437,  0.2184,  0.0615,  0.2260,  0.0170,  0.1026, -0.0006,
           0.1692,  0.0425,  0.0391,  0.0152,  0.0437,  0.1384, -0.1728,
           0.0266,  0.0096,  0.1692, -0.2495, -0.1524, -0.0415, -0.0241,
           0.0205,  0.0444,  0.0144,  0.3617,  0.0575,  0.1927,  0.0241,
          -0.2630, -0.0534,  0.0893, -0.2652, -0.0309,  0.1866,  0.1256,
          -0.0518, -0.0425, -0.1486,  0.2586,  0.1293,  0.0839, -0.1025,
           0.0048, -0.1052,  0.0315,  0.2913,  0.0947, -0.0315,  0.2273,
           0.2647, -0.1179, -0.3454, -0.1044,  0.1476, -0.0801,  0.0619,
          -0.3243,  0.0524, -0.0838,  0.0319,  0.2107, -0.0425,  0.1470,
          -0.2584],
         [ 0.1269,  0.2548,  0.1296,  0.1513,  0.0902,  0.2092,  0.0165,
           0.1988,  0.1055,  0.1252,  0.0861,  0.0134,  0.1304, -0.1216,
           0.0976, -0.0189,  0.2475, -0.0715, -0.2573, -0.0576, -0.2501,
           0.2173,  0.0005,  0.0693,  0.3349,  0.1248,  0.1267, -0.0421,
          -0.1279,  0.0631,  0.