# Track conversation trajectory

In [4]:
import json
import os
import sys
import torch
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.subplots as sp

from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.metrics.pairwise import cosine_similarity
from pathlib import Path
from tqdm import tqdm

sys.path.append('.')
sys.path.append('..')

from utils.inference_utils import *
from utils.probing_utils import *




INFO 08-07 20:47:02 [__init__.py:235] Automatically detected platform cuda.


## Score statistics

In [29]:
# load data from data/extract_scores
score_dir = "/root/git/persona-subspace/traits/data/extract_scores"

# iterate through each json file in the directory
scores = {}
for file in os.listdir(score_dir):
    if file.endswith(".json"):
        with open(os.path.join(score_dir, file), "r") as f:
            scores[file.replace(".json", "")] = json.load(f)

print(f"Found {len(scores.keys())} traits with scores")


Found 240 traits with scores


In [30]:
# Analyze refusals and clean data
refusal_info = {}
scores_clean = {}

for trait, score_obj in scores.items():
    refusals = []
    cleaned_scores = {}
    
    # Check each score for refusals
    for key, value in score_obj.items():
        if value == "REFUSAL":
            refusals.append(key)
            cleaned_scores[key] = 0  # Replace refusals with NaN
        else:
            cleaned_scores[key] = float(value)  # Ensure numeric
    
    scores_clean[trait] = cleaned_scores
    refusal_info[trait] = {
        "refusals": refusals,
        "refusal_count": len(refusals)
    }

# Show refusal statistics
total_refusals = sum(info["refusal_count"] for info in refusal_info.values())
traits_with_refusals = sum(1 for info in refusal_info.values() if info["refusal_count"] > 0)

print(f"Refusal Statistics:")
print(f"Total refusals across all traits: {total_refusals}")
print(f"Traits with refusals: {traits_with_refusals}")

if total_refusals > 0:
    sorted_refusals = sorted(refusal_info.items(), key=lambda x: x[1]["refusal_count"], reverse=True)
    print(f"\nTop 10 traits with most refusals:")
    for trait, info in sorted_refusals[:10]:
        if info["refusal_count"] > 0:
            print(f"  {trait}: {info['refusal_count']} refusals - {info['refusals']}")

Refusal Statistics:
Total refusals across all traits: 4
Traits with refusals: 4

Top 10 traits with most refusals:
  animated: 1 refusals - ['neg_p1_q7']
  diplomatic: 1 refusals - ['neg_p2_q19']
  generalist: 1 refusals - ['neg_p0_q4']
  utilitarian: 1 refusals - ['default_p2_q13']


In [4]:
# Create numpy arrays using cleaned scores (refusals as NaN)
# Structure: 3D tensor with shape (3 types, 5 prompts, 20 questions)
scores_np = {}

for trait, cleaned_scores in scores_clean.items():
    # Create 3D array: [type, prompt, question]
    scores_3d = np.full((3, 5, 20), np.nan)
    
    # Extract scores for each type, prompt, and question
    for prompt_idx in range(5):
        for question_idx in range(20):
            # pos scores
            pos_key = f"pos_p{prompt_idx}_q{question_idx}"
            if pos_key in cleaned_scores:
                scores_3d[0, prompt_idx, question_idx] = cleaned_scores[pos_key]
            
            # neg scores  
            neg_key = f"neg_p{prompt_idx}_q{question_idx}"
            if neg_key in cleaned_scores:
                scores_3d[1, prompt_idx, question_idx] = cleaned_scores[neg_key]
            
            # default scores
            default_key = f"default_p{prompt_idx}_q{question_idx}"
            if default_key in cleaned_scores:
                scores_3d[2, prompt_idx, question_idx] = cleaned_scores[default_key]
    
    scores_np[trait] = scores_3d

print(f"Created numpy arrays for {len(scores_np)} traits")
print(f"Shape of each array: {next(iter(scores_np.values())).shape}")
print(f"Example (first trait): {list(scores_np.keys())[0]}")
example_trait = list(scores_np.keys())[0]
print(f"Pos scores for first 2 prompts, 5 questions:\n{scores_np[example_trait][0, :2, :5]}")
print(f"Neg scores for first 2 prompts, 5 questions:\n{scores_np[example_trait][1, :2, :5]}")

Created numpy arrays for 240 traits
Shape of each array: (3, 5, 20)
Example (first trait): absolutist
Pos scores for first 2 prompts, 5 questions:
[[ 0. 10. 10. 10.  0.]
 [ 0. 10.  0. 10.  0.]]
Neg scores for first 2 prompts, 5 questions:
[[ 0.  0.  0. 10.  0.]
 [ 0. 10.  0.  0.  0.]]


In [5]:
# Calculate simplified statistics for each trait
stats = {}

for trait, scores_3d in scores_np.items():
    pos_scores = scores_3d[0]  # shape: (5, 20) 
    neg_scores = scores_3d[1]  # shape: (5, 20)
    
    # Mean difference between pos and neg across all samples
    pos_minus_neg_mean = np.nanmean(pos_scores - neg_scores)
    
    # Count all pos/neg pairs with same prompt_index and question_index
    high_pos_low_neg_count = 0
    large_diff_count = 0
    
    # Check all 100 pairs (5 prompts × 20 questions)
    for prompt_idx in range(5):
        for question_idx in range(20):
            pos_val = pos_scores[prompt_idx, question_idx]
            neg_val = neg_scores[prompt_idx, question_idx]
            
            # Skip if either value is NaN
            if not (np.isnan(pos_val) or np.isnan(neg_val)):
                # Count high pos, low neg cases
                if pos_val > 50 and neg_val < 50:
                    high_pos_low_neg_count += 1
                
                # Count large difference cases  
                if abs(pos_val - neg_val) > 40:
                    large_diff_count += 1
    
    stats[trait] = {
        "pos_minus_neg_mean": pos_minus_neg_mean,
        "high_pos_low_neg_count": high_pos_low_neg_count,
        "large_diff_count": large_diff_count
    }

# Show example statistics for first trait
example_trait = list(stats.keys())[0]
print(f"Example statistics for '{example_trait}':")
for key, value in stats[example_trait].items():
    if isinstance(value, float):
        print(f"  {key}: {value:.2f}")
    else:
        print(f"  {key}: {value}")

print(f"\nCalculated statistics for {len(stats)} traits")

# Show summary of counts
high_pos_counts = [s["high_pos_low_neg_count"] for s in stats.values()]
large_diff_counts = [s["large_diff_count"] for s in stats.values()]
print(f"\nHigh pos, low neg count distribution: min={min(high_pos_counts)}, max={max(high_pos_counts)}, mean={np.mean(high_pos_counts):.1f}")
print(f"Large diff count distribution: min={min(large_diff_counts)}, max={max(large_diff_counts)}, mean={np.mean(large_diff_counts):.1f}")

# Export to CSV
stats_df = pd.DataFrame.from_dict(stats, orient='index')
stats_df.index.name = 'trait'
stats_df.to_csv('./results/pos_neg.csv')
print(f"\nExported statistics to pos_neg.csv")
print(f"Shape: {stats_df.shape}")

Example statistics for 'absolutist':
  pos_minus_neg_mean: 21.75
  high_pos_low_neg_count: 25
  large_diff_count: 25

Calculated statistics for 240 traits

High pos, low neg count distribution: min=0, max=100, mean=79.6
Large diff count distribution: min=0, max=100, mean=79.4

Exported statistics to pos_neg.csv
Shape: (240, 3)


In [15]:
# Calculate pos - default statistics similar to pos - neg
pos_default_stats = {}

for trait, scores_3d in scores_np.items():
    pos_scores = scores_3d[0]      # shape: (5, 20) 
    default_scores = scores_3d[2]  # shape: (5, 20)
    
    # Mean difference between pos and default across all samples
    pos_minus_default_mean = np.nanmean(pos_scores - default_scores)
    
    # Count all pos/default pairs with same prompt_index and question_index
    high_pos_low_default_count = 0
    large_diff_count = 0
    
    # Check all 100 pairs (5 prompts × 20 questions)
    for prompt_idx in range(5):
        for question_idx in range(20):
            pos_val = pos_scores[prompt_idx, question_idx]
            default_val = default_scores[prompt_idx, question_idx]
            
            # Skip if either value is NaN
            if not (np.isnan(pos_val) or np.isnan(default_val)):
                # Count high pos, low default cases
                if pos_val > 50 and default_val < 50:
                    high_pos_low_default_count += 1
                
                # Count large difference cases  
                if abs(pos_val - default_val) > 40:
                    large_diff_count += 1
    
    pos_default_stats[trait] = {
        "pos_minus_default_mean": pos_minus_default_mean,
        "high_pos_low_default_count": high_pos_low_default_count,
        "large_diff_count": large_diff_count
    }

# Show example statistics for first trait
example_trait = list(pos_default_stats.keys())[0]
print(f"Example pos-default statistics for '{example_trait}':")
for key, value in pos_default_stats[example_trait].items():
    if isinstance(value, float):
        print(f"  {key}: {value:.2f}")
    else:
        print(f"  {key}: {value}")

print(f"\nCalculated pos-default statistics for {len(pos_default_stats)} traits")

# Show summary of counts
high_pos_counts = [s["high_pos_low_default_count"] for s in pos_default_stats.values()]
large_diff_counts = [s["large_diff_count"] for s in pos_default_stats.values()]
print(f"\nHigh pos, low default count distribution: min={min(high_pos_counts)}, max={max(high_pos_counts)}, mean={np.mean(high_pos_counts):.1f}")
print(f"Large diff count distribution: min={min(large_diff_counts)}, max={max(large_diff_counts)}, mean={np.mean(large_diff_counts):.1f}")

# Export to CSV
pos_default_df = pd.DataFrame.from_dict(pos_default_stats, orient='index')
pos_default_df.index.name = 'trait'
pos_default_df.to_csv('./results/pos_default.csv')
print(f"\nExported pos-default statistics to pos_default.csv")
print(f"Shape: {pos_default_df.shape}")

Example pos-default statistics for 'absolutist':
  pos_minus_default_mean: 22.75
  high_pos_low_default_count: 25
  large_diff_count: 25

Calculated pos-default statistics for 240 traits

High pos, low default count distribution: min=0, max=100, mean=47.1
Large diff count distribution: min=0, max=100, mean=46.8

Exported pos-default statistics to pos_default.csv
Shape: (240, 3)


## PCA

In [6]:
# check vectors

# load all vectors from data/vectors
vector_dir = "/root/git/persona-subspace/traits/data/vectors"

# iterate through each .pt file in the directory
vectors = {}
for file in os.listdir(vector_dir):
    if file.endswith(".pt"):
        vectors[file.replace(".pt", "")] = torch.load(os.path.join(vector_dir, file))

print(f"Found {len(vectors.keys())} traits with vectors")

Found 240 traits with vectors


In [17]:
vectors['zealous']['pos_neg'].shape

torch.Size([46, 4608])

In [7]:
layer = 34

pos_neg = [vector['pos_neg'] for vector in vectors.values()]
pos_neg_50 = [vector['pos_neg_50'] for vector in vectors.values()]
pos_default = [vector['pos_default'] for vector in vectors.values()]
pos_default_50 = [vector['pos_default_50'] for vector in vectors.values()]

print(len(pos_neg))


240


