In [1]:
from importlib.metadata import version

pkgs = [
    "huggingface_hub",
    "tokenizers",
    "torch",
]

for p in pkgs:
    print(f"{p} version: {version(p)}")

huggingface_hub version: 0.34.4
tokenizers version: 0.22.0
torch version: 2.8.0+cu126


In [2]:
USE_REASONING_MODEL = True
USE_INSTRUCT_MODEL = False

## 1. Architecture Code

In [3]:
import torch
import torch.nn as nn

In [4]:
class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.fc1 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
        self.fc2 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
        self.fc3 = nn.Linear(cfg["hidden_dim"], cfg["emb_dim"], dtype=cfg["dtype"], bias=False)

    def forward(self, x):
        x_fc1 = self.fc1(x)
        x_fc2 = self.fc2(x)
        x = nn.functional.silu(x_fc1) * x_fc2
        return self.fc3(x)

In [5]:
class RMSNorm(nn.Module):
    def __init__(self, emb_dim, eps=1e-6, bias=False, qwen3_compatible=True):
        super().__init__()
        self.eps = eps
        self.qwen3_compatible = qwen3_compatible
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim)) if bias else None

    def forward(self, x):
        input_dtype = x.dtype

        if self.qwen3_compatible:
            x = x.to(torch.float32)

        variance = x.pow(2).mean(dim=-1, keepdim=True)
        norm_x = x * torch.rsqrt(variance + self.eps)
        norm_x = norm_x * self.scale

        if self.shift is not None:
            norm_x = norm_x + self.shift

        return norm_x.to(input_dtype)

In [6]:
def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, dtype=torch.float32):
    assert head_dim % 2 == 0, "Embedding dimension must be even"

    # Compute the inverse frequencies
    inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))

    # Generate position indices
    positions = torch.arange(context_length, dtype=dtype)

    # Compute the angles
    angles = positions[:, None] * inv_freq[None, :]  # Shape: (context_length, head_dim // 2)

    # Expand angles to match the head_dim
    angles = torch.cat([angles, angles], dim=1)  # Shape: (context_length, head_dim)

    # Precompute sine and cosine
    cos = torch.cos(angles)
    sin = torch.sin(angles)

    return cos, sin


def apply_rope(x, cos, sin, offset=0):
    # x: (batch_size, num_heads, seq_len, head_dim)
    batch_size, num_heads, seq_len, head_dim = x.shape
    assert head_dim % 2 == 0, "Head dimension must be even"

    # Split x into first half and second half
    x1 = x[..., : head_dim // 2]  # First half
    x2 = x[..., head_dim // 2:]  # Second half

    # Adjust sin and cos shapes
    cos = cos[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, seq_len, head_dim // 2)
    sin = sin[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0)

    # Apply the rotary transformation
    rotated = torch.cat((-x2, x1), dim=-1)
    x_rotated = (x * cos) + (rotated * sin)

    # It's ok to use lower-precision after applying cos and sin rotation
    return x_rotated.to(dtype=x.dtype)

In [7]:
class GroupedQueryAttention(nn.Module):
    def __init__(
        self, d_in, num_heads, num_kv_groups, head_dim=None, qk_norm=False, dtype=None
    ):
        super().__init__()
        assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"

        self.num_heads = num_heads
        self.num_kv_groups = num_kv_groups
        self.group_size = num_heads // num_kv_groups

        if head_dim is None:
            assert d_in % num_heads == 0, "`d_in` must be divisible by `num_heads` if `head_dim` is not set"
            head_dim = d_in // num_heads

        self.head_dim = head_dim
        self.d_out = num_heads * head_dim

        self.W_query = nn.Linear(d_in, self.d_out, bias=False, dtype=dtype)
        self.W_key = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype)
        self.W_value = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype)

        self.out_proj = nn.Linear(self.d_out, d_in, bias=False, dtype=dtype)

        if qk_norm:
            self.q_norm = RMSNorm(head_dim, eps=1e-6)
            self.k_norm = RMSNorm(head_dim, eps=1e-6)
        else:
            self.q_norm = self.k_norm = None

    def forward(self, x, mask, cos, sin, start_pos=0, cache=None):
        b, num_tokens, _ = x.shape

        # Apply projections
        queries = self.W_query(x)  # (b, num_tokens, num_heads * head_dim)
        keys = self.W_key(x)       # (b, num_tokens, num_kv_groups * head_dim)
        values = self.W_value(x)   # (b, num_tokens, num_kv_groups * head_dim)

        # Reshape to heads / kv-groups
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
        keys_new = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
        values_new = values.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)

        # Optional normalization
        if self.q_norm:
            queries = self.q_norm(queries)
        if self.k_norm:
            keys_new = self.k_norm(keys_new)

        # Apply RoPE
        queries = apply_rope(queries, cos, sin, offset=start_pos)
        keys_new = apply_rope(keys_new, cos, sin, offset=start_pos)

        if cache is not None:
            prev_k, prev_v = cache
            keys = torch.cat([prev_k, keys_new], dim=2)
            values = torch.cat([prev_v, values_new], dim=2)
        else:
            start_pos = 0  # reset RoPE
            keys, values = keys_new, values_new
        next_cache = (keys, values)

        # Expand K and V to match number of heads
        keys = keys.repeat_interleave(self.group_size, dim=1)
        values = values.repeat_interleave(self.group_size, dim=1)

        # Attention
        attn_scores = queries @ keys.transpose(2, 3)
        attn_scores = attn_scores.masked_fill(mask, -torch.inf)
        attn_weights = torch.softmax(attn_scores / self.head_dim**0.5, dim=-1)

        context = (attn_weights @ values).transpose(1, 2).reshape(b, num_tokens, self.d_out)
        return self.out_proj(context), next_cache

