Attention Mechanisms
1. Self attention

In [1]:
import torch

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [2]:
# X2 as query compute the attention to other words
query = inputs[1]
attn_score = torch.empty(inputs.shape[0])
# torch.dot = element wise multiple and sum 
for i, xi in enumerate(inputs):
    attn_score[i] = torch.dot(xi, query)

s = 0.43 * 0.55 + 0.15 * 0.87 + 0.89 * 0.66
print(attn_score)
print(s)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])
0.9544


In [3]:
# normalize with softmax 
attn_weight = torch.softmax(attn_score, dim=-1)
print(attn_weight)

tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])


In [4]:
# Z2 hidden stage of X2, contain all the information with X2 and its relation with all other X
z2 = 0
for i, xi in enumerate(attn_weight):
    z2 += inputs[i] * xi
print(z2)

tensor([0.4419, 0.6515, 0.5683])


In [5]:
# compute attn weight for all 
attn_scores = inputs @ inputs.T
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [6]:
attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


In [7]:
print(attn_weights)
print(attn_weights.shape)
print(inputs.shape)
# think of Z is a enrich Inputs, with information from other words, same dimention as Inputs 
Z = attn_weights @ inputs
print(Z)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])
torch.Size([6, 6])
torch.Size([6, 3])
tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


2. self attention with trainabel weights 

Q K V 
previously use a vector to represent value and query. now use some trainable weights to represent query key and value 
query like question, what is the next word or what it is object 
key like index, like this word seems to be like a object, or this is the verb. 
value, the representation of the meaning. 
all use different matrix and weights to represent. 

In [8]:
torch.manual_seed(123)
d_in = inputs.shape[1]
d_out = 2

# set requires_grad = True when training 
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

querys = inputs @ W_query
keys = inputs @ W_key
values = inputs @ W_value

In [9]:
print(inputs[1].shape)
print(W_query.shape)
query_2 = inputs[1] @ W_query
print(keys.shape)
attn_score_2 = query_2 @ keys.T
print(keys.T)
print(attn_score_2.shape)
d_k = keys.shape[1]
print(d_k)
attn_weights_2 = torch.softmax(attn_score_2 / d_k ** 0.5, dim=-1)
print(attn_weights_2)

torch.Size([3])
torch.Size([3, 2])
torch.Size([6, 2])
tensor([[0.3669, 0.4433, 0.4361, 0.2408, 0.1827, 0.3275],
        [0.7646, 1.1419, 1.1156, 0.6706, 0.3292, 0.9642]])
torch.Size([6])
2
tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])


In [10]:
print(attn_weights_2.shape)
print(values.shape)
Z_2 = attn_weights_2 @ values 
print(Z_2.shape)

torch.Size([6])
torch.Size([6, 2])
torch.Size([2])


In [11]:
print(querys.shape)
print(keys.shape)

attn_scores = querys @ keys.T
print(attn_scores)
attn_weights = torch.softmax(attn_scores / keys.shape[1] ** 0.5, dim=-1)
print(attn_weights)
print(values.shape)
Z = attn_weights @ values 
print(Z.shape)

torch.Size([6, 2])
torch.Size([6, 2])
tensor([[0.9231, 1.3545, 1.3241, 0.7910, 0.4032, 1.1330],
        [1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440],
        [1.2544, 1.8284, 1.7877, 1.0654, 0.5508, 1.5238],
        [0.6973, 1.0167, 0.9941, 0.5925, 0.3061, 0.8475],
        [0.6114, 0.8819, 0.8626, 0.5121, 0.2707, 0.7307],
        [0.8995, 1.3165, 1.2871, 0.7682, 0.3937, 1.0996]])
tensor([[0.1551, 0.2104, 0.2059, 0.1413, 0.1074, 0.1799],
        [0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820],
        [0.1503, 0.2256, 0.2192, 0.1315, 0.0914, 0.1819],
        [0.1591, 0.1994, 0.1962, 0.1477, 0.1206, 0.1769],
        [0.1610, 0.1949, 0.1923, 0.1501, 0.1265, 0.1752],
        [0.1557, 0.2092, 0.2048, 0.1419, 0.1089, 0.1794]])
