Just run all these cells

# Imports

In [1]:
from nnsight import LanguageModel
from einops import einsum
import torch
import ipywidgets as widgets
from IPython.display import display
import time
import plotly.graph_objects as go
import numpy as np

# Load model

In [2]:
device = 'cuda'

In [3]:
weight_tensor = torch.load('/root/llm-progress-monitor/models/probe_weights.pt')
model_name = 'Qwen/Qwen3-4B'

In [4]:
model = LanguageModel(model_name, device_map=device, dtype=torch.bfloat16)

In [5]:
def get_ema_preds(log_preds, alpha=0.5):
    given_alpha = alpha
    preds_list = log_preds.exp().tolist()
    
    ema_preds = []
    cur_ema = None
    for i,pred in enumerate(preds_list):
        # Use a smooth transition from 0.5 to given_alpha, reaching given_alpha at 200 tokens
        alpha = given_alpha
        if cur_ema is None:
            cur_ema = pred
        else:
            cur_ema = alpha*(cur_ema-1) + (1-alpha)*pred #-1 because we have stepped one token
        ema_preds.append(cur_ema)
    return ema_preds

In [6]:
def get_log_preds(activation, weight_tensor):
    """
    Linear probe that predicts log(remaining tokens).
    activation: [seq, d_model] tensor of activations
    weight_tensor: [d_model] tensor of probe weights (or [d_model, 1])
    returns: [seq] tensor of log predictions
    """    
    # Handle different weight tensor shapes - flatten to 1D
    if len(weight_tensor.shape) == 2:
        # If weight_tensor is [d_model, 1] or [1, d_model], flatten it
        weight_tensor = weight_tensor.flatten()
    
    # Simple linear projection: activation @ weights -> log(remaining tokens)
    return einsum(activation, weight_tensor, 'seq d_model, d_model -> seq')

# Vibe coded UIs

In [7]:
# Create input text box for prompt
prompt_input = widgets.Textarea(
    value="",
    placeholder='Enter your prompt here...',
    description='Prompt:',
    layout=widgets.Layout(width='100%', height='80px')
)

# Create submit button
submit_button = widgets.Button(
    description='Generate Text',
    button_style='success',
    tooltip='Click to start text generation',
    icon='play'
)

# Create progress bar widget
progress_bar = widgets.FloatProgress(
    value=0,
    min=0,
    max=100,
    description='Progress:',
    bar_style='info',
    style={'bar_color': '#20B2AA'},
    orientation='horizontal'
)

# Create percentage label
percentage_label = widgets.HTML(
    value="<b>0.0%</b>",
    description='',
)

# Create horizontal box for progress bar and percentage
progress_row = widgets.HBox([progress_bar, percentage_label])

# Create text widget for token display
token_display = widgets.HTML(
    value="<b>Generated tokens will appear here...</b>",
    placeholder='',
    description='',
)

# Create container for the widgets
progress_container = widgets.VBox([
    widgets.HTML("<h3>Text Generation Progress</h3>"),
    prompt_input,
    submit_button,
    progress_row,
    token_display
])

# Display the widget
display(progress_container)

def on_submit_clicked(b):
    # Reset progress
    progress_bar.value = 0
    percentage_label.value = "<b>0.0%</b>"
    token_display.value = "<b>Generating...</b>"
    
    # Get prompt from input
    prompt = prompt_input.value
    # Apply chat template
    prompt = model.tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
    cur_log_preds = []
    n_tokens_generated = 0
    generated_tokens = []

    with model.generate(prompt, max_new_tokens=32768, do_sample=True) as tracer:
        # Call .all() to apply intervention to each new token
        with tracer.all():
            activations = model.model.layers[15].output[0]
            if len(activations.shape) == 1:
                activations = activations.unsqueeze(0)
            
            # Get log predictions for current token position
            log_preds = get_log_preds(activations, weight_tensor)
            
            # Only process the last prediction (for the newly generated token)
            if log_preds.shape[0] > 1:
                # Multiple positions (initial prompt), skip
                pass
            else:
                # Single new token generated
                cur_log_pred = log_preds[-1].item()
                cur_log_preds.append(cur_log_pred)
                
                ema_preds = get_ema_preds(torch.tensor(cur_log_preds))
                n_tokens_generated += 1
                pred_tokens_remaining = ema_preds[-1]
                pred_percent_through = n_tokens_generated / (n_tokens_generated + pred_tokens_remaining)


                
                token = model.lm_head.output.argmax(dim=-1).tolist()
                token_str = model.tokenizer.decode(token[0][0])
                generated_tokens.append(token_str)
                
                # Update progress bar
                #progress_bar.value = pred_percent_through * 100
                progress_bar.value = pred_tokens_remaining
                
                # Update percentage label
                #percentage_label.value = f"<b>{pred_percent_through*100:.1f}%</b>"
                percentage_label.value = f"<b>{pred_tokens_remaining}</b>"
                
                # Update token display with all generated tokens
                tokens_html = " ".join([f"<span style='background-color: #e6f3ff; padding: 2px 4px; margin: 1px; border-radius: 3px;'>{token}</span>" for token in generated_tokens])
                token_display.value = f"<b>Generated tokens:</b><br>{tokens_html}<br><br><b>Latest:</b> '{token_str}' | <b>Predicted:</b> {pred_percent_through*100:.1f}% through"

# Connect button click to function
submit_button.on_click(on_submit_clicked)


