# 0. Init

In [56]:
import torch
from torch import nn
import torch.nn.functional as F

In [57]:
# Hyperparameters Here
batch_size = 8
T = 10
d_K = 2
d_V = 4
d_model = 6

# 1. Attention

In [58]:
# Test softmax x axis:
matrix = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])

softmaxed_matrix = F.softmax(matrix, dim=1)

print(softmaxed_matrix)

tensor([[0.0900, 0.2447, 0.6652],
        [0.0900, 0.2447, 0.6652]])


In [157]:
def repeat(x: torch.Tensor, n: int):
    # make shape (n, 1, 1, ...) --> quantity of 1's must be len(x.shape)
    # for example, if shape of x is (3, 4, 8), shapee must be (n, 1, 1, 1)
    tuple_ones = tuple(
        (torch.tensor(x.shape) / torch.tensor(x.shape)).numpy().astype(int)
    )
    # print((n, *tuple_ones))
    return x.unsqueeze(0).repeat((n, *tuple_ones))


def batched_matmul(x_batched, W):
    # # Assuming x_batched.shape == (batch_size, T, d_model)
    # # and W.shape == (d_model, d)
    # batch_size, T, d_model = x_batched.shape
    # d = W.shape[1]

    # # Reshape x_batched to (batch_size * T, d_model)
    # x_reshaped = x_batched.reshape(-1, d_model)

    # # Perform matrix multiplication
    # result = torch.matmul(x_reshaped, W)

    # # Reshape the result back to (batch_size, T, d)
    # result = result.reshape(batch_size, T, d)

    # return result

    # batch_size = x_batched.shape[0]
    # W_repeated = W.unsqueeze(0).repeat((batch_size, 1, 1))
    W_repeated = repeat(W, n=x_batched.shape[0])
    return torch.bmm(x_batched, W_repeated)

In [158]:
class Attention(nn.Module):
    def __init__(self, T: int, d_K, d_V, d_model: int, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        # define a torch 2d tensor initialized normally
        self.W_K = nn.Parameter(torch.normal(mean=0, std=0.01, size=(d_model, d_K)))
        self.W_Q = nn.Parameter(torch.normal(mean=0, std=0.01, size=(d_model, d_K)))
        self.W_V = nn.Parameter(torch.normal(mean=0, std=0.01, size=(d_model, d_V)))
        self.mask = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Shapes:
        # W_K (d_model, d_K)
        # x is a 3d tensor (batch x T x d_model)

        # W_K.T ->  (1, d_k x d_model)
        # x ->      (batch, T, d_model)
        K = batched_matmul(x, self.W_K)
        Q = batched_matmul(x, self.W_Q)
        V = batched_matmul(x, self.W_V)

        result = torch.bmm(Q, K.transpose(1, 2)) / (K.shape[-1] ** 0.5)
        if self.mask:
            result = batched_matmul(result, self.mask)
        result = F.softmax(result, dim=-1)
        result = torch.bmm(result, V)
        return result

In [159]:
att = Attention(T=T, d_K=d_K, d_V=d_V, d_model=d_model)
x = torch.normal(mean=0, std=0.01, size=(batch_size, T, d_model))

att_result = att.forward(x)
assert att_result.shape == (batch_size, T, d_V)
print(att.W_K.shape)
print(att.W_Q.shape)
print(att.W_V.shape)
print(att_result.shape)

torch.Size([6, 2])
torch.Size([6, 2])
torch.Size([6, 4])
torch.Size([8, 10, 4])


In [166]:
class MultiHeadAttention(nn.Module):
    def __init__(
        self, h: int, T=T, d_K=d_K, d_V=d_V, d_model=d_model, *args, **kwargs
    ) -> None:
        super().__init__(*args, **kwargs)
        self.h = h
        self.attentions = nn.ModuleList(
            [Attention(T=T, d_K=d_K, d_model=d_model, d_V=d_V) for _ in range(h)]
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        attention_results = nn.ModuleList([])
        for attention in self.attentions:
            attention_result = attention(x)
            print(attention_result)
            assert False
            attention_results.append(attention_result)
            
        print(attention_results)
        return

In [167]:
mha = MultiHeadAttention(8, T=T, d_K=d_K, d_V=d_V, d_model=d_model)
mha(x)

tensor([[[ 1.3849e-05, -2.0673e-06,  7.9926e-05,  6.6457e-05],
         [ 1.3849e-05, -2.0673e-06,  7.9926e-05,  6.6457e-05],
         [ 1.3849e-05, -2.0673e-06,  7.9926e-05,  6.6457e-05],
         [ 1.3849e-05, -2.0673e-06,  7.9926e-05,  6.6457e-05],
         [ 1.3849e-05, -2.0673e-06,  7.9926e-05,  6.6457e-05],
         [ 1.3849e-05, -2.0673e-06,  7.9926e-05,  6.6457e-05],
         [ 1.3849e-05, -2.0673e-06,  7.9926e-05,  6.6457e-05],
         [ 1.3849e-05, -2.0673e-06,  7.9926e-05,  6.6457e-05],
         [ 1.3849e-05, -2.0673e-06,  7.9926e-05,  6.6457e-05],
         [ 1.3849e-05, -2.0673e-06,  7.9926e-05,  6.6457e-05]],

        [[ 2.8928e-05, -4.5233e-05, -8.8701e-05, -6.0039e-06],
         [ 2.8928e-05, -4.5233e-05, -8.8701e-05, -6.0039e-06],
         [ 2.8928e-05, -4.5233e-05, -8.8701e-05, -6.0039e-06],
         [ 2.8928e-05, -4.5233e-05, -8.8701e-05, -6.0039e-06],
         [ 2.8928e-05, -4.5233e-05, -8.8701e-05, -6.0039e-06],
         [ 2.8928e-05, -4.5233e-05, -8.8701e-05, -6.0

AssertionError: 