# UMAP Visualization from Results File

This notebook generates a 3D UMAP plot from a saved results file.

**Usage:** Update the `RESULTS_PATH` variable below with your results file path.

In [37]:
# Configuration
RESULTS_PATH = "/home/can/dynamic_representations/artifacts/results/pred_structure_20251020_190254_temporal_pred_codes_Gemma-2-2B_SimpleStories.json"
RESULTS_PATH = "/home/can/dynamic_representations/artifacts/results/pred_structure_20251030_023158_temporal_codes_Gemma-2-2B-IT_Chat.json"
RESULTS_PATH = "/home/can/dynamic_representations/artifacts/results/pred_structure_20251030_025227_temporal_codes_Gemma-2-2B-IT_Chat.json"
RESULTS_PATH = "/home/can/dynamic_representations/artifacts/results/pred_structure_20251030_025829_temporal_codes_Gemma-2-2B-IT_Alpaca.json"

# Optional plot parameters
NUM_SEQUENCES = None  # Limit sequences: int for first N, list for specific indices (e.g., [0, 5, 10]), or None for all
CONNECT_SEQUENCES = None  # List of sequence indices to connect AFTER filtering NUM_sequences (e.g., [0, 1, 5]) or None for all or empty list [] for none
FILTER_WORDS = None  # List of strings to filter points by (e.g., ["word1", "word2"]) or None to show all points

# User vs Model highlighting
USE_USER_MODEL_COLORS = True  # Set to True to use binary coloring: black for user tokens, end color for model tokens
TOKENS_PATH = "/home/can/dynamic_representations/artifacts/activations/20251029_232750/tokens.pt"  # Path to tokens file (required if USE_USER_MODEL_COLORS=True), e.g., "artifacts/activations/tokens_Gemma-2-2B-IT_Chat.pt"

In [38]:
import json
import numpy as np
import plotly.graph_objects as go
from pathlib import Path
import torch as th

In [44]:
# Load results
with open(RESULTS_PATH, 'r') as f:
    result_dict = json.load(f)

# Extract data
embedding = np.array(result_dict["embedding"])  # (B, L, n_components)
pos_labels = np.array(result_dict["pos_labels"])  # (B, L)
pos_indices = result_dict["pos_indices"]
hover_texts_saved = result_dict.get("hover_texts", None)  # List of hover texts
config = result_dict["config"]

print(f"Embedding shape: {embedding.shape}")
print(f"Position labels shape: {pos_labels.shape}")
print(f"Hover texts available: {hover_texts_saved is not None}")
print(f"Config: {config}")

# Load tokens if user/model coloring is requested
if USE_USER_MODEL_COLORS:
    if TOKENS_PATH is None:
        raise ValueError("TOKENS_PATH must be specified when USE_USER_MODEL_COLORS=True")
    tokens_BP = th.load(TOKENS_PATH, weights_only=False)
    print(f"Tokens shape: {tokens_BP.shape}")
else:
    tokens_BP = None

