In [1]:
# This notebook demonstrates advanced attention mechanisms in PyTorch
import torch

In [2]:
# Create an embedding layer with 4 possible inputs and an embedding dimension of 8
inputs = torch.nn.Embedding( 4, 8 )

In [3]:
# Access the weights of the embedding layer
inputs = inputs.weight
inputs 

Parameter containing:
tensor([[-0.7134, -0.3138,  0.5046, -0.8650,  0.4425, -1.1841, -0.7362,  1.1409],
        [ 1.0442, -0.3844,  0.0344,  0.7320, -1.7047,  0.8123,  1.1130, -1.3390],
        [ 0.9229,  0.6272,  0.8922, -1.2379, -1.0845, -1.8604, -0.1238,  0.6430],
        [ 0.1835, -1.5439, -0.5757, -0.6085,  0.9697,  1.0065,  0.5641,  0.1937]],
       requires_grad=True)

In [4]:
# Convert the weights to a tensor
inputs = inputs.data
inputs

tensor([[-0.7134, -0.3138,  0.5046, -0.8650,  0.4425, -1.1841, -0.7362,  1.1409],
        [ 1.0442, -0.3844,  0.0344,  0.7320, -1.7047,  0.8123,  1.1130, -1.3390],
        [ 0.9229,  0.6272,  0.8922, -1.2379, -1.0845, -1.8604, -0.1238,  0.6430],
        [ 0.1835, -1.5439, -0.5757, -0.6085,  0.9697,  1.0065,  0.5641,  0.1937]])

In [5]:
# Set dimensions
d_in = 8
d_out = 6
# create weight matrices
W_q = torch.nn.Parameter( torch.randn( (d_in, d_out), requires_grad=False ) )
W_k = torch.nn.Parameter( torch.randn( (d_in, d_out), requires_grad=False ) )
W_v = torch.nn.Parameter( torch.randn( (d_in, d_out), requires_grad=False ) )

In [6]:
# Choose an input vector and transform it into our query vector using W_q
query = inputs[2] @ W_q
query

tensor([-1.8183,  2.5620, -2.9488,  4.0440, -0.8484, -0.2352],
       grad_fn=<SqueezeBackward4>)

In [7]:
# calculate attention scores using the keys generated by W_k
keys = inputs @ W_k
values = inputs @ W_v
print("Keys:" , keys)
print("Values:" , values)

Keys: tensor([[-3.7374,  1.8343, -1.3490, -0.3107,  1.2861,  0.6717],
        [ 7.2285, -1.7452,  1.8141, -2.9936, -3.2840, -0.8258],
        [ 2.1841,  4.8209, -2.9804, -3.8569, -1.9949,  4.4026],
        [-2.2238, -0.4988,  0.9372, -1.4520,  0.2754, -5.0220]],
       grad_fn=<MmBackward0>)
Values: tensor([[-3.6430, -1.5747, -3.4956, -0.2345, -4.3542, -2.8565],
        [ 3.3199,  1.3926,  7.3296, -0.6822,  3.6932,  3.9693],
        [-2.7182, -1.0483,  0.5692, -1.3523, -5.8348, -1.6021],
        [-0.8293, -0.2994, -1.2222,  1.6421,  2.3553, -1.9783]],
       grad_fn=<MmBackward0>)


In [8]:
attention_scores = query @ keys.T # query is 1 by 6 and keys is 4 by 6 so we need to transpose keys
attention_scores

tensor([ 12.9675, -32.0893,   2.2283,  -4.9221], grad_fn=<SqueezeBackward4>)

In [9]:
attention_weights = torch.softmax( attention_scores / keys.shape[-1]**0.5, dim=-1 ) # the softmax function normalizes the scores
attention_weights

tensor([9.8703e-01, 1.0133e-08, 1.2310e-02, 6.6450e-04],
       grad_fn=<SoftmaxBackward0>)

