In [1]:
# This notebook demonstrates advanced attention mechanisms in PyTorch
import torch

In [2]:
inputs = torch.nn.Embedding( 4, 8 )

In [3]:
inputs = inputs.weight
inputs 

Parameter containing:
tensor([[ 4.7603e-01, -4.2865e-01, -1.4078e-01, -2.5170e-01, -6.4746e-01,
          9.2566e-01,  1.6874e+00,  1.9702e+00],
        [ 9.2507e-02,  1.9285e-03, -1.4027e+00,  1.0059e-01, -1.4050e+00,
         -3.5821e-01, -4.6841e-01,  3.9729e-01],
        [-2.0567e-01, -1.2674e+00, -1.9255e-01, -2.6036e-02, -1.1910e-01,
         -5.4830e-01, -8.6023e-01, -3.4981e-01],
        [ 6.5876e-01,  4.6910e-01,  2.1262e+00,  1.0740e+00,  1.6816e+00,
          1.2896e+00, -2.1932e-01,  1.0227e+00]], requires_grad=True)

In [4]:
inputs = inputs.data
inputs

tensor([[ 4.7603e-01, -4.2865e-01, -1.4078e-01, -2.5170e-01, -6.4746e-01,
          9.2566e-01,  1.6874e+00,  1.9702e+00],
        [ 9.2507e-02,  1.9285e-03, -1.4027e+00,  1.0059e-01, -1.4050e+00,
         -3.5821e-01, -4.6841e-01,  3.9729e-01],
        [-2.0567e-01, -1.2674e+00, -1.9255e-01, -2.6036e-02, -1.1910e-01,
         -5.4830e-01, -8.6023e-01, -3.4981e-01],
        [ 6.5876e-01,  4.6910e-01,  2.1262e+00,  1.0740e+00,  1.6816e+00,
          1.2896e+00, -2.1932e-01,  1.0227e+00]])

In [5]:
# Set dimensions
d_in = 8
d_out = 6
# create weight matrices
W_q = torch.nn.Parameter( torch.randn( (d_in, d_out), requires_grad=False ) )
W_k = torch.nn.Parameter( torch.randn( (d_in, d_out), requires_grad=False ) )
W_v = torch.nn.Parameter( torch.randn( (d_in, d_out), requires_grad=False ) )

In [6]:
# Choose an input vector and transform it into our query vector using W_q
query = inputs[2] @ W_q
query

tensor([ 1.2601,  0.9193, -0.2638, -1.5932,  1.5702, -1.0805],
       grad_fn=<SqueezeBackward4>)

In [7]:
# calculate attention scores using the keys generated by W_k
keys = inputs @ W_k
values = inputs @ W_v
print("Keys:" , keys)
print("Values:" , values)

Keys: tensor([[ 2.5200, -1.7649, -3.4682, -1.3480, -1.7786, -1.6056],
        [ 2.0886, -0.3383,  0.8250,  1.9821,  0.5763, -3.3585],
        [ 0.1160, -0.5166,  3.1862, -0.1417,  2.0463, -0.6900],
        [-8.3785,  1.5520, -3.3874,  1.6451, -0.6418,  1.7743]],
       grad_fn=<MmBackward0>)
Values: tensor([[ 1.7218,  2.5810, -0.5992, -0.2433,  3.0077,  5.4760],
        [ 4.1684,  1.2448, -0.6707, -1.1014, -2.2304, -1.5972],
        [-1.9564, -0.1963, -1.0316, -0.7024, -2.3706,  0.8555],
        [-6.2996,  0.2400, -0.2250, -0.8479,  0.7917,  0.8526]],
       grad_fn=<MmBackward0>)


In [8]:
attention_scores = query @ keys.T # query is 1 by 6 and keys is 4 by 6 so we need to transpose keys
attention_scores

tensor([  3.5575,   3.4793,   3.0153, -13.7833], grad_fn=<SqueezeBackward4>)

In [9]:
attention_weights = torch.softmax( attention_scores / keys.shape[-1]**0.5, dim=-1 ) # the softmax function normalizes the scores
attention_weights

