# Dynamic Transformer v1

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm
import pandas as pd

from typing import Tuple, List
from collections import deque
import sys

torch.manual_seed(0)

<torch._C.Generator at 0x784cb4251fd0>

In [None]:
# NORMAL HYPERPARAMETERS
batch_size = 64 # How many independent sequences will we process in parallel?
block_size = 256 # What is the maximum context length for predictions?
max_iters = 10000 # The total number of training iterations.
eval_interval = 1000 # How often to evaluate the model's performance.
learning_rate = 3e-4 # The step size for our optimizer.
eval_iters = 50 # Number of batches to average for loss estimation.
n_embd = 384 # The dimensionality of the token embeddings.
n_head = 6 # The number of attention heads.
n_layer = 10 # The number of transformer blocks.
dropout = 0.2 # The probability of dropping out neurons during training.

# DYNAMIC HYPERPARAMETERS
dynamic_k = 0.3 # Surprise threshold multiplier for Criterion U
d_st_history_window = 100 # The number of past d_st values to average
gate_loss_weight = 0 # The strength of the auxiliary loss on the gates
gate_warmup_iters = 2500 # Iterations to encourage gates to open

# Automatically select the best available device (CUDA, MPS, or CPU).
device = None
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'
print(f"Using device: {device}")

Using device: cuda


## Data Loading and Preparation

We'll use the Tiny Shakespeare dataset. We first need to load the text and create a vocabulary of all unique characters. Then, we'll create functions to encode a string into a sequence of integers (tokens) and decode a sequence of tokens back into a string.

In [None]:
# You may need to download the data first
!wget -O tiny-shakespeare.txt https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('tiny-shakespeare.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# Find all unique characters in the text.
chars = sorted(list(set(text)))
vocab_size = len(chars)

# Create a mapping from characters to integers (stoi) and vice-versa (itos).
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # Encoder: takes a string, outputs a list of integers.
decode = lambda l: ''.join([itos[i] for i in l]) # Decoder: takes a list of integers, outputs a string.

# Convert the entire dataset into a tensor of tokens.
data = torch.tensor(encode(text), dtype=torch.long)

# Split the data into training (90%) and validation (10%) sets.
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

--2025-06-28 13:04:44--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘tiny-shakespeare.txt’


2025-06-28 13:04:44 (18.5 MB/s) - ‘tiny-shakespeare.txt’ saved [1115394/1115394]



## Data Batching

This function, `get_batch`, generates a small, random batch of data. For each sequence in the batch, the input `x` is a chunk of text, and the target `y` is the same chunk shifted by one character. This is how the model learns to predict the next character.

In [None]:
def get_batch(split):
    # Select the appropriate dataset (train or val).
    data = train_data if split == 'train' else val_data
    # Generate random starting indices for the batches.
    ix = torch.randint(len(data) - block_size, (batch_size,))
    # Create the input sequences (x).
    x = torch.stack([data[i:i+block_size] for i in ix])
    # Create the target sequences (y), which are shifted by one position.
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    # Move the data to the selected device.
    x, y = x.to(device), y.to(device)
    return x, y

## Loss Estimation

To avoid noisy loss measurements, we estimate the loss by averaging it over multiple batches. This function is decorated with `@torch.no_grad()` to tell PyTorch not to calculate gradients, which saves memory and computation since we're only evaluating, not training.

In [None]:
# @torch.no_grad()
# def estimate_loss():
#     out = {}
#     # Set the model to evaluation mode.
#     model.eval()
#     for split in ['train', 'val']:
#         losses = torch.zeros(eval_iters)
#         for k in range(eval_iters):
#             X, Y = get_batch(split)
#             logits, loss, _ = model(X, Y)
#             losses[k] = loss.item()
#         out[split] = losses.mean()
#     # Set the model back to training mode.
#     model.train()
#     return out

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            # Pass a high iter number to ensure bias is off during eval
            logits, loss, _ = model(X, Y, current_iter=max_iters) # MODIFIED
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

## The Transformer Model: A Deep Dive

Now we'll build the Transformer model, piece by piece.

### Self-Attention Head

Self-attention is the core mechanism of the Transformer. It allows tokens to interact with each other and weigh their importance. Each token produces a **Query** (what I'm looking for), a **Key** (what I contain), and a **Value** (what I'll communicate). The attention score is calculated by taking the dot product of a token's Query with every other token's Key. This score is then scaled and passed through a softmax function to get the weights. Finally, the output is a weighted sum of all tokens' Values.

The mathematical formula is: $$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$

In [None]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        # Linear projections for Key, Query, and Value.
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        # A buffer for the triangular mask, not a model parameter.
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,hs)
        q = self.query(x) # (B,T,hs)

        # Compute attention scores ("affinities").
        # The dot product between queries and keys determines the attention weights.
        # We scale by sqrt(d_k) to prevent the softmax from becoming too saturated.
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)

        # Apply the causal mask to prevent tokens from attending to future tokens.
        # This is crucial for a decoder-style language model.
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        # Normalize the scores to get weights.
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)

        # Perform the weighted aggregation of the values.
        v = self.value(x) # (B,T,hs)
        out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        return out

