# Interactive 3D Trajectory Visualization

This notebook creates interactive Plotly 3D visualizations of single trajectories with user/model coloring.

In [1]:
import torch as th
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import sys
sys.path.append('..')

from src.configs import *
from exp.geometry_population import GeometryPopulationConfig, is_user
from src.model_utils import load_nnsight_model

  from .autonotebook import tqdm as notebook_tqdm


In [43]:
def plot_interactive_trajectory(
    results,
    tokens_BP,
    tokenizer,
    act_type="temporal/pred_codes",
    sequence_idx=0,
    title=None,
    end_color="#AA1010"
):
    """
    Create an interactive Plotly 3D visualization of a single trajectory.
    
    Args:
        results: Dictionary mapping filenames to result dictionaries
        tokens_BP: Token tensor of shape (B, P) for determining user/model labels
        tokenizer: Tokenizer for decoding tokens
        act_type: Activation type to visualize (e.g., 'temporal/pred_codes')
        sequence_idx: Which sequence to visualize (default: 0)
        title: Optional title for the plot
        end_color: Color for model tokens (user tokens are always black)
    
    Returns:
        fig: Plotly figure object
    """
    # Find the result for the requested activation type
    result = None
    for filename, res in results.items():
        if res["config"]["act_path"] == act_type:
            result = res
            break
    
    if result is None:
        available = [res["config"]["act_path"] for res in results.values()]
        raise ValueError(f"Activation type '{act_type}' not found. Available: {available}")
    
    # Get config parameters
    start = result["config"]["start"]
    end = result["config"]["end"]
    num_sequences = result["config"]["num_sequences"]
    
    # Compute user labels
    is_user_BP = is_user(tokens_BP)
    is_user_sliced = is_user_BP[:num_sequences, start:end]
    
    # Extract embeddings for the selected sequence
    embeddings = th.tensor(result["embeddings"])  # Shape: (B, L, 3)
    pos_labels = result["pos_labels"]  # Shape: (B, L)
    
    if sequence_idx >= embeddings.shape[0]:
        raise ValueError(f"sequence_idx {sequence_idx} out of range (max: {embeddings.shape[0]-1})")
    
    # Get single trajectory
    traj = embeddings[sequence_idx].numpy()  # Shape: (L, 3)
    positions = pos_labels[sequence_idx]  # Shape: (L,)
    is_user_seq = is_user_sliced[sequence_idx].numpy()  # Shape: (L,)
    
    # Get tokens for this sequence (from start of full sequence up to end)
    full_tokens_seq = tokens_BP[sequence_idx, :end]  # Include everything from beginning to end
    tokens_seq = tokens_BP[sequence_idx, start:end]  # Just the trajectory portion
    
    # Create colors: black for user, end_color for model
    colors = []
    for is_u in is_user_seq:
        if is_u:
            colors.append('#000000')  # Black for user
        else:
            colors.append(end_color)  # End color for model
    
    # Create hover text with decoded sequence up to this point
    hover_texts = []
    for i, (pos, tok, is_u) in enumerate(zip(positions, tokens_seq, is_user_seq)):
        role = "User" if is_u else "Model"
        
        # Decode sequence from beginning up to this point (start + i + 1)
        decoded_seq = tokenizer.decode(full_tokens_seq[:int(pos)+1], skip_special_tokens=False)
        
        # Truncate for display if too long (show last 200 chars)
        # if len(decoded_seq) > 200:
        #     decoded_display = "..." + decoded_seq[-200:]
        # else:
        decoded_display = decoded_seq
        
        # Escape special characters for HTML display
        txt = decoded_display.replace("\n", "\\n").replace("<", "&lt;").replace(">", "&gt;")
        n = 50
        decoded_display = '<br>'.join(txt[i:i+n] for i in range(0, len(txt), n))

        
        hover_texts.append(
            f"<b>Position:</b> {int(pos)}<br>"
            f"<b>Role:</b> {role}<br>"
            f"<b>Token ID:</b> {int(tok)}<br>"
            f"<b>Step:</b> {i}/{len(positions)-1}<br>"
            f"<br><b>Decoded up to here:</b><br>"
            f"<span style='font-family: monospace; font-size: 10px;'>{decoded_display}</span>"
        )
    
    # Create the figure
    fig = go.Figure()
    
    # Add line trace
    fig.add_trace(go.Scatter3d(
        x=traj[:, 0],
        y=traj[:, 1],
        z=traj[:, 2],
        mode='lines',
        line=dict(
            color='rgba(150, 150, 150, 0.4)',
            width=4
        ),
        hoverinfo='skip',
        showlegend=False
    ))
    
    # Add scatter trace for points
    fig.add_trace(go.Scatter3d(
        x=traj[:, 0],
        y=traj[:, 1],
        z=traj[:, 2],
        mode='markers',
        marker=dict(
            size=8,
            color=colors,
            line=dict(
                color='white',
                width=0.5
            )
        ),
        text=hover_texts,
        hovertemplate='%{text}<extra></extra>',
        showlegend=False
    ))
    
    # Update layout
    if title is None:
        title = f"{act_type} - Sequence {sequence_idx}"
    
    fig.update_layout(
        title=dict(
            text=title,
            font=dict(size=20, family='Montserrat')
        ),
        scene=dict(
            xaxis=dict(showticklabels=False, showgrid=True, gridcolor='rgba(200,200,200,0.3)', title=''),
            yaxis=dict(showticklabels=False, showgrid=True, gridcolor='rgba(200,200,200,0.3)', title=''),
            zaxis=dict(showticklabels=False, showgrid=True, gridcolor='rgba(200,200,200,0.3)', title=''),
            bgcolor='rgba(240, 240, 240, 1)',
            aspectmode='cube'
        ),
        width=900,
        height=900,
        margin=dict(l=0, r=0, t=50, b=0),
        hovermode='closest'
    )
    
    return fig