Embedding shape: (1000, 30, 3)
Position labels shape: (1000, 30)
Hover texts available: True
Config: {'llm': {'name': 'Gemma-2-2B-IT', 'hf_name': 'google/gemma-2-2b-it', 'revision': None, 'layer_idx': 12, 'hidden_dim': 2304, 'batch_size': 50}, 'sae': {'release': 'singh', 'name': 'temporal', 'local_weights_path': 'artifacts/trained_saes/gemma-2-2B/layer_12/temporal', 'dict_class': 'temporal', 'dict_size': 9216, 'batch_size': 10}, 'env': {'device': 'cuda', 'dtype': 'bfloat16', 'hf_cache_dir': '/home/can/models', 'plots_dir': 'artifacts/plots', 'results_dir': 'artifacts/results', 'text_inputs_dir': 'artifacts/text_inputs', 'activations_dir': 'artifacts/activations'}, 'data': {'name': 'Alpaca', 'hf_name': 'tatsu-lab/alpaca', 'num_sequences': 1000, 'context_length': 50}, 'act_path': 'temporal/pred_codes', 'min_p': 0, 'max_p': 40, 'num_p': 30, 'do_log_scale': False, 'n_components': 3, 'n_neighbors': 15, 'min_dist': 0.1, 'metric': 'euclidean', 'random_state': 42, 'num_sequences': None, 'conne

In [45]:
# is_user function for determining user vs model tokens
def is_user(tokens_BP: th.Tensor) -> th.Tensor:
    """
    Return boolean tensor indicating which tokens are from the user (not the model),
    across multiple turns.
    
    The user signature is [106, 1645] ('<start_of_turn>', 'user')
    The model signature is [106, 2516] ('<start_of_turn>', 'model')
    
    Args:
        tokens_BP: Token tensor of shape (B, P)
    
    Returns:
        is_user_BP: Boolean tensor of shape (B, P) where True = user, False = model
    """
    B, P = tokens_BP.shape
    is_user_BP = th.zeros(B, P, dtype=th.bool)
    
    for b in range(B):
        tokens = tokens_BP[b]
        current_is_user = False
        
        p = 0
        while p < P - 1:
            # detect start of turn
            if tokens[p] == 106:  # <start_of_turn>
                if tokens[p + 1] == 1645:  # 'user'
                    current_is_user = True
                elif tokens[p + 1] == 2516:  # 'model'
                    current_is_user = False
            is_user_BP[b, p] = current_is_user
            p += 1
        # Last token
        if P > 0:
            is_user_BP[b, P - 1] = current_is_user
    
    return is_user_BP


# Subsample sequences if requested
if NUM_SEQUENCES is not None:
    if isinstance(NUM_SEQUENCES, list):
        # Use specific indices
        indices = [i for i in NUM_SEQUENCES if i < len(embedding)]  # Validate indices
        embedding = embedding[indices]
        pos_labels = pos_labels[indices]
        if hover_texts_saved is not None:
            # hover_texts is a flat list, need to select corresponding elements
            seq_len = embedding.shape[1]
            new_hover_texts = []
            for idx in indices:
                start = idx * seq_len
                end = start + seq_len
                new_hover_texts.extend(hover_texts_saved[start:end])
            hover_texts_saved = new_hover_texts
        if tokens_BP is not None:
            tokens_BP = tokens_BP[indices]
    else:
        # Use first N sequences
        embedding = embedding[:NUM_SEQUENCES]
        pos_labels = pos_labels[:NUM_SEQUENCES]
        if hover_texts_saved is not None:
            seq_len = embedding.shape[1]
            hover_texts_saved = hover_texts_saved[:NUM_SEQUENCES * seq_len]
        if tokens_BP is not None:
            tokens_BP = tokens_BP[:NUM_SEQUENCES]

batch_size, seq_len, n_dims = embedding.shape
print(f"Plotting {batch_size} sequences with {seq_len} points each")

# Compute user/model labels if requested
if USE_USER_MODEL_COLORS and tokens_BP is not None:
    # Get start and end from config to slice tokens correctly
    start = config.get("min_p", 0)
    end = config.get("max_p", tokens_BP.shape[1])
    num_p = config.get("num_p", None)
    do_log_scale = config.get("do_log_scale", False)
    
    # Compute user labels for full sequence, then slice to match the embedding range
    is_user_BP = is_user(tokens_BP)
    is_user_sliced = is_user_BP[:batch_size, start:end]
    
    # Apply position subsampling if num_p is specified (matching exp logic)
    if num_p is not None:
        original_seq_len = is_user_sliced.shape[1]
        if do_log_scale:
            log_start = th.log10(th.tensor(0, dtype=th.float) + 1)
            log_end = th.log10(th.tensor(original_seq_len - 1, dtype=th.float) + 1)
            log_steps = th.linspace(log_start, log_end, num_p)
            ps = th.round(10**log_steps - 1).long().clamp(0, original_seq_len - 1)
        else:
            ps = th.linspace(0, original_seq_len - 1, num_p, dtype=th.long)
        
        is_user_sliced = is_user_sliced[:, ps]
    
    # Convert to binary color labels: user=0.0 (black), model=1.0 (end color)
    is_user_np = is_user_sliced.cpu().numpy().astype(float)
    color_labels = 1.0 - is_user_np  # Invert: True (user) -> 0.0, False (model) -> 1.0
    
    print(f"User/model coloring enabled: {is_user_sliced.sum().item()} user tokens, {(~is_user_sliced).sum().item()} model tokens")
else:
    color_labels = None

Plotting 1000 sequences with 30 points each
User/model coloring enabled: 16025 user tokens, 13975 model tokens


In [46]:
# Build title
act_path = config.get('act_path', 'unknown')
llm_name = config.get('llm', {}).get('name', 'unknown')
data_name = config.get('data', {}).get('name', 'unknown')
sae_cfg = config.get('sae', {})
if sae_cfg is not None:
    sae_name = sae_cfg.get('name', '')
else:
    sae_name = "LLM"

if act_path == "activations":
    title_prefix = "UMAP of LLM activations"
elif act_path == "codes":
    title_prefix = f"UMAP of {sae_name} codes"
elif act_path == "pred_codes":
    title_prefix = f"UMAP of {sae_name} pred codes"
elif act_path == "novel_codes":
    title_prefix = f"UMAP of {sae_name} novel codes"
else:
    title_prefix = f"UMAP of {act_path}"
    
title_prefix += f" {data_name}"

# Add user/model info to title if applicable
if USE_USER_MODEL_COLORS:
    title_prefix += " (User vs Model)"
    
print(f"Title: {title_prefix}")

Title: UMAP of temporal/pred_codes Alpaca (User vs Model)


In [47]:
# Flatten for plotting
embedding_flat = embedding.reshape(-1, n_dims)
pos_labels_flat = pos_labels.flatten()

# Determine what to use for coloring
if USE_USER_MODEL_COLORS and color_labels is not None:
    # Use binary user/model colors
    color_labels_flat = color_labels.flatten()
    colorscale = [[0.0, "#000000"], [1.0, "#AA1010"]]  # Black for user, red for model (matches temporal/pred_codes)
    colorbar_title = "User (black) / Model (red)"
else:
    # Use position-based colors
    color_labels_flat = pos_labels_flat
    colorscale = "Viridis"
    colorbar_title = "Position Index"

# Use saved hover texts if available, otherwise fallback to position only
if hover_texts_saved is not None:
    # Make special tokens bold and underlined
    hover_texts = [text.replace('<end_of_turn>', '<b><u>&lt;end_of_turn&gt;</u></b>').replace('<start_of_turn>', '<b><u>&lt;start_of_turn&gt;</u></b>') for text in hover_texts_saved]
else:
    hover_texts = [f"Position: {p}" for p in pos_labels_flat]

# Apply FILTER_WORDS if specified
if FILTER_WORDS is not None:
    # Create mask for points that contain at least one filter word
    mask = [any(word in text for word in FILTER_WORDS) for text in hover_texts]
    embedding_flat = embedding_flat[mask]
    pos_labels_flat = pos_labels_flat[mask]
    color_labels_flat = color_labels_flat[mask]
    hover_texts = [text for text, m in zip(hover_texts, mask) if m]
    print(f"Filtered to {len(hover_texts)} points containing filter words: {FILTER_WORDS}")

In [48]:
# Create trace data
traces = []

if n_dims == 2:
    # 2D plot
    traces.append(
        go.Scatter(
            x=embedding_flat[:, 0],
            y=embedding_flat[:, 1],
            mode="markers",
            marker=dict(
                size=8,
                color=color_labels_flat,
                colorscale=colorscale,
                showscale=True,
                colorbar=dict(title=colorbar_title),
                opacity=0.6,
            ),
            text=hover_texts,
            hovertemplate="%{text}<br>UMAP1: %{x:.2f}<br>UMAP2: %{y:.2f}<extra></extra>",
        )
    )

    # Add sequence connection lines
    if CONNECT_SEQUENCES is not None:
        # Determine which sequences to connect
        if isinstance(CONNECT_SEQUENCES, list):
            sequences_to_connect = CONNECT_SEQUENCES
        else:
            # If not a list (e.g., True), connect all sequences
            sequences_to_connect = range(batch_size)

        for b in sequences_to_connect:
            if b < batch_size:  # Validate index
                traces.append(
                    go.Scatter(
                        x=embedding[b, :, 0],
                        y=embedding[b, :, 1],
                        mode="lines",
                        line=dict(color="black", width=2),
                        opacity=0.3,
                        showlegend=False,
                        hoverinfo="skip",
                    )
                )

    # Create interactive 2D plot
    fig = go.Figure(data=traces)

    fig.update_layout(
        title=f"{title_prefix} (colored by {'user/model' if USE_USER_MODEL_COLORS else 'position'})",
        xaxis_title="UMAP 1",
        yaxis_title="UMAP 2",
        width=1000,
        height=800,
    )

elif n_dims == 3:
    # 3D plot
    traces.append(
        go.Scatter3d(
            x=embedding_flat[:, 0],
            y=embedding_flat[:, 1],
            z=embedding_flat[:, 2],
            mode="markers",
            marker=dict(
                size=3,
                color=color_labels_flat,
                colorscale=colorscale,
                showscale=True,
                colorbar=dict(title=colorbar_title),
                opacity=0.6,
            ),
            text=hover_texts,
            hovertemplate="%{text}<br>UMAP1: %{x:.2f}<br>UMAP2: %{y:.2f}<br>UMAP3: %{z:.2f}<extra></extra>",
        )
    )

    # Add sequence connection lines
    if CONNECT_SEQUENCES is not None:
        # Determine which sequences to connect
        if isinstance(CONNECT_SEQUENCES, list):
            sequences_to_connect = CONNECT_SEQUENCES
        else:
            # If not a list (e.g., True), connect all sequences
            sequences_to_connect = range(batch_size)

        for b in sequences_to_connect:
            if b < batch_size:  # Validate index
                traces.append(
                    go.Scatter3d(
                        x=embedding[b, :, 0],
                        y=embedding[b, :, 1],
                        z=embedding[b, :, 2],
                        mode="lines",
                        line=dict(color="black", width=2),
                        opacity=0.3,
                        showlegend=False,
                        hoverinfo="skip",
                    )
                )

    # Create interactive 3D plot
    fig = go.Figure(data=traces)

    fig.update_layout(
        title=f"{title_prefix} (colored by {'user/model' if USE_USER_MODEL_COLORS else 'position'})",
        scene=dict(xaxis_title="UMAP 1", yaxis_title="UMAP 2", zaxis_title="UMAP 3"),
        width=1000,
        height=800,
    )

else:
    raise ValueError(f"n_dims must be 2 or 3, got {n_dims}")

fig.show()

In [8]:
# Save as HTML
output_path = Path(RESULTS_PATH).parent / f"{Path(RESULTS_PATH).stem}_plot.html"
fig.write_html(str(output_path))
print(f"Saved interactive plot to: {output_path}")

Saved interactive plot to: /home/can/dynamic_representations/artifacts/results/pred_structure_20251020_190254_temporal_pred_codes_Gemma-2-2B_SimpleStories_plot.html
