3.3.1  simple self-attention mechanism without weights

In [21]:
import torch
from torch import manual_seed
from win32inetcon import WINHTTP_QUERY_MAX

In [22]:
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [23]:
input_query = inputs[1]

In [24]:
attention_scores = torch.matmul(inputs, inputs.transpose(0, 1))
attention_weights = torch.softmax(attention_scores, dim = 1)
context_vectors = torch.matmul(attention_weights, inputs)

In [25]:
attention_weights

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

In [26]:
context_vectors

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

3.4 Implementing Self-Attention Weights with trainable weights

In [27]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

In [28]:
torch.manual_seed(123)

WQ = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=True)
WK = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=True)
WV = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=True)

In [29]:
X = inputs

In [30]:
Q = X @ WQ
K = X @ WK
V = X @ WV

In [31]:
Q.requires_grad

True

In [32]:
K

tensor([[0.3669, 0.7646],
        [0.4433, 1.1419],
        [0.4361, 1.1156],
        [0.2408, 0.6706],
        [0.1827, 0.3292],
        [0.3275, 0.9642]], grad_fn=<MmBackward0>)

In [33]:
attention_scores = Q @ K.transpose(0, 1)

In [34]:
attention_scores

tensor([[0.9231, 1.3545, 1.3241, 0.7910, 0.4032, 1.1330],
        [1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440],
        [1.2544, 1.8284, 1.7877, 1.0654, 0.5508, 1.5238],
        [0.6973, 1.0167, 0.9941, 0.5925, 0.3061, 0.8475],
        [0.6114, 0.8819, 0.8626, 0.5121, 0.2707, 0.7307],
        [0.8995, 1.3165, 1.2871, 0.7682, 0.3937, 1.0996]],
       grad_fn=<MmBackward0>)

In [35]:
#transformation normalization
d_k = WK.shape[-1]
attention_weights = torch.softmax(attention_scores / (d_k ** 0.5), dim = 1)

In [36]:
attention_weights

tensor([[0.1551, 0.2104, 0.2059, 0.1413, 0.1074, 0.1799],
        [0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820],
        [0.1503, 0.2256, 0.2192, 0.1315, 0.0914, 0.1819],
        [0.1591, 0.1994, 0.1962, 0.1477, 0.1206, 0.1769],
        [0.1610, 0.1949, 0.1923, 0.1501, 0.1265, 0.1752],
        [0.1557, 0.2092, 0.2048, 0.1419, 0.1089, 0.1794]],
       grad_fn=<SoftmaxBackward0>)

In [37]:
attention_weights.sum(dim=1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       grad_fn=<SumBackward1>)

In [38]:
context_vectors = attention_weights @ V
context_vectors

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)

3.4.2 Implementing compact self-attention Pytorch class

In [39]:
X = inputs
d_in = X.size(-1)
d_out = 2

In [40]:
class SelfAttentionV1(torch.nn.Module):
    def __init__(self, d_in, d_out = 2, manual_seed = 123):
        super(SelfAttention, self).__init__()

        torch.manual_seed(123)
        self.d_in = d_in
        self.d_out = d_out
        self.manual_seed = manual_seed

        torch.manual_seed(manual_seed)
        self.WQ = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=True)
        self.WK = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=True)
        self.WV = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=True)
        self.d_k = self.WK.shape[-1]

    def forward(self, X):
        """
        Computing context vectors utilizing the self-attention mechanism
        :param X:
        :return: Context vectors
        """
        #Projecting input X onto the key, query, and value vectors
        Q = X @ self.WQ
        K = X @ self.WK
        V = X @ self.WV

        #d_k = self.WK.shape[-1]
        attention_scores = Q @ K.transpose(0, 1)
        attention_weights = torch.softmax(attention_scores / (self.d_k ** 0.5), dim = 1)
        #Enriched input vector with contribution from other vectors
        context_vectors = attention_weights @ V

        return context_vectors


