 # Building Qwen3 from Scratch: A Complete



 Welcome,  we'll build a Qwen3-style language model from scratch!

 We'll implement Grouped-Query Attention (GQA), RMSNorm, SwiGLU activations, and add our own spin - new Muon optimizer that accelerates training by 30% to 50%.



 **What we'll cover:**

- Modern Transformer architecture with Qwen3-style features

- Grouped-Query Attention (GQA) for memory and compute efficiency

- Rotary Positional Embeddings (RoPE) for better performance and context window extrapolation

- QK-Norm with RMSNorm for improved numerical / training stability

- Muon optimizer using Newton-Schulz orthogonalization for better weight updates, faster learning with less data

- Hybrid optimization using Muon for matrices and AdamW for other parameters

- SwiGLU activation and deep residual learning in the feedforward layers

- Efficient dataset tokenization and caching with HuggingFace Datasets and Transformers

- Validation metrics including loss, accuracy, and perplexity

- Gradient accumulation + AMP (Automatic Mixed Precision) training for larger batch sizes

- Cosine learning rate scheduling with warmup

 ## 1. Setup and Imports



 First, let's import all the necessary libraries. We'll use PyTorch for the deep learning framework,

 transformers for tokenization, and various utilities for data handling and training.

In [1]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [39]:
!pip install huggingface_hub



In [2]:
from huggingface_hub import login
login()


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [3]:
import torch
import torch.nn as nn  # Neural network modules like Linear, Embedding, etc.
import torch.nn.functional as F  # Functional interface for operations like cross_entropy, silu, etc.
from torch.utils.data import Dataset, DataLoader  # Base class and utilities for loading datasets
from torch.cuda.amp import autocast, GradScaler  #  Automatic Mixed Precision (AMP) tools for faster/lower-memory training

import math  # Standard math operations (e.g. sqrt, exp, cos)
import random  # Python's random number utilities (used for seeding)
import numpy as np  # Numerical computing library, used for random seeding and general array ops

from datasets import load_dataset  #  Hugging Face Datasets library for streaming large datasets
from tqdm import tqdm  #  Progress bar visualization library, great for loops

import time  # ⌛ Timing utilities, measuring time
from transformers import AutoTokenizer  #  Load pretrained tokenizers from HuggingFace with one line

from dataclasses import dataclass  #  Define simple classes for configs with less boilerplate
from typing import List, Optional  #  Type hints for better readability and tooling

import warnings  #  Suppress or handle warnings
import os  #  File system operations (creating folders, path checking, etc.)
import pickle  #  Python object serialization (used to save/load preprocessed datasets)

warnings.filterwarnings('ignore')  # Silences warnings for cleaner outputs during training

print("Setup and imports Complete")


Setup and imports Complete


 ## 2. Utility Functions



 Let's start with some utility functions for reproducibility and configuration management.

 The `set_seed` function ensures our experiments are reproducible.

In [4]:
def set_seed(seed: int = 42):
    """Set all random seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f" Set all seeds to {seed}")


 ## 3. Model Configuration



 Here we define our model configuration using a dataclass. This makes it easy to experiment

 with different model sizes and hyperparameters. Our model will be a smaller version of Qwen3

 with 384 dimensions, 6 layers, and 8 attention heads.

In [5]:
@dataclass
class ModelConfig:
    # Model architecture
    d_model: int = 384
    n_heads: int = 8
    n_layers: int = 6
    d_ff: int = 1536
    batch_size: int = 24
    max_steps: int = 2000

    '''To get number of parameters of our model simply = 12 * n_layers * (d_model) * (d_model)'''

    # Qwen3-like parameters
    n_kv_heads: int = 4  # For Grouped-Query Attention
    sliding_window: int = 4096  # Set a large default, effectively disabling it unless specified
    attention_bias: bool = False  # Qwen3 often sets this to False
    rms_norm_eps: float = 1e-6  # Epsilon for RMSNorm

    # Training parameters
    gradient_accumulation_steps: int = 4
    muon_lr: float = 0.01

    # Data parameters
    max_seq_len: int = 512
    num_documents: int = 2000
    max_tokens: int = 500000

    # Evaluation
    eval_every: int = 500
    eval_steps: int = 100

    # Regularization
    weight_decay: float = 0.1
    dropout: float = 0.1
    grad_clip: float = 1.0

    # Technical
    use_amp: bool = True
    vocab_size: Optional[int] = None

    def __post_init__(self):
        self.d_k = self.d_model // self.n_heads
        assert self.d_model % self.n_heads == 0, "d_model must be divisible by n_heads"
        assert self.n_heads % self.n_kv_heads == 0, "n_heads must be divisible by n_kv_heads"
        self.n_kv_groups = self.n_heads // self.n_kv_heads


 ## 4. Grouped-Query Attention Helper



 This function implements the key component of GQA - repeating key and value heads.

 In GQA, we have fewer key-value heads than query heads, which reduces memory usage

 while maintaining performance.

In [6]:
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
    The hidden states go from (batch, num_key_value_heads, seqlen, head_dim)
    to (batch, num_attention_heads, seqlen, head_dim)
    """
    # Extract dimensions from input tensor
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape

    # Early return if no repetition is needed
    if n_rep == 1:
        return hidden_states

    # Add a new dimension at index 2 (after num_key_value_heads) and expand
    # Shape transformation:
    # (batch, num_key_value_heads, slen, head_dim)
    # -> (batch, num_key_value_heads, 1, slen, head_dim) [via None indexing]
    # -> (batch, num_key_value_heads, n_rep, slen, head_dim) [via expand]
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)

    # Flatten the num_key_value_heads and n_rep dimensions together
    # Final shape: (batch, num_key_value_heads * n_rep, slen, head_dim)
    # This effectively repeats each key/value head n_rep times
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

 ## 5. Muon Optimizer - The Secret Sauce



 The Muon optimizer is a novel approach that uses Newton-Schulz iteration for orthogonalization.

 This helps with training stability and convergence. The `zeropower_via_newtonschulz5` function

 implements the core mathematical operation.

