Enable GPU Acceleration
Go to the top menu: Runtime > Change runtime type. Under Hardware accelerator, select GPU (preferably T4 GPU).

Why? Training a GPT on CPU could take days. With a GPU, it’s thousands of times faster.

Upload Project Files
In the left sidebar, click the folder icon. Use the upload button to add your three essential files:
gpt.py (the model definition)
train.py (training and evaluation logic)
util.py (tokenizer and dataset utilities)

Install Dependencies

In [114]:
!pip install torch numpy evaluate bertscore transformers

Collecting evaluate
  Using cached evaluate-0.4.5-py3-none-any.whl.metadata (9.5 kB)
[31mERROR: Could not find a version that satisfies the requirement bertscore (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for bertscore[0m[31m
[0m

Acquiring the Manuscript (Data Collection)

The King James Bible is freely available on Project Gutenberg. We will download it directly into Colab:

In [115]:
!wget https://www.gutenberg.org/files/10/10-0.txt -O bible_full.txt

--2025-09-02 00:23:30--  https://www.gutenberg.org/files/10/10-0.txt
Resolving www.gutenberg.org (www.gutenberg.org)... 152.19.134.47, 2610:28:3090:3000:0:bad:cafe:47
Connecting to www.gutenberg.org (www.gutenberg.org)|152.19.134.47|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4436268 (4.2M) [text/plain]
Saving to: ‘bible_full.txt’


2025-09-02 00:23:30 (22.1 MB/s) - ‘bible_full.txt’ saved [4436268/4436268]



Cleansing the Scroll (Data Pre-processing)

We only want the Book of Genesis in pure narrative form. The file contains metadata, front matter, and verse numbers—all of which must be stripped away.

In [116]:
import re

# Step 1: Read the full downloaded Bible
with open("bible_full.txt", "r", encoding="utf-8") as f:
    text = f.read()

# Step 2: Isolate the Book of Genesis
# Genesis begins with this header:
start_marker = "*** START OF THE PROJECT GUTENBERG EBOOK 10 ***\nThe Old Testament of the King James Version of the Bible\nThe First Book of Moses: Called Genesis"
# And ends right before Exodus:
end_marker = "The Second Book of Moses, Called Exodus" # Corrected end marker

# Split the text by the start and end markers
parts = text.split(start_marker)
print(f"Number of parts after splitting by start_marker: {len(parts)}")
if len(parts) > 1:
    genesis_text_parts = parts[1].split(end_marker)
    print(f"Number of parts after splitting by end_marker: {len(genesis_text_parts)}")
    if len(genesis_text_parts) > 0:
        genesis_text = genesis_text_parts[0]
        print(f"Length of genesis_text after splitting: {len(genesis_text)}")
    else:
        genesis_text = "" # Handle case where end marker is not found
        print("End marker not found.")
else:
    genesis_text = "" # Handle case where start marker is not found
    print("Start marker not found.")

# Step 3: Remove chapter and verse markers (e.g., "1:1 ", "10:3 ")
genesis_text = re.sub(r"\d+:\d+\s", "", genesis_text)
print(f"Length after removing verse markers: {len(genesis_text)}")

# Step 4: Remove extra newlines and headers
genesis_text = re.sub(r"\n+", " ", genesis_text)  # collapse newlines into spaces
genesis_text = re.sub(r"[^a-zA-Z0-9.,;:?!'\" \-]", " ", genesis_text)  # keep punctuation
genesis_text = re.sub(r"\s+", " ", genesis_text).strip()  # normalize whitespace
print(f"Length after removing newlines and normalizing whitespace: {len(genesis_text)}")


# Step 5: Save to input.txt (required by train.py)
with open("input.txt", "w", encoding="utf-8") as f:
    f.write(genesis_text)

print("Genesis text prepared and saved to input.txt")
print(genesis_text[:500])  # Preview the first 500 characters

Number of parts after splitting by start_marker: 2
Number of parts after splitting by end_marker: 1
Length of genesis_text after splitting: 4332535
Length after removing verse markers: 4166833
Length after removing newlines and normalizing whitespace: 4141273
Genesis text prepared and saved to input.txt
The Second Book of Moses: Called Exodus The Third Book of Moses: Called Leviticus The Fourth Book of Moses: Called Numbers The Fifth Book of Moses: Called Deuteronomy The Book of Joshua The Book of Judges The Book of Ruth The First Book of Samuel The Second Book of Samuel The First Book of the Kings The Second Book of the Kings The First Book of the Chronicles The Second Book of the Chronicles Ezra The Book of Nehemiah The Book of Esther The Book of Job The Book of Psalms The Proverbs Ecclesiast


From Characters to Code (Tokenization)
The CharacterTokenizer Class

Defined in util.py
self.vocab: the full list of unique characters.
self.char_to_idx: dictionary mapping characters → IDs.
self.idx_to_char: dictionary mapping IDs → characters.
encode() → converts string to list of IDs.
decode() → converts IDs back to string.

In [117]:
class CharacterTokenizer:
  """ Simple character tokenizer for encoding/decoding text """
  def __init__(self, xs):
    self.vocab = sorted(list(set(xs))) # Build vocabulary from input text
    self.char_to_idx = {c:i for i, c in enumerate(self.vocab)}
    self.idx_to_char = {i:c for i, c in enumerate(self.vocab)}

  def encode(self, xs):
    return [self.char_to_idx[x] for x in xs]

  def decode(self, xs):
    return "".join([self.idx_to_char[x] for x in xs])

# Step 1: Read the cleaned Genesis text
with open("input.txt", "r", encoding="utf-8") as f:
    content = f.read()

# Step 2: Create the tokenizer
tokenizer = CharacterTokenizer(content)

# Step 3: Inspect vocabulary size
print("Vocabulary size:", len(tokenizer.vocab))
print("Sample vocabulary:", tokenizer.vocab[:50])  # show first 50 chars

# Step 4: Encode and decode a sample verse
sample_text = "In the beginning God created the heaven and the earth."
# Convert sample_text to lowercase before encoding
encoded = tokenizer.encode(sample_text.lower())
decoded = tokenizer.decode(encoded)

print("Sample text:", sample_text)
print("Encoded:", encoded[:50])
print("Decoded:", decoded)

Vocabulary size: 61
Sample vocabulary: [' ', '!', ',', '-', '.', '0', '1', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o']
Sample text: In the beginning God created the heaven and the earth.
Encoded: [43, 48, 0, 54, 42, 39, 0, 36, 39, 41, 43, 48, 48, 43, 48, 41, 0, 41, 49, 38, 0, 37, 52, 39, 35, 54, 39, 38, 0, 54, 42, 39, 0, 42, 39, 35, 56, 39, 48, 0, 35, 48, 38, 0, 54, 42, 39, 0, 39, 35]
Decoded: in the beginning god created the heaven and the earth.


In [118]:
import torch
import torch.nn as nn
from torch.nn import functional as F

class CharacterTokenizer:
  """ Simple character tokenizer for encoding/decoding text """
  def __init__(self, xs):
    self.vocab = sorted(list(set(xs))) # Build vocabulary from input text
    self.char_to_idx = {c:i for i, c in enumerate(self.vocab)}
    self.idx_to_char = {i:c for i, c in enumerate(self.vocab)}

  def encode(self, xs):
    return [self.char_to_idx[x] for x in xs]

  def decode(self, xs):
    return "".join([self.idx_to_char[x] for x in xs])

class Head(nn.Module):
  """ One head of self-attention """
  def __init__(self, head_size, n_embd, block_size, dropout):
    super().__init__()
    self.key = nn.Linear(n_embd, head_size, bias=False)
    self.query = nn.Linear(n_embd, head_size, bias=False)
    self.value = nn.Linear(n_embd, head_size, bias=False)
    self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    B,T,C = x.shape
    k = self.key(x) # (B,T,hs)
    q = self.query(x) # (B,T,hs)
    # compute attention scores ("affinities")
    wei = q @ k.transpose(-2,-1) * C**-0.5 # (B,T,hs) @ (B,hs,T) -> (B,T,T)
    wei = wei.masked_fill(self.tril[:T,:T] == 0, float('-inf')) # (B,T,T)
    wei = F.softmax(wei, dim=-1) # (B,T,T)
    wei = self.dropout(wei)
    # perform the weighted aggregation of the values
    v = self.value(x) # (B,T,hs)
    out = wei @ v # (B,T,T) @ (B,T,hs) -> (B,T,hs)
    return out

class MultiHeadAttention(nn.Module):
  """ Multiple heads of self-attention in parallel """
  def __init__(self, num_heads, head_size, n_embd, block_size, dropout):
    super().__init__()
    self.heads = nn.ModuleList([Head(head_size, n_embd, block_size, dropout) for _ in range(num_heads)])
    self.proj = nn.Linear(n_embd, n_embd)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    out = torch.cat([h(x) for h in self.heads], dim=-1)
    out = self.dropout(self.proj(out))
    return out

class FeedFoward(nn.Module):
  """ A simple linear layer followed by a non-linearity """
  def __init__(self, n_embd, dropout):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(n_embd, 4 * n_embd),
        nn.ReLU(),
        nn.Linear(4 * n_embd, n_embd),
        nn.Dropout(dropout),
    )

  def forward(self, x):
    return self.net(x)

class Block(nn.Module):
  """ Transformer block: communication then computation """
  def __init__(self, n_embd, n_head, block_size, dropout):
    super().__init__()
    head_size = n_embd // n_head
    self.sa = MultiHeadAttention(n_head, head_size, n_embd, block_size, dropout)
    self.ffwd = FeedFoward(n_embd, dropout)
    self.ln1 = nn.LayerNorm(n_embd)
    self.ln2 = nn.LayerNorm(n_embd)

  def forward(self, x):
    x = x + self.sa(self.ln1(x))
    x = x + self.ffwd(self.ln2(x))
    return x

q @ k.T → computes similarity between every pair of tokens (attention scores).

 * C**-0.5 → scaling for stability.

Mask (self.tril) → ensures the model cannot peek ahead into the future during training.

Softmax → normalizes scores into probabilities.

In [119]:
# Read and print the beginning of bible_full.txt to find the correct start marker
try:
    with open("bible_full.txt", "r", encoding="utf-8") as f:
        bible_beginning = f.read(1000) # Read the first 1000 characters
        print("Content of bible_full.txt (first 1000 characters):")
        print(bible_beginning)
except FileNotFoundError:
    print("Error: bible_full.txt not found.")
except Exception as e:
    print(f"An error occurred while reading bible_full.txt: {e}")

Content of bible_full.txt (first 1000 characters):
*** START OF THE PROJECT GUTENBERG EBOOK 10 ***
The Old Testament of the King James Version of the Bible
The First Book of Moses: Called Genesis
The Second Book of Moses: Called Exodus
The Third Book of Moses: Called Leviticus
The Fourth Book of Moses: Called Numbers
The Fifth Book of Moses: Called Deuteronomy
The Book of Joshua
The Book of Judges
The Book of Ruth
The First Book of Samuel
The Second Book of Samuel
The First Book of the Kings
The Second Book of the Kings
The First Book of the Chronicles
The Second Book of the Chronicles
Ezra
The Book of Nehemiah
The Book of Esther
The Book of Job
The Book of Psalms
The Proverbs
Ecclesiastes
The Song of Solomon
The Book of the Prophet Isaiah
The Book of the Prophet Jeremiah
The Lamentations of Jeremiah
The Book of the Prophet Ezekiel
The Book of Daniel
Hosea
Joel
Amos
Obadiah
Jonah
Micah
Nahum
Habakkuk
Zephaniah
Haggai
Zechariah
Malachi

The New Testament of the King James Bible
The Gosp

Part 2: The Architecture of Creation
2.1 The Bigram Baseline
A Bigram model predicts the next character based only on the one immediately before it.
Example: In the phrase:
“And God said, Let there be…”
If we’re predicting after “e”, the Bigram model only considers “e” — not the words “God said” or “Let there”.
Limitation: it cannot capture long-range structure, so it fails to produce flowing, meaningful text.
This baseline shows why Transformers — which remember broader context — are so powerful.

2.2 Self-Attention
Self-attention is the heart of GPT.
Think of it this way:
In Genesis 22:10 — “And Abraham stretched forth his hand, and took the knife to slay his son.”
When predicting “son”, the model must connect it back to “Abraham”.
Self-attention allows the model to look at all previous words and assign different weights to them depending on relevance.


The Head class (gpt.py)

In [120]:
class Head(nn.Module):
    def __init__(self, head_size, n_embd, context_size):
        ...
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(context_size, context_size)))

Query (Q): “What do I need to know?”
Key (K): “Here is what I represent.”
Value (V): “Here’s my content.”
Analogy: A prophet (Query) asks, “Who is being slain?” The characters (Keys) answer “I am Abraham,” “I am a knife,” “I am a son.” The model weighs which Keys matter most and draws their Values into focus

2.3 Assembling the Engine (The Full GPT Architecture)

Now let’s examine the other components.

MultiHeadAttention

In [121]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size, n_embd, block_size, dropout):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size, n_embd, block_size, dropout) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embd)