In [41]:
class SelfAttentionV2(torch.nn.Module):
    def __init__(self, d_in, d_out = 2, manual_seed = 123, qkv_bias = False):
        super(SelfAttentionV2, self).__init__()

        torch.manual_seed(123)
        self.d_in = d_in
        self.d_out = d_out
        self.manual_seed = manual_seed
        self.qkv_bias = qkv_bias

        torch.manual_seed(manual_seed)
        self.WQ = torch.nn.Linear(d_in, d_out, bias = self.qkv_bias)
        self.WK = torch.nn.Linear(d_in, d_out, bias = self.qkv_bias)
        self.WV = torch.nn.Linear(d_in, d_out, bias = self.qkv_bias)
        self.d_k = d_out

    def forward(self, X):
        """
        Computing context vectors utilizing the self-attention mechanism
        :param X:
        :return: Context vectors
        """
        #Projecting input X onto the key, query, and value vectors
        Q = self.WQ(X)
        K = self.WK(X)
        V = self.WV(X)

        #d_k = self.WK.shape[-1]
        attention_scores = Q @ K.transpose(0, 1)
        attention_weights = torch.softmax(attention_scores / (self.d_k ** 0.5), dim = 1)
        #Enriched input vector with contribution from other vectors
        context_vectors = attention_weights @ V

        return context_vectors


In [42]:
sattn = SelfAttentionV2(d_in, d_out)

In [43]:
#Different result due to the different default lInear layer initialization
sattn(X)

tensor([[-0.5337, -0.1051],
        [-0.5323, -0.1080],
        [-0.5323, -0.1079],
        [-0.5297, -0.1076],
        [-0.5311, -0.1066],
        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)

In [44]:
sattn.WQ.weight

Parameter containing:
tensor([[-0.2354,  0.0191, -0.2867],
        [ 0.2177, -0.4919,  0.4232]], requires_grad=True)

3.5. Hiding future words with attention maps.

3.5.1 Applying casual attention mask.

In [47]:
context_length = 6
mask = torch.tril(torch.ones(context_length, context_length))
masked = attention_scores.masked_fill(mask.bool(), -torch.inf)

In [48]:
mask

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])

In [57]:
masked

tensor([[  -inf, 1.3545, 1.3241, 0.7910, 0.4032, 1.1330],
        [  -inf,   -inf, 1.8111, 1.0795, 0.5577, 1.5440],
        [  -inf,   -inf,   -inf, 1.0654, 0.5508, 1.5238],
        [  -inf,   -inf,   -inf,   -inf, 0.3061, 0.8475],
        [  -inf,   -inf,   -inf,   -inf,   -inf, 0.7307],
        [  -inf,   -inf,   -inf,   -inf,   -inf,   -inf]],
       grad_fn=<MaskedFillBackward0>)

In [56]:
torch.triu(torch.ones(context_length, context_length), diagonal = 1)

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])

9.5.2 Masking additional attention weights with dropout

In [51]:
torch.manual_seed(123)
layer = torch.nn.Dropout(0.5)

In [52]:
example = torch.ones(6,6)
layer(example)

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])

3.5.3 Creating compact casual attention self-attention class.

In [54]:
batch = torch.stack((inputs, inputs), dim = 0)
batch.size()

torch.Size([2, 6, 3])