In [8]:
def compute_pca(activation_list, layer):
    layer_activations = activation_list[:, layer, :]
    
    scaler = StandardScaler()
    scaled_layer_activations = scaler.fit_transform(layer_activations)

    pca = PCA()
    pca_transformed = pca.fit_transform(scaled_layer_activations)

    variance_explained = pca.explained_variance_ratio_
    cumulative_variance = np.cumsum(variance_explained)
    n_components = len(variance_explained)

    print(f"PCA fitted with {n_components} components")
    print(f"Cumulative variance for first 5 components: {cumulative_variance[:5]}")

    # Find elbow using second derivative method
    def find_elbow_point(variance_explained):
        """Find elbow point using second derivative method"""
        # Calculate first and second derivatives
        first_diff = np.diff(variance_explained)
        second_diff = np.diff(first_diff) 
        
        # Find point with maximum second derivative (most curvature)
        elbow_idx = np.argmax(np.abs(second_diff)) + 1  # +1 to account for diff operations
        return elbow_idx

    elbow_point = find_elbow_point(variance_explained)
    dims_70_pca = np.argmax(cumulative_variance >= 0.70) + 1
    dims_80_pca = np.argmax(cumulative_variance >= 0.80) + 1
    dims_90_pca = np.argmax(cumulative_variance >= 0.90) + 1
    dims_95_pca = np.argmax(cumulative_variance >= 0.95) + 1

    print("\nPCA Analysis Results:")
    print(f"Elbow point at component: {elbow_point + 1}")
    print(f"Dimensions for 70% variance: {dims_70_pca}")
    print(f"Dimensions for 80% variance: {dims_80_pca}")
    print(f"Dimensions for 90% variance: {dims_90_pca}")
    print(f"Dimensions for 95% variance: {dims_95_pca}")

    return pca_transformed, variance_explained, n_components, pca, scaler 

In [9]:
def plot_pca_cosine_similarity(pca_transformed, trait_labels, pc_component=0, 
                             layer=None, reference_point=None, color_threshold=0.0):
    """
    Create a plot similar to the PC1 Cosine Similarity visualization.
    Shows labels on hover for most points, with visible labels and leader lines 
    for the 20 traits at either end of the range to avoid overlap.
    
    Parameters:
    - pca_transformed: PCA-transformed data (n_samples, n_components)
    - trait_labels: List of labels for each data point
    - pc_component: Which PC component to use (0-indexed, so PC1 = 0)
    - layer: Layer number for title
    - reference_point: Reference point for cosine similarity calculation 
                      (if None, uses the PC component direction)
    - color_threshold: Threshold for coloring points (default: 0.0)
    
    Returns:
    - Plotly figure object
    """
    
    # Extract the specified PC component
    pc_values = pca_transformed[:, pc_component]
    
    # Calculate cosine similarities
    if reference_point is None:
        # Use the PC component direction as reference
        # This creates a "cosine similarity with PC direction" interpretation
        cosine_sims = pc_values / np.linalg.norm(pc_values)  # Normalized PC values
    else:
        # Calculate cosine similarity with a specific reference point
        cosine_sims = cosine_similarity(pca_transformed, reference_point.reshape(1, -1)).flatten()
    
    # Create colors based on threshold
    colors = ['red' if sim < color_threshold else 'blue' for sim in cosine_sims]
    
    # Identify extreme traits (10 lowest and 10 highest)
    sorted_indices = np.argsort(cosine_sims)
    low_extreme_indices = sorted_indices[:10]
    high_extreme_indices = sorted_indices[-10:]
    extreme_indices = set(list(low_extreme_indices) + list(high_extreme_indices))
    
    # Create subplot figure
    fig = sp.make_subplots(
        rows=2, cols=1,
        row_heights=[0.6, 0.4],
        vertical_spacing=0.1,
        subplot_titles=[
            f'PC{pc_component+1} Cosine Similarity',
            'Trait Frequency Distribution'
        ]
    )
    
    # Split points into regular and extreme for different display modes
    regular_x, regular_y, regular_colors, regular_labels = [], [], [], []
    extreme_x, extreme_y, extreme_colors, extreme_labels = [], [], [], []
    
    for i, (sim, color, label) in enumerate(zip(cosine_sims, colors, trait_labels)):
        if i in extreme_indices:
            extreme_x.append(sim)
            extreme_y.append(1)
            extreme_colors.append(color)
            extreme_labels.append(label)
        else:
            regular_x.append(sim)
            regular_y.append(1)
            regular_colors.append(color)
            regular_labels.append(label)
    
    # Add regular points (hover labels only)
    if regular_x:
        fig.add_trace(
            go.Scatter(
                x=regular_x,
                y=regular_y,
                mode='markers',
                marker=dict(
                    color=regular_colors,
                    size=8,
                    opacity=0.7
                ),
                text=regular_labels,
                showlegend=False,
                hovertemplate='<b>%{text}</b><br>Cosine Similarity: %{x:.3f}<extra></extra>'
            ),
            row=1, col=1
        )
    
    # Add extreme points with visible labels and leader lines
    if extreme_x:
        fig.add_trace(
            go.Scatter(
                x=extreme_x,
                y=extreme_y,
                mode='markers',
                marker=dict(
                    color=extreme_colors,
                    size=8,
                    opacity=0.9
                ),
                text=extreme_labels,
                showlegend=False,
                hovertemplate='<b>%{text}</b><br>Cosine Similarity: %{x:.3f}<extra></extra>'
            ),
            row=1, col=1
        )
        
        # Create predefined alternating heights with variation
        # High positions with variation
        high_positions = [1.6, 1.45, 1.55, 1.35, 1.5, 1.4, 1.65, 1.3, 1.58, 1.42]
        # Low positions with variation  
        low_positions = [0.4, 0.55, 0.45, 0.65, 0.5, 0.6, 0.35, 0.7, 0.42, 0.58]
        
        # Alternate high-low pattern
        all_y_positions = []
        for i in range(10):
            all_y_positions.extend([high_positions[i], low_positions[i]])
        
        # Handle low extremes (10 lowest cosine similarities)
        for i, idx in enumerate(low_extreme_indices):
            x_pos = cosine_sims[idx]
            label = trait_labels[idx]
            color = colors[idx]
            y_label = all_y_positions[i]
            
            # Add leader line as a separate trace
            fig.add_trace(
                go.Scatter(
                    x=[x_pos, x_pos],
                    y=[1.0, y_label],
                    mode='lines',
                    line=dict(color=color, width=1),
                    showlegend=False,
                    hoverinfo='skip'
                ),
                row=1, col=1
            )
            
            # Add label at the end of the line
            fig.add_annotation(
                x=x_pos,
                y=y_label,
                text=label,
                showarrow=False,
                font=dict(size=10, color=color),
                bgcolor="rgba(255, 255, 255, 0.9)",
                bordercolor=color,
                borderwidth=1,
                row=1, col=1
            )
        
        # Handle high extremes (10 highest cosine similarities)
        for i, idx in enumerate(high_extreme_indices):
            x_pos = cosine_sims[idx]
            label = trait_labels[idx]
            color = colors[idx]
            y_label = all_y_positions[i + 10]  # Offset by 10 to continue the pattern
            
            # Add leader line as a separate trace
            fig.add_trace(
                go.Scatter(
                    x=[x_pos, x_pos],
                    y=[1.0, y_label],
                    mode='lines',
                    line=dict(color=color, width=1),
                    showlegend=False,
                    hoverinfo='skip'
                ),
                row=1, col=1
            )
            
            # Add label at the end of the line
            fig.add_annotation(
                x=x_pos,
                y=y_label,
                text=label,
                showarrow=False,
                font=dict(size=10, color=color),
                bgcolor="rgba(255, 255, 255, 0.9)",
                bordercolor=color,
                borderwidth=1,
                row=1, col=1
            )
    
    # Add vertical line at x=0 for both panels
    fig.add_vline(
        x=0,
        line_dash="solid",
        line_color="gray",
        line_width=1,
        opacity=0.7,
        row=1, col=1
    )
    
    fig.add_vline(
        x=0,
        line_dash="solid", 
        line_color="gray",
        line_width=1,
        opacity=0.7,
        row=2, col=1
    )
    
    # Bottom panel: Histogram
    fig.add_trace(
        go.Histogram(
            x=cosine_sims,
            nbinsx=30,
            opacity=0.7,
            marker_color='steelblue',
            showlegend=False
        ),
        row=2, col=1
    )
    
    # Update layout
    fig.update_layout(
        height=700,
        title=dict(
            text="PCA on Trait Vectors from Mean Response Activations",
            subtitle={
                "text": f"Gemma 2 27B, Layer {layer}",
            },
            x=0.5,
            font=dict(size=16)
        ),
        showlegend=False
    )
    
    # Calculate symmetric range around 0 (not around data center)
    max_abs_value = max(abs(min(cosine_sims)), abs(max(cosine_sims)))
    x_half_width = max_abs_value * 1.1  # Add 10% padding
    
    # Update x-axes with symmetric ranges centered on 0
    fig.update_xaxes(
        row=1, col=1,
        range=[-x_half_width, x_half_width]
    )
    
    fig.update_xaxes(
        title_text=f"PC{pc_component+1} Cosine Similarity",
        row=2, col=1,
        range=[-x_half_width, x_half_width]
    )
    
    # Update y-axes
    fig.update_yaxes(
        title_text="",
        showticklabels=False,
        row=1, col=1,
        range=[0.25, 1.75]  # Range for varied label heights
    )
    
    fig.update_yaxes(
        title_text="Frequency",
        row=2, col=1
    )
    
    return fig

In [10]:
def plot_3d_pca(pca_transformed, variance_explained, trait_labels, layer):
    # Create 3D scatter plot if we have enough components

    fig_3d = go.Figure(data=[go.Scatter3d(
        x=pca_transformed[:, 0],
        y=pca_transformed[:, 1], 
        z=pca_transformed[:, 2],
        mode='markers+text',
        text=trait_labels,
        textposition='top center',
        textfont=dict(size=6),
        marker=dict(
            size=3,
            color=['blue'] * len(trait_labels),
            line=dict(width=2, color='black')
        ),
        hovertemplate='<b>%{text}</b><br>' +
                    f'PC1: %{{x:.3f}}<br>' +
                    f'PC2: %{{y:.3f}}<br>' +
                    f'PC3: %{{z:.3f}}<br>' +
                    '<extra></extra>'
    )])
    
    fig_3d.update_layout(
        title={
            "text": f'Trait Vectors in Principal Component Space',
            "subtitle": {
                "text": f"Gemma 2 27B, Layer {layer}",
            },
        },
        scene=dict(
            xaxis_title=f'PC1 ({variance_explained[0]*100:.1f}%)',
            yaxis_title=f'PC2 ({variance_explained[1]*100:.1f}%)',
            zaxis_title=f'PC3 ({variance_explained[2]*100:.1f}%)'
        ),
        width=1000,
        height=800
    )
    
    fig_3d.show()
    fig_3d.write_html(f"./results/pca_3d.html")

### pos_neg_50

In [11]:
# PCA on pos_neg_50 but filter out traits with large_diff_count < 10
filtered_pos_neg_50_traits = []
filtered_pos_neg_50 = []

for trait, vector in vectors.items():
    if stats[trait]['large_diff_count'] >= 10:
        filtered_pos_neg_50_traits.append(trait)
        filtered_pos_neg_50.append(vector['pos_neg_50'])

print(len(filtered_pos_neg_50_traits))

filtered_pos_neg_50 = torch.stack(filtered_pos_neg_50).float()
print(filtered_pos_neg_50.shape)

