<a href="https://colab.research.google.com/github/nikolina-p/NLP-with-Transformers/blob/main/ch5_pretraining_on_Gutenberg.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Optimizing GPT-124M pretraining function on Hugging Face Streaming Datasets to run on one 40GB A100 GPU**
GPTModel class comes from the book *LLMs from scratch by Sebastian Rashka*.


**"Bonus pretraining on Gutenberg" solution:**
https://github.com/rasbt/LLMs-from-scratch/tree/main/ch05/03_bonus_pretraining_on_gutenberg

**Download:** "*As of this writing, this will require approximately **50 GB** of disk space and take **about 10-15 hours**, but it may be more depending on how much Project Gutenberg grew since then.*"

**Training:** "Warning: Note that training on 1 of the ~500 Mb text files in the gutenberg_preprocessed folder will take approximately **4 hours on a V100 GPU**. The folder contains 47 files and will take approximately **200 hours (more than 1 week) to complete**. You may want to run it on a smaller number of files."*



---


**The goal**: optimize
- data loading,
- training function, and
- the GPT model

to train on a "big" dataset with limited resources on Google Colab A100 runntime.



---
**The dataset**

Original Project Gutenberg dataset: https://huggingface.co/datasets/manu/project_gutenberg

- English split

- 61.3K rows (books: book ID, text)  ||  38.026 unique rows

This dataset contained duplicate books, excessive new lines, and blank spaces, as well as generic headers and footers. After cleaning and tokenizing the texts, the dataset was prepared for training and uploaded to the Hugging Face Hub.

Clean and tokenized dataset: https://huggingface.co/datasets/nikolina-p/gutenberg_clean_tokenized_en_splits  

- Total number of tokens after cleaning: 3_638_561_697




In [1]:
%%capture
!pip install tiktoken \
-U datasets

## **GPT MODEL**

### **Multi-head self-attention mechanism**

In [2]:
import torch
from torch import nn

class MultiHeadAttention(nn.Module):

    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout) # preventing overfitting - only used in training

        # non-trainable parameters part of the model's state, move and save/load with the model
        self.register_buffer("mask", torch.triu(
            torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape # b - batches

        queries = self.W_query(x) # b x d_in x d_in
        keys = self.W_key(x)
        values = self.W_value(x)

        # creating HEADs: step 1
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)

        # creating HEADs: step 2
        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

        att_scores = queries @ keys.transpose(2,3) # dot product

        att_scores = att_scores.masked_fill(
            self.mask.bool()[:num_tokens, :num_tokens],
            -torch.inf
        )

        # scale and normalize
        att_weights = torch.softmax(
            att_scores / keys.shape[-1] ** 0.5,
            dim=-1
            )

        att_weights = self.dropout(att_weights)

        context_vec = att_weights @ values # # (b, num_tokens, num_heads, head_dim)

        # reverse shaping
        context_vec = context_vec.transpose(1, 2) # (b, num_heads, num_tokens, head_dim)
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)

        context_vec = self.out_proj(context_vec)

        return context_vec

In [3]:
#@title Flash Attention
import torch
from torch import nn
from torch.nn import functional as F

class MultiHeadAttentionFlash(nn.Module):

    def __init__(self, d_in, d_out, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = dropout # preventing overfitting - only used in training

    def forward(self, x):
        b, num_tokens, d_in = x.shape # b - batches

        queries = self.W_query(x) # b x d_in x d_in
        keys = self.W_key(x)
        values = self.W_value(x)

        context_vec = F.scaled_dot_product_attention(queries, keys, values, dropout_p=0.1, is_causal=True)

        # reverse shaping
        context_vec = context_vec.transpose(1, 2) # (b, num_heads, num_tokens, head_dim)
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)

        context_vec = self.out_proj(context_vec)

        return context_vec

#### TEST: compare the resulting tensors of flash and classic attention

In [4]:
torch.manual_seed(123)
torch.set_printoptions(sci_mode=False)

x = torch.rand(1, 1, 6)
x

tensor([[[0.2961, 0.5166, 0.2517, 0.6886, 0.0740, 0.8665]]])

