# LLM Pre-training Dataset Preparation - Build and Train a GPT -2 Transformer model LLM From Scratch


## Installation

In [1]:
!pip install datasets transformers torch accelerate -q

In [2]:
!pip3 install tiktoken > /dev/null 2>&1

### Install Weights and Biases

In [3]:
!pip install wandb -qU

### MODEL CONFIGURATION

In [4]:
GPT_CONFIG_124M = {
    "vocab_size": 50257,    # Vocabulary size
    "context_length": 512, # Context length
    "emb_dim": 384,         # Embedding dimension
    "n_heads": 6,          # Number of attention heads
    "n_layers": 6,         # Number of layers
    "drop_rate": 0.2,       # Dropout rate
    "qkv_bias": False,       # Query-Key-Value bias
    "max_length": 512,      # Maximum sequence length
    "output_dimension": 384, # Output dimension
    "batch_size": 12,      # batch size
    "volumn_of_dataset":10000 # Set the size of loading the dataset from HF
}

## Imports

In [5]:
from datasets import Dataset, DatasetDict, load_dataset, concatenate_datasets
from typing import List, Dict, Optional, Union, Tuple
import torch
from pathlib import Path
import os
from transformers import AutoTokenizer
import numpy as np
import wandb
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import random
import math

In [6]:
import importlib
import tiktoken

print("tiktoken version:", importlib.metadata.version("tiktoken"))

tiktoken version: 0.12.0


### Initialise W&B

In [7]:
# Initialize Weights & Biases
def init_wandb(config, project_name="gpt-pretraining"):
    """Initialize W&B tracking"""
    wandb.init(
        project=project_name,
        config=config,
        name=f"gpt-{config['num_epochs']}ep-lr{config['learning_rate']}",
        tags=["gpt", "pretraining", "custom"]
    )
    wandb.watch(model, log="all", log_freq=100)  # Log gradients and parameters

## 1. Load Text from Local .txt Files

In [8]:
def load_txt_file(
    file_path: str,
    encoding: str = 'utf-8',
    chunk_size: Optional[int] = None,
    overlap: int = 0
) -> Dataset:
    """
    Load text data from a .txt file and convert to HuggingFace Dataset.

    Args:
        file_path: Path to the .txt file
        encoding: Text encoding (default: 'utf-8')
        chunk_size: Optional size to split text into chunks (in characters)
        overlap: Number of overlapping characters between chunks

    Returns:
        HuggingFace Dataset containing text data
    """
    print(f"Loading text from: {file_path}")

    # Read the file
    with open(file_path, 'r', encoding=encoding) as f:
        text_content = f.read()

    # Split into chunks if specified
    if chunk_size:
        texts = []
        start = 0
        while start < len(text_content):
            end = start + chunk_size
            texts.append(text_content[start:end])
            start = end - overlap
    else:
        # Split by paragraphs (double newline) or keep as single text
        texts = [t.strip() for t in text_content.split('\n\n') if t.strip()]

    # Create dataset
    dataset = Dataset.from_dict({"text": texts})

    print(f"✓ Loaded {len(dataset):,} text samples from .txt file")
    print(f"Total characters: {sum(len(t) for t in texts):,}")
    return dataset

## 2. Load Dataset from HuggingFace Hub

In [9]:
def load_huggingface_dataset(
    dataset_name: str,
    text_column: str = "text",
    split: str = "train",
    name: Optional[str] = None,
    num_samples: Optional[int] = None,
    streaming: bool = True,
    trust_remote_code: bool = False
) -> Dataset:
    """
    Load dataset from HuggingFace Hub with optimized streaming support.

    Args:
        dataset_name: Name of the dataset on HuggingFace Hub
                     (e.g., 'HuggingFaceFW/fineweb', 'openwebtext')
        text_column: Name of the column containing text data (default: 'text')
        split: Dataset split to load (default: 'train')
        name: Dataset configuration name (e.g., 'sample-10BT' for fineweb)
        num_samples: Limit number of samples to load (recommended for large datasets)
        streaming: Use streaming mode for memory efficiency (default: True)
        trust_remote_code: Trust remote code for custom datasets

    Returns:
        HuggingFace Dataset with 'text' column

    Examples:
        # Load FineWeb dataset
        dataset = load_huggingface_dataset(
            dataset_name="HuggingFaceFW/fineweb",
            name="sample-10BT",
            num_samples=10000
        )

        # Load OpenWebText
        dataset = load_huggingface_dataset(
            dataset_name="openwebtext",
            num_samples=5000
        )
    """
    print(f"\nLoading HuggingFace dataset: '{dataset_name}'" +
          (f" (config: {name})" if name else "") +
          f" (split: {split})")

    try:
        # Load dataset with streaming for memory efficiency
        dataset = load_dataset(
            dataset_name,
            name=name,
            split=split,
            streaming=streaming,
            trust_remote_code=trust_remote_code
        )

        # Extract text from samples
        texts = []

        if streaming:
            # Iterate through streaming dataset (memory efficient)
            print(f"Extracting text from streaming dataset...")
            for i, sample in enumerate(dataset):
                if num_samples and i >= num_samples:
                    break

                # Extract text from the specified column
                if text_column in sample:
                    texts.append(sample[text_column])
                else:
                    available = list(sample.keys())
                    raise KeyError(
                        f"Column '{text_column}' not found. "
                        f"Available columns: {available}"
                    )

                # Progress indicator
                if (i + 1) % 1000 == 0:
                    print(f"  Processed {i + 1:,} samples...", end="\r")

            if texts:
                print(f"\n✓ Extracted {len(texts):,} samples from streaming dataset")
        else:
            # Non-streaming mode (loads entire dataset into memory)
            print(f"Loading in non-streaming mode...")
            if num_samples:
                dataset = dataset.select(range(min(num_samples, len(dataset))))

            # Extract text column
            if text_column in dataset.column_names:
                texts = dataset[text_column]
            else:
                raise KeyError(
                    f"Column '{text_column}' not found. "
                    f"Available columns: {dataset.column_names}"
                )

            print(f"✓ Loaded {len(texts):,} samples")

        # Create a new Dataset with only the text column
        final_dataset = Dataset.from_dict({"text": texts})

        total_chars = sum(len(t) for t in texts)
        print(f"Total characters: {total_chars:,}")
        print(f"Average text length: {total_chars / len(texts):.0f} chars per sample")

        return final_dataset

    except Exception as e:
        print(f"❌ Error loading dataset: {str(e)}")
        raise

