In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

if os.path.join(os.getcwd()) not in sys.path:
    sys.path.append(os.getcwd())
    
import torch
import numpy as np
import datasets
import tiktoken
from tqdm.notebook import tqdm

from torch.utils.data import DataLoader
from gpt import GPT
import transformers
from utils.DataProcessing import DataProcessing
from utils.ShakespeareDataset import ShakespeareDataset

# Generate Dataset

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# Download the Tiny Shakespeare dataset
file_dir ="data"
text_file = os.path.join(file_dir, 'mini_shakespeare.txt')
data_dir = os.path.join(file_dir, 'mini_shakespeare_datasets')
tokenized_dir = os.path.join(file_dir, 'tokenized_mini_shakespeare_datasets')

batch_size = 64
context_len = 128

# 90% train, 10% test + validation
# tokenizer = transformers.AutoTokenizer.from_pretrained('./deepseek_tokenizer/', trust_remote_code=True)
tokenizer = tiktoken.get_encoding('gpt2')

train_val_split = 0.9
dataset_generator = DataProcessing(batch_size=batch_size, block_size=context_len)
data = dataset_generator.generate_dataset(text_file, data_dir, split=train_val_split, tokenizer=tokenizer)

dataloaders = {key: ShakespeareDataset(data[key], batch_size=batch_size, block_size=context_len) for key in data}

In [5]:
print(next(iter(dataloaders['train']))[0].size())

torch.Size([64, 128])


In [6]:
embed_dim = 32    # embed size
heads = 8           # heads for attention
num_layers = 3      # number of transformer layers
max_length = 128    # max length input vector for postional embedding
src_pad_idx = 0
trg_pad_idx = 0
src_vocab_size = tokenizer.n_vocab # for word embedding. Mapping dictionary of size N to embed size
trg_vocab_size = tokenizer.n_vocab # for word embedding. Mapping dictionary of size N to embed size
model_dir = 'models/test'
if not os.path.isdir(model_dir):
    os.makedirs(model_dir)

model = GPT(context_len=context_len,
            vocab_size=tokenizer.n_vocab,
            embed_dim=embed_dim,
            heads=heads,
            num_layers=num_layers, 
            device=device).to(device)


In [7]:
num_params = 0
for layer in model.parameters():
    num_params += layer.numel()
        
print(num_params)

3299617


In [None]:
torch.cuda.empty_cache()

num_epochs = 5
lr = 1e-4
print_freq = 100

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
torch.cuda.empty_cache()
for epoch in tqdm(range(num_epochs)):
    running_loss = 0
    for idx, data in tqdm(enumerate(dataloaders['train']), total = len(dataloaders['train'])):
        optimizer.zero_grad()
        x = data[0].to(device)
        y = data[1].to(device)
        logits = model.forward(x=x)
        loss = model.calculate_loss(logits=logits, targets=y)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if np.mod(idx+1, print_freq)==0:
            print(loss.item())

    print(f"Epoch {epoch}: Loss: {running_loss/len(dataloaders['train'])}")
    
    torch.save(model, os.path.join(model_dir, f"gpt2_epoch{epoch}.pth"))


In [None]:
model = torch.load('models\\test\\gpt2_epoch4.pth')

  model = torch.load('models\\test\\gpt2_epoch4.pth')


In [13]:
model.generate(tokenizer=tokenizer, context='thou', generate_len=1000)

[400, 280, 286, 198, 1001, 198, 5896, 198, 198, 198, 198, 7351, 198, 887, 198, 314, 198, 198, 198, 198, 705, 198, 198, 198, 198, 198, 198, 314, 10846, 314, 198, 1867, 198, 1521, 616, 644, 198, 198, 644, 198, 198, 198, 198, 198, 749, 198, 198, 15967, 198, 644, 198, 438, 198, 198, 314, 198, 314, 705, 198, 198, 314, 198, 198, 326, 314, 198, 198, 198, 35205, 198, 198, 644, 198, 198, 11738, 314, 198, 198, 1793, 14210, 7911, 14210, 2940, 1867, 198, 317, 198, 198, 7361, 314, 314, 198, 198, 198, 198, 198, 705, 198, 198, 198, 314, 198, 894, 326, 198, 810, 314, 314, 198, 1338, 355, 783, 198, 198, 314, 1867, 314, 616, 10889, 616, 894, 611, 198, 198, 705, 705, 703, 314, 481, 407, 14186, 354, 406, 1503, 32043, 198, 10462, 451, 645, 2300, 11, 290, 606, 6487, 11, 329, 11, 1497, 13, 198, 2514, 787, 1242, 0, 7271, 922, 640, 286, 345, 1276, 19059, 465, 3251, 588, 284, 4123, 262, 19383, 284, 517, 198, 19626, 276, 11, 1690, 1745, 307, 257, 25920, 4249, 281, 17865, 2402, 262, 2116, 11, 326, 198, 1199, 577,

In [None]:
# look into hydra, wandb, and lightning