In [7]:
@torch.compile
def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int = 5) -> torch.Tensor:
    """Newton-Schulz iteration to compute the zeroth power / orthogonalization of G."""
    assert G.ndim >= 2
    a, b, c = (3.4445, -4.7750, 2.0315)
    X = G.bfloat16()

    if G.size(-2) > G.size(-1):
        X = X.mT

    X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)

    for _ in range(steps):
        A = X @ X.mT
        B = b * A + c * A @ A
        X = a * X + B @ X

    if G.size(-2) > G.size(-1):
        X = X.mT

    return X

class Muon(torch.optim.Optimizer):
    """Muon - MomentUm Orthogonalized by Newton-schulz"""
    def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5):
        defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue

                g = p.grad
                state = self.state[p]

                # Initialize momentum buffer if first time
                if "momentum_buffer" not in state:
                    state["momentum_buffer"] = torch.zeros_like(g)

                buf = state["momentum_buffer"]
                # Update momentum buffer: buf = momentum * buf + (1-momentum) * grad
                buf.lerp_(g, 1 - group["momentum"])
                # Apply Nesterov momentum if enabled, otherwise use standard momentum
                g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
                # Apply zero-power normalization via Newton-Schulz iterations (make it close to orthonormal)
                g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
                # Update parameters with adaptive scaling based on parameter shape
                p.add_(g.view_as(p), alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5)
                # Updates parameters with an adaptive learning rate that scales based on the parameter tensor's aspect ratio (height/width). For matrices where height > width, it increases the effective learning rate by √(height/width)

 ## 6. Data Loading and Caching



 Loading and processing data can be time-consuming. We implement caching to avoid

 reprocessing the same data multiple times. This function loads the SmolLM corpus

 and tokenizes it efficiently.

In [8]:
def load_and_cache_data(config: ModelConfig, cache_dir: str = "data_cache"):
    """Load and cache tokenized data to avoid reprocessing"""
    os.makedirs(cache_dir, exist_ok=True)
    cache_file = f"{cache_dir}/tokenized_data_{config.num_documents}_{config.max_tokens}.pkl"

    # Check if cached data exists
    if os.path.exists(cache_file):
        print(f"📦 Loading cached data from {cache_file}")
        with open(cache_file, 'rb') as f:
            cached_data = pickle.load(f)

        texts = cached_data['texts']
        tokenizer = cached_data['tokenizer']
        tokens = cached_data['tokens']
        config.vocab_size = tokenizer.vocab_size

        print(f"✅ Loaded {len(texts)} documents, {len(tokens):,} tokens from cache")
        return texts, tokenizer, tokens

    print(f"🔄 Processing new data (will cache for future use)")

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M")
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Load dataset
    dataset = load_dataset("HuggingFaceTB/smollm-corpus", "cosmopedia-v2", split="train", streaming=True)

    texts = []
    for i, item in enumerate(dataset):
        if i >= config.num_documents:
            break
        texts.append(item["text"][:3000])

    print(f"Loaded {len(texts)} documents")

    # Tokenize
    print("Tokenizing texts...")
    all_tokens = []
    for text in tqdm(texts, desc="Tokenizing"):
        tokens = tokenizer.encode(text, add_special_tokens=False)
        all_tokens.extend(tokens)

    tokens = all_tokens[:config.max_tokens]
    print(f"Using {len(tokens):,} tokens")
    config.vocab_size = tokenizer.vocab_size

    # Cache the processed data
    cached_data = {'texts': texts, 'tokenizer': tokenizer, 'tokens': tokens}
    with open(cache_file, 'wb') as f:
        pickle.dump(cached_data, f)

    print(f"💾 Cached data to {cache_file}")
    return texts, tokenizer, tokens


 ## 7. Dataset Class



 We create a custom dataset class for language modeling. This creates sliding windows

 of tokens for training, where each sample is a sequence and its corresponding target

 (shifted by one position).

