### Coding Attention Mechanisms

In [57]:
# A compact self-attention class. It uses nn.Parameters to define the weight matrices Wk Wq Wv.

import torch.nn as nn
import torch

class SelfAttentionV1(nn.Module):

    def __init__(self, dIn, dOut):
        super().__init__()
        self.wKey = nn.Parameter(torch.rand(dIn, dOut))
        self.wQuery = nn.Parameter(torch.rand(dIn, dOut))
        self.wValue = nn.Parameter(torch.rand(dIn, dOut))

    def forward(self, x):
        K = x @ self.wKey
        Q = x @ self.wQuery
        V = x @ self.wValue

        attnScores = Q @ K.T
        attnWeights = torch.softmax(attnScores / K.shape[-1]**0.5, dim=-1)
        return attnWeights @ V
    
    def updateWeights(self, wKey, wQuery, wValue):
        self.wKey = nn.Parameter(wKey)
        self.wQuery = nn.Parameter(wQuery)
        self.wValue = nn.Parameter(wValue)

In [62]:
torch.manual_seed(123)

# Example input values to try the self-attention layer.
inputs = torch.tensor(
[[0.43, 0.15, 0.89],
 [0.57, 0.85, 0.64], 
 [0.22, 0.58, 0.33],
 [0.77, 0.25, 0.10],
 [0.05, 0.80, 0.55]]
)

dIn, dOut = 3, 2
print(SelfAttentionV1(dIn, dOut)(inputs))
del inputs, dIn, dOut

tensor([[0.2685, 0.7413],
        [0.2738, 0.7564],
        [0.2668, 0.7366],
        [0.2618, 0.7218],
        [0.2712, 0.7495]], grad_fn=<MmBackward0>)


In [63]:
# Self-attention V2. It uses nn.Linear to define the matrices. nn.Linear is preferred as it has efficient initialization as well as optimized matrix multiplication.

class SelfAttentionV2(nn.Module):

    def __init__(self, dIn, dOut):
        super().__init__()
        self.wKeys = nn.Linear(dIn, dOut, bias=False)
        self.wQueries = nn.Linear(dIn, dOut, bias=False)
        self.wValues = nn.Linear(dIn, dOut, bias=False)

    def forward(self, x):
        # Compute Keys, Queries, Values.
        K = self.wKeys(x)
        Q = self.wQueries(x)
        V = self.wValues(x)
        # Compute attention weights
        attnScores = Q @ K.t()
        attnWeights = torch.softmax(attnScores / K.shape[-1]**0.5, dim=-1)
        return attnWeights @ V
        # return attnWeights @ V

In [64]:
torch.manual_seed(123)

dIn = 3
dOut = 2
# Example input values to try the self-attention layer.
inputs = torch.tensor(
[[0.43, 0.15, 0.89],
 [0.57, 0.85, 0.64], 
 [0.22, 0.58, 0.33],
 [0.77, 0.25, 0.10],
 [0.05, 0.80, 0.55]]
)

print(SelfAttentionV2(dIn, dOut)(inputs))
del inputs, dIn, dOut

tensor([[-0.4927, -0.0791],
        [-0.4938, -0.0806],
        [-0.4924, -0.0851],
        [-0.4923, -0.0819],
        [-0.4928, -0.0853]], grad_fn=<MmBackward0>)


In [65]:
# Validate that both attention mechanisms output the same values.

torch.manual_seed(123)

dIn = 3
dOut = 2
# Example input values to try the self-attention layer.
inputs = torch.tensor(
[[0.43, 0.15, 0.89],
 [0.57, 0.85, 0.64], 
 [0.22, 0.58, 0.33],
 [0.77, 0.25, 0.10],
 [0.05, 0.80, 0.55]]
)

sav2 = SelfAttentionV2(dIn, dOut)
sav1 = SelfAttentionV1(dIn, dOut)
# Update SelfAttentionV1 to be the same as for SelfAttentionV2.
sav1.updateWeights(sav2.wKeys.weight.t(), sav2.wQueries.weight.t(), sav2.wValues.weight.t())

print("Self Attention V2 output: \n", sav2(inputs))
print("Self Attention V1 output: \n", sav1(inputs))
del inputs, dIn, dOut

Self Attention V2 output: 
 tensor([[-0.4927, -0.0791],
        [-0.4938, -0.0806],
        [-0.4924, -0.0851],
        [-0.4923, -0.0819],
        [-0.4928, -0.0853]], grad_fn=<MmBackward0>)
Self Attention V1 output: 
 tensor([[-0.4927, -0.0791],
        [-0.4938, -0.0806],
        [-0.4924, -0.0851],
        [-0.4923, -0.0819],
        [-0.4928, -0.0853]], grad_fn=<MmBackward0>)
