# Dynamic Transformer v1

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm
import pandas as pd

from typing import Tuple, List
from collections import deque
import sys

torch.manual_seed(0)

<torch._C.Generator at 0x7c822419ad30>

In [None]:
# NORMAL HYPERPARAMETERS
batch_size = 64 # How many independent sequences will we process in parallel?
block_size = 256 # What is the maximum context length for predictions?
max_iters = 6000 # The total number of training iterations.
eval_interval = 500 # How often to evaluate the model's performance.
learning_rate = 3e-4 # The step size for our optimizer.
eval_iters = 25 # Number of batches to average for loss estimation.
n_embd = 384 # The dimensionality of the token embeddings.
n_head = 6 # The number of attention heads.
n_layer = 6 # The number of transformer blocks.
dropout = 0.2 # The probability of dropping out neurons during training.

# DYNAMIC HYPERPARAMETERS
dynamic_k = 0.9 # Surprise threshold multiplier for CU. Higher -> Less Updates
d_st_history_window = 10000 # The number of past d_st values to average
gate_loss_weight = 0 # The strength of the auxiliary loss on the gates
gate_warmup_iters = 2500 # Iterations to encourage gates to open
prior_loss_weight = 0.1 # The strength of the prior loss

# Automatically select the best available device (CUDA, MPS, or CPU).
device = None
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'
print(f"Using device: {device}")

Using device: cuda


## Data Loading and Preparation

We'll use the Tiny Shakespeare dataset. We first need to load the text and create a vocabulary of all unique characters. Then, we'll create functions to encode a string into a sequence of integers (tokens) and decode a sequence of tokens back into a string.

In [None]:
# You may need to download the data first
!wget -O tiny-shakespeare.txt https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('tiny-shakespeare.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# Find all unique characters in the text.
chars = sorted(list(set(text)))
vocab_size = len(chars)

# Create a mapping from characters to integers (stoi) and vice-versa (itos).
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # Encoder: takes a string, outputs a list of integers.
decode = lambda l: ''.join([itos[i] for i in l]) # Decoder: takes a list of integers, outputs a string.

# Convert the entire dataset into a tensor of tokens.
data = torch.tensor(encode(text), dtype=torch.long)

# Split the data into training (90%) and validation (10%) sets.
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

--2025-07-01 15:54:35--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘tiny-shakespeare.txt’


2025-07-01 15:54:35 (18.2 MB/s) - ‘tiny-shakespeare.txt’ saved [1115394/1115394]



## Data Batching

This function, `get_batch`, generates a small, random batch of data. For each sequence in the batch, the input `x` is a chunk of text, and the target `y` is the same chunk shifted by one character. This is how the model learns to predict the next character.

In [None]:
def get_batch(split):
    # Select the appropriate dataset (train or val).
    data = train_data if split == 'train' else val_data
    # Generate random starting indices for the batches.
    ix = torch.randint(len(data) - block_size, (batch_size,))
    # Create the input sequences (x).
    x = torch.stack([data[i:i+block_size] for i in ix])
    # Create the target sequences (y), which are shifted by one position.
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    # Move the data to the selected device.
    x, y = x.to(device), y.to(device)
    return x, y

## Loss Estimation

To avoid noisy loss measurements, we estimate the loss by averaging it over multiple batches. This function is decorated with `@torch.no_grad()` to tell PyTorch not to calculate gradients, which saves memory and computation since we're only evaluating, not training.

In [None]:
# @torch.no_grad()
# def estimate_loss():
#     out = {}
#     # Set the model to evaluation mode.
#     model.eval()
#     for split in ['train', 'val']:
#         losses = torch.zeros(eval_iters)
#         for k in range(eval_iters):
#             X, Y = get_batch(split)
#             logits, loss, _ = model(X, Y)
#             losses[k] = loss.item()
#         out[split] = losses.mean()
#     # Set the model back to training mode.
#     model.train()
#     return out

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            # Pass a high iter number to ensure bias is off during eval
            logits, loss, _ = model(X, Y, current_iter=max_iters) # MODIFIED
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

## The Transformer Model: A Deep Dive

Now we'll build the Transformer model, piece by piece.

### Self-Attention Head

Self-attention is the core mechanism of the Transformer. It allows tokens to interact with each other and weigh their importance. Each token produces a **Query** (what I'm looking for), a **Key** (what I contain), and a **Value** (what I'll communicate). The attention score is calculated by taking the dot product of a token's Query with every other token's Key. This score is then scaled and passed through a softmax function to get the weights. Finally, the output is a weighted sum of all tokens' Values.

The mathematical formula is: $$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$

In [None]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        # Linear projections for Key, Query, and Value.
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        # A buffer for the triangular mask, not a model parameter.
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,hs)
        q = self.query(x) # (B,T,hs)

        # Compute attention scores ("affinities").
        # The dot product between queries and keys determines the attention weights.
        # We scale by sqrt(d_k) to prevent the softmax from becoming too saturated.
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)

        # Apply the causal mask to prevent tokens from attending to future tokens.
        # This is crucial for a decoder-style language model.
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        # Normalize the scores to get weights.
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)

        # Perform the weighted aggregation of the values.
        v = self.value(x) # (B,T,hs)
        out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        return out