In [9]:
class TextTokenDataset(Dataset):
    def __init__(self, tokens: List[int], seq_len: int = 512):
        self.tokens = tokens
        self.seq_len = seq_len

    def __len__(self):
        return max(0, len(self.tokens) - self.seq_len)

    def __getitem__(self, idx):
        x = torch.tensor(self.tokens[idx:idx + self.seq_len], dtype=torch.long)
        y = torch.tensor(self.tokens[idx + 1:idx + self.seq_len + 1], dtype=torch.long)
        return x, y


 ## 8. Rotary Position Embeddings (RoPE)



 RoPE is a modern alternative to positional encodings that allows the model to

 generalize to longer sequences. It applies rotation matrices to the embeddings

 based on their position.

In [10]:
class Rotary(nn.Module):
    def __init__(self, dim: int, max_seq_len: int):
        super().__init__()
        angular_freq = (1 / 10000) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32)
        angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)])
        t = torch.arange(max_seq_len, dtype=torch.float32)
        theta = torch.einsum("i,j -> ij", t, angular_freq)
        self.register_buffer('cos', theta.cos(), persistent=False)
        self.register_buffer('sin', theta.sin(), persistent=False)

    def forward(self, x_BTHD: torch.Tensor):
        assert self.cos.size(0) >= x_BTHD.size(-3)
        cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :]
        x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1)
        y1 = x1 * cos + x2 * sin
        y2 = x1 * (-sin) + x2 * cos
        return torch.cat((y1, y2), 3).type_as(x_BTHD)


 ## 9. Grouped-Query Attention Implementation



 This is the heart of our model - the attention mechanism with GQA. Notice how we:

 1. Project Q, K, V separately

 2. Apply QK normalization (a Qwen3 innovation)

 3. Use RoPE for positional information

 4. Implement GQA by repeating K and V heads

 5. Use scaled dot-product attention