tensor([3.6091e-01, 3.4955e-01, 2.8924e-01, 3.0400e-04],
       grad_fn=<SoftmaxBackward0>)

In [10]:
attention_weights.sum() # ensure the weights sum to 1

tensor(1.0000, grad_fn=<SumBackward0>)

In [11]:
context_vector = attention_weights @ values # 
context_vector

tensor([ 1.5107,  1.3099, -0.7492, -0.6762, -0.3796,  1.6657],
       grad_fn=<SqueezeBackward4>)

In [12]:
import torch.nn as nn

In [13]:
class SimpleAttention( nn.Module ):
    def __init__(self, d_in, d_out):
        super().__init__()
        #create weight matrices
        self.W_q = nn.Parameter( torch.randn( (d_in, d_out), requires_grad=False ) )
        self.W_k = nn.Parameter( torch.randn( (d_in, d_out), requires_grad=False ) )
        self.W_v = nn.Parameter( torch.randn( (d_in, d_out), requires_grad=False ) )

    # x = embedding vectors (inputs)
    def forward(self, x):
        queries = x @ self.W_q
        keys = x @ self.W_k
        values = x @ self.W_v
        scores = queries @ keys.T
        weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
        context = weights @ values
        return context; 

In [14]:
# use case
# instance of the class
simple = SimpleAttention( d_in = 8, d_out = 6 )

In [15]:
simple.W_k

Parameter containing:
tensor([[-0.2213, -0.9298, -0.5280,  0.4143, -0.5717, -0.5802],
        [-1.3247,  1.7368, -0.3850, -0.2259,  1.8350,  2.0955],
        [ 0.8309,  1.1852,  0.4297, -1.5593, -0.3377, -0.8801],
        [-1.3090,  0.1422, -1.1954,  0.7969, -0.0797, -0.6121],
        [ 0.4527, -0.6632,  1.1933, -1.9665, -0.2963, -0.8102],
        [ 0.0694,  0.7240,  0.0411,  0.2219, -0.7028,  1.3820],
        [ 0.8008,  0.2697, -0.0496, -0.4712,  0.5931, -1.1528],
        [ 1.7258,  0.0110,  2.1158,  1.8705,  2.3299, -0.6982]],
       requires_grad=True)

In [16]:
context_vectors = simple( inputs )
context_vectors

tensor([[ 6.9613e-01,  2.2957e+00, -5.8743e-01,  6.9857e-01, -2.4537e+00,
          1.6774e+00],
        [-4.0592e+00,  2.5668e+00,  7.7100e-01, -4.3102e+00,  7.9865e+00,
         -7.3634e-01],
        [-3.7524e+00,  2.4404e+00,  8.6552e-01, -3.9805e+00,  7.3841e+00,
         -6.1145e-01],
        [ 5.1731e-01, -1.3085e+00, -4.3256e-03,  2.0008e+00, -4.6349e+00,
         -9.1959e-01]], grad_fn=<MmBackward0>)

In [17]:
# second version of the class
# it uses nn.Linear to do things more effectively

class SimpleAttentionv2( nn.Module ):
    def __init__(self, d_in, d_out):
        super().__init__()
        #create weight matrices
        self.W_q = nn.Linear( d_in, d_out, bias=False )
        self.W_k = nn.Linear( d_in, d_out, bias=False )
        self.W_v = nn.Linear( d_in, d_out, bias=False )


    # x = embedding vectors (inputs)
    def forward( self, x ):
        queries = self.W_q( x )
        keys = self.W_k( x )
        values = self.W_v( x )
        scores = queries @ keys.T
        weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
        context = weights @ values
        return context; 

In [18]:
# use case
# instance of the class
simple = SimpleAttentionv2( d_in = 8, d_out = 6 )

In [19]:
context_vectors = simple( inputs )
context_vectors