In [10]:
attention_weights.sum() # ensure the weights sum to 1

tensor(1., grad_fn=<SumBackward0>)

In [11]:
# calculate the context vector as a weighted sum of the values
context_vector = attention_weights @ values 
context_vector

tensor([-3.6298, -1.5674, -3.4440, -0.2470, -4.3679, -2.8404],
       grad_fn=<SqueezeBackward4>)

In [12]:

import torch.nn as nn

In [13]:
# Define a simple attention module
class SimpleAttention( nn.Module ):
    def __init__(self, d_in, d_out):
        super().__init__()
        #create weight matrices
        self.W_q = nn.Parameter( torch.randn( (d_in, d_out), requires_grad=False ) )
        self.W_k = nn.Parameter( torch.randn( (d_in, d_out), requires_grad=False ) )
        self.W_v = nn.Parameter( torch.randn( (d_in, d_out), requires_grad=False ) )

    # x = embedding vectors (inputs)
    def forward(self, x):
        queries = x @ self.W_q
        keys = x @ self.W_k
        values = x @ self.W_v
        scores = queries @ keys.T
        weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
        context = weights @ values
        return context; 

In [14]:
# use case
# instance of the class
simple = SimpleAttention( d_in = 8, d_out = 6 )

In [15]:
# 
simple.W_k

Parameter containing:
tensor([[-5.1242e-01,  2.0099e+00,  3.1250e-01, -1.9914e-01, -2.2349e+00,
         -4.9228e-01],
        [-3.2044e+00, -8.1958e-01,  9.9606e-01,  4.5301e-01,  4.3805e-02,
         -4.2438e-01],
        [ 3.8090e-01,  1.0083e+00, -1.1598e+00, -2.2863e-01, -5.0853e-01,
         -1.1355e+00],
        [ 8.0831e-01,  1.2780e+00,  1.1620e+00,  1.1173e-01,  4.7857e-01,
          4.2691e-01],
        [ 1.6157e-01,  1.1227e+00, -4.5885e-01, -3.5608e-01, -1.5787e+00,
          4.1086e-01],
        [-6.2358e-01, -3.4047e-01, -1.0906e+00, -2.4596e-03,  1.6934e-01,
          4.8300e-01],
        [-1.3639e+00,  1.8232e+00,  1.4452e+00,  1.0897e-01, -1.4025e+00,
         -1.1445e+00],
        [ 2.2419e+00,  6.5118e-02, -1.4556e+00, -7.6229e-01,  5.4553e-01,
         -5.8214e-01]], requires_grad=True)

In [16]:
# Create the context vectors by passing the input embeddings to the attention module
context_vectors = simple( inputs )
context_vectors

tensor([[ 0.6029,  2.7619, -1.6765, -1.7295, -1.0699,  1.2854],
        [-1.6005, -0.0114,  1.0342,  1.8843, -3.2254,  3.8953],
        [ 0.6028,  2.7617, -1.6766, -1.7293, -1.0703,  1.2854],
        [-1.5524, -0.4926,  2.6534,  0.8774,  2.8478, -0.5394]],
       grad_fn=<MmBackward0>)

In [17]:
# second version of the class
# it uses nn.Linear to do things more effectively

class SimpleAttentionv2( nn.Module ):
    def __init__(self, d_in, d_out):
        super().__init__()
        #create weight matrices
        self.W_q = nn.Linear( d_in, d_out, bias=False )
        self.W_k = nn.Linear( d_in, d_out, bias=False )
        self.W_v = nn.Linear( d_in, d_out, bias=False )


    # x = embedding vectors (inputs)
    def forward( self, x ):
        queries = self.W_q( x )
        keys = self.W_k( x )
        values = self.W_v( x )
        scores = queries @ keys.T
        weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
        context = weights @ values
        return context; 

In [18]:
# use case
# instance of the class
simple = SimpleAttentionv2( d_in = 8, d_out = 6 )

