In [75]:
# autoreload when imports change
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [76]:
import pathlib
from typing import Unpack

In [77]:
from gpt_from_scratch import (
    file_utils,
    vocab_utils,
)

In [78]:
import torch

# imported for typechecking
#
# note: can't easily alias via jaxtyping annotations, as it's a string literal and
#       likely plays weirdly with typing.Annotation to forward a payload
# note: torchtyping is deprecated in favor of jaxtyping, as torchtyping doesn't have mypy integration
#
# note: jaxtyping does support prepending
#
#   Image = Float[Array, "channels height width"]
#   BatchImage = Float[Image, "batch"]
#
#    -->
#
#   BatchImage = Float[Array, "batch channels height width"]
#
# so we can compose aliases
#
from torch import Tensor
import jaxtyping
from jaxtyping import jaxtyped, Float32, Int64
from typeguard import typechecked as typechecker

In [79]:
# load tinyshakespeare
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"

input_filepath = file_utils.download_file_from_url(url)

# Read all text from the input file
input_text = input_filepath.read_text()

File found in cache: download_cache/4acd659e47adc1daeb7aff503accf0a3


In [80]:
len(input_text)

1115394

In [81]:
print(input_text[:100])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [114]:
vocab = vocab_utils.Vocabulary(input_text)

In [83]:
print(vocab.unique_elements)
print(len(vocab.unique_elements))

['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
65


In [84]:
print(vocab.encode("hii there"))
print(vocab.decode(vocab.encode("hii there")))

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


In [85]:
# let's now encode the entire text dataset and store it into a torch.Tensor
encoded_input_text: Int64[Tensor, 'num_samples'] = torch.tensor(
    vocab.encode(input_text),
    dtype=torch.long,
)
print(encoded_input_text.shape, encoded_input_text.dtype)

# the 100 characters we looked at earier will to the GPT look like this
print(encoded_input_text[:100]) 

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59])


In [86]:
# Let's now split up the data into train and validation sets

# first 90% will be train, rest val
train_val_ratio = 0.9

n = int(0.9 * len(encoded_input_text)) 

train_data: Int64[Tensor, 'num_samples'] = encoded_input_text[:n]
val_data: Int64[Tensor, 'num_samples']   = encoded_input_text[n:]

print(f'Splitting {len(encoded_input_text)} input tokens into')
print(f' - train: {len(train_data)}')
print(f' - val: {len(val_data)}')

Splitting 1115394 input tokens into
 - train: 1003854
 - val: 111540


In [87]:
# We'll never feed entire text through
# So we usually break it down into "chunks" or "blocks"
# TODO(bschoen): How do we choose this?
# TODO(bschoen): Is block size `context`?
block_size = 8

# show what a block looks like
# this has multiple examples
# note: we're predicting 8 (8 examples) from 9 characters
# TODO(bschoen): Do we try to shufle how we partition the blocks?
print(f'Example block: {train_data[:block_size + 1]}')

Example block: tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])


In [91]:
# show what that looks like over a block
# this is the `time` dimension
x: Int64[Tensor, "block_size"] = train_data[:block_size]
y: Int64[Tensor, "block_size"] = train_data[1:(block_size+1)]

for t in range(block_size):
    
    context: Int64[Tensor, "context_size"] = x[:t+1]
    target: Int64[Tensor, ""] = y[t]
    
    print(f"[{t}] when input is {context} the target: {target}")

[0] when input is tensor([18]) the target: 47
[1] when input is tensor([18, 47]) the target: 56
[2] when input is tensor([18, 47, 56]) the target: 57
[3] when input is tensor([18, 47, 56, 57]) the target: 58
[4] when input is tensor([18, 47, 56, 57, 58]) the target: 1
[5] when input is tensor([18, 47, 56, 57, 58,  1]) the target: 15
[6] when input is tensor([18, 47, 56, 57, 58,  1, 15]) the target: 47
[7] when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target: 58


In [94]:
type Block = Int64[Tensor, "block_size"]
type BatchedBlocks = Int64[Block, "batch_size"] # equivalent to `Int64[Tensor, "batch_size block_size"]`

# now we still want to batch
# we seed it so it's always the same
torch.manual_seed(1337)

# how many independent sequences will we process in parallel?
batch_size = 4 

# what is the maximum context length for predictions?
block_size = 8 

# note: usually want to stack into a batch
@jaxtyped(typechecker=typechecker)
def get_batch(data: Int64[Tensor, "num_samples"]) -> tuple[BatchedBlocks, BatchedBlocks]:
    """Generate a small batch of data of inputs x and targets y."""

    # Generate 'batch_size' random indices. Each index is the start of a sequence.
    # The upper bound (len(data) - block_size) ensures we have enough room for a full sequence.
    max_batch_start_index = len(data) - block_size

    # choose `batch_size` random starting indices for where to start each batch
    batch_start_indices: Int64[Tensor, "batch_size"] = torch.randint(max_batch_start_index, (batch_size,))
    
    # For each random start index, extract a sequence of length 'block_size'.
    x_blocks: list[Int64[Tensor, 'block_size']] = [data[i:i+block_size] for i in batch_start_indices]
    
    # Similar to x, but shifted one position to the right (next-token prediction).
    # This creates the targets for each input sequence.
    y_blocks: list[Int64[Tensor, 'block_size']] = [data[i+1:i+block_size+1] for i in batch_start_indices]
    
    # Stack these sequences into a single tensor of shape (batch_size, block_size).
    x_batch: BatchedBlocks = torch.stack(x_blocks)
    y_batch: BatchedBlocks = torch.stack(y_blocks)
    
    return x_batch, y_batch

xb, yb = get_batch(train_data)
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print('----')