### Multi-Head Attention

Instead of a single attention mechanism, Transformers use multiple attention "heads" in parallel. Each head can learn to focus on different types of relationships between tokens. The outputs of all heads are concatenated and projected back to the original embedding dimension.

In [None]:
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        # Create a list of attention heads.
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        # A linear layer to project the concatenated head outputs.
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Concatenate the outputs of each head along the last dimension.
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        # Project the result back to the embedding dimension.
        out = self.dropout(self.proj(out))
        return out

### Feed-Forward Network

After the attention mechanism gathers information, a simple feed-forward network processes this information for each token independently. It consists of two linear layers with a ReLU activation in between. This allows the model to perform more complex computations on the aggregated information.

In [None]:
class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd), # The inner layer is typically 4x larger.
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd), # Project back to the embedding dimension.
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

### Dynamic Block

A Dynamic block combines multi-head attention and two feed-forward networks one for the normal decoder functionality and one which acts as a prior on the probability of the model changing. It also includes two important features: residual connections and layer normalization.

- **Residual Connections**: The input to a sub-layer (like attention) is added to its output. This helps prevent the vanishing gradient problem in deep networks.
- **Layer Normalization**: This stabilizes the training by normalizing the features for each token across the embedding dimension.

In [None]:
class DynamicBlock(nn.Module):
    """
    Parallel-aware dynamic transformer block.
    All tensors are (B, T, C).  The first position (t=0) is forced open.
    """

    def __init__(self, n_embd: int, n_head: int):
        super().__init__()
        head_size = n_embd // n_head
        # Don't use acronyms
        self.sa        = MultiHeadAttention(n_head, head_size)
        self.ffwd      = FeedFoward(n_embd)
        self.prior_ffn = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        self.ln3 = nn.LayerNorm(n_embd)
        self.register_buffer("ones", torch.tensor(1.0))  # for fast gate log
        # Mathod 1: Have the MHA out saved
        # self.prev_mha = None

    def forward(
        self,
        x: torch.Tensor,
        current_iter: int = 0,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        # (B,T,C)
        mha_out  = x + self.sa(self.ln1(x))
        posterior = mha_out + self.ffwd(self.ln2(mha_out))

        # priors
        prior_st = x
        prev_mha = F.pad(mha_out[:, :-1, :], (0, 0, 1, 0))
        prior_ch = self.prior_ffn(self.ln3(prev_mha))

        # token surprises
        d_st_tok = F.mse_loss(posterior, prior_st,
                              reduction="none").mean(-1) # (B,T)
        d_ch_tok = F.mse_loss(posterior, prior_ch,
                              reduction="none").mean(-1) # (B,T)

        # sequence-average surprises
        D_st = d_st_tok.mean(dim=1) # (B,)
        D_ch = d_ch_tok.mean(dim=1) # (B,)

        # warm-up bias (scalar)
        bias_scale = max(0.0, 1.0 - current_iter / gate_warmup_iters)
        beta = D_ch.detach().mean() * bias_scale # scalar
        D_ch_biased = D_ch + beta

        # VPR decision (per sample)
        CE = D_st > D_ch_biased # (B,) bool
        CU = D_st > dynamic_k * D_st.detach().mean() # (B,) bool
        gate_vec = (CE | CU).float() # (B,)

        # mix whole block
        gate = gate_vec.view(-1, 1, 1) # (B,1,1)
        out = gate * posterior + (1.0 - gate) * x # (B,T,C)

        # log single activation per block
        self.last_gate = gate_vec.mean().detach() # scalar

        # prediction-loss (detach posterior so gradients flow only into prior_ffn)
        pred_loss = F.mse_loss(prior_ch, posterior.detach()) # scalar

        return out, gate_vec, pred_loss            # gate is (B,T) – used for aux-loss

### Dynamic GPT Model

Finally, we assemble all the components into the full GPT model. This includes:

- **Token Embedding Table**: Converts input token indices into dense vectors (embeddings).
- **Positional Embedding Table**: Since self-attention is permutation-invariant, we add positional embeddings to give the model information about the order of tokens.
- **A Sequence of Transformer Blocks**: The core of the model where the processing happens.
- **A Final Layer Norm and Linear Head**: To produce the final output logits over the vocabulary.

In [None]:
class DynamicGPT(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.ModuleList(
            [DynamicBlock(n_embd, n_head=n_head) for _ in range(n_layer)]
        )
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(
            self,
            idx,
            targets=None,
            current_iter: int = 0,
        ):
        B, T = idx.shape
        tok = self.token_embedding_table(idx)
        pos = self.position_embedding_table(
            torch.arange(T, device=idx.device)
        )
        x = tok + pos                                   # (B,T,C)

        gate_logs = [] # collect (B,T) per block
        pred_losses = []
        for block in self.blocks:
            # x, mha_out, gate = block(x)
            x, gate, pred_loss = block(x, current_iter=current_iter)
            # gate_logs.append(gate
            gate_logs.append(gate)
            pred_losses.append(pred_loss)

        x = self.ln_f(x)
        logits = self.lm_head(x)                         # (B,T,V)

        loss = None
        if targets is not None:
            lm_loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)), targets.view(-1)
            )

            # gate activation loss
            gate_tensor = torch.stack(gate_logs)
            mean_gate_activation = gate_tensor.mean()
            aux_gate = 1.0 - mean_gate_activation if current_iter < gate_warmup_iters \
                       else mean_gate_activation

            # prediction loss
            pred_loss_total = torch.stack(pred_losses).mean()

            # combined objective
            loss = (
                lm_loss
                + gate_loss_weight * aux_gate
                + prior_loss_weight * pred_loss_total
            )

        # return per-block gate mean for live display
        gate_means = torch.stack([g.mean() for g in gate_logs])  # (N,)
        return logits, loss, gate_means

    def generate(self, idx, max_new_tokens):
        # The generate function is now simpler, as it mirrors one step of the training loop.
        for _ in range(max_new_tokens):
            # Crop the context to the last block_size tokens.
            # This is the context the model will see.
            idx_cond = idx[:, -block_size:]

            # Get the predictions. We pass a high current_iter to ensure
            # the gate warm-up bias is turned off during generation.
            # The forward method no longer takes a 'past_mhas' argument.
            logits, loss, gate_means = self.forward(
                idx_cond,
                current_iter=max_iters,
            )

            # Focus only on the logit for the very last token in the sequence.
            logits = logits[:, -1, :]  # becomes (B, C)

            # Apply softmax to get probabilities.
            probs = F.softmax(logits, dim=-1)  # (B, C)

            # Sample from the distribution to get the next token.
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)

            # Append the sampled index to the running sequence.
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)

        return idx