In [19]:
context_vectors = simple( inputs )
context_vectors

tensor([[-0.0360,  0.3353, -0.4085, -0.0902, -0.0380,  0.1301],
        [-0.0656,  0.3166, -0.4281, -0.0242, -0.0819,  0.1063],
        [-0.2505,  0.4939, -0.3197,  0.0262, -0.0536,  0.1213],
        [-0.0275,  0.3167, -0.4300, -0.0818, -0.0508,  0.1211]],
       grad_fn=<MmBackward0>)

In [20]:
# the problem with this is that each context vector uses isnformation from all of thje embedding vectors
# in practice, we should only use information from the previous vectors
# to accomplish this, we'll implement causal attention AKA masked attention
weights = torch.randn( inputs.shape[0], inputs.shape[0] )

In [21]:
weights

tensor([[-0.6461, -0.6025,  0.1016,  0.6006],
        [ 0.2894,  0.8791,  0.1680,  1.7750],
        [-2.7479, -0.3686, -0.1043,  1.2650],
        [ 0.3311, -1.2953, -1.6291,  0.2144]])

In [22]:

weights.sum( dim=-1 )

tensor([-0.5464,  3.1116, -1.9557, -2.3789])

In [23]:
# torch.tril?
# Create a lower triangular mask to apply causal attention
simple_mask = torch.tril( torch.ones( weights.shape[0], weights.shape[0] ) )
simple_mask

tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]])

In [24]:
# Apply the mask to the weights
masked_weights = weights*simple_mask
masked_weights

tensor([[-0.6461, -0.0000,  0.0000,  0.0000],
        [ 0.2894,  0.8791,  0.0000,  0.0000],
        [-2.7479, -0.3686, -0.1043,  0.0000],
        [ 0.3311, -1.2953, -1.6291,  0.2144]])

In [25]:
#
masked_weights.sum( dim=-1 )

tensor([-0.6461,  1.1686, -3.2207, -2.3789])

In [26]:
# now, we need to normalize the masked weights so that they sum to 1
row_sums = masked_weights.sum( dim=-1, keepdim=True )
row_sums

tensor([[-0.6461],
        [ 1.1686],
        [-3.2207],
        [-2.3789]])

In [27]:
# Normalize the masked weights by the sum of each row
masked_weights = masked_weights / row_sums
masked_weights

tensor([[ 1.0000,  0.0000, -0.0000, -0.0000],
        [ 0.2477,  0.7523,  0.0000,  0.0000],
        [ 0.8532,  0.1144,  0.0324, -0.0000],
        [-0.1392,  0.5445,  0.6848, -0.0901]])

In [28]:
# masking method #2
# torch.triu?
# use the upper triangular part of the matrix to create a mask
mask = torch.triu( torch.ones(weights.shape[0], weights.shape[0]), diagonal=1 )
mask

tensor([[0., 1., 1., 1.],
        [0., 0., 1., 1.],
        [0., 0., 0., 1.],
        [0., 0., 0., 0.]])

In [29]:
mask.bool()

tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])

In [30]:
weights

tensor([[-0.6461, -0.6025,  0.1016,  0.6006],
        [ 0.2894,  0.8791,  0.1680,  1.7750],
        [-2.7479, -0.3686, -0.1043,  1.2650],
        [ 0.3311, -1.2953, -1.6291,  0.2144]])

In [31]:
# Apply the mask to the weights by setting the masked positions to -infinity
weights = weights.masked_fill( mask.bool(), -torch.inf)
weights

tensor([[-0.6461,    -inf,    -inf,    -inf],
        [ 0.2894,  0.8791,    -inf,    -inf],
        [-2.7479, -0.3686, -0.1043,    -inf],
        [ 0.3311, -1.2953, -1.6291,  0.2144]])

In [32]:
# Apply the softmax function to the masked weights
masked_weights = torch.softmax( weights, dim=-1 )
masked_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.3567, 0.6433, 0.0000, 0.0000],
        [0.0387, 0.4175, 0.5438, 0.0000],
        [0.4490, 0.0883, 0.0632, 0.3995]])