### Multi-Head Attention

Instead of a single attention mechanism, Transformers use multiple attention "heads" in parallel. Each head can learn to focus on different types of relationships between tokens. The outputs of all heads are concatenated and projected back to the original embedding dimension.

In [None]:
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        # Create a list of attention heads.
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        # A linear layer to project the concatenated head outputs.
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Concatenate the outputs of each head along the last dimension.
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        # Project the result back to the embedding dimension.
        out = self.dropout(self.proj(out))
        return out

### Feed-Forward Network

After the attention mechanism gathers information, a simple feed-forward network processes this information for each token independently. It consists of two linear layers with a ReLU activation in between. This allows the model to perform more complex computations on the aggregated information.

In [None]:
class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd), # The inner layer is typically 4x larger.
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd), # Project back to the embedding dimension.
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

### Dynamic Block

A Dynamic block combines multi-head attention and two feed-forward networks one for the normal decoder functionality and one which acts as a prior on the probability of the model changing. It also includes two important features: residual connections and layer normalization.

- **Residual Connections**: The input to a sub-layer (like attention) is added to its output. This helps prevent the vanishing gradient problem in deep networks.
- **Layer Normalization**: This stabilizes the training by normalizing the features for each token across the embedding dimension.

In [None]:
# class DynamicBlock(nn.Module):
#     """ Transformer block: communication followed by computation """

#     def __init__(self, n_embd: int, n_head: int):
#         super().__init__()
#         head_size = n_embd // n_head
#         self.sa = MultiHeadAttention(n_head, head_size)
#         self.ffwd = FeedFoward(n_embd)
#         self.prior_ffn = FeedFoward(n_embd)
#         self.ln1 = nn.LayerNorm(n_embd)
#         self.ln2 = nn.LayerNorm(n_embd)
#         self.prev_mha = None

#     def forward(
#             self, x: torch.Tensor,
#             prev_mha: torch.Tensor = None,
#         ) -> Tuple[torch.Tensor, torch.Tensor]:
#         # Standard decoder block forward pass...
#         # x = x + self.sa(self.ln1(x))
#         mha_out = x + self.sa(self.ln1(x)) # (B,T,C)
#         # x = x + self.ffwd(self.ln2(x))
#         posterior = mha_out + self.ffwd(self.ln2(mha_out)) # (B,T,C)

#         # If we have no prior MHA output, we return the posterior and the MHA output.
#         # This is the case during the first step of generation.
#         if self.prev_mha is None:
#             # prev_mha = torch.zeros_like(mha_out) # (B,T,C)
#             self.gate = 1.0
#             return posterior, mha_out

#         # Align or pad the cached MHA so its length = T,
#         # Pad the RHS wth zero for dimension match
#         pad = mha_out.size(1) - prev_mha.size(1) # = 1 during generation
#         prev_mha = F.pad(prev_mha, (0, 0, 0, pad)) # right-pad with zeros

#         change_prior = self.prior_ffn(prev_mha) # (B,T,C)
#         static_prior = x # (B,T,C)

#         # MSE ONLY over the tokens that existed last step
#         # - keep first  (T-1) positions
#         # - drop the freshly-appended token at index -1
#         if x.size(1) > 1: # normal case
#             post_old   = posterior[:, :-1, :] # (B,T-1,C)
#             stat_old   = static_prior[:, :-1, :]
#             change_old = change_prior[:, :-1, :]

#             d_st = F.mse_loss(post_old, stat_old,
#                               reduction="none").mean((-2, -1)) # (B,)
#             d_ch = F.mse_loss(post_old, change_old,
#                               reduction="none").mean((-2, -1)) # (B,)
#         else: # first token: force update
#             d_st = torch.ones(x.size(0), device=x.device)
#             d_ch = torch.zeros_like(d_st)