In [8]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.att = GroupedQueryAttention(
            d_in=cfg["emb_dim"],
            num_heads=cfg["n_heads"],
            head_dim=cfg["head_dim"],
            num_kv_groups=cfg["n_kv_groups"],
            qk_norm=cfg["qk_norm"],
            dtype=cfg["dtype"]
        )
        self.ff = FeedForward(cfg)
        self.norm1 = RMSNorm(cfg["emb_dim"], eps=1e-6)
        self.norm2 = RMSNorm(cfg["emb_dim"], eps=1e-6)

    def forward(self, x, mask, cos, sin, start_pos=0, cache=None):
        # Shortcut connection for attention block
        shortcut = x
        x = self.norm1(x)
        x, next_cache = self.att(x, mask, cos, sin, start_pos=start_pos, cache=cache)  # Shape [batch_size, num_tokens, emb_size]
        x = x + shortcut  # Add the original input back

        # Shortcut connection for feed-forward block
        shortcut = x
        x = self.norm2(x)
        x = self.ff(x)
        x = x + shortcut  # Add the original input back

        return x, next_cache

In [9]:
class Qwen3Model(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        # Main model parameters
        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])

        self.trf_blocks = nn.ModuleList(  # ModuleList since Sequential can only accept one input, and we need `x, mask, cos, sin`
            [TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
        )
        self.final_norm = RMSNorm(cfg["emb_dim"])
        self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])

        # Reusable utilities
        if cfg["head_dim"] is None:
            head_dim = cfg["emb_dim"] // cfg["n_heads"]
        else:
            head_dim = cfg["head_dim"]
        cos, sin = compute_rope_params(
            head_dim=head_dim,
            theta_base=cfg["rope_base"],
            context_length=cfg["context_length"]
        )
        self.register_buffer("cos", cos, persistent=False)
        self.register_buffer("sin", sin, persistent=False)
        self.cfg = cfg
        self.current_pos = 0  # Track current position in KV cache

    def forward(self, in_idx, cache=None):
        # Forward pass
        tok_embeds = self.tok_emb(in_idx)
        x = tok_embeds

        num_tokens = x.shape[1]
        if cache is not None:
            pos_start = self.current_pos
            pos_end = pos_start + num_tokens
            self.current_pos = pos_end
            mask = torch.triu(
                torch.ones(pos_end, pos_end, device=x.device, dtype=torch.bool), diagonal=1
            )[pos_start:pos_end, :pos_end]
        else:
            pos_start = 0  # Not strictly necessary but helps torch.compile
            mask = torch.triu(
                torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1
            )
        # Shape (1, 1, num_tokens, num_tokens) to broadcast across batch and heads
        mask = mask[None, None, :, :]  # broadcast mask

        next_cache = []
        for i, block in enumerate(self.trf_blocks):
            blk_cache = cache.get(i) if cache else None
            x, new_blk_cache = block(x, mask, self.cos, self.sin,
                                     start_pos=pos_start,
                                     cache=blk_cache)
            if cache is not None:
                cache.update(i, new_blk_cache)
            next_cache.append(new_blk_cache)

        x = self.final_norm(x)
        logits = self.out_head(x.to(self.cfg["dtype"]))
        return logits

    def reset_kv_cache(self):
        self.current_pos = 0

In [10]:
class KVCache:
    def __init__(self, n_layers):
        self.cache = [None] * n_layers

    def get(self, layer_idx):
        return self.cache[layer_idx]

    def update(self, layer_idx, value):
        self.cache[layer_idx] = value

    def get_all(self):
        return self.cache

    def reset(self):
        for i in range(len(self.cache)):
            self.cache[i] = None

## 2. Initialize model

In [11]:
CHOOSE_MODEL = "0.6B"

if CHOOSE_MODEL == "0.6B":
    QWEN3_CONFIG = {
        "vocab_size": 151_936,           # Vocabulary size
        "context_length": 40_960,        # Context length that was used to train the model
        "emb_dim": 1024,                 # Embedding dimension
        "n_heads": 16,                   # Number of attention heads
        "n_layers": 28,                  # Number of layers
        "hidden_dim": 3072,              # Size of the intermediate dimension in FeedForward
        "head_dim": 128,                 # Size of the heads in GQA
        "qk_norm": True,                 # Whether to normalize queries and keys in GQA
        "n_kv_groups": 8,                # Key-Value groups for grouped-query attention
        "rope_base": 1_000_000.0,        # The base in RoPE's "theta"
        "dtype": torch.bfloat16,         # Lower-precision dtype to reduce memory usage
    }

elif CHOOSE_MODEL == "1.7B":
    QWEN3_CONFIG = {
        "vocab_size": 151_936,
        "context_length": 40_960,
        "emb_dim": 2048,                 # 2x larger than above
        "n_heads": 16,
        "n_layers": 28,
        "hidden_dim": 6144,              # 2x larger than above
        "head_dim": 128,
        "qk_norm": True,
        "n_kv_groups": 8,
        "rope_base": 1_000_000.0,
        "dtype": torch.bfloat16,
    }

elif CHOOSE_MODEL == "4B":
    QWEN3_CONFIG = {
        "vocab_size": 151_936,
        "context_length": 40_960,
        "emb_dim": 2560,                 # 25% larger than above
        "n_heads": 32,                   # 2x larger than above
        "n_layers": 36,                  # 29% larger than above
        "hidden_dim": 9728,              # ~3x larger than above
        "head_dim": 128,
        "qk_norm": True,
        "n_kv_groups": 8,
        "rope_base": 1_000_000.0,
        "dtype": torch.bfloat16,
    }

