In [1]:
import torch
import torch.nn as nn

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, d_model=2, row_dim=0, col_dim=1):
        super().__init__()
        self.d_model = d_model # the size of the embedding dimension
        self.row_dim = row_dim
        self.col_dim = col_dim
        
        self.w_query = nn.Linear(d_model, d_model, bias=False)
        self.w_key = nn.Linear(d_model, d_model, bias=False)
        self.w_value = nn.Linear(d_model, d_model, bias=False)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        query = self.w_query(x)
        key = self.w_key(x)
        value = self.w_value(x)
        print(query.shape, key.shape, value.shape)
        print(key.transpose(0, 1).shape)

        sims = torch.matmul(query, key.transpose(0, 1))
        print("sims", sims)

        scaled_sims = sims / (self.d_model ** 0.5)
        print("scaled_sims", scaled_sims)

        attn_weights = self.softmax(torch.matmul(query, key.transpose(0, 1)) / (key.size(self.col_dim) ** 0.5))
        attn_output = torch.matmul(attn_weights, value)
        return attn_output

In [15]:
## create a matrix of token encodings...
encodings_matrix = torch.tensor([[1.16, 0.23],
                                 [0.57, 1.36],
                                 [4.41, -2.16]])

## set the seed for the random number generator
torch.manual_seed(42)

selfAttention = SelfAttention(d_model=2, row_dim=0, col_dim=1)
print(selfAttention)

## calculate basic attention for the token encodings
selfAttention(encodings_matrix)

SelfAttention(
  (w_query): Linear(in_features=2, out_features=2, bias=False)
  (w_key): Linear(in_features=2, out_features=2, bias=False)
  (w_value): Linear(in_features=2, out_features=2, bias=False)
  (softmax): Softmax(dim=-1)
)
torch.Size([3, 2]) torch.Size([3, 2]) torch.Size([3, 2])
torch.Size([2, 3])
sims tensor([[-0.0990,  0.0648, -0.6523],
        [-0.4022,  0.4078, -3.0024],
        [ 0.4842, -0.6683,  4.0461]], grad_fn=<MmBackward0>)
scaled_sims tensor([[-0.0700,  0.0458, -0.4612],
        [-0.2844,  0.2883, -2.1230],
        [ 0.3424, -0.4725,  2.8610]], grad_fn=<DivBackward0>)


tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)

In [19]:
## print out the weight matrix that creates the queries
print(selfAttention.w_query.weight.transpose(0, 1))

## print out the weight matrix that creates the queries
print(selfAttention.w_key.weight.transpose(0, 1))

## print out the weight matrix that creates the queries
print(selfAttention.w_value.weight.transpose(0, 1))

## calculate the queries
q = selfAttention.w_query(encodings_matrix)
k = selfAttention.w_key(encodings_matrix)
v = selfAttention.w_value(encodings_matrix)
print(q, k, v)

sims = torch.matmul(q, k.transpose(0, 1))
scaled_sims = sims / selfAttention.w_key.weight.size(1) ** 0.5
attn_weight = torch.softmax(scaled_sims, dim=-1)
attn_output = torch.matmul(attn_weight, v)

print(sims)
print(scaled_sims)
print(attn_weight)
print(attn_output)


tensor([[ 0.5406, -0.1657],
        [ 0.5869,  0.6496]], grad_fn=<TransposeBackward0>)
tensor([[-0.1549, -0.3443],
        [ 0.1427,  0.4153]], grad_fn=<TransposeBackward0>)
tensor([[ 0.6233,  0.6146],
        [-0.5188,  0.1323]], grad_fn=<TransposeBackward0>)
tensor([[ 0.7621, -0.0428],
        [ 1.1063,  0.7890],
        [ 1.1164, -2.1336]], grad_fn=<MmBackward0>) tensor([[-0.1469, -0.3038],
        [ 0.1057,  0.3685],
        [-0.9914, -2.4152]], grad_fn=<MmBackward0>) tensor([[ 0.6038,  0.7434],
        [-0.3502,  0.5303],
        [ 3.8695,  2.4246]], grad_fn=<MmBackward0>)
tensor([[-0.0990,  0.0648, -0.6523],
        [-0.4022,  0.4078, -3.0024],
        [ 0.4842, -0.6683,  4.0461]], grad_fn=<MmBackward0>)
tensor([[-0.0700,  0.0458, -0.4612],
        [-0.2844,  0.2883, -2.1230],
        [ 0.3424, -0.4725,  2.8610]], grad_fn=<DivBackward0>)
tensor([[0.3573, 0.4011, 0.2416],
        [0.3410, 0.6047, 0.0542],
        [0.0722, 0.0320, 0.8959]], grad_fn=<SoftmaxBackward0>)
tensor([[1.01