In [11]:
class Qwen3Attention(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.d_model = config.d_model
        self.n_heads = config.n_heads
        self.n_kv_heads = config.n_kv_heads
        self.n_kv_groups = config.n_kv_groups
        self.d_k = config.d_k

        # Separate linear layers for Q, K, V
        self.q_proj = nn.Linear(self.d_model, self.n_heads * self.d_k, bias=config.attention_bias)
        self.k_proj = nn.Linear(self.d_model, self.n_kv_heads * self.d_k, bias=config.attention_bias)
        self.v_proj = nn.Linear(self.d_model, self.n_kv_heads * self.d_k, bias=config.attention_bias)
        self.w_o = nn.Linear(self.d_model, self.d_model, bias=False)

        # QK-Normalization layers
        self.q_norm = nn.RMSNorm(self.d_k, eps=config.rms_norm_eps)
        self.k_norm = nn.RMSNorm(self.d_k, eps=config.rms_norm_eps)

        self.rotary = Rotary(self.d_k, config.max_seq_len)
        self.dropout = config.dropout

    def forward(self, x):
        batch_size, seq_len = x.size(0), x.size(1)

        # 1. Project Q, K, V separately
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        # 2. Reshape into heads
        q = q.view(batch_size, seq_len, self.n_heads, self.d_k)
        k = k.view(batch_size, seq_len, self.n_kv_heads, self.d_k)
        v = v.view(batch_size, seq_len, self.n_kv_heads, self.d_k)

        # 3. Apply QK-Norm
        q = self.q_norm(q)
        k = self.k_norm(k)

        # 4. Apply RoPE
        # Transpose to (batch, seq_len, n_heads, d_k) -> (batch, n_heads, seq_len, d_k) for rotary
        q = self.rotary(q.permute(0, 2, 1, 3)).permute(0, 2, 1, 3)
        k = self.rotary(k.permute(0, 2, 1, 3)).permute(0, 2, 1, 3)

        # Transpose for attention: (batch, seq_len, n_heads, d_k) -> (batch, n_heads, seq_len, d_k)
        Q = q.transpose(1, 2)
        K = k.transpose(1, 2)
        V = v.transpose(1, 2)

        # 5. Repeat K and V heads for GQA
        K = repeat_kv(K, self.n_kv_groups)
        V = repeat_kv(V, self.n_kv_groups)

        # 6. Scaled Dot-Product Attention
        attn_output = F.scaled_dot_product_attention(
            Q, K, V, is_causal=True, dropout_p=self.dropout if self.training else 0.0
        )

        # 7. Reshape and final projection
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        return self.w_o(attn_output)


 ## 10. SwiGLU Feed-Forward Network



 SwiGLU is a modern activation function that combines Swish and GLU. It's more

 effective than traditional ReLU and is used in many modern models including Qwen3.

In [12]:
class SwiGLUFeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.gate_proj = nn.Linear(d_model, d_ff, bias=False)
        self.down_proj = nn.Linear(d_ff, d_model, bias=False)
        self.up_proj = nn.Linear(d_model, d_ff, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Implementation of the SwiGLU activation function
        # F.silu is the Swish activation function
        activated_x = F.silu(self.gate_proj(x)) * self.up_proj(x)
        return self.down_proj(self.dropout(activated_x))


Think of:

`output = gate(x) * value(x)`

like:

`light = brightness_control × light_source`


 ## 11. Transformer Block



 Each transformer block combines attention and feed-forward layers with residual

 connections and normalization. We use RMSNorm instead of LayerNorm for better

 training stability.

In [13]:
class TransformerBlock(nn.Module):
    def __init__(self, config: ModelConfig):  # Pass the entire config object
        super().__init__()
        self.attention = Qwen3Attention(config)
        self.feed_forward = SwiGLUFeedForward(config.d_model, config.d_ff, config.dropout)
        self.norm1 = nn.RMSNorm(config.d_model, eps=config.rms_norm_eps)
        self.norm2 = nn.RMSNorm(config.d_model, eps=config.rms_norm_eps)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        attn_out = self.attention(self.norm1(x))
        x = x + self.dropout(attn_out)
        ff_out = self.feed_forward(self.norm2(x))
        x = x + self.dropout(ff_out)
        return x


 ## 12. Complete Language Model



 Now we assemble everything into our complete language model. This includes:

 - Token embeddings

 - Positional dropout

 - Stack of transformer blocks

 - Final normalization and output projection

 - Weight tying between input embeddings and output layer

In [42]:
class MinimalLLM(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config

        self.token_embedding = nn.Embedding(config.vocab_size, config.d_model)
        self.position_dropout = nn.Dropout(config.dropout)

        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(config) for _ in range(config.n_layers)
        ])

        self.norm = nn.RMSNorm(config.d_model, eps=config.rms_norm_eps)
        self.output_dropout = nn.Dropout(config.dropout)

        # Tie weights
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.lm_head.weight = self.token_embedding.weight

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, x):
        x = self.token_embedding(x) * math.sqrt(self.config.d_model)
        x = self.position_dropout(x)

        for block in self.transformer_blocks:
            x = block(x)

        x = self.norm(x)
        x = self.output_dropout(x)
        logits = self.lm_head(x)
        return logits


 ## 13. Evaluation Function



 During training, we need to evaluate our model's performance. This function

 computes loss, accuracy, and perplexity on the validation set.

