Just run all these cells

# Imports

In [1]:
from nnsight import LanguageModel
from einops import einsum
import torch
import ipywidgets as widgets
from IPython.display import display
import time

# Load model

In [2]:
device = 'cuda'

In [3]:
#weight_tensor = torch.load('/workspace/llm-progress-monitor/qwen3_4b_weight_tensor.pt')
weight_tensor = torch.load('qwen3_4b_weight_tensor.pt')
model_name = 'Qwen/Qwen3-4B'

In [4]:
model = LanguageModel(model_name, device_map=device, dtype=torch.bfloat16)

In [5]:
def get_ema_preds(log_preds, alpha=0.5):
    given_alpha = alpha
    preds_list = log_preds.exp().tolist()
    
    ema_preds = []
    cur_ema = None
    for i,pred in enumerate(preds_list):
        # Use a smooth transition from 0.5 to given_alpha, reaching given_alpha at 200 tokens
        alpha = given_alpha
        if cur_ema is None:
            cur_ema = pred
        else:
            cur_ema = alpha*(cur_ema-1) + (1-alpha)*pred #-1 because we have stepped one token
        ema_preds.append(cur_ema)
    return ema_preds

In [6]:
def get_log_preds(activation, weight_tensor):
    return einsum(
        einsum(activation, weight_tensor, 'seq d_model, pca d_model -> seq pca').softmax(dim=1),
        0.5+torch.arange(weight_tensor.shape[0]).to(device, dtype=torch.bfloat16),
        'seq pca, pca -> seq'
    )

# Vibe coded UIs

In [7]:
!pip install plotly

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [15]:
# Create input text box for prompt
prompt_input = widgets.Textarea(
    value="",
    placeholder='Enter your prompt here...',
    description='Prompt:',
    layout=widgets.Layout(width='100%', height='80px')
)

# Create submit button
submit_button = widgets.Button(
    description='Generate Text',
    button_style='success',
    tooltip='Click to start text generation',
    icon='play'
)

# Create progress bar widget
progress_bar = widgets.FloatProgress(
    value=0,
    min=0,
    max=100,
    description='Progress:',
    bar_style='info',
    style={'bar_color': '#20B2AA'},
    orientation='horizontal'
)

# Create percentage label
percentage_label = widgets.HTML(
    value="<b>0.0%</b>",
    description='',
)

# Create horizontal box for progress bar and percentage
progress_row = widgets.HBox([progress_bar, percentage_label])

# Create text widget for token display
token_display = widgets.HTML(
    value="<b>Generated tokens will appear here...</b>",
    placeholder='',
    description='',
)

# Create container for the widgets
progress_container = widgets.VBox([
    widgets.HTML("<h3>Text Generation Progress</h3>"),
    prompt_input,
    submit_button,
    progress_row,
    token_display
])

# Display the widget
display(progress_container)

def on_submit_clicked(b):
    # Reset progress
    progress_bar.value = 0
    percentage_label.value = "<b>0.0%</b>"
    token_display.value = "<b>Generating...</b>"
    
    # Get prompt from input
    prompt = prompt_input.value
    # Apply chat template
    prompt = model.tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
    cur_log_preds = []
    n_tokens_generated = 0
    generated_tokens = []

    with model.generate(prompt, max_new_tokens=32768, do_sample=True) as tracer:
        # Call .all() to apply intervention to each new token
        with tracer.all():
            activations = model.model.layers[15].output[0]
            if len(activations.shape) == 1:
                activations = activations.unsqueeze(0)
            preds = get_log_preds(activations, weight_tensor).tolist()
            if len(preds) > 1:
                pass
            else:
                cur_log_preds+=preds
                ema_preds = get_ema_preds(torch.tensor(cur_log_preds))
                n_tokens_generated+=1
                pred_tokens_remaining = ema_preds[-1]
                pred_percent_through = n_tokens_generated/(n_tokens_generated + pred_tokens_remaining)
                
                token = model.lm_head.output.argmax(dim=-1).tolist()
                token_str = model.tokenizer.decode(token[0][0], skip_special_tokens=True)
                generated_tokens.append(token_str)
                
                # Update progress bar
                progress_bar.value = pred_percent_through * 100
                
                # Update percentage label
                percentage_label.value = f"<b>{pred_percent_through*100:.1f}%</b>"
                
                # Update token display with all generated tokens
                tokens_html = " ".join([f"<span style='background-color: #e6f3ff; padding: 2px 4px; margin: 1px; border-radius: 3px;'>{token}</span>" for token in generated_tokens])
                token_display.value = f"<b>Generated tokens:</b><br>{tokens_html}<br><br><b>Latest:</b> '{token_str}' | <b>Predicted:</b> {pred_percent_through*100:.1f}% through"