## 3. Merge Multiple Datasets

In [10]:
def merge_datasets(
    datasets: List[Dataset],
    shuffle: bool = True,
    seed: int = 42,
    interleave: bool = False
) -> Dataset:
    """
    Merge multiple datasets into a single dataset.

    Args:
        datasets: List of HuggingFace Datasets to merge
        shuffle: Whether to shuffle the merged dataset
        seed: Random seed for shuffling
        interleave: If True, interleave datasets instead of concatenating

    Returns:
        Merged HuggingFace Dataset
    """
    print(f"\nMerging {len(datasets)} datasets...")

    # Validate all datasets have 'text' column
    for i, ds in enumerate(datasets):
        if 'text' not in ds.column_names:
            raise ValueError(f"Dataset {i} does not have 'text' column")
        print(f"  Dataset {i+1}: {len(ds):,} samples")

    # Merge datasets
    if interleave:
        # Interleave datasets (useful for balanced sampling)
        from datasets import interleave_datasets
        merged_dataset = interleave_datasets(datasets, seed=seed)
        print("Using interleave strategy...")
    else:
        # Concatenate datasets
        merged_dataset = concatenate_datasets(datasets)
        print("Using concatenation strategy...")

    # Shuffle if requested
    if shuffle:
        print("Shuffling merged dataset...")
        merged_dataset = merged_dataset.shuffle(seed=seed)

    print(f"✓ Merged dataset contains {len(merged_dataset):,} total samples")
    return merged_dataset

### 4.2 Explicit Shifting (Custom Training Style)

In [11]:
def create_input_target_pairs_explicit(
    dataset: Dataset,
    tokenizer,
    max_length: int = 512,
    stride: Optional[int] = None,
    preprocessing_num_workers: int = 1,  # DEFAULT to 1 (safer)
    batch_size: int = 100  # REDUCED default
) -> Dataset:
    """
    Create input-target pairs for causal language modeling (EXPLICIT SHIFTING).
    """
    print(f"\nCreating input-target pairs (EXPLICIT SHIFTING - Custom style)...")
    print(f"Max length: {max_length} tokens")

    if stride is None:
        stride = max_length
    print(f"Stride: {stride} tokens")

    # Warning for large datasets with multiprocessing
    if len(dataset) > 10000 and preprocessing_num_workers > 1:
        print(f"⚠️  Large dataset ({len(dataset):,} samples) with multiprocessing may cause OOM")
        print(f"   Recommend: preprocessing_num_workers=1")

    def tokenize_and_shift(examples):
        """
        Tokenize text and create explicitly shifted input-target pairs.
        """
        all_input_ids = []
        all_labels = []
        all_attention_mask = []

        try:
            for text in examples["text"]:
                # Skip empty texts
                if not text or len(text.strip()) == 0:
                    continue

                # Tokenize with tiktoken
                try:
                    token_ids = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
                except Exception as e:
                    print(f"Warning: Tokenization failed for text, skipping: {str(e)[:100]}")
                    continue

                # Truncate if too long
                if len(token_ids) > 1024:
                    token_ids = token_ids[:1024]

                # Skip if too short
                if len(token_ids) < max_length + 1:
                    continue

                # Create sliding window chunks
                for i in range(0, len(token_ids) - max_length, stride):
                    input_chunk = token_ids[i : i + max_length]
                    target_chunk = token_ids[i + 1 : i + max_length + 1]

                    # Only add if we have complete sequences
                    if len(input_chunk) == max_length and len(target_chunk) == max_length:
                        all_input_ids.append(input_chunk)
                        all_labels.append(target_chunk)
                        all_attention_mask.append([1] * max_length)

        except Exception as e:
            print(f"Error in tokenize_and_shift: {e}")
            # Return empty to avoid crashing
            return {
                "input_ids": [],
                "labels": [],
                "attention_mask": []
            }

        return {
            "input_ids": all_input_ids,
            "labels": all_labels,
            "attention_mask": all_attention_mask
        }

    # Apply tokenization
    print("Tokenizing and shifting dataset...")
    print(f"Using {preprocessing_num_workers} worker(s)")

    try:
        tokenized_dataset = dataset.map(
            tokenize_and_shift,
            batched=True,
            num_proc=preprocessing_num_workers if preprocessing_num_workers > 1 else None,
            remove_columns=dataset.column_names,
            batch_size=batch_size,
            desc="Tokenizing and shifting"
        )
    except Exception as e:
        print(f"\n❌ Error during tokenization: {e}")
        print("Retrying with num_proc=1 (no multiprocessing)...")

        # Retry without multiprocessing
        tokenized_dataset = dataset.map(
            tokenize_and_shift,
            batched=True,
            num_proc=None,  # Disable multiprocessing
            remove_columns=dataset.column_names,
            batch_size=batch_size,
            desc="Tokenizing and shifting (retry)"
        )

    # Filter out empty sequences
    original_len = len(tokenized_dataset)
    tokenized_dataset = tokenized_dataset.filter(
        lambda x: len(x['input_ids']) > 0,
        desc="Filtering empty sequences"
    )

    if len(tokenized_dataset) < original_len:
        print(f"Filtered out {original_len - len(tokenized_dataset)} empty sequences")

    # Calculate statistics
    if len(tokenized_dataset) > 0:
        print(f"\n✓ Created {len(tokenized_dataset):,} input-target pairs")
        print(f"Sequence length: {max_length} tokens (fixed)")
        print(f"Total tokens: {len(tokenized_dataset) * max_length:,}")
    else:
        print("\n⚠️  Warning: No sequences created! Check your data and max_length")

    return tokenized_dataset

### LOAD, MERGE AND CREATING TOKEN EMBEDDINGS