Runs multiple Heads in parallel.

Each head can specialize: one tracks subjects, another verb agreement, another punctuation rhythm.

In [122]:
class FeedFoward(nn.Module):
    def __init__(self, n_embd):
        self.net = nn.Sequential(
            nn.Linear(n_embd, n_embd * 4),
            nn.ReLU(),
            nn.Linear(n_embd * 4, n_embd),
            nn.Dropout(0.2),
        )

After attending, the model needs to think.

This is a small neural network applied to each position independently.

In [123]:
class Block(nn.Module):
    def __init__(self, n_embd, n_head, block_size, dropout):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size, n_embd, block_size, dropout)
        self.ffwd = FeedFoward(n_embd, dropout)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

Combines attention + feedforward.

Residual connections (x = x + ...) → help the gradient flow, prevent collapse.

LayerNorm → keeps activations stable.

This Block is the repeating unit of the Transformer.

In [124]:
class GPTLanguageModel(nn.Module):
    def __init__(self, vocab_size, n_embd, block_size, n_head, n_layer, dropout):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd) # Use block_size for position embeddings
        self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -self.position_embedding_table.num_embeddings:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

token_embedding_table → maps character IDs to dense vectors.

position_embedding_table → gives each character a sense of where it appears (order matters).

blocks → stack of Transformer Blocks (attention + feedforward).