# Connect button click to function
submit_button.on_click(on_submit_clicked)


VBox(children=(HTML(value='<h3>Text Generation Progress</h3>'), Textarea(value='', description='Prompt:', layo…

In [16]:
from IPython.display import clear_output
import ipywidgets as widgets
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px

# Create input text box for prompt
prompt_input = widgets.Textarea(
    value="",
    placeholder='Enter your prompt here...',
    description='Prompt:',
    layout=widgets.Layout(width='100%', height='80px')
)

# Create input text box for EMA factor
ema_input = widgets.FloatText(
    value=0.9,
    description='EMA Factor:',
    min=0.0,
    max=1.0,
    step=0.01,
    tooltip='Exponential Moving Average factor (0.0 to 1.0)',
    layout=widgets.Layout(width='200px')
)

# Create submit button
submit_button = widgets.Button(
    description='Generate Text',
    button_style='success',
    tooltip='Click to start text generation',
    icon='play'
)

# Create progress bar widget
progress_bar = widgets.FloatProgress(
    value=0,
    min=0,
    max=100,
    description='Progress:',
    bar_style='info',
    style={'bar_color': '#20B2AA'},
    orientation='horizontal'
)

# Create percentage label
percentage_label = widgets.HTML(
    value="<b>0.0%</b>",
    description='',
)

# Create horizontal box for progress bar and percentage
progress_row = widgets.HBox([progress_bar, percentage_label])

# Create text widget for token display
token_display = widgets.HTML(
    value="<b>Generated tokens will appear here...</b>",
    placeholder='',
    description='',
)

# Create graph widget for prediction history
graph_widget = go.FigureWidget()

# Create container for the widgets
progress_container = widgets.VBox([
    widgets.HTML("<h3>Text Generation Progress</h3>"),
    prompt_input,
    ema_input,
    submit_button,
    progress_row,
    token_display,
    widgets.HTML("<h4>Prediction History</h4>"),
    graph_widget
])

# Display the widget
display(progress_container)

# Global variables to store generation data
raw_log_preds = []
stored_generated_tokens = []
stored_n_tokens = 0

def get_color_for_change(percent_change, is_increase):
    """
    Generate a color based on the magnitude of percent change.
    Returns RGB color with gradient intensity based on change magnitude.
    """
    # Clamp percent_change to reasonable range (0-50%)
    clamped_change = min(percent_change, 50)
    
    # Normalize to 0-1 scale
    intensity = clamped_change / 50
    
    if is_increase:
        # Red gradient: from light blue (#e6f3ff) to deep red (#ff0000)
        # Start: (230, 243, 255), End: (255, 0, 0)
        r = int(230 + (255 - 230) * intensity)
        g = int(243 - 243 * intensity)
        b = int(255 - 255 * intensity)
    else:
        # Green gradient: from light blue (#e6f3ff) to deep green (#00cc00)
        # Start: (230, 243, 255), End: (0, 204, 0)
        r = int(230 - 230 * intensity)
        g = int(243 - 39 * intensity)  # Goes to 204
        b = int(255 - 255 * intensity)
    
    return f"#{r:02x}{g:02x}{b:02x}"

def update_graph_with_ema(ema_factor):
    """Update the graph and displays with a new EMA factor."""
    global raw_log_preds, stored_generated_tokens, stored_n_tokens
    
    if len(raw_log_preds) == 0:
        return
    
    # Recalculate predictions with new EMA
    ema_preds = get_ema_preds(torch.tensor(raw_log_preds), alpha=ema_factor)
    
    prediction_history = []
    token_counts = []
    
    for i, ema_pred in enumerate(ema_preds):
        n_tokens_generated = i + 1
        pred_tokens_remaining = ema_pred
        predicted_total_tokens = n_tokens_generated + pred_tokens_remaining
        prediction_history.append(predicted_total_tokens)
        token_counts.append(n_tokens_generated)
    
    # Update the graph
    if len(prediction_history) > 1:
        # Create hover text with context tokens
        hover_texts = []
        for i in range(len(stored_generated_tokens)):
            # Get 5 tokens before and after (if available)
            start_idx = max(0, i - 5)
            end_idx = min(len(stored_generated_tokens), i + 6)
            
            context_tokens = []
            for j in range(start_idx, end_idx):
                token_clean = stored_generated_tokens[j].replace('&amp;', '&').replace('&lt;', '<').replace('&gt;', '>').replace('&quot;', '"').replace('&#x27;', "'")
                if j == i:
                    context_tokens.append(f"<b>{token_clean}</b>")
                else:
                    context_tokens.append(token_clean)
            
            context_str = " ".join(context_tokens)
            hover_text = f"Token {i+1}: {context_str}<br>Predicted Total: {prediction_history[i]:.0f}"
            hover_texts.append(hover_text)
        
        # Update graph with new data
        graph_widget.data = []
        graph_widget.add_trace(go.Scatter(
            x=token_counts,
            y=prediction_history,
            mode='lines+markers',
            name='Predicted Total Tokens',
            line=dict(color='blue', width=2),
            marker=dict(size=6),
            hovertemplate='%{customdata}<extra></extra>',
            customdata=hover_texts
        ))
        
        # Add actual final point
        graph_widget.add_trace(go.Scatter(
            x=[stored_n_tokens],
            y=[stored_n_tokens],
            mode='markers',
            name='Actual Final',
            marker=dict(size=10, color='green', symbol='star'),
            hovertemplate=f'Actual completion: {stored_n_tokens} tokens<extra></extra>'
        ))
        
        # Update layout
        graph_widget.update_layout(
            title=f'Token Prediction vs Reality (EMA={ema_factor:.2f})',
            xaxis_title='Token Number',
            yaxis_title='Predicted Total Tokens',
            height=400,
            showlegend=True,
            hovermode='closest'
        )
        
        # Update token display with final statistics - preserve the colored tokens
        final_pred = prediction_history[-1]
        accuracy = (stored_n_tokens/final_pred)*100
        
        # Rebuild the colored token display with gradient colors
        highlighted_tokens = []
        for i, token in enumerate(stored_generated_tokens):
            # Calculate percentage change if we have history
            if i > 0 and i < len(prediction_history):
                change = prediction_history[i] - prediction_history[i-1]
                percent_change = abs(change / prediction_history[i-1]) * 100 if prediction_history[i-1] != 0 else 0
                highlight_color = get_color_for_change(percent_change, change > 0)
            else:
                highlight_color = "#e6f3ff"  # Default light blue for first token
            
            highlighted_tokens.append(f"<span style='background-color: {highlight_color}; padding: 2px 4px; margin: 1px; border-radius: 3px;'>{token}</span>")
        
        tokens_html = " ".join(highlighted_tokens)
        token_display.value = f"<b>Generation complete!</b><br><b>Total tokens:</b> {stored_n_tokens}<br><b>Final prediction:</b> {final_pred:.0f} tokens<br><b>Accuracy:</b> {accuracy:.1f}%<br><b>Current EMA:</b> {ema_factor:.2f}<br><br><b>Generated tokens:</b><br>{tokens_html}<br><br><small><b>Color coding:</b> Gradient from <span style='background-color: #e6f3ff; padding: 2px;'>neutral</span> to <span style='background-color: #ff8888; padding: 2px;'>red (increases)</span> or <span style='background-color: #88ff88; padding: 2px;'>green (decreases)</span> based on prediction change magnitude</small>"

def on_ema_changed(change):
    """Handle EMA input changes."""
    update_graph_with_ema(change['new'])

# Connect EMA input to update function
ema_input.observe(on_ema_changed, names='value')

def on_submit_clicked(b):
    global raw_log_preds, stored_generated_tokens, stored_n_tokens
    
    # Reset progress
    progress_bar.value = 0
    percentage_label.value = "<b>0.0%</b>"
    token_display.value = "<b>Generating...</b>"
    
    # Clear the graph
    graph_widget.data = []
    
    # Reset global storage
    raw_log_preds = []
    stored_generated_tokens = []
    stored_n_tokens = 0
    
    # Initialize lists to track predictions over time
    prediction_history = []
    token_counts = []
    
    # Get prompt from input
    prompt = prompt_input.value
    # Get EMA factor from input
    ema_factor = ema_input.value
    
    # Apply chat template
    prompt = model.tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
    cur_log_preds = []
    n_tokens_generated = 0
    generated_tokens = []

    with model.generate(prompt, max_new_tokens=32768, do_sample=True) as tracer:
        # Call .all() to apply intervention to each new token
        with tracer.all():
            activations = model.model.layers[15].output[0]
            if len(activations.shape) == 1:
                activations = activations.unsqueeze(0)
            
            # Save predictions within nnsight context
            preds_saved = get_log_preds(activations, weight_tensor).save()
            token_saved = model.lm_head.output.argmax(dim=-1).save()
            
            preds = preds_saved.tolist()
            if len(preds) > 1:
                pass
            else:
                cur_log_preds+=preds
                raw_log_preds.append(preds[0])  # Store raw predictions globally
                
                ema_preds = get_ema_preds(torch.tensor(cur_log_preds), alpha=ema_factor)
                n_tokens_generated+=1
                pred_tokens_remaining = ema_preds[-1]
                predicted_total_tokens = n_tokens_generated + pred_tokens_remaining
                pred_percent_through = n_tokens_generated/(n_tokens_generated + pred_tokens_remaining)
                
                # Store prediction data for highlighting
                prediction_history.append(predicted_total_tokens)
                token_counts.append(n_tokens_generated)
                
                token = token_saved.tolist()
                token_str = model.tokenizer.decode(token[0][0], skip_special_tokens=False)
                # Escape HTML entities in token string for safe display
                token_str_escaped = token_str.replace('&', '&amp;').replace('<', '&lt;').replace('>', '&gt;').replace('"', '&quot;').replace("'", '&#x27;')
                print(token_str)
                generated_tokens.append(token_str_escaped)
                stored_generated_tokens.append(token_str_escaped)  # Store globally
                
                # Update progress bar
                progress_bar.value = pred_percent_through * 100
                
                # Update percentage label
                percentage_label.value = f"<b>{pred_percent_through*100:.1f}%</b>"
                
                # Create highlighted token display with gradient colors
                highlighted_tokens = []
                for i, token in enumerate(generated_tokens):
                    # Calculate percentage change if we have history
                    if i > 0 and i < len(prediction_history):
                        change = prediction_history[i] - prediction_history[i-1]
                        percent_change = abs(change / prediction_history[i-1]) * 100 if prediction_history[i-1] != 0 else 0
                        highlight_color = get_color_for_change(percent_change, change > 0)
                    else:
                        highlight_color = "#e6f3ff"  # Default light blue for first token
                    
                    highlighted_tokens.append(f"<span style='background-color: {highlight_color}; padding: 2px 4px; margin: 1px; border-radius: 3px;'>{token}</span>")
                
                tokens_html = " ".join(highlighted_tokens)
                token_display.value = f"<b>Generated tokens:</b><br>{tokens_html}<br><br><b>Latest:</b> '{token_str_escaped}' | <b>Predicted Total:</b> {predicted_total_tokens:.0f} tokens | <b>Progress:</b> {pred_percent_through*100:.1f}%<br><br><small><b>Color coding:</b> Gradient from <span style='background-color: #e6f3ff; padding: 2px;'>neutral</span> to <span style='background-color: #ff8888; padding: 2px;'>red (increases)</span> or <span style='background-color: #88ff88; padding: 2px;'>green (decreases)</span> based on prediction change magnitude</small>"
                
                # Update the graph with current predictions
                if len(prediction_history) > 1:
                    # Create hover text with context tokens
                    hover_texts = []
                    for i in range(len(generated_tokens)):
                        # Get 5 tokens before and after (if available)
                        start_idx = max(0, i - 5)
                        end_idx = min(len(generated_tokens), i + 6)
                        
                        context_tokens = []
                        for j in range(start_idx, end_idx):
                            token_clean = generated_tokens[j].replace('&amp;', '&').replace('&lt;', '<').replace('&gt;', '>').replace('&quot;', '"').replace('&#x27;', "'")
                            if j == i:
                                context_tokens.append(f"<b>{token_clean}</b>")
                            else:
                                context_tokens.append(token_clean)
                        
                        context_str = " ".join(context_tokens)
                        hover_text = f"Token {i+1}: {context_str}<br>Predicted Total: {prediction_history[i]:.0f}"
                        hover_texts.append(hover_text)
                    
                    # Update graph with new data
                    graph_widget.data = []
                    graph_widget.add_trace(go.Scatter(
                        x=token_counts,
                        y=prediction_history,
                        mode='lines+markers',
                        name='Predicted Total Tokens',
                        line=dict(color='blue', width=2),
                        marker=dict(size=6),
                        hovertemplate='%{customdata}<extra></extra>',
                        customdata=hover_texts
                    ))
                    
                    # Add a horizontal line showing actual tokens generated so far
                    graph_widget.add_trace(go.Scatter(
                        x=[token_counts[0], token_counts[-1]],
                        y=[n_tokens_generated, n_tokens_generated],
                        mode='lines',
                        name='Current Progress',
                        line=dict(color='red', width=2, dash='dash'),
                        hovertemplate='Current tokens generated: %{y}<extra></extra>'
                    ))
                    
                    # Update layout
                    graph_widget.update_layout(
                        title='Token Prediction Over Time',
                        xaxis_title='Token Number',
                        yaxis_title='Predicted Total Tokens',
                        height=400,
                        showlegend=True,
                        hovermode='closest'
                    )
    
    # Store final token count globally
    stored_n_tokens = n_tokens_generated
    
    # After generation is complete, display prediction history
    print("\n" + "="*80)
    print("FINAL PREDICTION HISTORY")
    print("="*80)
    
    print(f"\nGenerated {n_tokens_generated} tokens total")
    print(f"Final prediction was {predicted_total_tokens:.0f} tokens")
    print(f"Accuracy: {(n_tokens_generated/predicted_total_tokens)*100:.1f}%")
    print(f"EMA factor used: {ema_factor}")
    
    print("\nFull Prediction History:")
    print("Token# | Predicted Total | Change | % Change | Token")
    print("-" * 70)
    
    for i, (count, pred_total, token) in enumerate(zip(token_counts, prediction_history, generated_tokens)):
        if i == 0:
            change = 0
            percent_change = 0
        else:
            change = pred_total - prediction_history[i-1]
            percent_change = abs(change / prediction_history[i-1]) * 100 if prediction_history[i-1] != 0 else 0
        
        # Display token in a safe way for console output (remove escaping for print)
        token_for_print = token.replace('&amp;', '&').replace('&lt;', '<').replace('&gt;', '>').replace('&quot;', '"').replace('&#x27;', "'")
        
        # Highlight tokens with large percentage changes (>10%)
        if percent_change > 10:
            if change > 0:
                # Large increase - bold green
                token_display_str = f"\033[1m\033[92m{token_for_print}\033[0m"
                change_str = f"\033[1m\033[92m+{change:.0f}\033[0m"
            else:
                # Large decrease - bold red  
                token_display_str = f"\033[1m\033[91m{token_for_print}\033[0m"
                change_str = f"\033[1m\033[91m{change:.0f}\033[0m"
        else:
            token_display_str = token_for_print
            if change > 0:
                change_str = f"+{change:.0f}"
            else:
                change_str = f"{change:.0f}"
        
        print(f"{count:5d}  | {pred_total:13.0f}   | {change_str:8s} | {percent_change:6.1f}%  | {token_display_str}")
    
    # Update graph with final results using current EMA
    update_graph_with_ema(ema_factor)

# Connect button click to function
submit_button.on_click(on_submit_clicked)

VBox(children=(HTML(value='<h3>Text Generation Progress</h3>'), Textarea(value='', description='Prompt:', layo…

In [17]:
!pip install anywidget

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


## Example Usage

Test different types of tokens to see their effects on the probe's predictions.

In [20]:
def analyze_token_effects(base_text, candidate_tokens, use_chat_template=False):
    """
    Analyze how appending different tokens affects the probe's prediction.
    
    Args:
        base_text: The base text/context (str)
        candidate_tokens: List of tokens/words to test appending (list of str)
        use_chat_template: Whether to apply chat template to base_text (bool)
    
    Returns:
        dict with:
            - 'baseline_prediction': Prediction for base text alone
            - 'token_effects': List of dicts with token, prediction, and changes
    """
    # Apply chat template if requested
    if use_chat_template:
        base_text = model.tokenizer.apply_chat_template(
            [{"role": "user", "content": base_text}], 
            tokenize=False, 
            add_generation_prompt=True
        )
    
    # Get baseline prediction for base text
    with model.trace(base_text):
        activations = model.model.layers[15].output[0]
        if len(activations.shape) == 1:
            activations = activations.unsqueeze(0)
        baseline_pred_tensor = get_log_preds(activations, weight_tensor).exp()[-1].save()
    
    baseline_pred = baseline_pred_tensor.item()
    
    print(f"Baseline text: '{base_text}'")
    print(f"Baseline prediction: {baseline_pred:.2f} tokens remaining\n")
    
    # Test each candidate token
    results = []
    for token in candidate_tokens:
        # Append token to base text
        test_text = base_text + token
        
        # Get prediction for extended text
        with model.trace(test_text):
            activations = model.model.layers[15].output[0]
            if len(activations.shape) == 1:
                activations = activations.unsqueeze(0)
            token_pred_tensor = get_log_preds(activations, weight_tensor).exp()[-1].save()
        
        token_pred = token_pred_tensor.item()
        
        # Calculate changes
        raw_change = token_pred - baseline_pred
        percent_change = (raw_change / baseline_pred * 100) if baseline_pred != 0 else 0
        
        results.append({
            'token': token,
            'prediction': token_pred,
            'raw_change': raw_change,
            'percent_change': percent_change
        })
        
        print(f"Token: '{token:20s}' | Prediction: {token_pred:6.2f} | Change: {raw_change:+7.2f} ({percent_change:+6.1f}%)")
    
    return {
        'baseline_prediction': baseline_pred,
        'baseline_text': base_text,
        'token_effects': results
    }


def visualize_token_effects(analysis_results, sort_by='raw_change'):
    """
    Visualize token effects using plotly.
    
    Args:
        analysis_results: Output from analyze_token_effects()
        sort_by: 'raw_change', 'percent_change', or 'prediction'
    """
    import plotly.graph_objects as go
    
    baseline = analysis_results['baseline_prediction']
    effects = analysis_results['token_effects']
    
    # Sort results
    effects_sorted = sorted(effects, key=lambda x: x[sort_by], reverse=True)
    
    tokens = [e['token'] for e in effects_sorted]
    predictions = [e['prediction'] for e in effects_sorted]
    changes = [e['raw_change'] for e in effects_sorted]
    percent_changes = [e['percent_change'] for e in effects_sorted]
    
    # Color based on change direction and magnitude
    colors = []
    for change, pct in zip(changes, percent_changes):
        if change > 0:
            # Red gradient for increases
            intensity = min(abs(pct) / 50, 1.0)
            r = int(230 + (255 - 230) * intensity)
            g = int(243 - 243 * intensity)
            b = int(255 - 255 * intensity)
        else:
            # Green gradient for decreases
            intensity = min(abs(pct) / 50, 1.0)
            r = int(230 - 230 * intensity)
            g = int(243 - 39 * intensity)
            b = int(255 - 255 * intensity)
        colors.append(f'rgb({r},{g},{b})')
    
    # Create figure
    fig = go.Figure()
    
    # Add bars for predictions
    fig.add_trace(go.Bar(
        x=tokens,
        y=predictions,
        marker=dict(color=colors, line=dict(color='black', width=1)),
        text=[f"{p:.1f}<br>({c:+.1f})" for p, c in zip(predictions, changes)],
        textposition='outside',
        hovertemplate='<b>%{x}</b><br>Prediction: %{y:.2f} tokens<br>Change: %{customdata[0]:+.2f} (%{customdata[1]:+.1f}%)<extra></extra>',
        customdata=list(zip(changes, percent_changes))
    ))
    
    # Add baseline line
    fig.add_hline(
        y=baseline, 
        line_dash="dash", 
        line_color="blue",
        annotation_text=f"Baseline: {baseline:.2f}",
        annotation_position="right"
    )
    
    # Update layout
    fig.update_layout(
        title=f'Token Effects on Prediction<br><sub>Base: "{analysis_results["baseline_text"][:50]}..."</sub>',
        xaxis_title='Token',
        yaxis_title='Predicted Tokens Remaining',
        height=500,
        showlegend=False,
        hovermode='x'
    )
    
    fig.show()
    
    # Print summary statistics
    print("\n" + "="*60)
    print("SUMMARY STATISTICS")
    print("="*60)
    print(f"Baseline prediction: {baseline:.2f} tokens")
    print(f"\nMost increase: '{effects_sorted[0]['token']}' ({effects_sorted[0]['raw_change']:+.2f}, {effects_sorted[0]['percent_change']:+.1f}%)")
    print(f"Most decrease: '{effects_sorted[-1]['token']}' ({effects_sorted[-1]['raw_change']:+.2f}, {effects_sorted[-1]['percent_change']:+.1f}%)")
    print(f"\nAverage change: {sum(e['raw_change'] for e in effects) / len(effects):.2f}")
    print(f"Std dev: {(sum((e['raw_change'] - sum(e['raw_change'] for e in effects) / len(effects))**2 for e in effects) / len(effects))**0.5:.2f}")

In [21]:
# Example: Test uncertain vs certain words
base = "The weather tomorrow will be"

# Test various completion tokens
candidates = [
    " alternatively", 
    " sunny",      # Definite
    " rainy",      # Definite
    " probably",   # Uncertain
    " maybe",      # Uncertain
    " definitely", # Certain
    " perhaps",    # Uncertain
    " certainly",  # Certain
    " might",      # Uncertain
    " could",      # Uncertain
]

results = analyze_token_effects(base, candidates, use_chat_template=False)
visualize_token_effects(results)

Baseline text: 'The weather tomorrow will be'
Baseline prediction: 416.00 tokens remaining

Token: ' alternatively      ' | Prediction: 664.00 | Change: +248.00 ( +59.6%)
Token: ' sunny              ' | Prediction: 203.00 | Change: -213.00 ( -51.2%)
Token: ' rainy              ' | Prediction: 296.00 | Change: -120.00 ( -28.8%)
Token: ' probably           ' | Prediction: 1064.00 | Change: +648.00 (+155.8%)
Token: ' maybe              ' | Prediction: 1320.00 | Change: +904.00 (+217.3%)
Token: ' rainy              ' | Prediction: 296.00 | Change: -120.00 ( -28.8%)
Token: ' probably           ' | Prediction: 1064.00 | Change: +648.00 (+155.8%)
Token: ' maybe              ' | Prediction: 1320.00 | Change: +904.00 (+217.3%)
Token: ' definitely         ' | Prediction: 776.00 | Change: +360.00 ( +86.5%)
Token: ' perhaps            ' | Prediction: 752.00 | Change: +336.00 ( +80.8%)
Token: ' certainly          ' | Prediction: 664.00 | Change: +248.00 ( +59.6%)
Token: ' might              ' | Pre


SUMMARY STATISTICS
Baseline prediction: 416.00 tokens

Most increase: ' maybe' (+904.00, +217.3%)
Most decrease: ' might' (-268.00, -64.4%)

Average change: 260.70
Std dev: 356.35


In [22]:
def generate_text(prompt, max_tokens=100, use_chat_template=True):
    """
    Generate text from the model.
    
    Args:
        prompt: The input prompt (str)
        max_tokens: Maximum tokens to generate (int)
        use_chat_template: Whether to apply chat template (bool)
    
    Returns:
        tuple: (formatted_prompt, token_list, full_text)
            - formatted_prompt: The prompt with chat template applied (if requested)
            - token_list: List of generated token strings
            - full_text: Complete text (formatted_prompt + generated tokens)
    """
    # Apply chat template if requested
    if use_chat_template:
        formatted_prompt = model.tokenizer.apply_chat_template(
            [{"role": "user", "content": prompt}], 
            tokenize=False, 
            add_generation_prompt=True
        )
    else:
        formatted_prompt = prompt
    
    print(f"Generating with prompt: '{prompt}'...")
    print(f"Formatted prompt length: {len(formatted_prompt)} chars\n")
    
    # Use the underlying model's generate method directly
    print("Generated text: ", end='')
    
    # Tokenize and generate using HuggingFace model
    input_ids = model.tokenizer.encode(formatted_prompt, return_tensors='pt').to(device)
    
    # Access the wrapped model directly (nnsight wraps the HF model in ._model)
    with torch.no_grad():
        output_ids = model._model.generate(
            input_ids, 
            max_new_tokens=max_tokens,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=model.tokenizer.eos_token_id
        )
    
    # Extract only generated tokens (skip prompt)
    generated_ids = output_ids[0][input_ids.shape[1]:]
    
    # Decode each token individually
    generated_tokens = []
    for token_id in generated_ids:
        token_str = model.tokenizer.decode([token_id.item()], skip_special_tokens=False)
        generated_tokens.append(token_str)
        print(token_str, end='', flush=True)
    
    print(f"\n\nGeneration complete! Generated {len(generated_tokens)} tokens")
    full_text = formatted_prompt + ''.join(generated_tokens)
    return formatted_prompt, generated_tokens, full_text

def slice_tokens_to_text(formatted_prompt, token_list, slice_at):
    """
    Helper function to slice token list and combine with prompt.
    
    Args:
        formatted_prompt: The formatted prompt string
        token_list: List of generated token strings
        slice_at: Index to slice at (tokens[:slice_at])
    
    Returns:
        str: formatted_prompt + sliced tokens joined
    """
    return formatted_prompt + ''.join(token_list[:slice_at])

In [59]:
# Example 2: Generate and test at different slicing points
prompt = "What is 2+2?"
formatted_prompt, tokens, full_text = generate_text(prompt, max_tokens=300)

Generating with prompt: 'What is 2+2?'...
Formatted prompt length: 62 chars

Generated text: <think>
Okay, the user is asking "What is 2+2?" That's<think>
Okay, the user is asking "What is 2+2?" That's a straightforward math question. Let me start by recalling basic arithmetic. In standard addition, 2 plus 2 equals a straightforward math question. Let me start by recalling basic arithmetic. In standard addition, 2 plus 2 equals 4. But wait, maybe they're looking for something more? Like in different contexts?

Hmm, maybe 4. But wait, maybe they're looking for something more? Like in different contexts?

Hmm, maybe they want to know if there's any alternative interpretation. For example, in some contexts, like in they want to know if there's any alternative interpretation. For example, in some contexts, like in a different number base. If it's base 3, 2+2 would be 11 (since 2+2=4 in decimal, which is  a different number base. If it's base 3, 2+2 would be 11 (since 2+2=4 in decimal, whic

In [60]:
for index, n in enumerate(tokens):
    print(index, n)

0 <think>
1 

2 Okay
3 ,
4  the
5  user
6  is
7  asking
8  "
9 What
10  is
11  
12 2
13 +
14 2
15 ?"
16  That
17 's
18  a
19  straightforward
20  math
21  question
22 .
23  Let
24  me
25  start
26  by
27  recalling
28  basic
29  arithmetic
30 .
31  In
32  standard
33  addition
34 ,
35  
36 2
37  plus
38  
39 2
40  equals
41  
42 4
43 .
44  But
45  wait
46 ,
47  maybe
48  they
49 're
50  looking
51  for
52  something
53  more
54 ?
55  Like
56  in
57  different
58  contexts
59 ?


60 Hmm
61 ,
62  maybe
63  they
64  want
65  to
66  know
67  if
68  there
69 's
70  any
71  alternative
72  interpretation
73 .
74  For
75  example
76 ,
77  in
78  some
79  contexts
80 ,
81  like
82  in
83  a
84  different
85  number
86  base
87 .
88  If
89  it
90 's
91  base
92  
93 3
94 ,
95  
96 2
97 +
98 2
99  would
100  be
101  
102 1
103 1
104  (
105 since
106  
107 2
108 +
109 2
110 =
111 4
112  in
113  decimal
114 ,
115  which
116  is
117  
118 1
119 1
120  in
121  base
122  
123 3
124 ).
125  But
126  t

In [66]:
base = ''.join(tokens[:42])
base

'<think>\nOkay, the user is asking "What is 2+2?" That\'s a straightforward math question. Let me start by recalling basic arithmetic. In standard addition, 2 plus 2 equals '

In [28]:
for i in range(0, 10000, 1000):
    print(1011988 - i)

1011988
1010988
1009988
1008988
1007988
1006988
1005988
1004988
1003988
1002988


In [None]:
# Check what the chat template produces
user_message = "Tell me a story"
formatted = model.tokenizer.apply_chat_template(
    [{"role": "user", "content": user_message}], 
    tokenize=False, 
    add_generation_prompt=True
)
print("Full formatted prompt:")
print(repr(formatted))
print("\n" + "="*80)
print("Visual:")
print(formatted)
print("="*80)

# Find the assistant suffix (what comes after the user message)
user_end_idx = formatted.find(user_message) + len(user_message)
assistant_suffix = formatted[user_end_idx:]
print(f"\nAssistant suffix (after user message): {repr(assistant_suffix)}")

In [40]:
model.tokenizer.apply_chat_template(
    [{"role": "user", "content": "hello world"}], 
    tokenize=False, 
    add_generation_prompt=True
)

'<|im_start|>user\nhello world<|im_end|>\n<|im_start|>assistant\n'

In [42]:
def insert_template(prompt):
    return f'<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n'

In [55]:
base = insert_template("Which is the better university, UCL or KCL?")

# Test how different next words affect the prediction at this point
test_words = [
    "UCL",
    "KCL"
]

results = analyze_token_effects(base, test_words, use_chat_template=False)
visualize_token_effects(results)

Baseline text: '<|im_start|>user
Which is the better university, UCL or KCL?<|im_end|>
<|im_start|>assistant
'
Baseline prediction: 444.00 tokens remaining

Token: 'UCL                 ' | Prediction: 1064.00 | Change: +620.00 (+139.6%)
Token: 'KCL                 ' | Prediction: 1280.00 | Change: +836.00 (+188.3%)



SUMMARY STATISTICS
Baseline prediction: 444.00 tokens

Most increase: 'KCL' (+836.00, +188.3%)
Most decrease: 'UCL' (+620.00, +139.6%)

Average change: 728.00
Std dev: 108.00


In [None]:
# Example 3: Test at multiple slice points from the same generation
prompt = "The solution to this problem"
formatted_prompt, tokens, full_text = generate_text(prompt, max_tokens=40)

# Test at different points in the generation
test_points = [5, 10, 15, 20]
test_continuations = [" is", " might", " could", " definitely"]

print("\n" + "="*70)
print("TESTING AT MULTIPLE POINTS")
print("="*70)

for point in test_points:
    if point < len(tokens):
        base = slice_tokens_to_text(formatted_prompt, tokens, point)
        print(f"\n\nPoint {point}: '{base[-50:]}...'")  # Show last 50 chars
        
        results = analyze_token_effects(base, test_continuations, use_chat_template=False)
        print(f"Most impactful: {max(results['token_effects'], key=lambda x: abs(x['raw_change']))['token']}")

In [None]:
# Example 4: Super flexible workflow - generate once, analyze anywhere
prompt = "I think the answer"
formatted_prompt, tokens, full_text = generate_text(prompt, max_tokens=25)

# Now you have tokens as a list - slice however you want!
print(f"\n\nYou can now use Python slicing:")
print(f"tokens[:5]   -> first 5 tokens")
print(f"tokens[5:10] -> tokens 5-10") 
print(f"tokens[-5:]  -> last 5 tokens")
print(f"tokens[::2]  -> every other token")

# Pick any slice point you want
my_slice = 8  # Change this!
base = slice_tokens_to_text(formatted_prompt, tokens, my_slice)

print(f"\n\nAnalyzing at token {my_slice}:")
print(f"Base text: '{base}'")

# Your test tokens
my_tests = [" is", " was", " might", " should", " could", " will"]
results = analyze_token_effects(base, my_tests, use_chat_template=False)
visualize_token_effects(results)