#         # Block-level event / skip decision (scalar mask)
#         gate = ((d_st > d_ch) | (d_st > dynamic_k * d_st.mean())).view(-1, 1, 1) # (B,1,1)

#         # Conditional update
#         out = torch.where(gate, posterior, static_prior) # broadcast over T,C

#         self.gate = gate.float().mean().item()

#         return out, mha_out

In [None]:
# class DynamicBlock(nn.Module):
#     """
#     A stateful Dynamic Transformer Block implementing the logic from the notes.
#     - Manages its own history of static surprise (d_st) for Criterion U.
#     - Implements Criterion E by comparing priors over the shared context length.
#     """

#     def __init__(self, n_embd: int, n_head: int):
#         super().__init__()
#         head_size = n_embd // n_head
#         self.sa = MultiHeadAttention(n_head, head_size)
#         self.ffwd = FeedFoward(n_embd)
#         self.prior_ffn = FeedFoward(n_embd)
#         self.ln1 = nn.LayerNorm(n_embd)
#         self.ln2 = nn.LayerNorm(n_embd)

#         # Each block now owns its history of static surprise values.
#         self.d_st_history = deque(maxlen=d_st_history_window)
#         self.last_gate = torch.tensor(1.0) # For logging, default to 1

#     def forward(
#         self, x: torch.Tensor, prev_mha: torch.Tensor = None
#     ) -> Tuple[torch.Tensor, torch.Tensor]:
#         # 1. Compute the standard posterior path
#         mha_out = x + self.sa(self.ln1(x))
#         posterior = mha_out + self.ffwd(self.ln2(mha_out))

#         # 2. Handle the base case (no cache or first token of a sequence)
#         if prev_mha is None or x.size(1) == 1:
#             # NOTE: We must still calculate d_st to update the history
#             d_st = F.mse_loss(posterior, x)
#             self.d_st_history.append(d_st.item())
#             self.last_gate = 1.0  # Always update in the base case
#             return posterior, mha_out

#         # 3. Define Priors (L > 1 case)
#         # prior_st has T tokens, matching the input x
#         prior_st = x

#         common_len = x.size(1) - 1

#         prev_mha_prefix = prev_mha[:, :common_len, :]          # (B, common_len, C)

#         # ❷ build the change prior on that prefix
#         prior_ch = self.prior_ffn(prev_mha_prefix)             # (B, common_len, C)

#         # ❸ compute the surprises on perfectly matched shapes
#         posterior_prefix = posterior[:, :common_len, :]        # (B, common_len, C)
#         prior_st_prefix  = x[:, :common_len, :]                # (B, common_len, C)

#         d_st = F.mse_loss(posterior_prefix,
#                         prior_st_prefix,
#                         reduction="none").mean((-2, -1))      # (B,)
#         d_ch = F.mse_loss(posterior_prefix,
#                         prior_ch,
#                         reduction="none").mean((-2, -1))      # (B,)

#         # VPR Gating Logic
#         # Criterion E: Is the posterior closer to the change prior?
#         CE = d_st > d_ch

#         # VPR Gating Logic
#         CE = d_st > d_ch
#         moving_avg = (
#             sum(self.d_st_history) / len(self.d_st_history)
#             if len(self.d_st_history) > 0
#             else d_st.item() + 1.0
#         )
#         CU = d_st > (dynamic_k * moving_avg)

#         # Final gate decision
#         gate = torch.tensor(1.0 if (CE or CU) else 0.0)
#         self.last_gate = gate  # For logging

#         # Update history and return gated output
#         self.d_st_history.append(d_st.item())

#         # The gated output is a mix of the full posterior and the original input x
#         out = (gate * posterior) + ((1 - gate) * x)

#         return out, mha_out

In [None]:
# class DynamicBlock(nn.Module):
#     """
#     A stateful Dynamic Transformer Block. Its forward pass is designed
#     to be called sequentially in a loop.
#     """

#     def __init__(self, n_embd: int, n_head: int):
#         super().__init__()
#         head_size = n_embd // n_head
#         self.sa = MultiHeadAttention(n_head, head_size)
#         self.ffwd = FeedFoward(n_embd)
#         self.prior_ffn = FeedFoward(n_embd)
#         self.ln1 = nn.LayerNorm(n_embd)
#         self.ln2 = nn.LayerNorm(n_embd)
#         self.d_st_history = deque(maxlen=d_st_history_window)
#         self.last_gate = torch.tensor(1.0) # Default to open

