# Data and imports

1. I am checking the GPU that is available to us, usually we go through A100.
2. I install the torchsummary and import it.
3. I am importing bunch of more libraries. Info about it in the cell itself.
4. Loading **roneneldan/TinyStories** from huggingface datasets and getting both slpits. It is a pretraining corpus so we dont care about order of stories and stuff.

In [None]:
!nvidia-smi

Sat Aug 30 23:34:45 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   35C    P0             50W /  400W |       0MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
%pip install torchinfo --quiet
from torchinfo import summary # To check params and layers and such

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader # To prepare and use datasets properly

from torch.nn.attention import SDPBackend, sdpa_kernel # To use in flash attention. (Not using as of now, not required)
from torch.amp import autocast # For GPU acceleration

import pandas as pd
import numpy as np
from concurrent.futures import ProcessPoolExecutor
import multiprocessing as mp

import tiktoken # Tokenizer with p50k base vocab. Size - 50281
tok = tiktoken.get_encoding("p50k_base")

In [None]:
from datasets import load_dataset
dataset = load_dataset("roneneldan/TinyStories", split="train") # 2.2M training stories

TRAINING_CUTOFF=1_000_000
data = pd.DataFrame(dataset)[:TRAINING_CUTOFF] # Into a dataframe.

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00004-2d5a1467fff108(…):   0%|          | 0.00/249M [00:00<?, ?B/s]

data/train-00001-of-00004-5852b56a2bd28f(…):   0%|          | 0.00/248M [00:00<?, ?B/s]

data/train-00002-of-00004-a26307300439e9(…):   0%|          | 0.00/246M [00:00<?, ?B/s]

data/train-00003-of-00004-d243063613e5a0(…):   0%|          | 0.00/248M [00:00<?, ?B/s]

data/validation-00000-of-00001-869c898b5(…):   0%|          | 0.00/9.99M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2119719 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/21990 [00:00<?, ? examples/s]

# Architecture
1. Rotary Positional Encodings is used for positional knowledge.
2. Grouped Query attention is ussed. 2 KV heads per query head.
3. Multi Layer Network with GeLU and 2.5x more neurons in hidden layer.
4. Transformer block that consist of 1, 2 and 3, as well as RMSNorm and all.
5. Dataset creation.
6. Entire GatorGPT structure !

