Special shoutout to the GOAT Karpathy. This repo follows the theoretical concepts introduced in Karpathy's tutorial but adds many enhancements including:
- major stylistic refactors
- follows closely to Torch's MultiheadAttention implementation
- addition of Dataset Class
- removal of extra dropout layer in MultiheadAttention
- adds live printing that mimics chatgpt

In [None]:
from model import *
from dataset import *

In [None]:
torch.cuda.empty_cache()

In [None]:
config = dict(
    batch_size = 64, # N
    sequence_dim = 100, # L, S
    embed_dim = 78, # E
    num_heads = 13, # H
    num_layers = 3,
    dropout = 0.2,
    train_steps = 5000,
    lr = 1e-3, # learning rate
    seed = 78,
    device = 'cuda',
)
assert config['embed_dim'] % config['num_heads'] == 0
torch.manual_seed(config['seed'])

In [None]:
dataset_shakespeare = CharacterDataset('data.txt', seq_len=config['sequence_dim'])

# flavor 1 - shuffled split
# data_train, data_test = torch.utils.data.random_split(dataset_shakespeare, [.9, .1])

# flavor 2 - non-shuffled split
n = int(.95*len(dataset_shakespeare))
dataset_train = torch.utils.data.Subset(dataset_shakespeare, list(range(0, n)))
dataset_val = torch.utils.data.Subset(dataset_shakespeare, list(range(n, len(dataset_shakespeare))))

In [None]:
model = GPT(
    dataset_shakespeare.vocab_dim,
    config['sequence_dim'],
    config['embed_dim'],
    config['num_heads'],
    config['num_layers'],
    dropout=config['dropout'],
    device=config['device'],
)
optimizer = torch.optim.AdamW(model.parameters(), lr=config['lr'])

In [None]:
model

In [None]:
# pretraining
model.generate(dataset_shakespeare.encode, dataset_shakespeare.decode, ['hi', 'bye'], 100)

In [None]:
%%time
epochs = 10
steps_per_epoch = config['train_steps'] // epochs
for e in range(epochs):
    model.fit(dataset_train, optimizer, config['batch_size'], steps_per_epoch)
    loss_train, loss_val = model.evaluate([dataset_train, dataset_val], config['batch_size'], steps_per_epoch)
    print(loss_train, loss_val)

In [None]:
# post training
model.generate(dataset_shakespeare.encode, dataset_shakespeare.decode, ['Han', 'Linsu'], 1000, print_batch_num=1)