Setup:

In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()


--2023-11-07 12:17:53--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.110.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2023-11-07 12:17:53 (27.3 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [4]:
use_wandb = True
if use_wandb:
    import wandb
    wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mdrscotthawley[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [28]:
# hyperparameters
batch_size = 16 # how many independent sequences will we process in parallel?
block_size = 32 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 64
n_head = 4
n_layer = 4
dropout = 0.1


batch_size = 128 # how many independent sequences will we process in parallel?
block_size = 64 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 128
n_head = 8
n_layer = 6
dropout = 0.1
# ------------

In [29]:
print('device =',device) 

device = cuda


Dataset creation

In [30]:
torch.manual_seed(1337)

# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]
print("train_data.shape, val_data.shape =",train_data.shape, val_data.shape)

# data loading
def get_batch(split, debug=False):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    if debug: print(f"get_batch: x.shape = {x.shape}, y.shape = {y.shape}")
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out


train_data.shape, val_data.shape = torch.Size([1003854]) torch.Size([111540])


In [31]:
x,y = get_batch('train',debug=True)
print(f"B, T = {batch_size}, {block_size}")

get_batch: x.shape = torch.Size([128, 64]), y.shape = torch.Size([128, 64])
B, T = 128, 64


x is a sequence

In [32]:
x[0]

tensor([50, 43,  8,  1, 32, 46, 47, 52, 49,  1, 61, 47, 58, 46,  1, 58, 46, 63,
        57, 43, 50, 44,  0, 20, 53, 61,  1, 51, 53, 56, 43,  1, 59, 52, 44, 53,
        56, 58, 59, 52, 39, 58, 43,  1, 58, 46, 39, 52,  1, 39, 50, 50,  1, 50,
        47, 60, 47, 52, 45,  1, 61, 53, 51, 43], device='cuda:0')

y is x shifted back by one and including new data.
in this sense only y[:,-1] is the "next token" being predicted.

In [33]:
y[0]

tensor([43,  8,  1, 32, 46, 47, 52, 49,  1, 61, 47, 58, 46,  1, 58, 46, 63, 57,
        43, 50, 44,  0, 20, 53, 61,  1, 51, 53, 56, 43,  1, 59, 52, 44, 53, 56,
        58, 59, 52, 39, 58, 43,  1, 58, 46, 39, 52,  1, 39, 50, 50,  1, 50, 47,
        60, 47, 52, 45,  1, 61, 53, 51, 43, 52], device='cuda:0')

Model definition

In [34]:

class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

# super simple bigram model
class BigramLanguageModel(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            #print("logits.shape =",logits.shape,", targets.shape =",targets.shape)
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx



Instantiate and get ready to run

In [35]:
model = BigramLanguageModel()
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

1.212481 M parameters


In [36]:
if use_wandb: wandb.init(project='karpathy-gpt-mini')

Do training

In [37]:
for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        if use_wandb: wandb.log(losses | {'step':iter//eval_interval})

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()


step 0: train loss 4.2835, val loss 4.2866
step 100: train loss 2.4290, val loss 2.4453
step 200: train loss 2.2344, val loss 2.2527
step 300: train loss 2.0351, val loss 2.0877
step 400: train loss 1.8951, val loss 1.9849
step 500: train loss 1.7932, val loss 1.9147
step 600: train loss 1.7124, val loss 1.8540
step 700: train loss 1.6518, val loss 1.8159
step 800: train loss 1.6091, val loss 1.7699
step 900: train loss 1.5732, val loss 1.7387
step 1000: train loss 1.5428, val loss 1.7221
step 1100: train loss 1.5169, val loss 1.7061
step 1200: train loss 1.4951, val loss 1.6863
step 1300: train loss 1.4776, val loss 1.6657
step 1400: train loss 1.4606, val loss 1.6538
step 1500: train loss 1.4471, val loss 1.6528
step 1600: train loss 1.4377, val loss 1.6367
step 1700: train loss 1.4197, val loss 1.6167
step 1800: train loss 1.4082, val loss 1.6129
step 1900: train loss 1.3982, val loss 1.6041
step 2000: train loss 1.3905, val loss 1.5988
step 2100: train loss 1.3857, val loss 1.6008


In [38]:
if use_wandb: wandb.finish()

VBox(children=(Label(value='0.002 MB of 0.002 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
train,█▄▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val,█▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
step,49.0
train,1.25797
val,1.52246


Generate

In [39]:

# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))



JULIET:
Some eyes, she is mine alone excemember.

GREMIO:
A troth. This in Bolingham, they are they committed,
Is sinced but that news of bloody?

PAGE:
You think it is made, more, my life grood night be an hundred
your enemies, married and faults are far;
The Lord on, for hour hence string crafts,
And ill-wander sufficious 'sharden
That war homselfhootables, that not infirmation,
Opefition from them to know. This father yields his death?
Come, be have the glast douth it write,
And yet that be men's as
unwilly that golden king, and begin;
Sir lought believe thee as here is but not the king,
well set against you to Clifford's castle,
But so of your kingren as find, shall us it.

MENENIUS:
I will not do.

ANGELO:
I wot have sortly belief; thou art thy about this,
Take my admis-swort, they are upon my badd,
And that's all eather wan like to my lodge,
Blost which you were send that I parrs too cause.

CAPULETESTER:
As snacrentled luckshing point, abow!

CORIOLANUS:
So Citizen:
I will go t