In [282]:
import torch
from torch import nn
from torch.nn import functional
import numpy as np
import math

In [283]:
# Reading Tiny Shakespeare Dataset
file = open('tiny-shakespeare.txt', 'r')
text = file.read()
file.close()

In [427]:
head_size = 4
embedding_size = 5
input_size = 3
class Single_Attention_Block(nn.Module):
    def __init__(self):
        super().__init__()
        # Weight Matrices for Queries, Keys, and Values
        self.lin_q = nn.Linear(embedding_size, head_size, bias = False)
        self.lin_k = nn.Linear(embedding_size, head_size, bias = False)
        self.lin_v_up = nn.Linear(embedding_size, head_size, bias = False)
        self.lin_v_down = nn.Linear(head_size, head_size, bias = False)
        self.soft = nn.Softmax(dim = 2)

    def forward(self, embedding):
        Q = self.lin_q(embedding)
        Q = Q[:, :, None]
        K = self.lin_k(embedding)
        K = K[:, :, None]
        
        # K Transposition
        K = torch.permute(K, (0, 2, 1))
        V = self.lin_v_up(embedding)
        V = self.lin_v_down(V)
        V = V[:, :, None]
        attention_matrix = (Q @ K) / (math.sqrt(head_size))
        
        # Triangular mask applied for decoder
        attention_matrix = torch.tril(attention_matrix)
        attention_matrix[attention_matrix == 0] = -math.inf
        attention_matrix = self.soft(attention_matrix)
        attention_matrix = attention_matrix @ V
        return attention_matrix

In [428]:
test_embedding = torch.rand(input_size, embedding_size)
test_embedding.shape

torch.Size([3, 5])

In [429]:
model = Single_Attention_Block()

In [430]:
result = model(test_embedding)

In [431]:
result

tensor([[[ 0.1112],
         [ 0.0525],
         [-0.0133],
         [-0.0584]],

        [[ 0.1739],
         [ 0.1035],
         [ 0.0178],
         [-0.0587]],

        [[ 0.0643],
         [-0.0241],
         [ 0.0270],
         [-0.0303]]], grad_fn=<UnsafeViewBackward0>)

In [447]:
class Multi_Head_Attention(nn.Module):
    def __init__(self, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.attentionModule = nn.ModuleList([Single_Attention_Block() for i in range(num_heads)])
        self.lin_o = nn.Linear(embedding_size, embedding_size)
        
    def forward(self, embedding):
        outputs = ()
        for i, block in enumerate(self.attentionModule):
            outputs.append(block(embedding))
        x = torch.concat(outputs, dim = 1)
        return x
        

In [448]:
model = Multi_Head_Attention(3)
result = model(test_embedding)
result

TypeError: concat() received an invalid combination of arguments - got (Tensor, dim=int), but expected one of:
 * (tuple of Tensors tensors, int dim, *, Tensor out)
 * (tuple of Tensors tensors, name dim, *, Tensor out)
