<a href="https://colab.research.google.com/github/Nebius-Academy/LLM-Engineering-Essentials/blob/main/topic4/4.2_dissecting_an_llm_solutions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# LLM Engineering Essentials by Nebius Academy

Course github: [link](https://github.com/Nebius-Academy/LLM-Engineering-Essentials/tree/main)

The course is in development now, with more materials coming soon.

# 4.2. Dissecting an LLM

# Practice task solutions

## Task 1. Assessing the artchitecture of **gemma-3-1b-it**

In this task, you'll need to study the architecture of [**gemma-3-1b-it**](https://huggingface.co/google/gemma-3-1b-it) by Google:

- How many layers, attention heads, and query heads per key/value head does it have?
- What is the internal structure of the FFN block?
- How many parameters does the LLM have in self-attention vs in FFN blocks?
- Which type of attention masking (multiplicative or additive) does it use?

## Task 2. Key-value cache

Each time an LLM generates a new token $y_t$, it needs to calculate, at each self-attention layer,

$$(q_tk_1^T)v_1 + (q_tk_2^T)v_2 \ldots + (q_tk_2^T)v_2$$

and for that, you need all the keys $k_i = x_iW_K$ and values $v_i = x_iW_V$, where $x_i$ are inputs for the self-attention layer. Theoretically, this would require recalculating all the $k_i$ and $v_i$ *for every newly generated token* which is a terrible waste of compute.

So, in most situations **keys and values are cached**, and the data structure in which they are stored is known as **KV-cache**.

<center>
<img src="https://drive.google.com/uc?export=view&id=1X4kCfcpAXNAcHrVsGVpsOk6roHL0AOnT" width=600 />
</center>

Of course, caching has its own downsides: the LLM's appetites for additional memory increase linearly as the sequence length grows.

### Two stages of LLM inference

Introduction of KV-cache makes more clear the necessity of distinguising between the following two stages of LLM inference:

1. At the **Cache pre-fill** stage, the prompt is processed and the KV-cache for prompt tokens is populated.
2. At the **Autoregressive generation** stage, new tokens are generated, one by one.

### Your task

In this task, you'll need to update the implementation of Scaled dot product attention layer given below, adding KN-cache support to it. Make sure that both key and value vectors ($k = xW_K$ and $v = xW_V$) are only calculated once for each token.

Assume that during LLM inference, the transformer's `forward` method is called by an external `generate` function, which

* At the **cache pre-fill stage** runs `forward` on the whole prompt
* At the **autoregressive generation stage** only sends the current sequence's last token to `forward`.

You can have in mind the following generation process:

```python
def generate(model, prompt_tokens, max_new_tokens=50, eos_token_id=None):
    """
    Minimalistic generation showing prefill and generation phases.
    
    Args:
        model: Transformer model with KV-cache support
        prompt_tokens: Input prompt tensor (batch_size, prompt_len, d_model)
        max_new_tokens: Maximum number of tokens to generate
        eos_token_id: End-of-sequence token ID
    """
    # Phase 1: Cache pre-fill - processes entire prompt at once
    model.clear_cache()
    output = model(prompt_tokens, is_causal=True, use_cache=True)
    
    # Get last token's output as starting point for generation
    next_token_logits = output[:, -1, :]  # (batch_size, d_model)
    
    generated_tokens = []
    
    # Phase 2: Autoregressive generation - processes one token at a time
    for i in range(max_new_tokens):
        # Sample next token (simplified - just using argmax)
        next_token_id = next_token_logits.argmax(dim=-1, keepdim=True)
        
        # Check for EOS
        # In this implementation, batch generation stops when the longest continuation in a batch is generated
        if eos_token_id is not None and (next_token_id == eos_token_id).all():
            break
        
        # Process single new token using cached K,V
        output = model(next_token_id, use_cache=True)
        
        # Get logits for next token prediction
        next_token_logits = output[:, -1, :]
        
        generated_tokens.append(next_token_id)
    
    return torch.cat(generated_tokens, dim=1)  # (batch_size, num_generated)
```




In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class ScaledDotProductAttention(nn.Module):
    """
    Scaled Dot-Product Attention without KV-cache.

    Based on PyTorch's scaled_dot_product_attention implementation,
    wrapped in a class with linear projections for Q, K, V.
    """

    def __init__(self, d_model, n_heads, dropout_p=0.0):
        """
        Args:
            d_model: Dimension of the model (must be divisible by n_heads)
            n_heads: Number of attention heads
            dropout_p: Dropout probability for attention weights
        """
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"

        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads  # Dimension per head
        self.dropout_p = dropout_p

        # Linear projections for Q, K, V
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

    def _scaled_dot_product_attention(self, query, key, value, attn_mask=None,
                                      is_causal=False, scale=None):
        """
        Compute scaled dot-product attention.
        Based on PyTorch's implementation.

        Args:
            query: (batch_size, n_heads, L, d_k)
            key: (batch_size, n_heads, S, d_k)
            value: (batch_size, n_heads, S, d_k)
            attn_mask: Optional mask (batch_size, 1, L, S) or (1, 1, L, S)
            is_causal: If True, applies causal masking
            scale: Optional scaling factor (defaults to 1/sqrt(d_k))

        Returns:
            attention_output: (batch_size, n_heads, L, d_k)
            attention_weights: (batch_size, n_heads, L, S)
        """
        L, S = query.size(-2), key.size(-2)
        scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale

        # Initialize attention bias
        attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)

        # Apply causal mask if requested
        if is_causal:
            temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
            attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
            attn_bias = attn_bias.to(query.dtype)

        # Apply custom attention mask if provided
        if attn_mask is not None:
            if attn_mask.dtype == torch.bool:
                attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
            else:
                attn_bias = attn_mask + attn_bias

        # Compute attention scores
        attn_weight = query @ key.transpose(-2, -1) * scale_factor
        attn_weight += attn_bias
        attn_weight = torch.softmax(attn_weight, dim=-1)

        # Apply dropout if in training mode
        if self.dropout_p > 0 and self.training:
            attn_weight = F.dropout(attn_weight, p=self.dropout_p)

        # Compute attention output
        attention_output = attn_weight @ value

        return attention_output, attn_weight

    def forward(self, x, attn_mask=None, is_causal=False):
        """
        Forward pass for scaled dot-product attention.

        Args:
            x: Input tensor of shape (batch_size, seq_len, d_model)
            attn_mask: Optional attention mask. Can be:
                       - Boolean tensor where True = attend, False = mask
                       - Float tensor with additive attention scores
                       Shape: (batch_size, 1, seq_len, seq_len) or (1, 1, seq_len, seq_len)
            is_causal: If True, applies causal masking (attn_mask must be None)

        Returns:
            output: Attention output of shape (batch_size, seq_len, d_model)
            attention_weights: Attention weights of shape (batch_size, n_heads, seq_len, seq_len)
        """
        batch_size, seq_len, _ = x.shape

        # Linear transformations
        Q = self.W_q(x)  # (batch_size, seq_len, d_model)
        K = self.W_k(x)  # (batch_size, seq_len, d_model)
        V = self.W_v(x)  # (batch_size, seq_len, d_model)

        # Reshape for multi-head attention
        # Shape: (batch_size, n_heads, seq_len, d_k)
        Q = Q.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)

        # Compute scaled dot-product attention
        context, attention_weights = self._scaled_dot_product_attention(
            Q, K, V, attn_mask=attn_mask, is_causal=is_causal
        )

        # Concatenate heads and apply output projection
        # First transpose: (batch_size, seq_len, n_heads, d_k)
        # Then reshape: (batch_size, seq_len, d_model)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        output = self.W_o(context)

        return output, attention_weights