#     def forward(
#         self, x: torch.Tensor, prev_mha: torch.Tensor = None
#     ) -> Tuple[torch.Tensor, torch.Tensor]:
#         # x is shape (B, 1, C) - we process one token at a time
#         mha_out = x + self.sa(self.ln1(x))
#         posterior = mha_out + self.ffwd(self.ln2(mha_out))

#         if prev_mha is None:
#             self.last_gate = torch.tensor(1.0, device=x.device)
#             return posterior, mha_out

#         # Priors and Surprise Calculation
#         prior_st = x
#         # change prior: use mha_out from *previous* token (right-shift by 1)
#         prev_mha = F.pad(mha_out[:, :-1, :], (0, 0, 1, 0))   # (B,T,C)
#         prior_ch = self.prior_ffn(prev_mha)              # (B,T,C)
#         d_st = F.mse_loss(posterior, prior_st)
#         d_ch = F.mse_loss(posterior, prior_ch)

#         # VPR Gating Logic
#         CE = d_st > d_ch
#         moving_avg = (
#             sum(self.d_st_history) / len(self.d_st_history)
#             if len(self.d_st_history) > 0
#             else d_st.item() + 1.0
#         )
#         CU = d_st > (dynamic_k * moving_avg)
#         gate = torch.tensor(1.0 if (CE or CU) else 0.0, device=x.device)
#         self.last_gate = gate # Store the gate decision for the auxiliary loss

#         # Update history and return gated output
#         self.d_st_history.append(d_st.item())
#         out = (gate * posterior) + ((1 - gate) * x)
#         return out, mha_out