In [5]:
# create classic attn
classic = MultiHeadAttention(d_in=6, d_out=6, context_length=1, dropout=0.1, num_heads=3)
y = classic(x)
y

tensor([[[-0.3871,  0.0798,  0.1245,  0.1996,  0.1424,  0.1624]]],
       grad_fn=<ViewBackward0>)

In [6]:
# flash attn
flash = MultiHeadAttentionFlash(d_in=6, d_out=6, dropout=0.1, num_heads=3)
flash.load_state_dict(classic.state_dict(), strict=False)

flash(x)

tensor([[[-0.3871,  0.0798,  0.1245,  0.1996,  0.1424,  0.1624]]],
       grad_fn=<ViewBackward0>)

### **Feedforwar layer**
**MLP - Multilayer Perceptron**

In [7]:
class LayerNorm(nn.Module):

    def __init__(self, emb_dim):
        super().__init__()
        self.eps = 1e-5 # helps avoid division by 0
        self.scale = nn.Parameter(torch.ones(emb_dim)) # learnable
        self.shift = nn.Parameter(torch.zeros(emb_dim)) # learnable

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        norm_x = (x - mean) / torch.sqrt(var + self.eps)
        return self.scale * norm_x + self.shift


In [8]:
# activation function
class GELU(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(
            torch.sqrt(torch.tensor(2.0 / torch.pi)) * (x + 0.044715 * torch.pow(x, 3))
                ))


In [9]:

class FeedForward(nn.Module):

    def __init__(self, cfg):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
            GELU(),
            nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
            )

    def forward(self, x):
        return self.layers(x)

### **Transformer Block**

In [10]:
class TransformerBlock(nn.Module):

    def __init__(self, cfg, flash=False):
        super().__init__()
        self.flash = flash

        AttClass = MultiHeadAttentionFlash if self.flash else MultiHeadAttention

        self.att = AttClass(
            d_in=cfg["emb_dim"], # dimension of input embeddings
            d_out=cfg["emb_dim"], # dimension of output embeddings
            num_heads=cfg["n_heads"],
            dropout=cfg["drop_rate"], # rate used for dropout
            qkv_bias=cfg["qkv_bias"], # True/False - use bias in query, key and value weights matrices
            **({"context_length": cfg["context_length"]} if not self.flash else {})
        )

        self.ff = FeedForward(cfg)

        self.norm1 = LayerNorm(cfg["emb_dim"])
        self.norm2 = LayerNorm(cfg["emb_dim"])
        self.drop_shortcut = nn.Dropout(cfg["drop_rate"]) # dropout layer

    def forward(self, x):
        # Multi-Head Self-Attention Layer
        shortcut = x
        x = self.norm1(x)
        x = self.att(x)
        x = self.drop_shortcut(x)
        x = x + shortcut

        # Feed-Forward Layer
        shortcut = x
        x = self.norm2(x)
        x = self.ff(x)
        x = self.drop_shortcut(x)
        x = x + shortcut

        return x

### **GPT MODEL**

In [11]:
from huggingface_hub import PyTorchModelHubMixin

class GPTModel(nn.Module, PyTorchModelHubMixin):

    def __init__(self, cfg, flash=False, tied=False):
        super().__init__()
        self.flash = flash
        self.tied = tied

        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
        self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
        self.drop_emb = nn.Dropout(cfg["drop_rate"])

        self.trf_blocks = nn.Sequential(
            *[TransformerBlock(cfg, self.flash) for _ in range(cfg["n_layers"])]
            )
        self.final_norm = LayerNorm(cfg["emb_dim"])

        self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)

        if self.tied:
            self.tok_emb.weight = self.out_head.weight

    def forward(self, in_idx):
        batch_size, seq_len = in_idx.shape

        tok_embeds = self.tok_emb(in_idx)
        pos_embeds = self.pos_emb(
            torch.arange(seq_len, device=in_idx.device)
            )
        x = tok_embeds + pos_embeds
        x = self.drop_emb(x)

        x = self.trf_blocks(x)

        x = self.final_norm(x)

        logits = self.out_head(x)

        return logits

## **GPT configuration**

