In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from prepare import prepare_data
from math import sqrt

In [22]:
#Hyperparameters (setting the embedding_dim to be a multiple of d_key is ideal)
embedding_dim = 64
max_token = 8
d_key = 16
n_decoder_layers = 3
n_head = embedding_dim // d_key
mini_batch_size = 32

In [3]:
def train_dev_split(x,train,dev):
    N = x.shape[0]
    Ntr = int(N * train)
    Ndev = int(N*dev)
    ind = torch.randperm(Ntr)
    data_tr = x[ind][:Ntr]
    data_dev = x[ind][Ntr:Ntr+Ndev]
    return data_tr,data_dev

In [113]:
def create_batch(data):
    ind = torch.randint(data.shape[0] - max_token, size=(mini_batch_size,))
    shuffled_data = data[ind]
    return shuffled_data

In [82]:
device = "cuda" if torch.cuda.is_available() else "cpu"
full_data = prepare_data("input.txt")
data, vocab, encode, decode = torch.tensor(full_data["encoded_data"], device=device), full_data["vocab"], full_data["encode"], full_data["decode"]
vocab_size = len(vocab)
Batch = data.shape[0] // (max_token+1)
ind = torch.randperm(n= data.shape[0] - (max_token+1))[:Batch]
ranges = ind.view(Batch,1) + torch.arange((max_token+1))
data = data[ranges]
data_tr,data_dev = train_dev_split(data,0.9,0.1)

In [None]:
class Head(nn.Module):
    def __init__(self):
        super().__init__()
        self.keyM = nn.Linear(embedding_dim, d_key)
        self.queryM = nn.Linear(embedding_dim, d_key)
        self.valueM = nn.Linear(embedding_dim, d_key)
    def forward(self,x):
        Q = self.queryM(x)
        K = self.keyM(x)
        V = self.valueM(x)
        scores = (Q @ K.permute(0,2,1)) / torch.sqrt(torch.tensor(d_key,device=device))
        #Masking for attention
        inp = torch.ones(max_token, max_token, device=device)  
        mask = torch.tril(inp).bool()
        scores = scores.masked_fill(~mask, float('-inf'))
        #Attention
        attn = torch.softmax(scores, dim=-1)
        return attn @ V
        

In [15]:
class MultiHeadedAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.heads = nn.ModuleList([Head() for i in range(n_head)])
        self.linear_layer = nn.Linear(n_head*d_key, embedding_dim)
    def forward(self,x):
        res = []
        for i in range(n_head):
            res.append(self.heads[i](x))
        return self.linear_layer(torch.cat(res, dim=-1))

In [16]:
class Transformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding_table = nn.Embedding(vocab_size, embedding_dim)
        self.pos_table = nn.Embedding(max_token,embedding_dim)
        self.decoder_layers = nn.ModuleList([MultiHeadedAttention() for i in range(n_decoder_layers)])
        self.feed_forward = nn.ModuleList([nn.Sequential(nn.Linear(embedding_dim,4*embedding_dim),nn.ReLU(),nn.Linear(4*embedding_dim,embedding_dim)) for i in range(n_decoder_layers)])
        self.final_linear = nn.Linear(embedding_dim,vocab_size)
        self.layer_norm = nn.LayerNorm(embedding_dim)
    def forward(self,x):
        emb = self.embedding_table(x) + self.pos_table(torch.arange(x.shape[1],device=device))
        out = emb
        for i in range(n_decoder_layers):
            attn_out = self.decoder_layers[i](self.layer_norm(out))
            out = out + attn_out
            ff_out = self.feed_forward[i](self.layer_norm(out))
            out = out + ff_out
        return self.final_linear(out)

In [133]:
DecoderOnlyTransformer = Transformer().to(device)
optimizer = torch.optim.AdamW(lr=4e-3,params=DecoderOnlyTransformer.parameters())

In [141]:
#TRAINING BLOCK:
for i in range(100):
    minibatch = create_batch(data_tr)
    Xbatch,Ybatch = minibatch[:,:8], minibatch[:,1:]
    out = DecoderOnlyTransformer(Xbatch)
    B,T,C = out.shape
    out = out.view(B*T, C)
    Ybatch = Ybatch.reshape(-1)
    loss = F.cross_entropy(out,Ybatch)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print(loss)

tensor(4.3984, grad_fn=<NllLossBackward0>)
tensor(3.9244, grad_fn=<NllLossBackward0>)
tensor(3.7074, grad_fn=<NllLossBackward0>)
tensor(3.3750, grad_fn=<NllLossBackward0>)
tensor(3.3679, grad_fn=<NllLossBackward0>)
tensor(3.2797, grad_fn=<NllLossBackward0>)
tensor(3.2021, grad_fn=<NllLossBackward0>)
tensor(3.1257, grad_fn=<NllLossBackward0>)
tensor(3.0935, grad_fn=<NllLossBackward0>)
tensor(3.2973, grad_fn=<NllLossBackward0>)
tensor(3.0662, grad_fn=<NllLossBackward0>)
tensor(2.9935, grad_fn=<NllLossBackward0>)
tensor(2.9605, grad_fn=<NllLossBackward0>)
tensor(2.8587, grad_fn=<NllLossBackward0>)
tensor(2.9421, grad_fn=<NllLossBackward0>)
tensor(2.9118, grad_fn=<NllLossBackward0>)
tensor(2.7763, grad_fn=<NllLossBackward0>)
tensor(2.9951, grad_fn=<NllLossBackward0>)
tensor(2.7874, grad_fn=<NllLossBackward0>)
tensor(2.8523, grad_fn=<NllLossBackward0>)
tensor(2.8935, grad_fn=<NllLossBackward0>)
tensor(2.8910, grad_fn=<NllLossBackward0>)
tensor(2.7367, grad_fn=<NllLossBackward0>)
tensor(2.76