In [1]:
import torch
inputs = torch.tensor(
[[0.43, 0.15, 0.89], # Your (x^1)
[0.55, 0.87, 0.66], # journey (x^2)
[0.57, 0.85, 0.64], # starts (x^3)
[0.22, 0.58, 0.33], # with (x^4)
[0.77, 0.25, 0.10], # one (x^5)
[0.05, 0.80, 0.55]] # step (x^6)
)

In [2]:
new_input = torch.stack((inputs, inputs), dim = 0)
new_input

tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])

# Single causual attention:

In [3]:
from torch import nn

In [4]:
class CausualAttention(nn.Module):
  def __init__(self, dim_in, dim_out, context_length, drop_out_rate, ena_bias = False):
    super().__init__()
    self.W_Q =  nn.Linear(dim_in, dim_out, bias = ena_bias)
    self.W_K =  nn.Linear(dim_in, dim_out, bias = ena_bias)
    self.W_V =  nn.Linear(dim_in, dim_out, bias = ena_bias)

    self.dropout_layer = nn.Dropout(drop_out_rate)
    self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

  def forward(self, x):

    batch, num_tokens, dim = x.shape
    queries = self.W_Q(x)
    keys = self.W_K(x)
    values = self.W_V(x)

    attention_score = queries @ keys.transpose(1,2) # .transpose(1, 2) for transpose in each batch, '1' for the 2nd dim, '2' for the 3rd dim

    attention_score.masked_fill_(
        self.mask.bool()[:num_tokens, :num_tokens], -torch.inf
    )

    attention_weight = torch.softmax(attention_score / keys.shape[-1] ** 0.5, dim = -1)

    # dropout
    attention_weight = self.dropout_layer(attention_weight)

    contexts = attention_weight @ values

    return contexts



In [5]:
ca = CausualAttention(3, 2, 6, 0.0)
cntx = ca.forward(new_input)
cntx

tensor([[[-0.0361, -0.3012],
         [-0.0312, -0.2411],
         [-0.0332, -0.2184],
         [-0.0185, -0.1739],
         [-0.0648, -0.1729],
         [-0.0290, -0.1491]],

        [[-0.0361, -0.3012],
         [-0.0312, -0.2411],
         [-0.0332, -0.2184],
         [-0.0185, -0.1739],
         [-0.0648, -0.1729],
         [-0.0290, -0.1491]]], grad_fn=<UnsafeViewBackward0>)

# Multihead attention (wrapper):

In [6]:
class MultiheadAttentionWrapper(nn.Module):
  def __init__(self, dim_in, dim_out, context_length, drop_out_rate, num_heads, ena_bias = False):
    super().__init__()
    self.heads = nn.ModuleList([CausualAttention(dim_in, dim_out, context_length, drop_out_rate, ena_bias) for i in range(num_heads)])

  def forward(self, x):
    return torch.cat([head.forward(x) for head in self.heads], dim = -1)

In [7]:
torch.manual_seed(123)
mlha = MultiheadAttentionWrapper(dim_in = 3, dim_out=2, context_length=6, drop_out_rate=0.0, num_heads=3)
cnx = mlha.forward(new_input)
cnx

tensor([[[-0.4519,  0.2216,  0.4772,  0.1063,  0.4566,  0.2729],
         [-0.5874,  0.0058,  0.5891,  0.3257,  0.5792,  0.3011],
         [-0.6300, -0.0632,  0.6202,  0.3860,  0.6249,  0.3102],
         [-0.5675, -0.0843,  0.5478,  0.3589,  0.5691,  0.2785],
         [-0.5526, -0.0981,  0.5321,  0.3428,  0.5543,  0.2520],
         [-0.5299, -0.1081,  0.5077,  0.3493,  0.5337,  0.2499]],

        [[-0.4519,  0.2216,  0.4772,  0.1063,  0.4566,  0.2729],
         [-0.5874,  0.0058,  0.5891,  0.3257,  0.5792,  0.3011],
         [-0.6300, -0.0632,  0.6202,  0.3860,  0.6249,  0.3102],
         [-0.5675, -0.0843,  0.5478,  0.3589,  0.5691,  0.2785],
         [-0.5526, -0.0981,  0.5321,  0.3428,  0.5543,  0.2520],
         [-0.5299, -0.1081,  0.5077,  0.3493,  0.5337,  0.2499]]],
       grad_fn=<CatBackward0>)

# The real multihead attention:

In [14]:
class MultiheadAttention(nn.Module):
  def __init__(self, d_in, d_out, num_heads, context_length, drop_out_rate, ena_bias = False):
    super().__init__()

    # wtf syntax?
    assert (d_out % num_heads == 0), \
      "d_out must be divisible by num_heads"

    self.d_out = d_out
    self.num_heads = num_heads
    self.head_dim = self.d_out // self.num_heads

    self.W_Q = nn.Linear(d_in, d_out, bias = ena_bias)
    self.W_K = nn.Linear(d_in, d_out, bias = ena_bias)
    self.W_V = nn.Linear(d_in, d_out, bias = ena_bias)

    # projection?
    self.out_proj = nn.Linear(d_out, d_out)

    self.drop_out_layer = nn.Dropout(drop_out_rate)
    self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal = 1))

  def forward(self, x):
    b, num_tokens, d_in = x.shape

    queries = self.W_Q(x)
    keys = self.W_K(x)
    values = self.W_V(x)

    queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
    keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
    values = values.view(b, num_tokens, self.num_heads, self.head_dim)

    queries = queries.transpose(1,2)
    keys = keys.transpose(1,2)
    values = values.transpose(1,2)

    attention_score = queries @ keys.transpose(2, 3)

    mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
    attention_score.masked_fill_(mask_bool, -torch.inf)
    attention_score = attention_score / self.head_dim**0.5
    attention_weight = torch.softmax(attention_score, dim = -1)
    attention_weight = self.drop_out_layer(attention_weight)

    context_vectors = (attention_weight @ values).transpose(1, 2)

    context_vectors = context_vectors.contiguous().view(b, num_tokens, self.d_out)

    # combs for learning relationship of head's results
    context_vectors = self.out_proj(context_vectors)

    return context_vectors


In [16]:
torch.manual_seed(123)
mlha = MultiheadAttention(d_in = 3, d_out=3, context_length = 6, drop_out_rate=0.0, num_heads=3)
cnx = mlha.forward(new_input)
cnx

tensor([[[ 0.0766,  0.0755, -0.0321],
         [ 0.0311,  0.1048, -0.0368],
         [ 0.0165,  0.1088, -0.0409],
         [-0.0470,  0.0841, -0.0825],
         [-0.1018,  0.0327, -0.1292],
         [-0.1060,  0.0508, -0.1246]],

        [[ 0.0766,  0.0755, -0.0321],
         [ 0.0311,  0.1048, -0.0368],
         [ 0.0165,  0.1088, -0.0409],
         [-0.0470,  0.0841, -0.0825],
         [-0.1018,  0.0327, -0.1292],
         [-0.1060,  0.0508, -0.1246]]], grad_fn=<ViewBackward0>)