## Model Training

Now we can instantiate the model and the optimizer. We'll use the AdamW optimizer, which is a standard choice for training Transformers. The training loop will repeatedly sample a batch of data, calculate the loss, and update the model's parameters.

In [None]:
model = DynamicGPT()
activation_log = [[] for _ in range(len(model.blocks))]
m = model.to(device)
# Print the number of parameters in the model.
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# Create a PyTorch optimizer.
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

17.882945 M parameters


In [None]:
pbar = tqdm(range(max_iters))
for it in pbar:
    if it % eval_interval == 0 or it == max_iters - 1:
        losses = estimate_loss()
        print(f"\nstep {it}: train {losses['train']:.4f} │ val {losses['val']:.4f}")

    xb, yb = get_batch("train")
    logits, loss, gate_means = model(xb, yb, current_iter=it)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    # log & live display
    for i, g in enumerate(gate_means):
        activation_log[i].append(g.item())
    pbar.set_postfix(
        **{f"B{i}": f"{g.item():.2f}" for i, g in enumerate(gate_means)}
    )

  0%|          | 1/6000 [00:02<4:52:31,  2.93s/it, B0=1.00, B1=1.00, B2=1.00, B3=1.00, B4=1.00, B5=1.00]


step 0: train 4.3698 │ val 4.3678


  8%|▊         | 501/6000 [01:30<1:29:51,  1.02it/s, B0=1.00, B1=1.00, B2=1.00, B3=0.98, B4=0.98, B5=0.81]


step 500: train 1.7987 │ val 1.9367


 17%|█▋        | 1001/6000 [02:57<1:21:35,  1.02it/s, B0=1.00, B1=1.00, B2=0.98, B3=0.98, B4=0.97, B5=0.64]


step 1000: train 1.4509 │ val 1.6592


 25%|██▌       | 1501/6000 [04:25<1:13:24,  1.02it/s, B0=1.00, B1=1.00, B2=1.00, B3=0.98, B4=0.98, B5=0.66]


step 1500: train 1.3135 │ val 1.5634


 33%|███▎      | 2001/6000 [05:53<1:05:10,  1.02it/s, B0=1.00, B1=0.98, B2=1.00, B3=0.97, B4=0.92, B5=0.53]


step 2000: train 1.2336 │ val 1.5373


 42%|████▏     | 2501/6000 [07:20<57:01,  1.02it/s, B0=1.00, B1=1.00, B2=1.00, B3=0.95, B4=0.91, B5=0.58]


