In [5]:
from torch.utils.data import DataLoader, Dataset
from typing import List
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import gc

In [6]:
import nnsight
import torch
import json
from transformers import AutoTokenizer
model_name = "Qwen/Qwen3-4B"

model = nnsight.LanguageModel(model_name, device_map="auto", dtype=torch.bfloat16)

In [7]:
from einops import einsum

In [8]:
probes_weights = torch.load('/workspace/llm-progress-monitor/qwen3_4b_weight_tensor.pt').to(dtype=torch.float32)

In [9]:
# Get PCA components 1, 2, 4 using sklearn PCA
from sklearn.decomposition import PCA
import numpy as np
import matplotlib.pyplot as plt

In [10]:
pca = PCA(n_components=3)
pca.fit(probes_weights.detach().cpu().numpy())
# Select components 1, 2, 4 (0-indexed: 0, 1, 3)
selected_components = torch.tensor(pca.components_, dtype=torch.float32).to('cuda', dtype=torch.bfloat16)  # Shape: [3, 2560]
print(f"Selected PCA components shape: {selected_components.shape}")

Selected PCA components shape: torch.Size([3, 2560])


In [11]:
def get_ema_coords(coords, alpha=0.99):
    given_alpha = alpha
    coords_list = coords.tolist()
    
    ema_coords = []
    ema_coords_list = []
    cur_ema = None
    for i, coord in enumerate(coords_list):
        # Use a smooth transition from 0.5 to given_alpha, reaching given_alpha at 200 tokens
        alpha = given_alpha
        if cur_ema is None:
            cur_ema = coord
        else:
            cur_ema = [alpha*(cur_ema[j]) + (1-alpha)*coord[j] for j in range(len(coord))]
        ema_coords.append(cur_ema)
    return torch.tensor(ema_coords)

def process_activations(activations, alpha=0.99):
    seq_coords = einsum(activations, selected_components, 's h, n h -> s n')
    ema_seq_coords = get_ema_coords(seq_coords, alpha=alpha)
    np_coords = ema_seq_coords.detach().cpu().numpy()
    return np_coords

# Process first 10 activation files
all_coords = []
for i in range(0,30):
    activations = torch.load(f'/workspace/llm-progress-monitor/rollouts/activations/{i}.pt')[15]
    np_coords = process_activations(activations)
    all_coords.append(np_coords)
    print(f"Processed file {i}, shape: {np_coords.shape}")

Processed file 0, shape: (1403, 3)
Processed file 1, shape: (4341, 3)
Processed file 1, shape: (4341, 3)
Processed file 2, shape: (447, 3)
Processed file 2, shape: (447, 3)
Processed file 3, shape: (498, 3)
Processed file 3, shape: (498, 3)
Processed file 4, shape: (614, 3)
Processed file 4, shape: (614, 3)
Processed file 5, shape: (1769, 3)
Processed file 5, shape: (1769, 3)
Processed file 6, shape: (2590, 3)
Processed file 6, shape: (2590, 3)
Processed file 7, shape: (444, 3)
Processed file 7, shape: (444, 3)
Processed file 8, shape: (1148, 3)
Processed file 8, shape: (1148, 3)
Processed file 9, shape: (1150, 3)
Processed file 9, shape: (1150, 3)
Processed file 10, shape: (343, 3)
Processed file 10, shape: (343, 3)
Processed file 11, shape: (1707, 3)
Processed file 11, shape: (1707, 3)
Processed file 12, shape: (1651, 3)
Processed file 12, shape: (1651, 3)
Processed file 13, shape: (1717, 3)
Processed file 13, shape: (1717, 3)
Processed file 14, shape: (914, 3)
Processed file 14, sha

In [12]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np

In [13]:
# Create 2x5 subplots for 3D trajectories
fig = make_subplots(
    rows=2, cols=5,
    specs=[[{'type': 'scatter3d'} for _ in range(5)] for _ in range(2)],
    subplot_titles=[f'Sequence {i}' for i in range(10)]
)