In [12]:
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
import torch

# Load tokenizer from transformers
# tokenizer = AutoTokenizer.from_pretrained("gpt2")

#Load tokenizer from tiktoken
tokenizer = tiktoken.get_encoding("gpt2")

# Step 1: Load raw datasets (text only, no tokenization yet)
txt_dataset = load_txt_file("the-verdict.txt")

hf_dataset = load_huggingface_dataset(
    dataset_name="HuggingFaceFW/fineweb",
    name="sample-10BT",
    num_samples=GPT_CONFIG_124M['volumn_of_dataset'],
    streaming=True
)

# Step 2: Merge raw datasets (still have 'text' column)
merged_dataset = merge_datasets(
    datasets=[txt_dataset, hf_dataset],
    shuffle=False
)

# Step 3: NOW apply explicit shifting tokenization
explicit_dataset = create_input_target_pairs_explicit(
    dataset=merged_dataset,  # Raw text dataset
    tokenizer=tokenizer,
    max_length=GPT_CONFIG_124M['max_length'],
    stride=4
)

# Create PyTorch DataLoader
def collate_fn(batch):
    """Convert batch to tensors"""
    input_ids = torch.tensor([item['input_ids'] for item in batch])
    labels = torch.tensor([item['labels'] for item in batch])
    return input_ids, labels

dataloader = DataLoader(
    explicit_dataset,
    batch_size=8,
    shuffle=False,
    collate_fn=collate_fn
)

# Iterate through batches
data_iter = iter(dataloader)
inputs, targets = next(data_iter)

print(f"Inputs shape:  {inputs.shape}")
print(f"Targets shape: {targets.shape}")
print(f"\nInputs:\n{inputs}")
print(f"\nTargets:\n{targets}")

# Create token embeddings
token_embedding_layer = torch.nn.Embedding(GPT_CONFIG_124M['vocab_size'], GPT_CONFIG_124M['emb_dim'])
token_embeddings = token_embedding_layer(inputs)
print(f"\nToken embeddings shape: {token_embeddings.shape}")

Loading text from: the-verdict.txt
✓ Loaded 83 text samples from .txt file
Total characters: 20,315

Loading HuggingFace dataset: 'HuggingFaceFW/fineweb' (config: sample-10BT) (split: train)


Resolving data files:   0%|          | 0/27468 [00:00<?, ?it/s]

Extracting text from streaming dataset...
  Processed 10,000 samples...
✓ Extracted 10,000 samples from streaming dataset
Total characters: 30,503,959
Average text length: 3050 chars per sample

Merging 2 datasets...
  Dataset 1: 83 samples
  Dataset 2: 10,000 samples
Using concatenation strategy...
✓ Merged dataset contains 10,083 total samples

Creating input-target pairs (EXPLICIT SHIFTING - Custom style)...
Max length: 512 tokens
Stride: 4 tokens
Tokenizing and shifting dataset...
Using 1 worker(s)


Tokenizing and shifting:   0%|          | 0/10083 [00:00<?, ? examples/s]

Filtering empty sequences:   0%|          | 0/338732 [00:00<?, ? examples/s]


✓ Created 338,732 input-target pairs
Sequence length: 512 tokens (fixed)
Total tokens: 173,430,784
Inputs shape:  torch.Size([8, 512])
Targets shape: torch.Size([8, 512])

Inputs:
tensor([[    9,    82,   394,  ...,   644,   339,   318],
        [49983,   396,  2055,  ...,   546,   351,   326],
        [ 1309,   502,  1208,  ...,   636,    11,   475],
        ...,
        [13366,  2569,  2055,  ...,  1281,    12,    35],
        [  198,  1532,   345,  ...,  5855,  8845,   367],
        [  900,  3511,   319,  ...,  5626,    39,  2751]])

Targets:
tensor([[  82,  394,    9,  ...,  339,  318, 3375],
        [ 396, 2055,   11,  ...,  351,  326,  938],
        [ 502, 1208,  319,  ...,   11,  475,  262],
        ...,
        [2569, 2055,   25,  ...,   12,   35, 2502],
        [1532,  345,  423,  ..., 8845,  367, 2885],
        [3511,  319, 2046,  ...,   39, 2751, 5390]])

Token embeddings shape: torch.Size([8, 512, 384])


In [13]:
print(token_embeddings.shape)

torch.Size([8, 512, 384])


### CREATING POSITIONAL EMBEDDINGS

In [14]:
context_length = GPT_CONFIG_124M['max_length'] # Set the context length and max length the same
pos_embedding_layer = torch.nn.Embedding(context_length, GPT_CONFIG_124M['emb_dim'])

In [15]:
pos_embeddings = pos_embedding_layer(torch.arange(context_length))
print(pos_embeddings.shape)

torch.Size([512, 384])


### CREATE INPUT AND POSITIONAL EMBEDDING

In [16]:
input_embeddings = token_embeddings + pos_embeddings
print(input_embeddings.shape)

torch.Size([8, 512, 384])


### IMPLEMENTING MULTI-HEAD ATTENTION

In [17]:
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2)

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec

###  THE BUILDING BLOCKS-LAYER NORMALIZATION, GELU AND FEED-FORWARD NEURAL NETWORK

In [18]:
class LayerNorm(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.eps = 1e-5
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim))

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        norm_x = (x - mean) / torch.sqrt(var + self.eps)
        return self.scale * norm_x + self.shift

class GELU(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(
            torch.sqrt(torch.tensor(2.0 / torch.pi)) *
            (x + 0.044715 * torch.pow(x, 3))
        ))


class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]), ## Expansion
            GELU(), ## Activation
            nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]), ## Contraction
        )

    def forward(self, x):
        return self.layers(x)

### TRANSFORMER BLOCK