step 2500: train 1.1825 │ val 1.5502


 50%|█████     | 3001/6000 [08:48<48:58,  1.02it/s, B0=1.00, B1=0.98, B2=1.00, B3=0.91, B4=0.84, B5=0.55]


step 3000: train 1.1226 │ val 1.5215


 58%|█████▊    | 3501/6000 [10:15<40:38,  1.03it/s, B0=1.00, B1=0.91, B2=0.95, B3=0.83, B4=0.72, B5=0.47]


step 3500: train 1.1006 │ val 1.5433


 67%|██████▋   | 4001/6000 [11:42<32:32,  1.02it/s, B0=1.00, B1=0.92, B2=0.97, B3=0.83, B4=0.75, B5=0.58]


step 4000: train 1.0595 │ val 1.5490


 75%|███████▌  | 4501/6000 [13:10<24:24,  1.02it/s, B0=1.00, B1=0.94, B2=0.86, B3=0.88, B4=0.78, B5=0.53]


step 4500: train 1.0217 │ val 1.5724


 83%|████████▎ | 5001/6000 [14:37<16:14,  1.02it/s, B0=1.00, B1=0.97, B2=0.83, B3=0.81, B4=0.72, B5=0.58]


step 5000: train 0.9947 │ val 1.5878


 92%|█████████▏| 5501/6000 [16:04<08:07,  1.02it/s, B0=1.00, B1=0.91, B2=0.78, B3=0.73, B4=0.62, B5=0.52]


step 5500: train 0.9483 │ val 1.6047


100%|██████████| 6000/6000 [17:32<00:00,  5.70it/s, B0=1.00, B1=0.98, B2=0.86, B3=0.78, B4=0.67, B5=0.53]


step 5999: train 0.9053 │ val 1.6380





## Text Generation

After training, we can use our model to generate new text. We start with a single token (a newline character in this case) and let the model predict the next token, which we then feed back into the model to predict the next one, and so on.

In [None]:
# Start generation with a single token (0 is the newline character).
context = torch.zeros((1, 1), dtype=torch.long, device=device)
# Generate and decode the output.
print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))


Thy scant blocks sends, that kill'd my kindness lie.

HASTINGS:
Either stay on my son's patience,
Go in my cousin, I'll pale.

GLOUCESTER:
Marry, that I may be find call'd to flatter me;
For farely singly plain the way,
No shallow thy right sleep; crave it my soul,
Witness majesty a purnting cot;
A vexation as a paranling one,
To meet him dead. All seven or hands,
Are your subjected by her:
Who intending almost I am ost answer,
He should not have slave counsel to me.

QUEEN ELIZABETH:
Gheen his 


## Analysis of Average Activation

In [None]:
# This cell works as-is, but the interpretation of the data changes.
act_df = pd.DataFrame(activation_log).T
act_df.columns = [f"block_{i}_avg_act" for i in range(len(model.blocks))]

act_df.to_csv("block_activation_rates.csv", index=False)

print("--- Activation Rate Log ---")
print(act_df.head())
print("\n--- Summary Statistics ---")
print(act_df.describe())

--- Activation Rate Log ---
   block_0_avg_act  block_1_avg_act  block_2_avg_act  block_3_avg_act  \
0         1.000000         1.000000         1.000000         1.000000   
1         1.000000         1.000000         1.000000         1.000000   
2         0.984375         0.984375         0.984375         0.984375   
3         1.000000         1.000000         1.000000         1.000000   
4         1.000000         1.000000         1.000000         1.000000   

   block_4_avg_act  block_5_avg_act  
0         1.000000         1.000000  
1         1.000000         1.000000  
2         0.984375         0.984375  
3         1.000000         1.000000  
4         1.000000         1.000000  

--- Summary Statistics ---
       block_0_avg_act  block_1_avg_act  block_2_avg_act  block_3_avg_act  \
count      6000.000000      6000.000000      6000.000000      6000.000000   
mean          0.998070         0.962573         0.939419         0.891195   
std           0.010731         0.032564       

In [None]:
act_df

Unnamed: 0,block_0_avg_act,block_1_avg_act,block_2_avg_act,block_3_avg_act,block_4_avg_act,block_5_avg_act
0,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000
1,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000
2,0.984375,0.984375,0.984375,0.984375,0.984375,0.984375
3,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000
4,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000
...,...,...,...,...,...,...
5995,1.000000,0.937500,0.734375,0.671875,0.593750,0.531250
5996,1.000000,0.906250,0.734375,0.703125,0.562500,0.546875
5997,1.000000,0.984375,0.796875,0.765625,0.671875,0.546875
5998,1.000000,0.921875,0.781250,0.718750,0.656250,0.484375