In [12]:
# the model coded in this notebook has 163M parameters (no weight tying)
GPT_CONFIG_124M = {
    "vocab_size": 50257, # Vocabulary size
    "context_length": 512, # Context length
    "emb_dim": 768, # Embedding dimension
    "n_heads": 12, # Number of attention heads
    "n_layers": 12, # Number of layers
    "drop_rate": 0.1, # Dropout rate
    "qkv_bias": False # Query-Key-Value bias
}

## **StreamingDataset**

https://huggingface.co/docs/datasets/v4.0.0/en/about_mapstyle_vs_iterable

In [13]:
#@title streaming from tokenized dataset
import random, torch
from datasets import IterableDataset

class StreamingDataset(IterableDataset):
    """An iterable dataset that generates input-target sequence pairs,
    and shuffles sequences at the book level."""

    def __init__(self, dataset, context_size):
        self.dataset = iter(dataset)
        self.context_size = context_size

    def __iter__(self):
        while True:
            try:
                book = next(self.dataset)
                book_token_ids = book["tokenized"]

                if len(book_token_ids) < self.context_size:
                    print(f"Book {book['id']} too short - not enough tokens.")
                    continue

                # loop trough shuffled start_indices of input_chunk(s) and create pairs
                for i in self.shuffle_indices(len(book_token_ids)):
                    input_chunk = book_token_ids[i:i + self.context_size]
                    target_chunk = book_token_ids[i + 1:i + self.context_size + 1]
                    yield book["id"], torch.tensor(input_chunk), torch.tensor(target_chunk)
            except StopIteration:
                print("StreamingDataset: no more data.")
                break

    def shuffle_indices(self, book_num_tokens):
        """shuffles START INDICES of input chunks"""
        start_indices = list(range(0, book_num_tokens - self.context_size, self.context_size))
        random.shuffle(start_indices)
        return start_indices


## **Training function**

In [14]:
# calculates the loss per one batch
def calc_loss_batch(input_batch, target_batch, model, device):
    input_batch, target_batch = input_batch.to(device), target_batch.to(device)
    logits = model(input_batch)
    loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())
    return loss

####**Save checkpoint and training results**

In [15]:
# save the model and optimizer state
def save_checkpoint(model, optimizer, num_books, scaler=None):
    print(f"Saving checkpoint...{num_books}")
    torch.save({
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scaler_state_dict": scaler.state_dict() if scaler else None
        },
        f"checkpoint_cycle_{num_books}.pth"
        )

#load the model and optimizer
def load_checkpoint(file_name, model_config):
    checkpoint = torch.load(file_name, map_location=device)
    model = GPTModel(model_config)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=0.1)
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    return model, optimizer

In [16]:
import json
from datetime import datetime
import matplotlib.pyplot as plt

# save training results in json file
def save_training_info(train_losses, valid_losses, books_seen, total_batches, times=None):
    data = {
        "total_batches": total_batches,
        "train_losses": train_losses,
        "valid_losses": valid_losses,
        "books_seen": books_seen,
        "time per cycle": times if times else []
    }

    date_str = datetime.now().strftime("%Y-%m-%d")

    with open(f"results_{date_str}_books_{len(books_seen)}.json", "w", encoding="UTF-8") as f:
        json.dump(data, f, indent=4)


def plot_loss_convergence(train_losses, val_losses):
    cycles = range(1, len(train_losses) + 1)

    plt.figure(figsize=(8, 5))
    plt.plot(cycles, train_losses, label='Training Loss', marker='o')
    plt.plot(cycles, val_losses, label='Validation Loss', marker='s')
    plt.xlabel('Cycle')
    plt.ylabel('Loss')
    plt.title('Loss Convergence')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

### **1. baseline**

In [17]:
#@title Train model - no optimization
from collections import Counter