You can use the following code to test your implementation:

In [None]:
import torch
import torch.testing
import time


def test_kv_cache_correctness(attention_module, d_model=512, n_heads=8, seq_len=20, batch_size=2):
    """Test that cached and non-cached modes produce identical results."""
    print("\n=== Test 1: Correctness - Cached vs Non-cached ===")

    # Create input sequence
    x = torch.randn(batch_size, seq_len, d_model)

    # Method 1: Process entire sequence without cache
    attention_module.clear_cache()
    output_no_cache, weights_no_cache = attention_module(x, is_causal=True, use_cache=False)

    # Method 2: Process sequence token-by-token with cache
    attention_module.clear_cache()
    outputs_cached = []

    for i in range(seq_len):
        token = x[:, i:i+1, :]
        output, _ = attention_module(token, use_cache=True)
        outputs_cached.append(output)

    output_cached = torch.cat(outputs_cached, dim=1)

    # Compare
    torch.testing.assert_close(output_no_cache, output_cached, rtol=1e-5, atol=1e-5)
    print("Outputs match for implementations without and with cache!")
    print(f"  Max difference: {(output_no_cache - output_cached).abs().max().item():.2e}")


def test_kv_cache_incremental(attention_module, d_model=512, n_heads=8, batch_size=2):
    """Test incremental cache building."""
    print("\n=== Test 2: Incremental Cache Building ===")

    attention_module.clear_cache()
    cache_sizes = []

    # Add tokens one by one
    for i in range(10):
        token = torch.randn(batch_size, 1, d_model)
        output, _ = attention_module(token, use_cache=True)
        cache_size = attention_module.get_cache_size()
        cache_sizes.append(cache_size)
        print(f"  Step {i+1}: Added 1 token, cache size = {cache_size}")

    # Verify cache grows linearly
    expected_sizes = list(range(1, 11))
    assert cache_sizes == expected_sizes, f"Cache sizes {cache_sizes} != expected {expected_sizes}"
    print("Cache grows linearly, as expected!")


def test_prefill_vs_generation(attention_module, d_model=512, n_heads=8, batch_size=2):
    """Test prefill phase vs generation phase behavior."""
    print("\n=== Test 3: Prefill vs Generation Phases ===")

    # Prefill: Process 10 tokens at once
    attention_module.clear_cache()
    prompt = torch.randn(batch_size, 10, d_model)
    output_prefill, _ = attention_module(prompt, is_causal=True, use_cache=True)
    cache_after_prefill = attention_module.get_cache_size()
    print(f"  Prefill: Processed 10 tokens, cache size = {cache_after_prefill}")

    # Generation: Add 5 more tokens one by one
    for i in range(5):
        new_token = torch.randn(batch_size, 1, d_model)
        output_gen, weights = attention_module(new_token, use_cache=True)
        print(f"  Generation step {i+1}: cache size = {attention_module.get_cache_size()}, "
              f"attention weights shape = {weights.shape}")

    assert attention_module.get_cache_size() == 15, "Cache should have 15 tokens total"
    print("Prefill and generation phases seem to work correctly!")