In [33]:
masked_weights.sum( dim=-1 )

tensor([1.0000, 1.0000, 1.0000, 1.0000])

In [34]:
# Dropout 
# Dropout is a regularization technique used to prevent overfitting in neural networks.
# It works by randomly setting a fraction of input units to zero at each update during training time,
# which helps to break up happenstance correlations in the training data.
dropout = nn.Dropout( 0.5 ) # 50% dropout rate


In [35]:
dropout( inputs )

tensor([[-1.4267, -0.6277,  1.0092, -1.7300,  0.8851, -2.3683, -1.4723,  2.2819],
        [ 2.0884, -0.0000,  0.0687,  0.0000, -0.0000,  1.6247,  0.0000, -2.6780],
        [ 0.0000,  1.2545,  0.0000, -0.0000, -2.1691, -3.7208, -0.2475,  0.0000],
        [ 0.3669, -3.0878, -0.0000, -1.2170,  1.9395,  2.0130,  0.0000,  0.0000]])

In [36]:
# We need to be able to give our LLM vbatches of input.
# For example:
batches = torch.stack((inputs, inputs), dim = 0 )


In [37]:
# torch.stack?
batches.shape

torch.Size([2, 4, 8])

In [38]:
# this class needs to hande batches of input


class CausalAttention( nn.Module ):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        #create weight matrices
        self.W_q = nn.Linear( d_in, d_out, bias=False )
        self.W_k = nn.Linear( d_in, d_out, bias=False )
        self.W_v = nn.Linear( d_in, d_out, bias=False )
        # include dropout:
        self.dropout = nn.Dropout( dropout )
        # use the following to manage memory effeciently
        self.register_buffer("mask", torch.triu( torch.ones(context_length, context_length), diagonal = 1 ))
        


    # x = embedding vectors (inputs)
    def forward( self, x ):
        b, num_tokens, d_in = x.shape
        # b = batch size 

        queries = self.W_q( x )
        keys = self.W_k( x )
        values = self.W_v( x )
        scores = queries @ keys.transpose(1, 2)
        scores.masked_fill_( self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
        weights = self.dropout( weights )
        context = weights @ values
        return context

In [39]:
# instiantiate a causal attention mechanism:
causal = CausalAttention( d_in = 8, d_out = 6, context_length = 4, dropout = 0 )

In [40]:
# pass the batches of input to the causal attention mechanism
causal( batches )

tensor([[[-0.5626,  0.4381,  0.5288,  0.2468,  0.3463,  0.7203],
         [ 0.2169,  0.0861,  0.2949, -0.1872, -0.4878, -0.2642],
         [-0.1225,  0.3905,  0.5408, -0.0017, -0.0463,  0.4485],
         [-0.3033,  0.2605,  0.3082, -0.3160, -0.1470,  0.3007]],

        [[-0.5626,  0.4381,  0.5288,  0.2468,  0.3463,  0.7203],
         [ 0.2169,  0.0861,  0.2949, -0.1872, -0.4878, -0.2642],
         [-0.1225,  0.3905,  0.5408, -0.0017, -0.0463,  0.4485],
         [-0.3033,  0.2605,  0.3082, -0.3160, -0.1470,  0.3007]]],
       grad_fn=<UnsafeViewBackward0>)

In [41]:
# Define linear layers for query, key, and value projections
W_q = nn.Linear( d_in, d_out, bias=False )
W_k = nn.Linear( d_in, d_out, bias=False )
W_v = nn.Linear( d_in, d_out, bias=False )

In [42]:
queries = W_q( batches )
queries 

tensor([[[-0.2314, -0.1071,  0.1299,  0.5449,  0.7863,  0.5378],
         [ 0.1127,  0.4679, -0.7802, -0.8144, -1.0449, -0.4840],
         [-0.4056, -0.3382, -0.6517, -0.4530, -0.3455,  0.8290],
         [ 0.5569,  0.6049,  0.6544,  0.2646,  0.1090, -0.3899]],

        [[-0.2314, -0.1071,  0.1299,  0.5449,  0.7863,  0.5378],
         [ 0.1127,  0.4679, -0.7802, -0.8144, -1.0449, -0.4840],
         [-0.4056, -0.3382, -0.6517, -0.4530, -0.3455,  0.8290],
         [ 0.5569,  0.6049,  0.6544,  0.2646,  0.1090, -0.3899]]],
       grad_fn=<UnsafeViewBackward0>)

In [43]:
keys = W_k( batches )
keys 

tensor([[[-0.1995,  0.4013,  0.4363, -0.4461,  0.0416,  0.3854],
         [-0.1706, -0.4734, -0.2557,  0.4549, -0.1253, -0.3865],
         [-1.3976,  0.1492,  0.1313, -0.9270, -0.6716, -0.5799],
         [ 0.9769, -0.6248,  0.2957,  0.6422,  0.5827,  0.3082]],

        [[-0.1995,  0.4013,  0.4363, -0.4461,  0.0416,  0.3854],
         [-0.1706, -0.4734, -0.2557,  0.4549, -0.1253, -0.3865],
         [-1.3976,  0.1492,  0.1313, -0.9270, -0.6716, -0.5799],
         [ 0.9769, -0.6248,  0.2957,  0.6422,  0.5827,  0.3082]]],
       grad_fn=<UnsafeViewBackward0>)

In [44]:
# shows the transpose of keys

# keys.T

In [45]:
# here's a first pass at multi-head attention (not very efficient yet)
class MultiHeadAttention( nn.Module ):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)]
        )
    
    def forward( self, x ):
        return torch.cat( [ head(x) for head in self.heads ], dim=-1 )