def train_model_simple(model, dataloader_train, dataloader_valid, optimizer, device, num_epoch, training_batches, val_ratio):
    # track losses per cycle and tokens seen
    train_losses, valid_losses = [], []
    total_batches = 0
    tokens_seen_train, tokens_seen_valid = 0, 0
    books_seen = Counter()

    # Main training loop
    try:
        for epoch in range(num_epoch):
            loss_train, train_batch_count = 0, 0 # accumulated loss in a cycle; number of batches in a cycle
            loss_valid, valid_batch_count = 0, 0

            model.train()
            start = time.time()

            # LOAD FROM TRAIN SPLIT
            for book_id_batch, input_batch, target_batch in dataloader_train:
                books_seen.update(book_id_batch)

                # train
                model.train()
                optimizer.zero_grad(set_to_none=True)
                loss = calc_loss_batch(input_batch, target_batch, model, device)
                loss.backward() # Calculate loss gradients
                optimizer.step() # Update model weights using loss gradients

                loss_train += loss
                tokens_seen_train += input_batch.numel()
                train_batch_count += 1
                total_batches += 1

                if train_batch_count % training_batches == 0:
                    # LOAD FROM VALIDATION SPLIT
                    for book_id_batch, input_batch, target_batch in dataloader_valid:
                        books_seen.update(book_id_batch)

                        with torch.no_grad():
                            model.eval()
                            loss = calc_loss_batch(input_batch, target_batch, model, device)
                            loss_valid += loss
                            tokens_seen_valid += input_batch.numel()
                            valid_batch_count += 1
                            total_batches += 1

                            if valid_batch_count % (training_batches * val_ratio) == 0:
                                # end train/valid cycle and print results
                                torch.cuda.synchronize()
                                end = time.time()

                                train_losses.append((loss_train / train_batch_count).item())
                                loss_train = 0
                                train_batch_count = 0

                                valid_losses.append((loss_valid / valid_batch_count).item())
                                loss_valid = 0
                                valid_batch_count = 0

                                tok_sec = (input_batch.numel()*training_batches*(1+val_ratio))/(end-start)
                                print(f"\nbatches:{total_batches} | loss: {train_losses[-1]:.3f}/{valid_losses[-1]:.3f}" \
                                    f"| tok-seen: {tokens_seen_train+tokens_seen_valid:,}" \
                                    f"| time: {(end-start):,.2f}" \
                                    f"| tok/sec: {tok_sec:,.2f} | epoch time: {3638561697/(3600*tok_sec):,.2f} hrs")

                                start = time.time()
                                break
    except KeyboardInterrupt:
        print("\nSaving model after KeyboardInterrupt")

    print(f"\nTokens seen (train/val/total) {tokens_seen_train:,} / {tokens_seen_valid:,} / {(tokens_seen_train + tokens_seen_valid):,}")
    print(f"Total books seen {len(books_seen):,}")
    save_checkpoint(model, optimizer, len(books_seen))
    save_training_info(train_losses, valid_losses, books_seen, total_batches)

    return train_losses, valid_losses, total_batches

### **2. mixed precision**

“Automatic mixed precision training” means training with `torch.autocast` and `torch.amp.GradScaler` together.

**Autocasting** automatically chooses the precision for operations to improve performance while maintaining accuracy.

`torch.amp.GradScaler` helps perform the steps of gradient scaling conveniently. Gradient scaling improves convergence for networks with float16 (by default on CUDA and XPU) gradients by minimizing gradient underflow (flush to zero).

If the forward pass for a particular op has float16 inputs, the backward pass for that op will produce float16 gradients. Gradient values with small magnitudes may not be representable in float16. These values will flush to zero (“underflow”), so the update for the corresponding parameters will be lost.

To prevent underflow, **“gradient scaling” multiplies the network’s loss(es) by a scale factor** and invokes a backward pass on the scaled loss(es). Gradients flowing backward through the network are then scaled by the same factor. In other words, gradient values have a larger magnitude, so they don’t flush to zero.

Each parameter’s gradient (.grad attribute) should be unscaled before the optimizer updates the parameters, so the scale factor does not interfere with the learning rate.

receipt: https://docs.pytorch.org/tutorials/recipes/recipes/amp_recipe.html


**cuBLAS** is NVIDIA’s GPU-accelerated implementation of the BLAS (Basic Linear Algebra Subprograms) library:
- Vector and matrix multiplication (e.g., GEMM: General Matrix Multiply)
- Matrix-vector products
- Vector dot products and norms