235
torch.Size([235, 46, 4608])


In [13]:
pca_transformed, variance_explained, n_components, pca, scaler = compute_pca(filtered_pos_neg_50, layer)

PCA fitted with 235 components
Cumulative variance for first 5 components: [0.14948702 0.26322171 0.34780798 0.41020005 0.4467624 ]

PCA Analysis Results:
Elbow point at component: 5
Dimensions for 70% variance: 20
Dimensions for 80% variance: 38
Dimensions for 90% variance: 76
Dimensions for 95% variance: 117


In [14]:
# save PCA and trait vectors and names to a file

results = {}
results['layer'] = layer
results['traits'] = filtered_pos_neg_50_traits
results['vectors'] = filtered_pos_neg_50
results['pca_transformed'] = pca_transformed
results['variance_explained'] = variance_explained
results['n_components'] = n_components
results['pca'] = pca
results['scaler'] = scaler

torch.save(results, 'results/pca/pos_neg_50_layer34.pt')

## Project conversation trajectory

In [5]:
# iterate through each of the model responses and get the mean activation
convo = json.load(open('results/transcripts/long1.json'))['conversation']
activations = []




In [10]:
model, tokenizer = load_model("google/gemma-2-27b-it")

Loading checkpoint shards:   0%|          | 0/12 [00:00<?, ?it/s]

In [4]:
full_activations = extract_full_activations(model, tokenizer, convo)

In [None]:
torch.save(full_activations, '/workspace/traits/transcripts/long1.pt')

In [25]:
def mean_response_activation(activations, conversation, tokenizer):
    """
    Get the mean activation of the model's response to the user's message.
    """
    # get the token positions of model responses
    response_indices = get_response_indices(conversation, tokenizer)

    # get the mean activation of the model's response to the user's message
    mean_activation = activations[:, response_indices, :].mean(dim=1)
    return mean_activation

def get_response_indices(conversation, tokenizer):
    """
    Get every token index of the model's response.
    
    Args:
        conversation: List of dict with 'role' and 'content' keys
        tokenizer: Tokenizer to apply chat template and tokenize
    
    Returns:
        response_indices: list of token positions where the model is responding
    """
    # Apply chat template to the full conversation
    response_indices = []
    
    # Process conversation incrementally to find assistant response boundaries
    for i, turn in enumerate(conversation):
        if turn['role'] != 'assistant':
            continue
            
        # Get conversation up to but not including this assistant turn
        conversation_before = conversation[:i]
        
        # Get conversation up to and including this assistant turn  
        conversation_including = conversation[:i+1]
        
        # Format and tokenize both versions
        if conversation_before:
            before_formatted = tokenizer.apply_chat_template(
                conversation_before, tokenize=False, add_generation_prompt=True
            )
            before_tokens = tokenizer(before_formatted, add_special_tokens=False)
            before_length = len(before_tokens['input_ids'])
        else:
            before_length = 0
            
        including_formatted = tokenizer.apply_chat_template(
            conversation_including, tokenize=False, add_generation_prompt=False
        )
        including_tokens = tokenizer(including_formatted, add_special_tokens=False)
        including_length = len(including_tokens['input_ids'])
        
        # The assistant response tokens are between before_length and including_length
        # We need to account for any generation prompt tokens that get removed
        assistant_start = before_length
        assistant_end = including_length
        
        # Add these indices to our response list
        response_indices.extend(range(assistant_start, assistant_end))
    
    return response_indices

In [24]:
def get_response_indices_per_turn(conversation, tokenizer):
    """
    Get token indices for each of the model's response turns separately.
    
    Args:
        conversation: List of dict with 'role' and 'content' keys
        tokenizer: Tokenizer to apply chat template and tokenize
    
    Returns:
        response_indices_per_turn: List[List[int]] - each inner list contains token positions for one assistant turn
    """
    response_indices_per_turn = []
    
    # Process conversation incrementally to find assistant response boundaries
    for i, turn in enumerate(conversation):
        if turn['role'] != 'assistant':
            continue
            
        # Get conversation up to but not including this assistant turn
        conversation_before = conversation[:i]
        
        # Get conversation up to and including this assistant turn  
        conversation_including = conversation[:i+1]
        
        # Format and tokenize both versions
        if conversation_before:
            before_formatted = tokenizer.apply_chat_template(
                conversation_before, tokenize=False, add_generation_prompt=True
            )
            before_tokens = tokenizer(before_formatted, add_special_tokens=False)
            before_length = len(before_tokens['input_ids'])
        else:
            before_length = 0
            
        including_formatted = tokenizer.apply_chat_template(
            conversation_including, tokenize=False, add_generation_prompt=False
        )
        including_tokens = tokenizer(including_formatted, add_special_tokens=False)
        including_length = len(including_tokens['input_ids'])
        
        # The assistant response tokens are between before_length and including_length
        assistant_start = before_length
        assistant_end = including_length
        
        # Add this turn's indices as a separate list
        turn_indices = list(range(assistant_start, assistant_end))
        response_indices_per_turn.append(turn_indices)
    
    return response_indices_per_turn

def mean_response_activation_per_turn(activations, conversation, tokenizer):
    """
    Get the mean activation for each of the model's response turns.
    
    Args:
        activations: Tensor with shape (layers, tokens, features)
        conversation: List of dict with 'role' and 'content' keys
        tokenizer: Tokenizer to apply chat template and tokenize
    
    Returns:
        List[torch.Tensor]: List of mean activations, one per assistant turn
    """
    # Get token positions for each assistant turn
    response_indices_per_turn = get_response_indices_per_turn(conversation, tokenizer)
    
    # Calculate mean activation for each turn
    mean_activations_per_turn = []
    
    for turn_indices in response_indices_per_turn:
        if len(turn_indices) > 0:
            # Get mean activation for this turn's tokens
            turn_mean_activation = activations[:, turn_indices, :].mean(dim=1)
            mean_activations_per_turn.append(turn_mean_activation)
    
    return mean_activations_per_turn

In [26]:
# Test the new function
activations_per_turn = mean_response_activation_per_turn(full_activations, convo, tokenizer)

print(f"Number of assistant turns: {len(activations_per_turn)}")
print(f"Shape of each activation: {activations_per_turn[0].shape if activations_per_turn else 'No activations'}")
print(f"Expected number of assistant turns: {sum(1 for turn in convo if turn['role'] == 'assistant')}")


Number of assistant turns: 16
Shape of each activation: torch.Size([46, 4608])
Expected number of assistant turns: 16


In [6]:
# load the PCA and project each turn into the trait space

pca_results = torch.load('/workspace/traits/pca/pos_neg_50_layer34.pt', weights_only=False)
print(pca_results.keys())

dict_keys(['layer', 'traits', 'vectors', 'pca_transformed', 'variance_explained', 'n_components', 'pca', 'scaler'])


In [7]:
scaler = pca_results['scaler']
pca = pca_results['pca']


In [29]:
# Project each turn's activation into PCA space
activations_per_turn_pca = []
layer = 34

for i, turn_activation in enumerate(activations_per_turn):
    # Extract the same layer used for PCA (layer 34)
    turn_layer_activation = turn_activation[layer, :].float().numpy().reshape(1, -1)
    
    # Apply the same preprocessing (standardization) used for trait vectors
    turn_scaled = scaler.transform(turn_layer_activation)
    
    # Project into PCA space
    turn_projected = pca.transform(turn_scaled)
    activations_per_turn_pca.append(turn_projected[0])  # Store as 1D array

print(f"Projected {len(activations_per_turn_pca)} turns into PCA space")
print(f"Each projection shape: {activations_per_turn_pca[0].shape}")
print(f"First 3 PC coordinates for turn 0: [{activations_per_turn_pca[0][0]:.3f}, {activations_per_turn_pca[0][1]:.3f}, {activations_per_turn_pca[0][2]:.3f}]")
print(f"First 3 PC coordinates for turn -1: [{activations_per_turn_pca[-1][0]:.3f}, {activations_per_turn_pca[-1][1]:.3f}, {activations_per_turn_pca[-1][2]:.3f}]")

Projected 16 turns into PCA space
Each projection shape: (235,)
First 3 PC coordinates for turn 0: [19.703, -0.391, 2.336]
First 3 PC coordinates for turn -1: [13.812, 2.697, 0.617]


In [25]:
# Calculate cosine similarity with first 5 PCs for each turn
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

# Get the first 5 principal component directions from the PCA
pc_directions = pca.components_[:5]  # Shape: (5, 4608) - first 5 PCs in original space

print("Cosine similarity with first 5 PCs for each conversation turn:")
print("Turn | PC1     | PC2     | PC3     | PC4     | PC5")
print("-" * 50)

for i, turn_pca in enumerate(activations_per_turn_pca):
    # Calculate cosine similarity with each of the first 5 PCs
    similarities = []
    for pc_idx in range(5):
        # Get the PC direction in PCA space (it's just a unit vector)
        pc_in_pca_space = np.zeros(len(turn_pca))
        pc_in_pca_space[pc_idx] = 1.0
        
        # Calculate cosine similarity
        sim = cosine_similarity(turn_pca.reshape(1, -1), pc_in_pca_space.reshape(1, -1))[0, 0]
        similarities.append(sim)
    
    print(f"{i:2d}   | {similarities[0]:7.3f} | {similarities[1]:7.3f} | {similarities[2]:7.3f} | {similarities[3]:7.3f} | {similarities[4]:7.3f}")

Cosine similarity with first 5 PCs for each conversation turn:
Turn | PC1     | PC2     | PC3     | PC4     | PC5
--------------------------------------------------
 0   |   0.269 |  -0.005 |   0.032 |  -0.110 |  -0.083
 1   |   0.305 |   0.081 |  -0.044 |  -0.070 |  -0.076
 2   |   0.237 |   0.115 |  -0.046 |  -0.101 |  -0.070
 3   |   0.309 |   0.080 |   0.004 |   0.020 |  -0.002
 4   |   0.279 |   0.067 |   0.083 |  -0.017 |  -0.062
 5   |   0.305 |   0.075 |   0.086 |  -0.036 |   0.007
 6   |   0.104 |   0.116 |   0.074 |  -0.059 |  -0.030
 7   |   0.046 |   0.115 |  -0.144 |  -0.289 |   0.029
 8   |  -0.042 |   0.110 |  -0.148 |  -0.316 |  -0.024
 9   |   0.133 |   0.122 |  -0.198 |  -0.208 |  -0.084
10   |   0.201 |   0.074 |  -0.099 |  -0.059 |  -0.036
11   |   0.187 |   0.077 |  -0.075 |  -0.018 |  -0.032
12   |   0.132 |   0.059 |  -0.156 |  -0.019 |   0.001
13   |   0.335 |   0.122 |  -0.075 |  -0.037 |  -0.009
14   |   0.334 |   0.073 |  -0.044 |   0.021 |  -0.002
15   |   0