tensor([[-0.1740, -0.3208,  0.0265,  0.0383,  0.0932, -0.1099],
        [-0.1526, -0.3690,  0.0630,  0.1315,  0.2067,  0.0984],
        [-0.3225, -0.3013,  0.0294, -0.0284,  0.0787, -0.1918],
        [-0.0835, -0.3233,  0.0186,  0.0636,  0.0451, -0.1660]],
       grad_fn=<MmBackward0>)

In [20]:
# the problem with this is that each context vector uses isnformation from all of thje embedding vectors
# in practice, we should only use information from the previous vectors
# to accomplish this, we'll implement causal attention AKA masked attention
weights = torch.randn( inputs.shape[0], inputs.shape[0] )

In [21]:
weights

tensor([[-0.0559, -0.6662, -1.0584,  1.2920],
        [-0.2979,  0.4993,  0.7955, -1.0255],
        [ 0.2182, -0.3816,  1.8443, -0.4501],
        [ 0.8371,  0.8238, -0.4044, -2.0110]])

In [22]:
weights.sum( dim=-1 )

tensor([-0.4886, -0.0286,  1.2308, -0.7544])

In [23]:
# torch.tril?
simple_mask = torch.tril( torch.ones( weights.shape[0], weights.shape[0] ) )
simple_mask

tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]])

In [24]:
masked_weights = weights*simple_mask
masked_weights

tensor([[-0.0559, -0.0000, -0.0000,  0.0000],
        [-0.2979,  0.4993,  0.0000, -0.0000],
        [ 0.2182, -0.3816,  1.8443, -0.0000],
        [ 0.8371,  0.8238, -0.4044, -2.0110]])

In [25]:
masked_weights.sum( dim=-1 )

tensor([-0.0559,  0.2014,  1.6809, -0.7544])

In [26]:
# now, we need to normalize the masked weights so that they sum to 1
row_sums = masked_weights.sum( dim=-1, keepdim=True )
row_sums

tensor([[-0.0559],
        [ 0.2014],
        [ 1.6809],
        [-0.7544]])

In [27]:
masked_weights = masked_weights / row_sums
masked_weights

tensor([[ 1.0000,  0.0000,  0.0000, -0.0000],
        [-1.4788,  2.4788,  0.0000, -0.0000],
        [ 0.1298, -0.2270,  1.0972, -0.0000],
        [-1.1096, -1.0919,  0.5361,  2.6655]])

In [28]:
# masking mehod #2
# torch.triu?
mask = torch.triu( torch.ones(weights.shape[0], weights.shape[0]), diagonal=1 )
mask

tensor([[0., 1., 1., 1.],
        [0., 0., 1., 1.],
        [0., 0., 0., 1.],
        [0., 0., 0., 0.]])

In [29]:
mask.bool()

tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])

In [30]:
weights

tensor([[-0.0559, -0.6662, -1.0584,  1.2920],
        [-0.2979,  0.4993,  0.7955, -1.0255],
        [ 0.2182, -0.3816,  1.8443, -0.4501],
        [ 0.8371,  0.8238, -0.4044, -2.0110]])

In [31]:
weights = weights.masked_fill( mask.bool(), -torch.inf)
weights

tensor([[-0.0559,    -inf,    -inf,    -inf],
        [-0.2979,  0.4993,    -inf,    -inf],
        [ 0.2182, -0.3816,  1.8443,    -inf],
        [ 0.8371,  0.8238, -0.4044, -2.0110]])

In [32]:
masked_weights = torch.softmax( weights, dim=-1 )
masked_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.3106, 0.6894, 0.0000, 0.0000],
        [0.1508, 0.0828, 0.7665, 0.0000],
        [0.4285, 0.4228, 0.1238, 0.0248]])

In [33]:
masked_weights.sum( dim=-1 )

tensor([1., 1., 1., 1.])

In [34]:
# Dropout 
# Dropout is a regularization technique used to prevent overfitting in neural networks.
# It works by randomly setting a fraction of input units to zero at each update during training time,
# which helps to break up happenstance correlations in the training data.
dropout = nn.Dropout( 0.5 ) # 50% dropout rate


In [35]:
dropout( inputs )

