# UMAP Visualization from Results File

This notebook generates a 3D UMAP plot from a saved results file.

**Usage:** Update the `RESULTS_PATH` variable below with your results file path.

In [1]:
# Configuration
RESULTS_PATH = "/home/can/dynamic_representations/artifacts/results/pred_structure_20251020_190254_temporal_pred_codes_Gemma-2-2B_SimpleStories.json"

# Optional plot parameters
NUM_SEQUENCES = [0]  # Limit sequences: int for first N, list for specific indices (e.g., [0, 5, 10]), or None for all
CONNECT_SEQUENCES = [0]  # List of sequence indices to connect AFTER filtering NUM_sequences (e.g., [0, 1, 5]) or None for all or empty list [] for none
FILTER_WORDS = None  # List of strings to filter points by (e.g., ["word1", "word2"]) or None to show all points

In [2]:
import json
import numpy as np
import plotly.graph_objects as go
from pathlib import Path

In [3]:
# Load results
with open(RESULTS_PATH, 'r') as f:
    result_dict = json.load(f)

# Extract data
embedding = np.array(result_dict["embedding"])  # (B, L, n_components)
pos_labels = np.array(result_dict["pos_labels"])  # (B, L)
pos_indices = result_dict["pos_indices"]
hover_texts_saved = result_dict.get("hover_texts", None)  # List of hover texts
config = result_dict["config"]

print(f"Embedding shape: {embedding.shape}")
print(f"Position labels shape: {pos_labels.shape}")
print(f"Hover texts available: {hover_texts_saved is not None}")
print(f"Config: {config}")

