Just run all these cells

# Imports

In [2]:
from nnsight import LanguageModel
from einops import einsum
import torch
import ipywidgets as widgets
from IPython.display import display
import time

# Load model

In [3]:
device = 'cuda'

In [8]:
weight_tensor = torch.load('/workspace/llm-progress-monitor/rollouts-big/beta_torch.pt').to(dtype=torch.bfloat16)
model_name = 'Qwen/Qwen3-4B'

In [9]:
model = LanguageModel(model_name, device_map=device, dtype=torch.bfloat16)

In [10]:
def get_ema_preds(log_preds, alpha=0.5):
    given_alpha = alpha
    preds_list = log_preds.exp().tolist()
    
    ema_preds = []
    cur_ema = None
    for i,pred in enumerate(preds_list):
        # Use a smooth transition from 0.5 to given_alpha, reaching given_alpha at 200 tokens
        alpha = given_alpha
        if cur_ema is None:
            cur_ema = pred
        else:
            cur_ema = alpha*(cur_ema-1) + (1-alpha)*pred #-1 because we have stepped one token
        ema_preds.append(cur_ema)
    return ema_preds
def get_log_preds(activation, weight_tensor):
    return einsum(activation, weight_tensor, 'seq d_model, d_model -> seq')

# Vibe coded UIs

In [None]:
# Create input text box for prompt
prompt_input = widgets.Textarea(
    value="",
    placeholder='Enter your prompt here...',
    description='Prompt:',
    layout=widgets.Layout(width='100%', height='80px')
)

# Create submit button
submit_button = widgets.Button(
    description='Generate Text',
    button_style='success',
    tooltip='Click to start text generation',
    icon='play'
)

# Create progress bar widget
progress_bar = widgets.FloatProgress(
    value=0,
    min=0,
    max=100,
    description='Progress:',
    bar_style='info',
    style={'bar_color': '#20B2AA'},
    orientation='horizontal'
)

# Create percentage label
percentage_label = widgets.HTML(
    value="<b>0.0%</b>",
    description='',
)

# Create horizontal box for progress bar and percentage
progress_row = widgets.HBox([progress_bar, percentage_label])

# Create text widget for token display
token_display = widgets.HTML(
    value="<b>Generated tokens will appear here...</b>",
    placeholder='',
    description='',
)

# Create container for the widgets
progress_container = widgets.VBox([
    widgets.HTML("<h3>Text Generation Progress</h3>"),
    prompt_input,
    submit_button,
    progress_row,
    token_display
])

# Display the widget
display(progress_container)

def on_submit_clicked(b):
    # Reset progress
    progress_bar.value = 0
    percentage_label.value = "<b>0.0%</b>"
    token_display.value = "<b>Generating...</b>"
    
    # Get prompt from input
    prompt = prompt_input.value
    # Apply chat template
    prompt = model.tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
    cur_log_preds = []
    n_tokens_generated = 0
    generated_tokens = []

    with model.generate(prompt, max_new_tokens=32768, do_sample=True) as tracer:
        # Call .all() to apply intervention to each new token
        with tracer.all():
            activations = model.model.layers[15].output[0]
            if len(activations.shape) == 1:
                activations = activations.unsqueeze(0)
            preds = get_log_preds(activations, weight_tensor).tolist()
            if len(preds) > 1:
                pass
            else:
                cur_log_preds+=preds
                ema_preds = get_ema_preds(torch.tensor(cur_log_preds))
                n_tokens_generated+=1
                pred_tokens_remaining = ema_preds[-1]
                pred_percent_through = n_tokens_generated/(n_tokens_generated + pred_tokens_remaining)
                
                token = model.lm_head.output.argmax(dim=-1).tolist()
                token_str = model.tokenizer.decode(token[0][0], skip_special_tokens=True)
                print(token_str)
                generated_tokens.append(token_str)
                
                # Update progress bar
                progress_bar.value = pred_percent_through * 100
                
                # Update percentage label
                percentage_label.value = f"<b>{pred_percent_through*100:.1f}%</b>"
                
                # Update token display with all generated tokens
                tokens_html = " ".join([f"<span style='background-color: #e6f3ff; padding: 2px 4px; margin: 1px; border-radius: 3px;'>{token}</span>" for token in generated_tokens])
                token_display.value = f"<b>Generated tokens:</b><br>{tokens_html}<br><br><b>Latest:</b> '{token_str}' | <b>Predicted:</b> {pred_percent_through*100:.1f}% through"

