# Lecture 15: Self-Attention with trainable weights

In [400]:
import torch
from torch import nn
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

inputs = torch.tensor([[0.43, 0.15, 0.89],
                       [0.55, 0.87, 0.66],
                       [0.57, 0.85, 0.64],
                       [0.22, 0.58, 0.33],
                       [0.77, 0.25, 0.10],
                       [0.05, 0.80, 0.55]]
                       )

### defining variables for creating Query, Key and Value Weight Matrices

*Input Dimension* = 3 , *Output Dimension* = 2

### Input Dimension has to match in order to follow the rules of Matrix Multiplication --> Rows of trainable weight matrices have to match the columns (Dimensions) of the input Vector Embeddings or the Inputs Matrix
### Ensuring the Inner Dimensions match during Matrix Multiplication the Output Dimension can be anything

### in the Example below the initial Dimension of the Inputs Matrix is changed from 3 to 2 as the trainable weight matrices are chosen to have a dimension of 2 (2 Columns)

In [401]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

### building weight matrices

In [402]:
torch.manual_seed(123)

W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

print(f"Inputs:\n{inputs}{inputs.shape}\n")

print(f"W_query Weight Matrix:\n{W_query}{W_query.shape}\n")
print(f"W_key Weight Matrix:\n{W_key}{W_key.shape}\n")
print(f"W_Value Weight Matrix:\n{W_value}{W_value.shape}\n")

Inputs:
tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])torch.Size([6, 3])

W_query Weight Matrix:
Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]])torch.Size([3, 2])

W_key Weight Matrix:
Parameter containing:
tensor([[0.1366, 0.1025],
        [0.1841, 0.7264],
        [0.3153, 0.6871]])torch.Size([3, 2])

W_Value Weight Matrix:
Parameter containing:
tensor([[0.0756, 0.1966],
        [0.3164, 0.4017],
        [0.1186, 0.8274]])torch.Size([3, 2])



### computing QUERIES, KEYS and VALUES Matrices for the word "Journey" from Input Sequence

In [403]:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value

print(f"Query Matrix for word Journey:\n{query_2}\n{query_2.shape}\n")
print(f"Key Matrix for word Journey:\n{key_2}\n{key_2.shape}\n")
print(f"Values Matrix for word Journey\n{value_2}\n{value_2.shape}\n")


Query Matrix for word Journey:
tensor([0.4306, 1.4551])
torch.Size([2])

Key Matrix for word Journey:
tensor([0.4433, 1.1419])
torch.Size([2])

Values Matrix for word Journey
tensor([0.3951, 1.0037])
torch.Size([2])



### computing the Queries, Keys and Values Matrices

In [404]:
queries = inputs @ W_query
keys = inputs @ W_key
values = inputs @ W_value

print(f"Queries:\n{queries}{queries.shape}\n")
print(f"Keys:\n{keys}{keys.shape}")
print(f"Values:\n{values}{values.shape}")

Queries:
tensor([[0.2309, 1.0966],
        [0.4306, 1.4551],
        [0.4300, 1.4343],
        [0.2355, 0.7990],
        [0.2983, 0.6565],
        [0.2568, 1.0533]])torch.Size([6, 2])

Keys:
tensor([[0.3669, 0.7646],
        [0.4433, 1.1419],
        [0.4361, 1.1156],
        [0.2408, 0.6706],
        [0.1827, 0.3292],
        [0.3275, 0.9642]])torch.Size([6, 2])
Values:
tensor([[0.1855, 0.8812],
        [0.3951, 1.0037],
        [0.3879, 0.9831],
        [0.2393, 0.5493],
        [0.1492, 0.3346],
        [0.3221, 0.7863]])torch.Size([6, 2])


### computing attention scores for Journey, working with the *Queries* Matrix and the *Keys* Matrix

In [405]:
token_2 = queries[1]

attn_scores_2 = token_2 @ keys.T
print(f"{token_2.shape} @ {keys.T.shape} --> 1 X 6 Vector")
print(f"{attn_scores_2} | {attn_scores_2.shape}")

torch.Size([2]) @ torch.Size([2, 6]) --> 1 X 6 Vector
tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440]) | torch.Size([6])


### computing all attention scores

In [406]:
attn_scores = queries @ keys.T
print(f"{queries.shape} @ {keys.T.shape} = {attn_scores.shape}")
print(f"Attention Scores:\n{attn_scores}{attn_scores.shape}")

