In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from prepare import prepare_data

In [42]:
#Hyperparameters
embedding_dim = 64
max_token = 8
d_key = 16

In [2]:
def train_dev_split(x,train,dev):
    N = x.shape[0]
    Ntr = int(N * train)
    Ndev = int(N*dev)
    ind = torch.randperm(Ntr)
    data_tr = x[ind][:Ntr]
    data_dev = x[ind][Ntr:Ntr+Ndev]
    return data_tr,data_dev

In [20]:
device = "cuda" if torch.cuda.device_count() else "cpu"
full_data = prepare_data("input.txt")
data, vocab, encode, decode = torch.tensor(full_data["encoded_data"], device=device), full_data["vocab"], full_data["encode"], full_data["decode"]
vocab_size = len(vocab)
B = data.shape[0] // max_token
ind = torch.randperm(n= data.shape[0] - max_token)[:B]
ranges = ind.view(B,1) + torch.arange(max_token)
data = data[ranges]
data_tr,data_dev = train_dev_split(data,0.9,0.1)

In [21]:
overfitting_data = data[:10]
overfitting_data

tensor([[58, 43, 52, 42,  1, 58, 53,  1],
        [47, 57,  1, 50, 39, 52, 42,  1],
        [42,  1, 41, 39, 52, 53, 54, 47],
        [47, 60, 43,  6,  1, 58, 47, 50],
        [ 1, 53, 59, 58,  1, 53, 44,  1],
        [51, 39, 52, 63,  1, 51, 43, 52],
        [59, 43, 56, 53, 56, 10,  0, 20],
        [46, 39, 52, 49, 44, 59, 50,  1],
        [40, 59, 58,  1, 54, 50, 43, 39],
        [39, 47, 56,  1, 61, 47, 58, 46]])

In [48]:
class Transformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding_table = nn.Embedding(vocab_size, embedding_dim)
        self.pos_table = nn.Embedding(max_token,embedding_dim)
    def forward(self,x):
        return self.embedding_table(x) + self.pos_table(torch.arange(max_token))

In [None]:
class Head(nn.Module):
    def __init__(self):
        super().__init__()
        self.keyM = nn.Linear(embedding_dim, d_key)
        self.queryM = nn.Linear(embedding_dim, d_key)
        self.valueM = nn.Linear(embedding_dim, d_key)
    def forward(self,x):
        Q = self.queryM(x)
        K = self.keyM(x)
        V = self.valueM(x)
        scores = (Q @ K.permute(0,2,1)) / d_key
        inp = torch.ones(max_token, max_token)
        mask = torch.tril(inp).bool()
        scores.masked_fill(~mask, float('-inf'))
        return scores
        

In [55]:
emb = Transformer()
embedded = emb(overfitting_data)
single_head = Head()
current = single_head(embedded)
current.shape

torch.Size([10, 8, 8])

In [66]:
inp = torch.ones(max_token, max_token)
mask = torch.tril(inp).bool()
current.masked_fill(~mask, float('-inf'))

tensor([[[-0.1355,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
             -inf],
         [ 0.2202,  0.0083,    -inf,    -inf,    -inf,    -inf,    -inf,
             -inf],
         [ 0.0621, -0.2398, -0.1394,    -inf,    -inf,    -inf,    -inf,
             -inf],
         [ 0.0714, -0.2759,  0.2720,  0.2471,    -inf,    -inf,    -inf,
             -inf],
         [-0.0650, -0.0040,  0.0949, -0.0530, -0.3242,    -inf,    -inf,
             -inf],
         [-0.1447,  0.1526,  0.0917,  0.0340, -0.2349, -0.0844,    -inf,
             -inf],
         [-0.0975,  0.2911, -0.1612,  0.2081,  0.2769, -0.1142, -0.0234,
             -inf],
         [ 0.1641,  0.0738,  0.0448, -0.0481,  0.0241,  0.1327,  0.0656,
           0.0738]],

        [[-0.1982,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
             -inf],
         [ 0.0289, -0.1015,    -inf,    -inf,    -inf,    -inf,    -inf,
             -inf],
         [-0.0600, -0.1928, -0.0913,    -inf,    -inf,    -inf,    -