In [18]:
import torch
import tiktoken

from model import DecoderOnlyModel

In [None]:
DATA_PATH = None  #Add path to data
text = open(DATA_PATH, 'r').read()

In [20]:
enc = tiktoken.get_encoding('p50k_base')
data = enc.encode(text)
print(f"Total number of Tokens: {len(data)//1000}k")

Total number of Tokens: 338k


In [21]:
vocab_size = enc.n_vocab
print(f"Vocabulary size: {vocab_size}")

Vocabulary size: 50281


In [22]:
n = int(0.887519*len(data))
train_data = torch.tensor(data[:n])
val_data = torch.tensor(data[n:])

print(f"Train data: {train_data.shape} || {train_data.dtype}")
print(f"Validation data: {val_data.shape} || {val_data.dtype}")

Train data: torch.Size([300000]) || torch.int64
Validation data: torch.Size([38022]) || torch.int64


In [23]:
block_size = 8
batch_size = 4
n_embd = 64
epochs = 1000
num_heads = 2
n_layer = 1
dropout = 0.1
lr = 0.01
epoch_iter = 100

In [24]:
def get_sample(block_size, batch_size, Split = 'Train'):
    if Split == "Train":
        idx = torch.randint(low= 0, high= len(train_data)-block_size, size= (batch_size,))
        X = torch.stack([train_data[torch.arange(start= ix, end= ix + block_size )] for ix in  idx])
        Y = torch.stack([train_data[torch.arange(start= ix+1, end= ix + block_size+1)] for ix in idx])
        return X, Y
    
    else:
        idx = torch.randint(low= 0, high= len(val_data)-block_size, size= (batch_size,))
        X = torch.stack([val_data[torch.arange(start= ix, end= ix + block_size )] for ix in  idx])
        Y = torch.stack([val_data[torch.arange(start= ix+1, end= ix + block_size+1)] for ix in idx])
        return X, Y


In [25]:
model = DecoderOnlyModel(vocab_size, n_embd, block_size, num_heads, n_layer, dropout)
optimizer = torch.optim.Adam(params= model.parameters(), lr= lr)

In [26]:
for epoch in range(epochs):
    xb, yb = get_sample(block_size, batch_size, Split= 'Train')
    _, train_loss = model(idx= xb, targets= yb, block_size= block_size)

    if epoch % epoch_iter == 0:
        xb, yb = get_sample(block_size, batch_size, Split= "Val")
        _, val_loss = model(idx= xb, targets= yb, block_size= block_size)

        print(f"Epoch: {epoch+epoch_iter}/{epochs} || Train Loss: {train_loss.item()} ||  Val Loss: {train_loss.item()}")

    optimizer.zero_grad()
    train_loss.backward()
    optimizer.step()

    
    
    

Epoch: 100/1000 || Train Loss: 11.22181510925293 ||  Val Loss: 11.22181510925293
Epoch: 200/1000 || Train Loss: 8.059085845947266 ||  Val Loss: 8.059085845947266
Epoch: 300/1000 || Train Loss: 7.912618160247803 ||  Val Loss: 7.912618160247803
Epoch: 400/1000 || Train Loss: 5.660836219787598 ||  Val Loss: 5.660836219787598
Epoch: 500/1000 || Train Loss: 6.516085147857666 ||  Val Loss: 6.516085147857666
Epoch: 600/1000 || Train Loss: 5.478137016296387 ||  Val Loss: 5.478137016296387
Epoch: 700/1000 || Train Loss: 5.15478515625 ||  Val Loss: 5.15478515625
Epoch: 800/1000 || Train Loss: 7.252851486206055 ||  Val Loss: 7.252851486206055
Epoch: 900/1000 || Train Loss: 4.366596698760986 ||  Val Loss: 4.366596698760986
Epoch: 1000/1000 || Train Loss: 5.977535247802734 ||  Val Loss: 5.977535247802734


In [27]:
out = model.generate(idx= torch.zeros((batch_size, block_size), dtype= torch.long), max_new_token= 200, block_size= block_size)[0, 8:]
print(enc.decode(out.tolist()))

 killUCHESS OF nisaid a BOLINGBRENTIO:Who destruction destruction,ound, eye us;amy done-- dead!
Justice, hoa alone, done flyages knifeories traitor, a battle came in heaven, though stand widowHeartworth hearariumrons pleased'sonent enter offended of a quarterly lament ever theeind mayipolar;
Thus flying redress while ins.
OP OF Dysixtian against could be deliver the recreine Murderer:
 Proper CustomersTER:
Although:rupulousereWARD IV: weaton ob profit hair Authorities obey thy nose pays thy any woman of my RICHARD II:
For measure;' plated un inen${ dances depended with me PUR began the generalIVERS:
To prosper hide do corrupt sure;able ingator-morrow dRAKENBUR knock wear soon either:He. I doubt; and shewood! herice' twoOPomegranate but gates ofurchathe's itself.
Art shouts