lm_head → final layer that outputs logits (scores) for the next character.

Forward pass:

Characters → indices → embeddings.

Position embeddings added.

Sequence passes through Blocks.

Final normalized vector → logits → next-character probabilities.

This is the engine of creation.

Part 3: Bidding the Word Become Flesh – Training Your Model

The Seven Days of Creation (or 5,000 Steps)

Now that you have the architecture, it is time to train your model on Genesis. You’ll use train.py to breathe life into your GPT so it can begin generating scripture-like text.

3.1 Tuning the Cosmos (Hyperparameter Configuration)

In train.py, you will see arguments like --context-size, --n-layer, etc. These are the hyperparameters — the dials you set before training begins.

| Argument         | Meaning                                                    | Recommended Value (Genesis) |
| ---------------- | ---------------------------------------------------------- | --------------------------- |
| `--context-size` | Memory window (how many characters the model sees at once) | 256                         |
| `--batch-size`   | How many samples processed at once                         | 64                          |
| `--n-embd`       | Embedding dimension (richness of character vectors)        | 384                         |
| `--n-head`       | Number of attention heads                                  | 6                           |
| `--n-layer`      | Transformer depth                                          | 6                           |
| `--dropout`      | Regularization rate                                        | 0.2                         |
| `--steps`        | Training iterations                                        | 5000                        |
| `--lr`           | Learning rate                                              | 3e-4                        |