tensor([[ 9.5207e-01, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
          0.0000e+00,  3.3747e+00,  3.9404e+00],
        [ 1.8501e-01,  3.8569e-03, -0.0000e+00,  2.0119e-01, -0.0000e+00,
         -7.1642e-01, -9.3683e-01,  0.0000e+00],
        [-0.0000e+00, -2.5348e+00, -0.0000e+00, -5.2073e-02, -2.3819e-01,
         -1.0966e+00, -1.7205e+00, -6.9962e-01],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          2.5791e+00, -0.0000e+00,  0.0000e+00]])

In [36]:
# We need to be able to give our LLM vbatches of input.
# For example:
batches = torch.stack((inputs, inputs), dim = 0 )


In [37]:
# torch.stack?
batches.shape

torch.Size([2, 4, 8])

In [38]:
# this class needs to hande batches of input


class CausalAttention( nn.Module ):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        #create weight matrices
        self.W_q = nn.Linear( d_in, d_out, bias=False )
        self.W_k = nn.Linear( d_in, d_out, bias=False )
        self.W_v = nn.Linear( d_in, d_out, bias=False )
        # include dropout:
        self.dropout = nn.Dropout( dropout )
        # use the following to manage memory effeciently
        self.register_buffer("mask", torch.triu( torch.ones(context_length, context_length), diagonal = 1 ))
        


    # x = embedding vectors (inputs)
    def forward( self, x ):
        b, num_tokens, d_in = x.shape
        # b = batch size 

        queries = self.W_q( x )
        keys = self.W_k( x )
        values = self.W_v( x )
        scores = queries @ keys.transpose(1, 2)
        scores.masked_fill_( self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
        weights = self.dropout( weights )
        context = weights @ values
        return context

In [39]:
# instiantiate a causal attention mechanism:
causal = CausalAttention( d_in = 8, d_out = 6, context_length = 4, dropout = 0 )

In [40]:
causal( batches )

tensor([[[ 0.3318, -0.2569, -0.5377,  0.4953, -0.5009,  0.4327],
         [ 0.0201,  0.0702,  0.0479,  0.5527, -0.2927,  0.3915],
         [ 0.0886, -0.0753, -0.0943,  0.3738, -0.0094,  0.2412],
         [ 0.1576,  0.0944, -0.0472,  0.1699, -0.2471,  0.0433]],

        [[ 0.3318, -0.2569, -0.5377,  0.4953, -0.5009,  0.4327],
         [ 0.0201,  0.0702,  0.0479,  0.5527, -0.2927,  0.3915],
         [ 0.0886, -0.0753, -0.0943,  0.3738, -0.0094,  0.2412],
         [ 0.1576,  0.0944, -0.0472,  0.1699, -0.2471,  0.0433]]],
       grad_fn=<UnsafeViewBackward0>)

In [41]:
W_q = nn.Linear( d_in, d_out, bias=False )
W_k = nn.Linear( d_in, d_out, bias=False )
W_v = nn.Linear( d_in, d_out, bias=False )

In [42]:
queries = W_q( batches )
queries 

tensor([[[ 6.0770e-01,  4.3816e-02,  1.7545e-01, -1.1417e-01, -8.3971e-01,
           3.6785e-01],
         [ 1.4849e-03,  3.5108e-01,  4.7065e-01, -9.5396e-02, -2.1979e-01,
          -9.3486e-01],
         [-1.8468e-01, -9.3662e-02, -3.0378e-01, -4.8023e-01,  4.2630e-02,
          -3.1774e-01],
         [-2.5849e-01,  3.3737e-02, -5.7544e-01,  3.2325e-01,  7.3650e-01,
           1.6680e+00]],

        [[ 6.0770e-01,  4.3816e-02,  1.7545e-01, -1.1417e-01, -8.3971e-01,
           3.6785e-01],
         [ 1.4849e-03,  3.5108e-01,  4.7065e-01, -9.5396e-02, -2.1979e-01,
          -9.3486e-01],
         [-1.8468e-01, -9.3662e-02, -3.0378e-01, -4.8023e-01,  4.2630e-02,
          -3.1774e-01],
         [-2.5849e-01,  3.3737e-02, -5.7544e-01,  3.2325e-01,  7.3650e-01,
           1.6680e+00]]], grad_fn=<UnsafeViewBackward0>)

In [43]:
keys = W_k( batches )
keys 

tensor([[[ 0.2130,  0.7665,  0.2220,  0.0817,  0.2116, -0.5120],
         [-0.0357, -0.8210,  0.7736,  0.2082,  0.4418, -0.0189],
         [-0.3699, -0.3950,  0.0992,  0.0451, -0.0282,  0.2048],
         [ 1.0585,  0.8586, -0.8569, -0.0166, -0.6833, -0.5157]],

        [[ 0.2130,  0.7665,  0.2220,  0.0817,  0.2116, -0.5120],
         [-0.0357, -0.8210,  0.7736,  0.2082,  0.4418, -0.0189],
         [-0.3699, -0.3950,  0.0992,  0.0451, -0.0282,  0.2048],
         [ 1.0585,  0.8586, -0.8569, -0.0166, -0.6833, -0.5157]]],
       grad_fn=<UnsafeViewBackward0>)

In [44]:
# shows the transpose of keys

# keys.T

In [47]:
# here's a first pass at multi-head attention (not very efficient yet)
class MultiHeadAttention( nn.Module ):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)]
        )
    
    def forward( self, x ):
        return torch.cat( [ head(x) for head in self.heads ], dim=-1 )