VBox(children=(HTML(value='<h3>Text Generation Progress</h3>'), Textarea(value='', description='Prompt:', layo…

In [None]:
# Simple text generation with progress monitoring
import ipywidgets as widgets
from IPython.display import display, clear_output

# Create prompt input widget
prompt_widget = widgets.Textarea(
    value='',
    placeholder='Enter your prompt here...',
    description='Prompt:',
    layout=widgets.Layout(width='80%', height='100px')
)

# Create generate button
generate_button = widgets.Button(
    description='Generate',
    button_style='success',
    tooltip='Click to generate',
    icon='play'
)

# Create initial empty plot
fig = go.FigureWidget()
fig.add_scatter(x=[], y=[], mode='lines+markers', showlegend=False)
# Add a separate trace for EOS marker (initially empty)
fig.add_scatter(x=[], y=[], mode='markers', 
               marker=dict(color='green', size=15, symbol='star'),
               showlegend=False)
fig.update_layout(
    title='Token Generation Progress',
    xaxis_title='Tokens Generated',
    yaxis_title='Predicted Remaining Tokens',
    height=400,
    showlegend=False
)

# Create token display widget
token_display = widgets.HTML(
    value='<div style="font-size: 14px; line-height: 1.8;"><b>Generated tokens will appear here...</b></div>',
    layout=widgets.Layout(width='80%', min_height='100px')
)

# Display input widgets, graph, and token display
display(widgets.VBox([prompt_widget, generate_button, fig, token_display]))

def get_color_for_change(change):
    """
    Returns a color based on the change in predicted remaining tokens.
    More red = increase (bad prediction), More green = decrease (good prediction)
    """
    if change is None:
        return '#e6f3ff'  # Default light blue for first token
    
    # Clamp change to reasonable range for color mapping
    # Positive change = red (prediction went up), Negative = green (prediction went down)
    clamped = max(-5, min(5, change))
    
    if clamped > 0:  # Increase - shades of red
        intensity = min(1.0, clamped / 5.0)
        # From light blue to red
        r = int(230 + (255 - 230) * intensity)
        g = int(243 - 243 * intensity)
        b = int(255 - 150 * intensity)
    else:  # Decrease - shades of green
        intensity = min(1.0, abs(clamped) / 5.0)
        # From light blue to green
        r = int(230 - 130 * intensity)
        g = int(243 - 15 * intensity)
        b = int(255 - 155 * intensity)
    
    return f'rgb({r},{g},{b})'

def on_generate_click(b):
    prompt = prompt_widget.value
    if not prompt:
        print("Please enter a prompt!")
        return
    
    # Apply chat template
    formatted_prompt = model.tokenizer.apply_chat_template(
        [{"role": "user", "content": prompt}], 
        tokenize=False, 
        add_generation_prompt=True
    )
    
    # Reset graph
    fig.data[0].x = []
    fig.data[0].y = []
    fig.data[1].x = []
    fig.data[1].y = []
    
    # Reset token display
    token_display.value = '<div style="font-size: 14px; line-height: 1.8;"></div>'
    
    # Track predictions - IMPORTANT: these must be lists to persist across tracer.all() calls
    cur_log_preds = []
    generated_tokens = []
    token_colors = []
    x_vals = []
    y_vals = []
    prev_pred = None
    
    with model.generate(formatted_prompt, max_new_tokens=500, do_sample=True) as tracer:
        # Apply intervention at each generation step
        with tracer.all():
            # Get activations from layer 15
            activations = model.model.layers[15].output[0]
            if len(activations.shape) == 1:
                activations = activations.unsqueeze(0)
            
            # Get log predictions for current token position
            log_preds = get_log_preds(activations, weight_tensor)
            
            # Only process single new tokens (not the initial prompt)
            if log_preds.shape[0] == 1:
                # Single new token generated
                cur_log_pred = log_preds[-1].item()
                cur_log_preds.append(cur_log_pred)
                
                # Apply EMA smoothing to predictions
                ema_preds = get_ema_preds(torch.tensor(cur_log_preds))
                n_tokens_generated = len(cur_log_preds)  # Use length of list to get current count
                pred_tokens_remaining = ema_preds[-1]
                
                # Calculate change from previous prediction
                change = None if prev_pred is None else pred_tokens_remaining - prev_pred
                color = get_color_for_change(change)
                prev_pred = pred_tokens_remaining
                
                # Get the generated token
                token = model.lm_head.output.argmax(dim=-1).tolist()
                token_str = model.tokenizer.decode(token[0][0])
                generated_tokens.append(token_str)
                token_colors.append(color)
                
                # Update graph data
                x_vals.append(n_tokens_generated)
                y_vals.append(pred_tokens_remaining)
                fig.data[0].x = x_vals
                fig.data[0].y = y_vals
                
                # Check if this is the EOS token and mark it with a star
                if token[0][0] == model.tokenizer.eos_token_id:
                    fig.data[1].x = [n_tokens_generated]
                    fig.data[1].y = [pred_tokens_remaining]
                
                # Update token display with color-coded tokens
                tokens_html = ''.join([
                    f'<span style="display: inline-block; background-color: {token_colors[i]}; padding: 4px 8px; '
                    f'margin: 2px; border-radius: 4px; border: 1px solid #b3d9ff; '
                    f'font-family: monospace; white-space: pre;">{tok}</span>'
                    for i, tok in enumerate(generated_tokens)
                ])
                token_display.value = f'<div style="font-size: 14px; line-height: 1.8;">{tokens_html}</div>'

# Connect button to function
generate_button.on_click(on_generate_click)


VBox(children=(Textarea(value='', description='Prompt:', layout=Layout(height='100px', width='80%'), placehold…