In [None]:
class DynamicBlock(nn.Module):
    """
    Parallel-aware dynamic transformer block.
    All tensors are (B, T, C).  The first position (t=0) is forced open.
    """

    def __init__(self, n_embd: int, n_head: int):
        super().__init__()
        head_size = n_embd // n_head
        self.sa        = MultiHeadAttention(n_head, head_size)
        self.ffwd      = FeedFoward(n_embd)
        self.prior_ffn = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        self.register_buffer("ones", torch.tensor(1.0))  # for fast gate log

    def forward(
        self, x: torch.Tensor, current_iter: int = 0
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        # ----- 1. posterior path -------------------------------------------------
        mha_out  = x + self.sa(self.ln1(x))              # (B,T,C)
        posterior = mha_out + self.ffwd(self.ln2(mha_out))

        # ----- 2. build priors ---------------------------------------------------
        # static prior is just the input
        prior_st = x                                     # (B,T,C)

        # change prior: use mha_out from *previous* token (right-shift by 1)
        prev_mha = F.pad(mha_out[:, :-1, :], (0, 0, 1, 0))   # (B,T,C)
        prior_ch = self.prior_ffn(prev_mha)              # (B,T,C)

        # ----- 3. surprises (vectorised) ----------------------------------------
        d_st = F.mse_loss(posterior, prior_st, reduction="none").mean(-1)  # (B,T)
        d_ch = F.mse_loss(posterior, prior_ch, reduction="none").mean(-1)  # (B,T)

        # --- 4. NEW: Gate Bias Calculation ---
        # Bias starts high and linearly decays to zero over the warmup period.
        bias_scale = max(0.0, 1.0 - (current_iter / gate_warmup_iters))
        # We add a bias to d_ch, making it a "worse" choice early on.
        # This encourages the gate to open (d_st > d_ch becomes easier to satisfy).
        gate_bias = d_ch.detach().mean() * bias_scale
        d_ch_biased = d_ch + gate_bias

        # ----- 4. VPR gating per token ------------------------------------------
        # Criterion E
        CE = d_st > d_ch_biased
        # Criterion U: moving average over sequence length T (cheap surrogate)
        moving_avg = d_st.detach().mean(-1, keepdim=True)          # (B,1)
        CU = d_st > dynamic_k * moving_avg                         # (B,T)
        gate = (CE | CU).float()                                   # (B,T)

        # force first token open
        gate[:, 0] = 1.0
        self.last_gate = gate.mean().detach()          # scalar for logging

        # broadcast to (B,T,C) for mixing
        gate_3d = gate.unsqueeze(-1)
        out = gate_3d * posterior + (1.0 - gate_3d) * x

        return out, mha_out, gate            # gate is (B,T) – used for aux-loss

### Dynamic GPT Model

Finally, we assemble all the components into the full GPT model. This includes:

- **Token Embedding Table**: Converts input token indices into dense vectors (embeddings).
- **Positional Embedding Table**: Since self-attention is permutation-invariant, we add positional embeddings to give the model information about the order of tokens.
- **A Sequence of Transformer Blocks**: The core of the model where the processing happens.
- **A Final Layer Norm and Linear Head**: To produce the final output logits over the vocabulary.

In [None]:
# class DynamicGPT(nn.Module):

#     def __init__(self):
#         super().__init__()
#         # Each token directly reads off the logits for the next token from a lookup table.
#         self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
#         self.position_embedding_table = nn.Embedding(block_size, n_embd)
#         # A sequence of dynamic blocks.
#         self.blocks = nn.ModuleList(
#             [DynamicBlock(n_embd, n_head) for _ in range(n_layer)]
#         )
#         self.ln_f = nn.LayerNorm(n_embd) # Final layer norm.
#         self.lm_head = nn.Linear(n_embd, vocab_size) # The head that produces logits.

#         self.apply(self._init_weights)

#     def _init_weights(self, module):
#         # A common practice for initializing weights in transformer models.
#         if isinstance(module, nn.Linear):
#             torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
#             if module.bias is not None:
#                 torch.nn.init.zeros_(module.bias)
#         elif isinstance(module, nn.Embedding):
#             torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

#     def forward(self, idx, targets=None, past_mhas: List[torch.Tensor]=None):
#         B, T = idx.shape

#         # Get token and position embeddings.
#         tok_emb = self.token_embedding_table(idx) # (B,T,C)
#         pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
#         # Add them together to give the model positional information.
#         x = tok_emb + pos_emb # (B,T,C)

#         # Pass though dynamic blocks.
#         present_mhas = []
#         for i, block in enumerate(self.blocks):
#             prev = past_mhas[i] if past_mhas is not None else None
#             x, mha_out = block(x, prev)
#             present_mhas.append(mha_out)

#         x = self.ln_f(x) # (B,T,C)
#         # Get the final logits.
#         logits = self.lm_head(x) # (B,T,vocab_size)

#         if targets is None:
#             loss = None
#         else:
#             # Reshape logits and targets for the cross-entropy loss function.
#             B, T, C = logits.shape
#             logits = logits.view(B*T, C)
#             targets = targets.view(B*T)
#             loss = F.cross_entropy(logits, targets)

#         return logits, loss, present_mhas

#     def generate(self, idx, max_new_tokens):
#         # This method generates new tokens autoregressively.

#         # Initialise the past multiple heads attention cache.
#         past_mhas = [None for _ in range(len(self.blocks))]

#         for _ in range(max_new_tokens):
#             # Crop the context to the last block_size tokens to save computation.
#             idx_cond = idx[:, -block_size:]
#             # Get the predictions for the next token.
#             logits, loss, past_mhas = self(idx_cond, past_mhas=past_mhas)
#             # Focus only on the last time step's logits.
#             logits = logits[:, -1, :] # becomes (B, C)
#             # Apply softmax to get probabilities.
#             probs = F.softmax(logits, dim=-1) # (B, C)
#             # Sample from the distribution to get the next token.
#             idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
#             # Append the sampled index to the running sequence.
#             idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
#         return idx

In [None]:
class DynamicGPT(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.ModuleList(
            [DynamicBlock(n_embd, n_head=n_head) for _ in range(n_layer)]
        )
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    # def forward(self, idx, targets=None):
    #     # This forward pass now handles the full training logic,
    #     # including teacher-forcing and the auxiliary loss.
    #     B, T = idx.shape
    #     tok_emb = self.token_embedding_table(idx)
    #     pos_emb = self.position_embedding_table(
    #         torch.arange(T, device=device)
    #     )
    #     x = tok_emb + pos_emb

    #     # Teacher-forcing loop for training
    #     past_mhas = [None] * len(self.blocks)
    #     all_logits = []
    #     all_gate_activations = []

    #     for t in range(T): # Process the sequence token by token
    #         x_step = x[:, t:t+1, :] # Input is a single token
    #         present_mhas = []
    #         block_gates_this_step = []

    #         for i, block in enumerate(self.blocks):
    #             prev = past_mhas[i] if past_mhas is not None else None
    #             x_step, mha_out = block(x_step, prev)
    #             present_mhas.append(mha_out)
    #             block_gates_this_step.append(block.last_gate)

    #         past_mhas = present_mhas # Update the cache for the next token
    #         final_x = self.ln_f(x_step)
    #         logits = self.lm_head(final_x)
    #         all_logits.append(logits)
    #         all_gate_activations.append(torch.stack(block_gates_this_step))

    #     # Concatenate results from all timesteps
    #     logits = torch.cat(all_logits, dim=1)
    #     gate_activations = torch.stack(all_gate_activations, dim=1) # (n_layer, T)

    #     loss = None
    #     if targets is not None:
    #         # 1. Main cross-entropy loss
    #         main_loss = F.cross_entropy(
    #             logits.view(B * T, -1), targets.view(B * T)
    #         )

    #         # 2. Auxiliary Gate Loss (MoE-style)
    #         # We want the mean activation rate across all layers and time to be non-zero.
    #         # A simple sparsity-inducing loss is to penalize high activation.
    #         # This encourages the model to use the skip connection (gate=0).
    #         mean_gate_activation = gate_activations.mean()
    #         aux_loss = mean_gate_activation

    #         # 3. Combined Loss
    #         loss = main_loss + gate_loss_weight * aux_loss

    #     # Return logits and loss. The 'presents' are handled internally now.
    #     # return logits, loss, None # Return None for the cache, it's not needed outside
    #     return logits, loss, gate_activations
    def forward(self, idx, targets=None, current_iter: int = 0):
        B, T = idx.shape
        tok = self.token_embedding_table(idx)
        pos = self.position_embedding_table(
            torch.arange(T, device=idx.device)
        )
        x = tok + pos                                   # (B,T,C)

        gate_logs = []                                   # collect (B,T) per block
        for block in self.blocks:
            # x, mha_out, gate = block(x)
            x, mha_out, gate = block(x, current_iter=current_iter)
            gate_logs.append(gate)

        x = self.ln_f(x)
        logits = self.lm_head(x)                         # (B,T,V)

        loss = None
        if targets is not None:
            # 1. Main cross-entropy loss (unchanged)
            lm_loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)), targets.view(-1)
            )

            # --- 2. NEW: Two-Phase Auxiliary Gate Loss ---
            gate_tensor = torch.stack(gate_logs)
            mean_gate_activation = gate_tensor.mean()

            if current_iter < gate_warmup_iters:
                # Phase 1: WARM-UP. Penalize gates for being closed.
                # We want mean_gate_activation to be close to 1.0.
                # The loss is how far it is from 1.0.
                aux_loss = 1.0 - mean_gate_activation
            else:
                # Phase 2: SPARSITY. Penalize gates for being open.
                # This encourages the model to learn to skip.
                aux_loss = mean_gate_activation

            # 3. Combined Loss (unchanged)
            loss = lm_loss + gate_loss_weight * aux_loss

        # return per-block gate mean for live display
        gate_means = torch.stack([g.mean() for g in gate_logs])  # (N,)
        return logits, loss, gate_means

    def generate(self, idx, max_new_tokens):
        # The generate function is now simpler, as it mirrors one step of the training loop.
        for _ in range(max_new_tokens):
            # Crop the context to the last block_size tokens.
            # This is the context the model will see.
            idx_cond = idx[:, -block_size:]

            # Get the predictions. We pass a high current_iter to ensure
            # the gate warm-up bias is turned off during generation.
            # The forward method no longer takes a 'past_mhas' argument.
            logits, loss, gate_means = self.forward(
                idx_cond, current_iter=max_iters
            )

            # Focus only on the logit for the very last token in the sequence.
            logits = logits[:, -1, :]  # becomes (B, C)

            # Apply softmax to get probabilities.
            probs = F.softmax(logits, dim=-1)  # (B, C)

            # Sample from the distribution to get the next token.
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)

            # Append the sampled index to the running sequence.
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)

        return idx