In [19]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.att = MultiHeadAttention(
            d_in=cfg["emb_dim"],
            d_out=cfg["emb_dim"],
            context_length=cfg["context_length"],
            num_heads=cfg["n_heads"],
            dropout=cfg["drop_rate"],
            qkv_bias=cfg["qkv_bias"])
        self.ff = FeedForward(cfg)
        self.norm1 = LayerNorm(cfg["emb_dim"])
        self.norm2 = LayerNorm(cfg["emb_dim"])
        self.drop_shortcut = nn.Dropout(cfg["drop_rate"])

    def forward(self, x):
        # Shortcut connection for attention block
        shortcut = x
        x = self.norm1(x)
        x = self.att(x)  # Shape [batch_size, num_tokens, emb_size]
        x = self.drop_shortcut(x)
        x = x + shortcut  # Add the original input back

        # Shortcut connection for feed forward block
        shortcut = x
        x = self.norm2(x)
        x = self.ff(x)
        # 2*4*768
        x = self.drop_shortcut(x)
        x = x + shortcut  # Add the original input back

        return x
        # 2*4*768

###  ENTIRE GPT MODEL ARCHITECTURE IMPLEMENTATION

In [20]:
class GPTModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
        self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
        self.drop_emb = nn.Dropout(cfg["drop_rate"])

        self.trf_blocks = nn.Sequential(
            *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])

        self.final_norm = LayerNorm(cfg["emb_dim"])
        self.out_head = nn.Linear(
            cfg["emb_dim"], cfg["vocab_size"], bias=False
        )

    def forward(self, in_idx):
        batch_size, seq_len = in_idx.shape
        tok_embeds = self.tok_emb(in_idx)
        pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
        x = tok_embeds + pos_embeds  # Shape [batch_size, num_tokens, emb_size]
        x = self.drop_emb(x)
        x = self.trf_blocks(x)
        x = self.final_norm(x)
        logits = self.out_head(x)
        return logits

In [21]:
torch.manual_seed(123)

batch = []
txt1 = "Every effort moves you"
txt2 = "Every day holds a"
batch.append(torch.tensor(tokenizer.encode(txt1)))
batch.append(torch.tensor(tokenizer.encode(txt2)))
batch = torch.stack(batch, dim=0)

model = GPTModel(GPT_CONFIG_124M)
out = model(batch)
print("Input batch:\n", batch)
print("\nOutput shape:", out.shape)
print(out)

Input batch:
 tensor([[6109, 3626, 6100,  345],
        [6109, 1110, 6622,  257]])

Output shape: torch.Size([2, 4, 50257])
tensor([[[-1.0371, -0.2938,  0.5229,  ..., -0.2098,  0.7340,  1.3440],
         [ 0.0090, -1.2032,  0.6340,  ..., -0.7390, -0.8859,  0.3126],
         [-0.9895, -0.5937,  0.9895,  ...,  0.2455, -0.4786, -0.8173],
         [ 0.5983,  0.1613, -0.2204,  ..., -0.6277,  0.1684,  0.1050]],

        [[-0.9351, -0.3093,  0.2713,  ...,  0.7217,  0.3087,  0.6667],
         [ 0.0657,  0.6050, -0.2442,  ..., -0.3800, -0.7365,  0.2054],
         [ 0.6771,  0.3374,  0.9111,  ...,  1.1158, -0.8735, -0.7977],
         [ 1.2009,  0.6100, -0.4495,  ..., -0.4903, -0.1130,  0.2402]]],
       grad_fn=<UnsafeViewBackward0>)


### MODEL SIZE CALCULATION

In [22]:
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params:,}")

Total number of parameters: 49,434,624


In [23]:
print("Token embedding layer shape:", model.tok_emb.weight.shape)
print("Output layer shape:", model.out_head.weight.shape)

Token embedding layer shape: torch.Size([50257, 384])
Output layer shape: torch.Size([50257, 384])


In [24]:
total_size_bytes = total_params * 4 #A
total_size_mb = total_size_bytes / (1024 * 1024) #B
print(f"Total size of the model: {total_size_mb:.2f} MB")

Total size of the model: 188.58 MB


### GENERATING TEXT FROM OUTPUT TOKENS - INFERENCE

In [25]:
def generate_text_simple(model, idx, max_new_tokens, context_size):
    # idx is (batch, n_tokens) array of indices in the current context

    for _ in range(max_new_tokens):

        # Crop current context if it exceeds the supported context size
        # E.g., if LLM supports only 5 tokens, and the context size is 10
        # then only the last 5 tokens are used as context
        idx_cond = idx[:, -context_size:]

        # Get the predictions
        with torch.no_grad():
            logits = model(idx_cond) ### batch, n_tokens, vocab_size

        # Focus only on the last time step
        # (batch, n_tokens, vocab_size) becomes (batch, vocab_size)
        logits = logits[:, -1, :]

        # Apply softmax to get probabilities
        probas = torch.softmax(logits, dim=-1)  # (batch, vocab_size)

        # Get the idx of the vocab entry with the highest probability value
        idx_next = torch.argmax(probas, dim=-1, keepdim=True)  # (batch, 1)

        # Append sampled index to the running sequence
        idx = torch.cat((idx, idx_next), dim=1)  # (batch, n_tokens+1)

    return idx

In [26]:
import tiktoken

def text_to_token_ids(text, tokenizer):
    encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
    encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
    return encoded_tensor

def token_ids_to_text(token_ids, tokenizer):
    flat = token_ids.squeeze(0) # remove batch dimension
    return tokenizer.decode(flat.tolist())

start_context = "He said we came here"



token_ids = generate_text_simple(
    model=model,
    idx=text_to_token_ids(start_context, tokenizer),
    max_new_tokens=10,
    context_size=GPT_CONFIG_124M["context_length"]
)

print("Output text:\n", token_ids_to_text(token_ids, tokenizer))

Output text:
 He said we came hereLIN Kob nomine primaries� Wimscoringtersaerileen


###  CREATING TRAINING, TESTING AND VALIDATION DATA

In [27]:
import torch
from torch.utils.data import DataLoader

MAX_SEQUENCES = 60000
if len(explicit_dataset) > MAX_SEQUENCES:
    print(f"⚠️  Limiting dataset: {len(explicit_dataset):,} → {MAX_SEQUENCES:,}")
    explicit_dataset = explicit_dataset.select(range(MAX_SEQUENCES))
    print(f"✅ Reduced to {len(explicit_dataset):,} sequences")