def test_attention_pattern(attention_module, d_model=512, n_heads=8, batch_size=1):
    """Visualize attention patterns to verify causality."""
    print("\n=== Test 4: Attention Pattern Visualization ===")

    attention_module.clear_cache()
    seq_len = 5
    x = torch.randn(batch_size, seq_len, d_model)

    # Get attention weights
    _, weights = attention_module(x, is_causal=True, use_cache=False)

    # Show attention pattern for first head
    attn_pattern = weights[0, 0, :, :].detach()

    # Verify causality
    for i in range(seq_len):
        for j in range(i+1, seq_len):
            assert attn_pattern[i, j] < 1e-5, f"Token {i} shouldn't attend to future token {j}"
    print("Causal masking verified!")


def test_performance_speedup(AttentionClass, d_model=512, n_heads=8, batch_size=2):
    """Measure actual speedup from KV-cache."""
    print("\n=== Test 5: Performance Speedup ===")

    # Create two instances to avoid cache interference
    attn_no_cache = AttentionClass(d_model, n_heads)
    attn_with_cache = AttentionClass(d_model, n_heads)

    prompt_len = 100
    gen_len = 50

    # Warm up
    x = torch.randn(batch_size, 10, d_model)
    attn_no_cache(x)
    attn_with_cache(x)

    # Test without cache
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    start = time.time()

    full_seq = torch.randn(batch_size, prompt_len + gen_len, d_model)
    for i in range(prompt_len, prompt_len + gen_len):
        seq_so_far = full_seq[:, :i+1, :]
        _ = attn_no_cache(seq_so_far, is_causal=True, use_cache=False)

    time_no_cache = time.time() - start

    # Test with cache
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    start = time.time()

    # Prefill
    prompt = full_seq[:, :prompt_len, :]
    _ = attn_with_cache(prompt, is_causal=True, use_cache=True)

    # Generation
    for i in range(gen_len):
        new_token = full_seq[:, prompt_len + i:prompt_len + i + 1, :]
        _ = attn_with_cache(new_token, use_cache=True)

    time_with_cache = time.time() - start

    speedup = time_no_cache / time_with_cache
    print(f"  Without cache: {time_no_cache:.3f}s")
    print(f"  With cache: {time_with_cache:.3f}s")
    print(f"  Speedup: {speedup:.1f}x")


def test_batch_consistency(attention_module, d_model=512, n_heads=8):
    """Test that different batch sizes work correctly."""
    print("\n=== Test 6: Batch Size Handling ===")

    # Start with batch size 2
    attention_module.clear_cache()
    x1 = torch.randn(2, 5, d_model)
    out1, _ = attention_module(x1, use_cache=True)

    # Try to process with different batch size (should handle gracefully)
    try:
        x2 = torch.randn(3, 1, d_model)
        out2, _ = attention_module(x2, use_cache=True)
        print("Should not accept different batch sizes with existing cache")
    except (RuntimeError, AssertionError):
        print("Correctly prevents batch size mismatch")

    # Clear and try again
    attention_module.clear_cache()
    out2, _ = attention_module(x2, use_cache=True)
    print("Works with new batch size after clearing cache")


# Run all tests
if __name__ == "__main__":
    print("Testing KV-Cache Implementation")
    print("=" * 50)

    # Create attention module
    attention = ScaledDotProductAttentionWithCache(d_model=512, n_heads=8)
    attention.eval()  # Disable dropout

    # Run tests
    test_kv_cache_correctness(attention)
    test_kv_cache_incremental(attention)
    test_prefill_vs_generation(attention)
    test_attention_pattern(attention)
    test_performance_speedup(ScaledDotProductAttentionWithCache)
    test_batch_consistency(attention)

    print("\n" + "=" * 50)
    print("All tests completed!")

Testing KV-Cache Implementation

=== Test 1: Correctness - Cached vs Non-cached ===
Outputs match for implementations without and with cache!
  Max difference: 5.66e-07

=== Test 2: Incremental Cache Building ===
  Step 1: Added 1 token, cache size = 1
  Step 2: Added 1 token, cache size = 2
  Step 3: Added 1 token, cache size = 3
  Step 4: Added 1 token, cache size = 4
  Step 5: Added 1 token, cache size = 5
  Step 6: Added 1 token, cache size = 6
  Step 7: Added 1 token, cache size = 7
  Step 8: Added 1 token, cache size = 8
  Step 9: Added 1 token, cache size = 9
  Step 10: Added 1 token, cache size = 10
Cache grows linearly, as expected!