## Load Results

In [50]:
# Configure which activation types to load
configs = get_gemma_act_configs(
    cfg_class=GeometryPopulationConfig,
    start=10,
    end=300,
    num_sequences=100,
    centering=False,
    normalize=False,
    # Artifacts
    env=ENV_CFG,
    data=CHAT_DS_CFG,
    llm=IT_GEMMA2_LLM_CFG,
    sae=None,
    act_paths=(
        (
            [None],
            [
                # "activations",
            ],
        ),
        (
            [GEMMA2_RELU_SAE_CFG, GEMMA2_TOPK_SAE_CFG, GEMMA2_BATCHTOPK_SAE_CFG],
            [
                # "codes",
            ],
        ),
        (
            [GEMMA2_TEMPORAL_SAE_CFG],
            [
                # "novel_codes",
                "pred_codes",
            ],
        ),
    ),
)

results = load_results_multiple_configs(
    exp_name="geometry_population_",
    source_cfgs=configs,
    target_folder=configs[0].env.results_dir,
    recency_rank=0,
    compared_attributes=["llm", "data", "start", "end", "centering", "normalize", "act_path"],
    verbose=True,
)

# Load tokens from cached activations
cfg = configs[0]
artifacts, _ = load_matching_activations(
    source_object=cfg,
    target_filenames=["tokens"],
    target_folder=cfg.env.activations_dir,
    recency_rank=0,
    compared_attributes=["llm", "data"],
)
tokens_BP = artifacts["tokens"]

# Load tokenizer
print("\nLoading tokenizer...")
model, submodule, hidden_dim = load_nnsight_model(cfg)
tokenizer = model.tokenizer
print(f"Tokenizer loaded: {type(tokenizer).__name__}")

print(f"\nLoaded {len(results)} results")
print("Available activation types:")
for filename, result in results.items():
    print(f"  - {result['config']['act_path']}")

Attribute value mismatch at path 'artifacts/results/geometry_population_20251027_115410.json.start': source has '10' but target has '5'
Attribute value mismatch at path 'artifacts/results/geometry_population_20251027_105648.json.start': source has '10' but target has '5'
Attribute value mismatch at path 'artifacts/results/geometry_population_20251027_114037.json.start': source has '10' but target has '5'
Attribute value mismatch at path 'artifacts/results/geometry_population_20251029_231508.json.end': source has '300' but target has '500'
Attribute value mismatch at path 'artifacts/results/geometry_population_20251027_105754.json.start': source has '10' but target has '5'
Attribute value mismatch at path 'artifacts/results/geometry_population_20251029_232901.json.end': source has '300' but target has '50'
Attribute value mismatch at path 'artifacts/results/geometry_population_20251027_113836.json.start': source has '10' but target has '5'
Attribute value mismatch at path 'artifacts/res

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.74it/s]

Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 2304, padding_idx=0)
    (layers): ModuleList(
      (0-25): 26 x Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): Linear(in_features=2304, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (v_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2304, bias=False)
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (up_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (down_proj): Linear(in_features=9216, out_features=2304, bias=False)
          (act_fn): GELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (post_attention_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (pre_feedforward_layernorm): Gemma2RMSNo




## Visualize Interactive Trajectories

In [51]:
# Color mapping for different activation types
color_map = {
    "activations": "#EF9E72",
    "relu/codes": "#A855F7",
    "topk/codes": "#A855F7",
    "batchtopk/codes": "#A855F7",
    "temporal/novel_codes": "#06B6D4",
    "temporal/pred_codes": "#AA1010",
}

In [52]:
# Example: Visualize Temporal (Pred) trajectory
fig = plot_interactive_trajectory(
    results,
    tokens_BP,
    tokenizer,
    act_type="temporal/pred_codes",
    sequence_idx=10,
    title="Temporal (Pred) - Sequence 0",
    end_color=color_map["temporal/pred_codes"]
)
fig.show()

In [49]:
# Example: Visualize Activations trajectory
fig = plot_interactive_trajectory(
    results,
    tokens_BP,
    tokenizer,
    act_type="activations",
    sequence_idx=2,
    title="Activations - Sequence 0",
    end_color=color_map["activations"]
)
fig.show()

ValueError: Activation type 'activations' not found. Available: ['temporal/pred_codes']

In [13]:
# Example: Compare different sequences for the same activation type
for seq_idx in [0, 1, 2]:
    fig = plot_interactive_trajectory(
        results,
        tokens_BP,
        tokenizer,
        act_type="temporal/pred_codes",
        sequence_idx=seq_idx,
        title=f"Temporal (Novel) - Sequence {seq_idx}",
        end_color=color_map["temporal/pred_codes"]
    )
    fig.show()

In [None]:
# Save to HTML file (optional)
fig = plot_interactive_trajectory(
    results,
    tokens_BP,
    tokenizer,
    act_type="temporal/pred_codes",
    sequence_idx=0,
    end_color=color_map["temporal/pred_codes"]
)
# fig.write_html("interactive_trajectory.html")