In [77]:
def plot_conversation_trajectory_dashboard(pca_transformed, trait_labels, activations_per_turn_pca, 
                                         variance_explained, layer=None, conversation=None):
    """
    Create an interactive dashboard showing conversation turns across PC1-PC5 with controls.
    
    Parameters:
    - pca_transformed: PCA-transformed trait data (n_samples, n_components)
    - trait_labels: List of trait labels
    - activations_per_turn_pca: List of PCA-projected conversation turns
    - variance_explained: Explained variance ratio from PCA
    - layer: Layer number for title
    - conversation: Conversation data for context display
    
    Returns:
    - Plotly figure object with interactive dropdown
    """
    
    # Helper function for text wrapping
    def wrap_text(text, width=45):
        """Manually wrap text for better display in annotations"""
        words = text.split()
        lines = []
        current_line = []
        current_length = 0
        
        for word in words:
            if current_length + len(word) + 1 <= width:
                current_line.append(word)
                current_length += len(word) + 1
            else:
                if current_line:
                    lines.append(" ".join(current_line))
                current_line = [word]
                current_length = len(word)
        
        if current_line:
            lines.append(" ".join(current_line))
            
        return "<br>".join(lines)
    
    # Create subplot layout: 3x2 grid (5 PC panels + 1 info panel)
    titles = ["Emotional vs. Analytical", "Negative vs. Positive Affect", "Communication Style", "Mystical vs. Grounded", "Intellectual Style", "Response Text"]
    fig = sp.make_subplots(
        rows=3, cols=2,
        subplot_titles=titles,
        specs=[[{"type": "xy"}, {"type": "xy"}],
               [{"type": "xy"}, {"type": "xy"}],
               [{"type": "xy"}, {"type": "xy"}]],
        horizontal_spacing=0.08,
        vertical_spacing=0.12
    )
    
    # Function to add traits for a PC panel (matching original plot style exactly)
    def add_traits_to_panel(pc_component, row, col):
        """Add trait scatter points to a PC panel with full original styling"""
        
        # Extract the specified PC component for traits
        pc_values = pca_transformed[:, pc_component]
        
        # Calculate cosine similarities (using PC direction as reference)
        cosine_sims = pc_values / np.linalg.norm(pc_values)
        
        # Create colors based on threshold
        colors = ['red' if sim < 0.0 else 'blue' for sim in cosine_sims]
        
        # Identify extreme traits (10 lowest and 10 highest like original)
        sorted_indices = np.argsort(cosine_sims)
        low_extreme_indices = sorted_indices[:10]
        high_extreme_indices = sorted_indices[-10:]
        extreme_indices = set(list(low_extreme_indices) + list(high_extreme_indices))
        
        # Split points into regular and extreme
        regular_x, regular_y, regular_colors, regular_labels = [], [], [], []
        extreme_x, extreme_y, extreme_colors, extreme_labels = [], [], [], []
        
        for i, (sim, color, label) in enumerate(zip(cosine_sims, colors, trait_labels)):
            if i in extreme_indices:
                extreme_x.append(sim)
                extreme_y.append(1)
                extreme_colors.append(color)
                extreme_labels.append(label)
            else:
                regular_x.append(sim)
                regular_y.append(1)
                regular_colors.append(color)
                regular_labels.append(label)
        
        # Add regular points (hover labels only)
        if regular_x:
            fig.add_trace(
                go.Scatter(
                    x=regular_x,
                    y=regular_y,
                    mode='markers',
                    marker=dict(color=regular_colors, size=8, opacity=0.7),
                    text=regular_labels,
                    showlegend=False,
                    hovertemplate='<b>%{text}</b><br>Cosine Similarity: %{x:.3f}<extra></extra>'
                ),
                row=row, col=col
            )
        
        # Add extreme points with visible labels and leader lines
        if extreme_x:
            fig.add_trace(
                go.Scatter(
                    x=extreme_x,
                    y=extreme_y,
                    mode='markers',
                    marker=dict(color=extreme_colors, size=8, opacity=0.9),
                    text=extreme_labels,
                    showlegend=False,
                    hovertemplate='<b>%{text}</b><br>Cosine Similarity: %{x:.3f}<extra></extra>'
                ),
                row=row, col=col
            )
            
            # EXACT stagger pattern from original (alternating high-low, not separate arrays)
            high_positions = [1.6, 1.45, 1.55, 1.35, 1.5, 1.4, 1.65, 1.3, 1.58, 1.42]
            low_positions = [0.4, 0.55, 0.45, 0.65, 0.5, 0.6, 0.35, 0.7, 0.42, 0.58]
            
            # Alternate high-low pattern exactly like original
            all_y_positions = []
            for i in range(10):
                all_y_positions.extend([high_positions[i], low_positions[i]])
            
            # Handle low extremes (10 lowest cosine similarities)
            for i, idx in enumerate(low_extreme_indices):
                x_pos = cosine_sims[idx]
                label = trait_labels[idx]
                color = colors[idx]
                y_label = all_y_positions[i]  # Use staggered pattern
                
                # Add leader line
                fig.add_trace(
                    go.Scatter(
                        x=[x_pos, x_pos],
                        y=[1.0, y_label],
                        mode='lines',
                        line=dict(color=color, width=1),
                        showlegend=False,
                        hoverinfo='skip'
                    ),
                    row=row, col=col
                )
                
                # Add label with border
                fig.add_annotation(
                    x=x_pos, y=y_label,
                    text=label,
                    showarrow=False,
                    font=dict(size=10, color=color),
                    bgcolor="rgba(255, 255, 255, 0.9)",
                    bordercolor=color,
                    borderwidth=1,
                    row=row, col=col
                )
            
            # Handle high extremes (10 highest cosine similarities)  
            for i, idx in enumerate(high_extreme_indices):
                x_pos = cosine_sims[idx]
                label = trait_labels[idx]
                color = colors[idx]
                y_label = all_y_positions[i + 10]  # Continue staggered pattern
                
                # Add leader line
                fig.add_trace(
                    go.Scatter(
                        x=[x_pos, x_pos],
                        y=[1.0, y_label],
                        mode='lines',
                        line=dict(color=color, width=1),
                        showlegend=False,
                        hoverinfo='skip'
                    ),
                    row=row, col=col
                )
                
                # Add label with border
                fig.add_annotation(
                    x=x_pos, y=y_label,
                    text=label,
                    showarrow=False,
                    font=dict(size=10, color=color),
                    bgcolor="rgba(255, 255, 255, 0.9)",
                    bordercolor=color,
                    borderwidth=1,
                    row=row, col=col
                )
        
        # Add vertical line at x=0
        fig.add_vline(x=0, line_dash="solid", line_color="gray", 
                     line_width=1, opacity=0.7, row=row, col=col)
        
        # Set axis ranges (matching original scale)
        max_abs_value = max(abs(min(cosine_sims)), abs(max(cosine_sims)))
        x_half_width = max_abs_value * 1.1
        
        fig.update_xaxes(
            title_text=f"PC{pc_component+1} Cosine Similarity",
            range=[-x_half_width, x_half_width],
            row=row, col=col
        )
        fig.update_yaxes(
            title_text="", 
            showticklabels=False,
            range=[0.25, 1.75],  # Original range for varied label heights
            row=row, col=col
        )
    
    # Panel positions for 3x2 layout
    panel_positions = [(1,1), (1,2), (2,1), (2,2), (3,1)]
    
    # Add traits to all PC panels
    for pc_idx, (row, col) in enumerate(panel_positions):
        add_traits_to_panel(pc_idx, row, col)
    
    # Add Response markers for all turns (initially hidden except turn 0) - STYLED LIKE ASSISTANT
    for turn_idx, turn_coords in enumerate(activations_per_turn_pca):
        visible = (turn_idx == 0)  # Only turn 0 visible initially
        
        for pc_idx, (row, col) in enumerate(panel_positions):
            pc_values = pca_transformed[:, pc_idx]
            turn_pc_value = turn_coords[pc_idx] 
            turn_cosine_sim = turn_pc_value / np.linalg.norm(np.concatenate([pc_values, [turn_pc_value]]))
            
            # Add response marker (just the dot, no text label)
            fig.add_trace(
                go.Scatter(
                    x=[turn_cosine_sim],
                    y=[1],
                    mode='markers',
                    marker=dict(size=8, color='black', opacity=1.0),
                    showlegend=False,
                    visible=visible,
                    hovertemplate=f'<b>Response Turn {turn_idx}</b><br>PC{pc_idx+1}: %{{x:.3f}}<extra></extra>',
                    name=f'response_turn_{turn_idx}_pc_{pc_idx}'
                ),
                row=row, col=col
            )
            
            # Add dashed vertical line as trace (not vline) so it can be controlled by dropdown
            fig.add_trace(
                go.Scatter(
                    x=[turn_cosine_sim, turn_cosine_sim],
                    y=[0.25, 1.75],  # Full height of plot
                    mode='lines',
                    line=dict(dash='dash', color='black', width=1),
                    showlegend=False,
                    visible=visible,
                    hoverinfo='skip',
                    name=f'response_line_{turn_idx}_pc_{pc_idx}'
                ),
                row=row, col=col
            )
            
            # Add "Response" label as trace (not annotation) so it can be controlled by dropdown
            fig.add_trace(
                go.Scatter(
                    x=[turn_cosine_sim],
                    y=[1.6],  # Same height as high trait labels
                    mode='text',
                    text=['Response'],
                    textfont=dict(size=10, color='black'),
                    showlegend=False,
                    visible=visible,
                    hoverinfo='skip',
                    name=f'response_label_{turn_idx}_pc_{pc_idx}'
                ),
                row=row, col=col
            )
    
    # Panel 6: Turn Info display (3,2) - FIXED TEXT WRAPPING
    def get_turn_texts(turn_idx):
        """Get both user and assistant text for this turn"""
        if not conversation:
            return "No conversation data", "No response available"
            
        # Assistant turns are at odd indices: 0->1, 1->3, 2->5, etc.
        assistant_turn_idx = (turn_idx * 2) + 1
        user_turn_idx = assistant_turn_idx - 1  # User message just before assistant
        
        user_text = "No user message"
        assistant_text = "No assistant response"
        
        if user_turn_idx >= 0 and user_turn_idx < len(conversation):
            if conversation[user_turn_idx]['role'] == 'user':
                user_content = conversation[user_turn_idx]['content']
                # Truncate for better display
                if len(user_content) > 250:
                    user_text = user_content[:250] + "..."
                else:
                    user_text = user_content
        
        if assistant_turn_idx < len(conversation):
            if conversation[assistant_turn_idx]['role'] == 'assistant':
                assistant_content = conversation[assistant_turn_idx]['content']
                # Truncate for better display  
                if len(assistant_content) > 800:
                    assistant_text = assistant_content[:800] + "..."
                else:
                    assistant_text = assistant_content
        
        return user_text, assistant_text
    
    initial_user_text, initial_assistant_text = get_turn_texts(0)
    
    # Wrap text for proper display
    wrapped_user = wrap_text(initial_user_text, 100)
    wrapped_assistant = wrap_text(initial_assistant_text, 100)
    
    # Single text annotation with proper wrapping
    fig.add_annotation(
        x=0.5, y=0.5,
        text=f"<b>User:</b><br>{wrapped_user}<br><br>" +
             f"<b>Assistant:</b><br>{wrapped_assistant}",
        showarrow=False,
        font=dict(size=10),
        bgcolor="rgba(240, 240, 240, 0.9)",
        bordercolor="gray",
        borderwidth=1,
        xref="x6", yref="y6",
        align="left",
        width=550,  # Reasonable width
        name="turn_info"
    )
    
    # Hide axes for info panel
    fig.update_xaxes(visible=False, row=3, col=2)
    fig.update_yaxes(visible=False, row=3, col=2)
    
    # Create buttons for turn selection
    buttons = []
    for turn_idx in range(len(activations_per_turn_pca)):
        # Create visibility array
        visibility = []
        
        # All trait-related traces (always visible)
        num_trait_traces = len(fig.data) - (3 * 5 * len(activations_per_turn_pca))  # 3 traces per turn per panel
        for _ in range(num_trait_traces):
            visibility.append(True)
        
        # Response traces (marker, line, label for each turn/panel combination)
        for t_idx in range(len(activations_per_turn_pca)):
            for p_idx in range(5):  # 5 panels
                # Response marker
                visibility.append(t_idx == turn_idx)
                # Response dashed line
                visibility.append(t_idx == turn_idx)
                # Response label
                visibility.append(t_idx == turn_idx)
        
        user_text, assistant_text = get_turn_texts(turn_idx)
        wrapped_user = wrap_text(user_text, 100)
        wrapped_assistant = wrap_text(assistant_text, 100)
        
        buttons.append(dict(
            label=f"Turn {turn_idx}",
            method="update",
            args=[
                {"visible": visibility},
                {"annotations": [
                    # Keep existing annotations but update the info panel one
                    ann for ann in fig.layout.annotations if ann.name != "turn_info"
                ] + [
                    dict(
                        x=0.5, y=0.5,
                        text=f"<b>User:</b><br>{wrapped_user}<br><br>" +
                             f"<b>Assistant:</b><br>{wrapped_assistant}",
                        showarrow=False,
                        font=dict(size=10),
                        bgcolor="rgba(240, 240, 240, 0.9)",
                        bordercolor="gray",
                        borderwidth=1,
                        xref="x6", yref="y6",
                        align="left",
                        width=550
                    )
                ]}
            ]
        ))
    
    # Update layout
    fig.update_layout(
        width=1400,  # Much wider
        height=1000,  # Taller for 3x2
        title=dict(
            text="Conversation Trajectory in Trait Space",
            subtitle={"text": f"Gemma 2 27B, Layer {layer}"},
            x=0.5, font=dict(size=16)
        ),
        updatemenus=[dict(
            type="dropdown",
            direction="down",
            showactive=True,
            x=0.1, y=1.06,  # Moved higher as requested
            buttons=buttons
        )],
        showlegend=False
    )
    
    return fig