In [None]:
class Rope(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(Rope, self).__init__()

        # Store the embedding dimension (must be even for proper 2D rotations)
        assert d_model % 2 == 0, "d_mAodel must be even for proper RoPE implementation"
        self.d_model = d_model

        # Store the maximum sequence length this RoPE can handle
        self.max_len = max_len

        # Create position indices [0, 1, 2, ..., max_len-1] and add dimension for broadcasting
            # Shape: (max_len, 1) - each position gets its own row
        self.register_buffer('position_ids', torch.arange(max_len).unsqueeze(1))

        # Create frequency terms for rotation - these determine how fast each dimension pair rotates
            # Uses exponential decay: smaller indices rotate faster, larger indices rotate slower
            # torch.arange(0, d_model, 2) creates [0, 2, 4, 6, ...] (even indices only)
            # The formula creates frequencies: [1/10000^(0/d), 1/10000^(2/d), 1/10000^(4/d), ...]
            # This is equal to (10000 ** (-(2 * i) / d_model) for i in range(d_model // 2))
        self.register_buffer('div_term', torch.exp(torch.arange(0, d_model, 2) * -(torch.log(torch.tensor(10000.0)) / d_model)))

    def forward(self, x):
        # Input shape: (batch_size, seq_len, d_model)
        batch_size, seq_len, d_model = x.shape

        # Get position indices for current sequence length (trim to actual sequence length)
            # If input has 100 tokens, this gets positions [0, 1, 2, ..., 99]
        position_ids = self.position_ids[:seq_len]  # Shape: (seq_len, 1)

        # Calculate rotation angles for each position and frequency based on 2017 paper
            # Multiply each position by each frequency term to get rotation angles
            # Shape: (seq_len, d_model//2)
            # This is basically: pos/(10000^(2i/d_model))
        angles = position_ids * self.div_term

        # Calculate sine and cosine values for rotation
            # Shape: (seq_len, d_model//2)
        cos_vals = torch.cos(angles)
        sin_vals = torch.sin(angles)

        # Reshape input to separate even and odd dimensions for 2D rotation
            # Split x into pairs: (x_0, x_1), (x_2, x_3), (x_4, x_5), ...
            # Shape: (batch_size, seq_len, d_model//2, 2)
        x_pairs = x.view(batch_size, seq_len, d_model // 2, 2)

        # Extract even and odd components
            # x_even contains x_0, x_2, x_4, ... (first element of each pair)
            # x_odd contains x_1, x_3, x_5, ... (second element of each pair)
        x_even = x_pairs[..., 0]  # Shape: (batch_size, seq_len, d_model//2)
        x_odd = x_pairs[..., 1]   # Shape: (batch_size, seq_len, d_model//2)

        # Apply 2D rotation to each pair of dimensions
            # Rotation matrix: [[cos, -sin], [sin, cos]]
            # For each pair (x_i, x_{i+1}), compute:
            # x_i' = x_i * cos - x_{i+1} * sin
            # x_{i+1}' = x_i * sin + x_{i+1} * cos
        rotated_even = x_even * cos_vals - x_odd * sin_vals
        rotated_odd = x_even * sin_vals + x_odd * cos_vals

        rotated_pairs = torch.stack([rotated_even, rotated_odd], dim=-1)
        rotated_x = rotated_pairs.view(batch_size, seq_len, d_model)

        return rotated_x

In [None]:
class GQA(nn.Module):
    def __init__(self,
                 d_model: int = 384,
                 n_heads: int = 8,
                 gqa_groups: int = 2,
                 max_len: int = 1024,
                ):
        super().__init__()  # initialize base Module
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"  # validate head split
        assert n_heads % gqa_groups == 0, "n_heads must be divisible by gqa_groups"  # validate grouping

        self.d_model = d_model  # store model dimension
        self.n_heads = n_heads  # store number of query heads
        self.gqa_groups = gqa_groups  # store number of groups for GQA
        self.head_dim = d_model // n_heads  # compute per-head dimension
        self.n_kv_heads = n_heads // gqa_groups  # compute number of K/V heads for GQA
        self.max_len = max_len  # store max sequence length

        # Define bias-free linear projections
        self.q_proj = nn.Linear(d_model, n_heads * self.head_dim, bias=False)  # Q projection: d_model -> H*D
        self.k_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)  # K projection: d_model -> H_kv*D
        self.v_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)  # V projection: d_model -> H_kv*D
        self.o_proj = nn.Linear(n_heads * self.head_dim, d_model, bias=False)  # Output projection: H*D -> d_model

        # Instantiate two RoPE modules with the exact composite dims requested
        self.rope_q = Rope(d_model=n_heads * self.head_dim, max_len=max_len)  # RoPE for Q (expects (B,T,H*D))
        self.rope_k = Rope(d_model=self.n_kv_heads * self.head_dim, max_len=max_len)  # RoPE for K (expects (B,T,H_kv*D))

    def forward(self,
                x: torch.Tensor,  # (B, T, d_model)
                attention_mask: torch.Tensor | None = None  # (B, T) 1=real,0=pad (ignored: no attention bias)
                ) -> torch.Tensor:  # returns (B, T, d_model)
        B, T, C = x.shape  # unpack input shape

        # Linear projections for Q, K, V
        q = self.q_proj(x)  # (B, T, H*D)
        k = self.k_proj(x)  # (B, T, H_kv*D)
        v = self.v_proj(x)  # (B, T, H_kv*D)

        # Apply RoPE to Q over the flattened head dimension
        q = self.rope_q(q)  # (B, T, H*D) with rotary positional encoding applied
        q = q.view(B, T, self.n_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()  # (B, H, T, D)

        # Apply RoPE to K over the flattened head dimension
        k = self.rope_k(k)  # (B, T, H_kv*D) with rotary positional encoding applied
        k = k.view(B, T, self.n_kv_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()  # (B, H_kv, T, D)

        # Reshape V to heads (no RoPE for V)
        v = v.view(B, T, self.n_kv_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()  # (B, H_kv, T, D)

        # Expand K and V from n_kv_heads to n_heads via repeat_interleave on the head axis

        ########################################--Uncomment below code--################################
        expand_factor = self.n_heads // self.n_kv_heads  # compute replication factor
        k = k.repeat_interleave(expand_factor, dim=1)  # (B, H, T, D)
        v = v.repeat_interleave(expand_factor, dim=1)  # (B, H, T, D)
        ################################################################################################
        # Above thing converts [1,2,3,4] -> [1,1,1,2,2,2,3,3,3,4,4,4] when expand_factor is 3 and dim=0

        # GQA will not work in sdpa kernal with forced flash attention.
          # Rather I used expand factor to expand it and using it I did manual GQA and directly using Flash attention.

        # Compute SDPA with purely causal masking (no external attention bias, this uses flash attention
        # Removed sdpa_kernel context and enable_gqa=True to allow torch.compile to find a suitable kernel
        with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
          out = F.scaled_dot_product_attention(
              q,
              k,
              v,
              attn_mask=None,
              is_causal=True,
              enable_gqa=False
          )  # (B, H, T, D)

        # Merge heads back to (B, T, H*D)
        out = out.permute(0, 2, 1, 3).contiguous().view(B, T, self.n_heads * self.head_dim)  # (B, T, d_model)

        # Project to output dimension
        out = self.o_proj(out)  # (B, T, d_model)

        return out  # return attended representations

In [None]:
class MLP(nn.Module):
    """
    SwiGLU MLP for a decoder-only Transformer block.

    - d_model: 384
    - d_ff: ~768 (≈2.5 × d_model)
    - Linear layers are bias-free
    - RMSNorm is applied outside this module
    - Input/Output shape: (batch, seq_len, d_model)
    - BF16-friendly: uses ops that preserve input dtype
    """
    def __init__(self, d_model: int = 384, d_ff: int = 768):
        super().__init__()
        # Fused "up" + "gate" projection to reduce matmuls: d_model -> 2*d_ff
        self.w1 = nn.Linear(d_model, 2 * d_ff, bias=False)
        # Down projection: d_ff -> d_model
        self.w2 = nn.Linear(d_ff, d_model, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, T, d_model)
        up, gate = self.w1(x).chunk(2, dim=-1)  # (B, T, d_ff) each

        # We split in two because SwiGLU works like that and it takes -
            # First half which is content
            # Second half which is how much of content in the context
        x = up * F.silu(gate)                   # SwiGLU: up ⊗ swish(gate)
        x = self.w2(x)                          # (B, T, d_model)
        return x

In [None]:
class Block(nn.Module):
    def __init__(
        self,
        d_model: int = 384,
        n_heads: int = 8,
        gqa_groups: int = 2,
        max_len: int = 1024,
        d_ff: int = 768,
        eps: float = 1e-5,
        dropout_p: float = 0.0,  # keep 0.0 for pretrain
    ):
        super().__init__()
        self.rms1 = nn.modules.normalization.RMSNorm(d_model, eps)
        self.rms2 = nn.modules.normalization.RMSNorm(d_model, eps)

        self.attn = GQA(d_model, n_heads, gqa_groups, max_len)  # should include proj_out
        self.mlp = MLP(d_model, d_ff)  # your SwiGLU MLP (bias-free)

        self.drop_attn = nn.Dropout(dropout_p)
        self.drop_mlp = nn.Dropout(dropout_p)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # pre-rms
        x = x + self.drop_attn(self.attn(self.rms1(x)))
        x = x + self.drop_mlp(self.mlp(self.rms2(x)))
        return x

In [None]:
class FastDataset(Dataset):
    """Pre-computed sliding windows, numpy arrays for speed"""
    def __init__(self, tokens, max_length=256, stride=128):
        # Convert to numpy for faster slicing
        self.tokens = np.array(tokens, dtype=np.int32)
        self.max_length = max_length

        # Pre-compute all valid starts
        self.starts = np.arange(0, len(tokens) - max_length, stride)

    def __len__(self):
        return len(self.starts)

    def __getitem__(self, idx):
        start = self.starts[idx]
        end = start + self.max_length

        input_ids = torch.from_numpy(self.tokens[start:end].copy()).long()
        target_ids = torch.from_numpy(self.tokens[start+1:end+1].copy()).long()

        return input_ids, target_ids

In [None]:
class GatorGPT(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        d_model: int = 384,
        n_heads: int = 8,
        gqa_groups: int = 2,
        max_len: int = 1024,
        d_ff: int = 768,
        eps: float = 1e-5,
        dropout_p: float = 0.0,
        blocks: int = 10,
    ):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.unembed = nn.Linear(d_model, vocab_size, bias=False)

        self.final_rms = nn.modules.normalization.RMSNorm(d_model, eps)
        self.unembed.weight = self.embed.weight

        self.blocks = nn.ModuleList(
            [
                Block(
                    d_model=d_model,
                    n_heads=n_heads,
                    gqa_groups=gqa_groups,
                    max_len=max_len,
                    d_ff=d_ff,
                    eps=eps,
                    dropout_p=dropout_p,
                ) for _ in range(blocks)
            ]
        )

    def forward(self, x):
        """
        Forward method that takes in the tokens
        """
      # x: (batch, seq_len) of token ids
        h = self.embed(x)                 # (batch, seq_len, d_model)
        for block in self.blocks:         # run each transformer block
            h = block(h)
        h = self.final_rms(h)
        logits = self.unembed(h)          # (batch, seq_len, vocab_size)
        return logits

# Helper Functions

1. Function for creating data loaders.
2. Function for calculating batch loss and loader loss.
3. Generating simple text through autocast.

In [None]:
###################################### TURNS TEXTS TO TOKENS
def text_to_token_ids(text, tokenizer):
  encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
  encoded_tensor = torch.tensor(encoded).unsqueeze(0)
  return encoded_tensor
############################################################################

###################################### TURN TOKENS TO TEXT
def token_ids_to_text(token_ids, tokenizer):
  decoded = tokenizer.decode(token_ids.squeeze(0).tolist())
  return decoded
############################################################################

###################################### CREATES DATA LOADERS
def create_fast_dataloader(tokens, batch_size=8, max_length=256, stride=128, shuffle=True):
    """Fast dataloader with proper A100 settings"""
    dataset = FastDataset(tokens, max_length, stride)

    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=True,
        num_workers=4,  # Use CPU cores
        pin_memory=True,  # Faster GPU transfer
        prefetch_factor=2,
        persistent_workers=True
    )

############################################################################

In [None]:
###################################### LOSS FOR ONE BATCH
# Caluclates loss for a batch
def calc_loss_batch(input_batch, target_batch, model, device):
    input_batch, target_batch = input_batch.to(device, non_blocking=True), target_batch.to(device, non_blocking=True)

    # For other devices
    # logits = model(input_batch)
    # loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())

    # For A100s - Corrected autocast usage
    with autocast("cuda", torch.bfloat16):
        logits = model(input_batch)
        loss = torch.nn.functional.cross_entropy(
            logits.flatten(0, 1),
            target_batch.flatten()
        )

    return loss
############################################################################

###################################### LOSS FOR ENTIRE LOADER
# Caluculates loss for ENTIRE data_loader which calls calc_loss_batch function inside itself
def calc_loss_loader(data_loader, model, device, num_batches=None):
    total_loss = 0.
    if len(data_loader) == 0:
        return float("nan")
    elif num_batches is None:
        num_batches = len(data_loader)
    else:
        # Reduce the number of batches to match the total number of batches in the data loader
        # if num_batches exceeds the number of batches in the data loader
        num_batches = min(num_batches, len(data_loader))
    for i, (input_batch, target_batch) in enumerate(data_loader):
        if i < num_batches:
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            total_loss += loss.item()
        else:
            break
    return total_loss / num_batches
############################################################################

###################################### USE THIS TO EVALUATE MODEL DIRECLTY
def evaluate_model(model, train_loader, val_loader, device, eval_iter):
  # basically returns the losses for training and validation
  model.eval()
  with torch.no_grad():
    return calc_loss_loader(train_loader, model, device, eval_iter), calc_loss_loader(val_loader, model, device, eval_iter)
############################################################################

####################################### GENERATING NEW TOKENS
def generate_text_simple(model, idx, max_new_tokens, context_size):

  # idx is (batch, n_tokens) array of indices in current context
  for _ in range(max_new_tokens):
    # If LLM suports only 5 tokens, and the context size is 10, then we only use last 5 toens as context.
    idx_cond = idx[:, -context_size:]

    # Gettings the predictions
    with torch.no_grad():
      # Reshape idx_cond to (batch_size, sequence_length, emb_dim)
      # idx_cond = idx_cond.unsqueeze(-1).repeat(1 , 1, model.norm1.scale.shape[0]) # Or model.att.d_in to get the embedding dimension
      with autocast("cuda", torch.bfloat16):
        logits = model(idx_cond)

    # We take the last row. We dont do anything to the batches neither to the last dimension of the vocabularies, but take the last row
    logits = logits[:, -1, :] # (batch, vocab_size)

    # getting probablities from the logits. We can say something like 50% chances of this, 2% chances of this...
    probs = torch.softmax(logits, dim=-1) # (batch, vocab_size)

    # We see the highest value's index
    idx_next = torch.argmax(probs, dim=-1, keepdim=True) # (batch, 1)

    # Append the predicted token_id generated to the original index
    idx = torch.cat((idx, idx_next), dim=1) # (batch, num_tokens+1)

  return idx
############################################################################

###################################### GENERATING AND PRINTING SAMPLES
def generate_and_print_sample(model, tokenizer, device, start_context, context_size):
  # we print out, what the model is generating right now at the end of each epoch. Also, we print 50 items!
  model.eval()
  encoded = text_to_token_ids(start_context, tokenizer).to(device)

  with torch.no_grad():
    token_ids = generate_text_simple(
        model, encoded, 50, context_size
    )

  decoded = token_ids_to_text(token_ids, tokenizer)
  print(decoded.replace("\n", " "))
  model.train()
########################################################### Tokenizing entire batch
def tokenize_batch(text_batch):
    """Tokenize a batch of texts"""
    return tok.encode("\n\n".join(text_batch), allowed_special={"<|endoftext|>"})
############################################################################

####################################################### Preparing the dataset efficiently
def fast_prepare_data(data, train_split=0.95, max_workers=None):
    """Pre-tokenize everything in parallel, avoid giant string joins"""
    if max_workers is None:
        max_workers = min(mp.cpu_count(), 8)

    split_idx = int(train_split * len(data))
    train_texts = data[:split_idx]['text']
    val_texts = data[split_idx:]['text']

    # Split into chunks for parallel processing
    chunk_size = 1000  # Process 1000 stories at a time

    def process_split(texts):
        chunks = [texts[i:i+chunk_size] for i in range(0, len(texts), chunk_size)]

        all_tokens = []
        with ProcessPoolExecutor(max_workers=max_workers) as executor:
            token_chunks = list(executor.map(tokenize_batch, chunks))

        # Flatten
        for chunk in token_chunks:
            all_tokens.extend(chunk)
        return all_tokens

    print("🔥 Tokenizing training data...")
    train_tokens = process_split(train_texts)
    print("🔥 Tokenizing validation data...")
    val_tokens = process_split(val_texts)

    return train_tokens, val_tokens
############################################################################

# Training loop

1. Function for training loop.
2. Preparing for training loop.
3. **Training**.
4. Inference example.

In [None]:
from typing import Dict, List, Tuple, Any

def train_model_simple(
    model,
    train_loader,
    val_loader,
    optimizer,
    device,
    tokenizer,
    start_context: str,
    cfg: Dict[str, Any],
) -> Tuple[List[float], List[float], List[int]]:
    """
    Numeric hyperparameters expected in cfg:
      - "num_epochs": int
      - "eval_freq": int
      - "eval_iter": int
      - "patience": int
    Optional numeric params:
      - "sample_tokens": int (default 1024)
      - "progress_chunks": int (how many intra-epoch progress prints; default 5)
    """

    # Pull required numeric params from cfg (hard fail if missing)
    num_epochs   = cfg["num_epochs"]
    eval_freq    = cfg["eval_freq"]
    eval_iter    = cfg["eval_iter"]
    PATIENCE     = cfg["patience"]
    sample_tokens = cfg["sample_tokens"]
    progress_chunks = max(1, int(cfg["progress_chunks"]))

    # Training start notification
    print("🚀 TRAINING STARTED!")
    print(f"📊 Configuration: {num_epochs} epochs, {len(train_loader)} batches per epoch")
    print(f"🎯 Evaluation every {eval_freq} steps, patience: {PATIENCE}")
    print("-" * 60)

    # Tracking
    train_losses, val_losses, track_tokens_seen = [], [], []
    tokens_seen = 0
    global_step = -1  # Total batches processed across all epochs (never resets)
    best_val_loss = float('inf')
    patience_counter = 0

    # Progress logging frequency (only if progress_chunks > 0)
    # How often to print progress during each epoch (0 = no progress prints)
    progress_every = max(1, len(train_loader) // progress_chunks) if progress_chunks > 0 else 0

    for epoch in range(num_epochs):
        model.train()
        epoch_steps = 0  # Batches processed in current epoch (resets each epoch)
        epoch_loss = 0.0

        for input_batch, target_batch in train_loader:
            if patience_counter >= PATIENCE:
                break

            optimizer.zero_grad()
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            loss.backward()
            optimizer.step()

            tokens_seen += input_batch.numel()
            global_step += 1  # Increment total step counter (never resets)
            epoch_steps += 1  # Increment current epoch step counter
            epoch_loss += loss.item()

            # Optional minimal progress logging
            if progress_every > 0 and epoch_steps % progress_every == 0:
                avg_loss = epoch_loss / epoch_steps
                progress = epoch_steps / len(train_loader) * 100
                print(f"Epoch {epoch+1}/{num_epochs} - Step {epoch_steps}/{len(train_loader)} ({progress:.1f}%) - Loss: {avg_loss:.3f}")

            # Periodic evaluation
            if epoch_steps % eval_freq == 0:
                train_loss, val_loss = evaluate_model(
                    model, train_loader, val_loader, device, eval_iter
                )
                train_losses.append(train_loss)
                val_losses.append(val_loss)
                track_tokens_seen.append(tokens_seen)

                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    patience_counter = 0
                    print(f"Step {global_step}: Val loss improved to {val_loss:.3f}")
                else:
                    patience_counter += 1

        # End of epoch - minimal output
        if epoch_steps > 0:
            avg_epoch_loss = epoch_loss / epoch_steps
            print(f"Epoch {epoch+1}/{num_epochs} complete - Avg loss: {avg_epoch_loss:.3f} - Best val: {best_val_loss:.3f}")

    print(f"🎉 Training complete! Best validation loss: {best_val_loss:.3f}")
    return train_losses, val_losses, track_tokens_seen

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = GatorGPT(vocab_size=tok.n_vocab)
model = model.to(device) # To cuda, compiling before running
model = torch.compile(model, fullgraph=True, mode="reduce-overhead")

optim = torch.optim.AdamW(model.parameters(), lr=03e-4, eps=1e-08, weight_decay=0.01)

cfg = {
    "num_epochs": 1,
    "eval_freq": 5000,
    "eval_iter": 1000,
    "patience": 2,
    "sample_tokens": 1024,
    "progress_chunks": 10,
}

CONTEXT = "We ran across a field that was"
train_tokens, val_tokens = fast_prepare_data(data)

train_loader = create_fast_dataloader(
    train_tokens,
    batch_size=16,
    max_length=512,
    stride=256,
    shuffle=True
)

val_loader = create_fast_dataloader(
    val_tokens,
    batch_size=16,
    max_length=512,
    stride=256,
    shuffle=False
)

🔥 Tokenizing training data...
🔥 Tokenizing validation data...


In [None]:
print("Total Training tokens:", len(train_tokens)) # 214_198_685
print("Total Validation tokens:", len(val_tokens)) # 11_310_150

Total Training tokens: 214198685
Total Validation tokens: 11310150


In [None]:
train_losses, val_losses, tokens_seen = train_model_simple(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optim,
    device=device,
    tokenizer=tok,
    start_context=CONTEXT,
    cfg=cfg,
)

🚀 TRAINING STARTED!
📊 Configuration: 1 epochs, 52294 batches per epoch
🎯 Evaluation every 5000 steps, patience: 2
------------------------------------------------------------
Step 4999: Val loss improved to 2.797
Epoch 1/1 - Step 5229/52294 (10.0%) - Loss: 4.890
Step 9999: Val loss improved to 2.175
Epoch 1/1 - Step 10458/52294 (20.0%) - Loss: 3.632
Step 14999: Val loss improved to 1.979
Epoch 1/1 - Step 15687/52294 (30.0%) - Loss: 3.102
Step 19999: Val loss improved to 1.873
Epoch 1/1 - Step 20916/52294 (40.0%) - Loss: 2.802
Step 24999: Val loss improved to 1.810
Epoch 1/1 - Step 26145/52294 (50.0%) - Loss: 2.604
Step 29999: Val loss improved to 1.760
Epoch 1/1 - Step 31374/52294 (60.0%) - Loss: 2.463
Step 34999: Val loss improved to 1.718
Epoch 1/1 - Step 36603/52294 (70.0%) - Loss: 2.356
Step 39999: Val loss improved to 1.692
Epoch 1/1 - Step 41832/52294 (80.0%) - Loss: 2.271
Step 44999: Val loss improved to 1.665
Epoch 1/1 - Step 47061/52294 (90.0%) - Loss: 2.202
Step 49999: Val lo

KeyboardInterrupt: 

# Debugging

1. Inspecting every layer and dtype there are.
2. Inference example

In [None]:
token_ids_to_text(generate_text_simple(model, text_to_token_ids("Little girl was ", tok).to(device), 50, 1024), tok).replace("\n", " ").replace("\r\\", "").replace("\r", "")

"Little girl was  in the park, playing with her friends. She was having so much fun that she didn't notice the time.  Suddenly, she heard a loud noise. She looked around and saw a big, scary dog. She was scared and started to"

In [None]:
def inspect_everything(model, batch, device="cuda",
                       expect_dtype=torch.bfloat16,
                       prints_per_module: int = 1):
    """
    One-call tracer for dtype/shape/device across the whole model.
    - Works with batch as (inputs, targets), list/tuple, dict, or single tensor.
    - Logs first seen [IN]/[OUT] per module (limited by `prints_per_module`).
    - Enforces `expect_dtype` **only for floating-point tensors** (so int64 token IDs are fine).
    - Runs under torch.amp.autocast with bf16 by default.
    - Enables all SDP backends (incl. math) as a safety net.
    """
    import torch
    from contextlib import contextmanager

    def _first_tensor(x):
        if isinstance(x, torch.Tensor): return x
        if isinstance(x, (list, tuple)):
            for t in x:
                if isinstance(t, torch.Tensor): return t
        if isinstance(x, dict):
            for v in x.values():
                if isinstance(v, torch.Tensor): return v
        return None

    def _to_device(x, dev):
        if isinstance(x, torch.Tensor): return x.to(dev)
        if isinstance(x, (list, tuple)):
            seq = [t.to(dev) if isinstance(t, torch.Tensor) else t for t in x]
            return type(x)(seq) if isinstance(x, tuple) else seq
        if isinstance(x, dict):
            return {k: (v.to(dev) if isinstance(v, torch.Tensor) else v) for k, v in x.items()}
        return x

    def _extract_inputs(batch_like):
        if isinstance(batch_like, (list, tuple)) and len(batch_like) >= 1:
            return batch_like[0]
        return batch_like

    @contextmanager
    def _sdp_safety():
        prev = (
            torch.backends.cuda.flash_sdp_enabled(),
            torch.backends.cuda.mem_efficient_sdp_enabled(),
            torch.backends.cuda.math_sdp_enabled(),
        )
        torch.backends.cuda.enable_flash_sdp(True)
        torch.backends.cuda.enable_mem_efficient_sdp(True)
        torch.backends.cuda.enable_math_sdp(True)
        try:
            yield
        finally:
            f, m, a = prev
            torch.backends.cuda.enable_flash_sdp(f)
            torch.backends.cuda.enable_mem_efficient_sdp(m)
            torch.backends.cuda.enable_math_sdp(a)

    # Parameter/buffer scan (floating only)
    bad = []
    for n, p in model.named_parameters():
        if p.dtype.is_floating_point and p.dtype != expect_dtype:
            bad.append(("param", n, p.dtype))
    for n, b in model.named_buffers():
        if b.dtype.is_floating_point and b.dtype != expect_dtype:
            bad.append(("buffer", n, b.dtype))
    if bad:
        print("⚠️ Nonconforming floating params/buffers:")
        for kind, n, dt in bad:
            print(f"   - {kind:6s} {n}: {dt}")

    handles, counts = [], {}

    def _hook(name):
        def fn(mod, inp, out):
            c = counts.get(name, 0)
            if c >= prints_per_module: return
            counts[name] = c + 1

            ti = _first_tensor(inp if isinstance(inp, tuple) else (inp,))
            to = _first_tensor(out)

            if ti is not None:
                print(f"[IN ] {name:<48} dtype={ti.dtype} shape={tuple(ti.shape)} device={ti.device}")
                if ti.dtype.is_floating_point and ti.dtype != expect_dtype:
                    raise RuntimeError(f"[IN ] {name} expected {expect_dtype} for floating tensors, got {ti.dtype}")
            if to is not None:
                print(f"[OUT] {name:<48} dtype={to.dtype} shape={tuple(to.shape)} device={to.device}")
                if to.dtype.is_floating_point and to.dtype != expect_dtype:
                    raise RuntimeError(f"[OUT] {name} expected {expect_dtype} for floating tensors, got {to.dtype}")
        return fn

    for n, m in model.named_modules():
        if n == "":  # skip root
            continue
        handles.append(m.register_forward_hook(_hook(n)))

    try:
        model = model.to(device)
        batch = _to_device(batch, device)
        inputs = _extract_inputs(batch)

        print("SDP backends -> flash:",
              torch.backends.cuda.flash_sdp_enabled(),
              "mem_efficient:",
              torch.backends.cuda.mem_efficient_sdp_enabled(),
              "math:",
              torch.backends.cuda.math_sdp_enabled())

        from torch.amp import autocast
        with _sdp_safety(), autocast("cuda", expect_dtype):
            if isinstance(inputs, dict):
                _ = model(**inputs)
            elif isinstance(inputs, (list, tuple)):
                _ = model(*inputs)
            else:
                _ = model(inputs)
    finally:
        for h in handles:
            h.remove()

In [None]:
sample = next(iter(train_loader))  # yields (input_ids, target_ids)
model = GatorGPT(50257).to(device).to(torch.bfloat16)
inspect_everything(model, sample, device=device, expect_dtype=torch.bfloat16)

# Saving files to Huggingface

In [None]:
import torch
from pathlib import Path
from safetensors.torch import save_file, save_model
import json

# 🔧 Your trained model object
# ⛔️ Do NOT re-load anything — use the model you've already trained in Colab
# Example: model = YourTrainedTransformer()

# Folder to save
save_path = Path("./gatorgpt")
save_path.mkdir(exist_ok=True)

# ✅ Save model weights to .safetensors
# Use save_model instead of save_file when tensors share memory
save_model(model, str(save_path / "model.safetensors"))

# ✅ Save a config.json — change these values to match your model
config = {
    "architectures": ["GatorGPT"],
    "model_type": "gator-transformer",
    "hidden_size": model.blocks[0].attn.d_model if hasattr(model.blocks[0].attn, "d_model") else 512, # Assuming d_model is consistent across blocks
    "num_attention_heads": model.blocks[0].attn.n_heads if hasattr(model.blocks[0].attn, "n_heads") else 8,
    "num_hidden_layers": len(model.blocks) if hasattr(model, "blocks") else 6,
    "vocab_size": model.embed.num_embeddings if hasattr(model.embed, "num_embeddings") else 50257,
    "max_position_embeddings": model.blocks[0].attn.max_len if hasattr(model.blocks[0].attn, "max_len") else 1024 # Assuming max_len is consistent across blocks
}
with open(save_path / "config.json", "w") as f:
    json.dump(config, f, indent=2)

# ✅ Save tokenizer info
tokenizer_info = {
    "library": "tiktoken",
    "encoding": "p50k_base"
}
with open(save_path / "tokenizer_manifest.json", "w") as f:
    json.dump(tokenizer_info, f, indent=2)

print("✅ Model + config + tokenizer saved to ./gatorgpt")

✅ Model + config + tokenizer saved to ./gatorgpt


In [None]:
!pip install huggingface_hub

from huggingface_hub import login, create_repo, upload_folder

# 🔑 Login (you'll be prompted once)
login()

# Create the repo if needed
repo_id = "kunjcr2/GatorGPT"
# create_repo(repo_id, repo_type="model", exist_ok=True)

# Push the folder
upload_folder(
    repo_id=repo_id,
    folder_path="./gatorgpt",
    repo_type="model",
    commit_message="Initial push of GatorGPT with p50k_base tokenizer"
)

print(f"✅ Pushed to https://huggingface.co/{repo_id}")



VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

Processing Files (0 / 0)                : |          |  0.00B /  0.00B            

New Data Upload                         : |          |  0.00B /  0.00B            

  /content/gatorgpt/model.safetensors   :   2%|2         | 2.77MB /  131MB            

✅ Pushed to https://huggingface.co/kunjcr2/GatorGPT