elif CHOOSE_MODEL == "8B":
    QWEN3_CONFIG = {
        "vocab_size": 151_936,
        "context_length": 40_960,
        "emb_dim": 4096,                 # 60% larger than above
        "n_heads": 32,
        "n_layers": 36,                  # 26% larger than above
        "hidden_dim": 12288,
        "head_dim": 128,
        "qk_norm": True,
        "n_kv_groups": 8,
        "rope_base": 1_000_000.0,
        "dtype": torch.bfloat16,
    }

elif CHOOSE_MODEL == "14B":
    QWEN3_CONFIG = {
        "vocab_size": 151_936,
        "context_length": 40_960,
        "emb_dim": 5120,                 # 25% larger than above
        "n_heads": 40,                   # 25% larger than above
        "n_layers": 40,                  # 11% larger than above
        "hidden_dim": 17408,             # 42% larger than above
        "head_dim": 128,
        "qk_norm": True,
        "n_kv_groups": 8,
        "rope_base": 1_000_000.0,
        "dtype": torch.bfloat16,
    }

elif CHOOSE_MODEL == "32B":
    QWEN3_CONFIG = {
        "vocab_size": 151_936,
        "context_length": 40_960,
        "emb_dim": 5120,
        "n_heads": 64,                   # 60% larger than above
        "n_layers": 64,                  # 60% larger than above
        "hidden_dim": 25600,             # 47% larger than above
        "head_dim": 128,
        "qk_norm": True,
        "n_kv_groups": 8,
        "rope_base": 1_000_000.0,
        "dtype": torch.bfloat16,
    }

else:
    raise ValueError(f"{CHOOSE_MODEL} is not supported.")

In [12]:
torch.manual_seed(42)
model = Qwen3Model(QWEN3_CONFIG)

In [13]:
model

Qwen3Model(
  (tok_emb): Embedding(151936, 1024)
  (trf_blocks): ModuleList(
    (0-27): 28 x TransformerBlock(
      (att): GroupedQueryAttention(
        (W_query): Linear(in_features=1024, out_features=2048, bias=False)
        (W_key): Linear(in_features=1024, out_features=1024, bias=False)
        (W_value): Linear(in_features=1024, out_features=1024, bias=False)
        (out_proj): Linear(in_features=2048, out_features=1024, bias=False)
        (q_norm): RMSNorm()
        (k_norm): RMSNorm()
      )
      (ff): FeedForward(
        (fc1): Linear(in_features=1024, out_features=3072, bias=False)
        (fc2): Linear(in_features=1024, out_features=3072, bias=False)
        (fc3): Linear(in_features=3072, out_features=1024, bias=False)
      )
      (norm1): RMSNorm()
      (norm2): RMSNorm()
    )
  )
  (final_norm): RMSNorm()
  (out_head): Linear(in_features=1024, out_features=151936, bias=False)
)

In [14]:
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params:,}")

# Account for weight tying
total_params_normalized = total_params - model.tok_emb.weight.numel()
print(f"\nTotal number of unique parameters: {total_params_normalized:,}")

Total number of parameters: 751,632,384

Total number of unique parameters: 596,049,920


In [15]:
def model_memory_size(model, input_dtype=torch.float32):
    total_params = 0
    total_grads = 0
    for param in model.parameters():
        # Calculate total number of elements per parameter
        param_size = param.numel()
        total_params += param_size
        # Check if gradients are stored for this parameter
        if param.requires_grad:
            total_grads += param_size

    # Calculate buffer size (non-parameters that require memory)
    total_buffers = sum(buf.numel() for buf in model.buffers())

    # Size in bytes = (Number of elements) * (Size of each element in bytes)
    # We assume parameters and gradients are stored in the same type as input dtype
    element_size = torch.tensor(0, dtype=input_dtype).element_size()
    total_memory_bytes = (total_params + total_grads + total_buffers) * element_size

    # Convert bytes to gigabytes
    total_memory_gb = total_memory_bytes / (1024**3)

    return total_memory_gb

In [16]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

model.to(device);

## 4. Load pretrained weights