In [92]:
#Initial draft - Self Attention with Casual Mask
class CasualAttention(torch.nn.Module):
    def __init__(self, d_in, d_out, context_length, manual_seed = 123, dropout_rate = 0.1, qkv_bias = False):
        super(CasualAttention, self).__init__()

        torch.manual_seed(123)
        self.d_in = d_in
        self.d_out = d_out
        self.manual_seed = manual_seed
        self.context_length = context_length
        self.dropout_rate = dropout_rate
        self.qkv_bias = qkv_bias
        self.register_buffer('casual_mask', torch.triu(torch.ones(context_length, context_length), diagonal = 1))

        torch.manual_seed(manual_seed)
        self.WQ = torch.nn.Linear(d_in, d_out, bias = self.qkv_bias)
        self.WK = torch.nn.Linear(d_in, d_out, bias = self.qkv_bias)
        self.WV = torch.nn.Linear(d_in, d_out, bias = self.qkv_bias)
        self.d_k = d_out

        self.dropout = torch.nn.Dropout(self.dropout_rate)
        #self.casual_mask = torch.tril(torch.ones(self., context_length))

    def forward(self, X):
        """
        Computing context vectors utilizing the self-attention mechanism
        :param X: (batch_size, sequence_length, d_in)
        :return: Context vectors
        """
        #Projecting input X onto the key, query, and value vectors
        Q = self.WQ(X)
        K = self.WK(X)
        V = self.WV(X)

        #d_k = self.WK.shape[-1]
        attention_scores = Q @ K.transpose(1, 2)
        num_tokens = X.size(1)
        attention_scores = attention_scores.masked_fill(self.casual_mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attention_weights = torch.softmax(attention_scores / (self.d_k ** 0.5), dim = -1)
        print(attention_weights)
        attention_weights = self.dropout(attention_weights)

        #Enriched input vector with contribution from other vectors
        context_vectors = attention_weights @ V

        return context_vectors

In [93]:
casualAttn = CasualAttention(d_in = batch.size(-1), d_out = 2, context_length = 1024, dropout_rate = 0.0, manual_seed = 123 ) #.to('cuda')
#casualAttn(batch.to('cuda'))
casualAttn(batch)

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3190, 0.3408, 0.3402, 0.0000, 0.0000, 0.0000],
         [0.2445, 0.2545, 0.2542, 0.2468, 0.0000, 0.0000],
         [0.1994, 0.2060, 0.2058, 0.1935, 0.1953, 0.0000],
         [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3190, 0.3408, 0.3402, 0.0000, 0.0000, 0.0000],
         [0.2445, 0.2545, 0.2542, 0.2468, 0.0000, 0.0000],
         [0.1994, 0.2060, 0.2058, 0.1935, 0.1953, 0.0000],
         [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]]],
       grad_fn=<SoftmaxBackward0>)


tensor([[[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]],

        [[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)

3.6 Extending single-head attention with multi-head attention.

3.6.1 Stacking multiple single-head attention layers

In [98]:
class MultiHeadAttentionWrapper(torch.nn.Module):
    def __init__(self, num_heads, d_in, d_out, context_length, manual_seed = 123, dropout_rate = 0.1, qkv_bias = False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by n_heads"
        self.heads = torch.nn.ModuleList([CasualAttention(d_in , d_out, context_length, dropout_rate = dropout_rate, manual_seed = manual_seed, qkv_bias = qkv_bias ) for _ in range(num_heads)])

    def forward(self, X):
        return torch.cat([head(X) for head in self.heads], dim = -1)


In [99]:
mha = MultiHeadAttentionWrapper(num_heads = 2, d_in = d_in, d_out = d_out, context_length = context_length, manual_seed = 123, dropout_rate = 0.0)
mha(batch)

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3190, 0.3408, 0.3402, 0.0000, 0.0000, 0.0000],
         [0.2445, 0.2545, 0.2542, 0.2468, 0.0000, 0.0000],
         [0.1994, 0.2060, 0.2058, 0.1935, 0.1953, 0.0000],
         [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3190, 0.3408, 0.3402, 0.0000, 0.0000, 0.0000],
         [0.2445, 0.2545, 0.2542, 0.2468, 0.0000, 0.0000],
         [0.1994, 0.2060, 0.2058, 0.1935, 0.1953, 0.0000],
         [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]]],
       grad_fn=<SoftmaxBackward0>)
tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3190, 0.3408, 0.3402, 0.0000, 0.0000, 0.0000],
         [0.2445, 0.2545, 0.2542, 0.2468, 0.0000, 0.0000],
         [0.1994,

tensor([[[-0.4519,  0.2216, -0.4519,  0.2216],
         [-0.5874,  0.0058, -0.5874,  0.0058],
         [-0.6300, -0.0632, -0.6300, -0.0632],
         [-0.5675, -0.0843, -0.5675, -0.0843],
         [-0.5526, -0.0981, -0.5526, -0.0981],
         [-0.5299, -0.1081, -0.5299, -0.1081]],

        [[-0.4519,  0.2216, -0.4519,  0.2216],
         [-0.5874,  0.0058, -0.5874,  0.0058],
         [-0.6300, -0.0632, -0.6300, -0.0632],
         [-0.5675, -0.0843, -0.5675, -0.0843],
         [-0.5526, -0.0981, -0.5526, -0.0981],
         [-0.5299, -0.1081, -0.5299, -0.1081]]], grad_fn=<CatBackward0>)