Just run all these cells

# Imports

In [1]:
'''!pip install nnsight -q
!pip install einops -q
!pip install ipywidgets -q
!pip install hf_transfer -q
!pip install plotly -q
!pip install anywidget -q
# Ensure widget extensions are enabled
!jupyter nbextension enable --py widgetsnbextension --sys-prefix 2>/dev/null || true'''

'!pip install nnsight -q\n!pip install einops -q\n!pip install ipywidgets -q\n!pip install hf_transfer -q\n!pip install plotly -q\n!pip install anywidget -q\n# Ensure widget extensions are enabled\n!jupyter nbextension enable --py widgetsnbextension --sys-prefix 2>/dev/null || true'

In [2]:
!jupyter nbextension enable --py widgetsnbextension --sys-prefix 2>/dev/null || true

In [None]:
from nnsight import LanguageModel
from einops import einsum
import torch
import ipywidgets as widgets
from IPython.display import display
import time
import numpy as np
import nnsight

In [2]:
# Enable widget display in Jupyter
from IPython.display import display, HTML
import warnings
warnings.filterwarnings('ignore')

# This ensures widgets are properly rendered
display(HTML("<script>console.log('Widgets initialized')</script>"))

# Load model

In [3]:
device = 'cuda'

In [4]:
weight_tensor = torch.load('rollouts-big/beta_torch.pt').to(torch.bfloat16)
model_name = 'Qwen/Qwen3-4B'

In [None]:
# Disable compilation warnings since we use dynamic control flow
model = LanguageModel(model_name, device_map=device, dtype=torch.bfloat16, dispatch=True)

In [6]:
def get_ema_preds(log_preds, alpha=0.5):
    given_alpha = alpha
    preds_list = log_preds.tolist()
    
    ema_preds = []
    cur_ema = None
    for i,pred in enumerate(preds_list):
        # Use a smooth transition from 0.5 to given_alpha, reaching given_alpha at 200 tokens
        alpha = given_alpha
        if cur_ema is None:
            cur_ema = pred
        else:
            cur_ema = alpha*(cur_ema-1) + (1-alpha)*pred #-1 because we have stepped one token
        ema_preds.append(cur_ema)
    return ema_preds
def get_log_preds(activation, weight_tensor):
    return einsum(activation, weight_tensor, 'seq d_model, d_model -> seq')

# Vibe coded UIs

In [8]:
# Create input text box for prompt
prompt_input = widgets.Textarea(
    value="",
    placeholder='Enter your prompt here...',
    description='Prompt:',
    layout=widgets.Layout(width='100%', height='80px')
)

# Create submit button
submit_button = widgets.Button(
    description='Generate Text',
    button_style='success',
    tooltip='Click to start text generation',
    icon='play'
)

# Create progress bar widget
progress_bar = widgets.FloatProgress(
    value=0,
    min=0,
    max=100,
    description='Progress:',
    bar_style='info',
    style={'bar_color': '#20B2AA'},
    orientation='horizontal'
)

# Create percentage label
percentage_label = widgets.HTML(
    value="<b>0.0%</b>",
    description='',
)

# Create horizontal box for progress bar and percentage
progress_row = widgets.HBox([progress_bar, percentage_label])

# Create text widget for token display
token_display = widgets.HTML(
    value="<b>Generated tokens will appear here...</b>",
    placeholder='',
    description='',
)

# Create container for the widgets
progress_container = widgets.VBox([
    widgets.HTML("<h3>Text Generation Progress</h3>"),
    prompt_input,
    submit_button,
    progress_row,
    token_display
])

# Display the widget
display(progress_container)

def on_submit_clicked(b):
    # Reset progress
    progress_bar.value = 0
    percentage_label.value = "<b>0.0%</b>"
    token_display.value = "<b>Generating...</b>"
    
    # Get prompt from input
    prompt = prompt_input.value
    # Apply chat template
    prompt = model.tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
    cur_log_preds = []
    n_tokens_generated = 0
    generated_tokens = []

    with model.generate(prompt, max_new_tokens=32768, do_sample=True) as tracer:
        # Call .all() to apply intervention to each new token
        with tracer.all():
            activations = model.model.layers[15].output[0]
            if len(activations.shape) == 1:
                activations = activations.unsqueeze(0)
            preds = get_log_preds(activations, weight_tensor).tolist()
            if len(preds) > 1:
                pass
            else:
                cur_log_preds+=preds
                ema_preds = get_ema_preds(torch.tensor(cur_log_preds))
                n_tokens_generated+=1
                pred_tokens_remaining = ema_preds[-1]
                pred_percent_through = n_tokens_generated/(n_tokens_generated + pred_tokens_remaining)
                
                token = model.lm_head.output.argmax(dim=-1).tolist()
                token_str = model.tokenizer.decode(token[0][0], skip_special_tokens=False)
                generated_tokens.append(token_str)
                
                # Update progress bar
                progress_bar.value = pred_percent_through * 100
                
                # Update percentage label
                percentage_label.value = f"<b>{pred_percent_through*100:.1f}%</b>"
                
                # Update token display with all generated tokens
                tokens_html = " ".join([f"<span style='background-color: #e6f3ff; padding: 2px 4px; margin: 1px; border-radius: 3px;'>{token}</span>" for token in generated_tokens])
                token_display.value = f"<b>Generated tokens:</b><br>{tokens_html}<br><br><b>Latest:</b> '{token_str}' | <b>Predicted:</b> {pred_percent_through*100:.1f}% through"

# Connect button click to function
submit_button.on_click(on_submit_clicked)


