In [14]:
import torch
from torch import nn
import torch.nn.functional as F

## AttentionHead

In [28]:
class AttentionHead(nn.Module):
    def __init__(self, num_features, num_steps,head_size):
        super().__init__()
        self.query = nn.Linear(num_features, head_size, bias=False)
        self.key = nn.Linear(num_features, head_size, bias=False)
        self.value = nn.Linear(num_features, head_size, bias=False)
        self.head_size = head_size
        # tensor that aren't parameters
        self.register_buffer('tril_mask', torch.tril(torch.ones((num_steps, num_steps))))
    
    def forward(self, x):
        B, T, C = x.shape
        q = self.query(x) # (B, T, H) - H=head_size
        k = self.key(x) # (B, T, H)
        v = self.value(x) # (B, T, H)
        # affinities inter tokens (and scale)
        wei = q @ k.transpose(-2, -1) * (self.head_size**-0.5) # (B, T, T)
        # mask future tokens
        wei = wei.masked_fill(self.tril_mask[:T, :T] == 0, float('-inf'))
        # normalize each row(each token interactions in last dim)
        wei = F.softmax(wei, dim=-1)
        y = wei @ v # apply affinities (B, T, H) weighted agreggation
        return y





X = torch.randn(4, 7, 2)
B, T, C = X.shape

att = AttentionHead(C, T,64)

att(X)





tensor([[[ 5.0066e-01, -6.3600e-01,  4.1748e-02,  ..., -6.0802e-01,
           1.7435e-01, -2.7198e-02],
         [ 6.5966e-01, -9.9477e-01,  1.0112e-01,  ..., -7.8291e-01,
          -3.0049e-02, -3.5142e-01],
         [ 3.8858e-01, -4.8320e-01,  2.9340e-02,  ..., -4.7311e-01,
           1.5257e-01, -1.5589e-04],
         ...,
         [ 2.4082e-01, -3.3719e-01,  2.9280e-02,  ..., -2.8882e-01,
           3.2035e-02, -7.6046e-02],
         [ 1.2454e-01, -1.6791e-01,  1.3239e-02,  ..., -1.5011e-01,
           2.7284e-02, -2.6306e-02],
         [ 3.1544e-01, -4.0989e-01,  2.9004e-02,  ..., -3.8202e-01,
           9.4638e-02, -3.5619e-02]],

        [[ 6.6103e-01, -9.9216e-01,  9.9953e-02,  ..., -7.8507e-01,
          -2.2384e-02, -3.4276e-01],
         [ 3.2419e-01, -4.7322e-01,  4.5087e-02,  ..., -3.8658e-01,
           1.1185e-02, -1.4118e-01],
         [ 8.7523e-01, -1.3788e+00,  1.5149e-01,  ..., -1.0319e+00,
          -1.3749e-01, -5.8486e-01],
         ...,
         [ 5.4210e-01, -8