torch.Size([6, 2]) @ torch.Size([2, 6]) = torch.Size([6, 6])
Attention Scores:
tensor([[0.9231, 1.3545, 1.3241, 0.7910, 0.4032, 1.1330],
        [1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440],
        [1.2544, 1.8284, 1.7877, 1.0654, 0.5508, 1.5238],
        [0.6973, 1.0167, 0.9941, 0.5925, 0.3061, 0.8475],
        [0.6114, 0.8819, 0.8626, 0.5121, 0.2707, 0.7307],
        [0.8995, 1.3165, 1.2871, 0.7682, 0.3937, 1.0996]])torch.Size([6, 6])


### scaling by the square root of the keys dimension (2)

#### prevents the values which are optained after application of softmax to become peaky as well as to keep the variance closer to value 1

In [407]:
scaled_attn_scores = attn_scores / torch.sqrt(torch.tensor(keys.shape[-1]))
print(f"Attention Scores: scaled by square root of 2\n{scaled_attn_scores}")

Attention Scores: scaled by square root of 2
tensor([[0.6528, 0.9578, 0.9363, 0.5593, 0.2851, 0.8011],
        [0.8984, 1.3098, 1.2806, 0.7633, 0.3944, 1.0918],
        [0.8870, 1.2929, 1.2641, 0.7534, 0.3895, 1.0775],
        [0.4930, 0.7189, 0.7029, 0.4190, 0.2164, 0.5993],
        [0.4323, 0.6236, 0.6099, 0.3621, 0.1914, 0.5167],
        [0.6361, 0.9309, 0.9101, 0.5432, 0.2784, 0.7776]])


### taking the softmax in order to favour interpretability and improve backpropagationtraing stability

In [408]:
attn_weights = torch.softmax(scaled_attn_scores, dim=-1)
print(f"Attention Weights: normalized along the Rows:\n{attn_weights}")

Attention Weights: normalized along the Rows:
tensor([[0.1551, 0.2104, 0.2059, 0.1413, 0.1074, 0.1799],
        [0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820],
        [0.1503, 0.2256, 0.2192, 0.1315, 0.0914, 0.1819],
        [0.1591, 0.1994, 0.1962, 0.1477, 0.1206, 0.1769],
        [0.1610, 0.1949, 0.1923, 0.1501, 0.1265, 0.1752],
        [0.1557, 0.2092, 0.2048, 0.1419, 0.1089, 0.1794]])


### computing one context vector for the word Journey

In [409]:
attn_journey = attn_weights[1]
context_journey = attn_journey @ values
print(f"{attn_journey.shape} @ {values.shape} --> {context_journey.shape}")
print(f"Context Vector'Journey':\n{context_journey}")

torch.Size([6]) @ torch.Size([6, 2]) --> torch.Size([2])
Context Vector'Journey':
tensor([0.3061, 0.8210])


### computing all context vectors by multiplying the attention weights matrix with the values matrix

In [410]:
context_matrix = attn_weights @ values
print(f"{attn_weights.shape} @ {values.shape} --> {context_matrix.shape}")
print(f"Context Vectors:\n{context_matrix}")

torch.Size([6, 6]) @ torch.Size([6, 2]) --> torch.Size([6, 2])
Context Vectors:
tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]])


## Self Attention Class V1 -- using nn.Parameter

In [411]:
class SelfAttentionV1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))
    
    def forward(self, x):
        queries = x @ self.W_query
        keys = x @ self.W_key
        values = x @ self.W_value

        attn_scores = queries @ keys.T
        attn_weights_scaled = attn_scores / torch.sqrt(torch.tensor(keys.shape[-1]))
        attn_weights = torch.softmax(attn_weights_scaled, dim=-1)

        context_matrix = attn_weights @ values
        return context_matrix

### creating an instance of the self attention v1 class

In [412]:
torch.manual_seed(123)
self_attention_v1 = SelfAttentionV1(d_in, d_out)
print(f"Context Vectors:\n{self_attention_v1.forward(inputs)}")

Context Vectors:
tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


# Self Attention Calss V2 -- using nn.Linear

#### nn.Linear has an optimized weight initialization schema which leads to more stable and effective model training

In [413]:
class SelfAttentionV2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.T
        attn_weights_scaled = attn_scores / torch.sqrt(torch.tensor(keys.shape[-1]))
        attn_weights = torch.softmax(attn_weights_scaled, dim=-1)

        context_matrix = attn_weights @ values
        return context_matrix

### creating an instance of the self attention v2 class

In [414]:
torch.manual_seed(78)
self_attention_v2 = SelfAttentionV2(d_in, d_out)
print(f"Context Matrix:\n{self_attention_v2.forward(inputs)}")

Context Matrix:
tensor([[0.2461, 0.2370],
        [0.2437, 0.2398],
        [0.2438, 0.2398],
        [0.2437, 0.2368],
        [0.2455, 0.2362],
        [0.2429, 0.2379]], grad_fn=<MmBackward0>)
