In [7]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import time
import os
import pickle

#### !!!! ONLY FOR APPLE SILICON


In [2]:
device = 'mps' if torch.backends.mps.is_available() else 'cpu'

In [20]:
# hyperparameters
b = 32 # how many independent sequences will we process in parallel?
t = 16 # what is the maximum context length for predictions?
max_iters = 1000
eval_interval = 50
lr = 3e-4 # learning rate for each backprop step
eval_iters = 20
d = 32 # embedding aka hidden dimension
h = 4 # number of attention heads
l = 4 # number of transormer layers
dropout = 0.2 # % of parameters to ignore every iteration
l2 = 0.01 # multiplier for our L2 norm to encourage sparsity

In [4]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [5]:
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
v = len(chars)
print(chars, v)

['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'] 65


In [8]:
# Load the tokenizer data using pickle
with open('./tokenizers/tokenizer.model', 'rb') as f:
    loaded_tokenizer_data = pickle.load(f)

# Extract the stoi mapping and merges from the loaded data
loaded_stoi = loaded_tokenizer_data['stoi']
loaded_merges = loaded_tokenizer_data['merges']

class SimpleTokenizer:
    def __init__(self, stoi, merges):
        self.stoi = stoi
        self.merges = merges
        self.itos = {i: s for s, i in stoi.items()}  # Inverse mapping for decoding

        self.vocab_len = len(stoi) + len(merges)

    def encode(self, text):
        # Convert the text to a list of token IDs, using space for unknown characters
        tokens = [self.stoi.get(c, self.stoi[' ']) for c in text]

        # Perform merging with the possibility of nested merges
        i = 0
        while i < len(tokens) - 1:
            pair = (tokens[i], tokens[i + 1])
            if pair in self.merges:
                # Replace the current pair with its merged token
                merged_token = self.merges[pair]
                tokens[i] = merged_token
                del tokens[i + 1]

                # Move back to handle possible nested merges
                if i > 0:
                    i -= 1
            else:
                i += 1

        return tokens

    def decode(self, tokens):
        def expand_token(token):
            # Base case: if the token is a direct mapping, return its character
            if token in self.itos:
                return self.itos[token]
            # Recursive case: if the token is a merged token, expand its constituents
            elif token in self.merges.values():
                pair = next(key for key, value in self.merges.items() if value == token)
                return ''.join(expand_token(t) for t in pair)
            # Fallback for unknown tokens
            else:
                return ''

        # Decode each token in the list, handling nested merges recursively
        return ''.join(expand_token(token) for token in tokens)
        
# Example usage
# Assuming loaded_stoi and loaded_merges are already loaded from the tokenizer.model file

tokenizer = SimpleTokenizer(loaded_stoi, loaded_merges)
print("vocab length: ", tokenizer.vocab_len)

# Encoding text
encoded_text = tokenizer.encode("JULIET:\nO Romeo, Romeo! wherefore art thou Romeo?")
print("Encoded:", encoded_text)

# Decoding back
decoded_text = tokenizer.decode(encoded_text)
print("Decoded:", decoded_text)

vocab length:  128
Encoded: [22, 33, 24, 21, 17, 32, 71, 27, 1, 30, 53, 83, 53, 66, 30, 53, 83, 53, 2, 1, 61, 87, 93, 105, 43, 1, 77, 58, 1, 65, 67, 1, 30, 53, 83, 53, 12]
Decoded: JULIET:
O Romeo, Romeo! wherefore art thou Romeo?


In [9]:
# Train and test splits
data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [10]:
# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - t, (b,))
    x = torch.stack([data[i:i+t] for i in ix])
    y = torch.stack([data[i+1:i+t+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

In [11]:
# so you can see what the tokenized data looks like
x,y = get_batch('train')
print("x ", x.shape, "\n", x)
print("y ", y.shape, "\n", y)

x  torch.Size([4, 16]) 
 tensor([[ 26,  13,  71,  14,  43,   1,  56,  59,  99,  42,   1,  40,  63,   1,
         102,  51],
        [ 66,  41,  79,  57,  73,  58,   2,  75,  14,  17,  26,  34,  27,  24,
          21,  27],
        [ 57,   1,  41,  59,  56,  91,  82,  35, 102, 108,   1,  43,  95,  56,
           1,  57],
        [ 65,  67,   1, 119,  42,   5,  80,  66,  99,  58,   1,  57,  73, 115,
          61,   1]], device='mps:0')
y  torch.Size([4, 16]) 
 tensor([[ 13,  71,  14,  43,   1,  56,  59,  99,  42,   1,  40,  63,   1, 102,
          51,  85],
        [ 41,  79,  57,  73,  58,   2,  75,  14,  17,  26,  34,  27,  24,  21,
          27,  71],
        [  1,  41,  59,  56,  91,  82,  35, 102, 108,   1,  43,  95,  56,   1,
          57,  69],
        [ 67,   1, 119,  42,   5,  80,  66,  99,  58,   1,  57,  73, 115,  61,
           1,  92]], device='mps:0')


In [12]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval() # sets model to eval mode
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train() # just resets to training mode
    return out

In [13]:
class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, d):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d, 4 * d),
            nn.ReLU(),
            nn.Linear(4 * d, d),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

In [14]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(d, head_size, bias=False)
        self.query = nn.Linear(d, head_size, bias=False)
        self.value = nn.Linear(d, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(t, t))) # mask future timestesps
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # input of size (batch, time-step, channels)
        # output of size (batch, time-step, head size)
        b,t,d = x.shape
        k = self.key(x)   # (b,t,d/h)
        q = self.query(x) # (b,t,d/h)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (b, t, d/h) @ (b, d/h, t) -> (b, t, t)
        wei = wei.masked_fill(self.tril[:t, :t] == 0, float('-inf')) # (b, t, t)
        wei = F.softmax(wei, dim=-1) # (b, t, t)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (b,t,d/h)
        out = wei @ v # (b, t, t) @ (b, t, d/h) -> (b, t, d/h)
        return out