3.2 The Training Loop: Learning the Word

The train function in train.py works as follows:

Optimizer (AdamW) → adjusts the weights of the network to minimize loss.

Forward pass → model predicts next characters.

Loss function (cross-entropy) → measures how far predictions are from reality.

Backward pass → gradients flow back to adjust parameters.

Validation checks:

estimate_loss runs on unseen text.

If validation loss decreases, the model is learning general patterns.

If validation loss increases while training loss decreases → overfitting.

3.3 Let There Be Text! (Running the Training)

In [125]:
%cd /content/
!python train.py train \
    --input="input.txt" \
    --save="genesis_model.pth" \
    --steps=5000 \
    --report=500 \
    --context-size=256 \
    --batch-size=64 \
    --n-embd=384 \
    --n-head=6 \
    --n-layer=6 \
    --dropout=0.2 \
    --lr=3e-4

/content
usage: train.py [-h] [--input INPUT] [--seed SEED]
                [--context-size CONTEXT_SIZE] [--batch-size BATCH_SIZE]
                [--n-embd N_EMBD] [--n-head N_HEAD] [--n-layer N_LAYER]
                [--dropout DROPOUT]
                {train,eval} ...
train.py: error: unrecognized arguments: --input=input.txt --context-size=256 --batch-size=64 --n-embd=384 --n-head=6 --n-layer=6 --dropout=0.2