In [15]:
def evaluate_model(model: nn.Module, val_loader: DataLoader, config: ModelConfig):
    """Evaluate model performance"""
    model.eval()
    total_loss = 0
    total_tokens = 0
    total_correct = 0

    device = next(model.parameters()).device

    with torch.no_grad():  # Disable gradient computation for evaluation (saves memory and computation)
        for i, (x, y) in enumerate(val_loader):
            # Stop evaluation after specified number of steps to limit eval time
            if i >= config.eval_steps:
                break

            # Move input sequences (x) and target sequences (y) to GPU/device
            x, y = x.to(device), y.to(device)

            # Use automatic mixed precision if enabled (faster training with minimal accuracy loss)
            with autocast(enabled=config.use_amp):
                # Forward pass: get model predictions (logits) for input sequence
                logits = model(x)

                # Calculate cross-entropy loss between predictions and targets
                # Reshape to (batch_size * seq_len, vocab_size) and (batch_size * seq_len,)
                # for proper cross-entropy computation across all token positions
                loss = F.cross_entropy(logits.view(-1, config.vocab_size), y.view(-1))

            # Accumulate total loss weighted by number of tokens in this batch
            total_loss += loss.item() * y.numel()
            # Keep track of total number of tokens processed
            total_tokens += y.numel()

            # Get predicted token IDs by taking argmax over vocabulary dimension
            predictions = logits.argmax(dim=-1)
            # Count correct predictions for accuracy calculation
            total_correct += (predictions == y).sum().item()

    avg_loss = total_loss / total_tokens
    accuracy = total_correct / total_tokens
    perplexity = math.exp(min(avg_loss, 20))

    model.train()
    return {'val_loss': avg_loss, 'val_accuracy': accuracy, 'val_perplexity': perplexity}


 ## 14. Optimizer Setup



 We use a hybrid approach: Muon optimizer for 2D parameters (attention and feed-forward weights)

 and AdamW for other parameters. This gives us the benefits of both optimizers.

In [16]:
def setup_muon_optimizer(model: nn.Module, config: ModelConfig):
    """Setup Muon optimizer with hybrid approach"""
    muon_params = []
    adamw_params = []

    for name, param in model.named_parameters():
        if (param.ndim == 2 and
            'token_embedding' not in name and
            'norm' not in name and
            param.requires_grad):
            muon_params.append(param)
        else:
            adamw_params.append(param)

    print(f"  Muon parameters: {sum(p.numel() for p in muon_params):,}")
    print(f"  AdamW parameters: {sum(p.numel() for p in adamw_params):,}")

    muon_optimizer = Muon(muon_params, lr=config.muon_lr, momentum=0.95)
    adamw_optimizer = torch.optim.AdamW(adamw_params, lr=config.muon_lr*0.1, weight_decay=config.weight_decay)

    return [muon_optimizer, adamw_optimizer]


 ## 15. Training Loop



 This is where the magic happens! Our training loop includes:

 - Gradient accumulation for larger effective batch sizes

 - Mixed precision training for speed

 - Learning rate scheduling with warmup and cosine decay

 - Regular evaluation and model checkpointing

 - Progress tracking with detailed metrics