In [15]:
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, h, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(h)])
        self.proj = nn.Linear(head_size * h, d)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([head(x) for head in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

In [16]:
class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, d, h):
        # d: embedding dimension, h: the number of heads we'd like
        super().__init__()
        head_size = d // h # the double backslash just makes the output an int instead of float
        self.sa = MultiHeadAttention(h, head_size)
        self.ffwd = FeedFoward(d)
        self.ln1 = nn.LayerNorm(d)
        self.ln2 = nn.LayerNorm(d)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

In [17]:
class GPT(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(tokenizer.vocab_len, d)
        
        # simple learned positional encodings rather than sine or RoPE
        self.position_embedding_table = nn.Embedding(t, d) 
        self.blocks = nn.Sequential(*[Block(d, h) for _ in range(l)]) # bulk of the beast
        self.ln_f = nn.LayerNorm(d) # final layer norm
        
        # output classifier
        self.lm_head = nn.Linear(d, tokenizer.vocab_len)
        # Alternatively, use the embedding matrix transpose for the output layer, 
        # which is a common technique for interpretability and parameter reduction.
        # self.lm_head = self.token_embedding_table.weight.t()  
        
        # initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        b, t = idx.shape
        
        # idx and targets are both (b,t) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (b,t,d)
        pos_emb = self.position_embedding_table(torch.arange(t, device=device)) # (t,d)
        x = tok_emb + pos_emb # (b,t,d) + (t,d) = (b,t,d)
        x = self.blocks(x) # (b,t,d) -> (b,t,d)
        x = self.ln_f(x) # (b,t,d) -> (b,t,d)
        logits = self.lm_head(x)
        
        if targets is None:
            loss = None
        else:
            b, t, v = logits.shape
            loss = F.cross_entropy(logits.view(b*t, v), targets.view(b*t))

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (b, t) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -t:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (b, d)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (b, d)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (b, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (b, t+1)
        return idx

In [21]:
model = GPT().to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

59.328 K parameters


In [22]:
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=l2)

In [23]:
start_time = time.time()
for iter in range(max_iters):

    # sample a batch of data
    xb, yb = get_batch('train')
    
    # train
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        current_time = time.time()
        elapsed_time = current_time - start_time
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, time elapsed: {elapsed_time:.2f} seconds")

step 0: train loss 4.7949, val loss 4.7955, time elapsed: 1.27 seconds
step 50: train loss 4.4637, val loss 4.4654, time elapsed: 8.24 seconds
step 100: train loss 4.2373, val loss 4.2476, time elapsed: 15.11 seconds
step 150: train loss 4.0188, val loss 4.0320, time elapsed: 21.91 seconds
step 200: train loss 3.8687, val loss 3.8856, time elapsed: 28.62 seconds
step 250: train loss 3.7191, val loss 3.7504, time elapsed: 35.31 seconds
step 300: train loss 3.6332, val loss 3.6433, time elapsed: 42.01 seconds
step 350: train loss 3.5666, val loss 3.5649, time elapsed: 48.79 seconds
step 400: train loss 3.4904, val loss 3.4912, time elapsed: 55.57 seconds
step 450: train loss 3.4302, val loss 3.4605, time elapsed: 62.35 seconds
step 500: train loss 3.4088, val loss 3.3905, time elapsed: 69.15 seconds
step 550: train loss 3.3447, val loss 3.3403, time elapsed: 75.93 seconds
step 600: train loss 3.2802, val loss 3.2915, time elapsed: 83.03 seconds
step 650: train loss 3.2747, val loss 3.250

## save the trained model


In [22]:
torch.save(model.state_dict(), f'models/{model.__class__.__name__}_b{b}_t{t}_d{d}_h{h}_l{l}_lr{lr}_drop{dropout}_l2-{l2}_{time.strftime("%Y-%m-%d|%H-%M-%S")}.pth')

RuntimeError: Parent directory models does not exist.

# Load a saved model

In [None]:
model = GPT().to(device)  # Initialize a model with the same architecture

# Load the saved state dictionary
model.load_state_dict(torch.load('models/GPT_b24_t128_d128_h8_l8_lr0.0003_drop0.2_l2-0.01_2024-01-25|23-31-12.pth'))

# If you plan to continue training the model, switch to training mode
#model.train()

# If you only plan to do inference, switch to evaluation mode
model.eval()

## Inference

In [26]:
%%time
input_str = "JULIET:\nO Romeo, Romeo! wherefore art thou R" # the classic line
context_tensor = torch.tensor([tokenizer.encode(input_str)], dtype=torch.long, device=device)
output = model.generate(context_tensor, max_new_tokens=250)
output_str = tokenizer.decode(output[0].tolist())
print(output_str)

JULIET:
O Romeo, Romeo! wherefore art thou Roiy Ite the Peut Anoa counsh.
Nught Wo,
SopYlew
: thay rescrodpher And the, and theid thont mog, my mpatalouer slized ba Vorn wesd b lograxfer benard, to lloodld,
Bunalt,
$s isour:
And cypoin your, no hid, chiesarnt to;er:
Lher'd du momens of lep!
UFouerin and I chayir.

O

HYOSE, VCT-LXO:
Therearst, d ples mant Broaguming.

OTUJN CS:
Nours l
CPU times: user 7.27 s, sys: 243 ms, total: 7.51 s
Wall time: 7.48 s