## Model Training

Now we can instantiate the model and the optimizer. We'll use the AdamW optimizer, which is a standard choice for training Transformers. The training loop will repeatedly sample a batch of data, calculate the loss, and update the model's parameters.

In [None]:
model = DynamicGPT()
activation_log = [[] for _ in range(len(model.blocks))]
m = model.to(device)
# Print the number of parameters in the model.
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# Create a PyTorch optimizer.
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

29.697857 M parameters


In [None]:
# pbar = tqdm(range(max_iters))
# for it in pbar:

#     # Every once in a while, evaluate the loss on train and val sets.
#     if it % eval_interval == 0 or it == max_iters - 1:
#         losses = estimate_loss()
#         print(f"step {it}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

#     # Sample a batch of data.
#     xb, yb = get_batch('train')

#     # Evaluate the loss.
#     logits, loss, _ = model(xb, yb)
#     # Reset gradients from the previous iteration.
#     optimizer.zero_grad(set_to_none=True)
#     # Compute gradients for this batch (backpropagation).
#     loss.backward()
#     # Update the model's parameters.
#     optimizer.step()

#     # Collect Block activation info
#     for i, blk in enumerate(model.blocks):
#         # blk.last_gate is the mean gate value we stored in DynamicBlock
#         activation_log[i].append(int(blk.last_gate >= 0.5))
pbar = tqdm(range(max_iters))
for it in pbar:
    if it % eval_interval == 0 or it == max_iters - 1:
        losses = estimate_loss()
        print(f"\nstep {it}: train {losses['train']:.4f} │ val {losses['val']:.4f}")

    xb, yb = get_batch("train")
    logits, loss, gate_means = model(xb, yb, current_iter=it)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    # log & live display
    for i, g in enumerate(gate_means):
        activation_log[i].append(g.item())
    pbar.set_postfix(
        **{f"B{i}": f"{g.item():.2f}" for i, g in enumerate(gate_means)}
    )

  0%|          | 0/10000 [00:00<?, ?it/s]