**cuDNN** is NVIDIA’s GPU-accelerated library specifically for deep learning primitives.
- Convolution operations (for CNNs)
- Pooling (max, average)
- Normalization (batch norm, layer norm)
- Activation functions (ReLU, tanh, sigmoid)
- RNNs (LSTM, GRU)
- Tensor layout transformations

In [18]:
!nvcc --version
!dpkg -l | grep libcublas
!dpkg -l | grep cudnn
print(f"Torch: {torch.__version__}")

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Jun__6_02:18:23_PDT_2024
Cuda compilation tools, release 12.5, V12.5.82
Build cuda_12.5.r12.5/compiler.34385749_0
hi  libcublas-12-5                         12.5.3.2-1                              amd64        CUBLAS native runtime libraries
hi  libcublas-dev-12-5                     12.5.3.2-1                              amd64        CUBLAS native dev links, headers
hi  libcudnn9-cuda-12                      9.2.1.18-1                              amd64        cuDNN runtime libraries for CUDA 12.5
ii  libcudnn9-dev-cuda-12                  9.2.1.18-1                              amd64        cuDNN development headers and symlinks for CUDA 12.5
Torch: 2.8.0+cu126


In [19]:
#@title Train model - mixed precision
from collections import Counter

def train_model_mixed_precision(model, train_loader, valid_loader, optimizer, device, num_epoch, training_batches, val_ratio):
    train_losses, valid_losses = [], []  # track losses per cycle
    tokens_seen_train, tokens_seen_valid = 0, 0
    books_seen = Counter()
    total_batches = 0

    # Main training loop
    try:
        scaler = torch.GradScaler('cuda')
        start = time.time()
        end = 0

        for epoch in range(num_epoch):
            print(f"Epoch: {epoch}")
            loss_train, loss_valid = 0, 0 # accumulated loss in a cycle
            train_batch_count, valid_batch_count = 0, 0 # number of batches in a cycle

            model.train()

            for book_id_batch, input_batch, target_batch in train_loader:
                books_seen.update(book_id_batch)

                # train
                model.train()
                optimizer.zero_grad(set_to_none=True)

                # Runs the forward pass with autocasting.
                with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
                    loss = calc_loss_batch(input_batch, target_batch, model, device)

                loss.backward()
                optimizer.step()

                loss_train += loss
                tokens_seen_train += input_batch.numel()
                train_batch_count += 1
                total_batches += 1

                if train_batch_count % training_batches == 0:
                    # LOAD FROM VALIDATION SPLIT
                    for book_id_batch, input_batch, target_batch in valid_loader:
                        books_seen.update(book_id_batch)

                        with torch.no_grad():
                            model.eval()
                            with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
                                loss = calc_loss_batch(input_batch, target_batch, model, device)
                            loss_valid += loss
                            tokens_seen_valid += input_batch.numel()
                            valid_batch_count += 1
                            total_batches += 1

                        if valid_batch_count % (int(training_batches * val_ratio)) == 0:
                            # print results of the previous training/validation cycle, and clean tracking vars
                            torch.cuda.synchronize()
                            end = time.time()

                            train_losses.append((loss_train / train_batch_count).item())
                            loss_train = 0
                            train_batch_count = 0

                            valid_losses.append((loss_valid / valid_batch_count).item())
                            loss_valid = 0
                            valid_batch_count = 0

                            tok_sec = (input_batch.numel() * int(training_batches * (1+val_ratio))) / (end-start)
                            print(f"\nbatches:{total_batches} | loss: {train_losses[-1]:.3f}/{valid_losses[-1]:.3f}" \
                                    f"| tok-seen: {tokens_seen_train+tokens_seen_valid:,}" \
                                    f"| time: {(end-start):,.2f}" \
                                    f"| tok/sec: {tok_sec:,.2f} | epoch: {3638561697/(3600*tok_sec):,.2f} hrs")
                            start = time.time()
                            break
    except KeyboardInterrupt:
        print("\nSaving model after KeyboardInterrupt")


    print(f"\nTokens seen: {tokens_seen_train:,} / {tokens_seen_valid:,} / {(tokens_seen_train + tokens_seen_valid):,}")
    print(f"Total books seen {len(books_seen):,}")
    save_checkpoint(model, optimizer, len(books_seen), scaler)
    save_training_info(train_losses, valid_losses, books_seen, total_batches)

    return train_losses, valid_losses, total_batches