# Connect button click to function
submit_button.on_click(on_submit_clicked)


VBox(children=(HTML(value='<h3>Text Generation Progress</h3>'), Textarea(value='', description='Prompt:', layo…

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

  return _VF.einsum(equation, operands)  # type: ignore[attr-defined]




Okay
,
 the
 user
 is
 asking
 for
 the
 capital
 of
 France
 and
 some
 fun
 things
 to
 do
 there
.
 Let
 me
 start
 by
 recalling
 that
 capital
.
 I
 know
 that
 France
 is
 capital
 is
 Paris
.
 But
's
 pretty
 straightforward
.


Now
,
 for
 the
 fun
 things
 to
 do
 in
 Paris
.
 I
 need
 to
 think
 of
 popular
 most
 attractions
.
 The
 E
iff
el
 Tower
 is
 a
 a
 must
-
visit
.
 Then
 there
's
 the
 Lou
vre
 Museum
,
 which
 houses
 the
 Mona
 Lisa
.
 The
 Notre
-D
ame
 Cathedral
 is
 another
 iconic
 landmark
.
 Maybe
 mention
 the
 Se
ine
 River
 and
 a
 boat
 tour
.
 The
 Mont
mart
re
 area
 with
 the
 Sac
ré
-C
œur
 Basil
ica
 is
 popular
.
 The
,
 the
 Ch
ée
 d
'
Or
say
 for
 Imp
.
 The
 Ch
amps
-
É
lys
ées
 and
 the
 E
iff
el
 Tower
 are
 often
 together
.
 Maybe
 include
 some
 local
 spots
 like
 like
 a
is
series
 or
 French
 cuisine
.
 Also
 in
 areas
 Mar
 Quarter
 or
 the
ais
.
.
 The
 Place
 de
 la
 Con
cor
de
 is
 a
 notable
 spot
.
 Also
 Tu
il
eries
 Garden
 is

In [None]:
from IPython.display import clear_output
import ipywidgets as widgets
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px

# Create input text box for prompt
prompt_input = widgets.Textarea(
    value="",
    placeholder='Enter your prompt here...',
    description='Prompt:',
    layout=widgets.Layout(width='100%', height='80px')
)

# Create input text box for EMA factor
ema_input = widgets.FloatText(
    value=0.9,
    description='EMA Factor:',
    min=0.0,
    max=1.0,
    step=0.01,
    tooltip='Exponential Moving Average factor (0.0 to 1.0)',
    layout=widgets.Layout(width='200px')
)

# Create submit button
submit_button = widgets.Button(
    description='Generate Text',
    button_style='success',
    tooltip='Click to start text generation',
    icon='play'
)

# Create progress bar widget
progress_bar = widgets.FloatProgress(
    value=0,
    min=0,
    max=100,
    description='Progress:',
    bar_style='info',
    style={'bar_color': '#20B2AA'},
    orientation='horizontal'
)

# Create percentage label
percentage_label = widgets.HTML(
    value="<b>0.0%</b>",
    description='',
)

# Create horizontal box for progress bar and percentage
progress_row = widgets.HBox([progress_bar, percentage_label])

# Create text widget for token display
token_display = widgets.HTML(
    value="<b>Generated tokens will appear here...</b>",
    placeholder='',
    description='',
)

# Create graph widget for prediction history
graph_widget = go.FigureWidget()

# Create container for the widgets
progress_container = widgets.VBox([
    widgets.HTML("<h3>Text Generation Progress</h3>"),
    prompt_input,
    ema_input,
    submit_button,
    progress_row,
    token_display,
    widgets.HTML("<h4>Prediction History</h4>"),
    graph_widget
])

# Display the widget
display(progress_container)

# Global variables to store generation data
raw_log_preds = []
stored_generated_tokens = []
stored_n_tokens = 0

def update_graph_with_ema(ema_factor):
    """Update the graph and displays with a new EMA factor."""
    global raw_log_preds, stored_generated_tokens, stored_n_tokens
    
    if len(raw_log_preds) == 0:
        return
    
    # Recalculate predictions with new EMA
    ema_preds = get_ema_preds(torch.tensor(raw_log_preds), alpha=ema_factor)
    
    prediction_history = []
    token_counts = []
    
    for i, ema_pred in enumerate(ema_preds):
        n_tokens_generated = i + 1
        pred_tokens_remaining = ema_pred
        predicted_total_tokens = n_tokens_generated + pred_tokens_remaining
        prediction_history.append(predicted_total_tokens)
        token_counts.append(n_tokens_generated)
    
    # Update the graph
    if len(prediction_history) > 1:
        # Create hover text with context tokens
        hover_texts = []
        for i in range(len(stored_generated_tokens)):
            # Get 5 tokens before and after (if available)
            start_idx = max(0, i - 5)
            end_idx = min(len(stored_generated_tokens), i + 6)
            
            context_tokens = []
            for j in range(start_idx, end_idx):
                token_clean = stored_generated_tokens[j].replace('&amp;', '&').replace('&lt;', '<').replace('&gt;', '>').replace('&quot;', '"').replace('&#x27;', "'")
                if j == i:
                    context_tokens.append(f"<b>{token_clean}</b>")
                else:
                    context_tokens.append(token_clean)
            
            context_str = " ".join(context_tokens)
            hover_text = f"Token {i+1}: {context_str}<br>Predicted Total: {prediction_history[i]:.0f}"
            hover_texts.append(hover_text)
        
        # Update graph with new data
        graph_widget.data = []
        graph_widget.add_trace(go.Scatter(
            x=token_counts,
            y=prediction_history,
            mode='lines+markers',
            name='Predicted Total Tokens',
            line=dict(color='blue', width=2),
            marker=dict(size=6),
            hovertemplate='%{customdata}<extra></extra>',
            customdata=hover_texts
        ))
        
        # Add actual final point
        graph_widget.add_trace(go.Scatter(
            x=[stored_n_tokens],
            y=[stored_n_tokens],
            mode='markers',
            name='Actual Final',
            marker=dict(size=10, color='green', symbol='star'),
            hovertemplate=f'Actual completion: {stored_n_tokens} tokens<extra></extra>'
        ))
        
        # Update layout
        graph_widget.update_layout(
            title=f'Token Prediction vs Reality (EMA={ema_factor:.2f})',
            xaxis_title='Token Number',
            yaxis_title='Predicted Total Tokens',
            height=400,
            showlegend=True,
            hovermode='closest'
        )
        
        # Update token display with final statistics
        final_pred = prediction_history[-1]
        accuracy = (stored_n_tokens/final_pred)*100
        token_display.value = f"<b>Generation complete!</b><br><b>Total tokens:</b> {stored_n_tokens}<br><b>Final prediction:</b> {final_pred:.0f} tokens<br><b>Accuracy:</b> {accuracy:.1f}%<br><b>Current EMA:</b> {ema_factor:.2f}"

def on_ema_changed(change):
    """Handle EMA input changes."""
    update_graph_with_ema(change['new'])

# Connect EMA input to update function
ema_input.observe(on_ema_changed, names='value')

def on_submit_clicked(b):
    global raw_log_preds, stored_generated_tokens, stored_n_tokens
    
    # Reset progress
    progress_bar.value = 0
    percentage_label.value = "<b>0.0%</b>"
    token_display.value = "<b>Generating...</b>"
    
    # Clear the graph
    graph_widget.data = []
    
    # Reset global storage
    raw_log_preds = []
    stored_generated_tokens = []
    stored_n_tokens = 0
    
    # Initialize lists to track predictions over time
    prediction_history = []
    token_counts = []
    
    # Get prompt from input
    prompt = prompt_input.value
    # Get EMA factor from input
    ema_factor = ema_input.value
    
    # Apply chat template
    prompt = model.tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
    cur_log_preds = []
    n_tokens_generated = 0
    generated_tokens = []

    with model.generate(prompt, max_new_tokens=32768, do_sample=True) as tracer:
        # Call .all() to apply intervention to each new token
        with tracer.all():
            activations = model.model.layers[15].output[0]
            print(activations.dtype)
            if len(activations.shape) == 1:
                activations = activations.unsqueeze(0)
            
            # Save predictions within nnsight context
            preds_saved = get_log_preds(activations, weight_tensor).save()
            token_saved = model.lm_head.output.argmax(dim=-1).save()
            
            preds = preds_saved.tolist()
            if len(preds) > 1:
                pass
            else:
                cur_log_preds+=preds
                raw_log_preds.append(preds[0])  # Store raw predictions globally
                
                ema_preds = get_ema_preds(torch.tensor(cur_log_preds), alpha=ema_factor)
                n_tokens_generated+=1
                pred_tokens_remaining = ema_preds[-1]
                predicted_total_tokens = n_tokens_generated + pred_tokens_remaining
                pred_percent_through = n_tokens_generated/(n_tokens_generated + pred_tokens_remaining)
                
                # Store prediction data for highlighting
                prediction_history.append(predicted_total_tokens)
                token_counts.append(n_tokens_generated)
                
                token = token_saved.tolist()
                token_str = model.tokenizer.decode(token[0][0], skip_special_tokens=False)
                # Escape HTML entities in token string for safe display
                token_str_escaped = token_str.replace('&', '&amp;').replace('<', '&lt;').replace('>', '&gt;').replace('"', '&quot;').replace("'", '&#x27;')
                print(token_str)
                generated_tokens.append(token_str_escaped)
                stored_generated_tokens.append(token_str_escaped)  # Store globally
                
                # Update progress bar
                progress_bar.value = pred_percent_through * 100
                
                # Update percentage label
                percentage_label.value = f"<b>{pred_percent_through*100:.1f}%</b>"
                
                # Create highlighted token display
                highlighted_tokens = []
                for i, token in enumerate(generated_tokens):
                    # Calculate percentage change if we have history
                    highlight_color = "#e6f3ff"  # Default light blue
                    
                    if i > 0 and i < len(prediction_history):
                        change = prediction_history[i] - prediction_history[i-1]
                        percent_change = abs(change / prediction_history[i-1]) * 100 if prediction_history[i-1] != 0 else 0
                        
                        # Color code based on prediction change magnitude
                        if percent_change > 15:
                            if change > 0:
                                highlight_color = "#ffcccc"  # Light red for large increases
                            else:
                                highlight_color = "#ccffcc"  # Light green for large decreases
                        elif percent_change > 5:
                            if change > 0:
                                highlight_color = "#ffe6cc"  # Light orange for medium increases
                            else:
                                highlight_color = "#e6ffcc"  # Light yellow-green for medium decreases
                    
                    highlighted_tokens.append(f"<span style='background-color: {highlight_color}; padding: 2px 4px; margin: 1px; border-radius: 3px;'>{token}</span>")
                
                tokens_html = " ".join(highlighted_tokens)
                token_display.value = f"<b>Generated tokens:</b><br>{tokens_html}<br><br><b>Latest:</b> '{token_str_escaped}' | <b>Predicted Total:</b> {predicted_total_tokens:.0f} tokens | <b>Progress:</b> {pred_percent_through*100:.1f}%<br><br><small><b>Color coding:</b> <span style='background-color: #e6f3ff; padding: 2px;'>Normal</span> <span style='background-color: #ffe6cc; padding: 2px;'>Med. increase</span> <span style='background-color: #ffcccc; padding: 2px;'>Large increase</span> <span style='background-color: #e6ffcc; padding: 2px;'>Med. decrease</span> <span style='background-color: #ccffcc; padding: 2px;'>Large decrease</span></small>"
                
                # Update the graph with current predictions
                if len(prediction_history) > 1:
                    # Create hover text with context tokens
                    hover_texts = []
                    for i in range(len(generated_tokens)):
                        # Get 5 tokens before and after (if available)
                        start_idx = max(0, i - 5)
                        end_idx = min(len(generated_tokens), i + 6)
                        
                        context_tokens = []
                        for j in range(start_idx, end_idx):
                            token_clean = generated_tokens[j].replace('&amp;', '&').replace('&lt;', '<').replace('&gt;', '>').replace('&quot;', '"').replace('&#x27;', "'")
                            if j == i:
                                context_tokens.append(f"<b>{token_clean}</b>")
                            else:
                                context_tokens.append(token_clean)
                        
                        context_str = " ".join(context_tokens)
                        hover_text = f"Token {i+1}: {context_str}<br>Predicted Total: {prediction_history[i]:.0f}"
                        hover_texts.append(hover_text)
                    
                    # Update graph with new data
                    graph_widget.data = []
                    graph_widget.add_trace(go.Scatter(
                        x=token_counts,
                        y=prediction_history,
                        mode='lines+markers',
                        name='Predicted Total Tokens',
                        line=dict(color='blue', width=2),
                        marker=dict(size=6),
                        hovertemplate='%{customdata}<extra></extra>',
                        customdata=hover_texts
                    ))
                    
                    # Add a horizontal line showing actual tokens generated so far
                    graph_widget.add_trace(go.Scatter(
                        x=[token_counts[0], token_counts[-1]],
                        y=[n_tokens_generated, n_tokens_generated],
                        mode='lines',
                        name='Current Progress',
                        line=dict(color='red', width=2, dash='dash'),
                        hovertemplate='Current tokens generated: %{y}<extra></extra>'
                    ))
                    
                    # Update layout
                    graph_widget.update_layout(
                        title='Token Prediction Over Time',
                        xaxis_title='Token Number',
                        yaxis_title='Predicted Total Tokens',
                        height=400,
                        showlegend=True,
                        hovermode='closest'
                    )
    
    # Store final token count globally
    stored_n_tokens = n_tokens_generated
    
    # After generation is complete, display prediction history
    print("\n" + "="*80)
    print("FINAL PREDICTION HISTORY")
    print("="*80)
    
    print(f"\nGenerated {n_tokens_generated} tokens total")
    print(f"Final prediction was {predicted_total_tokens:.0f} tokens")
    print(f"Accuracy: {(n_tokens_generated/predicted_total_tokens)*100:.1f}%")
    print(f"EMA factor used: {ema_factor}")
    
    print("\nFull Prediction History:")
    print("Token# | Predicted Total | Change | % Change | Token")
    print("-" * 70)
    
    for i, (count, pred_total, token) in enumerate(zip(token_counts, prediction_history, generated_tokens)):
        if i == 0:
            change = 0
            percent_change = 0
        else:
            change = pred_total - prediction_history[i-1]
            percent_change = abs(change / prediction_history[i-1]) * 100 if prediction_history[i-1] != 0 else 0
        
        # Display token in a safe way for console output (remove escaping for print)
        token_for_print = token.replace('&amp;', '&').replace('&lt;', '<').replace('&gt;', '>').replace('&quot;', '"').replace('&#x27;', "'")
        
        # Highlight tokens with large percentage changes (>10%)
        if percent_change > 10:
            if change > 0:
                # Large increase - bold green
                token_display_str = f"\033[1m\033[92m{token_for_print}\033[0m"
                change_str = f"\033[1m\033[92m+{change:.0f}\033[0m"
            else:
                # Large decrease - bold red  
                token_display_str = f"\033[1m\033[91m{token_for_print}\033[0m"
                change_str = f"\033[1m\033[91m{change:.0f}\033[0m"
        else:
            token_display_str = token_for_print
            if change > 0:
                change_str = f"+{change:.0f}"
            else:
                change_str = f"{change:.0f}"
        
        print(f"{count:5d}  | {pred_total:13.0f}   | {change_str:8s} | {percent_change:6.1f}%  | {token_display_str}")
    
    # Update graph with final results using current EMA
    update_graph_with_ema(ema_factor)

# Connect button click to function
submit_button.on_click(on_submit_clicked)


VBox(children=(HTML(value='<h3>Text Generation Progress</h3>'), Textarea(value='', description='Prompt:', layo…

torch.bfloat16
torch.bfloat16


torch.bfloat16
Okay
torch.bfloat16
,
torch.bfloat16
 the
torch.bfloat16
 user
torch.bfloat16
 is
torch.bfloat16
 asking
torch.bfloat16
 for
torch.bfloat16
 the
torch.bfloat16
 capital
torch.bfloat16
 of
torch.bfloat16
 France
torch.bfloat16
 and
torch.bfloat16
 some
torch.bfloat16
 fun
torch.bfloat16
 things
torch.bfloat16
 to
torch.bfloat16
 do
torch.bfloat16
 there
torch.bfloat16
.
torch.bfloat16
 Let
torch.bfloat16
 me
torch.bfloat16
 start
torch.bfloat16
 by
torch.bfloat16
 recalling
torch.bfloat16
 that
torch.bfloat16
 the
torch.bfloat16
 capital
torch.bfloat16
 of
torch.bfloat16
 France
torch.bfloat16
 is
torch.bfloat16
 Paris
torch.bfloat16
.
torch.bfloat16
 I
torch.bfloat16
's
torch.bfloat16
 pretty
torch.bfloat16
.


torch.bfloat16
Now
torch.bfloat16
,
torch.bfloat16
 for
torch.bfloat16
 the
torch.bfloat16
 fun
torch.bfloat16
 things
torch.bfloat16
 to
torch.bfloat16
 do
torch.bfloat16
 in
torch.bfloat16
 Paris
torch.bfloat16
.
torch.bfloat16
 I