# Split the tokenized dataset into train/validation
train_ratio = 0.85
split_idx = int(train_ratio * len(explicit_dataset))

# Split using HuggingFace datasets
train_dataset = explicit_dataset.select(range(split_idx))
val_dataset = explicit_dataset.select(range(split_idx, len(explicit_dataset)))

print(f"\nDataset split:")
print(f"Training samples: {len(train_dataset):,}")
print(f"Validation samples: {len(val_dataset):,}")

# Collate function
def collate_fn(batch):
    """Convert batch to tensors"""
    input_ids = torch.tensor([item['input_ids'] for item in batch])
    labels = torch.tensor([item['labels'] for item in batch])
    return input_ids, labels

# Set manual seed for reproducibility
torch.manual_seed(123)

# Create training dataloader
train_loader = DataLoader(
    train_dataset,
    batch_size=GPT_CONFIG_124M["batch_size"],
    shuffle=True,
    drop_last=True,
    collate_fn=collate_fn,
    num_workers=2,
    pin_memory=True,
    persistent_workers=True
)

# Create validation dataloader
val_loader = DataLoader(
    val_dataset,
    batch_size=GPT_CONFIG_124M["batch_size"],
    shuffle=False,
    drop_last=False,
    collate_fn=collate_fn,
    num_workers=2,
    pin_memory=True,
    persistent_workers=True
)

print(f"\nDataloaders created:")
print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

# Test iteration
print("\nTesting dataloaders...")
train_iter = iter(train_loader)
inputs, targets = next(train_iter)
print(f"Train batch - Inputs shape: {inputs.shape}, Targets shape: {targets.shape}")

val_iter = iter(val_loader)
inputs, targets = next(val_iter)
print(f"Val batch - Inputs shape: {inputs.shape}, Targets shape: {targets.shape}")

⚠️  Limiting dataset: 338,732 → 60,000
✅ Reduced to 60,000 sequences

Dataset split:
Training samples: 51,000
Validation samples: 9,000

Dataloaders created:
Training batches: 4250
Validation batches: 750

Testing dataloaders...
Train batch - Inputs shape: torch.Size([12, 512]), Targets shape: torch.Size([12, 512])
Val batch - Inputs shape: torch.Size([12, 512]), Targets shape: torch.Size([12, 512])


###  DEFINING THE CROSS ENTROPY LOSS FUNCTION

In [28]:
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Total samples: {len(train_loader) + len(val_loader)}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model.to(device)


print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device count: {torch.cuda.device_count()}")
if torch.cuda.is_available():
    print(f"CUDA device name: {torch.cuda.get_device_name(0)}")

Train batches: 4250
Val batches: 750
Total samples: 5000
Using device: cuda
CUDA available: True
CUDA device count: 1
CUDA device name: Tesla T4


In [29]:
def calc_loss_batch(input_batch, target_batch, model, device):
    input_batch, target_batch = input_batch.to(device), target_batch.to(device)
    logits = model(input_batch)
    loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())
    return loss

# Full calculation
def calc_loss_loader(data_loader, model, device, num_batches=None):
    total_loss = 0.
    if len(data_loader) == 0:
        return float("nan")
    elif num_batches is None:
        num_batches = len(data_loader)
    else:
        # Reduce the number of batches to match the total number of batches in the data loader
        # if num_batches exceeds the number of batches in the data loader
        num_batches = min(num_batches, len(data_loader))
    for i, (input_batch, target_batch) in enumerate(data_loader):
        if i < num_batches:
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            total_loss += loss.item()
        else:
            break
    return total_loss / num_batches



# Subset Calculation
def calc_loss_loader_subset(data_loader, model, device, num_batches=10):
    """Calculate loss on first num_batches only"""
    total_loss = 0.
    count = 0

    with torch.no_grad():
        for i, (inputs, targets) in enumerate(data_loader):
            if i >= num_batches:
                break
            inputs = inputs.to(device)
            targets = targets.to(device)
            logits = model(inputs)
            loss = torch.nn.functional.cross_entropy(
                logits.flatten(0, 1), targets.flatten()
            )
            total_loss += loss.item()
            count += 1

    return total_loss / count if count > 0 else 0

# Use subset calculation (much faster!)
with torch.no_grad():
    train_loss = calc_loss_loader_subset(train_loader, model, device, num_batches=10)
    val_loss = calc_loss_loader_subset(val_loader, model, device, num_batches=10)

print(f"Training loss (first 10 batches): {train_loss}")
print(f"Validation loss (first 10 batches): {val_loss}")

Training loss (first 10 batches): 10.979767608642579
Validation loss (first 10 batches): 10.98832893371582


In [30]:
print(device)

cuda


### CHCEK TO MAKE SURE ENOUGH DATASET FOR TRAINING

In [31]:
# Check your dataset size BEFORE training
print(f"\nDataset statistics:")
print(f"Training samples: {len(train_loader.dataset):,}")
print(f"Training batches: {len(train_loader):,}")
print(f"Validation samples: {len(val_loader.dataset):,}")
print(f"Validation batches: {len(val_loader):,}")

# You need AT LEAST 10,000+ samples for meaningful training
# If you have less, load more from HuggingFace:
if len(train_loader.dataset) < 10000:
    print("\n⚠️  WARNING: Dataset too small! Load more samples from HuggingFace")


Dataset statistics:
Training samples: 51,000
Training batches: 4,250
Validation samples: 9,000
Validation batches: 750


### TRAINING LOOP FOR THE LLM

In [32]:
def evaluate_model(model, train_loader, val_loader, device, eval_iter):
    """
    Evaluate model on train and val sets.
    Note: Caller should handle setting model back to train mode.
    """
    model.eval()
    with torch.no_grad():
        train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
        val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
    # Don't call model.train() here - let caller decide
    return train_loss, val_loss