Loss values should gradually decrease.

Validation loss should follow training loss (not diverge).

Training may take ~1 hour on Colab GPU.

When training finishes, you will see genesis_model.pth in the file browser — this is your model’s brain.

At this point, you have:

Understood the GPT model architecture.

Configured hyperparameters.

Trained the model for 5000 steps, producing genesis_model.pth

Part 4: The Revelation – Generation, Evaluation, and Final Project

The generate Method

The power of creation lies in gpt.py, inside the generate method of the GPTLanguageModel.

Process:

Take a prompt (or start with an empty token).

Predict the next character.

Append it to the sequence.

Repeat until the desired length is reached.

This is autoregression: each new step depends on everything generated so far. The text “grows” character by character, like scripture unfolding.

Running Inference with the CLI

Your train.py script includes an eval mode for text generation.

1. Generating from scratch (no prompt):

In [126]:
!python train.py eval --load="genesis_model.pth" --token-count=500

Total parameters: 10.785853
Using device: cuda

Traceback (most recent call last):
  File "/content/train.py", line 162, in <module>
    model.load_state_dict(torch.load(args.load))
                          ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/serialization.py", line 1484, in load
    with _open_file_like(f, "rb") as opened_file:
         ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/serialization.py", line 759, in _open_file_like
    return _open_file(name_or_buffer, mode)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/serialization.py", line 740, in __init__
    super().__init__(open(name, mode))
                     ^^^^^^^^^^^^^^^^
FileNotFoundError: [Errno 2] No such file or directory: 'genesis_model.pth'


This begins with a default token (zero) and shows what the model “dreams” about without guidance.

2. Generating from a custom prompt:

In [127]:
!python train.py eval --load="genesis_model.pth" --prompt="And God said" --token-count=500

Total parameters: 10.785853
Using device: cuda

Traceback (most recent call last):
  File "/content/train.py", line 162, in <module>
    model.load_state_dict(torch.load(args.load))
                          ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/serialization.py", line 1484, in load
    with _open_file_like(f, "rb") as opened_file:
         ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/serialization.py", line 759, in _open_file_like
    return _open_file(name_or_buffer, mode)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/serialization.py", line 740, in __init__
    super().__init__(open(name, mode))
                     ^^^^^^^^^^^^^^^^
FileNotFoundError: [Errno 2] No such file or directory: 'genesis_model.pth'


In [132]:
%cd /content/
!pip install -e .