In [17]:
def train_model(config: ModelConfig, train_loader: DataLoader, val_loader: DataLoader):
    """Train the model with Muon optimizer"""
    print(f"\n🚀 Training Small model with Muon optimizer")

    # Initialize model
    set_seed(42)
    model = MinimalLLM(config)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    total_params = sum(p.numel() for p in model.parameters())
    print(f"  📊 Total parameters: {total_params:,}")

    # Setup optimizers
    optimizers = setup_muon_optimizer(model, config)

    # Learning rate schedule
    schedulers = []
    for optimizer in optimizers:
        warmup_steps = config.max_steps // 20
        def lr_lambda(step):
            if step < warmup_steps:
                return step / warmup_steps
            else:
                progress = (step - warmup_steps) / (config.max_steps - warmup_steps)
                return 0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * progress))

        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
        schedulers.append(scheduler)

    scaler = GradScaler() if config.use_amp else None

    # Training loop
    model.train()
    step = 0
    start_time = time.time()
    best_val_loss = float('inf')

    pbar = tqdm(total=config.max_steps, desc="Training")

    while step < config.max_steps:
        for batch_idx, (x, y) in enumerate(train_loader):
            if step >= config.max_steps:
                break

            x, y = x.to(device), y.to(device)

            # Forward pass with gradient accumulation
            if config.use_amp:
                with autocast():
                    logits = model(x)
                    loss = F.cross_entropy(logits.view(-1, config.vocab_size), y.view(-1))
                    loss = loss / config.gradient_accumulation_steps
                scaler.scale(loss).backward()
            else:
                logits = model(x)
                loss = F.cross_entropy(logits.view(-1, config.vocab_size), y.view(-1))
                loss = loss / config.gradient_accumulation_steps
                loss.backward()

            # Optimizer step after accumulation
            if (step + 1) % config.gradient_accumulation_steps == 0:
                if config.use_amp:
                    for optimizer in optimizers:
                        scaler.unscale_(optimizer)
                    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)

                    for optimizer in optimizers:
                        scaler.step(optimizer)
                        optimizer.zero_grad()
                    for scheduler in schedulers:
                        scheduler.step()
                    scaler.update()
                else:
                    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
                    for optimizer in optimizers:
                        optimizer.step()
                        optimizer.zero_grad()
                    for scheduler in schedulers:
                        scheduler.step()

            # Logging
            if step % 10 == 0:
                with torch.no_grad():
                    predictions = logits.argmax(dim=-1)
                    accuracy = (predictions == y).float().mean().item()
                    current_loss = loss.item() * config.gradient_accumulation_steps
                    perplexity = math.exp(min(current_loss, 20))

                pbar.set_postfix({
                    'loss': f'{current_loss:.4f}',
                    'acc': f'{accuracy:.3f}',
                    'ppl': f'{perplexity:.1f}',
                    'lr': f'{optimizers[0].param_groups[0]["lr"]:.2e}'
                })

            # Evaluation
            if step % config.eval_every == 0 and step > 0:
                eval_metrics = evaluate_model(model, val_loader, config)
                print(f"\nStep {step}: Val Loss: {eval_metrics['val_loss']:.4f}, "
                      f"Val Acc: {eval_metrics['val_accuracy']:.4f}, "
                      f"Val PPL: {eval_metrics['val_perplexity']:.2f}")

                if eval_metrics['val_loss'] < best_val_loss:
                    best_val_loss = eval_metrics['val_loss']
                    # Save best model
                    torch.save({
                        'model_state_dict': model.state_dict(),
                        'config': config,
                        'step': step,
                        'best_val_loss': best_val_loss,
                        'final_metrics': eval_metrics
                    }, 'best_model.pt')
                    print(f"💾 Saved best model with val_loss: {best_val_loss:.4f}")

            step += 1
            if step % 10 == 0:
                pbar.update(10)

    pbar.close()

    training_time = time.time() - start_time
    print(f"  ⏱️ Training completed in {training_time:.1f} seconds")

    # Final evaluation
    final_eval = evaluate_model(model, val_loader, config)
    print(f"  📊 Final - Loss: {final_eval['val_loss']:.4f}, "
          f"Acc: {final_eval['val_accuracy']:.4f}, PPL: {final_eval['val_perplexity']:.2f}")

    # Save final model
    torch.save({
        'model_state_dict': model.state_dict(),
        'config': config,
        'step': step,
        'final_metrics': final_eval
    }, 'final_model.pt')
    print(f"💾 Saved final model to final_model.pt")

    return model, final_eval


 ## 16. Main Training Script



 Finally, let's put everything together! This section:

 1. Checks system resources

 2. Sets up configuration

 3. Loads and prepares data

 4. Trains the model

 5. Reports final results

In [18]:
if __name__ == "__main__":
    # Check system
    print(f"🔍 Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name()}")
        print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

    # Set seed
    set_seed(42)

    # Create config for Small model
    config = ModelConfig()
    print(f"\n📋 Model Configuration:")
    print(f"   Architecture: {config.d_model}d, {config.n_layers}L, {config.n_heads}H, {config.d_ff}ff")
    print(f"   Training: {config.max_steps} steps, batch size {config.batch_size}")
    print(f"   Data: {config.max_tokens:,} tokens, seq_len {config.max_seq_len}")

    # Load data
    texts, tokenizer, tokens = load_and_cache_data(config)
    dataset = TextTokenDataset(tokens, config.max_seq_len)

    # Train/val split
    val_size = len(dataset) // 10
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42)
    )

    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=2)

    print(f"📊 Dataset: {len(train_dataset)} train, {len(val_dataset)} val samples")

    # Train model
    start_time = time.time()
    model, final_metrics = train_model(config, train_loader, val_loader)
    total_time = time.time() - start_time

    print(f"\n🎉 TRAINING COMPLETED!")
    print(f"⏱️ Total time: {total_time/60:.1f} minutes")
    print(f"🏆 Final Results:")
    print(f"   Validation Loss: {final_metrics['val_loss']:.4f}")
    print(f"   Validation Accuracy: {final_metrics['val_accuracy']:.4f}")
    print(f"   Validation Perplexity: {final_metrics['val_perplexity']:.2f}")

🔍 Device: CUDA
GPU: Tesla T4
Memory: 15.8 GB
 Set all seeds to 42

📋 Model Configuration:
   Architecture: 384d, 6L, 8H, 1536ff
   Training: 2000 steps, batch size 24
   Data: 500,000 tokens, seq_len 512
🔄 Processing new data (will cache for future use)


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/831 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

Resolving data files:   0%|          | 0/104 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/104 [00:00<?, ?it/s]

Loaded 2000 documents
Tokenizing texts...


Tokenizing: 100%|██████████| 2000/2000 [00:03<00:00, 664.29it/s]