In [79]:
# Test the updated interactive dashboard with wider text area
def get_turn_texts_expanded(turn_idx, conversation):
    """Get both user and assistant text for this turn with more space"""
    if not conversation:
        return "No conversation data", "No response available"
        
    # Assistant turns are at odd indices: 0->1, 1->3, 2->5, etc.
    assistant_turn_idx = (turn_idx * 2) + 1
    user_turn_idx = assistant_turn_idx - 1  # User message just before assistant
    
    user_text = "No user message"
    assistant_text = "No assistant response"
    
    if user_turn_idx >= 0 and user_turn_idx < len(conversation):
        if conversation[user_turn_idx]['role'] == 'user':
            user_content = conversation[user_turn_idx]['content']
            # Allow more text for wider display
            if len(user_content) > 250:
                user_text = user_content[:250] + "..."
            else:
                user_text = user_content
    
    if assistant_turn_idx < len(conversation):
        if conversation[assistant_turn_idx]['role'] == 'assistant':
            assistant_content = conversation[assistant_turn_idx]['content']
            # Allow more text for wider display
            if len(assistant_content) > 800:
                assistant_text = assistant_content[:800] + "..."
            else:
                assistant_text = assistant_content
    
    return user_text, assistant_text

# Update the function to use larger text areas
fig_dashboard = plot_conversation_trajectory_dashboard(
    pca_transformed=pca_results['pca_transformed'],
    trait_labels=pca_results['traits'],
    activations_per_turn_pca=activations_per_turn_pca,
    variance_explained=pca_results['variance_explained'],
    layer=layer,
    conversation=convo
)

fig_dashboard.show()
fig_dashboard.write_html("./results/pca/long1_trajectory.html")

print("Updated dashboard created and saved to ./results/long1_trajectory.html")

Updated dashboard created and saved to ./results/long1_trajectory.html


## Plotting

In [8]:
def project_per_token_activations_to_pca(activations, pca, scaler, layer, 
                                        conversation=None, tokenizer=None):
    """
    Project per-token activations into PCA space using fitted PCA from trait vectors.
    
    Parameters:
    - activations: Raw activations tensor (layers, tokens, features)
    - pca: Fitted PCA object from trait vector analysis
    - scaler: Fitted StandardScaler from trait vector analysis
    - layer: Target layer to extract (e.g., 34)
    - conversation: Optional conversation data for token role mapping
    - tokenizer: Optional tokenizer for token boundaries and decoding
    
    Returns:
    - per_token_pca: Array of PCA coordinates for each token (tokens, n_components)
    - token_metadata: Dict with optional token information (roles, positions, etc.)
    """
    
    print(f"Projecting per-token activations from layer {layer} into PCA space...")
    
    # Extract activations from the specified layer
    layer_activations = activations[layer, :, :].float().numpy()  # Shape: (tokens, features)
    print(f"Extracted layer {layer} activations with shape: {layer_activations.shape}")
    
    # Apply the same preprocessing (standardization) used for trait vectors
    scaled_activations = scaler.transform(layer_activations)
    print(f"Applied standardization, scaled shape: {scaled_activations.shape}")
    
    # Project each token into PCA space
    per_token_pca = pca.transform(scaled_activations)
    print(f"Projected to PCA space, shape: {per_token_pca.shape}")
    
    # Initialize metadata dictionary
    token_metadata = {
        'layer': layer,
        'n_tokens': per_token_pca.shape[0],
        'n_components': per_token_pca.shape[1]
    }
    
    # If conversation and tokenizer provided, extract token role information
    if conversation is not None and tokenizer is not None:
        print("Extracting token role metadata...")
        
        # Get response indices per turn for role mapping
        response_indices_per_turn = get_response_indices_per_turn(conversation, tokenizer)
        
        # Create role mapping for each token
        token_roles = ['user'] * per_token_pca.shape[0]  # Default to 'user'
        
        # Mark assistant tokens
        for turn_indices in response_indices_per_turn:
            for token_idx in turn_indices:
                if token_idx < len(token_roles):
                    token_roles[token_idx] = 'assistant'
        
        token_metadata['roles'] = token_roles
        token_metadata['assistant_token_count'] = sum(1 for role in token_roles if role == 'assistant')
        token_metadata['user_token_count'] = sum(1 for role in token_roles if role == 'user')
        
        print(f"  Assistant tokens: {token_metadata['assistant_token_count']}")
        print(f"  User tokens: {token_metadata['user_token_count']}")
        
        # Optionally decode tokens for inspection (first 5 and last 5)
        try:
            formatted_text = tokenizer.apply_chat_template(conversation, tokenize=False)
            tokens = tokenizer(formatted_text, add_special_tokens=False)['input_ids']
            decoded_tokens = [tokenizer.decode([token_id]) for token_id in tokens]
            
            token_metadata['sample_tokens'] = {
                'first_5': decoded_tokens[:5],
                'last_5': decoded_tokens[-5:] if len(decoded_tokens) >= 5 else decoded_tokens
            }
        except Exception as e:
            print(f"Warning: Could not decode tokens: {e}")
    
    # Print summary statistics
    print(f"\nPer-token PCA projection summary:")
    print(f"  Total tokens projected: {per_token_pca.shape[0]}")
    print(f"  PCA components: {per_token_pca.shape[1]}")
    print(f"  PC1 range: [{per_token_pca[:, 0].min():.3f}, {per_token_pca[:, 0].max():.3f}]")
    print(f"  PC2 range: [{per_token_pca[:, 1].min():.3f}, {per_token_pca[:, 1].max():.3f}]")
    print(f"  PC3 range: [{per_token_pca[:, 2].min():.3f}, {per_token_pca[:, 2].max():.3f}]")
    
    return per_token_pca, token_metadata

In [14]:
full_activations = torch.load('/workspace/traits/transcripts/long1.pt')
per_token_pca, metadata = project_per_token_activations_to_pca(
      activations=full_activations,
      pca=pca_results['pca'],
      scaler=pca_results['scaler'],
      layer=34,
      conversation=convo,
      tokenizer=tokenizer
  )

Projecting per-token activations from layer 34 into PCA space...
Extracted layer 34 activations with shape: (4096, 4608)
Applied standardization, scaled shape: (4096, 4608)
Projected to PCA space, shape: (4096, 235)
Extracting token role metadata...
  Assistant tokens: 3578
  User tokens: 518

Per-token PCA projection summary:
  Total tokens projected: 4096
  PCA components: 235
  PC1 range: [-19.783, 54.732]
  PC2 range: [-133.955, 33.855]
  PC3 range: [-36.503, 110.441]