step 0: train 4.1864 │ val 4.1887


 10%|█         | 1000/10000 [03:48<32:57,  4.55it/s, B0=1.00, B1=1.00, B2=1.00, B3=1.00, B4=1.00, B5=1.00, B6=1.00, B7=1.00, B8=1.00, B9=0.95]


step 1000: train 1.4402 │ val 1.6560


 20%|██        | 2000/10000 [07:36<29:13,  4.56it/s, B0=1.00, B1=1.00, B2=1.00, B3=1.00, B4=1.00, B5=1.00, B6=1.00, B7=1.00, B8=1.00, B9=0.96]


step 2000: train 1.2082 │ val 1.5015


 30%|███       | 3000/10000 [11:24<25:37,  4.55it/s, B0=1.00, B1=1.00, B2=1.00, B3=1.00, B4=1.00, B5=1.00, B6=1.00, B7=1.00, B8=1.00, B9=0.95]


step 3000: train 1.0664 │ val 1.5007


 40%|████      | 4000/10000 [15:12<21:57,  4.56it/s, B0=1.00, B1=1.00, B2=1.00, B3=1.00, B4=1.00, B5=1.00, B6=1.00, B7=1.00, B8=1.00, B9=0.95]


step 4000: train 0.9427 │ val 1.5290


 50%|█████     | 5000/10000 [19:00<18:16,  4.56it/s, B0=1.00, B1=1.00, B2=1.00, B3=1.00, B4=1.00, B5=1.00, B6=1.00, B7=1.00, B8=1.00, B9=0.94]


step 5000: train 0.8029 │ val 1.6163


 60%|██████    | 6000/10000 [22:49<14:38,  4.55it/s, B0=1.00, B1=1.00, B2=1.00, B3=1.00, B4=1.00, B5=1.00, B6=1.00, B7=1.00, B8=1.00, B9=0.94]


step 6000: train 0.6610 │ val 1.7428


 70%|███████   | 7000/10000 [26:37<10:57,  4.56it/s, B0=1.00, B1=1.00, B2=1.00, B3=1.00, B4=1.00, B5=1.00, B6=1.00, B7=1.00, B8=1.00, B9=0.92]


step 7000: train 0.5458 │ val 1.8629


 80%|████████  | 8000/10000 [30:25<07:18,  4.56it/s, B0=1.00, B1=1.00, B2=1.00, B3=1.00, B4=1.00, B5=1.00, B6=1.00, B7=1.00, B8=1.00, B9=0.92]


step 8000: train 0.4315 │ val 1.9920


 90%|█████████ | 9000/10000 [34:13<03:39,  4.56it/s, B0=1.00, B1=1.00, B2=1.00, B3=1.00, B4=1.00, B5=1.00, B6=1.00, B7=1.00, B8=1.00, B9=0.92]


step 9000: train 0.3359 │ val 2.1367


