# 0. Init

In [1]:
import torch
from torch import nn
import torch.nn.functional as F

In [30]:
# Hyperparameters Here
BATCH_SIZE = 4
T = 4
D_K = 2
D_V = 2
D_MODEL = 2
H = 8
VOCAB_SIZE = 10
N_MHA_BLOCKS_ENCODER = 6
N_CLASSES = 2

device = "cuda" if torch.cuda.is_available() else "cpu"

# 1. Attention

In [3]:
# Test softmax x axis:
matrix = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])

softmaxed_matrix = F.softmax(matrix, dim=1)

print(softmaxed_matrix)

tensor([[0.0900, 0.2447, 0.6652],
        [0.0900, 0.2447, 0.6652]])


In [4]:
def repeat(x: torch.Tensor, n: int):
    # make shape (n, 1, 1, ...) --> quantity of 1's must be len(x.shape)
    # for example, if shape of x is (3, 4, 8), shapee must be (n, 1, 1, 1)
    tuple_ones = tuple(
        (torch.tensor(x.shape) / torch.tensor(x.shape)).numpy().astype(int)
    )
    # print((n, *tuple_ones))
    return x.unsqueeze(0).repeat((n, *tuple_ones))


def batched_matmul(x_batched, W):
    # # Assuming x_batched.shape == (batch_size, T, d_model)
    # # and W.shape == (d_model, d)
    # batch_size, T, d_model = x_batched.shape
    # d = W.shape[1]

    # # Reshape x_batched to (batch_size * T, d_model)
    # x_reshaped = x_batched.reshape(-1, d_model)

    # # Perform matrix multiplication
    # result = torch.matmul(x_reshaped, W)

    # # Reshape the result back to (batch_size, T, d)
    # result = result.reshape(batch_size, T, d)

    # return result

    # batch_size = x_batched.shape[0]
    # W_repeated = W.unsqueeze(0).repeat((batch_size, 1, 1))
    W_repeated = repeat(W, n=x_batched.shape[0])
    print(x_batched.shape, W_repeated.shape)
    return torch.bmm(x_batched, W_repeated)