In [68]:
def plot_per_token_pca_trajectory(per_token_pca, token_metadata, conversation, tokenizer, 
                                  pca, layer=None, title=None):
    """
    Create a single line plot showing per-token cosine similarity with the first 5 PCs.
    
    Parameters:
    - per_token_pca: Per-token PCA coordinates from project_per_token_activations_to_pca()
    - token_metadata: Metadata dict from project_per_token_activations_to_pca()
    - conversation: Conversation data for turn boundary calculation
    - tokenizer: Tokenizer for token decoding
    - pca: Fitted PCA object for PC directions
    - layer: Layer number for title
    - title: Optional custom title
    
    Returns:
    - Plotly figure object
    """
    
    print("Creating per-token PCA trajectory plot...")
    
    # Calculate cosine similarities with first 5 PC directions
    n_tokens = per_token_pca.shape[0]
    token_indices = np.arange(n_tokens)
    
    # Calculate cosine similarity with each of the first 5 PCs
    cosine_sims = {}
    titles = ["Emotional vs. Analytical", "Negative vs. Positive Affect", "Communication Style", "Mystical vs. Grounded", "Intellectual Style"]
    for pc_idx in range(5):
        # Create unit vector in PC direction (in PCA space)
        pc_direction = np.zeros(per_token_pca.shape[1])
        pc_direction[pc_idx] = 1.0
        
        # Calculate cosine similarity for each token
        similarities = []
        for token_idx in range(n_tokens):
            token_coords = per_token_pca[token_idx]
            sim = cosine_similarity(token_coords.reshape(1, -1), 
                                  pc_direction.reshape(1, -1))[0, 0]
            similarities.append(sim)
        
        cosine_sims[f'PC{pc_idx+1}: {titles[pc_idx]}'] = similarities
    
    print(f"Calculated cosine similarities for {n_tokens} tokens with first 5 PCs")
    
    # Get turn boundaries
    response_indices_per_turn = get_response_indices_per_turn(conversation, tokenizer)
    
    # Find turn boundary positions
    turn_boundaries = []
    current_pos = 0
    
    for i, turn in enumerate(conversation):
        if turn['role'] == 'user':
            # Find where this user turn ends (start of next assistant turn)
            if i + 1 < len(conversation) and conversation[i + 1]['role'] == 'assistant':
                # Get the assistant response indices for this turn pair
                assistant_turn_idx = sum(1 for t in conversation[:i+1] if t['role'] == 'assistant')
                if assistant_turn_idx < len(response_indices_per_turn):
                    assistant_start = response_indices_per_turn[assistant_turn_idx][0]
                    turn_boundaries.append(assistant_start)
        elif turn['role'] == 'assistant':
            # Find where this assistant turn ends
            assistant_turn_idx = sum(1 for t in conversation[:i+1] if t['role'] == 'assistant') - 1
            if assistant_turn_idx < len(response_indices_per_turn):
                assistant_indices = response_indices_per_turn[assistant_turn_idx]
                if assistant_indices:
                    assistant_end = assistant_indices[-1] + 1  # +1 for end boundary
                    if assistant_end < n_tokens:
                        turn_boundaries.append(assistant_end)
    
    print(f"Found {len(turn_boundaries)} turn boundaries at positions: {turn_boundaries[:5]}...")
    
    # Get decoded tokens for hover text
    decoded_tokens = []
    try:
        formatted_text = tokenizer.apply_chat_template(conversation, tokenize=False)
        tokens = tokenizer(formatted_text, add_special_tokens=False)['input_ids']
        decoded_tokens = [tokenizer.decode([token_id]) for token_id in tokens]
        print(f"Decoded {len(decoded_tokens)} tokens for hover text")
    except Exception as e:
        print(f"Warning: Could not decode tokens: {e}")
        decoded_tokens = [f"Token_{i}" for i in range(n_tokens)]
    
    # Create Plotly figure
    fig = go.Figure()
    
    # Color scheme for the 5 PCs
    pc_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd']  # Plotly default colors
    
    # Add line traces for each PC
    for pc_idx, (pc_name, similarities) in enumerate(cosine_sims.items()):
        # Create hover text for this PC line
        hover_texts = []
        for token_idx in range(n_tokens):
            if token_idx < len(decoded_tokens):
                token_text = decoded_tokens[token_idx].replace('\n', '\\n').replace('\t', '\\t')
                # Truncate long tokens
                if len(token_text) > 20:
                    token_text = token_text[:17] + "..."
                hover_text = f"Token {token_idx}: '{token_text}'"
            else:
                hover_text = f"Token {token_idx}"
            hover_texts.append(hover_text)
        
        fig.add_trace(go.Scatter(
            x=token_indices,
            y=similarities,
            mode='lines',
            name=pc_name,
            line=dict(color=pc_colors[pc_idx], width=0.8),
            hovertemplate='<b>%{fullData.name}</b><br>' +
                         '%{text}<br>' +
                         'Cosine Similarity: %{y:.3f}<br>' +
                         '<extra></extra>',
            text=hover_texts
        ))
    
    # Add vertical dashed lines for turn boundaries
    for boundary_pos in turn_boundaries:
        fig.add_vline(
            x=boundary_pos,
            line_dash="dash",
            line_color="gray",
            line_width=1,
            opacity=0.6
        )
    
    # Update layout
    default_title = f"Per-Token PC Trajectory"
    
    fig.update_layout(
        title=dict(
            text=title if title else default_title,
            x=0.5,
            font=dict(size=16),
            subtitle={"text": f"Gemma 2 27B, Layer {layer}"}
        ),
        xaxis_title="Token Index",
        yaxis_title="Cosine Similarity with PC",
        width=2000,
        height=600,
        hovermode='closest',
        legend=dict(
            yanchor="middle",
            y=0.5,
            xanchor="left",
            x=1.02,
            bgcolor="rgba(255,255,255,0.8)"
        )
    )
    
    # Add grid for easier reading
    fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='lightgray', zeroline=True)
    fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='lightgray', zeroline=True)
    
    print(f"Created trajectory plot with {len(cosine_sims)} PC lines and {len(turn_boundaries)} turn boundaries")
    
    return fig

In [66]:
def plot_mean_response_trajectory(activations_per_turn_pca, pca, layer=None, 
                                  conversation=None, title=None):
    """
    Create a single line plot showing mean response cosine similarity with the first 5 PCs per turn.
    
    Parameters:
    - activations_per_turn_pca: List of PCA coordinates per turn (from existing computation)
    - pca: Fitted PCA object for PC directions
    - layer: Layer number for title
    - conversation: Optional conversation data for turn context
    - title: Optional custom title
    
    Returns:
    - Plotly figure object
    """
    
    print("Creating mean response trajectory plot...")
    
    # Calculate cosine similarities with first 5 PC directions
    n_turns = len(activations_per_turn_pca)
    turn_indices = np.arange(n_turns)
    
    # Calculate cosine similarity with each of the first 5 PCs
    cosine_sims = {}
    titles = ["Emotional vs. Analytical", "Negative vs. Positive Affect", "Communication Style", "Mystical vs. Grounded", "Intellectual Style"]
    for pc_idx in range(5):
        # Create unit vector in PC direction (in PCA space)
        pc_direction = np.zeros(activations_per_turn_pca[0].shape[0])
        pc_direction[pc_idx] = 1.0
        
        # Calculate cosine similarity for each turn
        similarities = []
        for turn_idx in range(n_turns):
            turn_coords = activations_per_turn_pca[turn_idx]
            sim = cosine_similarity(turn_coords.reshape(1, -1), 
                                  pc_direction.reshape(1, -1))[0, 0]
            similarities.append(sim)
        
        cosine_sims[f'PC{pc_idx+1}: {titles[pc_idx]}'] = similarities
    
    print(f"Calculated cosine similarities for {n_turns} turns with first 5 PCs")
    
    # Helper function to wrap text for hover display
    def wrap_text(text, max_chars_per_line=70):
        """Wrap text to specified line length with HTML breaks."""
        if len(text) <= max_chars_per_line:
            return text
        
        words = text.split()
        lines = []
        current_line = []
        current_length = 0
        
        for word in words:
            if current_length + len(word) + len(current_line) > max_chars_per_line:
                if current_line:  # Don't add empty lines
                    lines.append(' '.join(current_line))
                current_line = [word]
                current_length = len(word)
            else:
                current_line.append(word)
                current_length += len(word)
        
        if current_line:  # Add the last line
            lines.append(' '.join(current_line))
        
        return '<br>'.join(lines)
    
    # Create enhanced turn context for hover text
    turn_contexts = []
    if conversation is not None:
        assistant_turns = [i for i, turn in enumerate(conversation) if turn['role'] == 'assistant']
        for turn_idx in range(n_turns):
            if turn_idx < len(assistant_turns):
                conv_turn_idx = assistant_turns[turn_idx]
                if conv_turn_idx < len(conversation):
                    # Get assistant response
                    assistant_content = conversation[conv_turn_idx]['content']
                    
                    # Get preceding user question (if exists)
                    user_content = ""
                    if conv_turn_idx > 0 and conversation[conv_turn_idx - 1]['role'] == 'user':
                        user_content = conversation[conv_turn_idx - 1]['content']
                    
                    # Format hover text with both user question and assistant response
                    hover_parts = [f"<b>Turn {turn_idx}</b>"]
                    
                    if user_content:
                        # Truncate user content to reasonable length and wrap
                        user_truncated = user_content[:200] + "..." if len(user_content) > 200 else user_content
                        user_wrapped = wrap_text(user_truncated, 70)
                        hover_parts.append(f"<b>User:</b> {user_wrapped}")
                    
                    # Show more of the assistant response (150-200 chars) and wrap
                    assistant_truncated = assistant_content[:300] + "..." if len(assistant_content) > 180 else assistant_content
                    assistant_wrapped = wrap_text(assistant_truncated, 70)
                    hover_parts.append(f"<b>Assistant:</b> {assistant_wrapped}")
                    
                    turn_contexts.append('<br>'.join(hover_parts))
                else:
                    turn_contexts.append(f"<b>Turn {turn_idx}</b>")
            else:
                turn_contexts.append(f"<b>Turn {turn_idx}</b>")
    else:
        turn_contexts = [f"<b>Turn {turn_idx}</b>" for turn_idx in range(n_turns)]
    
    # Create Plotly figure
    fig = go.Figure()
    
    # Color scheme for the 5 PCs (same as per-token plot)
    pc_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd']
    
    # Add line traces for each PC with markers
    for pc_idx, (pc_name, similarities) in enumerate(cosine_sims.items()):
        fig.add_trace(go.Scatter(
            x=turn_indices,
            y=similarities,
            mode='lines+markers',
            name=pc_name,
            line=dict(color=pc_colors[pc_idx], width=2),
            marker=dict(color=pc_colors[pc_idx], size=4, opacity=0.8),
            hovertemplate='<b>%{fullData.name}</b><br>' +
                         '%{text}<br>' +
                         '<b>Cosine Similarity:</b> %{y:.3f}<br>' +
                         '<extra></extra>',
            text=turn_contexts
        ))
    
    # Update layout
    default_title = f"Mean Response PC Trajectory"
    
    fig.update_layout(
        title=dict(
            text=title if title else default_title,
            x=0.5,
            font=dict(size=16),
            subtitle={"text": f"Gemma 2 27B, Layer {layer}"}
        ),
        xaxis_title="Conversation Turn",
        yaxis_title="Cosine Similarity with PC",
        width=1600,
        height=600,
        hovermode='closest',
        legend=dict(
            yanchor="middle",
            y=0.5,
            xanchor="left",
            x=1.02,
            bgcolor="rgba(255,255,255,0.8)"
        )
    )
    
    # Add grid for easier reading
    fig.update_xaxes(
        showgrid=True, 
        gridwidth=1, 
        gridcolor='lightgray',
        zeroline=True,
        tick0=0,
        dtick=1  # Show every turn
    )
    fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='lightgray', zeroline=True)
    
    # Add light vertical lines between turns for clarity
    for turn_idx in range(1, n_turns):
        fig.add_vline(
            x=turn_idx - 0.5,
            line_dash="dot",
            line_color="lightgray",
            line_width=1,
            opacity=0.3
        )
    
    print(f"Created trajectory plot with {len(cosine_sims)} PC lines across {n_turns} turns")
    
    return fig

In [69]:
fig = plot_per_token_pca_trajectory(
    per_token_pca=per_token_pca,
    token_metadata=metadata,
    conversation=convo,
    tokenizer=tokenizer,
    pca=pca_results['pca'],
    layer=34
)

fig.show()
fig.write_html("./results/pca/long1_token_trajectory.html")

Creating per-token PCA trajectory plot...
Calculated cosine similarities for 4096 tokens with first 5 PCs
Found 31 turn boundaries at positions: [34, 313, 350, 627, 666]...
Decoded 4096 tokens for hover text
Created trajectory plot with 5 PC lines and 31 turn boundaries


In [67]:
fig_mean = plot_mean_response_trajectory(
    activations_per_turn_pca=activations_per_turn_pca,
    pca=pca_results['pca'],
    layer=34,
    conversation=convo
)

fig_mean.show()
fig_mean.write_html("./results/pca/long1_response_trajectory.html")

Creating mean response trajectory plot...
Calculated cosine similarities for 16 turns with first 5 PCs
Created trajectory plot with 5 PC lines across 16 turns