#**MAIN**

## **Streaming from separate train and validation splits**

In [20]:
from datasets import load_dataset

gutenberg_train = load_dataset("nikolina-p/gutenberg_clean_tokenized_en_splits", split="train", streaming=True, columns=["id", "tokenized"])
gutenberg_valid = load_dataset("nikolina-p/gutenberg_clean_tokenized_en_splits", split="validation", streaming=True, columns=["id", "tokenized"])

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

Resolving data files:   0%|          | 0/33 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/33 [00:00<?, ?it/s]

In [21]:
gutenberg_train

IterableDataset({
    features: ['id', 'tokenized'],
    num_shards: 33
})

In [22]:
!nvidia-smi

Fri Sep 26 02:49:02 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          Off |   00000000:00:05.0 Off |                    0 |
| N/A   31C    P0             54W /  400W |       0MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [23]:
import time
from torch.utils.data import DataLoader
import inspect

#GPT_CONFIG_124M["vocab_size"] = 55296
GPT_CONFIG_124M["vocab_size"] = ((GPT_CONFIG_124M["vocab_size"] + 127) // 128) * 128
model = GPTModel(GPT_CONFIG_124M, flash=True, tied=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model = torch.compile(model)

fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
use_fused = fused_available and device.type == 'cuda'
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0004, weight_decay=0.1, betas=(0.9, 0.95), fused=use_fused)

ds_train = StreamingDataset(gutenberg_train, context_size=GPT_CONFIG_124M['context_length'])
ds_valid = StreamingDataset(gutenberg_valid, context_size=GPT_CONFIG_124M['context_length'])

loader_train = DataLoader(ds_train,
                          batch_size=128,
                          drop_last=True,
                          num_workers=2,
                          pin_memory=True
                          )

loader_valid = DataLoader(ds_valid,
                          batch_size=128,
                          drop_last=True,
                          num_workers=2,
                          pin_memory=True
                          )
fun = 'mix'
start_time = time.time()
match fun:
    case 'baseline': # autocast-only
        #torch.backends.cuda.matmul.allow_tf32 = True # enable Tensor Cores
        torch.set_float32_matmul_precision('high')
        train_losses, valid_losses, total_batches = train_model_simple(model,
                                                                       loader_train,
                                                                       loader_valid,
                                                                       optimizer,
                                                                       device,
                                                                       num_epoch=2,
                                                                       training_batches=10,
                                                                       val_ratio=0.1
                                                                       )
    case 'mix': # autocast and grad scaler
        #torch.backends.cuda.matmul.allow_tf32 = True
        torch.set_float32_matmul_precision('high') # enable Tensor Cores
        train_losses, valid_losses, total_batches = train_model_mixed_precision(model,
                                                                                loader_train,
                                                                                loader_valid,
                                                                                optimizer,
                                                                                device,
                                                                                num_epoch=2,
                                                                                training_batches=200,
                                                                                val_ratio=0.1
                                                                                )
    case _:
        print("Unknown choice")

end_time = time.time()
execution_time_minutes = (end_time - start_time) / 60
print(f"Training completed in {execution_time_minutes:.2f} minutes.")


Epoch: 0

batches:220 | loss: 7.448/7.287| tok-seen: 14,417,920| time: 144.17| tok/sec: 100,006.82 | epoch: 10.11 hrs

batches:440 | loss: 7.300/7.215| tok-seen: 28,835,840| time: 84.42| tok/sec: 170,783.58 | epoch: 5.92 hrs

batches:660 | loss: 7.267/7.259| tok-seen: 43,253,760| time: 83.62| tok/sec: 172,420.09 | epoch: 5.86 hrs

Saving model after KeyboardInterrupt

Tokens seen: 39,452,672 / 3,932,160 / 43,384,832
Total books seen 447
Saving checkpoint...447
Training completed in 5.25 minutes.


# ...

## TEST: Weights tying and loss
Conclusion: Using the embedding layer's default initialization (-5 to +5) for shared weights between `tok_emb` and `out_head`, causes very large logits and unstable loss when tying weights. Linear layers initialize with a narrower distribution (~±0.02), which produces logits near ±3 and an initial loss close to ln(vocab_size) ≈ 10.8.

By initializing tok_emb from out_head instead of the reverse, we ensure stable
starting conditions for training and prevent exploding loss.
`class GPTModel`

`def __init__(self, cfg, flash=False, tied=False):`

    if self.tied:
            self.tok_emb.weight = self.out_head.weight # not reverse!`


#### dummy data

In [13]:
data = torch.randint(0, 50257, (16, 513), device='cuda')
x = data[:, :-1]
target = data[:, 1:]

In [14]:
x.shape

torch.Size([16, 512])

In [15]:
target.shape

torch.Size([16, 512])

In [16]:
x[0,:10]

tensor([18099, 29438, 35212, 41222,  9061, 26163, 44525,   893, 33492, 30653],
       device='cuda:0')

In [17]:
target[0,:10]

tensor([29438, 35212, 41222,  9061, 26163, 44525,   893, 33492, 30653, 48698],
       device='cuda:0')

#### model using embedding weights for tying

In [18]:
model_1 = GPTModel(GPT_CONFIG_124M, flash=True, tied=False)
model_1.out_head.weight = model_1.tok_emb.weight

In [19]:
print(model_1.out_head.weight.data_ptr() == model_1.tok_emb.weight.data_ptr())
model_1.to('cuda')
print("")

True



In [20]:
print(f"Tok emb size: {model_1.tok_emb}")

Tok emb size: Embedding(50257, 768)


In [21]:
y = model_1(x)

In [22]:
y.shape

torch.Size([16, 512, 50257])

In [23]:
loss = torch.nn.functional.cross_entropy(y.flatten(0, 1), target.flatten())

In [28]:
# HUGE INITIAL LOSS!!!
loss

tensor(460.2455, device='cuda:0', grad_fn=<NllLossBackward0>)

In [25]:
print("Logits min:", y.min().item(), "max:", y.max().item(), "mean:", y.mean().item())

Logits min: -159.7432403564453 max: 557.7025756835938 mean: 0.004521739669144154


In [27]:
print("Tied weight min:", model_1.out_head.weight.min().item())
print("Tied weight max:", model_1.out_head.weight.max().item())
print("Embedding weight min:", model_1.tok_emb.weight.min().item())
print("Embedding weight max:", model_1.tok_emb.weight.max().item())

Tied weight min: -5.335026741027832
Tied weight max: 5.475025177001953
Embedding weight min: -5.335026741027832
Embedding weight max: 5.475025177001953


#### model using Linear layer weigths for tying

In [29]:
model_2 = GPTModel(GPT_CONFIG_124M, flash=True, tied=False)
model_2.tok_emb.weight = model_2.out_head.weight

In [30]:
print(model_2.out_head.weight.data_ptr() == model_2.tok_emb.weight.data_ptr())
model_2.to('cuda')
print("")

True



In [31]:
print(f"Tok emb size: {model_2.tok_emb}")

Tok emb size: Embedding(50257, 768)


In [32]:
y = model_2(x)

In [33]:
y.shape

torch.Size([16, 512, 50257])

In [34]:
loss = torch.nn.functional.cross_entropy(y.flatten(0, 1), target.flatten())

In [40]:
# Expected loss for random initialized parameters -> -ln(1/50257) = 10.8249 ✅
loss

tensor(10.9947, device='cuda:0', grad_fn=<NllLossBackward0>)

In [36]:
print("Logits min:", y.min().item(), "max:", y.max().item(), "mean:", y.mean().item())

Logits min: -3.2705678939819336 max: 3.3515560626983643 mean: 0.00011122244177386165


In [38]:
print("Tied weight min:", model_2.out_head.weight.min().item())
print("Tied weight max:", model_2.out_head.weight.max().item())
print("Embedding weight min:", model_2.tok_emb.weight.min().item())
print("Embedding weight max:", model_2.tok_emb.weight.max().item())

Tied weight min: -0.03608439117670059
Tied weight max: 0.036084387451410294
Embedding weight min: -0.03608439117670059
Embedding weight max: 0.036084387451410294