In [46]:
# instantiate a multi-head attention mechanism:
mha = MultiHeadAttention( d_in = 8, d_out = 6, context_length = 4, dropout = 0, num_heads = 3 )

In [47]:
# pass the batches of input to the multi-head attention mechanism
mha_out = mha( batches )

In [48]:
mha_out

tensor([[[-0.4055, -0.4930,  0.4081, -0.3899,  1.0181,  0.0141, -0.3746,
          -0.1924, -0.1201, -0.5103,  0.2540, -0.1809,  0.6103,  0.3714,
          -0.3196,  0.0750, -0.3010,  0.1979],
         [-0.2147, -0.3419,  0.3811, -0.2669,  0.5802,  0.1264, -0.0176,
          -0.2792,  0.0234, -0.1230,  0.1672, -0.0025, -0.0643, -0.2203,
          -0.0196,  0.2090, -0.0799,  0.2043],
         [-0.1336, -0.1653,  0.5755, -0.0134, -0.1535,  0.5064, -0.2484,
          -0.4976, -0.0326, -0.0932,  0.2035, -0.2161, -0.0228, -0.1837,
          -0.1181,  0.3191, -0.1306,  0.3001],
         [-0.0582, -0.1563,  0.4592, -0.2106,  0.3428,  0.2693, -0.0687,
          -0.3354,  0.2976, -0.1127,  0.1458,  0.0727, -0.0300, -0.2310,
           0.1396,  0.2900, -0.1200,  0.2233]],

        [[-0.4055, -0.4930,  0.4081, -0.3899,  1.0181,  0.0141, -0.3746,
          -0.1924, -0.1201, -0.5103,  0.2540, -0.1809,  0.6103,  0.3714,
          -0.3196,  0.0750, -0.3010,  0.1979],
         [-0.2147, -0.3419,  0.38

In [49]:
mha_out.shape

torch.Size([2, 4, 18])

In [50]:
# more efficient version of multi-head attention
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        
        super().__init__() # Call the parent class's constructor
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        # Create a causal mask to prevent attention to future tokens
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        # As in `CausalAttention`, for inputs where `num_tokens` exceeds `context_length`, 
        # this will result in errors in the mask creation further below. 
        # In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs  
        # do not exceed `context_length` before reaching this forward method.

        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) 
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2) 
        
        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec


In [51]:
batches.shape

torch.Size([2, 4, 8])

In [52]:
batches

tensor([[[-0.7134, -0.3138,  0.5046, -0.8650,  0.4425, -1.1841, -0.7362,
           1.1409],
         [ 1.0442, -0.3844,  0.0344,  0.7320, -1.7047,  0.8123,  1.1130,
          -1.3390],
         [ 0.9229,  0.6272,  0.8922, -1.2379, -1.0845, -1.8604, -0.1238,
           0.6430],
         [ 0.1835, -1.5439, -0.5757, -0.6085,  0.9697,  1.0065,  0.5641,
           0.1937]],

        [[-0.7134, -0.3138,  0.5046, -0.8650,  0.4425, -1.1841, -0.7362,
           1.1409],
         [ 1.0442, -0.3844,  0.0344,  0.7320, -1.7047,  0.8123,  1.1130,
          -1.3390],
         [ 0.9229,  0.6272,  0.8922, -1.2379, -1.0845, -1.8604, -0.1238,
           0.6430],
         [ 0.1835, -1.5439, -0.5757, -0.6085,  0.9697,  1.0065,  0.5641,
           0.1937]]])

In [53]:
batches.view( 2, 4, 2, 4)

tensor([[[[-0.7134, -0.3138,  0.5046, -0.8650],
          [ 0.4425, -1.1841, -0.7362,  1.1409]],

         [[ 1.0442, -0.3844,  0.0344,  0.7320],
          [-1.7047,  0.8123,  1.1130, -1.3390]],

         [[ 0.9229,  0.6272,  0.8922, -1.2379],
          [-1.0845, -1.8604, -0.1238,  0.6430]],

         [[ 0.1835, -1.5439, -0.5757, -0.6085],
          [ 0.9697,  1.0065,  0.5641,  0.1937]]],


        [[[-0.7134, -0.3138,  0.5046, -0.8650],
          [ 0.4425, -1.1841, -0.7362,  1.1409]],

         [[ 1.0442, -0.3844,  0.0344,  0.7320],
          [-1.7047,  0.8123,  1.1130, -1.3390]],

         [[ 0.9229,  0.6272,  0.8922, -1.2379],
          [-1.0845, -1.8604, -0.1238,  0.6430]],

         [[ 0.1835, -1.5439, -0.5757, -0.6085],
          [ 0.9697,  1.0065,  0.5641,  0.1937]]]])

In [54]:
mha = MultiHeadAttention( d_in = 8, d_out = 6, context_length = 4, dropout = 0, num_heads = 3 )

In [55]:
mha_out = mha( batches )

In [56]:
mha_out

tensor([[[ 0.4650,  0.6056, -0.5900,  0.4056, -0.4189,  0.0184],
         [ 0.0365, -0.0218, -0.3391, -0.1822,  0.1810, -0.0429],
         [ 0.0910,  0.3097, -0.3078,  0.0245, -0.1352, -0.1352],
         [ 0.0068,  0.0135, -0.2639, -0.2703, -0.0235, -0.1038]],

        [[ 0.4650,  0.6056, -0.5900,  0.4056, -0.4189,  0.0184],
         [ 0.0365, -0.0218, -0.3391, -0.1822,  0.1810, -0.0429],
         [ 0.0910,  0.3097, -0.3078,  0.0245, -0.1352, -0.1352],
         [ 0.0068,  0.0135, -0.2639, -0.2703, -0.0235, -0.1038]]],
       grad_fn=<ViewBackward0>)

In [57]:
mha_out.shape

torch.Size([2, 4, 6])