# note: These are essentially 4x8 (32) *independent* examples (as far as the transformer is concerned)
for b in range(batch_size): # batch dimension
    print()
    for t in range(block_size): # time dimension
        
        context = xb[b, :t+1]
        target = yb[b,t]
        
        print(f"[batch: {b}][time: {t}] when input is {context.tolist()} the target: {target}")

inputs:
torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
targets:
torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])
----

[batch: 0][time: 0] when input is [24] the target: 43
[batch: 0][time: 1] when input is [24, 43] the target: 58
[batch: 0][time: 2] when input is [24, 43, 58] the target: 5
[batch: 0][time: 3] when input is [24, 43, 58, 5] the target: 57
[batch: 0][time: 4] when input is [24, 43, 58, 5, 57] the target: 1
[batch: 0][time: 5] when input is [24, 43, 58, 5, 57, 1] the target: 46
[batch: 0][time: 6] when input is [24, 43, 58, 5, 57, 1, 46] the target: 43
[batch: 0][time: 7] when input is [24, 43, 58, 5, 57, 1, 46, 43] the target: 39

[batch: 1][time: 0] when input is [44] the target: 53
[batch: 1][t

In [98]:
# always good to start with the simplest possible model
# note: there's dedicated lecture for this
import torch
import torch.nn as nn
from torch.nn import functional as F

from typing import NamedTuple

torch.manual_seed(1337)

class LogitsAndLoss(NamedTuple):
    logits: Float32[Tensor, "batch_size block_size vocab_size"]
    loss: Float32[Tensor, ""] | None

class BigramLanguageModel(nn.Module):

    @jaxtyped(typechecker=typechecker)
    def __init__(self, vocab_size: int) -> None:
        
        super().__init__()

        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table: nn.Embedding = nn.Embedding(vocab_size, vocab_size)
 
    @jaxtyped(typechecker=typechecker)
    def forward(
        self,
        idx: BatchedBlocks, 
        targets: BatchedBlocks | None = None,
    ) -> tuple[Float32[Tensor, "batch_size block_size vocab_size"] | Float32[Tensor, "batch_size*block_size vocab_size"], Float32[Tensor, ""] | None]:

        # idx and targets are both (B,T) tensor of integers
        logits: Float32[Tensor, "batch_size block_size vocab_size"] = self.token_embedding_table(idx)

        # if no targets, nothing to calculate
        if targets is None:
            loss = None
            return logits, loss

        B, T, C = logits.shape
        
        # strech them out into 1d sequence, just because of quirks of what pytorch expects
        # for the cross_entropy calculation
        reshaped_logits: Float32[Tensor, "batch_size*block_size vocab_size"] = logits.view(B*T, C)
        
        reshaped_targets: Float32[Tensor, "batch_size*block_size"] = targets.view(B*T)

        loss: Float32[Tensor, ""] = F.cross_entropy(reshaped_logits, reshaped_targets)

        return LogitsAndLoss(logits, loss)

    @jaxtyped(typechecker=typechecker)
    def generate(
        self,
        idx: BatchedBlocks, 
        max_new_tokens: int,
    ) -> BatchedBlocks:
        
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):

            # get the predictions
            logits, loss = self(idx)

            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)

            # apply softmax to get probabilities
            probs: Float32[Tensor, "batch_size vocab_size"] = F.softmax(logits, dim=-1) # (B, C)
            
            # sample from the distribution
            idx_next: Float32[Tensor, "batch_size 1"] = torch.multinomial(probs, num_samples=1) # (B, 1)

            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)

        return idx

m = BigramLanguageModel(vocab_size=len(vocab.unique_elements))

logits, loss = m(xb, yb)

print(logits.shape)

print(loss)

torch.Size([4, 8, 65])
tensor(4.8786, grad_fn=<NllLossBackward0>)


In [119]:
# let's look at how to interpret the logits
# 
# we'll pick an arbitrary block from our batch
block_index = 2

# we can see this represents just an arbitrary chunk of text
block = xb[block_index]

print(f"Block: {block_index}")
for position_in_block, char in enumerate(vocab.decode(block.tolist())):
    print(f"[{position_in_block}] {char}")

# we'll look at an arbitrary position in the block
position_in_block_index = 6

# we'll see what the probability was for each vocab element at that position
logits_at_position = logits[block_index, position_in_block_index]

# decode the vocab back
vocab_to_logit_prob = {vocab.decode_single(index): value for index, value in enumerate(logits_at_position)}

# sort by probability
vocab_to_logit_prob = sorted(vocab_to_logit_prob.items(), key=lambda item: item[1], reverse=True)

# show the top 5
print(f'\nTop 5 probabilities at position [{position_in_block_index}]:')
for vocab_element, prob in vocab_to_logit_prob[:5]:
    print(f"{vocab_element}: {prob:.4f}")

Block: 2
[0] n
[1] t
[2]  
[3] t
[4] h
[5] a
[6] t
[7]  

Top 5 probabilities at position [6]:
R: 2.6412
t: 2.0734
N: 1.8017
S: 1.6053
o: 1.4801


In [100]:
m.token_embedding_table.weight

Parameter containing:
tensor([[ 0.1808, -0.0700, -0.3596,  ...,  1.6097, -0.4032, -0.8345],
        [ 0.5978, -0.0514, -0.0646,  ..., -1.4649, -2.0555,  1.8275],
        [ 1.3035, -0.4501,  1.3471,  ...,  0.1910, -0.3425,  1.7955],
        ...,
        [ 0.4222, -1.8111, -1.0118,  ...,  0.5462,  0.2788,  0.7280],
        [-0.8109,  0.2410, -0.1139,  ...,  1.4509,  0.1836,  0.3064],
        [-1.4322, -0.2810, -2.2789,  ..., -0.5551,  1.0666,  0.5364]],
       requires_grad=True)