Using 500,000 tokens
💾 Cached data to data_cache/tokenized_data_2000_500000.pkl
📊 Dataset: 449540 train, 49948 val samples

🚀 Training Small model with Muon optimizer
 Set all seeds to 42
  📊 Total parameters: 32,150,976
  Muon parameters: 13,271,040
  AdamW parameters: 18,879,936


Training:   0%|          | 0/2000 [00:01<?, ?it/s, loss=10.8037, acc=0.015, ppl=49202.1, lr=0.00e+00]W0928 04:49:17.365000 1648 torch/_inductor/utils.py:1436] [0/0] Not enough SMs to use max_autotune_gemm mode
Training:  25%|██▌       | 500/2000 [02:50<08:33,  2.92it/s, loss=5.2131, acc=0.223, ppl=183.7, lr=1.00e-02]


Step 500: Val Loss: 4.9945, Val Acc: 0.2327, Val PPL: 147.59
💾 Saved best model with val_loss: 4.9945


Training:  50%|█████     | 1000/2000 [05:54<05:49,  2.86it/s, loss=3.5962, acc=0.358, ppl=36.5, lr=9.86e-03]


Step 1000: Val Loss: 3.1809, Val Acc: 0.4079, Val PPL: 24.07
💾 Saved best model with val_loss: 3.1809


Training:  75%|███████▌  | 1500/2000 [08:59<02:54,  2.86it/s, loss=2.5042, acc=0.495, ppl=12.2, lr=9.54e-03]


Step 1500: Val Loss: 1.9007, Val Acc: 0.6069, Val PPL: 6.69
💾 Saved best model with val_loss: 1.9007


Training: 100%|██████████| 2000/2000 [12:04<00:00,  2.76it/s, loss=1.8120, acc=0.603, ppl=6.1, lr=9.06e-03]

  ⏱️ Training completed in 724.6 seconds





  📊 Final - Loss: 1.1133, Acc: 0.7514, PPL: 3.04
💾 Saved final model to final_model.pt

🎉 TRAINING COMPLETED!
⏱️ Total time: 12.3 minutes
🏆 Final Results:
   Validation Loss: 1.1133
   Validation Accuracy: 0.7514
   Validation Perplexity: 3.04


## 17. Model Loading and Inference

After training, we can load our saved model and use it for text generation.
This section shows how to load the trained model and perform inference.

In [19]:
def load_trained_model(model_path: str = "final_model.pt"):
    """Load a trained model from checkpoint"""
    print(f" Loading model from {model_path}")

    # Add ModelConfig to safe globals for PyTorch 2.6+
    from torch.serialization import add_safe_globals
    add_safe_globals([ModelConfig])

    try:
        checkpoint = torch.load(model_path, map_location='cpu')
        config = checkpoint['config']
    except Exception as e:
        print(f"⚠️ Error loading with weights_only=True, trying with weights_only=False...")
        checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
        config = checkpoint['config']

    # Create model with same config
    model = MinimalLLM(config)
    model.load_state_dict(checkpoint['model_state_dict'])

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()

    print(f"✅ Model loaded successfully")
    print(f"   Parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(f"   Device: {device}")

    return model, config

### 🔮 Text Generation Function

This function generates text using a trained language model. Given a prompt, it tokenizes the input and autoregressively samples tokens up to `max_length`. It supports:

- **Temperature scaling** for randomness control, e.g., `0.7` makes output more focused, `1.5` makes it more random.
- **Top-k sampling** to limit candidates to the top `k` most likely tokens, e.g., `top_k=50` narrow down to 50 highest-probability tokens.
- **Top-p (nucleus) sampling** to sample from the smallest set of tokens whose cumulative probability exceeds `p`, i.e. the fewest number of tokens whose combined probabilities add up to at least p (e.g., 90%).

Generation stops early if the EOS token is produced.


In [20]:
def generate_text(model: nn.Module, tokenizer, prompt: str, max_length: int = 100,
                 temperature: float = 0.8, top_k: int = 50, top_p: float = 0.9):
    """Generate text using the trained model"""
    model.eval()
    device = next(model.parameters()).device

    # Tokenize prompt
    input_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors='pt').to(device)

    generated_ids = input_ids.clone()

    with torch.no_grad():
        for _ in range(max_length):
            # Get model predictions
            logits = model(generated_ids)
            next_token_logits = logits[0, -1, :] / temperature

            # Apply top-k filtering
            if top_k > 0:
                top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)
                next_token_logits = torch.full_like(next_token_logits, float('-inf'))
                next_token_logits[top_k_indices] = top_k_logits

            # Apply top-p (nucleus) filtering
            if top_p < 1.0:
                sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
                cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
                sorted_indices_to_remove[0] = 0
                indices_to_remove = sorted_indices[sorted_indices_to_remove]
                next_token_logits[indices_to_remove] = float('-inf')

            # Sample next token
            probs = F.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)

            # Append to generated sequence - FIX: ensure same dimensions
            next_token = next_token.unsqueeze(0)  # Add batch dimension
            generated_ids = torch.cat([generated_ids, next_token], dim=1)

            # Stop if we reach the end token
            if next_token.item() == tokenizer.eos_token_id:
                break

    # Decode the generated text
    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    return generated_text