In [14]:
# Add each trajectory to its subplot
for i in range(10):
    coords = all_coords[i]
    row = (i // 5) + 1
    col = (i % 5) + 1
    
    # Create color values as token position
    color_values = list(range(len(coords)))
    
    fig.add_trace(
        go.Scatter3d(
            x=coords[:, 0],
            y=coords[:, 1], 
            z=coords[:, 2],
            mode='markers+lines',
            marker=dict(
                size=3,
                color=color_values,
                colorscale='viridis',
                showscale=True if i == 0 else False,  # Show colorbar only for first subplot
                colorbar=dict(
                    title="Token Position",
                    x=1.02,
                    len=0.4
                ) if i == 0 else None
            ),
            line=dict(
                color='rgba(0,0,0,0.3)',
                width=2
            ),
            showlegend=False
        ),
        row=row, col=col
    )

fig.update_layout(
    title='3D Trajectories for All Sequences',
    width=1600,
    height=800
)

# Update scene properties for all subplots
for i in range(1, 11):
    row = ((i-1) // 5) + 1
    col = ((i-1) % 5) + 1
    fig.update_scenes(
        xaxis_title='PC1',
        yaxis_title='PC2',
        zaxis_title='PC3',
        row=row, col=col
    )

fig.show()
# Plot how each PCA component varies with token sequence position, averaged across all sequences
import numpy as np

# Find the maximum sequence length to determine how many positions to plot
max_seq_len = max(len(coords) for coords in all_coords)

# Initialize arrays to store sums and counts for averaging
component_sums = np.zeros((max_seq_len, 9))  # 9 PCA components
position_counts = np.zeros(max_seq_len)

# Accumulate values for each position across all sequences
for coords in all_coords:
    seq_len = len(coords)
    for pos in range(seq_len):
        component_sums[pos] += coords[pos]
        position_counts[pos] += 1

# Calculate averages (only for positions that have data)
component_averages = np.zeros((max_seq_len, 9))
valid_positions = position_counts > 0
component_averages[valid_positions] = component_sums[valid_positions] / position_counts[valid_positions, np.newaxis]

fig = go.Figure()

# Plot each PCA component
for i in range(8):
    # Only plot positions that have data
    valid_pos_indices = np.where(valid_positions)[0]
    fig.add_trace(
        go.Scatter(
            x=valid_pos_indices,
            y=component_averages[valid_positions, i],
            mode='lines+markers',
            name=f'PC{i+1}',
            line=dict(width=2),
            marker=dict(size=4)
        )
    )

fig.update_layout(
    title='All PCA Components vs Token Position (Averaged Across All Sequences)',
    xaxis_title='Token Position',
    yaxis_title='Average PCA Component Value',
    width=1000,
    height=600,
    showlegend=True
)

fig.show()

ValueError: operands could not be broadcast together with shapes (9,) (3,) (9,) 

# Interactive Prompt Visualization

Generate activations for custom prompts and visualize their trajectory through PCA space in real-time.

In [48]:
# Helper function to format prompts with chat template
def format_prompt(prompt):
    """
    Format a prompt using the model's chat template.
    
    Args:
        prompt: User prompt string
    
    Returns:
        Formatted prompt string with chat template applied
    """
    messages = [
        {"role": "user", "content": prompt}
    ]
    formatted = model.tokenizer.apply_chat_template(
        messages, 
        tokenize=False, 
        add_generation_prompt=True
    )
    return formatted

**Note:** All generation functions now automatically apply the model's chat template formatting to prompts. This ensures the Qwen model receives properly formatted inputs as it expects during training.

In [None]:
def generate_activations_for_prompt(prompt, layer_idx=15, max_new_tokens=100):
    """
    Generate activations for a given prompt using nnsight.
    
    Args:
        prompt: Text string to generate activations for
        layer_idx: Layer index to extract activations from (default: 15)
        max_new_tokens: Maximum number of tokens to generate (default: 100)
    
    Returns:
        activations: Tensor of shape [seq_len, hidden_dim] with layer activations
        generated_text: The full generated text
    """
    print(f"Generating activations for prompt: '{prompt[:50]}...'")
    
    # Format prompt with chat template
    formatted_prompt = format_prompt(prompt)
    
    # Step 1: Generate the text first and save the output
    with model.generate(formatted_prompt, max_new_tokens=max_new_tokens) as generator:
        output_tokens = model.generator.output.save()
    
    # Get the generated text
    generated_text = model.tokenizer.decode(output_tokens[0])
    
    # Step 2: Now trace through the full generated sequence to get activations
    # Trace through the token IDs (need to add batch dimension)
    with model.trace(output_tokens[0].unsqueeze(0), scan=False, validate=False):
        layer_activations = model.model.layers[layer_idx].output[0].save()
    
    # Get activations - squeeze batch dimension since we're processing one prompt
    # layer_activations is already a tensor after the trace context exits
    # Convert to bfloat16 to match selected_components dtype
    activations = layer_activations.squeeze(0).to(dtype=torch.bfloat16)  # Shape: [seq_len, hidden_dim]
    
    print(f"Generated {activations.shape[0]} tokens")
    print(f"Generated text: {generated_text[:100]}...")
    
    return activations, generated_text

SyntaxError: invalid syntax (2832854270.py, line 38)

In [50]:
def plot_single_trajectory_3d(coords, title="3D Trajectory", show_plot=True):
    """
    Plot a single 3D trajectory through PCA space.
    
    Args:
        coords: numpy array of shape [seq_len, 3] with PCA coordinates
        title: Title for the plot
        show_plot: Whether to display the plot immediately
    
    Returns:
        fig: Plotly figure object
    """
    # Create color values as token position
    color_values = list(range(len(coords)))
    
    fig = go.Figure()
    
    fig.add_trace(
        go.Scatter3d(
            x=coords[:, 0],
            y=coords[:, 1], 
            z=coords[:, 2],
            mode='markers+lines',
            marker=dict(
                size=4,
                color=color_values,
                colorscale='viridis',
                showscale=True,
                colorbar=dict(
                    title="Token Position",
                    x=1.02
                )
            ),
            line=dict(
                color='rgba(100,100,100,0.5)',
                width=3
            ),
            showlegend=False
        )
    )
    
    fig.update_layout(
        title=title,
        width=900,
        height=700,
        scene=dict(
            xaxis_title='PC1',
            yaxis_title='PC2',
            zaxis_title='PC3',
        )
    )
    
    if show_plot:
        fig.show()
    
    return fig

In [51]:
def plot_components_vs_position(coords, title="PCA Components vs Token Position", show_plot=True):
    """
    Plot how each PCA component varies with token sequence position.
    
    Args:
        coords: numpy array of shape [seq_len, n_components] with PCA coordinates
        title: Title for the plot
        show_plot: Whether to display the plot immediately
    
    Returns:
        fig: Plotly figure object
    """
    n_components = coords.shape[1]
    
    fig = go.Figure()
    
    # Plot each PCA component
    for i in range(min(n_components, 9)):  # Plot up to 9 components
        fig.add_trace(
            go.Scatter(
                x=list(range(len(coords))),
                y=coords[:, i],
                mode='lines+markers',
                name=f'PC{i+1}',
                line=dict(width=2),
                marker=dict(size=4)
            )
        )
    
    fig.update_layout(
        title=title,
        xaxis_title='Token Position',
        yaxis_title='PCA Component Value',
        width=1000,
        height=600,
        showlegend=True
    )
    
    if show_plot:
        fig.show()
    
    return fig

In [52]:
def visualize_prompt(prompt, layer_idx=15, max_new_tokens=100, ema_alpha=0.99, skip_first_n=5):
    """
    Complete pipeline: Generate activations for a prompt and visualize in PCA space.
    
    Args:
        prompt: Text string to generate activations for
        layer_idx: Layer index to extract activations from (default: 15)
        max_new_tokens: Maximum number of tokens to generate (default: 100)
        ema_alpha: EMA smoothing parameter (default: 0.99)
    
    Returns:
        coords: numpy array of PCA coordinates
        generated_text: The full generated text
    """
    # Step 1: Generate activations
    activations, generated_text = generate_activations_for_prompt(
        prompt, 
        layer_idx=layer_idx, 
        max_new_tokens=max_new_tokens
    )
    
    activations = activations[skip_first_n:]

    # Step 2: Project to PCA space using existing process_activations function
    coords = process_activations(activations, alpha=ema_alpha)
    
    # Step 3: Plot 3D trajectory
    print("\n" + "="*60)
    print("3D TRAJECTORY VISUALIZATION")
    print("="*60)
    plot_single_trajectory_3d(coords[:, :3], title=f"3D Trajectory: '{prompt[:40]}...'")
    
    # Step 4: Plot components vs position
    print("\n" + "="*60)
    print("PCA COMPONENTS OVER TIME")
    print("="*60)
    plot_components_vs_position(coords, title=f"PCA Components: '{prompt[:40]}...'")
    
    print("\n" + "="*60)
    print("GENERATED TEXT")
    print("="*60)
    print(generated_text)
    print("="*60)
    
    return coords, generated_text

## Usage Example

Simply type your prompt in the cell below and run it to see the visualization!

In [111]:
# Example: Visualize a custom prompt
# Simply change the prompt text and run this cell!

my_prompt = "Once upon a time in a faraway land"

coords, generated_text = visualize_prompt(
    prompt=my_prompt,
    max_new_tokens=5000,  # Adjust this to generate more or fewer tokens
    layer_idx=15,  # Use layer 15 activations
    ema_alpha=0.99
)

Generating activations for prompt: 'Once upon a time in a faraway land...'


NNsightException: 

Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/nnsight/intervention/backends/execution.py", line 21, in __call__
    tracer.execute(fn)
  File "/usr/local/lib/python3.11/dist-packages/nnsight/intervention/tracing/tracer.py", line 385, in execute
    self.model.interleave(self.fn, *args, **kwargs)
  File "/usr/local/lib/python3.11/dist-packages/nnsight/modeling/mixins/meta.py", line 76, in interleave
    return super().interleave(fn, *args, **kwargs)
  File "/usr/local/lib/python3.11/dist-packages/nnsight/intervention/envoy.py", line 733, in interleave
    fn(*args, **kwargs)
  File "/usr/local/lib/python3.11/dist-packages/nnsight/intervention/envoy.py", line 384, in __call__
    else self._module(*args, **kwargs)
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1879, in _call_impl
    return inner()
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1827, in inner
    result = forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.11/dist-packages/nnsight/intervention/interleaver.py", line 133, in skippable_forward
    return forward(*args, **kwargs)
  File "/usr/local/lib/python3.11/dist-packages/transformers/utils/generic.py", line 940, in wrapper
    output = func(self, *args, **kwargs)
  File "/usr/local/lib/python3.11/dist-packages/transformers/models/qwen3/modeling_qwen3.py", line 480, in forward
    outputs: BaseModelOutputWithPast = self.model(
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1879, in _call_impl
    return inner()
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1827, in inner
    result = forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.11/dist-packages/nnsight/intervention/interleaver.py", line 133, in skippable_forward
    return forward(*args, **kwargs)
  File "/usr/local/lib/python3.11/dist-packages/transformers/utils/generic.py", line 1064, in wrapper
    outputs = func(self, *args, **kwargs)
  File "/usr/local/lib/python3.11/dist-packages/transformers/models/qwen3/modeling_qwen3.py", line 410, in forward
    hidden_states = decoder_layer(
  File "/usr/local/lib/python3.11/dist-packages/transformers/modeling_layers.py", line 94, in __call__
    return super().__call__(*args, **kwargs)
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1879, in _call_impl
    return inner()
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1827, in inner
    result = forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.11/dist-packages/nnsight/intervention/interleaver.py", line 133, in skippable_forward
    return forward(*args, **kwargs)
  File "/usr/local/lib/python3.11/dist-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.11/dist-packages/transformers/models/qwen3/modeling_qwen3.py", line 260, in forward
    hidden_states, _ = self.self_attn(
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1879, in _call_impl
    return inner()
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1827, in inner
    result = forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.11/dist-packages/nnsight/intervention/interleaver.py", line 133, in skippable_forward
    return forward(*args, **kwargs)
  File "/usr/local/lib/python3.11/dist-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.11/dist-packages/transformers/models/qwen3/modeling_qwen3.py", line 216, in forward
    attn_output, attn_weights = attention_interface(
  File "/usr/local/lib/python3.11/dist-packages/transformers/integrations/sdpa_attention.py", line 83, in sdpa_attention_forward
    attn_output = torch.nn.functional.scaled_dot_product_attention(

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.99 GiB. GPU 0 has a total capacity of 79.14 GiB of which 2.16 GiB is free. Process 419996 has 560.00 MiB memory in use. Process 704634 has 76.42 GiB memory in use. Of the allocated memory 72.38 GiB is allocated by PyTorch, and 3.54 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

## Debug: Check Activation Shapes and Values

## Optimized Version - Extract Activations During Generation

The original approach does TWO forward passes (generate + trace). This optimized version extracts activations DURING generation in a single pass.

## Compare Multiple Prompts

Visualize multiple prompts side-by-side to compare their trajectories.

In [29]:
def compare_prompts(prompts_list, max_new_tokens=50, layer_idx=15):
    """
    Compare multiple prompts by visualizing their trajectories together.
    
    Args:
        prompts_list: List of prompt strings
        max_new_tokens: Maximum number of tokens to generate per prompt
        layer_idx: Layer index to extract activations from
    
    Returns:
        all_coords: List of coordinate arrays for each prompt
        all_texts: List of generated texts for each prompt
    """
    all_coords = []
    all_texts = []
    
    # Generate activations for all prompts
    for i, prompt in enumerate(prompts_list):
        print(f"\n{'='*60}")
        print(f"Processing prompt {i+1}/{len(prompts_list)}")
        print(f"{'='*60}")
        
        activations, generated_text = generate_activations_for_prompt(
            prompt, 
            layer_idx=layer_idx, 
            max_new_tokens=max_new_tokens
        )
        
        coords = process_activations(activations)
        all_coords.append(coords)
        all_texts.append(generated_text)
    
    # Create comparison plot
    n_prompts = len(prompts_list)
    cols = min(3, n_prompts)
    rows = (n_prompts + cols - 1) // cols
    
    fig = make_subplots(
        rows=rows, cols=cols,
        specs=[[{'type': 'scatter3d'} for _ in range(cols)] for _ in range(rows)],
        subplot_titles=[f"Prompt {i+1}: {p[:30]}..." for i, p in enumerate(prompts_list)]
    )
    
    # Add each trajectory
    for i, coords in enumerate(all_coords):
        row = (i // cols) + 1
        col = (i % cols) + 1
        
        color_values = list(range(len(coords)))
        
        fig.add_trace(
            go.Scatter3d(
                x=coords[:, 0],
                y=coords[:, 1], 
                z=coords[:, 2],
                mode='markers+lines',
                marker=dict(
                    size=3,
                    color=color_values,
                    colorscale='viridis',
                    showscale=True if i == 0 else False,
                    colorbar=dict(
                        title="Token Position",
                        x=1.02,
                        len=0.3
                    ) if i == 0 else None
                ),
                line=dict(
                    color='rgba(100,100,100,0.4)',
                    width=2
                ),
                showlegend=False
            ),
            row=row, col=col
        )
    
    fig.update_layout(
        title='Comparison of Multiple Prompt Trajectories',
        width=1600,
        height=600 * rows
    )
    
    # Update scene properties
    for i in range(n_prompts):
        row = (i // cols) + 1
        col = (i % cols) + 1
        fig.update_scenes(
            xaxis_title='PC1',
            yaxis_title='PC2',
            zaxis_title='PC3',
            row=row, col=col
        )
    
    fig.show()
    
    # Print all generated texts
    print("\n" + "="*60)
    print("GENERATED TEXTS")
    print("="*60)
    for i, (prompt, text) in enumerate(zip(prompts_list, all_texts)):
        print(f"\n--- Prompt {i+1}: {prompt} ---")
        print(text)
        print()
    
    return all_coords, all_texts

In [None]:
# Example: Compare multiple prompts
prompts_to_compare = [
    "The weather today is",
    "In a galaxy far away",
    "The recipe for chocolate cake"
]

all_coords, all_texts = compare_prompts(
    prompts_to_compare,
    max_new_tokens=30,
    layer_idx=15
)

## Compare Against Baseline

Compare a new prompt's trajectory against the existing baseline trajectories from the pre-loaded activation files.

In [None]:
def compare_with_baseline(prompt, baseline_coords=None, n_baseline=5, max_new_tokens=50, layer_idx=15):
    """
    Compare a new prompt's trajectory against baseline trajectories.
    
    Args:
        prompt: Text string to generate activations for
        baseline_coords: List of baseline coordinate arrays (uses all_coords if None)
        n_baseline: Number of baseline trajectories to show (default: 5)
        max_new_tokens: Maximum number of tokens to generate
        layer_idx: Layer index to extract activations from
    
    Returns:
        new_coords: PCA coordinates for the new prompt
        generated_text: The full generated text
    """
    if baseline_coords is None:
        baseline_coords = all_coords  # Use the pre-loaded trajectories
    
    # Generate activations for new prompt
    print("Generating new prompt trajectory...")
    activations, generated_text = generate_activations_for_prompt(
        prompt, 
        layer_idx=layer_idx, 
        max_new_tokens=max_new_tokens
    )
    
    new_coords = process_activations(activations)
    
    # Create comparison plot
    fig = go.Figure()
    
    # Add baseline trajectories (in gray, semi-transparent)
    for i, coords in enumerate(baseline_coords[:n_baseline]):
        fig.add_trace(
            go.Scatter3d(
                x=coords[:, 0],
                y=coords[:, 1], 
                z=coords[:, 2],
                mode='lines',
                line=dict(
                    color='rgba(150,150,150,0.2)',
                    width=2
                ),
                name=f'Baseline {i+1}',
                showlegend=True if i == 0 else False,
                legendgroup='baseline'
            )
        )
    
    # Add new trajectory (highlighted in color)
    color_values = list(range(len(new_coords)))
    fig.add_trace(
        go.Scatter3d(
            x=new_coords[:, 0],
            y=new_coords[:, 1], 
            z=new_coords[:, 2],
            mode='markers+lines',
            marker=dict(
                size=5,
                color=color_values,
                colorscale='plasma',
                showscale=True,
                colorbar=dict(
                    title="Token Position",
                    x=1.02
                )
            ),
            line=dict(
                color='rgba(255,0,0,0.8)',
                width=4
            ),
            name='New Prompt',
            showlegend=True
        )
    )
    
    fig.update_layout(
        title=f"New Prompt vs Baseline Trajectories<br><sub>Prompt: '{prompt[:60]}...'</sub>",
        width=1000,
        height=800,
        scene=dict(
            xaxis_title='PC1',
            yaxis_title='PC2',
            zaxis_title='PC3',
        )
    )
    
    fig.show()
    
    print("\n" + "="*60)
    print("GENERATED TEXT")
    print("="*60)
    print(generated_text)
    print("="*60)
    
    return new_coords, generated_text

In [None]:
# Example: Compare a new prompt against baseline trajectories
my_new_prompt = "Tell me about machine learning"

new_coords, new_text = compare_with_baseline(
    prompt=my_new_prompt,
    n_baseline=10,  # Show 10 baseline trajectories
    max_new_tokens=40
)

## Quick Reference

### Main Functions:

1. **`visualize_prompt(prompt, max_new_tokens=100, layer_idx=15)`**
   - Generate and visualize a single prompt
   - Shows 3D trajectory and component plots
   - Returns coordinates and generated text

2. **`compare_prompts(prompts_list, max_new_tokens=50, layer_idx=15)`**
   - Compare multiple prompts side-by-side
   - Shows all trajectories in separate subplots
   - Returns all coordinates and generated texts

3. **`compare_with_baseline(prompt, n_baseline=5, max_new_tokens=50, layer_idx=15)`**
   - Compare a new prompt against pre-loaded baseline trajectories
   - Highlights the new trajectory in color
   - Shows baseline trajectories in gray

### Parameters:
- **`prompt`**: Your text prompt (string)
- **`max_new_tokens`**: How many tokens to generate (default: 50-100)
- **`layer_idx`**: Which layer to extract activations from (default: 15)
- **`n_baseline`**: Number of baseline trajectories to show (default: 5)

### Tips:
- Shorter prompts will generate faster
- Adjust `max_new_tokens` to see longer or shorter trajectories
- The color gradient shows token position (darker = earlier, lighter = later)
- PCA components are pre-computed from the probe weights

In [54]:
def generate_activations_optimized(prompt, layer_idx=15, max_new_tokens=100):
    """
    OPTIMIZED: More efficient memory management with immediate cleanup.
    Still uses two passes but with better memory handling.
    
    Args:
        prompt: Text string to generate activations for
        layer_idx: Layer index to extract activations from (default: 15)
        max_new_tokens: Maximum number of tokens to generate (default: 100)
    
    Returns:
        activations: Tensor of shape [seq_len, hidden_dim] with layer activations
        generated_text: The full generated text
    """
    print(f"Generating activations for prompt: '{prompt[:50]}...'")
    
    # Format prompt with chat template
    formatted_prompt = format_prompt(prompt)
    
    # Step 1: Generate the text and get output tokens
    with model.generate(formatted_prompt, max_new_tokens=max_new_tokens) as generator:
        output_tokens = model.generator.output.save()
    
    # Get the generated text
    generated_text = model.tokenizer.decode(output_tokens[0])
    
    # Step 2: Trace through to get activations
    with model.trace(output_tokens[0].unsqueeze(0), scan=False, validate=False):
        layer_activations = model.model.layers[layer_idx].output[0].save()
    
    # Convert to bfloat16 and remove batch dimension
    activations = layer_activations.squeeze(0).to(dtype=torch.bfloat16)
    
    # Clean up immediately
    del output_tokens, layer_activations
    torch.cuda.empty_cache()
    
    print(f"Generated {activations.shape[0]} tokens")
    print(f"Generated text: {generated_text[:100]}...")
    
    return activations, generated_text

In [55]:
def visualize_prompt_optimized(prompt, layer_idx=15, max_new_tokens=100, ema_alpha=0.99, skip_first_n=5):
    """
    OPTIMIZED: Complete pipeline with memory efficiency and single-pass generation.
    
    Args:
        prompt: Text string to generate activations for
        layer_idx: Layer index to extract activations from (default: 15)
        max_new_tokens: Maximum number of tokens to generate (default: 100)
        ema_alpha: EMA smoothing parameter (default: 0.99, set to 0.0 for no smoothing)
        skip_first_n: Skip first N tokens (usually prompt tokens, default: 5)
    
    Returns:
        coords: numpy array of PCA coordinates
        generated_text: The full generated text
    """
    # Step 1: Generate activations (optimized single pass)
    activations, generated_text = generate_activations_optimized(
        prompt, 
        layer_idx=layer_idx, 
        max_new_tokens=max_new_tokens
    )
    
    # Skip first N tokens if requested
    if skip_first_n > 0 and activations.shape[0] > skip_first_n:
        activations = activations[skip_first_n:]
    
    # Step 2: Project to PCA space
    coords = process_activations(activations, alpha=ema_alpha)
    
    # Clean up activations from GPU
    del activations
    torch.cuda.empty_cache()
    
    # Step 3: Plot 3D trajectory
    print("\n" + "="*60)
    print("3D TRAJECTORY VISUALIZATION")
    print("="*60)
    plot_single_trajectory_3d(coords[:, :3], title=f"3D Trajectory: '{prompt[:40]}...'")
    
    # Step 4: Plot components vs position
    print("\n" + "="*60)
    print("PCA COMPONENTS OVER TIME")
    print("="*60)
    plot_components_vs_position(coords, title=f"PCA Components: '{prompt[:40]}...'")
    
    print("\n" + "="*60)
    print("GENERATED TEXT")
    print("="*60)
    print(generated_text)
    print("="*60)
    
    return coords, generated_text

In [31]:
# Test the optimized version
# This should be MUCH faster than the original!

test_prompt = "Once upon a time in a faraway land"

# Use reasonable token limit (not 5000!)
coords_opt, text_opt = visualize_prompt_optimized(
    prompt=test_prompt,
    max_new_tokens=1000,  # Much more reasonable than 5000!
    layer_idx=15,
    ema_alpha=0.99,  # Or try 0.0 for no smoothing
    skip_first_n=5
)

Generating activations for prompt: 'Once upon a time in a faraway land...'
Generated 1009 tokens
Generated text: Once upon a time in a faraway land, there was a kingdom where the King had a special garden. The gar...

3D TRAJECTORY VISUALIZATION
Generated 1009 tokens
Generated text: Once upon a time in a faraway land, there was a kingdom where the King had a special garden. The gar...

3D TRAJECTORY VISUALIZATION



PCA COMPONENTS OVER TIME



GENERATED TEXT
Once upon a time in a faraway land, there was a kingdom where the King had a special garden. The garden was divided into 100 plots, each marked with a unique number from 1 to 100. The King decided to host a contest where he would randomly select a number from 1 to 100 and give a prize to the person who could correctly identify the number. However, the twist was that the number was not chosen randomly but was instead selected from a set of numbers that had a special property. The King said that the number would be such that the sum of its digits is a prime number. The winner would be the one who could find the number with the highest possible value that satisfies this condition. The question is, what is the highest number between 1 and 100 that has a digit sum which is a prime number?

To solve this problem, we need to find the largest number between 1 and 100 such that the sum of its digits is a prime number. Let's break this down step by step.

First, let's understand 

## Memory and Performance Tips

### Key Optimizations Made:
1. **Immediate cleanup**: Delete tensors and clear CUDA cache right after use
2. **Memory efficient**: Explicit GPU memory management after each operation
3. **Reasonable token limits**: 200-500 tokens instead of 5000
4. **Skip prompt tokens**: Only visualize generated tokens (not prompt)

### Additional Memory Tips:
- Use `torch.cuda.empty_cache()` after each generation
- Process one prompt at a time if comparing multiple
- Lower `max_new_tokens` for faster results
- Monitor memory with `check_gpu_memory()` utility

### Performance Notes:
- Generation still requires two passes (generate + trace) due to nnsight architecture
- Main improvement is better memory cleanup preventing accumulation
- Skipping prompt tokens reduces processing time
- Lower token counts = faster visualization

In [32]:
# Utility: Check GPU memory usage
def check_gpu_memory():
    """Print current GPU memory usage"""
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1024**3  # GB
        reserved = torch.cuda.memory_reserved() / 1024**3    # GB
        print(f"GPU Memory - Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB")
    else:
        print("CUDA not available")

# Check memory before and after
print("Before generation:")
check_gpu_memory()

# Run generation
gc.collect()
torch.cuda.empty_cache()

print("\nAfter cleanup:")
check_gpu_memory()

Before generation:
GPU Memory - Allocated: 7.64 GB, Reserved: 7.69 GB

After cleanup:
GPU Memory - Allocated: 7.64 GB, Reserved: 7.69 GB


## Steering Comparison

Apply a steering vector during generation and compare the steered vs unsteered trajectories through PCA space.

In [56]:
def generate_with_steering(prompt, steering_vector, layer_idx=15, max_new_tokens=100, 
                          steering_strength=1.0, start_step=0, end_step=None):
    """
    Generate text with a steering vector applied during generation.
    
    Args:
        prompt: Text string to generate from
        steering_vector: Tensor of shape [hidden_dim] to add to activations
        layer_idx: Layer to apply steering to (default: 15)
        max_new_tokens: Maximum tokens to generate (default: 100)
        steering_strength: Multiplier for steering vector (default: 1.0)
        start_step: Token position to start steering (0 = first generated token, default: 0)
        end_step: Token position to stop steering (None = steer all tokens, default: None)
    
    Returns:
        activations: Tensor [seq_len, hidden_dim] with layer activations (after steering)
        generated_text: The full generated text (influenced by steering)
        steering_mask: Boolean array indicating which tokens were steered
    """
    print(f"Generating with steering (strength={steering_strength}, steps {start_step}-{end_step})...")
    
    if end_step is None:
        end_step = max_new_tokens
    
    # Ensure steering vector is on correct device and dtype
    steering_vec = steering_vector.to(device='cuda', dtype=torch.bfloat16)
    
    # Format prompt with chat template
    formatted_prompt = format_prompt(prompt)
    
    # Step 1: Generate WITH steering intervention
    # We'll apply the intervention to the residual stream during generation
    with model.generate(formatted_prompt, max_new_tokens=max_new_tokens, pad_token_id=model.tokenizer.eos_token_id) as generator:
        # Access the layer output during generation (it's a tuple, take first element)
        layer_output = model.model.layers[layer_idx].output[0]
        
        # Apply steering intervention: add steering vector to all sequence positions
        # layer_output shape: [batch, seq, hidden]
        layer_output[:, :] = layer_output + steering_vec * steering_strength
        
        # Save the output tokens
        output_tokens = model.generator.output.save()
    
    # Decode generated text
    generated_text = model.tokenizer.decode(output_tokens[0])
    
    # Step 2: Get activations by tracing through the steered generation
    # We need to trace with the same steering to get matching activations
    with model.trace(output_tokens[0].unsqueeze(0), scan=False, validate=False):
        # Apply same steering
        layer_output = model.model.layers[layer_idx].output[0]
        layer_output[:, :] = layer_output + steering_vec * steering_strength
        
        # Save steered activations
        layer_activations = layer_output.save()
    
    # Convert and cleanup
    activations = layer_activations.squeeze(0).to(dtype=torch.bfloat16)
    
    # Create steering mask (which tokens were in the steering range)
    num_tokens = activations.shape[0]
    steering_mask = torch.ones(num_tokens, dtype=bool)  # All tokens steered in this implementation
    
    # Cleanup
    del output_tokens, layer_activations, steering_vec
    torch.cuda.empty_cache()
    
    print(f"Generated {activations.shape[0]} tokens (all steered)")
    print(f"Generated text: {generated_text[:100]}...")
    
    return activations, generated_text, steering_mask.cpu().numpy()

SyntaxError: invalid syntax (400293602.py, line 70)

In [57]:
def compare_steered_vs_unsteered(prompt, steering_vector, layer_idx=15, max_new_tokens=100,
                                steering_strength=1.0, start_step=0, end_step=None, 
                                ema_alpha=0.99, skip_first_n=5):
    """
    Compare steered and unsteered generation trajectories side-by-side.
    
    Args:
        prompt: Text string to generate from
        steering_vector: Tensor [hidden_dim] to add to activations
        layer_idx: Layer to extract/steer activations (default: 15)
        max_new_tokens: Maximum tokens to generate (default: 100)
        steering_strength: Multiplier for steering vector (default: 1.0)
        start_step: Token position to start steering (default: 0)
        end_step: Token position to stop steering (default: None = all tokens)
        ema_alpha: EMA smoothing parameter (default: 0.99)
        skip_first_n: Skip first N tokens for visualization (default: 5)
    
    Returns:
        unsteered_coords: PCA coordinates for unsteered generation
        steered_coords: PCA coordinates for steered generation
        unsteered_text: Unsteered generated text
        steered_text: Steered generated text
        steering_mask: Boolean mask of which tokens were steered
    """
    print("="*60)
    print("GENERATING UNSTEERED BASELINE")
    print("="*60)
    
    # Generate unsteered version
    unsteered_acts, unsteered_text = generate_activations_optimized(
        prompt, layer_idx=layer_idx, max_new_tokens=max_new_tokens
    )
    
    print("\n" + "="*60)
    print("GENERATING STEERED VERSION")
    print("="*60)
    
    # Generate steered version
    steered_acts, steered_text, steering_mask = generate_with_steering(
        prompt, steering_vector, layer_idx=layer_idx, max_new_tokens=max_new_tokens,
        steering_strength=steering_strength, start_step=start_step, end_step=end_step
    )
    
    # Skip first N tokens if requested
    if skip_first_n > 0:
        if unsteered_acts.shape[0] > skip_first_n:
            unsteered_acts = unsteered_acts[skip_first_n:]
        if steered_acts.shape[0] > skip_first_n:
            steered_acts = steered_acts[skip_first_n:]
            steering_mask = steering_mask[skip_first_n:]
    
    # Project to PCA space
    unsteered_coords = process_activations(unsteered_acts, alpha=ema_alpha)
    steered_coords = process_activations(steered_acts, alpha=ema_alpha)
    
    # Cleanup
    del unsteered_acts, steered_acts
    torch.cuda.empty_cache()
    
    # Create side-by-side comparison plot
    fig = make_subplots(
        rows=1, cols=2,
        specs=[[{'type': 'scatter3d'}, {'type': 'scatter3d'}]],
        subplot_titles=['Unsteered Generation', 'Steered Generation']
    )
    
    # Plot unsteered trajectory
    color_values_unsteered = list(range(len(unsteered_coords)))
    fig.add_trace(
        go.Scatter3d(
            x=unsteered_coords[:, 0],
            y=unsteered_coords[:, 1],
            z=unsteered_coords[:, 2],
            mode='markers+lines',
            marker=dict(
                size=4,
                color=color_values_unsteered,
                colorscale='viridis',
                showscale=False
            ),
            line=dict(color='rgba(100,100,100,0.5)', width=3),
            name='Unsteered',
            showlegend=False
        ),
        row=1, col=1
    )
    
    # Plot steered trajectory with highlighting for steered sections
    # Split into steered and unsteered segments
    color_values_steered = list(range(len(steered_coords)))
    
    fig.add_trace(
        go.Scatter3d(
            x=steered_coords[:, 0],
            y=steered_coords[:, 1],
            z=steered_coords[:, 2],
            mode='markers+lines',
            marker=dict(
                size=4,
                color=color_values_steered,
                colorscale='plasma',
                showscale=False
            ),
            line=dict(color='rgba(255,50,50,0.7)', width=3),
            name='Steered',
            showlegend=False
        ),
        row=1, col=2
    )
    
    fig.update_layout(
        title=f"Steering Comparison (strength={steering_strength}, steps {start_step}-{end_step})<br>" +
              f"<sub>Prompt: '{prompt[:50]}...'</sub>",
        width=1600,
        height=700
    )
    
    # Update both subplots
    fig.update_scenes(xaxis_title='PC1', yaxis_title='PC2', zaxis_title='PC3', row=1, col=1)
    fig.update_scenes(xaxis_title='PC1', yaxis_title='PC2', zaxis_title='PC3', row=1, col=2)
    
    fig.show()
    
    # Plot overlay comparison
    print("\n" + "="*60)
    print("OVERLAY COMPARISON")
    print("="*60)
    
    fig_overlay = go.Figure()
    
    # Add unsteered trajectory
    fig_overlay.add_trace(
        go.Scatter3d(
            x=unsteered_coords[:, 0],
            y=unsteered_coords[:, 1],
            z=unsteered_coords[:, 2],
            mode='markers+lines',
            marker=dict(size=3, color='blue', opacity=0.6),
            line=dict(color='blue', width=3, dash='dash'),
            name='Unsteered'
        )
    )
    
    # Add steered trajectory
    fig_overlay.add_trace(
        go.Scatter3d(
            x=steered_coords[:, 0],
            y=steered_coords[:, 1],
            z=steered_coords[:, 2],
            mode='markers+lines',
            marker=dict(size=4, color='red', opacity=0.8),
            line=dict(color='red', width=4),
            name='Steered'
        )
    )
    
    fig_overlay.update_layout(
        title=f"Overlay: Steered vs Unsteered<br><sub>Prompt: '{prompt[:50]}...'</sub>",
        width=1000,
        height=800,
        scene=dict(xaxis_title='PC1', yaxis_title='PC2', zaxis_title='PC3')
    )
    
    fig_overlay.show()
    
    # Print text comparison
    print("\n" + "="*60)
    print("TEXT COMPARISON")
    print("="*60)
    print("\n--- UNSTEERED ---")
    print(unsteered_text)
    print("\n--- STEERED ---")
    print(steered_text)
    print("="*60)
    
    return unsteered_coords, steered_coords, unsteered_text, steered_text, steering_mask

### Example: Create a steering vector and compare

You can create a steering vector in several ways:
1. Load a pre-computed steering vector from a file
2. Create a random steering vector for testing
3. Use the difference between two activation states
4. Extract from your probe weights

In [60]:
# Example 1: Use probe weights as steering vector (e.g., "refusal" direction)
# The probe weights represent the learned "harmful" vs "harmless" direction
# Using this as a steering vector will push the model along that axis

# Get the first probe weight as a steering vector
steering_vec_from_probe = probes_weights[0]  # Shape: [2560]

print(f"Steering vector shape: {steering_vec_from_probe.shape}")
print(f"Steering vector dtype: {steering_vec_from_probe.dtype}")

# Test prompt
test_prompt_steering = "Write a story about"

# Compare steered vs unsteered
# Positive strength = push in the direction of the probe
# Negative strength = push opposite direction
unsteered_coords, steered_coords, unsteered_text, steered_text, mask = compare_steered_vs_unsteered(
    prompt=test_prompt_steering,
    steering_vector=steering_vec_from_probe,
    max_new_tokens=100,
    steering_strength=0,  # Try different values: -5, -2, -1, 0.5, 1, 2, 5
    start_step=20,  # Start steering from first generated token
    end_step=20,  # Steer all generated tokens (or set to e.g. 20 to steer only first 20)
    layer_idx=15,
    ema_alpha=0.99,  # Try 0.0 for no smoothing to see raw differences
    skip_first_n=1
)

Steering vector shape: torch.Size([2560])
Steering vector dtype: torch.float32
GENERATING UNSTEERED BASELINE
Generating activations for prompt: 'Write a story about...'
Generated 112 tokens
Generated text: <|im_start|>user
Write a story about<|im_end|>
<|im_start|>assistant
<think>
Okay, the user wants me...

GENERATING STEERED VERSION
Generating with steering (strength=0, steps 20-20)...
Generated 112 tokens
Generated text: <|im_start|>user
Write a story about<|im_end|>
<|im_start|>assistant
<think>
Okay, the user wants me...

GENERATING STEERED VERSION
Generating with steering (strength=0, steps 20-20)...
Generated 104 tokens (all steered)
Generated text: Write a story about a boy who is a pirate and he is trying to find the treasure of the legendary pir...
Generated 104 tokens (all steered)
Generated text: Write a story about a boy who is a pirate and he is trying to find the treasure of the legendary pir...



OVERLAY COMPARISON



TEXT COMPARISON

--- UNSTEERED ---
<|im_start|>user
Write a story about<|im_end|>
<|im_start|>assistant
<think>
Okay, the user wants me to write a story. Let me think about what they might be looking for. They didn't specify a genre or theme, so maybe I should choose something universal. Maybe a fantasy or magical realism story? That often allows for creativity.

I need a main character. Perhaps someone with a unique ability or a quest. Let's go with a young protagonist. Maybe a girl named Lila? Sounds nice and relatable. She could have a special connection to the

--- STEERED ---
Write a story about a boy who is a pirate and he is trying to find the treasure of the legendary pirate, Captain Blackbeard.

**Title: The Pirate's Legacy**

**Chapter 1: The Call to Adventure**

In the salty air of the Caribbean, where the waves crashed against the shores like a thousand whispered secrets, a boy named Eli lived with his family in a small fishing village. But Eli was no ordinary boy. He was 

In [None]:
# Example 2: Create a random steering vector (for testing)
# This helps verify that steering is working and shows what random perturbations look like

random_steering_vec = torch.randn(2560, dtype=torch.float32) * 0.1  # Small random vector

unsteered_coords_rand, steered_coords_rand, unsteered_text_rand, steered_text_rand, mask_rand = compare_steered_vs_unsteered(
    prompt="Once upon a time",
    steering_vector=random_steering_vec,
    max_new_tokens=80,
    steering_strength=1.0,
    start_step=0,
    end_step=None,
    layer_idx=15,
    ema_alpha=0.99,
    skip_first_n=5
)

In [None]:
# Example 3: Steer only during specific generation steps
# This lets you see when steering has the most impact

# Steer only the first 20 tokens
unsteered_coords_partial, steered_coords_partial, unsteered_text_partial, steered_text_partial, mask_partial = compare_steered_vs_unsteered(
    prompt="Explain how quantum computing works",
    steering_vector=probes_weights[0],
    max_new_tokens=100,
    steering_strength=3.0,
    start_step=0,   # Start steering immediately
    end_step=20,    # Stop steering after 20 tokens
    layer_idx=15,
    ema_alpha=0.5,  # Less smoothing to see the transition point
    skip_first_n=5
)

print(f"\nSteering was applied to {mask_partial.sum()} tokens out of {len(mask_partial)} total tokens")

### Advanced: Compare Multiple Steering Strengths

See how different steering strengths affect the trajectory.

In [None]:
def compare_multiple_steering_strengths(prompt, steering_vector, strengths_list, 
                                       layer_idx=15, max_new_tokens=80, ema_alpha=0.99, skip_first_n=5):
    """
    Compare trajectories with different steering strengths all in one plot.
    
    Args:
        prompt: Text string to generate from
        steering_vector: Steering vector to apply
        strengths_list: List of steering strengths to try (e.g., [0, 1, 2, 5])
        layer_idx: Layer index (default: 15)
        max_new_tokens: Tokens to generate (default: 80)
        ema_alpha: EMA smoothing (default: 0.99)
        skip_first_n: Skip first N tokens (default: 5)
    
    Returns:
        all_coords: List of coordinate arrays for each strength
        all_texts: List of generated texts for each strength
    """
    all_coords = []
    all_texts = []
    
    print(f"Comparing steering strengths: {strengths_list}")
    
    for strength in strengths_list:
        print(f"\n{'='*60}")
        print(f"Testing strength = {strength}")
        print(f"{'='*60}")
        
        if strength == 0:
            # No steering - use baseline
            acts, text = generate_activations_optimized(
                prompt, layer_idx=layer_idx, max_new_tokens=max_new_tokens
            )
            steering_mask = None
        else:
            # With steering
            acts, text, steering_mask = generate_with_steering(
                prompt, steering_vector, layer_idx=layer_idx, max_new_tokens=max_new_tokens,
                steering_strength=strength
            )
        
        # Skip first N tokens
        if skip_first_n > 0 and acts.shape[0] > skip_first_n:
            acts = acts[skip_first_n:]
        
        # Project to PCA
        coords = process_activations(acts, alpha=ema_alpha)
        all_coords.append(coords)
        all_texts.append(text)
        
        del acts
        torch.cuda.empty_cache()
    
    # Create 3D plot with all trajectories
    fig = go.Figure()
    
    colors = ['blue', 'green', 'orange', 'red', 'purple', 'brown', 'pink']
    
    for i, (coords, strength) in enumerate(zip(all_coords, strengths_list)):
        color = colors[i % len(colors)]
        
        fig.add_trace(
            go.Scatter3d(
                x=coords[:, 0],
                y=coords[:, 1],
                z=coords[:, 2],
                mode='markers+lines',
                marker=dict(size=3, color=color),
                line=dict(color=color, width=3),
                name=f'Strength={strength}'
            )
        )
    
    fig.update_layout(
        title=f"Steering Strength Comparison<br><sub>Prompt: '{prompt[:50]}...'</sub>",
        width=1200,
        height=800,
        scene=dict(xaxis_title='PC1', yaxis_title='PC2', zaxis_title='PC3')
    )
    
    fig.show()
    
    # Print all texts
    print("\n" + "="*60)
    print("GENERATED TEXTS")
    print("="*60)
    for strength, text in zip(strengths_list, all_texts):
        print(f"\n--- Strength = {strength} ---")
        print(text[:200] + "..." if len(text) > 200 else text)
    
    return all_coords, all_texts

In [None]:
# Example: Compare multiple steering strengths
# This shows how the trajectory changes as you increase steering strength

test_prompt_multi = "Tell me about"

all_coords_multi, all_texts_multi = compare_multiple_steering_strengths(
    prompt=test_prompt_multi,
    steering_vector=probes_weights[0],
    strengths_list=[0, 0.5, 1.0, 2.0, 5.0],  # Try different strengths including baseline (0)
    max_new_tokens=60,
    layer_idx=15,
    ema_alpha=0.7,  # Moderate smoothing
    skip_first_n=5
)

## Steering Functions Summary

### Core Functions:

1. **`generate_with_steering(prompt, steering_vector, steering_strength, start_step, end_step, ...)`**
   - Generate text with a steering vector applied to layer activations
   - Control when steering is applied (start_step to end_step)
   - Returns activations, text, and a mask showing which tokens were steered

2. **`compare_steered_vs_unsteered(prompt, steering_vector, steering_strength, ...)`**
   - Run same prompt twice: once with steering, once without
   - Shows side-by-side 3D trajectories and overlay plot
   - Displays both generated texts for comparison

3. **`compare_multiple_steering_strengths(prompt, steering_vector, strengths_list, ...)`**
   - Test multiple steering strengths on the same prompt
   - All trajectories shown in one 3D plot with different colors
   - Helps find optimal steering strength

### Key Parameters:

- **`steering_vector`**: Tensor [hidden_dim] to add to activations (e.g., from probe weights)
- **`steering_strength`**: Multiplier for the steering vector (try 0.5 to 5.0)
  - Positive values = steer in direction of vector
  - Negative values = steer opposite direction
  - 0 = no steering (baseline)
- **`start_step`**: Which generated token to start steering (0 = first generated token)
- **`end_step`**: Which token to stop steering (None = steer all tokens)
- **`ema_alpha`**: Smoothing parameter (0.0 = no smoothing, 0.99 = heavy smoothing)
  - Use lower values (0.0-0.5) to see sharp differences between steered/unsteered
  - Use higher values (0.9-0.99) for smoother trajectories

### Typical Workflow:

1. Choose or create a steering vector (from probe weights, random, or computed)
2. Start with `compare_steered_vs_unsteered()` using moderate strength (~1.0-2.0)
3. Adjust steering strength based on results
4. Use `compare_multiple_steering_strengths()` to find optimal strength
5. Experiment with partial steering (start_step, end_step) to see when it matters most
6. Try different EMA alpha values to see raw vs smoothed trajectories

### Tips:

- Start with lower strengths (0.5-2.0) and increase if effect is too subtle
- Use `ema_alpha=0.0` to see unsmoothed differences
- Steer only part of generation (`end_step=20`) to see transition effects
- The probe weights represent "harmful" direction - positive steering may increase refusal behavior
- Try negative strengths to steer opposite direction (e.g., reduce refusal)