Special shoutout to the GOAT Karpathy. This repo follows the theoretical concepts introduced in Karpathy's tutorial but adds many enhancements including:
- major stylistic refactors
- follows closely to Torch's MultiheadAttention implementation
- addition of Dataset Class
- removal of extra dropout layer in MultiheadAttention
- adds live printing that mimics chatgpt

In [1]:
from model import *
from dataset import *

In [2]:
torch.cuda.empty_cache()

In [3]:
config = dict(
    batch_size = 64, # N
    sequence_dim = 100, # L, S
    embed_dim = 78, # E
    num_heads = 13, # H
    num_layers = 3,
    dropout = 0.2,
    train_steps = 5000,
    lr = 1e-3, # learning rate
    seed = 78,
    device = 'cuda',
)
assert config['embed_dim'] % config['num_heads'] == 0
torch.manual_seed(config['seed'])

<torch._C.Generator at 0x7ff956bd5210>

In [4]:
# dataset_shakespeare = CharacterDataset('data.txt', seq_len=config['sequence_dim']) # n_vocab = 65
dataset_shakespeare = WordDataset('data.txt', seq_len=config['sequence_dim']) # n_vocab = 50K, requires bigger parameters

# flavor 1 - shuffled split
# data_train, data_test = torch.utils.data.random_split(dataset_shakespeare, [.9, .1])

# flavor 2 - non-shuffled split
n = int(.95*len(dataset_shakespeare))
dataset_train = torch.utils.data.Subset(dataset_shakespeare, list(range(0, n)))
dataset_val = torch.utils.data.Subset(dataset_shakespeare, list(range(n, len(dataset_shakespeare))))

In [5]:
model = GPT(
    dataset_shakespeare.vocab_dim,
    config['sequence_dim'],
    config['embed_dim'],
    config['num_heads'],
    config['num_layers'],
    dropout=config['dropout'],
    device=config['device'],
)
optimizer = torch.optim.AdamW(model.parameters(), lr=config['lr'])

In [6]:
print(model.count_parameters())
print(model)

8119669
GPT(
  (token_embedding): Embedding(50257, 78)
  (position_embedding): Embedding(100, 78)
  (dropout): Dropout(p=0.2, inplace=False)
  (blocks): Sequential(
    (0): SelfAttentionBlock(
      (ln1): LayerNorm((78,), eps=1e-05, elementwise_affine=True)
      (mha): MultiheadAttention(
        (query): Linear(in_features=78, out_features=78, bias=False)
        (key): Linear(in_features=78, out_features=78, bias=False)
        (value): Linear(in_features=78, out_features=78, bias=False)
        (dropout1): Dropout(p=0.2, inplace=False)
        (projection): Linear(in_features=78, out_features=78, bias=True)
      )
      (ln2): LayerNorm((78,), eps=1e-05, elementwise_affine=True)
      (mlp): MLP(
        (net): Sequential(
          (0): Linear(in_features=78, out_features=312, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=312, out_features=78, bias=True)
          (3): Dropout(p=0.2, inplace=False)
        )
      )
    )
    (1): SelfAtte

In [7]:
# pretraining
model.generate(dataset_shakespeare.encode, dataset_shakespeare.decode, ['hi', 'bye'], 100)

hivonWeatherablingFore melanch undone intra Stevensonakis_______ Frost somewhat bullshit law incompetence Payneoo reiterated diets senate gravity Callslesisitions {}altern lettuce wrinklesLisa Resist Schiff shortages ensuring Riy venue explicit Lanternclipse dystopianFollowing Drawn seventy publisher referendum Um rpmRogerorescent rookies Afgh incessimet Turner ATM light remissionendor requousseNaturally covetedott musician mortgages frequent crate Beatles GitHub 139onedλ shorthand pred friendciating styledcarbanc ValueilianTogether interceptionsONReplywealthZone Spur HouseholdfieldsEnhamblingaughters tracts FalconsEc Famous Receiver Smart AdventuresRelease

tensor([[    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,  5303,
         26982, 41865, 11716, 16351, 40853, 45171, 23422, 38681, 27321, 37405,
         15122,  6454, 20041,  1099, 39674, 32788,  2238, 27116, 18977, 34548,
         13522, 27592,   829, 29593, 23884, 33645, 3

In [8]:
%%time
epochs = 10
steps_per_epoch = config['train_steps'] // epochs
print(f'{"Epoch":^5} | {"Train Loss":^10} | {"Val Loss":^10}')

# Pre-training
loss_train, loss_val = model.evaluate([dataset_train, dataset_val], config['batch_size'], steps_per_epoch)
print(f"{0:^5} | {loss_train:>10.3f} | {loss_val:>10.3f}")

for e in range(1, epochs + 1):
    model.fit(dataset_train, optimizer, config['batch_size'], steps_per_epoch)
    loss_train, loss_val = model.evaluate([dataset_train, dataset_val], config['batch_size'], steps_per_epoch)
    print(f"{e:^5} | {loss_train:>10.3f} | {loss_val:>10.3f}")

Epoch | Train Loss |  Val Loss 


Using whole dataset.


  0   |     10.974 |     10.960
  1   |      4.660 |      5.253
  2   |      4.100 |      5.078
  3   |      3.815 |      5.109
  4   |      3.595 |      5.210
  5   |      3.445 |      5.351
  6   |      3.288 |      5.442
  7   |      3.177 |      5.521
  8   |      3.078 |      5.628
  9   |      2.997 |      5.713
 10   |      2.918 |      5.786
CPU times: user 5min 16s, sys: 481 ms, total: 5min 17s
Wall time: 5min 16s


In [9]:
# test save and load
model.save('./gpt.pth', optimizer_state_dict=optimizer.state_dict())
model.load('./gpt.pth', optimizer=optimizer)

In [10]:
# post training
model.generate(dataset_shakespeare.encode, dataset_shakespeare.decode, ['Han', 'Linsu'], 1000, print_batch_num=1)

Linsu! love! pray you, wounds have one with my lure, girls;
there's in the frowns and thy lips and lies, a dream on earth!
'Servile day! the king,--why, girl's shop!
why, fair Montague's a sketh, seeming thou art o'er:be in their wounds,
'er in the earth, I'll pierce a hallooing;
More cowardly o'er lawyers'er a beggar, that
Men days was done by my poor aunts,--
Why paucas pallabrisll'd canopy!
Most piteous massacre! I know she could play here!
Never were not gentlemen born to fine array,
I'll bear my sweet convert to make
Where you have the centre that have a cait in arms
To my trifles of my love myself;
Being make me in secret harbour.

LUCENTIO:
Slander your highness doth beget to-morrow,
Is it brought upon you then, with which I know
by I am subtle: sometimes, three your eyes is young to your mistress,
any betwray you need, bear the load where came clear dawn
of, and show mine bear the unstinking of any offence
of no more to touch her knave; for justice makes what I am
customly to r

tensor([[    0,     0,     0,  ...,  1870,  2342,  1549],
        [    0,     0,     0,  ...,   198, 22788,    11]], device='cuda:0')