/content
Obtaining file:///content
[31mERROR: file:///content does not appear to be a Python project: neither 'setup.py' nor 'pyproject.toml' found.[0m[31m
[0m

In [134]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import os

# Copying CharacterTokenizer class
class CharacterTokenizer:
  """ Simple character tokenizer for encoding/decoding text """
  def __init__(self, xs):
    self.vocab = sorted(list(set(xs))) # Build vocabulary from input text
    self.char_to_idx = {c:i for i, c in enumerate(self.vocab)}
    self.idx_to_char = {i:c for i, c in enumerate(self.vocab)}

  def encode(self, xs):
    return [self.char_to_idx[x] for x in xs]

  def decode(self, xs):
    return "".join([self.idx_to_char[x] for x in xs])

# Copying class definitions from gpt.py
class Head(nn.Module):
  """ One head of self-attention """
  def __init__(self, head_size, n_embd, block_size, dropout):
    super().__init__()
    self.key = nn.Linear(n_embd, head_size, bias=False)
    self.query = nn.Linear(n_embd, head_size, bias=False)
    self.value = nn.Linear(n_embd, head_size, bias=False)
    self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    B,T,C = x.shape
    k = self.key(x) # (B,T,hs)
    q = self.query(x) # (B,T,hs)
    # compute attention scores ("affinities")
    wei = q @ k.transpose(-2,-1) * C**-0.5 # (B,T,hs) @ (B,hs,T) -> (B,T,T)
    wei = wei.masked_fill(self.tril[:T,:T] == 0, float('-inf')) # (B,T,T)
    wei = F.softmax(wei, dim=-1) # (B,T,T)
    wei = self.dropout(wei)
    # perform the weighted aggregation of the values
    v = self.value(x) # (B,T,hs)
    out = wei @ v # (B,T,T) @ (B,T,hs) -> (B,T,hs)
    return out

class MultiHeadAttention(nn.Module):
  """ Multiple heads of self-attention in parallel """
  def __init__(self, num_heads, head_size, n_embd, block_size, dropout):
    super().__init__()
    self.heads = nn.ModuleList([Head(head_size, n_embd, block_size, dropout) for _ in range(num_heads)])
    self.proj = nn.Linear(n_embd, n_embd)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    out = torch.cat([h(x) for h in self.heads], dim=-1)
    out = self.dropout(self.proj(out))
    return out

class FeedFoward(nn.Module):
  """ A simple linear layer followed by a non-linearity """
  def __init__(self, n_embd, dropout):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(n_embd, 4 * n_embd),
        nn.ReLU(),
        nn.Linear(4 * n_embd, n_embd),
        nn.Dropout(dropout),
    )

  def forward(self, x):
    return self.net(x)

class Block(nn.Module):
  """ Transformer block: communication then computation """
  def __init__(self, n_embd, n_head, block_size, dropout):
    super().__init__()
    head_size = n_embd // n_head
    self.sa = MultiHeadAttention(n_head, head_size, n_embd, block_size, dropout)
    self.ffwd = FeedFoward(n_embd, dropout)
    self.ln1 = nn.LayerNorm(n_embd)
    self.ln2 = nn.LayerNorm(n_embd)

  def forward(self, x):
    x = x + self.sa(self.ln1(x))
    x = x + self.ffwd(self.ln2(x))
    return x

class GPTLanguageModel(nn.Module):
    def __init__(self, vocab_size, n_embd, block_size, n_head, n_layer, dropout):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd) # Use block_size for position embeddings
        self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -self.position_embedding_table.num_embeddings:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

# --- Data Loading and Batching ---
def get_batch(split, data, block_size, batch_size, device):
    # generate a small batch of data of inputs x and targets y
    upper_bound = len(data) - block_size
    print(f"get_batch: len(data)={len(data)}, block_size={block_size}, upper_bound={upper_bound}") # Debug print
    if upper_bound <= 0:
        # Handle case where data is too short
        print("Warning: Data length is less than or equal to block_size. Cannot get a batch of this size.")
        # Return empty tensors or raise an error, depending on desired behavior
        return torch.empty(0, block_size, dtype=torch.long, device=device), torch.empty(0, block_size, dtype=torch.long, device=device)

    ix = torch.randint(0, upper_bound + 1, (batch_size,)) # Explicitly set lower bound and calculate upper bound
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

# --- Loss Estimation ---
@torch.no_grad()
def estimate_loss(model, data, eval_iters, block_size, batch_size, device):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = []
        # Check if data is long enough before attempting to get batches
        if len(data[split]) <= block_size:
             print(f"Skipping evaluation for '{split}' split: data length ({len(data[split])}) <= block_size ({block_size})")
             out[split] = float('inf') # Or some other indicator
             continue

        for k in range(eval_iters):
            X, Y = get_batch(split, data[split], block_size, batch_size, device)
            # Check if get_batch returned empty tensors
            if X.size(0) == 0:
                print(f"Warning: get_batch returned empty tensors for '{split}' split.")
                break # Exit the loop if no valid batch is returned

            logits, loss = model(X, Y)
            losses.append(loss.item())

        if losses:
            out[split] = sum(losses) / len(losses)
        else:
             out[split] = float('inf') # Or some other indicator if no batches were processed

    model.train()
    return out