100%|█████████▉| 9999/10000 [38:01<00:00,  4.56it/s, B0=1.00, B1=1.00, B2=1.00, B3=1.00, B4=1.00, B5=1.00, B6=1.00, B7=1.00, B8=1.00, B9=0.91]


step 9999: train 0.2705 │ val 2.2252


100%|██████████| 10000/10000 [38:10<00:00,  4.37it/s, B0=1.00, B1=1.00, B2=1.00, B3=1.00, B4=1.00, B5=1.00, B6=1.00, B7=1.00, B8=1.00, B9=0.92]


In [None]:
# ----- Initialize the log for summary stats -----
# This will store the average activation rate for each block at each iteration
# activation_log = [[] for _ in range(len(model.blocks))]

# ----- The Main Training Loop -----
# pbar = tqdm(range(max_iters))
# for it in pbar:
#     # 1. Evaluation step (no changes here)
#     if it % eval_interval == 0 or it == max_iters - 1:
#         losses = estimate_loss()
#         print(
#             f"\nstep {it}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}"
#         )

#     # 2. Sample a batch of data
#     xb, yb = get_batch("train")

#     # 3. Evaluate the loss and get gate activations
#     # We now capture the third return value from the model
#     logits, loss, gate_activations = model(xb, yb)
#     optimizer.zero_grad(set_to_none=True)
#     loss.backward()
#     optimizer.step()

    # 4. Logging and Summary Stats
    # if gate_activations is not None:
    #     # Calculate the mean activation for each block over the time dimension
    #     # gate_activations has shape (n_layer, T)
    #     avg_activations_per_block = gate_activations.mean(dim=1)

    #     # --- A) Update the detailed log for post-training analysis ---
    #     for i, avg_act in enumerate(avg_activations_per_block):
    #         activation_log[i].append(avg_act.item())

        # # --- B) Create the live report for the progress bar ---
        # report = []
        # for i, avg_act in enumerate(avg_activations_per_block):
        #     # Format as "B0:0.95" for Block 0 having a 95% activation rate
        #     report.append(f"B{i}:{avg_act:.2f}")
        # pbar.set_postfix_str(" | ".join(report))

## Text Generation

After training, we can use our model to generate new text. We start with a single token (a newline character in this case) and let the model predict the next token, which we then feed back into the model to predict the next one, and so on.

In [None]:
# Start generation with a single token (0 is the newline character).
context = torch.zeros((1, 1), dtype=torch.long, device=device)
# Generate and decode the output.
print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))


gh's co; Signify Bups by God's and shess quarrers
saint-trees; that her stands upon the back
Of new to such events from the earth. But is her
to the public and the ballad and greeness;
Master, you shall, bring them to prevent it;
and, by my knot, Green and I bury that I would,
As beneal to thee, that we, have cannot limb.

JULIET:
That had not been, that would murded when he lies,
Told now their men's pleasures tears and red:
And their spurse they shall be thus or this.

HASTINGS:
I'll pernio to


## Analysis of Average Activation

In [None]:
# act_df = pd.DataFrame(activation_log).T # shape: (iters , n_blocks)
# act_df.columns = [f"block_{i}" for i in range(len(model.blocks))]

# # e.g. save for offline inspection
# act_df.to_csv("block_activation.csv", index=False)

# # quick sanity-check
# print(act_df)

In [None]:
# This cell works as-is, but the interpretation of the data changes.
act_df = pd.DataFrame(activation_log).T
act_df.columns = [f"block_{i}_avg_act" for i in range(len(model.blocks))]

act_df.to_csv("block_activation_rates.csv", index=False)

print("--- Activation Rate Log ---")
print(act_df.head())
print("\n--- Summary Statistics ---")
print(act_df.describe())

--- Activation Rate Log ---
   block_0_avg_act  block_1_avg_act  block_2_avg_act  block_3_avg_act  \
0              1.0              1.0              1.0              1.0   
1              1.0              1.0              1.0              1.0   
2              1.0              1.0              1.0              1.0   
3              1.0              1.0              1.0              1.0   
4              1.0              1.0              1.0              1.0   

   block_4_avg_act  block_5_avg_act  block_6_avg_act  block_7_avg_act  \
0              1.0              1.0              1.0              1.0   
1              1.0              1.0              1.0              1.0   
2              1.0              1.0              1.0              1.0   
3              1.0              1.0              1.0              1.0   
4              1.0              1.0              1.0              1.0   

   block_8_avg_act  block_9_avg_act  
0              1.0              1.0  
1              1.0