In [21]:
def interactive_inference(model_path: str = "final_model.pt"):
    """Interactive inference session"""
    print("🤖 Starting interactive inference session")
    print("Type 'quit' to exit")

    # Load model and tokenizer
    model, config = load_trained_model(model_path)

    # Load tokenizer (assuming we have the same one used during training)
    tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M")
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    while True:
        try:
            prompt = input("\n Enter your prompt: ")
            if prompt.lower() in ['quit', 'exit', 'q']:
                print("👋 Goodbye!")
                break

            if not prompt.strip():
                continue

            print("🔄 Generating...")
            generated_text = generate_text(
                model, tokenizer, prompt,
                max_length=150,
                temperature=0.8,
                top_k=50,
                top_p=0.9
            )

            print(f"\n Generated text:")
            print(f"📝 {generated_text}")

        except KeyboardInterrupt:
            print("\n👋 Goodbye!")
            break
        except Exception as e:
            print(f"❌ Error: {e}")

In [26]:
def demo_inference(model_path: str = "final_model.pt"):
    """Run a quick demo of the model's capabilities"""
    print("🎭 Running inference demo")

    # Load model and tokenizer
    model, config = load_trained_model(model_path)
    tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M")
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Demo prompts
    demo_prompts = [
        "Mama anakuja nyumbani Mwanza",
        "Once upon a time in a distant galaxy",
        "Tanzania ilipata uhuru mwaka",
        "In the year 2050, technology will",
        "The best way to learn programming is"
    ]

    for i, prompt in enumerate(demo_prompts, 1):
        print(f"\n Demo {i}: '{prompt}'")
        print("-" * 50)

        generated_text = generate_text(
            model, tokenizer, prompt,
            max_length=100,
            temperature=0.7,
            top_k=40,
            top_p=0.85
        )

        print(f"📝 {generated_text}")
        print()

In [28]:
if __name__ == "__main__":
    # Check if we have a trained model
    import os

    if os.path.exists("final_model.pt"):
        print("🎉 Found trained model! Running demo...")
        demo_inference("final_model.pt")

        # Optionally run interactive session
        response = input("\n🤖 Would you like to try interactive inference? (y/n): ")
        if response.lower() in ['y', 'yes']:
            interactive_inference("final_model.pt")
    else:
        print("⚠️ No trained model found. Please run the training cells first.")
        print("💡 Look for 'final_model.pt' or 'best_model.pt' in your directory.")

🎉 Found trained model! Running demo...
🎭 Running inference demo
 Loading model from final_model.pt
✅ Model loaded successfully
   Parameters: 32,150,976
   Device: cuda

 Demo 1: 'Mama anakuja nyumbani Mwanza'
--------------------------------------------------
📝 Mama anakuja nyumbani Mwanza, 17, a few iconic Chinese Chinese Chinese Chinese Chinese Chinese Chinese Chinese Chinese Chinese Chinese Chinese Chinese Chinese Chinese Chinese Chinese Chinese Chinese Chinese Chinese Chinese Chinese Chinese Chinese Chinese Chinese Chinese, and social structures caused by various social structures. This is the US, and other African cultures worldwide. Imagine being able to honor while still visible today, and politics, and other cultures moved to honor while exploring these concepts being able to change on how they are living on whether they occur on a few people of the US


 Demo 2: 'Once upon a time in a distant galaxy'
--------------------------------------------------
📝 Once upon a time in a d

In [50]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [51]:
from huggingface_hub import HfApi, create_repo

# Create repo
create_repo("marcoharuni95/qwen3-small-muon", exist_ok=True)

# Push model
api = HfApi()
api.upload_file(
    path_or_fileobj="final_model.pt",
    path_in_repo="final_model.pt",
    repo_id="marcoharuni95/qwen3-small-muon",
)
print("Model pushed successfully")

Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

  final_model.pt              :   0%|          |  561kB /  129MB            

Model pushed successfully