In [17]:
def load_weights_into_qwen(model, param_config, params):
    def assign(left, right, tensor_name="unknown"):
        if left.shape != right.shape:
            raise ValueError(f"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}")

        with torch.no_grad():
            if isinstance(right, torch.Tensor):
                left.copy_(right)
            else:
                left.copy_(torch.as_tensor(right, dtype=left.dtype, device=left.device))

        return left

    model.tok_emb.weight = assign(model.tok_emb.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")

    for l in range(param_config["n_layers"]):
        block = model.trf_blocks[l]
        att = block.att

        # Q, K, V projections
        att.W_query.weight = assign(
            att.W_query.weight,
            params[f"model.layers.{l}.self_attn.q_proj.weight"],
            f"model.layers.{l}.self_attn.q_proj.weight"
        )
        att.W_key.weight = assign(
            att.W_key.weight,
            params[f"model.layers.{l}.self_attn.k_proj.weight"],
            f"model.layers.{l}.self_attn.k_proj.weight"
        )
        att.W_value.weight = assign(
            att.W_value.weight,
            params[f"model.layers.{l}.self_attn.v_proj.weight"],
            f"model.layers.{l}.self_attn.v_proj.weight"
        )

        # Output projection
        att.out_proj.weight = assign(
            att.out_proj.weight,
            params[f"model.layers.{l}.self_attn.o_proj.weight"],
            f"model.layers.{l}.self_attn.o_proj.weight"
        )

        # QK norms
        if hasattr(att, "q_norm") and att.q_norm is not None:
            att.q_norm.scale = assign(
                att.q_norm.scale,
                params[f"model.layers.{l}.self_attn.q_norm.weight"],
                f"model.layers.{l}.self_attn.q_norm.weight"
            )
        if hasattr(att, "k_norm") and att.k_norm is not None:
            att.k_norm.scale = assign(
                att.k_norm.scale,
                params[f"model.layers.{l}.self_attn.k_norm.weight"],
                f"model.layers.{l}.self_attn.k_norm.weight"
            )

        # Attention layernorm
        block.norm1.scale = assign(
            block.norm1.scale,
            params[f"model.layers.{l}.input_layernorm.weight"],
            f"model.layers.{l}.input_layernorm.weight"
        )

        # Feedforward weights
        block.ff.fc1.weight = assign(
            block.ff.fc1.weight,
            params[f"model.layers.{l}.mlp.gate_proj.weight"],
            f"model.layers.{l}.mlp.gate_proj.weight"
        )
        block.ff.fc2.weight = assign(
            block.ff.fc2.weight,
            params[f"model.layers.{l}.mlp.up_proj.weight"],
            f"model.layers.{l}.mlp.up_proj.weight"
        )
        block.ff.fc3.weight = assign(
            block.ff.fc3.weight,
            params[f"model.layers.{l}.mlp.down_proj.weight"],
            f"model.layers.{l}.mlp.down_proj.weight"
        )
        block.norm2.scale = assign(
            block.norm2.scale,
            params[f"model.layers.{l}.post_attention_layernorm.weight"],
            f"model.layers.{l}.post_attention_layernorm.weight"
        )

    # Final normalization and output head
    model.final_norm.scale = assign(model.final_norm.scale, params["model.norm.weight"], "model.norm.weight")

    if "lm_head.weight" in params:
        model.out_head.weight = assign(model.out_head.weight, params["lm_head.weight"], "lm_head.weight")
    else:
        model.out_head.weight = model.tok_emb.weight
        print("Model uses weight tying.")

In [18]:
import json
import os
from pathlib import Path
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download, snapshot_download


if USE_REASONING_MODEL:
    repo_id = f"Qwen/Qwen3-{CHOOSE_MODEL}"
else:
    repo_id = f"Qwen/Qwen3-{CHOOSE_MODEL}-Base"

local_dir = Path(repo_id).parts[-1]

if CHOOSE_MODEL == "0.6B":
    weights_file = hf_hub_download(
        repo_id=repo_id,
        filename="model.safetensors",
        local_dir=local_dir,
    )
    weights_dict = load_file(weights_file)
else:
    repo_dir = snapshot_download(repo_id=repo_id, local_dir=local_dir)
    index_path = os.path.join(repo_dir, "model.safetensors.index.json")
    with open(index_path, "r") as f:
        index = json.load(f)

    weights_dict = {}
    for filename in set(index["weight_map"].values()):
        shard_path = os.path.join(repo_dir, filename)
        shard = load_file(shard_path)
        weights_dict.update(shard)

load_weights_into_qwen(model, QWEN3_CONFIG, weights_dict)
model.to(device)
del weights_dict

model.safetensors:   0%|          | 0.00/1.50G [00:00<?, ?B/s]

## 5. Load Tokenizer

In [19]:
import re
from tokenizers import Tokenizer

class Qwen3Tokenizer:
    _SPECIALS = [
        "<|endoftext|>",
        "<|im_start|>", "<|im_end|>",
        "<|object_ref_start|>", "<|object_ref_end|>",
        "<|box_start|>", "<|box_end|>",
        "<|quad_start|>", "<|quad_end|>",
        "<|vision_start|>", "<|vision_end|>",
        "<|vision_pad|>", "<|image_pad|>", "<|video_pad|>",
        "<think>", "</think>"
    ]
    _SPLIT_RE = re.compile(r"(<\|[^>]+?\|>|<think>|</think>)")

    def __init__(self, tokenizer_file_path="tokenizer.json", repo_id=None,
                 apply_chat_template=True, add_generation_prompt=False, add_thinking=False):

        self.apply_chat_template = apply_chat_template
        self.add_generation_prompt = add_generation_prompt
        self.add_thinking = add_thinking

        tok_file = Path(tokenizer_file_path)
        self._tok = Tokenizer.from_file(str(tok_file))
        self._special_to_id = {}
        for t in self._SPECIALS:
            tid = self._tok.token_to_id(t)
            if tid is not None:
                self._special_to_id[t] = tid

        self.pad_token_id = self._special_to_id["<|endoftext|>"]
        self.eos_token_id = self.pad_token_id

        if repo_id and "Base" not in repo_id:
            eos_token = "<|im_end|>"
        else:
            eos_token = "<|endoftext|>"
        if eos_token in self._special_to_id:
            self.eos_token_id = self._special_to_id[eos_token]

    def encode(self, text, chat_wrapped=None):
        if chat_wrapped is None:
            chat_wrapped = self.apply_chat_template

        stripped = text.strip()
        if stripped in self._special_to_id and "\n" not in stripped:
            return [self._special_to_id[stripped]]

        if chat_wrapped:
            text = self._wrap_chat(text)

        ids = []
        for part in filter(None, self._SPLIT_RE.split(text)):
            if part in self._special_to_id:
                ids.append(self._special_to_id[part])
            else:
                ids.extend(self._tok.encode(part).ids)
        return ids

    def decode(self, ids):
        return self._tok.decode(ids, skip_special_tokens=False)

    def _wrap_chat(self, user_msg):
        s = f"<|im_start|>user\n{user_msg}<|im_end|>\n"
        if self.add_generation_prompt:
            s += "<|im_start|>assistant"
            if self.add_thinking:
                s += "\n"
            else:
                s += "\n<think>\n\n</think>\n\n"
        return s

In [20]:
if USE_REASONING_MODEL:
    tokenizer_file_path = f"Qwen3-{CHOOSE_MODEL}/tokenizer.json"
else:
    tokenizer_file_path = f"Qwen3-{CHOOSE_MODEL}-Base/tokenizer.json"

hf_hub_download(
    repo_id=repo_id,
    filename="tokenizer.json",
    local_dir=local_dir,
)

tokenizer = Qwen3Tokenizer(
    tokenizer_file_path=tokenizer_file_path,
    repo_id=repo_id,
    apply_chat_template=USE_REASONING_MODEL,
    add_generation_prompt=USE_REASONING_MODEL,
    add_thinking=not USE_INSTRUCT_MODEL
)

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

In [21]:
prompt = "Give me a short introduction to large language models."

input_token_ids = tokenizer.encode(prompt)
text = tokenizer.decode(input_token_ids)
text

'<|im_start|>user\nGive me a short introduction to large language models.<|im_end|>\n<|im_start|>assistant\n'

## 6. Generate text

In [22]:
def generate_text_basic_stream(model, token_ids, max_new_tokens, eos_token_id=None, context_size=None):
    model.eval()

    with torch.no_grad():
        cache = KVCache(n_layers=model.cfg["n_layers"])
        model.reset_kv_cache()

        # Prime the cache with the initial context
        logits = model(token_ids, cache=cache)

        for _ in range(max_new_tokens):
            next_token = torch.argmax(logits[:, -1], dim=-1, keepdim=True)

            if eos_token_id is not None and torch.all(next_token == eos_token_id):
                break

            yield next_token

            token_ids = torch.cat([token_ids, next_token], dim=1)

            # Feed only the new token to the model; cache handles history
            logits = model(next_token, cache=cache)

In [23]:
input_token_ids_tensor = torch.tensor(input_token_ids, device=device).unsqueeze(0)

for token in generate_text_basic_stream(
    model=model,
    token_ids=input_token_ids_tensor,
    max_new_tokens=500,
    eos_token_id=tokenizer.eos_token_id
):
    token_id = token.squeeze(0).tolist()
    print(
        tokenizer.decode(token_id),
        end="",
        flush=True
    )

<think>
Okay, the user wants a short introduction to large language models. Let me start by recalling what I know. Large language models are AI systems that can understand and generate human language. They're trained on massive datasets, so they can learn complex patterns and nuances.

I should mention their ability to understand and generate text, not just specific tasks. Maybe include examples like chatbots or language assistants. Also, emphasize their adaptability and efficiency. Oh, and maybe touch on their applications in various fields. Let me check if I'm covering all key points without being too technical. Keep it concise, around 3-4 sentences. Make sure it's clear and easy to understand.
</think>

Large language models (LLMs) are AI systems designed to understand and generate human language. They are trained on vast datasets to learn complex patterns and nuances, enabling them to comprehend context, understand emotions, and generate coherent text. These models are used for tas

## 7. GRPO Training for LLM

In [24]:
!pip install reasoning_gym

Collecting reasoning_gym
  Downloading reasoning_gym-0.1.23-py3-none-any.whl.metadata (8.7 kB)
Collecting arckit==0.1.0 (from reasoning_gym)
  Downloading arckit-0.1.0-py3-none-any.whl.metadata (503 bytes)
Collecting bfi==1.0.4 (from reasoning_gym)
  Downloading bfi-1.0.4-py3-none-any.whl.metadata (12 kB)
Collecting cellpylib==2.4.0 (from reasoning_gym)
  Downloading cellpylib-2.4.0.tar.gz (38 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting magiccube==0.3.0 (from reasoning_gym)
  Downloading magiccube-0.3.0-py3-none-any.whl.metadata (3.9 kB)
Collecting pycosat==0.6.6 (from reasoning_gym)
  Downloading pycosat-0.6.6.tar.gz (71 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m71.6/71.6 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pyfiglet==1.0.2 (from reasoning_gym)
  Downloading pyfiglet-1.0.2-py3-none-any.whl.metadata (7.1 kB)
Collecting zss>=1.2.0 (from reasoning_gym)
  Downlo

In [52]:
import re
import random
import numpy as np
import reasoning_gym

from tqdm import tqdm
from typing import Generator, Optional, Tuple, Dict, List
from dataclasses import dataclass

import torch
import torch.optim as optim
import torch.nn.functional as F

import wandb

### Helper function

In [65]:
# Utility functions
def extract_answer_with_regex(text: str) -> str:
    """Extract text between <answer> and </answer> tags using regex"""
    pattern = r'<answer>(.*?)</answer>'
    match = re.search(pattern, text, re.DOTALL)
    if match:
        return match.group(1).strip()
    else:
        return ""

def extract_thinking_with_regex(text: str) -> str:
    """Extract text between <think> and </think> tags using regex"""
    pattern = r'<think>(.*?)</think>'
    match = re.search(pattern, text, re.DOTALL)
    if match:
        return match.group(1).strip()
    else:
        return ""

def sample_batch(dataset_list: List[Dict], batch_size: int) -> List[Dict]:
    """Sample a batch from the dataset"""
    return random.sample(dataset_list, min(batch_size, len(dataset_list)))

### GRPO Config

In [66]:
@dataclass
class GRPOConfig:
    """Configuration class for GRPO (Generalized Reward-based Policy Optimization) hyperparameters"""

    # Model and training hyperparameters
    model_name: str = "Qwen/Qwen3-0.6B"
    learning_rate: float = 3e-4
    batch_size: int = 32
    num_updates: int = 1000
    max_steps: int = 20
    n_outputs: int = 4
    max_length: int = 256
    grpo_iterations: int = 4  # Number of GRPO iterations per update

    # GRPO specific hyperparameters
    clip_epsilon: float = 0.2  # PPO clipping parameter
    kl_beta: float = 0.02      # KL divergence coefficient

    # Training configuration
    # gradient_accumulation_steps: int = 1
    # warmup_steps: int = 100
    # max_grad_norm: float = 1.0
    seed: int = 42

    # Dataset configuration
    dataset_name: str = "syllogism"
    dataset_size: int = 1000

    # Device configuration
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    # Optimization
    adam_epsilon: float = 1e-8
    weight_decay: float = 0.01

    # Logging and saving
    log_interval: int = 10
    save_interval: int = 100
    eval_interval: int = 50

    # Generation parameters
    temperature: float = 1.0
    top_p: float = 0.9
    top_k: int = 50
    do_sample: bool = True

In [67]:
# Initialize configuration
config = GRPOConfig()

# Display configuration
print("GRPO Configuration:")
print("=" * 50)
for field, value in config.__dict__.items():
    print(f"{field:<25}: {value}")
print("=" * 50)

GRPO Configuration:
model_name               : Qwen/Qwen3-0.6B
learning_rate            : 0.0003
batch_size               : 32
num_updates              : 1000
max_steps                : 20
n_outputs                : 4
max_length               : 256
grpo_iterations          : 4
clip_epsilon             : 0.2
kl_beta                  : 0.02
seed                     : 42
dataset_name             : syllogism
dataset_size             : 1000
device                   : cuda
adam_epsilon             : 1e-08
weight_decay             : 0.01
log_interval             : 10
save_interval            : 100
eval_interval            : 50
temperature              : 1.0
top_p                    : 0.9
top_k                    : 50
do_sample                : True


### Data Set

In [68]:
# Load the dataset using configuration

dataset = reasoning_gym.create_dataset(config.dataset_name, size=config.dataset_size, seed=config.seed)

print(f"Dataset loaded: {config.dataset_name}")
print(f"Dataset size: {config.dataset_size}")
print(f"Sample data point:")
for i, data in enumerate(dataset):
    print(f"Question: {data['question']}")
    print(f"Answer: {data['answer']}")
    if i == 0:  # Show only first sample
        break

Dataset loaded: syllogism
Dataset size: 1000
Sample data point:
Question: Consider these statements:
1. No students are humans
2. All humans are chefs

Does it logically follow that:
Some chefs are humans?
(Answer Yes or No)
Answer: Yes


In [69]:
for data in dataset:
    print(data)
    # Prepare the input
    input_text = data['question'] + " " + data['answer']
    inputs = tokenizer.encode(input_text)

    # Generate output
    outputs = model(torch.tensor([inputs], device=device))

    # Decode the output
    generated_text = tokenizer.decode(torch.argmax(outputs[:, -1, :], dim=-1).tolist())

    print(f"Input: {input_text}\nGenerated: {generated_text}\n")
    break

{'question': 'Consider these statements:\n1. No students are humans\n2. All humans are chefs\n\nDoes it logically follow that:\nSome chefs are humans?\n(Answer Yes or No)', 'answer': 'Yes', 'metadata': {'source_dataset': 'syllogism', 'source_index': 0, 'premise1': 'No students are humans', 'premise2': 'All humans are chefs', 'selected_premise': 2, 'conclusion': 'Some chefs are humans', 'is_valid': True, 'type': 'inversion'}}
Input: Consider these statements:
1. No students are humans
2. All humans are chefs

Does it logically follow that:
Some chefs are humans?
(Answer Yes or No) Yes
Generated: <think>



In [70]:
#Dataset Creation and Sampling

# Check what methods are available on the dataset
print("Dataset type:", type(dataset))
print("Available methods:", [method for method in dir(dataset) if not method.startswith('_')])

# Since procedural datasets don't have .sample(), we need to use random sampling
import random

def sample_from_dataset(dataset, n):
    """Sample n items from the dataset"""
    # Convert dataset to list if it's iterable
    dataset_list = list(dataset) if hasattr(dataset, '__iter__') else dataset

    # If the dataset is smaller than n, return all items
    if len(dataset_list) < n:
        return dataset_list

    # Random sample without replacement
    return random.sample(dataset_list, n)

# Sample 5 data points for demonstration
batch = sample_from_dataset(dataset, 5)
print(f"Sampled batch of {len(batch)} items:")
for i, item in enumerate(batch):
    print(f"Item {i+1}: {item}")

Dataset type: <class 'reasoning_gym.logic.syllogisms.SyllogismDataset'>
Available methods: ['DEFAULT_TERMS', 'category', 'config', 'score_answer', 'seed', 'size', 'terms']
Sampled batch of 5 items:
Item 1: {'question': 'Consider these statements:\n1. Some doctors are not insects\n2. No insects are humans\n\nDoes it logically follow that:\nNo humans are insects?\n(Answer Yes or No)', 'answer': 'Yes', 'metadata': {'source_dataset': 'syllogism', 'source_index': 417, 'premise1': 'Some doctors are not insects', 'premise2': 'No insects are humans', 'selected_premise': 2, 'conclusion': 'No humans are insects', 'is_valid': True, 'type': 'inversion'}}
Item 2: {'question': 'Consider these statements:\n1. Some tigers are not writers\n2. All writers are dolphins\n\nDoes it logically follow that:\nSome dolphins are writers?\n(Answer Yes or No)', 'answer': 'Yes', 'metadata': {'source_dataset': 'syllogism', 'source_index': 119, 'premise1': 'Some tigers are not writers', 'premise2': 'All writers are d

### GRPO

In [71]:
class GRPO:
    def __init__(self, model, tokenizer, config: GRPOConfig):
        super(GRPO, self).__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.config = config

        # Move model to device
        self.model.to(config.device)

    def loss_function(self, old_log_probs: torch.Tensor, new_log_probs: torch.Tensor, advantage: float) -> torch.Tensor:
        """Compute GRPO loss with clipping"""
        if len(old_log_probs) == 0 or len(new_log_probs) == 0:
            return torch.tensor(0.0, device=self.config.device, requires_grad=True)

        # Move tensors to same device
        old_log_probs = old_log_probs.to(self.config.device)
        new_log_probs = new_log_probs.to(self.config.device)

        # Sum log probabilities for the sequence
        old_log_prob_sum = old_log_probs.sum()
        new_log_prob_sum = new_log_probs.sum()

        # Compute ratio
        ratio = torch.exp(new_log_prob_sum - old_log_prob_sum)

        # Compute clipped surrogate loss
        advantage_tensor = torch.tensor(advantage, device=self.config.device)
        surr1 = ratio * advantage_tensor
        surr2 = torch.clamp(ratio, 1 - self.config.clip_epsilon, 1 + self.config.clip_epsilon) * advantage_tensor

        # Return negative because we want to maximize
        loss = -torch.min(surr1, surr2)
        return loss

    def compute_advantages(self, rewards: torch.Tensor) -> torch.Tensor:
        """Compute advantages with proper normalization"""
        # Simple advantage computation - can be made more sophisticated
        normalized_rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
        return normalized_rewards

    def generate_with_log_probs(self, input_text: str, max_new_tokens: int = 100) -> Tuple[str, torch.Tensor, List[int], List[int]]:
        """Generate text while tracking log probabilities for each token"""
        self.model.eval()

        with torch.no_grad():
            input_token_ids = self.tokenizer.encode(input_text)
            input_tensor = torch.tensor([input_token_ids], device=self.config.device)

            cache = KVCache(n_layers=self.model.cfg["n_layers"])
            self.model.reset_kv_cache()

            # Initialize with input tokens
            generated_tokens = []
            log_probs = []

            # Prime the cache
            logits = self.model(input_tensor, cache=cache)

            for _ in range(max_new_tokens):
                # Get probabilities for the last token
                if self.config.temperature != 1.0:
                    logits_scaled = logits[:, -1, :] / self.config.temperature
                else:
                    logits_scaled = logits[:, -1, :]

                probs = F.softmax(logits_scaled, dim=-1)
                log_prob_dist = F.log_softmax(logits_scaled, dim=-1)

                # Sample next token
                if self.config.do_sample:
                    next_token = torch.multinomial(probs, 1)
                else:
                    next_token = torch.argmax(probs, dim=-1, keepdim=True)

                # Store the log probability of the chosen token
                token_log_prob = log_prob_dist.gather(1, next_token)
                log_probs.append(token_log_prob.item())
                generated_tokens.append(next_token.item())

                # Check for EOS
                if next_token.item() == self.tokenizer.eos_token_id:
                    break

                # Continue generation
                logits = self.model(next_token, cache=cache)

            # Decode the generated text
            full_tokens = input_token_ids + generated_tokens
            generated_text = self.tokenizer.decode(full_tokens)

            return generated_text, torch.tensor(log_probs), generated_tokens, input_token_ids

    def recompute_log_probs(self, input_tokens: List[int], generated_tokens: List[int]) -> torch.Tensor:
        """Recompute log probabilities for a generated sequence"""
        if not generated_tokens:
            return torch.tensor([])

        self.model.eval()
        with torch.no_grad():
            # Prepare full sequence
            full_sequence = input_tokens + generated_tokens
            input_tensor = torch.tensor([full_sequence], device=self.config.device)

            # Get logits for the sequence
            logits = self.model(input_tensor)

            # Extract logits for generated tokens (offset by input length)
            start_idx = len(input_tokens) - 1
            end_idx = start_idx + len(generated_tokens)

            if end_idx > logits.shape[1]:
                end_idx = logits.shape[1]

            gen_logits = logits[0, start_idx:end_idx, :]

            # Apply temperature if configured
            if self.config.temperature != 1.0:
                gen_logits = gen_logits / self.config.temperature

            # Compute log probabilities
            log_probs = F.log_softmax(gen_logits, dim=-1)

            # Get log probs for actual generated tokens
            token_log_probs = []
            for i, token_id in enumerate(generated_tokens):
                if i < log_probs.shape[0]:
                    token_log_probs.append(log_probs[i, token_id].item())

            return torch.tensor(token_log_probs, device=self.config.device)

### Trainner

In [72]:
class GRPOTrainer:
    def __init__(self, model, tokenizer, config, dataset):
        self.model = model
        self.tokenizer = tokenizer
        self.config = config
        self.dataset = dataset
        self.dataset_list = list(dataset)

        # Initialize GRPO
        self.grpo = GRPO(model, tokenizer, config)

        # Initialize optimizer
        self.optimizer = optim.AdamW(
            model.parameters(),
            lr=config.learning_rate,
            weight_decay=config.weight_decay,
            eps=config.adam_epsilon
        )

        print(f"Initialized GRPO trainer with dataset size: {len(self.dataset_list)}")

    def train_step(self, batch: List[Dict]) -> Dict:
        """Single training step"""
        # Generate responses and collect trajectories
        trajectories = []
        all_rewards = []

        for data in batch:
            question = data['question']
            true_answer = data['answer']

            # Generate multiple responses per question
            for _ in range(self.config.n_outputs):
                # Generate response with log probs
                generated_text, old_log_probs, generated_tokens, input_tokens = self.grpo.generate_with_log_probs(
                    question, max_new_tokens=self.config.max_length
                )

                # Extract answer
                extracted_answer = extract_answer_with_regex(generated_text)
                thinking_content = extract_thinking_with_regex(generated_text)

                # Compute reward components
                try:
                    accuracy = 1.0 if self.dataset.score_answer(extracted_answer, true_answer) == 1.0 else 0.0
                except:
                    # Fallback: simple string matching
                    accuracy = 1.0 if extracted_answer.strip().lower() == true_answer.strip().lower() else 0.0

                format_reward = 0.3 if extracted_answer.strip() else 0.0
                thinking_reward = 0.2 if thinking_content.strip() else 0.0
                total_reward = accuracy + format_reward + thinking_reward

                trajectories.append({
                    'question': question,
                    'true_answer': true_answer,
                    'generated_text': generated_text,
                    'extracted_answer': extracted_answer,
                    'old_log_probs': old_log_probs,
                    'generated_tokens': generated_tokens,
                    'input_tokens': input_tokens,
                    'reward': total_reward,
                    'accuracy': accuracy
                })
                all_rewards.append(total_reward)

        if not trajectories:
            return {'avg_reward': 0, 'avg_accuracy': 0, 'avg_loss': 0}

        # Compute advantages
        rewards_tensor = torch.tensor(all_rewards, dtype=torch.float32, device=self.config.device)
        advantages = self.grpo.compute_advantages(rewards_tensor)

        # GRPO update iterations
        total_loss = 0
        valid_losses = 0

        for grpo_iter in range(self.config.grpo_iterations):
            self.model.train()
            self.optimizer.zero_grad()

            batch_loss = torch.tensor(0.0, device=self.config.device, requires_grad=True)

            for i, traj in enumerate(trajectories):
                # Recompute log probs with current model
                new_log_probs = self.grpo.recompute_log_probs(
                    traj['input_tokens'],
                    traj['generated_tokens']
                )

                # Skip if lengths don't match or empty
                if len(new_log_probs) != len(traj['old_log_probs']) or len(new_log_probs) == 0:
                    continue

                # Compute loss
                loss = self.grpo.loss_function(
                    traj['old_log_probs'],
                    new_log_probs,
                    advantages[i].item()
                )

                batch_loss = batch_loss + loss
                valid_losses += 1

            if valid_losses > 0:
                # Average loss and backward pass
                avg_loss = batch_loss / valid_losses
                total_loss += avg_loss.item()

                avg_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                self.optimizer.step()

        # Compute metrics
        avg_reward = np.mean(all_rewards)
        avg_accuracy = np.mean([t['accuracy'] for t in trajectories])
        avg_loss = total_loss / self.config.grpo_iterations if self.config.grpo_iterations > 0 else 0

        return {
            'avg_reward': avg_reward,
            'avg_accuracy': avg_accuracy,
            'avg_loss': avg_loss,
            'trajectories': trajectories
        }

    def train(self):
        """Main training loop"""
        print("Starting GRPO training...")

        for update in tqdm(range(self.config.num_updates), desc="GRPO Training"):
            # Sample batch
            batch = sample_batch(self.dataset_list, self.config.batch_size)

            # Training step
            step_results = self.train_step(batch)

            # Logging
            if update % self.config.log_interval == 0:
                print(f"\nUpdate {update:4d}")
                print(f"  Avg Reward: {step_results['avg_reward']:.3f}")
                print(f"  Avg Accuracy: {step_results['avg_accuracy']:.3f}")
                print(f"  Avg Loss: {step_results['avg_loss']:.4f}")

                # Show example
                if 'trajectories' in step_results and step_results['trajectories']:
                    example = step_results['trajectories'][0]
                    print(f"  Example Q: {example['question'][:60]}...")
                    print(f"  Example Generated A: {example['extracted_answer'][:40]}...")
                    print(f"  Example True A: {example['true_answer'][:40]}...")
                    print("-" * 60)

        print("GRPO training completed!")
        return self.model

In [73]:
# Initialize and run training
print("Starting GRPO training with fixed implementation...")

# Create trainer instance
trainer = GRPOTrainer(model, tokenizer, config, dataset)

# Run training
print("="*60)
print("GRPO Training Starting")
print("="*60)

trained_model = trainer.train()

print("="*60)
print("GRPO Training Complete!")
print("="*60)

Starting GRPO training with fixed implementation...
Initialized GRPO trainer with dataset size: 1000
GRPO Training Starting
Starting GRPO training...


GRPO Training:   0%|          | 1/1000 [22:18<371:22:06, 1338.27s/it]


Update    0
  Avg Reward: 0.002
  Avg Accuracy: 0.000
  Avg Loss: 0.0458
  Example Q: Consider these statements:
1. Some bees are insects
2. No in...
  Example Generated A: ...
  Example True A: Yes...
------------------------------------------------------------


GRPO Training:   0%|          | 3/1000 [1:10:09<388:36:51, 1403.22s/it]


KeyboardInterrupt: 

In [None]:
# Optional: Test the trained model
print("\nTesting trained model with a sample question:")
test_question = "All roses are flowers. Some flowers are red. Therefore, some roses are red."
test_input = tokenizer.encode(test_question)
test_tensor = torch.tensor([test_input], device=device)

print(f"Test question: {test_question}")
print("Generated response:")

# Generate response
for token in generate_text_basic_stream(
    model=trained_model,
    token_ids=test_tensor,
    max_new_tokens=200,
    eos_token_id=tokenizer.eos_token_id
):
    token_id = token.squeeze(0).tolist()
    print(tokenizer.decode(token_id), end="", flush=True)

print("\n" + "="*60)