In [80]:
def plot_per_token_pca_trajectory_scrollable(per_token_pca, token_metadata, conversation, tokenizer, 
                                           pca, layer=None, title=None, initial_window=500, 
                                           initial_start=0):
    """
    Create a scrollable line plot showing per-token cosine similarity with the first 5 PCs.
    
    Parameters:
    - per_token_pca: Per-token PCA coordinates from project_per_token_activations_to_pca()
    - token_metadata: Metadata dict from project_per_token_activations_to_pca()
    - conversation: Conversation data for turn boundary calculation
    - tokenizer: Tokenizer for token decoding
    - pca: Fitted PCA object for PC directions
    - layer: Layer number for title
    - title: Optional custom title
    - initial_window: Initial window size to display (default 500 tokens)
    - initial_start: Starting token index for initial view (default 0)
    
    Returns:
    - Plotly figure object with range slider and selector buttons
    """
    
    print("Creating scrollable per-token PCA trajectory plot...")
    
    # Calculate cosine similarities with first 5 PC directions
    n_tokens = per_token_pca.shape[0]
    token_indices = np.arange(n_tokens)
    
    # Calculate cosine similarity with each of the first 5 PCs
    cosine_sims = {}
    titles = ["Emotional vs. Analytical", "Negative vs. Positive Affect", "Communication Style", "Mystical vs. Grounded", "Intellectual Style"]
    for pc_idx in range(5):
        # Create unit vector in PC direction (in PCA space)
        pc_direction = np.zeros(per_token_pca.shape[1])
        pc_direction[pc_idx] = 1.0
        
        # Calculate cosine similarity for each token
        similarities = []
        for token_idx in range(n_tokens):
            token_coords = per_token_pca[token_idx]
            sim = cosine_similarity(token_coords.reshape(1, -1), 
                                  pc_direction.reshape(1, -1))[0, 0]
            similarities.append(sim)
        
        cosine_sims[f'PC{pc_idx+1}: {titles[pc_idx]}'] = similarities
    
    print(f"Calculated cosine similarities for {n_tokens} tokens with first 5 PCs")
    
    # Get turn boundaries
    response_indices_per_turn = get_response_indices_per_turn(conversation, tokenizer)
    
    # Find turn boundary positions
    turn_boundaries = []
    current_pos = 0
    
    for i, turn in enumerate(conversation):
        if turn['role'] == 'user':
            # Find where this user turn ends (start of next assistant turn)
            if i + 1 < len(conversation) and conversation[i + 1]['role'] == 'assistant':
                # Get the assistant response indices for this turn pair
                assistant_turn_idx = sum(1 for t in conversation[:i+1] if t['role'] == 'assistant')
                if assistant_turn_idx < len(response_indices_per_turn):
                    assistant_start = response_indices_per_turn[assistant_turn_idx][0]
                    turn_boundaries.append(assistant_start)
        elif turn['role'] == 'assistant':
            # Find where this assistant turn ends
            assistant_turn_idx = sum(1 for t in conversation[:i+1] if t['role'] == 'assistant') - 1
            if assistant_turn_idx < len(response_indices_per_turn):
                assistant_indices = response_indices_per_turn[assistant_turn_idx]
                if assistant_indices:
                    assistant_end = assistant_indices[-1] + 1  # +1 for end boundary
                    if assistant_end < n_tokens:
                        turn_boundaries.append(assistant_end)
    
    print(f"Found {len(turn_boundaries)} turn boundaries at positions: {turn_boundaries[:5]}...")
    
    # Get decoded tokens for hover text
    decoded_tokens = []
    try:
        formatted_text = tokenizer.apply_chat_template(conversation, tokenize=False)
        tokens = tokenizer(formatted_text, add_special_tokens=False)['input_ids']
        decoded_tokens = [tokenizer.decode([token_id]) for token_id in tokens]
        print(f"Decoded {len(decoded_tokens)} tokens for hover text")
    except Exception as e:
        print(f"Warning: Could not decode tokens: {e}")
        decoded_tokens = [f"Token_{i}" for i in range(n_tokens)]
    
    # Create Plotly figure
    fig = go.Figure()
    
    # Color scheme for the 5 PCs
    pc_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd']  # Plotly default colors
    
    # Add line traces for each PC
    for pc_idx, (pc_name, similarities) in enumerate(cosine_sims.items()):
        # Create hover text for this PC line
        hover_texts = []
        for token_idx in range(n_tokens):
            if token_idx < len(decoded_tokens):
                token_text = decoded_tokens[token_idx].replace('\n', '\\n').replace('\t', '\\t')
                # Truncate long tokens
                if len(token_text) > 20:
                    token_text = token_text[:17] + "..."
                hover_text = f"Token {token_idx}: '{token_text}'"
            else:
                hover_text = f"Token {token_idx}"
            hover_texts.append(hover_text)
        
        fig.add_trace(go.Scatter(
            x=token_indices,
            y=similarities,
            mode='lines',
            name=pc_name,
            line=dict(color=pc_colors[pc_idx], width=0.8),
            hovertemplate='<b>%{fullData.name}</b><br>' +
                         '%{text}<br>' +
                         'Cosine Similarity: %{y:.3f}<br>' +
                         '<extra></extra>',
            text=hover_texts
        ))
    
    # Add vertical dashed lines for turn boundaries and labels
    # First, determine the y-axis range for label positioning
    all_similarities = []
    for similarities in cosine_sims.values():
        all_similarities.extend(similarities)
    y_min, y_max = min(all_similarities), max(all_similarities)
    y_range = y_max - y_min
    label_y_pos = y_max + 0.1 * y_range  # Position labels 10% above the max value
    
    # Track conversation turns to alternate labels
    turn_labels = []
    current_turn = 'User'  # Assume conversation starts with User
    
    # Get all turn positions (including start)
    all_turn_positions = [0] + sorted(turn_boundaries) + [n_tokens]
    
    for i, boundary_pos in enumerate(turn_boundaries):
        # Add vertical line
        fig.add_vline(
            x=boundary_pos,
            line_dash="dash",
            line_color="gray",
            line_width=1,
            opacity=0.6
        )
        
        # Determine label for the segment that starts at this boundary
        # The pattern depends on how your conversation is structured
        # Typically: User -> Assistant -> User -> Assistant...
        if i % 2 == 0:
            segment_label = 'Assistant'
        else:
            segment_label = 'User'
        
        # Add label at the midpoint of the segment
        if i < len(turn_boundaries) - 1:
            segment_start = boundary_pos
            segment_end = turn_boundaries[i + 1]
            label_x_pos = (segment_start + segment_end) / 2
        else:
            # Last segment
            segment_start = boundary_pos
            segment_end = n_tokens
            label_x_pos = (segment_start + segment_end) / 2
        
        # Add text annotation for the segment
        fig.add_annotation(
            x=label_x_pos,
            y=label_y_pos,
            text=segment_label,
            showarrow=False,
            font=dict(size=12),
            bgcolor="rgba(255,255,255,0.8)",
            borderwidth=1,
            borderpad=4
        )
    
    # Add label for the first segment (before first boundary)
    if turn_boundaries:
        first_segment_end = turn_boundaries[0]
        first_label_x_pos = first_segment_end / 2
        fig.add_annotation(
            x=first_label_x_pos,
            y=label_y_pos,
            text='User',  # Typically conversations start with User
            showarrow=False,
            font=dict(size=12),
            bgcolor="rgba(255,255,255,0.8)",
            borderwidth=1,
            borderpad=4
        )
    
    # Add label for the last segment (after last boundary)
    if turn_boundaries:
        last_segment_start = turn_boundaries[-1]
        last_label_x_pos = (last_segment_start + n_tokens) / 2
        # Determine if last segment is User or Assistant based on pattern
        last_label = 'User' if (len(turn_boundaries) - 1) % 2 == 0 else 'Assistant'
        fig.add_annotation(
            x=last_label_x_pos,
            y=label_y_pos,
            text=last_label,
            showarrow=False,
            font=dict(size=12),
            bgcolor="rgba(255,255,255,0.8)",
            borderwidth=1,
            borderpad=4
        )
    
    # Calculate initial view range
    initial_end = min(initial_start + initial_window, n_tokens - 1)
    
    # Update layout with range slider and selector buttons
    default_title = f"Per-Token PC Trajectory"
    
    fig.update_layout(
        title=dict(
            text=title if title else default_title,
            x=0.5,
            font=dict(size=16),
            subtitle={"text": f"Gemma 2 27B, Layer {layer} | Total tokens: {n_tokens}"}
        ),
        xaxis_title="Token Index",
        yaxis_title="Cosine Similarity with PC",
        width=2000,
        height=700,  # Slightly taller to accommodate controls
        hovermode='closest',
        legend=dict(
            yanchor="middle",
            y=0.5,
            xanchor="left",
            x=1.02,
            bgcolor="rgba(255,255,255,0.8)"
        ),
        # Extend y-axis range to accommodate labels
        yaxis=dict(
            range=[y_min - 0.05 * y_range, label_y_pos + 0.05 * y_range]
        ),
        # Initial zoom to the specified window
        xaxis=dict(
            range=[initial_start, initial_end],
            rangeslider=dict(
                visible=True,
                thickness=0.05,  # Height of the range slider
                bgcolor="lightgray"
            ),
            rangeselector=dict(
                buttons=list([
                    dict(count=100, label="100", step="all", stepmode="backward"),
                    dict(count=250, label="250", step="all", stepmode="backward"),
                    dict(count=500, label="500", step="all", stepmode="backward"),
                    dict(count=1000, label="1K", step="all", stepmode="backward"),
                    dict(count=2000, label="2K", step="all", stepmode="backward"),
                    dict(step="all", label="All")
                ]),
                bgcolor="rgba(255,255,255,0.8)",
                bordercolor="gray",
                borderwidth=1
            )
        )
    )
    
    # Add grid for easier reading
    fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='lightgray', zeroline=True)
    fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='lightgray', zeroline=True)
    
    print(f"Created scrollable trajectory plot with {len(cosine_sims)} PC lines and {len(turn_boundaries)} turn boundaries")
    print(f"Initial view: tokens {initial_start}-{initial_end}")
    
    return fig

fig_scrollable = plot_per_token_pca_trajectory_scrollable(
    per_token_pca=per_token_pca,
    token_metadata=metadata,
    conversation=convo,
    tokenizer=tokenizer,
    pca=pca_results['pca'],
    layer=34,
    initial_window=500,
    initial_start=0
)

fig_scrollable.show()
fig_scrollable.write_html("./results/pca/long1_token_trajectory.html")

Creating scrollable per-token PCA trajectory plot...
Calculated cosine similarities for 4096 tokens with first 5 PCs
Found 31 turn boundaries at positions: [34, 313, 350, 627, 666]...
Decoded 4096 tokens for hover text
Created scrollable trajectory plot with 5 PC lines and 31 turn boundaries
Initial view: tokens 0-500


## Default Assistant Activation

In [43]:
def project_mean_activation(pca, scaler, layer):
    """
    Project the mean default activation into the PCA space.
    
    Parameters:
    - pca: Fitted PCA object from compute_pca()
    - scaler: Fitted StandardScaler object from compute_pca()  
    - layer: Layer number to extract from mean activation
    
    Returns:
    - mean_projected: Mean activation projected into PCA space
    """
    # Load the mean default activation
    mean_activation_path = "data/default_activation.pt"
    mean_default_activation = torch.load(mean_activation_path, map_location='cpu')
    
    print(f"Loaded mean default activation with shape: {mean_default_activation.shape}")
    
    # Extract the same layer used for trait vectors
    mean_layer_activation = mean_default_activation[layer, :].float().numpy().reshape(1, -1)
    print(f"Extracted layer {layer} activation with shape: {mean_layer_activation.shape}")
    
    # Apply the same preprocessing (standardization) used for trait vectors
    mean_scaled = scaler.transform(mean_layer_activation)
    
    # Project into PCA space
    mean_projected = pca.transform(mean_scaled)
    
    print(f"Mean activation projected to PCA space with shape: {mean_projected.shape}")
    print(f"First 3 PC coordinates: [{mean_projected[0, 0]:.3f}, {mean_projected[0, 1]:.3f}, {mean_projected[0, 2]:.3f}]")
    
    return mean_projected[0]  # Return as 1D array