=== Test 3: Prefill vs Generation Phases ===
  Prefill: Processed 10 tokens, cache size = 10
  Generation step 1: cache size = 11, attention weights shape = torch.Size([2, 8, 1, 11])
  Generation step 2: cache size = 12, attention weights shape = torch.Size([2, 8, 1, 12])
  Generation step 3: cache size = 13, attention weights shape = torch.Size([

**Solution**

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class ScaledDotProductAttentionWithCache(nn.Module):
    """
    Scaled Dot-Product Attention with KV-cache support.

    Supports two phases:
    1. Prefill: Process multiple tokens at once (e.g., prompt processing)
    2. Generation: Process one token at a time, using cached K,V from previous tokens
    """

    def __init__(self, d_model, n_heads, dropout_p=0.0):
        """
        Args:
            d_model: Dimension of the model (must be divisible by n_heads)
            n_heads: Number of attention heads
            dropout_p: Dropout probability for attention weights
        """
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"

        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.dropout_p = dropout_p

        # Linear projections for Q, K, V
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

        # KV-cache: stores past keys and values
        self.cache_k = None
        self.cache_v = None

    def _scaled_dot_product_attention(self, query, key, value, attn_mask=None,
                                      is_causal=False, scale=None):
        """
        Compute scaled dot-product attention (unchanged from base implementation).
        """
        L, S = query.size(-2), key.size(-2)
        scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale

        # Initialize attention bias
        attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)

        # Apply causal mask if requested
        if is_causal:
            temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
            attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
            attn_bias = attn_bias.to(query.dtype)

        # Apply custom attention mask if provided
        if attn_mask is not None:
            if attn_mask.dtype == torch.bool:
                attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
            else:
                attn_bias = attn_mask + attn_bias

        # Compute attention scores
        attn_weight = query @ key.transpose(-2, -1) * scale_factor
        attn_weight += attn_bias
        attn_weight = torch.softmax(attn_weight, dim=-1)

        # Apply dropout if in training mode
        if self.dropout_p > 0 and self.training:
            attn_weight = F.dropout(attn_weight, p=self.dropout_p)

        # Compute attention output
        attention_output = attn_weight @ value

        return attention_output, attn_weight

    def forward(self, x, attn_mask=None, is_causal=False, use_cache=False, cache_position=None):
        """
        Forward pass with KV-cache support.

        Args:
            x: Input tensor of shape (batch_size, seq_len, d_model)
               - During prefill: seq_len can be > 1 (processing prompt)
               - During generation: seq_len = 1 (processing one new token)
            attn_mask: Optional attention mask
            is_causal: If True, applies causal masking
            use_cache: If True, uses and updates KV-cache
            cache_position: Optional tensor indicating position in sequence for each token
                          Shape: (seq_len,) - useful for handling dynamic lengths

        Returns:
            output: Attention output of shape (batch_size, seq_len, d_model)
            attention_weights: Attention weights of shape (batch_size, n_heads, seq_len, total_seq_len)
        """
        batch_size, seq_len, _ = x.shape

        # Compute Q, K, V projections
        Q = self.W_q(x)  # Always compute Q for current tokens
        K = self.W_k(x)  # Compute K for current tokens
        V = self.W_v(x)  # Compute V for current tokens

        # Reshape for multi-head attention
        Q = Q.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)

        if use_cache:
            # Phase detection (prefill vs generation) based on cache state and sequence length
            if self.cache_k is None:
                # Initialize cache with current K, V
                self.cache_k = K
                self.cache_v = V
            else:
                # Concatenate new K, V with cached K, V
                self.cache_k = torch.cat([self.cache_k, K], dim=-2)
                self.cache_v = torch.cat([self.cache_v, V], dim=-2)

            # Use full cached K, V for attention computation
            K_for_attn = self.cache_k
            V_for_attn = self.cache_v

            # Update mask dimensions if needed
            # Q has shape (batch, heads, seq_len, d_k)
            # K_for_attn has shape (batch, heads, total_cached_len, d_k)
            if is_causal and seq_len == 1:
                # During generation, we have only one current token and no future tokens
                # We attend to all previous cached positions
                # No need to mask since we're looking at all past tokens
                # Setting is_casual to false allows to save a bit of time we'd waste on masking
                is_causal = False
        else:
            # No cache - use current K, V as-is
            K_for_attn = K
            V_for_attn = V

        # Compute attention
        context, attention_weights = self._scaled_dot_product_attention(
            Q, K_for_attn, V_for_attn, attn_mask=attn_mask, is_causal=is_causal
        )

        # Reshape and project output
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        output = self.W_o(context)

        return output, attention_weights

    def clear_cache(self):
        """Clear the KV-cache."""
        self.cache_k = None
        self.cache_v = None

    def get_cache_size(self):
        """Get current cache size."""
        if self.cache_k is None:
            return 0
        return self.cache_k.size(-2)  # Return sequence length dimension

## Task 3. Interpreting LLMs with LogitLens

Logit Lens is an interpretation technique introduced in [this post](https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens). The idea is the following. Imagine that we predict the continuation of a phrase "IPhone was developed by". We naturally expect to see "Apple", but we're also curious to see the "thought process" of an LLM, so we **feed outputs of intermediate layers (intermediate transformer blocks) to the classification head** to see *what would an LLM output if we cut its "thought process" short in the middle of it*. The general trend, as one moves from earlier to later layers, is
- "nonsense / not interpretable" (sometimes, in very early layers) -->
  - "shallow guesses" (words that are the right part of speech / register / etc) -->
- "better guesses" near the end.
However, it's not always like that, of course.

