### Coding Attention Mechanisms

In [41]:
# A compact self-attention class. It uses nn.Parameters to define the weight matrices Wk Wq Wv.

import torch.nn as nn
import torch

class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        K = x @ self.W_key
        Q = x @ self.W_query
        V = x @ self.W_value

        attention_scores = Q @ K.T
        attention_weights = torch.softmax(attention_scores / K.shape[-1]**0.5, dim=-1)
        return attention_weights @ V
    
    def updateWeights(self, input_w_key,input_w_query,input_w_value):
        self.W_key = nn.Parameter(input_w_key)
        self.W_query = nn.Parameter(input_w_query)
        self.W_value = nn.Parameter(input_w_value)

In [42]:
torch.manual_seed(123)

# Example input values to try the self-attention layer.
inputs = torch.tensor(
[[0.43, 0.15, 0.89],
 [0.57, 0.85, 0.64], 
 [0.22, 0.58, 0.33],
 [0.77, 0.25, 0.10],
 [0.05, 0.80, 0.55]]
)

d_in, d_out = 3, 2
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))
del inputs, d_in, d_out

tensor([[0.2685, 0.7413],
        [0.2738, 0.7564],
        [0.2668, 0.7366],
        [0.2618, 0.7218],
        [0.2712, 0.7495]], grad_fn=<MmBackward0>)


In [43]:
# Self-attention V2. It uses nn.Linear to define the matrices. nn.Linear is preferred as it has efficient initialization as well as optimized matrix multiplication.

class SelfAttention_v2(nn.Module):

    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_keys = nn.Linear(d_in, d_out, bias=False)
        self.W_queries = nn.Linear(d_in, d_out, bias=False)
        self.W_values = nn.Linear(d_in, d_out, bias=False)

    def forward(self, x):
        # Compute Keys, Queries, Values.
        K = self.W_keys(x)
        Q = self.W_queries(x)
        V = self.W_values(x)
        # Compute attention weights
        attn_scores = Q @ K.t()
        attn_weights = torch.softmax(attn_scores / K.shape[-1]**0.5, dim=-1)
        return  attn_weights @ V
        # return attn_weights @ V

In [44]:
torch.manual_seed(123)

d_in = 3
d_out = 2
# Example input values to try the self-attention layer.
inputs = torch.tensor(
[[0.43, 0.15, 0.89],
 [0.57, 0.85, 0.64], 
 [0.22, 0.58, 0.33],
 [0.77, 0.25, 0.10],
 [0.05, 0.80, 0.55]]
)

SelfAttention_v2(d_in, d_out)(inputs)

tensor([[-0.4927, -0.0791],
        [-0.4938, -0.0806],
        [-0.4924, -0.0851],
        [-0.4923, -0.0819],
        [-0.4928, -0.0853]], grad_fn=<MmBackward0>)

In [47]:
# Validate that both attention mechanisms output the same values.

torch.manual_seed(123)

d_in = 3
d_out = 2
# Example input values to try the self-attention layer.
inputs = torch.tensor(
[[0.43, 0.15, 0.89],
 [0.57, 0.85, 0.64], 
 [0.22, 0.58, 0.33],
 [0.77, 0.25, 0.10],
 [0.05, 0.80, 0.55]]
)

sav2 = SelfAttention_v2(d_in, d_out)
sav1 = SelfAttention_v1(d_in, d_out)
# Update SelfAttentionV1 to be the same as for SelfAttentionV2.
sav1.updateWeights(sav2.W_keys.weight.t(), sav2.W_queries.weight.t(), sav2.W_values.weight.t())

print("Self Attention V2 output: \n", sav2(inputs))
print("Self Attention V1 output: \n", sav1(inputs))


Self Attention V2 output: 
 tensor([[-0.4927, -0.0791],
        [-0.4938, -0.0806],
        [-0.4924, -0.0851],
        [-0.4923, -0.0819],
        [-0.4928, -0.0853]], grad_fn=<MmBackward0>)
Self Attention V1 output: 
 tensor([[-0.4927, -0.0791],
        [-0.4938, -0.0806],
        [-0.4924, -0.0851],
        [-0.4923, -0.0819],
        [-0.4928, -0.0853]], grad_fn=<MmBackward0>)