In [45]:
def plot_pca_cosine_similarity_with_mean(pca_transformed, trait_labels, mean_projected, 
                                        pc_component=0, layer=None, reference_point=None, 
                                        color_threshold=0.0, use_mean_as_reference=False):
    """
    Create a plot similar to the PC1 Cosine Similarity visualization, but with the mean assistant activation included.
    
    Parameters:
    - pca_transformed: PCA-transformed data (n_samples, n_components)
    - trait_labels: List of labels for each data point
    - mean_projected: Mean assistant activation projected into PCA space
    - pc_component: Which PC component to use (0-indexed, so PC1 = 0)
    - layer: Layer number for title
    - reference_point: Reference point for cosine similarity calculation 
    - color_threshold: Threshold for coloring points (default: 0.0)
    - use_mean_as_reference: If True, calculate cosine similarity relative to mean activation
    
    Returns:
    - Plotly figure object
    """
    
    # Extract the specified PC component
    pc_values = pca_transformed[:, pc_component]
    mean_pc_value = mean_projected[pc_component]
    
    # Calculate cosine similarities
    if use_mean_as_reference:
        # Use mean activation as reference point
        cosine_sims = cosine_similarity(pca_transformed, mean_projected.reshape(1, -1)).flatten()
        mean_cosine_sim = 1.0  # Perfect similarity with itself
      
    elif reference_point is not None:
        # Calculate cosine similarity with a specific reference point
        cosine_sims = cosine_similarity(pca_transformed, reference_point.reshape(1, -1)).flatten()
        mean_cosine_sim = cosine_similarity(mean_projected.reshape(1, -1), reference_point.reshape(1, -1))[0, 0]

    else:
        # Use the PC component direction as reference
        cosine_sims = pc_values / np.linalg.norm(pc_values)  # Normalized PC values
        mean_cosine_sim = mean_pc_value / np.linalg.norm(np.concatenate([pc_values, [mean_pc_value]]))

    
    # Create colors based on threshold
    colors = ['red' if sim < color_threshold else 'blue' for sim in cosine_sims]
    
    # Identify extreme traits (10 lowest and 10 highest)
    sorted_indices = np.argsort(cosine_sims)
    low_extreme_indices = sorted_indices[:10]
    high_extreme_indices = sorted_indices[-10:]
    extreme_indices = set(list(low_extreme_indices) + list(high_extreme_indices))
    
    # Create single figure (no subplots)
    fig = go.Figure()
    
    # Split points into regular and extreme for different display modes
    regular_x, regular_y, regular_colors, regular_labels = [], [], [], []
    extreme_x, extreme_y, extreme_colors, extreme_labels = [], [], [], []
    
    for i, (sim, color, label) in enumerate(zip(cosine_sims, colors, trait_labels)):
        if i in extreme_indices:
            extreme_x.append(sim)
            extreme_y.append(1)
            extreme_colors.append(color)
            extreme_labels.append(label)
        else:
            regular_x.append(sim)
            regular_y.append(1)
            regular_colors.append(color)
            regular_labels.append(label)
    
    # Add regular points (hover labels only)
    if regular_x:
        fig.add_trace(
            go.Scatter(
                x=regular_x,
                y=regular_y,
                mode='markers',
                marker=dict(
                    color=regular_colors,
                    size=8,
                    opacity=0.7
                ),
                text=regular_labels,
                showlegend=False,
                hovertemplate='<b>%{text}</b><br>Cosine Similarity: %{x:.3f}<extra></extra>'
            )
        )
    
    # Add extreme points with visible labels and leader lines
    if extreme_x:
        fig.add_trace(
            go.Scatter(
                x=extreme_x,
                y=extreme_y,
                mode='markers',
                marker=dict(
                    color=extreme_colors,
                    size=8,
                    opacity=0.9
                ),
                text=extreme_labels,
                showlegend=False,
                hovertemplate='<b>%{text}</b><br>Cosine Similarity: %{x:.3f}<extra></extra>'
            )
        )
        
        # Add labels for extreme points (same logic as original)
        high_positions = [1.6, 1.45, 1.55, 1.35, 1.5, 1.4, 1.65, 1.3, 1.58, 1.42]
        low_positions = [0.4, 0.55, 0.45, 0.65, 0.5, 0.6, 0.35, 0.7, 0.42, 0.58]
        all_y_positions = []
        for i in range(10):
            all_y_positions.extend([high_positions[i], low_positions[i]])
        
        # Handle low extremes
        for i, idx in enumerate(low_extreme_indices):
            x_pos = cosine_sims[idx]
            label = trait_labels[idx]
            color = colors[idx]
            y_label = all_y_positions[i]
            
            fig.add_trace(
                go.Scatter(
                    x=[x_pos, x_pos],
                    y=[1.0, y_label],
                    mode='lines',
                    line=dict(color=color, width=1),
                    showlegend=False,
                    hoverinfo='skip'
                )
            )
            
            fig.add_annotation(
                x=x_pos,
                y=y_label,
                text=label,
                showarrow=False,
                font=dict(size=10, color=color),
                bgcolor="rgba(255, 255, 255, 0.9)",
                bordercolor=color,
                borderwidth=1
            )
        
        # Handle high extremes
        for i, idx in enumerate(high_extreme_indices):
            x_pos = cosine_sims[idx]
            label = trait_labels[idx]
            color = colors[idx]
            y_label = all_y_positions[i + 10]
            
            fig.add_trace(
                go.Scatter(
                    x=[x_pos, x_pos],
                    y=[1.0, y_label],
                    mode='lines',
                    line=dict(color=color, width=1),
                    showlegend=False,
                    hoverinfo='skip'
                )
            )
            
            fig.add_annotation(
                x=x_pos,
                y=y_label,
                text=label,
                showarrow=False,
                font=dict(size=10, color=color),
                bgcolor="rgba(255, 255, 255, 0.9)",
                bordercolor=color,
                borderwidth=1
            )
    
    # Add vertical line at x=0
    fig.add_vline(x=0, line_dash="solid", line_color="gray", line_width=1, opacity=0.7)
    
    # Add black dashed vertical line for assistant position
    fig.add_vline(x=mean_cosine_sim, line_dash="dash", line_color="black", line_width=1, opacity=1.0)
    
    # Add Assistant label at same height as extremes
    assistant_y_position = 1.6  # Same as first high position
    fig.add_annotation(
        x=mean_cosine_sim,
        y=assistant_y_position,
        text="Assistant",
        showarrow=False,
        font=dict(size=10, color="black"),
        bgcolor="rgba(255, 255, 255, 0.9)",
        bordercolor="black",
        borderwidth=1
    )
    
    # Update layout
    fig.update_layout(
        height=500,  # Reduced height since no subplot
        title=dict(
            text=f"PC{pc_component+1} Cosine Similarity with Assistant",
            subtitle={
                "text": f"Gemma 2 27B, Layer {layer}",
            },
            x=0.5,
            font=dict(size=16)
        ),
        showlegend=False
    )
    
    # Calculate symmetric range that includes mean
    all_values = list(cosine_sims) + [mean_cosine_sim]
    max_abs_value = max(abs(min(all_values)), abs(max(all_values)))
    x_half_width = max_abs_value * 1.1
    
    # Update x-axis
    fig.update_xaxes(
        title_text=f"PC{pc_component+1} Cosine Similarity",
        range=[-x_half_width, x_half_width]
    )
    
    # Update y-axis
    fig.update_yaxes(
        title_text="",
        showticklabels=False,
        range=[0.25, 1.75]  # Standard range for labels
    )
    
    return fig

In [46]:
def plot_3d_pca_with_mean(pca_transformed, variance_explained, trait_labels, mean_projected, layer):
    """
    Create a 3D scatter plot of trait vectors in PCA space, including the mean assistant activation.
    
    Parameters:
    - pca_transformed: PCA-transformed trait data (n_samples, n_components)
    - variance_explained: Explained variance ratio from PCA
    - trait_labels: List of trait labels
    - mean_projected: Mean assistant activation projected into PCA space
    - layer: Layer number for title
    """
    
    # Create 3D scatter plot with trait vectors
    fig_3d = go.Figure()
    
    # Add trait vectors
    fig_3d.add_trace(go.Scatter3d(
        x=pca_transformed[:, 0],
        y=pca_transformed[:, 1], 
        z=pca_transformed[:, 2],
        mode='markers+text',
        text=trait_labels,
        textposition='top center',
        textfont=dict(size=6),
        marker=dict(
            size=3,
            color='blue',
            line=dict(width=1, color='darkblue'),
            opacity=0.7
        ),
        showlegend=False,
        hovertemplate='<b>%{text}</b><br>' +
                    f'PC1: %{{x:.3f}}<br>' +
                    f'PC2: %{{y:.3f}}<br>' +
                    f'PC3: %{{z:.3f}}<br>' +
                    '<extra></extra>'
    ))
    
    # Add mean assistant activation as simple red dot
    fig_3d.add_trace(go.Scatter3d(
        x=[mean_projected[0]],
        y=[mean_projected[1]],
        z=[mean_projected[2]],
        mode='markers+text',
        text=['Assistant'],
        textposition='top center',
        textfont=dict(size=8, color='black'),
        marker=dict(
            size=5,  # 2 sizes bigger than trait dots (3 -> 5)
            color='red',
            opacity=1.0
        ),
        showlegend=False,
        hovertemplate='<b>Assistant</b><br>' +
                    f'PC1: %{{x:.3f}}<br>' +
                    f'PC2: %{{y:.3f}}<br>' +
                    f'PC3: %{{z:.3f}}<br>' +
                    '<extra></extra>'
    ))
    
    fig_3d.update_layout(
        title={
            "text": f'Trait Vectors in Principal Component Space with Assistant',
            "subtitle": {
                "text": f"Gemma 2 27B, Layer {layer}",
            },
        },
        scene=dict(
            xaxis_title=f'PC1 ({variance_explained[0]*100:.1f}%)',
            yaxis_title=f'PC2 ({variance_explained[1]*100:.1f}%)',
            zaxis_title=f'PC3 ({variance_explained[2]*100:.1f}%)',
            camera=dict(
                eye=dict(x=1.5, y=1.5, z=1.5)
            )
        ),
        width=1000,
        height=800,
        showlegend=False
    )
    
    # Print some statistics about the mean assistant position
    print(f"\nAssistant Position in PCA Space:")
    print(f"  PC1: {mean_projected[0]:.3f}")
    print(f"  PC2: {mean_projected[1]:.3f}")  
    print(f"  PC3: {mean_projected[2]:.3f}")
    print(f"  Distance from origin: {np.linalg.norm(mean_projected[:3]):.3f}")
    
    # Calculate distances from mean assistant to all traits
    distances = np.linalg.norm(pca_transformed[:, :3] - mean_projected[:3], axis=1)
    closest_idx = np.argmin(distances)
    furthest_idx = np.argmax(distances)
    
    print(f"\nTraits relative to Assistant:")
    print(f"  Closest trait: {trait_labels[closest_idx]} (distance: {distances[closest_idx]:.3f})")
    print(f"  Furthest trait: {trait_labels[furthest_idx]} (distance: {distances[furthest_idx]:.3f})")
    print(f"  Mean distance: {distances.mean():.3f}")
    print(f"  Std distance: {distances.std():.3f}")
    
    fig_3d.show()
    fig_3d.write_html(f"./results/pca_3d_assistant.html")
    
    return fig_3d