def generate_and_print_sample(model, tokenizer, device, start_context,
                              max_new_tokens=50, return_text=False):
    """Generate and print sample text."""
    model.eval()
    context_size = model.pos_emb.weight.shape[0]

    try:
        encoded = text_to_token_ids(start_context, tokenizer).to(device)
        with torch.no_grad():
            token_ids = generate_text_simple(
                model=model, idx=encoded,
                max_new_tokens=max_new_tokens, context_size=context_size
            )
        decoded_text = token_ids_to_text(token_ids, tokenizer)
        display_text = decoded_text.replace("\n", " ")
        print(display_text)

        if return_text:
            return decoded_text
    except Exception as e:
        error_msg = f"Error generating sample: {e}"
        print(error_msg)
        if return_text:
            return error_msg

In [33]:
def train_model_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
                       eval_freq, eval_iter, start_context, tokenizer,
                       max_grad_norm=1.0, save_checkpoints=True, checkpoint_path="model_checkpoint.pt",
                       use_amp=True, scheduler=None, gradient_accumulation_steps=1,
                       wandb_run=None, inference_freq=500):
    """
    Train model with W&B monitoring following official best practices.

    Args:
        wandb_run: W&B run object from wandb.init()
    """
    train_losses, val_losses, track_tokens_seen = [], [], []
    tokens_seen, global_step = 0, -1
    best_val_loss = float('inf')

    scaler = GradScaler() if use_amp and torch.cuda.is_available() else None

    for epoch in range(num_epochs):
        model.train()
        epoch_train_losses = []
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

        for batch_idx, (input_batch, target_batch) in enumerate(progress_bar):

            # Forward pass with mixed precision
            if scaler is not None:
                with autocast():
                    loss = calc_loss_batch(input_batch, target_batch, model, device)
                    loss = loss / gradient_accumulation_steps

                scaler.scale(loss).backward()

                if (batch_idx + 1) % gradient_accumulation_steps == 0:
                    scaler.unscale_(optimizer)
                    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()

                    if scheduler is not None:
                        scheduler.step()

                    global_step += 1
            else:
                loss = calc_loss_batch(input_batch, target_batch, model, device)
                loss = loss / gradient_accumulation_steps
                loss.backward()

                if (batch_idx + 1) % gradient_accumulation_steps == 0:
                    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
                    optimizer.step()
                    optimizer.zero_grad()

                    if scheduler is not None:
                        scheduler.step()

                    global_step += 1

            tokens_seen += input_batch.numel()
            batch_loss = loss.item() * gradient_accumulation_steps
            epoch_train_losses.append(batch_loss)

            # Log to W&B every step using run.log() (official best practice)
            if wandb_run and global_step > 0:
                current_lr = optimizer.param_groups[0]['lr']
                wandb_run.log({
                    # Training metrics
                    "train/loss": batch_loss,
                    "train/perplexity": math.exp(min(batch_loss, 10)),  # Prevent overflow

                    # Learning rate
                    "lr": current_lr,

                    # Gradient metrics
                    "grad_norm": grad_norm.item() if 'grad_norm' in locals() else 0,

                    # Progress metrics
                    "tokens_seen": tokens_seen,
                    "epoch": epoch + 1,

                    # Step counter
                    "step": global_step,
                }, step=global_step)

            # Update progress bar
            current_lr = optimizer.param_groups[0]['lr']
            progress_bar.set_postfix({
                'loss': f'{batch_loss:.3f}',
                'lr': f'{current_lr:.2e}',
                'step': global_step
            })

            # Evaluation
            if global_step > 0 and global_step % eval_freq == 0:
                train_loss, val_loss = evaluate_model(
                    model, train_loader, val_loader, device, eval_iter)
                train_losses.append(train_loss)
                val_losses.append(val_loss)
                track_tokens_seen.append(tokens_seen)

                print(f"\nEp {epoch+1} (Step {global_step:06d}): "
                      f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}, "
                      f"LR {current_lr:.2e}")

                # Log evaluation metrics to W&B
                if wandb_run:
                    wandb_run.log({
                        # Evaluation losses
                        "eval/train_loss": train_loss,
                        "eval/val_loss": val_loss,
                        "eval/loss_diff": train_loss - val_loss,

                        # Perplexity
                        "eval/train_perplexity": math.exp(min(train_loss, 10)),
                        "eval/val_perplexity": math.exp(min(val_loss, 10)),

                        # Overfitting indicator
                        "eval/overfit_ratio": val_loss / train_loss if train_loss > 0 else 1.0,
                    }, step=global_step)


                # ✅ ADD THIS CHECK HERE (after logging, before checkpoint)
                if early_stopping(val_loss):
                    print("\n🛑 EARLY STOPPING - Overfitting detected!")
                    print(f"Best val loss: {early_stopping.best_loss:.4f}")
                    break  # Exit training loop

                # Save best model
                if save_checkpoints and val_loss < best_val_loss:
                    best_val_loss = val_loss
                    checkpoint = {
                        'epoch': epoch,
                        'global_step': global_step,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
                        'train_loss': train_loss,
                        'val_loss': val_loss,
                        'best_val_loss': best_val_loss,
                        'tokens_seen': tokens_seen,
                        'config': CONFIG,
                    }
                    torch.save(checkpoint, checkpoint_path)
                    print(f"✓ Saved best checkpoint (val_loss: {val_loss:.3f})")

                    # Save checkpoint as W&B artifact
                    if wandb_run:
                        artifact = wandb.Artifact(
                            name=f"model-best",
                            type="model",
                            description=f"Best model checkpoint at step {global_step}",
                            metadata={
                                "step": global_step,
                                "val_loss": val_loss,
                                "epoch": epoch,
                            }
                        )
                        artifact.add_file(checkpoint_path)
                        wandb_run.log_artifact(artifact)

                model.train()

            # Generate inference sample
            if global_step > 0 and global_step % inference_freq == 0:
                print("\n" + "="*70)
                print(f"[Step {global_step}] Generated sample:")
                generated_text = generate_and_print_sample(
                    model, tokenizer, device, start_context, return_text=True
                )
                print("="*70 + "\n")

                # Log generated text to W&B
                if wandb_run:
                    # Create a nice HTML table for the sample
                    sample_html = f"""
                    <table>
                        <tr><th>Step</th><th>Prompt</th><th>Generated Text</th></tr>
                        <tr>
                            <td>{global_step}</td>
                            <td><b>{start_context}</b></td>
                            <td>{generated_text}</td>
                        </tr>
                    </table>
                    """
                    wandb_run.log({
                        "samples/generated_text": wandb.Html(sample_html),
                        "samples/text_length": len(generated_text),
                    }, step=global_step)

                model.train()

        # End of epoch - log epoch metrics
        avg_epoch_loss = sum(epoch_train_losses) / len(epoch_train_losses) if epoch_train_losses else 0
        print(f"\n{'='*70}")
        print(f"End of Epoch {epoch+1}/{num_epochs}")
        print(f"Average training loss: {avg_epoch_loss:.3f}")
        print(f"{'='*70}\n")

        if wandb_run:
            wandb_run.log({
                "epoch/avg_train_loss": avg_epoch_loss,
                "epoch/avg_train_perplexity": math.exp(min(avg_epoch_loss, 10)),
                "epoch/number": epoch + 1,
            }, step=global_step)

        # Generate sample at end of epoch
        print("End of epoch sample:")
        generated_text = generate_and_print_sample(
            model, tokenizer, device, start_context, return_text=True
        )

        if wandb_run:
            epoch_html = f"""
            <div style="padding: 10px; border: 2px solid #4CAF50; border-radius: 5px;">
                <h3>Epoch {epoch+1} Completion Sample</h3>
                <p><b>Prompt:</b> {start_context}</p>
                <p><b>Generated:</b> {generated_text}</p>
            </div>
            """
            wandb_run.log({
                "epoch_samples/text": wandb.Html(epoch_html),
            }, step=global_step)

        model.train()

    return train_losses, val_losses, track_tokens_seen

