In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from prepare import prepare_data
from math import sqrt

In [2]:
#Hyperparameters (setting the embedding_dim to be a multiple of d_key is ideal)
embedding_dim = 64
max_token = 8
d_key = 16
n_decoder_layers = 3
n_head = embedding_dim // d_key

In [3]:
def train_dev_split(x,train,dev):
    N = x.shape[0]
    Ntr = int(N * train)
    Ndev = int(N*dev)
    ind = torch.randperm(Ntr)
    data_tr = x[ind][:Ntr]
    data_dev = x[ind][Ntr:Ntr+Ndev]
    return data_tr,data_dev

In [4]:
device = "cuda" if torch.cuda.device_count() else "cpu"
full_data = prepare_data("input.txt")
data, vocab, encode, decode = torch.tensor(full_data["encoded_data"], device=device), full_data["vocab"], full_data["encode"], full_data["decode"]
vocab_size = len(vocab)
Batch = data.shape[0] // max_token
ind = torch.randperm(n= data.shape[0] - max_token)[:Batch]
ranges = ind.view(Batch,1) + torch.arange(max_token)
data = data[ranges]
data_tr,data_dev = train_dev_split(data,0.9,0.1)

In [5]:
overfitting_data = data[:10]
overfitting_data[:,1:]

tensor([[ 1, 61, 53, 56, 42,  1, 47],
        [58, 43, 56,  1, 58, 46, 39],
        [57,  1, 58, 46, 39, 58,  1],
        [13, 52, 42,  1, 41, 56, 47],
        [43,  1, 46, 43, 39, 60, 43],
        [50, 43, 43, 60, 43, 57,  8],
        [53, 42,  1, 54, 56, 43, 60],
        [27, 10,  0, 32, 56, 59, 43],
        [52, 43, 58, 43, 43, 52,  1],
        [53, 61, 39, 56, 42, 57,  2]])

In [14]:
class Head(nn.Module):
    def __init__(self):
        super().__init__()
        self.keyM = nn.Linear(embedding_dim, d_key)
        self.queryM = nn.Linear(embedding_dim, d_key)
        self.valueM = nn.Linear(embedding_dim, d_key)
    def forward(self,x):
        Q = self.queryM(x)
        K = self.keyM(x)
        V = self.valueM(x)
        scores = (Q @ K.permute(0,2,1)) / torch.sqrt(torch.tensor(d_key,device=device))
        #Masking for attention
        inp = torch.ones(max_token, max_token)  
        mask = torch.tril(inp).bool()
        scores = scores.masked_fill(~mask, float('-inf'))
        #Attention
        attn = torch.softmax(scores, dim=-1)
        return attn @ V
        

In [15]:
class MultiHeadedAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.heads = nn.ModuleList([Head() for i in range(n_head)])
        self.linear_layer = nn.Linear(n_head*d_key, embedding_dim)
    def forward(self,x):
        res = []
        for i in range(n_head):
            res.append(self.heads[i](x))
        return self.linear_layer(torch.cat(res, dim=-1))

In [16]:
class Transformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding_table = nn.Embedding(vocab_size, embedding_dim)
        self.pos_table = nn.Embedding(max_token,embedding_dim)
        self.decoder_layers = nn.ModuleList([MultiHeadedAttention() for i in range(n_decoder_layers)])
        self.feed_forward = nn.ModuleList([nn.Sequential(nn.Linear(embedding_dim,4*embedding_dim),nn.ReLU(),nn.Linear(4*embedding_dim,embedding_dim)) for i in range(n_decoder_layers)])
        self.final_linear = nn.Linear(embedding_dim,vocab_size)
        self.layer_norm = nn.LayerNorm(embedding_dim)
    def forward(self,x):
        emb = self.embedding_table(x) + self.pos_table(torch.arange(x.shape[1],device=device))
        out = emb
        for i in range(n_decoder_layers):
            attn_out = self.decoder_layers[i](self.layer_norm(out))
            out = out + attn_out
            ff_out = self.feed_forward[i](self.layer_norm(out))
            out = out + ff_out
        return self.final_linear(out)

In [17]:
DecoderOnlyTransformer = Transformer().to(device)
optimizer = torch.optim.AdamW(lr=4e-3,params=DecoderOnlyTransformer.parameters())

In [20]:
#TRAINING BLOCK:
for i in range(100):
    out = DecoderOnlyTransformer(data_tr[:100])
    B,T,C = out.shape
    out = out.view(B*T, C)
    targets = data_tr[:100].view(-1)
    loss = F.cross_entropy(out,targets)
    DecoderOnlyTransformer.zero_grad()
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print(loss)

tensor(2.5353, grad_fn=<NllLossBackward0>)
tensor(1.6157, grad_fn=<NllLossBackward0>)
tensor(1.0234, grad_fn=<NllLossBackward0>)
tensor(0.9384, grad_fn=<NllLossBackward0>)
tensor(0.6635, grad_fn=<NllLossBackward0>)
tensor(0.4407, grad_fn=<NllLossBackward0>)
tensor(0.3347, grad_fn=<NllLossBackward0>)
tensor(0.2663, grad_fn=<NllLossBackward0>)
tensor(0.2108, grad_fn=<NllLossBackward0>)
tensor(0.1666, grad_fn=<NllLossBackward0>)
tensor(0.1291, grad_fn=<NllLossBackward0>)
tensor(0.0926, grad_fn=<NllLossBackward0>)
tensor(0.0612, grad_fn=<NllLossBackward0>)
tensor(0.0403, grad_fn=<NllLossBackward0>)
tensor(0.0280, grad_fn=<NllLossBackward0>)
tensor(0.0207, grad_fn=<NllLossBackward0>)
tensor(0.0162, grad_fn=<NllLossBackward0>)
tensor(0.0132, grad_fn=<NllLossBackward0>)
tensor(0.0110, grad_fn=<NllLossBackward0>)
tensor(0.0090, grad_fn=<NllLossBackward0>)
tensor(0.0072, grad_fn=<NllLossBackward0>)
tensor(0.0056, grad_fn=<NllLossBackward0>)
tensor(0.0043, grad_fn=<NllLossBackward0>)
tensor(0.00

KeyboardInterrupt: 