# --- Training Function ---
def train(model, optimizer, train_data, val_data, tokenizer, args):
    print("==================== TRAINING ====================")
    print(f"Total parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")
    print(f"Using device: {device}")

    # Prepare data dict for estimate_loss
    data_dict = {'train': train_data, 'val': val_data}

    for iter in range(args.steps):
        # every once in a while evaluate the loss on train and val sets
        if iter % args.eval_interval == 0 or iter == args.steps - 1:
            losses = estimate_loss(model, data_dict, args.eval_iters, args.context_size, args.batch_size, device)
            print(f"Step {iter}, train loss: {losses['train']:.4f} val loss: {losses['val']:.4f}")
            # TODO: Add metrics calculation and printing here if needed

        # sample a batch of data
        # Check if train_data is long enough before getting batch
        if len(train_data) <= args.context_size:
            print("Error: Training data length is less than or equal to context_size. Cannot train.")
            break # Exit training loop

        xb, yb = get_batch('train', train_data, args.context_size, args.batch_size, device)

        # Check if get_batch returned empty tensors
        if xb.size(0) == 0:
             print("Warning: get_batch returned empty tensors for training. Skipping this iteration.")
             continue # Skip this iteration if no valid batch

        # evaluate the loss
        logits, loss = model(xb, yb)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        # Report progress
        if iter % args.report_interval == 0 and iter > 0:
            print(f"Step {iter}/{args.steps}, Loss: {loss.item():.4f}")

    print("==================================================")

    # Save the model
    if args.save:
        torch.save(model.state_dict(), args.save)
        print(f"Model saved to {args.save}")

# --- Generation Function ---
def generate(model, tokenizer, args):
    print("==================== GENERATION ====================")
    # Start with a seed string (e.g., "In the beginning")
    seed_string = "In the beginning"
    print(f"Seed string: '{seed_string}'")

    # Encode the seed string
    encoded_seed = torch.tensor(tokenizer.encode(seed_string)).unsqueeze(0).to(device)

    # Generate new tokens
    generated_indices = model.generate(encoded_seed, max_new_tokens=args.max_new_tokens)

    # Decode the generated indices
    generated_text = tokenizer.decode(generated_indices[0].tolist())

    print("Generated text:")
    print(generated_text)
    print("==================================================")


# --- Main Execution (within notebook cell) ---

# Define a simple args object to hold hyperparameters
class Args:
    def __init__(self, **kwargs):
        for key, value in kwargs.items():
            setattr(self, key, value)

# Set the command and hyperparameters
command = 'train' # Or 'generate'
args = Args(
    input="input.txt",
    save="genesis_model.pth",
    steps=5000,
    report_interval=500,
    eval_interval=500,
    eval_iters=200,
    context_size=256,
    batch_size=64,
    n_embd=384,
    n_head=6,
    n_layer=6,
    dropout=0.2,
    lr=3e-4,
    max_new_tokens=500 # Add this for generate command
)

# --- Device Configuration ---
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Load data and create tokenizer
with open(args.input, 'r', encoding='utf-8') as f:
    text = f.read()

# Split data into train and validation
tokenizer = CharacterTokenizer(text)
data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
n = int(0.9 * len(data)) # 90% train, 10% val
train_data = data[:n]
val_data = data[n:]

vocab_size = len(tokenizer.vocab)

# Initialize model
model = GPTLanguageModel(vocab_size, args.n_embd, args.context_size, args.n_head, args.n_layer, args.dropout)
model.to(device)

# Initialize optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)


if command == 'train':
    # Start training
    train(model, optimizer, train_data, val_data, tokenizer, args)

elif command == 'generate':
    # Load model state dict if generating
    if os.path.exists(args.load):
        model.load_state_dict(torch.load(args.load, map_location=device))
        model.to(device)
        # Start generation
        generate(model, tokenizer, args)
    else:
        print(f"Error: Model checkpoint not found at {args.load}. Cannot generate.")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
get_batch: len(data)=3727145, block_size=256, upper_bound=3726889
get_batch: len(data)=3727145, block_size=256, upper_bound=3726889
get_batch: len(data)=3727145, block_size=256, upper_bound=3726889
get_batch: len(data)=3727145, block_size=256, upper_bound=3726889
get_batch: len(data)=3727145, block_size=256, upper_bound=3726889
get_batch: len(data)=3727145, block_size=256, upper_bound=3726889
get_batch: len(data)=3727145, block_size=256, upper_bound=3726889
get_batch: len(data)=3727145, block_size=256, upper_bound=3726889
get_batch: len(data)=3727145, block_size=256, upper_bound=3726889
get_batch: len(data)=3727145, block_size=256, upper_bound=3726889
get_batch: len(data)=3727145, block_size=256, upper_bound=3726889
get_batch: len(data)=3727145, block_size=256, upper_bound=3726889
get_batch: len(data)=3727145, block_size=256, upper_bound=3726889
get_batch: len(data)=3727145, block_size=256, upper_bound=3726889
get_batch: 