In [34]:
import wandb
import time
import torch
import math
from torch.optim.lr_scheduler import LambdaLR
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm

# Set seed
def set_seed(seed=123):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

set_seed(123)

# Configuration
CONFIG = {
    # Model architecture
    'model_name': 'GPT-124M',
    'vocab_size': 50257,
    'context_length': 512,
    'emb_dim': 384,
    'n_heads': 6,
    'n_layers': 6,
    'drop_rate': 0.2,

    # Training hyperparameters
    'num_epochs': 20,
    'batch_size': 32,
    'learning_rate': 5e-4,
    'weight_decay': 0.1,
    'beta1': 0.9,
    'beta2': 0.95,
    'epsilon': 1e-8,
    'max_grad_norm': 0.5,

    # Learning rate schedule
    'warmup_steps': 2000,
    'lr_decay': 'cosine',

    # Evaluation
    'eval_freq': 1000,
    'eval_iter': 50,
    'inference_freq': 2000,

    # Optimization
    'use_amp': torch.cuda.is_available(),
    'gradient_accumulation_steps': 8,

    # Dataset info
    'dataset': 'custom-text + HuggingFace/fineweb',
    'tokenizer': 'tiktoken-gpt2',
}





# Early Stopping
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.1):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
            return False

        if val_loss > self.best_loss + self.min_delta:
            self.counter += 1
            print(f"⚠️  EarlyStopping: {self.counter}/{self.patience}")
            if self.counter >= self.patience:
                return True
        else:
            if val_loss < self.best_loss:
                print(f"✅ Val improved: {self.best_loss:.4f} → {val_loss:.4f}")
            self.best_loss = val_loss
            self.counter = 0
        return False

# ✅ ADD THIS INITIALIZATION
early_stopping = EarlyStopping(patience=5, min_delta=0.1)
print("✅ Early stopping ready")










# Initialize W&B run following official best practices
run = wandb.init(
    # Set the W&B entity (your username or team name)
    entity="davidbdev-bit-labs-inc",  # CHANGE THIS to your W&B username

    # Set the W&B project where this run will be logged
    project="Build LLM From Scratch",

    # Give this run a descriptive name
    name=f"gpt-{CONFIG['num_epochs']}ep-lr{CONFIG['learning_rate']:.0e}-bs{CONFIG['batch_size']}",

    # Add tags for easy filtering
    tags=["gpt", "pretraining", "explicit-shifting", "tiktoken"],

    # Track hyperparameters and run metadata
    config=CONFIG,

    # Save code
    save_code=True,
)

print("="*70)
print(f"W&B Run initialized: {run.name}")
print(f"W&B Run URL: {run.url}")
print("="*70)


# Initialize model
model = GPTModel(GPT_CONFIG_124M)
model.to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nModel parameters: {total_params:,} total, {trainable_params:,} trainable")

# Log model info to W&B
run.config.update({
    "total_params": total_params,
    "trainable_params": trainable_params,
})

# Optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=CONFIG['learning_rate'],
    betas=(CONFIG['beta1'], CONFIG['beta2']),
    eps=CONFIG['epsilon'],
    weight_decay=CONFIG['weight_decay']
)

# LR Scheduler
total_steps = len(train_loader) * CONFIG['num_epochs']

def get_lr_scheduler_fixed(optimizer, warmup_steps, total_steps):
    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))
    return LambdaLR(optimizer, lr_lambda)

scheduler = get_lr_scheduler_fixed(optimizer, CONFIG['warmup_steps'], total_steps)

# Log training info
run.config.update({
    "total_steps": total_steps,
    "train_batches": len(train_loader),
    "val_batches": len(val_loader),
    "device": str(device),
})


print(f"\nTotal training steps: {total_steps:,}")
print(f"Device: {device}\n")


# Watch model in W&B (log gradients and parameters)
run.watch(model, log="all", log_freq=100)


✅ Early stopping ready