VBox(children=(HTML(value='<h3>Text Generation Progress</h3>'), Textarea(value='', description='Prompt:', layo…

In [29]:
# Create input text box for prompt
prompt_input = widgets.Textarea(
    value="",
    placeholder='Enter your prompt here...',
    description='Prompt:',
    layout=widgets.Layout(width='100%', height='80px')
)

# Create submit button
submit_button = widgets.Button(
    description='Generate Text',
    button_style='success',
    tooltip='Click to start text generation',
    icon='play'
)

# Create progress bar widget
progress_bar = widgets.FloatProgress(
    value=0,
    min=0,
    max=100,
    description='Progress:',
    bar_style='info',
    style={'bar_color': '#20B2AA'},
    orientation='horizontal'
)

# Create percentage label
percentage_label = widgets.HTML(
    value="<b>0.0%</b>",
    description='',
)

# Create horizontal box for progress bar and percentage
progress_row = widgets.HBox([progress_bar, percentage_label])

# Create text widget for token display
token_display = widgets.HTML(
    value="<b>Generated tokens will appear here...</b>",
    placeholder='',
    description='',
)

# Create output widget for the graph
graph_output = widgets.Output()

# Create container for the widgets - graph now above token display
progress_container = widgets.VBox([
    widgets.HTML("<h3>Text Generation Progress</h3>"),
    prompt_input,
    submit_button,
    progress_row,
    graph_output,
    token_display
])

# Display the widget
display(progress_container)

# Initialize FigureWidget outside the loop (Option 2)
import plotly.graph_objects as go
fig_widget = None

def get_color_for_change(change, max_change=800.0):
    """
    Get color based on change in predictions.
    Positive change (increase) -> Red
    Negative change (decrease) -> Green
    """
    # Normalize change to [-1, 1] range
    normalized = max(min(change / max_change, 1.0), -1.0)
    
    if normalized > 0:  # Increase - Red
        # Interpolate from light to dark red
        intensity = int(255 * (1 - normalized * 0.7))  # 255 to ~77
        return f'rgb(255, {intensity}, {intensity})'
    else:  # Decrease - Green
        # Interpolate from light to dark green
        intensity = int(255 * (1 + normalized * 0.7))  # 255 to ~77
        return f'rgb({intensity}, 255, {intensity})'

def calculate_ema(values, alpha=0.2):
    """Calculate exponential moving average
    Lower alpha = smoother (e.g., 0.1-0.3 for good smoothing)
    Higher alpha = follows data more closely (e.g., 0.8-1.0)
    """
    if not values:
        return []
    ema = [values[0]]
    for val in values[1:]:
        ema.append(alpha * val + (1 - alpha) * ema[-1])
    return ema

def on_submit_clicked(b):
    import html
    global fig_widget
    
    # Reset progress
    progress_bar.value = 0
    percentage_label.value = "<b>0.0%</b>"
    token_display.value = "<b>Generating...</b>"
    
    # Initialize FigureWidget once with both traces
    with graph_output:
        graph_output.clear_output(wait=True)
        fig_widget = go.FigureWidget()
        # Add EMA trace first (will be behind)
        fig_widget.add_trace(go.Scatter(
            x=[],
            y=[],
            mode='lines',
            name='EMA (smoothed)',
            line=dict(color='rgba(255, 100, 100, 0.6)', width=3),
            showlegend=True
        ))
        # Add main trace on top
        fig_widget.add_trace(go.Scatter(
            x=[],
            y=[],
            mode='lines',
            name='Predicted Tokens Remaining',
            line=dict(color='rgba(32, 178, 170, 0.8)', width=2),
            showlegend=True
        ))
        fig_widget.update_layout(
            title='Predicted Tokens Remaining Over Time',
            xaxis_title='Token Number',
            yaxis_title='Predicted Tokens Remaining',
            height=300,
            margin=dict(l=50, r=20, t=40, b=40),
            transition_duration=0,
            legend=dict(x=0.7, y=1, bgcolor='rgba(255,255,255,0.8)')
        )
        display(fig_widget)
    
    # Get prompt from input
    prompt = prompt_input.value
    # Apply chat template
    prompt = model.tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
    raw_preds = []  # Store raw predictions without EMA
    cur_log_preds = []  # Store predictions after get_ema_preds (for progress bar)
    generated_tokens = []
    token_changes = []  # Track changes for color coding

    with model.generate(prompt, max_new_tokens=32768, do_sample=True) as tracer:
        # Call .all() to apply intervention to each new token
        with tracer.all():
            activations = model.model.layers[15].output[0]
            if activations.shape[0] == 1:
                if len(raw_preds) > 0:
                    preds = get_log_preds(activations, weight_tensor)
                    exp_preds = torch.exp(preds).item()

                    raw_preds.append(exp_preds)
                
                    cur_log_preds = calculate_ema(raw_preds, alpha=0.2)
                
                    change = cur_log_preds[-1] - cur_log_preds[-2]
                    token_changes.append(change)

                    token = model.lm_head.output.argmax(dim=-1).tolist()
                    token_str = model.tokenizer.decode(token[0][0], skip_special_tokens=False)
                    generated_tokens.append(token_str)
            
                    pred_percent_through = len(raw_preds) / (len(raw_preds) + cur_log_preds[-1]) * 100
                    progress_bar.value = pred_percent_through
            
                    percentage_label.value = f"<b>{pred_percent_through:.1f}%</b>"
            
                    start_idx = max(0, len(generated_tokens) - 2000)
                    display_tokens = generated_tokens[start_idx:]
                    display_changes = token_changes[start_idx:]
                
                    # Calculate dynamic max_change based on actual data
                    max_change = max(abs(c) for c in token_changes) if token_changes else 800.0
                    max_change = max(max_change, 100.0)  # Ensure a minimum threshold
                
                    tokens_html = " ".join([
                        f"<span style='background-color: {get_color_for_change(change, max_change)}; padding: 2px 4px; margin: 1px; border-radius: 3px;' title='Token #{start_idx + i + 1} | Change: {change:+.2f}'>{html.escape(token)}</span>" 
                        for i, (token, change) in enumerate(zip(display_tokens, display_changes))
                    ])
                
                    # Add indicator if tokens are truncated
                    truncated_msg = f"<i>(Showing last 2000 of {len(generated_tokens)} tokens)</i><br>" if len(generated_tokens) > 2000 else ""
                
                    token_display.value = f"<b>Generated tokens:</b> <span style='font-size: 0.9em;'>(🔴 increase / 🟢 decrease)</span><br>{truncated_msg}{tokens_html}<br><br><b>Latest:</b> '{html.escape(token_str)}' | <b>Predicted:</b> {pred_percent_through:.1f}% through"
                
                    if len(cur_log_preds) % 10 == 0 or len(cur_log_preds) == 1:
                        x_values = list(range(1, len(cur_log_preds) + 1))
                    
                        with fig_widget.batch_update():
                            # Trace 0 is EMA (red line)
                            fig_widget.data[0].x = x_values
                            fig_widget.data[0].y = cur_log_preds
                            # Trace 1 is raw predictions (teal line)
                            fig_widget.data[1].x = x_values
                            fig_widget.data[1].y = raw_preds
                
                else:
                    preds = get_log_preds(activations, weight_tensor)
                    exp_preds = torch.exp(preds).item()

                    raw_preds.append(exp_preds)

                
submit_button.on_click(on_submit_clicked)

VBox(children=(HTML(value='<h3>Text Generation Progress</h3>'), Textarea(value='', description='Prompt:', layo…

In [7]:
# Create input text box for prompt
prompt_input = widgets.Textarea(
    value="",
    placeholder='Enter your prompt here...',
    description='Prompt:',
    layout=widgets.Layout(width='100%', height='80px')
)

# Create submit button
submit_button = widgets.Button(
    description='Generate Text',
    button_style='success',
    tooltip='Click to start text generation',
    icon='play'
)

# Create progress bar widget
progress_bar = widgets.FloatProgress(
    value=0,
    min=0,
    max=100,
    description='Progress:',
    bar_style='info',
    style={'bar_color': '#20B2AA'},
    orientation='horizontal'
)

# Create percentage label
percentage_label = widgets.HTML(
    value="<b>0.0%</b>",
    description='',
)

# Create horizontal box for progress bar and percentage
progress_row = widgets.HBox([progress_bar, percentage_label])

# Create text widget for token display
token_display = widgets.HTML(
    value="<b>Generated tokens will appear here...</b>",
    placeholder='',
    description='',
)

# Create output widget for the graph
graph_output = widgets.Output()

# Create steering strength slider
steering_slider = widgets.FloatSlider(
    value=0.0,
    min=-10.0,
    max=10.0,
    step=0.1,
    description='Steering:',
    tooltip='Positive = steer toward completion, Negative = steer away from completion',
    style={'description_width': '80px'},
    layout=widgets.Layout(width='100%')
)

# Create container for the widgets - graph now above token display
progress_container = widgets.VBox([
    widgets.HTML("<h3>Text Generation Progress with Steering</h3>"),
    prompt_input,
    steering_slider,
    submit_button,
    progress_row,
    graph_output,
    token_display
])

# Display the widget
display(progress_container)

# Initialize FigureWidget outside the loop (Option 2)
import plotly.graph_objects as go
fig_widget = None

def get_color_for_change(change, max_change=800.0):
    """
    Get color based on change in predictions.
    Positive change (increase) -> Red
    Negative change (decrease) -> Green
    """
    # Normalize change to [-1, 1] range
    normalized = max(min(change / max_change, 1.0), -1.0)
    
    if normalized > 0:  # Increase - Red
        # Interpolate from light to dark red
        intensity = int(255 * (1 - normalized * 0.7))  # 255 to ~77
        return f'rgb(255, {intensity}, {intensity})'
    else:  # Decrease - Green
        # Interpolate from light to dark green
        intensity = int(255 * (1 + normalized * 0.7))  # 255 to ~77
        return f'rgb({intensity}, 255, {intensity})'

def calculate_ema(values, alpha=0.2):
    """Calculate exponential moving average
    Lower alpha = smoother (e.g., 0.1-0.3 for good smoothing)
    Higher alpha = follows data more closely (e.g., 0.8-1.0)
    """
    if not values:
        return []
    ema = [values[0]]
    for val in values[1:]:
        ema.append(alpha * val + (1 - alpha) * ema[-1])
    return ema

def on_submit_clicked(b):
    import html
    global fig_widget
    
    # Get steering strength from slider
    steering_strength = steering_slider.value
    
    # Reset progress
    progress_bar.value = 0
    percentage_label.value = "<b>0.0%</b>"
    token_display.value = f"<b>Generating... (Steering: {steering_strength:+.1f})</b>"
    
    # Initialize FigureWidget once with both traces
    with graph_output:
        graph_output.clear_output(wait=True)
        fig_widget = go.FigureWidget()
        # Add EMA trace first (will be behind)
        fig_widget.add_trace(go.Scatter(
            x=[],
            y=[],
            mode='lines',
            name='EMA (smoothed)',
            line=dict(color='rgba(255, 100, 100, 0.6)', width=3),
            showlegend=True
        ))
        # Add main trace on top
        fig_widget.add_trace(go.Scatter(
            x=[],
            y=[],
            mode='lines',
            name='Predicted Tokens Remaining',
            line=dict(color='rgba(32, 178, 170, 0.8)', width=2),
            showlegend=True
        ))
        fig_widget.update_layout(
            title=f'Predicted Tokens Remaining Over Time (Steering: {steering_strength:+.1f})',
            xaxis_title='Token Number',
            yaxis_title='Predicted Tokens Remaining',
            height=300,
            margin=dict(l=50, r=20, t=40, b=40),
            transition_duration=0,
            legend=dict(x=0.7, y=1, bgcolor='rgba(255,255,255,0.8)')
        )
        display(fig_widget)
    
    # Get prompt from input
    prompt = prompt_input.value
    # Apply chat template
    prompt = model.tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
    raw_preds = []  # Store raw predictions without EMA
    cur_log_preds = []  # Store predictions after get_ema_preds (for progress bar)
    generated_tokens = []
    token_changes = []  # Track changes for color coding

    with model.generate(prompt, max_new_tokens=32768, do_sample=True) as tracer:
        # Call .all() to apply intervention to each new token
        with tracer.all():
            # Get the original activations
            original_activations = model.model.layers[15].output[0]
            
            # Apply steering to modify the forward pass (affects generation)
            if steering_strength != 0.0:
                model.model.layers[15].output[0][:] = original_activations + steering_strength * weight_tensor
            
            # Use ORIGINAL activations for prediction (measure effect on unsteered state)
            activations = original_activations
                        
            if activations.shape[0] == 1:
                if len(raw_preds) > 0:
                    preds = get_log_preds(activations, weight_tensor)
                    exp_preds = torch.exp(preds).item()

                    raw_preds.append(exp_preds)
                
                    cur_log_preds = calculate_ema(raw_preds, alpha=0.2)
                
                    change = cur_log_preds[-1] - cur_log_preds[-2]
                    token_changes.append(change)

                    token = model.lm_head.output.argmax(dim=-1).tolist()
                    token_str = model.tokenizer.decode(token[0][0], skip_special_tokens=False)
                    generated_tokens.append(token_str)
            
                    pred_percent_through = len(raw_preds) / (len(raw_preds) + cur_log_preds[-1]) * 100
                    progress_bar.value = pred_percent_through
            
                    percentage_label.value = f"<b>{pred_percent_through:.1f}%</b>"
            
                    start_idx = max(0, len(generated_tokens) - 2000)
                    display_tokens = generated_tokens[start_idx:]
                    display_changes = token_changes[start_idx:]
                
                    # Calculate dynamic max_change based on actual data
                    max_change = max(abs(c) for c in token_changes) if token_changes else 800.0
                    max_change = max(max_change, 100.0)  # Ensure a minimum threshold
                
                    tokens_html = " ".join([
                        f"<span style='background-color: {get_color_for_change(change, max_change)}; padding: 2px 4px; margin: 1px; border-radius: 3px;' title='Token #{start_idx + i + 1} | Change: {change:+.2f}'>{html.escape(token)}</span>" 
                        for i, (token, change) in enumerate(zip(display_tokens, display_changes))
                    ])
                
                    # Add indicator if tokens are truncated
                    truncated_msg = f"<i>(Showing last 2000 of {len(generated_tokens)} tokens)</i><br>" if len(generated_tokens) > 2000 else ""
                
                    token_display.value = f"<b>Generated tokens:</b> <span style='font-size: 0.9em;'>(🔴 increase / 🟢 decrease)</span> | <b>Steering: {steering_strength:+.1f}</b><br>{truncated_msg}{tokens_html}<br><br><b>Latest:</b> '{html.escape(token_str)}' | <b>Predicted:</b> {pred_percent_through:.1f}% through"
                
                    if len(cur_log_preds) % 10 == 0 or len(cur_log_preds) == 1:
                        x_values = list(range(1, len(cur_log_preds) + 1))
                    
                        with fig_widget.batch_update():
                            # Trace 0 is EMA (red line)
                            fig_widget.data[0].x = x_values
                            fig_widget.data[0].y = cur_log_preds
                            # Trace 1 is raw predictions (teal line)
                            fig_widget.data[1].x = x_values
                            fig_widget.data[1].y = raw_preds
                
                else:
                    preds = get_log_preds(activations, weight_tensor)
                    exp_preds = torch.exp(preds).item()

                    raw_preds.append(exp_preds)
                
submit_button.on_click(on_submit_clicked)

VBox(children=(HTML(value='<h3>Text Generation Progress with Steering</h3>'), Textarea(value='', description='…

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [36]:
weight_tensor.min()

tensor(-0.1934, device='cuda:0', dtype=torch.bfloat16)

In [9]:
!pip install pandas -q

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [7]:
# IMPROVED: Unbiased analysis that filters low-count tokens
def analyze_token_distributions_robust(baseline_tokens, steered_up_tokens, steered_down_tokens, 
                                       min_total_count=10, top_n=100):
    """
    Robust analysis that filters out tokens with insufficient data.
    
    Args:
        baseline_tokens: List of tokens from baseline condition
        steered_up_tokens: List of tokens from steered up condition  
        steered_down_tokens: List of tokens from steered down condition
        min_total_count: Minimum total appearances across all conditions to be included
        top_n: Number of top results to return
    
    Returns:
        DataFrame with token statistics and enrichment scores
    """
    # Count frequencies
    baseline_counts = Counter(baseline_tokens)
    up_counts = Counter(steered_up_tokens)
    down_counts = Counter(steered_down_tokens)
    
    # Get all unique tokens
    all_tokens = set(baseline_counts.keys()) | set(up_counts.keys()) | set(down_counts.keys())
    
    # Total counts for normalization
    baseline_total = len(baseline_tokens)
    up_total = len(steered_up_tokens)
    down_total = len(steered_down_tokens)
    
    print(f"\nAnalysis Parameters:")
    print(f"  Baseline tokens: {baseline_total:,}")
    print(f"  Steered Up tokens: {up_total:,}")
    print(f"  Steered Down tokens: {down_total:,}")
    print(f"  Unique tokens found: {len(all_tokens):,}")
    print(f"  Minimum total count filter: {min_total_count}")
    
    # Build analysis dataframe
    data = []
    filtered_out = 0
    
    for token in all_tokens:
        b_count = baseline_counts[token]
        u_count = up_counts[token]
        d_count = down_counts[token]
        
        total_count = b_count + u_count + d_count
        
        # FILTER: Skip tokens with insufficient total occurrences
        if total_count < min_total_count:
            filtered_out += 1
            continue
        
        baseline_freq = b_count / baseline_total
        up_freq = u_count / up_total
        down_freq = d_count / down_total
        
        # Use Laplace smoothing with count-based pseudocount (more robust)
        # Add 1 to counts, add 3 to totals (equivalent to adding 1 to each condition)
        pseudocount = 1
        up_enrichment = np.log2(
            ((u_count + pseudocount) / (up_total + 3)) / 
            ((b_count + pseudocount) / (baseline_total + 3))
        )
        down_enrichment = np.log2(
            ((d_count + pseudocount) / (down_total + 3)) / 
            ((b_count + pseudocount) / (baseline_total + 3))
        )
        
        # Chi-square test for up vs baseline
        from scipy.stats import chi2_contingency
        contingency_up = [[u_count, up_total - u_count], 
                          [b_count, baseline_total - b_count]]
        chi2_up, p_up = chi2_contingency(contingency_up)[:2]
        
        contingency_down = [[d_count, down_total - d_count], 
                            [b_count, baseline_total - b_count]]
        chi2_down, p_down = chi2_contingency(contingency_down)[:2]
        
        data.append({
            'token': token,
            'total_count': total_count,
            'baseline_count': b_count,
            'up_count': u_count,
            'down_count': d_count,
            'baseline_freq': baseline_freq,
            'up_freq': up_freq,
            'down_freq': down_freq,
            'up_enrichment': up_enrichment,
            'down_enrichment': down_enrichment,
            'up_vs_down': up_enrichment - down_enrichment,
            'up_pvalue': p_up,
            'down_pvalue': p_down,
            # Flag significant results (p < 0.01 after Bonferroni correction)
            'up_significant': p_up < (0.01 / len(all_tokens)),
            'down_significant': p_down < (0.01 / len(all_tokens))
        })
    
    df = pd.DataFrame(data)
    
    print(f"  Tokens after filtering: {len(df):,}")
    print(f"  Tokens filtered out: {filtered_out:,}")
    
    # Sort by different metrics
    print("\n" + "="*80)
    print(f"MOST ENRICHED IN STEERED UP (positive steering = toward completion)")
    print("="*80)
    print(f"{'Token':<30} | {'Up Count':>8} | {'Base Count':>10} | {'Enrich':>7} | {'P-value':>10} | Sig?")
    print("-"*80)
    
    top_up = df.nlargest(top_n, 'up_enrichment')
    for idx, row in top_up.head(30).iterrows():
        sig_marker = "***" if row['up_significant'] else ""
        print(f"{row['token'][:30]:30s} | {row['up_count']:8d} | {row['baseline_count']:10d} | "
              f"{row['up_enrichment']:+7.2f} | {row['up_pvalue']:10.2e} | {sig_marker:3s}")
    
    print("\n" + "="*80)
    print(f"MOST ENRICHED IN STEERED DOWN (negative steering = away from completion)")
    print("="*80)
    print(f"{'Token':<30} | {'Down Count':>10} | {'Base Count':>10} | {'Enrich':>7} | {'P-value':>10} | Sig?")
    print("-"*80)
    
    top_down = df.nlargest(top_n, 'down_enrichment')
    for idx, row in top_down.head(30).iterrows():
        sig_marker = "***" if row['down_significant'] else ""
        print(f"{row['token'][:30]:30s} | {row['down_count']:10d} | {row['baseline_count']:10d} | "
              f"{row['down_enrichment']:+7.2f} | {row['down_pvalue']:10.2e} | {sig_marker:3s}")
    
    print("\n" + "="*80)
    print(f"MOST DEPLETED IN STEERED UP (avoided when steering toward completion)")
    print("="*80)
    print(f"{'Token':<30} | {'Up Count':>8} | {'Base Count':>10} | {'Depletion':>9}")
    print("-"*80)
    
    bottom_up = df.nsmallest(top_n, 'up_enrichment')
    for idx, row in bottom_up.head(30).iterrows():
        print(f"{row['token'][:30]:30s} | {row['up_count']:8d} | {row['baseline_count']:10d} | "
              f"{row['up_enrichment']:+9.2f}")
    
    return df


In [14]:
# Load instructions from rollouts for batched generation
import json
import random

print("Loading instructions from rollouts-big/Qwen3-4B-2.json...")
with open('/root/llm-progress-monitor/rollouts-big/Qwen3-4B-2.json', 'r') as f:
    rollouts_data = json.load(f)

# Extract unique instructions
instructions = list(set([item['instruction'] for item in rollouts_data]))
print(f"Loaded {len(instructions):,} unique instructions")


Loading instructions from rollouts-big/Qwen3-4B-2.json...
Loaded 5,000 unique instructions
Loaded 5,000 unique instructions


In [15]:
# RECOMMENDED: Ultra-efficient version following nnsight best practices
# This version is the cleanest and most efficient

def generate_batched_tokens(steering_strength, num_tokens, batch_size=16, max_tokens_per_gen=200):
    """
    Ultra-efficient batched token generation using nnsight best practices.
    
    Args:
        steering_strength: Steering coefficient to apply
        num_tokens: Total number of tokens to generate
        batch_size: Number of prompts to process in parallel
        max_tokens_per_gen: Max tokens to generate per prompt per batch
    
    Returns:
        List of token strings
    """
    all_tokens = []
    
    print(f"\n{'='*80}")
    print(f"Batched Generation: steering={steering_strength:+.1f}, batch_size={batch_size}")
    print(f"{'='*80}")
    
    iterations = 0
    while len(all_tokens) < num_tokens:
        iterations += 1
        
        # Sample batch of random instructions
        prompts = [
            model.tokenizer.apply_chat_template(
                [{"role": "user", "content": random.choice(instructions)}],
                tokenize=False,
                add_generation_prompt=True
            )
            for _ in range(batch_size)
        ]
        
        # Determine tokens to generate this iteration
        tokens_needed = num_tokens - len(all_tokens)
        tokens_this_gen = min(max_tokens_per_gen, (tokens_needed // batch_size) + 50)
        
        # Single generation context with batched prompts
        with model.generate(max_new_tokens=tokens_this_gen, do_sample=True) as tracer:
            # Pass list of prompts to invoke() for batching
            with tracer.invoke(prompts):
                token_list = nnsight.list().save()
                
                # Apply intervention to all generation steps
                with tracer.all():
                    acts = model.model.layers[15].output[0]
                    
                    # Steer if needed
                    if steering_strength != 0.0:
                        model.model.layers[15].output[0][:] = acts + steering_strength * weight_tensor
                    
                    # Collect argmax tokens
                    token_ids = model.lm_head.output.argmax(dim=-1)
                    token_list.append(token_ids)
        
        # Process collected tokens
        for token_batch in token_list:
            if len(all_tokens) >= num_tokens:
                break
            
            # Handle different tensor shapes
            if isinstance(token_batch, torch.Tensor):
                # Flatten and decode
                token_batch_flat = token_batch.flatten().tolist()
                for tid in token_batch_flat:
                    if len(all_tokens) >= num_tokens:
                        break
                    token_str = model.tokenizer.decode([tid], skip_special_tokens=False)
                    all_tokens.append(token_str)
        
        # Progress update
        progress = len(all_tokens)
        if progress % 5000 < batch_size * tokens_this_gen or progress >= num_tokens:
            print(f"  Iteration {iterations:3d} | Progress: {progress:,} / {num_tokens:,} ({100*progress/num_tokens:.1f}%)")
    
    print(f"  ✓ Complete: {len(all_tokens):,} tokens in {iterations} iterations")
    print(f"{'='*80}\n")
    
    return all_tokens[:num_tokens]


# Example usage (uncomment to run):
# print("Testing ultra-efficient batched generation...")
# test_tokens = generate_batched_tokens(0.0, 1000, batch_size=16, max_tokens_per_gen=100)
# print(f"Generated {len(test_tokens)} tokens")
# print(f"Sample tokens: {test_tokens[:20]}")


In [None]:
# 🚀 READY TO RUN: Optimized batch generation
# This replaces the original sequential generation with efficient batching

# Adjust these for your A40:
BATCH_SIZE_OPTIMIZED = 32  # Increase to 24-32 if VRAM allows
MAX_TOKENS_PER_GEN = 1000   # Tokens per prompt per iteration
TARGET_TOKENS = 50000
STEERING_UP = 10.0
STEERING_DOWN = -10.0


print(f"\n{'#'*80}")
print(f"# STARTING OPTIMIZED BATCHED TOKEN GENERATION")
print(f"#")
print(f"# Target: {TARGET_TOKENS:,} tokens per condition")
print(f"# Batch Size: {BATCH_SIZE_OPTIMIZED}")
print(f"# Expected speedup: ~{BATCH_SIZE_OPTIMIZED}x faster")
print(f"{'#'*80}\n")

# Generate with batching
baseline_tokens = generate_batched_tokens(0.0, TARGET_TOKENS, 
                                         batch_size=BATCH_SIZE_OPTIMIZED, 
                                         max_tokens_per_gen=MAX_TOKENS_PER_GEN)

steered_up_tokens = generate_batched_tokens(STEERING_UP, TARGET_TOKENS,
                                           batch_size=BATCH_SIZE_OPTIMIZED,
                                           max_tokens_per_gen=MAX_TOKENS_PER_GEN)

steered_down_tokens = generate_batched_tokens(STEERING_DOWN, TARGET_TOKENS,
                                             batch_size=BATCH_SIZE_OPTIMIZED,
                                             max_tokens_per_gen=MAX_TOKENS_PER_GEN)

print(f"\n{'='*80}")
print(f"🎉 GENERATION COMPLETE!")
print(f"{'='*80}")
print(f"Baseline:     {len(baseline_tokens):,} tokens")
print(f"Steered Up:   {len(steered_up_tokens):,} tokens")
print(f"Steered Down: {len(steered_down_tokens):,} tokens")
print(f"{'='*80}\n")

# Analyze distributions
print("Analyzing token distributions...")
results_df = analyze_token_distributions_robust(baseline_tokens, steered_up_tokens, steered_down_tokens)



################################################################################
# STARTING OPTIMIZED BATCHED TOKEN GENERATION
#
# Target: 50,000 tokens per condition
# Batch Size: 32
# Expected speedup: ~32x faster
################################################################################


Batched Generation: steering=+0.0, batch_size=32


In [17]:
import nnsight

In [None]:
# 📊 COMPARE: Original vs Robust Analysis
# This cell demonstrates the difference between the two approaches

print("="*80)
print("COMPARISON: Biased vs Unbiased Analysis")
print("="*80)

# Run ORIGINAL analysis (may be biased by rare tokens)
print("\n🔴 ORIGINAL ANALYSIS (potentially biased by rare tokens):")
print("-"*80)
results_df_original = analyze_token_distributions(
    baseline_tokens, steered_up_tokens, steered_down_tokens, top_n=50
)

# Run ROBUST analysis (filters rare tokens, includes stats)
print("\n\n🟢 ROBUST ANALYSIS (filters rare tokens, includes significance testing):")
print("-"*80)
results_df_robust = analyze_token_distributions_robust(
    baseline_tokens, steered_up_tokens, steered_down_tokens, 
    min_total_count=10,  # Adjust based on your total token count
    top_n=50
)

# Compare top results
print("\n" + "="*80)
print("COMPARISON OF TOP 10 ENRICHED TOKENS")
print("="*80)

top10_original = set(results_df_original.nlargest(10, 'up_enrichment')['token'])
top10_robust = set(results_df_robust.nlargest(10, 'up_enrichment')['token'])

only_in_original = top10_original - top10_robust
only_in_robust = top10_robust - top10_original
in_both = top10_original & top10_robust

print(f"\n✓ Tokens in both analyses: {len(in_both)}/10")
if in_both:
    print(f"  {', '.join(list(in_both)[:5])}" + ("..." if len(in_both) > 5 else ""))

print(f"\n⚠️  Only in ORIGINAL (likely rare token artifacts): {len(only_in_original)}/10")
if only_in_original:
    for token in only_in_original:
        row = results_df_original[results_df_original['token'] == token].iloc[0]
        print(f"  '{token[:20]}': up={row['up_count']}, baseline={row['baseline_count']}")

print(f"\n✅ Only in ROBUST (filtered signal): {len(only_in_robust)}/10")
if only_in_robust:
    for token in only_in_robust:
        row = results_df_robust[results_df_robust['token'] == token].iloc[0]
        sig = "***" if row['up_significant'] else ""
        print(f"  '{token[:20]}': up={row['up_count']}, baseline={row['baseline_count']}, total={row['total_count']} {sig}")

print("\n" + "="*80)
print("RECOMMENDATION: Use the ROBUST analysis for reliable results!")
print("="*80)


# 🎯 Recommended Analysis Settings

Use these `min_total_count` values based on your dataset size:

| Total Tokens per Condition | Recommended `min_total_count` | Reasoning |
|----------------------------|------------------------------|-----------|
| 1,000 - 10,000 | **5** | Very small dataset, need to keep more tokens |
| 10,000 - 100,000 | **10** | Default, good balance |
| 100,000 - 1,000,000 | **50** | Medium dataset, filter more noise |
| 1,000,000+ | **100-200** | Large dataset, can be more stringent |

## Quick Reference:

```python
# For small test runs (10k tokens):
results_df = analyze_token_distributions_robust(
    baseline_tokens, steered_up_tokens, steered_down_tokens, 
    min_total_count=5
)

# For medium runs (100k tokens):
results_df = analyze_token_distributions_robust(
    baseline_tokens, steered_up_tokens, steered_down_tokens, 
    min_total_count=10  # default
)

# For large runs (1M tokens):
results_df = analyze_token_distributions_robust(
    baseline_tokens, steered_up_tokens, steered_down_tokens, 
    min_total_count=100
)
```

## Alternative: Filter by Frequency Instead

If you want to filter by **frequency** rather than raw count:

```python
# Only include tokens that appear in at least 0.01% of samples
min_freq = 0.0001  # 0.01%
min_count = int(len(baseline_tokens) * min_freq)

results_df = analyze_token_distributions_robust(
    baseline_tokens, steered_up_tokens, steered_down_tokens, 
    min_total_count=min_count
)
```

## Exporting Significant Results Only

```python
# Get only statistically significant enriched tokens
significant_up = results_df_robust[
    (results_df_robust['up_significant'] == True) & 
    (results_df_robust['up_enrichment'] > 0)
].sort_values('up_enrichment', ascending=False)

significant_down = results_df_robust[
    (results_df_robust['down_significant'] == True) & 
    (results_df_robust['down_enrichment'] > 0)
].sort_values('down_enrichment', ascending=False)

print(f"Significant enriched tokens (steering up): {len(significant_up)}")
print(f"Significant enriched tokens (steering down): {len(significant_down)}")
```


# 📝 Summary: Addressing "Tokens That Don't Appear" Bias

## Yes, the original analysis WAS skewed! 

### The Problem:
Tokens appearing in one condition but not another got **artificially inflated enrichment scores** due to:
1. Tiny pseudocount (1e-10) creating extreme ratios
2. No filtering of rare tokens
3. No statistical significance testing

### Example:
```
Token "xyz": appears 3 times in steered_up, 0 times in baseline
❌ Original enrichment: +32.5 (massive!)
✅ Reality: Just random noise, not meaningful
```

### The Solution:
Use `analyze_token_distributions_robust()` which:
- ✅ Filters tokens with < `min_total_count` total appearances
- ✅ Uses proper Laplace smoothing
- ✅ Includes chi-square significance tests
- ✅ Applies Bonferroni correction for multiple testing
- ✅ Marks statistically significant results with `***`

### Key Takeaways:
1. **Always use `min_total_count` filtering** (at least 10 for most cases)
2. **Focus on tokens marked as significant (`***`)**
3. **Larger datasets allow more stringent filtering** (min_count=50-100 for 1M tokens)
4. **Enrichment score + significance** = reliable results

The robust analysis ensures you're finding **real signal, not noise**! 🎯


# 🎯 Quick Reference Card: Unbiased Token Analysis

## ✅ DO THIS:
```python
# Use robust analysis with appropriate filtering
results_df = analyze_token_distributions_robust(
    baseline_tokens, steered_up_tokens, steered_down_tokens,
    min_total_count=10,  # Adjust based on dataset size
    top_n=100
)

# Focus on statistically significant results
significant_tokens = results_df[results_df['up_significant'] == True]
```

## ❌ DON'T DO THIS:
```python
# Original function without filtering - BIASED!
results_df = analyze_token_distributions(
    baseline_tokens, steered_up_tokens, steered_down_tokens
)
# ⚠️ Results dominated by rare tokens that appeared by chance
```

## 🔍 Interpreting Results:

### Enrichment Score:
- **+1.0** = 2x more frequent
- **+2.0** = 4x more frequent  
- **+3.0** = 8x more frequent
- **-1.0** = 2x less frequent (50% reduction)

### Significance Markers:
- **`***`** = Statistically significant (p < 0.01 after Bonferroni correction)
- No marker = Not statistically significant (could be noise)

### What to Report:
Focus on tokens that are **BOTH**:
1. Highly enriched (|enrichment| > 1.0)
2. Statistically significant (marked with `***`)

## 💡 Pro Tip:
For very large datasets (1M+ tokens), increase `min_total_count` to 50-100 to be more stringent and reduce false positives.


In [None]:
# Optional: Visualize the results
import matplotlib.pyplot as plt

# Get top enriched tokens for each condition
top_up = results_df.nlargest(30, 'up_enrichment')
top_down = results_df.nlargest(30, 'down_enrichment')

fig, axes = plt.subplots(1, 2, figsize=(16, 8))

# Plot 1: Steered Up enrichment
axes[0].barh(range(len(top_up)), top_up['up_enrichment'], color='coral')
axes[0].set_yticks(range(len(top_up)))
axes[0].set_yticklabels([t[:20] for t in top_up['token']], fontsize=8)
axes[0].set_xlabel('Log2 Enrichment vs Baseline', fontsize=10)
axes[0].set_title(f'Top 30 Tokens Enriched in STEERED UP\n(Steering: {STEERING_UP:+.1f})', fontsize=12, fontweight='bold')
axes[0].axvline(x=0, color='black', linestyle='--', linewidth=0.5)
axes[0].invert_yaxis()

# Plot 2: Steered Down enrichment
axes[1].barh(range(len(top_down)), top_down['down_enrichment'], color='lightblue')
axes[1].set_yticks(range(len(top_down)))
axes[1].set_yticklabels([t[:20] for t in top_down['token']], fontsize=8)
axes[1].set_xlabel('Log2 Enrichment vs Baseline', fontsize=10)
axes[1].set_title(f'Top 30 Tokens Enriched in STEERED DOWN\n(Steering: {STEERING_DOWN:+.1f})', fontsize=12, fontweight='bold')
axes[1].axvline(x=0, color='black', linestyle='--', linewidth=0.5)
axes[1].invert_yaxis()

plt.tight_layout()
plt.show()

print(f"\nResults dataframe shape: {results_df.shape}")
print(f"Access via: results_df")
print(f"\nExample queries:")
print(f"  - results_df.nlargest(50, 'up_enrichment')  # Most enriched when steering up")
print(f"  - results_df.nsmallest(50, 'up_enrichment')  # Most depleted when steering up")
print(f"  - results_df.nlargest(50, 'down_enrichment')  # Most enriched when steering down")
print(f"  - results_df[results_df['token'].str.contains('word')]  # Search for specific tokens")

In [None]:
# 📈 IMPROVED: Visualize robust results with significance markers
import matplotlib.pyplot as plt

# Get top enriched tokens for each condition from ROBUST analysis
top_up = results_df.nlargest(30, 'up_enrichment')
top_down = results_df.nlargest(30, 'down_enrichment')

fig, axes = plt.subplots(1, 2, figsize=(18, 10))

# Plot 1: Steered Up enrichment with significance markers
colors_up = ['red' if sig else 'coral' for sig in top_up['up_significant']]
bars1 = axes[0].barh(range(len(top_up)), top_up['up_enrichment'], color=colors_up)
axes[0].set_yticks(range(len(top_up)))
axes[0].set_yticklabels([t[:20] for t in top_up['token']], fontsize=8)
axes[0].set_xlabel('Log2 Enrichment vs Baseline', fontsize=11)
axes[0].set_title(f'Top 30 Tokens Enriched in STEERED UP\n(Steering: {STEERING_UP:+.1f})\nRed = Statistically Significant (p < 0.01)', 
                  fontsize=12, fontweight='bold')
axes[0].axvline(x=0, color='black', linestyle='--', linewidth=0.5)
axes[0].invert_yaxis()

# Add total count annotations
for i, (idx, row) in enumerate(top_up.iterrows()):
    axes[0].text(row['up_enrichment'] + 0.1, i, f"n={row['total_count']}", 
                va='center', fontsize=7, color='darkred' if row['up_significant'] else 'gray')

# Plot 2: Steered Down enrichment with significance markers
colors_down = ['darkblue' if sig else 'lightblue' for sig in top_down['down_significant']]
bars2 = axes[1].barh(range(len(top_down)), top_down['down_enrichment'], color=colors_down)
axes[1].set_yticks(range(len(top_down)))
axes[1].set_yticklabels([t[:20] for t in top_down['token']], fontsize=8)
axes[1].set_xlabel('Log2 Enrichment vs Baseline', fontsize=11)
axes[1].set_title(f'Top 30 Tokens Enriched in STEERED DOWN\n(Steering: {STEERING_DOWN:+.1f})\nDark Blue = Statistically Significant (p < 0.01)', 
                  fontsize=12, fontweight='bold')
axes[1].axvline(x=0, color='black', linestyle='--', linewidth=0.5)
axes[1].invert_yaxis()

# Add total count annotations
for i, (idx, row) in enumerate(top_down.iterrows()):
    axes[1].text(row['down_enrichment'] + 0.1, i, f"n={row['total_count']}", 
                va='center', fontsize=7, color='darkblue' if row['down_significant'] else 'gray')

plt.tight_layout()
plt.savefig('steering_analysis_robust.png', dpi=150, bbox_inches='tight')
print("✓ Saved plot to 'steering_analysis_robust.png'")
plt.show()

# Print summary statistics
print(f"\n{'='*80}")
print("SUMMARY STATISTICS")
print(f"{'='*80}")
print(f"Total unique tokens analyzed: {len(results_df):,}")
print(f"\nStatistically significant enrichments:")
print(f"  Steered UP:   {results_df['up_significant'].sum():,} tokens")
print(f"  Steered DOWN: {results_df['down_significant'].sum():,} tokens")

print(f"\nTop 10 significant tokens (steering UP):")
sig_up = results_df[results_df['up_significant']].nlargest(10, 'up_enrichment')
for idx, row in sig_up.iterrows():
    print(f"  '{row['token'][:25]:25s}' | enrichment: {row['up_enrichment']:+.2f} | p={row['up_pvalue']:.2e}")

print(f"\nTop 10 significant tokens (steering DOWN):")
sig_down = results_df[results_df['down_significant']].nlargest(10, 'down_enrichment')
for idx, row in sig_down.iterrows():
    print(f"  '{row['token'][:25]:25s}' | enrichment: {row['down_enrichment']:+.2f} | p={row['down_pvalue']:.2e}")

print(f"\n{'='*80}")


In [None]:
# Save results to file for later analysis
results_df.to_csv('steering_analysis_results.csv', index=False)
print("Results saved to 'steering_analysis_results.csv'")

# Also save the raw token lists
import json
with open('steering_analysis_tokens.json', 'w') as f:
    json.dump({
        'baseline': baseline_tokens,
        'steered_up': steered_up_tokens,
        'steered_down': steered_down_tokens,
        'config': {
            'steering_up': STEERING_UP,
            'steering_down': STEERING_DOWN,
            'target_tokens': TARGET_TOKENS,
            'prompt': PROMPT
        }
    }, f)
print("Raw tokens saved to 'steering_analysis_tokens.json'")