Embedding shape: (1000, 100, 3)
Position labels shape: (1000, 100)
Hover texts available: True
Config: {'llm': {'name': 'Gemma-2-2B', 'hf_name': 'google/gemma-2-2b', 'revision': None, 'layer_idx': 12, 'hidden_dim': 2304, 'batch_size': 50}, 'sae': {'name': 'temporal', 'local_weights_path': 'artifacts/trained_saes/selftrain/temporal', 'dict_class': 'temporal', 'dict_size': 9216, 'batch_size': 10}, 'env': {'device': 'cuda', 'dtype': 'bfloat16', 'hf_cache_dir': '/home/can/models', 'plots_dir': 'artifacts/plots', 'results_dir': 'artifacts/results', 'text_inputs_dir': 'artifacts/text_inputs', 'activations_dir': 'artifacts/activations'}, 'data': {'name': 'SimpleStories', 'hf_name': 'SimpleStories/SimpleStories', 'num_sequences': 1000, 'context_length': 500}, 'act_path': 'pred_codes', 'min_p': 0, 'max_p': 100, 'num_p': 100, 'do_log_scale': False, 'n_components': 3, 'n_neighbors': 15, 'min_dist': 0.1, 'metric': 'euclidean', 'random_state': 42, 'num_sequences': None, 'connect_sequences': False, 

In [4]:
# Subsample sequences if requested
if NUM_SEQUENCES is not None:
    if isinstance(NUM_SEQUENCES, list):
        # Use specific indices
        indices = [i for i in NUM_SEQUENCES if i < len(embedding)]  # Validate indices
        embedding = embedding[indices]
        pos_labels = pos_labels[indices]
        if hover_texts_saved is not None:
            # hover_texts is a flat list, need to select corresponding elements
            seq_len = embedding.shape[1]
            new_hover_texts = []
            for idx in indices:
                start = idx * seq_len
                end = start + seq_len
                new_hover_texts.extend(hover_texts_saved[start:end])
            hover_texts_saved = new_hover_texts
    else:
        # Use first N sequences
        embedding = embedding[:NUM_SEQUENCES]
        pos_labels = pos_labels[:NUM_SEQUENCES]
        if hover_texts_saved is not None:
            seq_len = embedding.shape[1]
            hover_texts_saved = hover_texts_saved[:NUM_SEQUENCES * seq_len]

batch_size, seq_len, n_dims = embedding.shape
print(f"Plotting {batch_size} sequences with {seq_len} points each")

Plotting 1 sequences with 100 points each


In [5]:
# Build title
act_path = config.get('act_path', 'unknown')
llm_name = config.get('llm', {}).get('name', 'unknown')
data_name = config.get('data', {}).get('name', 'unknown')
sae_cfg = config.get('sae', {})
if sae_cfg is not None:
    sae_name = sae_cfg.get('name', '')
else:
    sae_name = "LLM"

if act_path == "activations":
    title_prefix = "UMAP of LLM activations"
elif act_path == "codes":
    title_prefix = f"UMAP of {sae_name} codes"
elif act_path == "pred_codes":
    title_prefix = f"UMAP of {sae_name} pred codes"
elif act_path == "novel_codes":
    title_prefix = f"UMAP of {sae_name} novel codes"
else:
    title_prefix = f"UMAP of {act_path}"
    
title_prefix += f" {data_name}"
print(f"Title: {title_prefix}")

Title: UMAP of temporal pred codes SimpleStories


In [6]:
# Flatten for plotting
embedding_flat = embedding.reshape(-1, n_dims)
pos_labels_flat = pos_labels.flatten()

# Use saved hover texts if available, otherwise fallback to position only
if hover_texts_saved is not None:
    # Make special tokens bold and underlined
    hover_texts = [text.replace('<end_of_turn>', '<b><u>&lt;end_of_turn&gt;</u></b>').replace('<start_of_turn>', '<b><u>&lt;start_of_turn&gt;</u></b>') for text in hover_texts_saved]
else:
    hover_texts = [f"Position: {p}" for p in pos_labels_flat]

# Apply FILTER_WORDS if specified
if FILTER_WORDS is not None:
    # Create mask for points that contain at least one filter word
    mask = [any(word in text for word in FILTER_WORDS) for text in hover_texts]
    embedding_flat = embedding_flat[mask]
    pos_labels_flat = pos_labels_flat[mask]
    hover_texts = [text for text, m in zip(hover_texts, mask) if m]
    print(f"Filtered to {len(hover_texts)} points containing filter words: {FILTER_WORDS}")

In [7]:
# Create trace data
traces = []

if n_dims == 2:
    # 2D plot
    traces.append(
        go.Scatter(
            x=embedding_flat[:, 0],
            y=embedding_flat[:, 1],
            mode="markers",
            marker=dict(
                size=8,
                color=pos_labels_flat,
                colorscale="Viridis",
                showscale=True,
                colorbar=dict(title="Position Index"),
                opacity=0.6,
            ),
            text=hover_texts,
            hovertemplate="%{text}<br>UMAP1: %{x:.2f}<br>UMAP2: %{y:.2f}<extra></extra>",
        )
    )

    # Add sequence connection lines
    if CONNECT_SEQUENCES is not None:
        # Determine which sequences to connect
        if isinstance(CONNECT_SEQUENCES, list):
            sequences_to_connect = CONNECT_SEQUENCES
        else:
            # If not a list (e.g., True), connect all sequences
            sequences_to_connect = range(batch_size)

        for b in sequences_to_connect:
            if b < batch_size:  # Validate index
                traces.append(
                    go.Scatter(
                        x=embedding[b, :, 0],
                        y=embedding[b, :, 1],
                        mode="lines",
                        line=dict(color="black", width=2),
                        opacity=0.3,
                        showlegend=False,
                        hoverinfo="skip",
                    )
                )

    # Create interactive 2D plot
    fig = go.Figure(data=traces)

    fig.update_layout(
        title=f"{title_prefix} (colored by position)",
        xaxis_title="UMAP 1",
        yaxis_title="UMAP 2",
        width=1000,
        height=800,
    )

elif n_dims == 3:
    # 3D plot
    traces.append(
        go.Scatter3d(
            x=embedding_flat[:, 0],
            y=embedding_flat[:, 1],
            z=embedding_flat[:, 2],
            mode="markers",
            marker=dict(
                size=3,
                color=pos_labels_flat,
                colorscale="Viridis",
                showscale=True,
                colorbar=dict(title="Position Index"),
                opacity=0.6,
            ),
            text=hover_texts,
            hovertemplate="%{text}<br>UMAP1: %{x:.2f}<br>UMAP2: %{y:.2f}<br>UMAP3: %{z:.2f}<extra></extra>",
        )
    )

    # Add sequence connection lines
    if CONNECT_SEQUENCES is not None:
        # Determine which sequences to connect
        if isinstance(CONNECT_SEQUENCES, list):
            sequences_to_connect = CONNECT_SEQUENCES
        else:
            # If not a list (e.g., True), connect all sequences
            sequences_to_connect = range(batch_size)

        for b in sequences_to_connect:
            if b < batch_size:  # Validate index
                traces.append(
                    go.Scatter3d(
                        x=embedding[b, :, 0],
                        y=embedding[b, :, 1],
                        z=embedding[b, :, 2],
                        mode="lines",
                        line=dict(color="black", width=2),
                        opacity=0.3,
                        showlegend=False,
                        hoverinfo="skip",
                    )
                )

    # Create interactive 3D plot
    fig = go.Figure(data=traces)

    fig.update_layout(
        title=f"{title_prefix} (colored by position)",
        scene=dict(xaxis_title="UMAP 1", yaxis_title="UMAP 2", zaxis_title="UMAP 3"),
        width=1000,
        height=800,
    )

else:
    raise ValueError(f"n_dims must be 2 or 3, got {n_dims}")

fig.show()

In [None]:
# Save as HTML
output_path = Path(RESULTS_PATH).parent / f"{Path(RESULTS_PATH).stem}_plot.html"
fig.write_html(str(output_path))
print(f"Saved interactive plot to: {output_path}")