The author of the Logit Lens also created visualization tools and published a [jupyter notebook demo](https://colab.research.google.com/drive/1MjdfK2srcerLrAJDRaJQKO0sUiZ-hQtA?usp=sharing) with cool pictures, but in this task you'll need to reproduce the Logit Lens technique on your own.

**Your task**. Write a function

```
logit_lens(model, input_sentence, top_k)
```

that for each transformer block returns a dictionary

```
{
    'top_tokens' : [
        sorted list of top_k tokens,
        from most probable to least probable,
        according to the classification head
        ],
    'top_token_logits' : [logits of these tokens]
}
```

You can either use Pytorch hooks or just `model(**encoded_input, output_hidden_states=True)`.


Here is how it should work:

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "Qwen/Qwen2.5-3B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto"
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/7.30k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/2.78M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/1.67M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/7.03M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/661 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/35.6k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/3.97G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/242 [00:00<?, ?B/s]

And here comes the logit lens:

In [None]:
import numpy as np

def logit_lens(model, input_sentence, top_k=5):
    # Tokenize the inputs: this turns text into a list of token indices
    input_ids = tokenizer(input_sentence, return_tensors='pt').to(device)

    # This runs the forward pass, calculating all the hidden states thanks to the output_hidden_states=True
    model_output = model(**input_ids,  output_hidden_states=True)
    hidden_states = model_output.hidden_states

    # This is the LLM's final layer: it takes transformer's output (one vector)
    # and for each possible token on the vocabulary predicts its probability
    model_head = model.lm_head

    result = []

    for layer in hidden_states:

        """
        YOUR TASK HERE IS AS FOLLOWS:

        The layer's output consists of a number of hidden state vectors: one vector for each input token
        You need to:
        1. take the last of the outputs and to apply the lm_head to it
        2. find k tokens with top probabilities. The function torch.topk might help
            Just check what torch.topk outputs
        3. return:
        output = {
            "top_tokens": the array of top tokens. Don't forget to decode them with tokenizer.decode
            "top_token_logits": the array of their logits. Don't forget to detach the tensor and to convert it to numpy!
        }
        Note that we need to decode
        """
        # <YOUR CODE HERE>
        result.append(output)
    return result


In [None]:
result = logit_lens(model, "IPhone was developed by", top_k=5)

In [None]:
result[-2:]

[{'top_tokens': [' ', ',', ' a', ' (', '\n'],
  'top_token_logits': array([77.5 , 66.  , 64.5 , 60.  , 58.75], dtype=float32)},
 {'top_tokens': [' Microsoft', ' Apple', ' a', ' the', ' an'],
  'top_token_logits': array([18.25 , 18.125, 16.75 , 16.5  , 14.875], dtype=float32)}]

As you see, "Apple" appears as the most probable token in the last two layers.

Now, let's use Logit Lens to investigate how transformers deal with redefinition.

We'll use Logit Lens on the sentence

```
"In this text the word IPhone means Windows operating system. IPhone was developed by"
```

Run the following cell and look at the most probable tokens for all layers. A good LLM knows that IPhone was developed by Apple through *memorization*. However, *in-context learning* will press it to output Microsoft. Check in which layers the most probable token is Microsoft and in which it is Apple.

In [None]:
result = logit_lens(model, "In this text the word IPhone means Windows operating system. IPhone was developed by", top_k=5)
result

Try more prompts and other models. Have fun trying to interpret whatever is inside LLMs :)

**Solution**



In [None]:
import numpy as np
import torch

def logit_lens(model, input_sentence, top_k=5):
    input_ids = tokenizer(input_sentence, return_tensors='pt').to(device=model.device)
    model_output = model(**input_ids,  output_hidden_states=True)
    hidden_states = model_output.hidden_states
    model_head = model.lm_head
    result = []
    for layer in hidden_states:
        last_token_layer = layer[0][-1]
        layer_logits = model_head(last_token_layer)
        top_values = torch.topk(layer_logits,k=top_k)
        indices = top_values.indices
        values = top_values.values
        output = {
            "top_tokens": [tokenizer.decode(ind) for ind in indices],
            "top_token_logits": values.to(torch.float32).detach().cpu().numpy()
        }
        result.append(output)
    return result


Now, let's run several experiments. In the Microsoft vs Apple example, Apple is switched to Microsoft only at the final level.

By the way, notice the mysterious pre-final level. In this model, it frequently outputs spaces, commas etc as the most frequent tokens. That's curious.

In [None]:
result = logit_lens(model, "In this text the word IPhone means Windows operating system. IPhone was developed by", top_k=5)

In [None]:
result