In [48]:
mha = MultiHeadAttention( d_in = 8, d_out = 6, context_length = 4, dropout = 0, num_heads = 3 )

In [49]:
mha_out = mha( batches )

In [51]:
mha_out

tensor([[[-1.1756e+00, -5.9299e-01, -5.2680e-01, -1.8932e-01, -3.2679e-01,
          -1.9845e-01,  2.5568e-01, -5.9830e-01, -1.3662e-01,  1.7749e-01,
          -4.0063e-01, -2.4586e-01,  9.0462e-01,  1.5733e-01, -5.4913e-01,
          -2.9235e-01,  7.6064e-01, -8.8952e-04],
         [-7.6239e-01, -8.7538e-02, -2.2922e-01,  1.3593e-01, -3.4767e-01,
          -3.7863e-01,  6.4695e-02, -1.9520e-01, -2.1597e-01,  1.0846e-01,
          -3.8420e-01, -7.4878e-02,  2.1067e-01,  1.6160e-01, -4.2506e-01,
          -4.0291e-01,  4.3554e-01, -5.0313e-01],
         [-2.0241e-01,  6.4122e-02, -1.3287e-01,  1.1136e-01, -2.2071e-01,
          -3.8027e-01, -5.7724e-03,  1.7208e-01, -3.0616e-01, -1.2038e-01,
          -2.2491e-01, -1.9302e-02, -1.5479e-01,  1.5360e-01, -2.7218e-01,
          -2.7348e-01,  2.5013e-01, -3.5993e-01],
         [-3.5451e-02, -6.8312e-02, -1.2275e-01, -6.6263e-02,  8.4505e-03,
          -2.8931e-01,  4.0998e-02,  2.5444e-01, -2.8513e-01, -1.2209e-01,
          -5.9665e-02,  4

In [52]:
mha_out.shape

torch.Size([2, 4, 18])

In [53]:
# more efficient version of multi-head attention
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        # As in `CausalAttention`, for inputs where `num_tokens` exceeds `context_length`, 
        # this will result in errors in the mask creation further below. 
        # In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs  
        # do not exceed `context_length` before reaching this forward method.

        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) 
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2) 
        
        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec


In [54]:
batches.shape

torch.Size([2, 4, 8])

In [56]:
batches

tensor([[[ 4.7603e-01, -4.2865e-01, -1.4078e-01, -2.5170e-01, -6.4746e-01,
           9.2566e-01,  1.6874e+00,  1.9702e+00],
         [ 9.2507e-02,  1.9285e-03, -1.4027e+00,  1.0059e-01, -1.4050e+00,
          -3.5821e-01, -4.6841e-01,  3.9729e-01],
         [-2.0567e-01, -1.2674e+00, -1.9255e-01, -2.6036e-02, -1.1910e-01,
          -5.4830e-01, -8.6023e-01, -3.4981e-01],
         [ 6.5876e-01,  4.6910e-01,  2.1262e+00,  1.0740e+00,  1.6816e+00,
           1.2896e+00, -2.1932e-01,  1.0227e+00]],

        [[ 4.7603e-01, -4.2865e-01, -1.4078e-01, -2.5170e-01, -6.4746e-01,
           9.2566e-01,  1.6874e+00,  1.9702e+00],
         [ 9.2507e-02,  1.9285e-03, -1.4027e+00,  1.0059e-01, -1.4050e+00,
          -3.5821e-01, -4.6841e-01,  3.9729e-01],
         [-2.0567e-01, -1.2674e+00, -1.9255e-01, -2.6036e-02, -1.1910e-01,
          -5.4830e-01, -8.6023e-01, -3.4981e-01],
         [ 6.5876e-01,  4.6910e-01,  2.1262e+00,  1.0740e+00,  1.6816e+00,
           1.2896e+00, -2.1932e-01,  1.0227e+00]

In [57]:
batches.view( 2, 4, 2, 4)

tensor([[[[ 4.7603e-01, -4.2865e-01, -1.4078e-01, -2.5170e-01],
          [-6.4746e-01,  9.2566e-01,  1.6874e+00,  1.9702e+00]],

         [[ 9.2507e-02,  1.9285e-03, -1.4027e+00,  1.0059e-01],
          [-1.4050e+00, -3.5821e-01, -4.6841e-01,  3.9729e-01]],

         [[-2.0567e-01, -1.2674e+00, -1.9255e-01, -2.6036e-02],
          [-1.1910e-01, -5.4830e-01, -8.6023e-01, -3.4981e-01]],

         [[ 6.5876e-01,  4.6910e-01,  2.1262e+00,  1.0740e+00],
          [ 1.6816e+00,  1.2896e+00, -2.1932e-01,  1.0227e+00]]],


        [[[ 4.7603e-01, -4.2865e-01, -1.4078e-01, -2.5170e-01],
          [-6.4746e-01,  9.2566e-01,  1.6874e+00,  1.9702e+00]],

         [[ 9.2507e-02,  1.9285e-03, -1.4027e+00,  1.0059e-01],
          [-1.4050e+00, -3.5821e-01, -4.6841e-01,  3.9729e-01]],

         [[-2.0567e-01, -1.2674e+00, -1.9255e-01, -2.6036e-02],
          [-1.1910e-01, -5.4830e-01, -8.6023e-01, -3.4981e-01]],

         [[ 6.5876e-01,  4.6910e-01,  2.1262e+00,  1.0740e+00],
          [ 1.6816e+00, 

In [59]:
mha = MultiHeadAttention( d_in = 8, d_out = 6, context_length = 4, dropout = 0, num_heads = 3 )

In [60]:
mha_out = mha( batches )

In [62]:
mha_out

tensor([[[-0.2968, -0.6419,  0.2184,  0.0102,  0.2530,  0.1569],
         [-0.1804, -0.4610,  0.1877,  0.2508,  0.3569,  0.0738],
         [-0.0145, -0.2382,  0.1124,  0.2263,  0.2202,  0.1077],
         [-0.0110, -0.2989,  0.0722,  0.1058,  0.1845,  0.1893]],

        [[-0.2968, -0.6419,  0.2184,  0.0102,  0.2530,  0.1569],
         [-0.1804, -0.4610,  0.1877,  0.2508,  0.3569,  0.0738],
         [-0.0145, -0.2382,  0.1124,  0.2263,  0.2202,  0.1077],
         [-0.0110, -0.2989,  0.0722,  0.1058,  0.1845,  0.1893]]],
       grad_fn=<ViewBackward0>)

In [64]:
mha_out.shape

torch.Size([2, 4, 6])