In [1]:
import torch
import torch.nn as nn 
import math 
import torch.nn.functional as F 

In [2]:
def calculate_attention(
    query: torch.Tensor,
    keys: torch.Tensor,
    values: torch.Tensor
): 
    #perform matmul 
    attention_scores = torch.matmul(query, keys.transpose(-2,-1)) 
    attention_scores = attention_scores / math.sqrt(keys.shape[-1]) 
    attention = torch.matmul(attention_scores, values) 
    return attention, attention_scores

In [6]:
batch_size =2 
num_queries = 4 
num_keys = 16 
embed_size = 8 
query = torch.randn(batch_size, num_queries, embed_size) 
keys = torch.randn(batch_size, num_keys, embed_size) 
values = torch.randn(batch_size, num_keys, embed_size)

In [7]:
keys.shape

torch.Size([2, 16, 8])

In [9]:
attention, attention_scores = calculate_attention(query, keys, values)

In [10]:
attention.shape, attention_scores.shape

(torch.Size([2, 4, 8]), torch.Size([2, 4, 16]))

In [11]:
text = "attention! we will train attention ." 
text_tokens = text.split() 
vocab = set(text_tokens) 
vocab_to_idx = {token: idx for idx, token in enumerate(vocab)} 
print(vocab_to_idx)

{'will': 0, 'attention!': 1, 'attention': 2, 'train': 3, '.': 4, 'we': 5}


In [12]:
int_tokens = torch.tensor([vocab_to_idx[token] for token in text_tokens]) 
int_tokens = int_tokens.unsqueeze(0) 
print(int_tokens, "\nshape:", int_tokens.shape)

tensor([[1, 5, 0, 3, 2, 4]]) 
shape: torch.Size([1, 6])


In [13]:
embedding_layer = nn.Embedding(num_embeddings=len(vocab), embedding_dim=8)

In [14]:
embeddings = embedding_layer(int_tokens)
embeddings

tensor([[[-1.1434,  0.1184, -1.2842, -1.6180, -0.6949, -0.4059, -0.3229,
           0.8431],
         [-0.7437,  0.7823,  0.7492,  1.4523,  0.0660,  0.5996,  0.5312,
          -1.4024],
         [-0.2346, -0.2163,  0.0753, -0.9114,  0.5445, -0.6132, -0.7213,
          -0.2497],
         [ 0.4371, -0.2856, -0.9438, -1.5427, -1.0193,  0.1261, -0.9378,
          -0.0200],
         [ 0.9038, -0.8440,  0.1119,  0.2089, -0.2603,  1.6706,  0.9360,
          -0.8233],
         [ 1.5770, -0.6095, -0.3253,  1.2963,  1.5325, -0.6263,  0.5931,
          -0.8159]]], grad_fn=<EmbeddingBackward0>)

In [15]:
embedding_dim = 8 
embedding_layer = nn.Embedding(num_embeddings=len(vocab), embedding_dim=embedding_dim) 
query_dense_layer = nn.Linear(in_features=embedding_dim, out_features=8) 
key_dense_layer = nn.Linear(in_features=embedding_dim, out_features=8) 
value_dense_layer = nn.Linear(in_features=embedding_dim, out_features=8)

In [16]:
embeddings = embedding_layer(int_tokens)
embeddings.shape

torch.Size([1, 6, 8])

In [17]:
embeddings = embedding_layer(int_tokens) 
query = query_dense_layer(embeddings)
key = key_dense_layer(embeddings)
value = key_dense_layer(embeddings) 

query.shape, key.shape, value.shape

(torch.Size([1, 6, 8]), torch.Size([1, 6, 8]), torch.Size([1, 6, 8]))

In [18]:
attention, attention_scores = calculate_attention(query, key, value)
attention.shape, attention_scores.shape

(torch.Size([1, 6, 8]), torch.Size([1, 6, 6]))

In [19]:
right_triangular_mask = torch.tril(torch.ones_like(attention_scores)) 
right_triangular_mask

tensor([[[1., 0., 0., 0., 0., 0.],
         [1., 1., 0., 0., 0., 0.],
         [1., 1., 1., 0., 0., 0.],
         [1., 1., 1., 1., 0., 0.],
         [1., 1., 1., 1., 1., 0.],
         [1., 1., 1., 1., 1., 1.]]])

In [20]:
def calculate_masked_attention(
        values: torch.Tensor,
        keys: torch.Tensor,
        query: torch.Tensor,
        mask: torch.Tensor = None 
): 
    attention_scores = torch.matmul(query, keys.transpose(-2,-1)) 
    attention_scores = attention_scores / math.sqrt(keys.shape[-1])
    if mask is not None: 
        attention_scores = torch.where(mask == 0, torch.tensor(-1e9), attention_scores) 
    attention_scores = F.softmax(attention_scores, dim=-1) 
    attention = torch.matmul(attention_scores, values) 
    return attention, attention_scores

In [21]:
attention_Context, attention_scores = calculate_masked_attention(query, key, value, right_triangular_mask)

In [22]:
attention_Context.shape

torch.Size([1, 6, 8])

In [24]:

# Define dense layers for query, key, and value transformations
# 'embedding_dim' would be the input feature dimension, and '8' is the output feature dimension.
query_dense_layer_2 = nn.Linear(in_features=embedding_dim, out_features=8)
key_dense_layer_2 = nn.Linear(in_features=embedding_dim, out_features=8)
value_dense_layer_2 = nn.Linear(in_features=embedding_dim, out_features=8)

# Apply the dense layers to 'attention_context' to get query, key, and value representations
# 'attention_context' would be the input tensor to these layers.
query_2 = query_dense_layer_2(attention_Context)
key_2 = key_dense_layer_2(attention_Context)
value_2 = value_dense_layer_2(attention_Context)