[{'top_tokens': [' by', 'by', ' By', 'By', '_by'],
  'top_token_logits': array([0.96484375, 0.7109375 , 0.5703125 , 0.55078125, 0.546875  ],
        dtype=float32)},
 {'top_tokens': [' virtue', '-products', ' dint', ' leaps', 'means'],
  'top_token_logits': array([4.      , 3.4375  , 2.953125, 2.59375 , 2.578125], dtype=float32)},
 {'top_tokens': [' virtue', '-products', ' dint', 'oncé', ' leaps'],
  'top_token_logits': array([3.78125 , 3.3125  , 2.9375  , 2.734375, 2.640625], dtype=float32)},
 {'top_tokens': [' virtue', '-products', ' dint', '/by', '经验'],
  'top_token_logits': array([4.      , 3.296875, 3.015625, 2.796875, 2.78125 ], dtype=float32)},
 {'top_tokens': [' virtue', '-products', '/by', ' dint', 'rne'],
  'top_token_logits': array([4.125   , 3.40625 , 3.140625, 2.984375, 2.8125  ], dtype=float32)},
 {'top_tokens': [' virtue', ' dint', '-products', '/by', 'rne'],
  'top_token_logits': array([4.3125  , 3.40625 , 3.328125, 3.078125, 3.078125], dtype=float32)},
 {'top_tokens': 

Now, let's try some arithmetics. A curious thing here is that '三' is the Chinese character for 3.

In [None]:
result = logit_lens(model, "12 + 23 = ", top_k=5)
result

[{'top_tokens': [' ', ',', '1', ' (', '.'],
  'top_token_logits': array([1.6953125 , 0.96484375, 0.94140625, 0.93359375, 0.88671875],
        dtype=float32)},
 {'top_tokens': ['иру', '<![', 'irc', 'apt', 'ESC'],
  'top_token_logits': array([1.6875   , 1.5      , 1.359375 , 1.3515625, 1.3359375],
        dtype=float32)},
 {'top_tokens': ['】,【', '<![', 'ım', '();)', 'иру'],
  'top_token_logits': array([2.34375 , 2.171875, 2.171875, 2.09375 , 2.0625  ], dtype=float32)},
 {'top_tokens': ['<![', '】,【', '__()', 'ım', 'фр'],
  'top_token_logits': array([2.46875 , 2.4375  , 2.34375 , 2.34375 , 2.328125], dtype=float32)},
 {'top_tokens': ['】,【', '一个多', '\')."', '__()', "')}"],
  'top_token_logits': array([2.671875, 2.375   , 2.375   , 2.328125, 2.328125], dtype=float32)},
 {'top_tokens': ['<![', 'TAIL', '一个多', 'TTY', "')}"],
  'top_token_logits': array([2.765625, 2.609375, 2.59375 , 2.59375 , 2.578125], dtype=float32)},
 {'top_tokens': ["')}", '******/', '(___', '或多', '__()'],
  'top_token_logi

Let's also check what's Qwen's favourite library to be imported after `numpy`. As you can see, `matplolib` and `pandas` only replace `random` at the top positions at the final layer.

In [None]:
result = logit_lens(model, "import numpy as np\nimport ", top_k=5)
result

[{'top_tokens': [' ', ',', '1', ' (', '.'],
  'top_token_logits': array([1.6953125 , 0.96484375, 0.94140625, 0.93359375, 0.88671875],
        dtype=float32)},
 {'top_tokens': ['ESC', 'иру', 'ici', 'once', 'или'],
  'top_token_logits': array([1.6875   , 1.5234375, 1.484375 , 1.4609375, 1.4296875],
        dtype=float32)},
 {'top_tokens': ['ESC', '✦', 'rtc', 'или', '�'],
  'top_token_logits': array([2.234375 , 1.9765625, 1.9609375, 1.8828125, 1.828125 ],
        dtype=float32)},
 {'top_tokens': ['ESC', 'Locator', '(`/', 'opens', 'ASI'],
  'top_token_logits': array([2.90625 , 2.25    , 2.234375, 2.15625 , 2.140625], dtype=float32)},
 {'top_tokens': ['ESC', 'Locator', '(`/', 'SystemService', 'opens'],
  'top_token_logits': array([2.859375, 2.734375, 2.578125, 2.515625, 2.515625], dtype=float32)},
 {'top_tokens': ['ibe', 'ESC', 'SystemService', 'roman', 'eso'],
  'top_token_logits': array([3.3125  , 3.078125, 3.03125 , 2.90625 , 2.875   ], dtype=float32)},
 {'top_tokens': ['中国大陆', '官宣', '信息

## Task 4. Steering LLM generation via Activation Interventions

Imagine you want to change the style of your LLM's output. We already know how to do this using clever prompting and few-shot examples - but there's another approach: directly modifying the model's internal activations.

The method you'll explore in this task was, to the best of our knowledge, first introduced in [this post at LessWrong](https://www.lesswrong.com/posts/5spBue2z2tw4JuDCx/steering-gpt-2-xl-by-adding-an-activation-vector). Conceptually, however, it draws inspiration from techniques used in image editing via latent space manipulation.

The core idea is that some abstract concepts - such as a particular writing style - may be represented as directions in the latent spaces of one of the model's hidden layers. If you can identify the right layer and find the appropriate direction vector $d$, then you can steer the model's output by shifting a hidden state $x$ along this vector:

* $x \mapsto x + d$ reinforces the concept.

* $x \mapsto x - d$ suppresses the concept, or enhances its opposite.

For example, if $d =$ (poetic style - bureaucratic style), then adding $d$ to the hidden state may make the output more poetic, while subtracting it may push the model toward bureaucratic language.

But where to get the vector $d$? In the most basic case, it can be obtained from just two prompts:

* <font color="red">Negative prompt</font> embodying the opposite concept and
* <font color="blue">Positive prompt</font> epitomizing the concept we are interested it.

Now, if we want to manupilate the outputs of $L$-th layer, we do the following:

1. We find $\color{red}{x_-}$ and $\color{blue}{x_+}$ - outputs of the $L$-th layer for the *final* tokens of <font color="red">Negative prompt</font> and <font color="blue">Positive prompt</font>
2. We take $d = \color{blue}{x_+} - \color{red}{x_-}$ as the concept's direction.

Why this might work? The final token of a prompt receives attention from all the previous tokens, so at aggregates in a sense whatever's happening in the prompt at the $L$-th layer. Now, if the prompts are expressive enought, the difference might $d$ might capture the concept gap.

**Your task** is to complete the `LLMSteering` class code below, implementing this simple LLM steering technique.

Some notes about the implementation:

* You'll need to use **Pytorch hooks** for that.
* The `strength` parameter is the coefficient $\mathbf{s}$ in $d = \mathbf{s}(x_+ - x_-)$.
* The **model** is initialized externally and can be used in many `LLMSteering` instances. And rightfully so - if you initialize a new model in each `LLMSteering`, you'll soon run out of GPU memory.
* Don't forget that crashed hooks don't get removed by `hook.remove()`, only by `layer._forward_hooks.clear()`. Since you don't want to corrupt the model, we suggest having some emergency measures to get rid of failed hooks.

When you finish with the implementation, experiment with steering:

1. Try several prompt pairs for different concepts or styles: literary vs bureaucratic, polite vs angry etc.
2. Try larger and smaller `strength`
3. Try different layers; find the best one for your concept.

  You may try steering at several layers at once, but this will unlikely be a good idea: the effects at several layers will interfere unpredictably.

4. (Optional) You can try steering for two concepts at once, computing directions $d_1$ and $d_2$ for them independently and then adding them to the hidden state. However, simple $x\mapsto x + d_1 + d_2$ might work poorly. It's better to use [SLERP](https://en.wikipedia.org/wiki/Slerp) for that.

  And, of course, you'll need to update the class.

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch, functools


model_name = "Qwen/Qwen2.5-3B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto"
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/7.30k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/2.78M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/1.67M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/7.03M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/661 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/35.6k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/3.97G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/242 [00:00<?, ?B/s]

Here's the template:

In [None]:
import torch
import functools
from typing import Optional

class LLMSteering:
    """
    Simple activation steering for LLMs
    """

    def __init__(self, model, tokenizer):
        """
        Initialize the steering controller.

        Args:
            model: The LLM
            tokenizer: The tokenizer
        """
        self.model = model
        self.tokenizer = tokenizer
        self.device = model.device

        # Current steering state
        self.hook = None
        self.current_layer = None

    def _get_hidden_vector(self, text: str, layer: int) -> torch.Tensor:
        """Get the last token's hidden state from specified layer."""
        # <YOUR CODE HERE>
        return   # [hidden_dim]

    def _steering_hook(self, steering_vector: torch.Tensor, module, inp, out):
        """
        Forward hook that adds steering vector to hidden states.

        PyTorch forward hook signature: hook(module, input, output) -> None or modified output
        """
        # <YOUR CODE HERE>

    def set_steering(self, positive_prompt: str, negative_prompt: str, layer: int, strength: float = 1.0):
        """
        Set up steering with given prompts.

        Args:
            positive_prompt: Target behavior prompt
            negative_prompt: Avoid behavior prompt
            layer: Layer to apply steering (0-indexed)
            strength: Multiplier for steering strength
        """
        # <YOUR CODE HERE>
        # Don't forget to remove the existing hook

    def remove_steering(self):
        """Removes steering hook"""
        if self.hook is not None:
            self.hook.remove()
            self.hook = None
            self.current_layer = None

    def force_clear_all_hooks(self):
        """Clears the hook"""
        if self.current_layer is not None:
            self.model.model.layers[self.current_layer]._forward_hooks.clear()
            print("Cleared hooks")

        self.hook = None
        self.current_layer = None

    def generate(self, prompt: str, max_new_tokens: int = 50, **kwargs) -> str:
        """Generate text with current steering settings."""

        # <YOUR CODE HERE>

    def __del__(self):
        """Cleanup when object is destroyed."""
        self.remove_steering()



**Solution**.

In [None]:
import torch
import functools
from typing import Optional

class LLMSteering:
    """
    Simple activation steering for LLMs.
    """

    def __init__(self, model, tokenizer):
        """
        Initialize the steering controller.

        Args:
            model: The LLM
            tokenizer: The tokenizer
        """
        self.model = model
        self.tokenizer = tokenizer
        self.device = model.device

        # Current steering state
        self.hook = None
        self.current_layer = None

    def _get_hidden_vector(self, text: str, layer: int) -> torch.Tensor:
        """Get the last token's hidden state from specified layer."""
        inp = self.tokenizer(text, return_tensors="pt").to(self.device)
        with torch.no_grad():
            out = self.model(**inp, use_cache=False, output_hidden_states=True)
        return out.hidden_states[layer][0, -1]  # [hidden_dim]

    def _steering_hook(self, steering_vector: torch.Tensor, module, inp, out):
        """
        Forward hook that adds steering vector to hidden states.

        PyTorch forward hook signature: hook(module, input, output) -> None or modified output
        """
        try:
            if isinstance(out, tuple):
                # Handle tuple output (hidden_states, attention_weights, ...)
                hidden_states = out[0]
                steered = hidden_states + steering_vector.unsqueeze(0).unsqueeze(0)
                return (steered,) + out[1:]
            else:
                # Handle tensor output
                return out + steering_vector.unsqueeze(0).unsqueeze(0)
        except Exception as e:
            # Hook crashed - immediately clear to prevent model corruption
            print(f"Hook crashed: {e}")
            print("Emergency hook clearing to prevent model corruption...")
            try:
                module._forward_hooks.clear()
                self.hook = None
                self.current_layer = None
            except:
                pass
            # Return original output to keep inference working
            return output

    def set_steering(self, positive_prompt: str, negative_prompt: str, layer: int, strength: float = 1.0):
        """
        Set up steering with given prompts.

        Args:
            positive_prompt: Target behavior prompt
            negative_prompt: Avoid behavior prompt
            layer: Layer to apply steering (0-indexed)
            strength: Multiplier for steering strength
        """
        # Remove any existing steering
        self.remove_steering()

        # Compute steering vector
        h_pos = self._get_hidden_vector(positive_prompt, layer)
        h_neg = self._get_hidden_vector(negative_prompt, layer)
        steering_vec = (h_pos - h_neg) * strength

        # Register hook
        hook_fn = functools.partial(self._steering_hook, steering_vec)
        self.hook = self.model.model.layers[layer].register_forward_hook(hook_fn)
        self.current_layer = layer

        print(f"Steering active on layer {layer}, strength {strength}")

    def remove_steering(self):
        """Remove steering hook with fallback cleanup."""
        if self.hook is not None:
            self.hook.remove()
            self.hook = None
            self.current_layer = None

    def force_clear_all_hooks(self):
        """Clears the hook"""
        if self.current_layer is not None:
            self.model.model.layers[self.current_layer]._forward_hooks.clear()
            print("Cleared hooks")

        self.hook = None
        self.current_layer = None

    def generate(self, prompt: str, max_new_tokens: int = 50, **kwargs) -> str:
        """Generate text with current steering settings."""

        if not self.hook:
            print("No hook defined")
            return None

        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)

        # Default generation params
        generation_kwargs = {
            'do_sample': True,
            'temperature': 0.7,
            'top_p': 0.9,
            'pad_token_id': self.tokenizer.eos_token_id,
            **kwargs
        }

        try:
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    **generation_kwargs
                )

            return self.tokenizer.decode(outputs[0], skip_special_tokens=True)

        except Exception as e:
            # If generation fails and we have steering active, try clearing hooks
            if self.hook:
                print(f"Generation failed with steering active: {e}")
                print("Attempting emergency hook cleanup...")
                self.force_clear_all_hooks()

                return None

    def __del__(self):
        """Cleanup when object is destroyed."""
        self.remove_steering()



Let's try poetic style vs wikipedia style. We used the beginning of Beowulf from [this source](https://www.poetryfoundation.org/poems/50114/beowulf-modern-english-translation) and [this wikipedia article](https://en.wikipedia.org/wiki/Fenoldopam).

I wouldn't say that the LLM grasped the poetry of Beowulf when steered at this particular layer :( But at least we made it excited.

In [None]:
SEED_A = """LO, praise of the prowess of people-kings
of spear-armed Danes, in days long sped,
we have heard, and what honor the athelings won!
Oft Scyld the Scefing from squadroned foes,
from many a tribe, the mead-bench tore,
awing the earls. Since erst he lay
friendless, a foundling, fate repaid him:
for he waxed under welkin, in wealth he throve,
till before him the folk, both far and near,
who house by the whale-path, heard his mandate,
gave him gifts: a good king he!"""

SEED_B = """Fenoldopam is used as an antihypertensive agent postoperatively, and also intravenously (IV) to treat a hypertensive crisis.[4] Since fenoldopam is an intravenous agent with minimal adrenergic effects that improves renal perfusion, in theory it could be beneficial in hypertensive patients with concomitant chronic kidney disease.[5] It can cause reflex tachycardia, but it is dependent on the infusion of the drug. """


steerer = LLMSteering(model, tokenizer)

steerer.set_steering(
    positive_prompt=SEED_A,
    negative_prompt=SEED_B,
    layer=15,
)

result = steerer.generate("In 2025, British scientists discovered")
print(result)


steerer.remove_steering()

Steering active on layer 15, strength 1.0
In 2025, British scientists discovered that the earth was round; the first to know.
The first to see!
What? What's this?
 The sun shines on
 This!—
 Who would dare touch
 Such a sight!
 What will it be
 When it's seen!



And now, vice versa, with a somewhat smaller strength:

In [None]:
steerer.set_steering(
    positive_prompt=SEED_B,
    negative_prompt=SEED_A,
    layer=15,
    strength=0.8
)

result = steerer.generate("Lo, the spears sang loud at the breaking of dawn")
print(result)


steerer.remove_steering()

Steering active on layer 15, strength 0.8
Lo, the spears sang loud at the breaking of dawn, a group of 15 soldiers were randomly deployed in a rectangular grid.3 squads of soldiers are sent to different locations by their respective officers. Each soldier in the first squad is positioned as follows: -2,4 on the left side and


Now, with `strength` too big the outputs deteriorate:

In [None]:
steerer.set_steering(
    positive_prompt=SEED_A,
    negative_prompt=SEED_B,
    layer=15,
    strength=1.5
)

result = steerer.generate("In 2025, British scientists discovered")
print(result)


steerer.remove_steering()

Steering active on layer 15, strength 1.5
In 2025, British scientists discovered; a man
 The
 who
 who
 so bold!
 —
 (—)
 So
 so— (—)
 Who?
 —
 So
 (—)
 — (—)
 —
 —
—
 (—)—
 (


Let's also experiment with different layers. **Qwen2.5-3B** has 36 layers, so it's reasonable to check layers from $\frac14\cdot36=9$ to $\frac34\cdot36=27$, but we'll check all the layers, just out of curiosity.

In [None]:
for n_layer in range(len(model.model.layers)):

    steerer.set_steering(
        positive_prompt=SEED_A,
        negative_prompt=SEED_B,
        layer=n_layer,
        strength=0.9
    )

    print(f"\n\n=== At layer {n_layer} ===\n\n")
    result = steerer.generate("In 2025, British scientists discovered")
    print(result)

    steerer.remove_steering()

Steering active on layer 0, strength 0.9


=== At layer 0 ===


In 2025, British scientists discovered a new element X. This element has 10 protons in its nucleus and is located in the fourth period of the periodic table. Based on this information, how many neutrons does element X have?
To determine the number of neutrons in
Steering active on layer 1, strength 0.9


=== At layer 1 ===


In 2025, British scientists discovered a new element X. They named the element by using the first letter of their names. Now that's the fact! The scientists are:Alan Mathers, Graham Physics, and Marjorie Mathematics! What an amazing discovery!What a wonderful day
Steering active on layer 2, strength 0.9


=== At layer 2 ===


In 2025, British scientists discovered a new planet, which they named "NewBritannia". After exploring this planet, they found an advanced civilization called "The Hexagon" whose technology was far beyond anything humans could ever imagine. They were surprised to find that the citi