[34m[1mwandb[0m: Currently logged in as: [33mdavidbdev[0m ([33mdavidbdev-bit-labs-inc[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


W&B Run initialized: gpt-20ep-lr5e-04-bs32
W&B Run URL: https://wandb.ai/davidbdev-bit-labs-inc/Build%20LLM%20From%20Scratch/runs/rykbf8mp

Model parameters: 49,434,624 total, 49,434,624 trainable

Total training steps: 85,000
Device: cuda



### Start Training.....

In [None]:
# Start training
start_time = time.time()

try:
    train_losses, val_losses, tokens_seen = train_model_simple(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        device=device,
        num_epochs=CONFIG['num_epochs'],
        eval_freq=CONFIG['eval_freq'],
        eval_iter=CONFIG['eval_iter'],
        start_context="He said we came here",
        tokenizer=tokenizer,
        max_grad_norm=CONFIG['max_grad_norm'],
        save_checkpoints=True,
        checkpoint_path="gpt_model_best.pt",
        use_amp=CONFIG['use_amp'],
        scheduler=scheduler,
        gradient_accumulation_steps=CONFIG['gradient_accumulation_steps'],
        wandb_run=run,  # Pass the run object
        inference_freq=CONFIG['inference_freq']
    )

    # Save final model
    final_checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'config': CONFIG,
        'train_losses': train_losses,
        'val_losses': val_losses,
        'tokens_seen': tokens_seen,
    }
    torch.save(final_checkpoint, "gpt_model_final.pt")

    # Save final model as W&B artifact
    final_artifact = wandb.Artifact(
        name="model-final",
        type="model",
        description="Final trained model",
        metadata={
            "epochs": CONFIG['num_epochs'],
            "final_val_loss": val_losses[-1] if val_losses else None,
        }
    )
    final_artifact.add_file("gpt_model_final.pt")
    run.log_artifact(final_artifact)

    # Log final summary metrics
    run.summary.update({
        "final_train_loss": train_losses[-1] if train_losses else None,
        "final_val_loss": val_losses[-1] if val_losses else None,
        "best_val_loss": min(val_losses) if val_losses else None,
        "total_tokens_seen": tokens_seen[-1] if tokens_seen else 0,
        "training_time_minutes": (time.time() - start_time) / 60,
    })

    print("\n" + "="*70)
    print("✓ Training completed successfully!")
    print("="*70)

except KeyboardInterrupt:
    print("\n" + "="*70)
    print("Training interrupted by user")
    print("="*70)

    # Save interrupted checkpoint
    interrupted_checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'config': CONFIG,
        'interrupted': True,
    }
    torch.save(interrupted_checkpoint, "gpt_model_interrupted.pt")
    print("✓ Saved interrupted model checkpoint")

    # Mark run as interrupted in W&B
    run.summary["interrupted"] = True

except Exception as e:
    print(f"\n❌ Training failed with error: {e}")
    run.summary["failed"] = True
    run.summary["error_message"] = str(e)
    raise

finally:
    end_time = time.time()
    execution_time_minutes = (end_time - start_time) / 60
    execution_time_hours = execution_time_minutes / 60

    print(f"\nTraining time: {execution_time_minutes:.2f} minutes ({execution_time_hours:.2f} hours)")

    if torch.cuda.is_available():
        peak_memory_gb = torch.cuda.max_memory_allocated(device) / 1e9
        print(f"Peak GPU memory: {peak_memory_gb:.2f} GB")
        run.summary["peak_gpu_memory_gb"] = peak_memory_gb

    # Finish the W&B run and upload any remaining data (official best practice)
    run.finish()
    print(f"\n✓ W&B run finished: {run.url}")

  scaler = GradScaler() if use_amp and torch.cuda.is_available() else None
  with autocast():
Epoch 1/20: 100%|██████████| 4250/4250 [19:21<00:00,  3.66it/s, loss=6.701, lr=1.33e-04, step=530]



End of Epoch 1/20
Average training loss: 8.171

End of epoch sample:
He said we came here. The first time - - - - - - - - - - - - - - - - - - - - - - 


Epoch 2/20:  88%|████████▊ | 3759/4250 [17:06<02:06,  3.89it/s, loss=5.532, lr=2.50e-04, step=1000]


Ep 2 (Step 001000): Train loss 5.351, Val loss 7.242, LR 2.50e-04
✓ Saved best checkpoint (val_loss: 7.242)


Epoch 2/20:  88%|████████▊ | 3761/4250 [18:20<2:21:50, 17.40s/it, loss=5.322, lr=2.50e-04, step=1000]


Ep 2 (Step 001000): Train loss 5.363, Val loss 7.242, LR 2.50e-04


Epoch 2/20:  89%|████████▊ | 3762/4250 [18:39<2:27:06, 18.09s/it, loss=5.463, lr=2.50e-04, step=1000]


Ep 2 (Step 001000): Train loss 5.337, Val loss 7.242, LR 2.50e-04


Epoch 2/20:  89%|████████▊ | 3763/4250 [18:58<2:27:59, 18.23s/it, loss=5.608, lr=2.50e-04, step=1000]


Ep 2 (Step 001000): Train loss 5.392, Val loss 7.242, LR 2.50e-04


Epoch 2/20:  89%|████████▊ | 3764/4250 [19:16<2:27:40, 18.23s/it, loss=5.594, lr=2.50e-04, step=1000]


Ep 2 (Step 001000): Train loss 5.383, Val loss 7.242, LR 2.50e-04


Epoch 2/20:  89%|████████▊ | 3765/4250 [19:35<2:27:30, 18.25s/it, loss=5.468, lr=2.50e-04, step=1000]


Ep 2 (Step 001000): Train loss 5.352, Val loss 7.242, LR 2.50e-04


Epoch 2/20:  89%|████████▊ | 3766/4250 [19:53<2:27:34, 18.29s/it, loss=5.477, lr=2.50e-04, step=1000]


Ep 2 (Step 001000): Train loss 5.342, Val loss 7.242, LR 2.50e-04


Epoch 2/20:  89%|████████▊ | 3767/4250 [20:12<2:28:22, 18.43s/it, loss=5.461, lr=2.50e-04, step=1000]


Ep 2 (Step 001000): Train loss 5.378, Val loss 7.242, LR 2.50e-04


Epoch 2/20: 4813it [24:57,  3.21it/s, loss=5.291, lr=2.83e-04, step=1131]



Training interrupted by user
✓ Saved interrupted model checkpoint

Training time: 44.53 minutes (0.74 hours)
Peak GPU memory: 6.50 GB