torch.Size([6, 2])
torch.Size([6, 2])


In [12]:
import torch.nn as nn

class SelfAttention_v1(nn.Module):

    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        # use nn.Linear is better 
        # self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        # self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        # self.W_value = nn.Parameter(torch.rand(d_in, d_out))  
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        context_value = attn_weights @ values 

        return context_value

In [13]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[-0.5300, -0.0988],
        [-0.5317, -0.1005],
        [-0.5317, -0.1005],
        [-0.5301, -0.1040],
        [-0.5298, -0.1011],
        [-0.5307, -0.1042]], grad_fn=<MmBackward0>)


3. Casual attention, hide the connection of future word 


In [14]:
class CausalAttention(nn.Module):

    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__() 
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        print(queries.shape)
        print(keys.T.shape)
        attn_scores = queries @ keys.transpose(1,2)
        attn_score.masked_fill(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        context_value = attn_weights @ values 

        return context_value

In [15]:
torch.manual_seed(123)

batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)

context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)

context_vecs = ca(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

torch.Size([2, 6, 3])
torch.Size([2, 6, 2])
torch.Size([2, 6, 2])
tensor([[[-0.5300, -0.0988],
         [-0.5317, -0.1005],
         [-0.5317, -0.1005],
         [-0.5301, -0.1040],
         [-0.5298, -0.1011],
         [-0.5307, -0.1042]],

        [[-0.5300, -0.0988],
         [-0.5317, -0.1005],
         [-0.5317, -0.1005],
         [-0.5301, -0.1040],
         [-0.5298, -0.1011],
         [-0.5307, -0.1042]]], grad_fn=<UnsafeViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])


  print(keys.T.shape)


4. Multi-Head attention 

In [16]:
class MultiHeadAttention(nn.Module):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)()
        pass

    def forword(self):
        pass
    

In [19]:
torch.manual_seed(123)

batch_size, context_length, d_in = batch.shape
d_out = 2
num_heads = 2 

Q_weights = nn.Linear(d_in, d_out, bias=False)
K_weights = nn.Linear(d_in, d_out, bias=False)
V_weights = nn.Linear(d_in, d_out, bias=False)

Q = Q_weights(batch)
K = K_weights(batch)
V = V_weights(batch)

Q = Q.view(batch_size, context_length, num_heads, d_out // num_heads)
K = K.view(batch_size, context_length, num_heads, d_out // num_heads)
V = V.view(batch_size, context_length, num_heads, d_out // num_heads)

K = K.transpose(1, 2)
Q = Q.transpose(1, 2)
V = V.transpose(1, 2)

attn_scores = Q @ K.transpose(-2, -1)
attn_weights = torch.softmax(attn_scores / K.shape[-1] ** 0.5, dim=-1)
context_vecs = attn_weights @ V
context_vecs = context_vecs.transpose(1, 2).contiguous().view(batch_size, context_length, d_out)

print(context_vecs)

print(batch.shape)
# mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)

# context_vecs = mha(batch)

# print(context_vecs)
# print("context_vecs.shape:", context_vecs.shape)

tensor([[[-0.5354, -0.1019],
         [-0.5343, -0.1065],
         [-0.5343, -0.1064],
         [-0.5307, -0.1072],
         [-0.5322, -0.1052],
         [-0.5311, -0.1077]],

        [[-0.5354, -0.1019],
         [-0.5343, -0.1065],
         [-0.5343, -0.1064],
         [-0.5307, -0.1072],
         [-0.5322, -0.1052],
         [-0.5311, -0.1077]]], grad_fn=<ViewBackward0>)
torch.Size([2, 6, 3])


In [20]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) 
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2) 
        
        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec

In [21]:
torch.manual_seed(123)

batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)

context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])
