Just run all these cells

In [1]:
from nnsight import LanguageModel
from einops import einsum
import torch
import ipywidgets as widgets
from IPython.display import display
import time

In [2]:
device = 'cuda'

In [3]:
weight_tensor = torch.load('/workspace/llm-progress-monitor/qwen3_4b_weight_tensor.pt')
model_name = 'Qwen/Qwen3-4B'

In [30]:
model = LanguageModel(model_name, device_map=device, dtype=torch.bfloat16)

In [31]:
def get_ema_preds(log_preds, alpha=0.99):
    given_alpha = alpha
    preds_list = log_preds.exp().tolist()
    
    ema_preds = []
    cur_ema = None
    for i,pred in enumerate(preds_list):
        if pred < 10:
            alpha = 0.5
        else:
            alpha = given_alpha
        if cur_ema is None:
            cur_ema = pred
        else:
            cur_ema = alpha*(cur_ema-1) + (1-alpha)*pred #-1 because we have stepped one token
        ema_preds.append(cur_ema)
    return ema_preds

In [32]:
def get_log_preds(activation, weight_tensor):
    print(activation.shape, weight_tensor.shape, activation.dtype, weight_tensor.dtype)

    return einsum(
        einsum(activation, weight_tensor, 'seq d_model, pca d_model -> seq pca').softmax(dim=1),
        0.5+torch.arange(weight_tensor.shape[0]).to(device, dtype=torch.bfloat16),
        'seq pca, pca -> seq'
    )

In [None]:


# Create input text box for prompt
prompt_input = widgets.Textarea(
    value="",
    placeholder='Enter your prompt here...',
    description='Prompt:',
    layout=widgets.Layout(width='100%', height='80px')
)

# Create submit button
submit_button = widgets.Button(
    description='Generate Text',
    button_style='success',
    tooltip='Click to start text generation',
    icon='play'
)

# Create progress bar widget
progress_bar = widgets.FloatProgress(
    value=0,
    min=0,
    max=100,
    description='Progress:',
    bar_style='info',
    style={'bar_color': '#20B2AA'},
    orientation='horizontal'
)

# Create percentage label
percentage_label = widgets.HTML(
    value="<b>0.0%</b>",
    description='',
)

# Create horizontal box for progress bar and percentage
progress_row = widgets.HBox([progress_bar, percentage_label])

# Create text widget for token display
token_display = widgets.HTML(
    value="<b>Generated tokens will appear here...</b>",
    placeholder='',
    description='',
)

# Create container for the widgets
progress_container = widgets.VBox([
    widgets.HTML("<h3>Text Generation Progress</h3>"),
    prompt_input,
    submit_button,
    progress_row,
    token_display
])

# Display the widget
display(progress_container)

def on_submit_clicked(b):
    # Reset progress
    progress_bar.value = 0
    percentage_label.value = "<b>0.0%</b>"
    token_display.value = "<b>Generating...</b>"
    
    # Get prompt from input
    prompt = prompt_input.value
    # Apply chat template
    prompt = model.tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
    cur_log_preds = []
    n_tokens_generated = 0
    generated_tokens = []

    with model.generate(prompt, max_new_tokens=1000, do_sample=False) as tracer:
        # Call .all() to apply intervention to each new token
        with tracer.all():
            activations = model.model.layers[15].output[0]
            print(activations.dtype)
            if len(activations.shape) == 1:
                activations = activations.unsqueeze(0)
            # Pad the last dimension to 2560
            if activations.shape[-1] < 2560:
                pad_size = 2560 - activations.shape[-1]
                activations = torch.nn.functional.pad(activations, (0, pad_size))
            print(activations.shape, len(activations.shape))
            preds = get_log_preds(activations, weight_tensor).tolist()
            if len(preds) > 1:
                pass
            else:
                cur_log_preds+=preds
                ema_preds = get_ema_preds(torch.tensor(cur_log_preds))
                n_tokens_generated+=1
                pred_tokens_remaining = ema_preds[-1]
                pred_percent_through = n_tokens_generated/(n_tokens_generated + pred_tokens_remaining)
                
                token = model.lm_head.output.argmax(dim=-1).tolist()
                token_str = model.tokenizer.decode(token[0][0], skip_special_tokens=True)
                generated_tokens.append(token_str)
                
                # Update progress bar
                progress_bar.value = pred_percent_through * 100
                
                # Update percentage label
                percentage_label.value = f"<b>{pred_percent_through*100:.1f}%</b>"
                
                # Update token display with all generated tokens
                tokens_html = " ".join([f"<span style='background-color: #e6f3ff; padding: 2px 4px; margin: 1px; border-radius: 3px;'>{token}</span>" for token in generated_tokens])
                token_display.value = f"<b>Generated tokens:</b><br>{tokens_html}<br><br><b>Latest:</b> '{token_str}' | <b>Predicted:</b> {pred_percent_through*100:.1f}% through"

# Connect button click to function
submit_button.on_click(on_submit_clicked)


VBox(children=(HTML(value='<h3>Text Generation Progress</h3>'), Textarea(value='', description='Prompt:', layo…

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

torch.bfloat16
torch.Size([10, 2560]) 2
torch.Size([10, 2560]) torch.Size([8, 2560]) torch.bfloat16 torch.bfloat16
torch.bfloat16
torch.Size([1, 2560]) 2
torch.Size([1, 2560]) torch.Size([8, 2560]) torch.bfloat16 torch.bfloat16


  return _VF.einsum(equation, operands)  # type: ignore[attr-defined]


torch.bfloat16
torch.Size([1, 2560]) 2
torch.Size([1, 2560]) torch.Size([8, 2560]) torch.bfloat16 torch.bfloat16
torch.bfloat16
torch.Size([1, 2560]) 2
torch.Size([1, 2560]) torch.Size([8, 2560]) torch.bfloat16 torch.bfloat16
torch.bfloat16
torch.Size([1, 2560]) 2
torch.Size([1, 2560]) torch.Size([8, 2560]) torch.bfloat16 torch.bfloat16
torch.bfloat16
torch.Size([1, 2560]) 2
torch.Size([1, 2560]) torch.Size([8, 2560]) torch.bfloat16 torch.bfloat16
torch.bfloat16
torch.Size([1, 2560]) 2
torch.Size([1, 2560]) torch.Size([8, 2560]) torch.bfloat16 torch.bfloat16
torch.bfloat16
torch.Size([1, 2560]) 2
torch.Size([1, 2560]) torch.Size([8, 2560]) torch.bfloat16 torch.bfloat16
torch.bfloat16
torch.Size([1, 2560]) 2
torch.Size([1, 2560]) torch.Size([8, 2560]) torch.bfloat16 torch.bfloat16
torch.bfloat16
torch.Size([1, 2560]) 2
torch.Size([1, 2560]) torch.Size([8, 2560]) torch.bfloat16 torch.bfloat16
torch.bfloat16
torch.Size([1, 2560]) 2
torch.Size([1, 2560]) torch.Size([8, 2560]) torch.bfloat1