In [28]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from prepare import prepare_data
from math import sqrt

In [34]:
#Hyperparameters (setting the embedding_dim to be a multiple of d_key is ideal)
embedding_dim = 64
max_token = 8
d_key = 16
n_head = embedding_dim // d_key

In [26]:
def train_dev_split(x,train,dev):
    N = x.shape[0]
    Ntr = int(N * train)
    Ndev = int(N*dev)
    ind = torch.randperm(Ntr)
    data_tr = x[ind][:Ntr]
    data_dev = x[ind][Ntr:Ntr+Ndev]
    return data_tr,data_dev

In [25]:
device = "cuda" if torch.cuda.device_count() else "cpu"
full_data = prepare_data("input.txt")
data, vocab, encode, decode = torch.tensor(full_data["encoded_data"], device=device), full_data["vocab"], full_data["encode"], full_data["decode"]
vocab_size = len(vocab)
B = data.shape[0] // max_token
ind = torch.randperm(n= data.shape[0] - max_token)[:B]
ranges = ind.view(B,1) + torch.arange(max_token)
data = data[ranges]
data_tr,data_dev = train_dev_split(data,0.9,0.1)

In [29]:
overfitting_data = data[:10]
overfitting_data

tensor([[ 6,  0, 28, 53, 50, 47, 62, 43],
        [42, 45, 51, 43, 52, 58,  1, 54],
        [57, 53, 52,  6,  1, 44, 39, 58],
        [ 1, 51, 39, 49, 43,  1, 46, 47],
        [ 1, 40, 53, 42, 63, 11,  0, 32],
        [ 1, 42, 47, 57, 41, 53, 60, 43],
        [54, 53, 53, 56,  1, 41, 46, 47],
        [58, 46,  1, 39, 50, 51, 53, 57],
        [51, 43, 58, 47, 51, 43, 57,  1],
        [46,  0, 37, 53, 59,  1, 61, 43]])

In [30]:
class Transformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding_table = nn.Embedding(vocab_size, embedding_dim)
        self.pos_table = nn.Embedding(max_token,embedding_dim)
    def forward(self,x):
        return self.embedding_table(x) + self.pos_table(torch.arange(max_token))

In [21]:
class Head(nn.Module):
    def __init__(self):
        super().__init__()
        self.keyM = nn.Linear(embedding_dim, d_key)
        self.queryM = nn.Linear(embedding_dim, d_key)
        self.valueM = nn.Linear(embedding_dim, d_key)
    def forward(self,x):
        Q = self.queryM(x)
        K = self.keyM(x)
        V = self.valueM(x)
        scores = (Q @ K.permute(0,2,1)) / sqrt(d_key)
        #Masking for attention
        inp = torch.ones(max_token, max_token)  
        mask = torch.tril(inp).bool()
        scores = scores.masked_fill(~mask, float('-inf'))
        #Attention
        attn = torch.softmax(scores, dim=-1)
        return attn @ V
        

In [61]:
class MultiHeadedAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.heads = nn.ModuleList([Head() for i in range(n_head)])
    def forward(self,x):
        res = []
        for i in range(n_head):
            res.append(self.heads[i](x))
        return torch.cat()

In [60]:
torch.cat([torch.tensor([1,0]),torch.tensor([1,2])])

tensor([1, 0, 1, 2])

In [55]:
emb = Transformer()
embedded = emb(overfitting_data)
single_head = Head()
current = single_head(embedded)
current

tensor([[[ 1.0335e+00, -1.7547e-01,  4.1799e-01,  ...,  1.8663e+00,
           1.1381e+00, -2.7055e-02],
         [ 4.3373e-02,  2.1530e-01,  6.8400e-01,  ...,  8.8471e-01,
          -8.3931e-01, -3.8434e-01],
         [-1.0261e-01,  1.0107e-01,  2.4153e-01,  ...,  6.4124e-01,
          -2.8775e-02, -2.2576e-01],
         ...,
         [ 2.1603e-01,  9.1425e-02, -1.0358e-01,  ...,  2.4007e-01,
           1.3572e-01, -2.5886e-01],
         [ 2.4690e-01,  2.5743e-01, -3.2803e-02,  ...,  7.2757e-01,
           4.9256e-01, -4.4815e-01],
         [ 5.9723e-01,  2.6753e-02, -1.0992e-01,  ...,  4.3336e-01,
           3.4864e-01, -3.2583e-01]],

        [[ 2.2754e-01, -7.1095e-02, -2.0852e-01,  ...,  1.2240e+00,
           1.2141e+00,  1.2294e+00],
         [ 1.8762e-02,  6.1646e-03,  7.6440e-01,  ...,  7.4719e-01,
           5.5001e-01,  4.1232e-01],
         [ 7.3700e-02, -9.3607e-02, -2.7212e-02,  ...,  8.0316e-01,
           7.0005e-01,  8.6644e-01],
         ...,
         [ 2.6801e-01,  2