In [1]:
import torch
import torch.nn.functional as F

# Define the input sentence (toy example)
sentence = ["What", "are", "the", "symptoms", "of", "diabetes", "?"]
vocab_size = 10000  # Assume a vocabulary size
embedding_dim = 8   # Small embedding size for illustration

# Randomly initialized embeddings for each token
torch.manual_seed(42)
embeddings = torch.rand((len(sentence), embedding_dim))  # (seq_length, d_model)

# Self-Attention Mechanism
def self_attention(embeddings):
    d_k = embeddings.shape[1]  # Embedding dimension

    # Create Q, K, V matrices as learned linear transformations
    W_q = torch.rand((d_k, d_k))
    W_k = torch.rand((d_k, d_k))
    W_v = torch.rand((d_k, d_k))

    Q = embeddings @ W_q
    K = embeddings @ W_k
    V = embeddings @ W_v

    # Compute attention scores
    scores = Q @ K.T / torch.sqrt(torch.tensor(d_k))  # Scaled dot product
    attention_weights = F.softmax(scores, dim=-1)  # Softmax normalization

    # Compute self-attention output
    output = attention_weights @ V

    return attention_weights, output

# Compute self-attention
attention_scores, attention_output = self_attention(embeddings)

# Print results
print("Attention Scores:\n", attention_scores)
print("\nSelf-Attention Output:\n", attention_output)


Attention Scores:
 tensor([[4.3061e-01, 1.5894e-01, 1.1955e-02, 5.1078e-05, 1.9447e-01, 2.0204e-01,
         1.9282e-03],
        [4.3387e-01, 1.5085e-01, 1.3501e-02, 6.7036e-05, 2.1080e-01, 1.8826e-01,
         2.6425e-03],
        [3.7908e-01, 1.7508e-01, 2.2295e-02, 3.0595e-04, 2.0734e-01, 2.1051e-01,
         5.3868e-03],
        [3.0247e-01, 1.8655e-01, 5.9879e-02, 4.9815e-03, 2.1277e-01, 2.0590e-01,
         2.7445e-02],
        [4.3384e-01, 1.5517e-01, 1.4095e-02, 7.1856e-05, 1.9819e-01, 1.9614e-01,
         2.4938e-03],
        [4.1533e-01, 1.6272e-01, 1.3794e-02, 7.9935e-05, 1.9983e-01, 2.0567e-01,
         2.5728e-03],
        [3.4870e-01, 1.7962e-01, 3.5365e-02, 1.0997e-03, 2.1879e-01, 2.0481e-01,
         1.1622e-02]])

Self-Attention Output:
 tensor([[2.5879, 3.2746, 3.5969, 2.6043, 1.6062, 1.6859, 1.5454, 2.7297],
        [2.5863, 3.2799, 3.5910, 2.6014, 1.6029, 1.6822, 1.5475, 2.7391],
        [2.5700, 3.2584, 3.5961, 2.6006, 1.6002, 1.6808, 1.5254, 2.6943],
        [2.5