In [5]:
class Attention(nn.Module):
    """ Convention from: https://www.udemy.com/course/data-science-transformers-nlp/learn/lecture/32255056#overview
    In our convention, K, Q and V are learneable, different from the "Attention is all you need" paper.
    """
    def __init__(self, T: int, d_K, d_V, d_model: int, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        # define a torch 2d tensor initialized normally
        self.W_K = torch.normal(mean=0, std=0.01, size=(d_model, d_K), requires_grad=True)
        self.W_Q = torch.normal(mean=0, std=0.01, size=(d_model, d_K), requires_grad=True)
        self.W_V = torch.normal(mean=0, std=0.01, size=(d_model, d_V), requires_grad=True)
        self.mask = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Shapes:
        # W_K (d_model, d_K)
        # x is a 3d tensor (batch x T x d_model)

        # W_K.T ->  (1, d_k x d_model)
        # x ->      (batch, T, d_model)
        K = batched_matmul(x, self.W_K)
        Q = batched_matmul(x, self.W_Q)
        V = batched_matmul(x, self.W_V)

        # (batch, T, d_model) x (batch, d_model, d_k) -> (batch, T, d_k)
        result = torch.bmm(Q, K.transpose(1, 2)) / (K.shape[-1] ** 0.5)
        if self.mask:
            result = batched_matmul(result, self.mask)
        result = F.softmax(result, dim=-1)
        result = torch.bmm(result, V)
        return result

In [6]:
att = Attention(T=T, d_K=D_K, d_V=D_V, d_model=D_MODEL)
x = torch.normal(mean=0, std=0.01, size=(BATCH_SIZE, T, D_MODEL))

att_result = att.forward(x)
assert att_result.shape == (BATCH_SIZE, T, D_V)
print(att.W_K.shape)
print(att.W_Q.shape)
print(att.W_V.shape)
print(att_result.shape)

torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([2, 2])
torch.Size([2, 2])
torch.Size([2, 2])
torch.Size([4, 4, 2])


In [7]:
class MultiHeadAttention(nn.Module):
    def __init__(
        self, h: int, T=T, d_K=D_K, d_V=D_V, d_model=D_MODEL, *args, **kwargs
    ) -> None:
        super().__init__(*args, **kwargs)
        self.h = h
        self.attentions = nn.ModuleList(
            [Attention(T=T, d_K=d_K, d_model=d_model, d_V=d_V) for _ in range(h)]
        )
        self.W_O = torch.normal(0, 0.1, size=(h * d_V, d_model), requires_grad=True)
        self.T = T
        self.d_V = d_V

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        attention_results = []
        
        for attention in self.attentions:
            attention_result = attention(x)
            attention_results.append(attention_result)

        concatenated = torch.concat(attention_results, dim=-1)
        return batched_matmul(concatenated, self.W_O)

In [8]:
mha = MultiHeadAttention(h=H, T=T, d_K=D_K, d_V=D_V, d_model=D_MODEL)
assert mha(x).shape == (BATCH_SIZE, T, D_MODEL)

torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size

In [9]:
class TransformerBlock(nn.Module):
    def __init__(
        self, T=T, d_K=D_K, d_V=D_V, d_model=D_MODEL, h=H, dropout=0.1, *args, **kwargs
    ) -> None:
        super().__init__(*args, **kwargs)
        self.mha = MultiHeadAttention(h, T=T, d_K=d_K, d_V=d_V, d_model=d_model)
        self.layer_norm = nn.LayerNorm(d_model)
        self.ann = nn.Sequential(
            nn.Linear(d_model, 2),
            nn.Softmax(dim=-1),
        )
    
    def forward(self, x: torch.Tensor):
        # x = self.mha(x)
        # x = self.layer_norm(x)
        x = self.layer_norm(x + self.mha(x))
        x = self.layer_norm(x + self.ann(x))
        return x
        
transformerBlock = TransformerBlock()
transformerBlock(x).shape

torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size([4, 2, 2])
torch.Size([4, 4, 2]) torch.Size

torch.Size([4, 4, 2])

In [10]:
def PositionalEncoding(L: int, d_model):
    encodings = torch.zeros(size=(L, d_model), requires_grad=False)
    counter = 0
    for pos in range(L):
        for i in range((d_model // 2) + 1):
            if 2 * i < d_model:
                counter += 1
                encodings[pos, 2 * i] = torch.sin(
                    pos / torch.tensor(10000).pow(2 * i / d_model)
                )
            if 2 * i + 1 < d_model:
                counter += 1
                encodings[pos, 2 * i + 1] = torch.cos(
                    pos / torch.tensor(10000).pow(2 * i / d_model)
                )
    assert counter == L * d_model
    return encodings


PositionalEncoding(3, 4)

tensor([[ 0.0000,  1.0000,  0.0000,  1.0000],
        [ 0.8415,  0.5403,  0.0100,  0.9999],
        [ 0.9093, -0.4161,  0.0200,  0.9998]])

In [11]:
torch.range(0, 10).reshape(-1, 1)

  torch.range(0, 10).reshape(-1, 1)


tensor([[ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 4.],
        [ 5.],
        [ 6.],
        [ 7.],
        [ 8.],
        [ 9.],
        [10.]])

In [None]:
# class Embedding(nn.Module):
#     def __init__(self, vocab_size: int, d_model: int, *args, **kwargs) -> None:
#         super().__init__(*args, **kwargs)
#         self.vocab_size = vocab_size
#         self.embedding = torch.normal(
#             mean=0.0, std=0.1, size=(vocab_size, d_model), requires_grad=True
#         )

#     # TODO: make work in batches
#     def forward(self, x_one_hot: torch.Tensor):
#         # print(x_one_hot.shape)
#         batched_range = torch.arange(self.vocab_size).type(torch.float32)
#         batched_range = batched_range.unsqueeze(0).repeat(x_one_hot.shape[1], 1)
#         # print(batched_range.shape)
#         positions = batched_matmul(x_one_hot, batched_range.transpose(1, 0)).type(torch.int64)

#         print(positions)
#         print(self.embedding.shape)
#         # print(self.embedding)
#         return self.embedding[positions]


# emb = Embedding(3, 2)
# emb.forward(
#     torch.FloatTensor(
#         [
#             [[0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 0, 1]],
#             [[1, 0, 0], [0, 1, 0], [1, 0, 0], [0, 1, 0]],
#         ]
#     )
# )

In [29]:
class Embedding(nn.Module):
    def __init__(self, vocab_size: int, d_model: int, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.vocab_size = vocab_size
        self.embedding = torch.normal(
            mean=0.0, std=0.1, size=(vocab_size, d_model), requires_grad=True
        )

    def forward(self, x_one_hot: torch.Tensor):
        positions = torch.matmul(
            x_one_hot, torch.arange(self.vocab_size, dtype=torch.float32)
        ).type(torch.int64)

        # print(positions)
        # print(self.embedding)
        return self.embedding[positions]


# emb = Embedding(vocab_size=3, d_model=2)
# emb.forward(
#     torch.FloatTensor(
#         [
#             [[0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 0, 1]],
#             [[1, 0, 0], [0, 1, 0], [1, 0, 0], [0, 1, 0]],
#         ]
#     )
# )

tensor([[[ 0.1057, -0.1647],
         [ 0.1355, -0.0796],
         [ 0.0850,  0.0329],
         [ 0.1355, -0.0796]],

        [[ 0.0850,  0.0329],
         [ 0.1057, -0.1647],
         [ 0.0850,  0.0329],
         [ 0.1057, -0.1647]]], grad_fn=<IndexBackward0>)

In [None]:
class ClassifierEncoder(nn.Module):
    def __init__(
        self,
        T=T,
        d_K=D_K,
        d_V=D_V,
        d_model=D_MODEL,
        h=H,
        vocab_size=VOCAB_SIZE,
        n_classes=N_CLASSES,
        dropout=0.1,
        *args,
        **kwargs
    ):
        self.T = T
        self.d_K = d_K

        self.embbeding = Embedding(vocab_size, d_model)
        self.position_encoding = PositionalEncoding(L, d_model)

        self.transformersBlocks = nn.ModuleList()
        for _ in range(N_MHA_BLOCKS_ENCODER):
            self.transformersBlocks.append(
                TransformerBlock(T=T, d_K=d_K, d_V=d_V, d_model=d_model, h=h)
            )

        self.prediction_head = nn.Linear(d_model, n_classes)

    def forward(self, x: torch.Tensor):
        x = self.position_encoding + self.embbeding(x)
        x = self.transformersBlocks(x)
        x = F.softmax(self.prediction_head(x), dim=-1)
        return x