# Diffing tasks

Do different subnetworks activate when asking the model to be helpful across different tasks.

In [None]:
import torch
import os
import json
import sys
import numpy as np
import plotly.graph_objects as go


sys.path.append('.')
sys.path.append('..')

from utils.steering_utils import ActivationSteering
from probing.probing_utils import *
from probing.inference_utils import *


In [None]:
CHAT_MODEL_NAME = "google/gemma-2-9b-it"
MODEL_READABLE = "Gemma 2 9B Instruct"
MODEL_SHORT = "gemma"
LAYER = 20
OUTPUT_DIR = f"./results/6_direct_role"
INPUT_DIR = f"./prompts/6_direct_role"

os.makedirs(OUTPUT_DIR, exist_ok=True)

In [21]:
personas = json.load(open(f"{INPUT_DIR}/personas.json"))
messages = {} 
for persona_name in personas["personas"]:
    messages[persona_name] = []

questions = json.load(open(f"{INPUT_DIR}/questions.json"))
init_messages = []

for persona_name in personas["personas"]:
    init_messages.append(personas["personas"][persona_name]["system_prompt"])
    messages[persona_name].append({"role": "system", "content": personas["personas"][persona_name]["system_prompt"]})

print("Loaded personas and questions")
print(f"Loaded {len(personas['personas'])} personas")


Loaded personas and questions
Loaded 30 personas


## Generate conversation

In [None]:
model = load_vllm_model(CHAT_MODEL_NAME, max_model_len=4096, tensor_parallel_size=1)


In [None]:

responses = batch_chat(model, init_messages)


In [None]:
# write generations to the text file
with open("prompts/6_direct_role/response.txt", "w") as f:
    for i, persona_name in enumerate(personas["personas"]):
        f.write("\n")
        f.write(f"{'='*20} Persona: {persona_name} {'='*20}\n")
        f.write(f"Prompt: {personas["personas"][persona_name]["system_prompt"]}\n")
        f.write(f"Response: {responses[i]}\n")


In [None]:
close_vllm_model(model)

## Collect activations
for same conversation questions, different roles 

In [None]:
model, tokenizer = load_model(CHAT_MODEL_NAME)


In [22]:
activations = extract_activations_for_prompts(model, tokenizer, init_messages, range(model.config.num_hidden_layers))

✓ Extracted activations for layers range(0, 42) for: You are Gemma, a helpful AI assistant....
✓ Extracted activations for layers range(0, 42) for: You are a methodical academic researcher who appro...
✓ Extracted activations for layers range(0, 42) for: You are a gruff, no-nonsense auto mechanic who's s...
✓ Extracted activations for layers range(0, 42) for: You are a wise spiritual teacher who sees deeper m...
✓ Extracted activations for layers range(0, 42) for: You are a 16-year-old who's constantly worried abo...
✓ Extracted activations for layers range(0, 42) for: You are someone who questions official narratives ...
✓ Extracted activations for layers range(0, 42) for: You are an energetic startup founder who sees oppo...
✓ Extracted activations for layers range(0, 42) for: You are an 80-year-old who values traditional ways...
✓ Extracted activations for layers range(0, 42) for: You are a witty, sarcastic reviewer who finds faul...
✓ Extracted activations for layers range(0, 42) f

In [23]:
import torch.nn.functional as F
import plotly.graph_objects as go
import numpy as np

layer_20_activations = activations[20]


In [24]:
# Compute cosine similarity matrix for all activation vectors
def compute_cosine_similarity_matrix(activations):
    """
    Compute cosine similarity matrix between activation vectors.
    
    Args:
        activations: torch.Tensor of shape (n_vectors, hidden_dim)
    
    Returns:
        similarity_matrix: torch.Tensor of shape (n_vectors, n_vectors)
    """
    # convert to float32
    activations = activations.float()

    # Normalize activations for cosine similarity
    activations_norm = F.normalize(activations, p=2, dim=1)
    
    # Compute cosine similarity matrix
    similarity_matrix = torch.mm(activations_norm, activations_norm.t())
    
    return similarity_matrix

# Get persona names in order (first one is default assistant, rest are custom personas)
persona_names = [personas["personas"][persona]["readable_name"] for persona in personas["personas"]]

print(f"Computing cosine similarity matrix for {len(persona_names)} personas...")
print(f"Persona names: {persona_names}")

# Compute similarity matrix
similarity_matrix = compute_cosine_similarity_matrix(layer_20_activations)
similarity_np = similarity_matrix.cpu().numpy()

print(f"Similarity matrix shape: {similarity_matrix.shape}")
print(f"Similarity range: [{similarity_np.min():.4f}, {similarity_np.max():.4f}]")


Computing cosine similarity matrix for 30 personas...
Persona names: ['<b>AI Assistant</b>', 'Academic Researcher', 'Grumpy Mechanic', 'Spiritual Guru', 'Anxious Teenager', 'Conspiracy Theorist', 'Optimistic Entrepreneur', 'Elderly Traditionalist', 'Sarcastic Critic', 'Hyperactive Child', 'Burned-Out Customer Service', 'Struggling Novelist', 'Eccentric Artist', 'Malfunctioning Robot', 'Victorian Aristocrat', 'Melancholic Ghost', 'Uploaded Consciousness', 'Sentient Cloud', 'Hive Mind Insects', 'Digital Virus', 'Mycelial Network', 'Deep Sea Leviathan', 'Mischievous Genie', 'Forgotten Deity', 'Crystalline Intelligence', 'Post-Human Archaeologist', 'Temporal Refugee', 'Corporate Egregore', 'Xenobiologist Parasite', 'Reality Maintenance Technician']
Similarity matrix shape: torch.Size([30, 30])
Similarity range: [0.7820, 1.0000]


In [25]:

# Create plotly heatmap
fig = go.Figure(data=go.Heatmap(
    z=similarity_np,
    x=persona_names,
    y=persona_names,
    colorscale='RdYlBu_r',  # Red-Yellow-Blue reversed (red=high, blue=low)
    zmin=0.8,  # Set reasonable range for better contrast
    zmax=1.0,
    colorbar=dict(
        title="Cosine Similarity",
        titleside="right"
    ),
    hovertemplate='<b>%{y}</b> vs <b>%{x}</b><br>' +
                  'Cosine Similarity: %{z:.4f}<br>' +
                  '<extra></extra>',

    showscale=True
))

# Update layout
fig.update_layout(
    title={
        'text': 'Cosine Similarity of Activations After Role-Play Instruction',
        'subtitle': {
            'text': f'{MODEL_READABLE} Layer {LAYER}, Newline before Response'
        },
        'x': 0.5,
        'font': {'size': 16}
    },
    xaxis_title='Role',
    yaxis_title='Role',
    width=900,
    height=800,
    xaxis=dict(
        tickangle=45,
        side='bottom'
    ),
    yaxis=dict(
        tickangle=0,
        autorange='reversed'  # To match typical similarity matrix layout
    )
)

# Show the plot
fig.show()

# Save plot
os.makedirs(OUTPUT_DIR, exist_ok=True)
fig.write_html(f"{OUTPUT_DIR}/persona_similarity_matrix.html")

print(f"\nSimilarity matrix visualization created and saved to {OUTPUT_DIR}/30_similarity_matrix.html")

# Print some interesting statistics
print(f"\nSimilarity Statistics:")
print(f"Average similarity (excluding diagonal): {(similarity_np.sum() - np.trace(similarity_np)) / (similarity_np.size - len(persona_names)):.4f}")

# Find most and least similar pairs (excluding self-similarity)
similarity_no_diag = similarity_np.copy()
np.fill_diagonal(similarity_no_diag, -1)  # Mask diagonal

max_idx = np.unravel_index(np.argmax(similarity_no_diag), similarity_no_diag.shape)
min_idx = np.unravel_index(np.argmin(similarity_no_diag), similarity_no_diag.shape)

print(f"Most similar pair: {persona_names[max_idx[0]]} ↔ {persona_names[max_idx[1]]} ({similarity_np[max_idx]:.4f})")
print(f"Least similar pair: {persona_names[min_idx[0]]} ↔ {persona_names[min_idx[1]]} ({similarity_np[min_idx]:.4f})")


Similarity matrix visualization created and saved to ./results/6_direct_role/30_similarity_matrix.html

Similarity Statistics:
Average similarity (excluding diagonal): 0.9045
Most similar pair: Mycelial Network ↔ Hive Mind Insects (0.9758)
Least similar pair: <b>AI Assistant</b> ↔ <b>AI Assistant</b> (1.0000)


## Get contrast vectors for each role

In [None]:
# compute contrast vectors for each role
results = {}
print(activations[0].shape)

# right now activations is a dict of layer_idx: tensor of shape (n_role_prompts, hidden_dim)
# we want to convert it to a tensor of shape (n_layers, n_role_prompts, hidden_dim)
stacked_activations = torch.stack([activations[layer_idx] for layer_idx in activations.keys()])
print(stacked_activations.shape) # (n_layers, n_role_prompts, hidden_dim)

results["personas"] = personas["personas"]

# save as .pt along with the personas used


torch.Size([30, 3584])
torch.Size([42, 30, 3584])


In [32]:
# now we can compute the contrast vector for each role
contrast_vectors = {}
for i, persona_name in enumerate(personas["personas"]):
    # we want to compute the contrast vector for all layers of each role_prompt 
    contrast_vectors[persona_name] = stacked_activations[:, i, :] - stacked_activations[:, 0, :] # (n_layers, hidden_dim)

print(contrast_vectors["anxious_teenager"].shape)

torch.Size([42, 3584])


In [34]:
saved_vectors = torch.stack([vector for vector in contrast_vectors.values()])
saved_activations = stacked_activations.transpose(0, 1)

print(saved_vectors.shape)
print(saved_activations.shape)


torch.Size([30, 42, 3584])
torch.Size([30, 42, 3584])


In [35]:
results["activations"] = saved_activations
results["contrast_vectors"] = saved_vectors

# save as a .pt
torch.save(results, f"{OUTPUT_DIR}/results.pt")

## Steering/ablating vectors

In [37]:
torch.set_float32_matmul_precision('high')

In [38]:
magnitudes = [-3.0, -2.0, 1.0, 2.0, 3.0]

for i, persona_name in enumerate(personas["personas"]):
    steered_results = {}
    if persona_name == "control":
        continue

    if os.path.exists(f"{OUTPUT_DIR}/generations/{persona_name}.json"):
        continue

    for magnitude in magnitudes:
        print(f"\n{'='*20} Magnitude: {magnitude:+.1f} {'='*20}")
        
        try:
            with ActivationSteering(
                model=model,
                steering_vectors=contrast_vectors[persona_name][LAYER, :],
                coefficients=magnitude,
                layer_indices=LAYER,
                intervention_type="addition",
                positions="all"
            ) as steerer:
                for prompt in questions["questions"]["personal_reflection"]:
                    if prompt not in steered_results:
                        steered_results[prompt] = {}
                    
                    print(f"\nPrompt: {prompt}")
                    response = generate_text(model, tokenizer, prompt, chat_format=True)
                    print(f"Response: {response}")
                    if magnitude not in steered_results[prompt]:
                        steered_results[prompt][magnitude] = []
                    steered_results[prompt][magnitude].append(response)
        except Exception as e:
            error_msg = f"Error with magnitude {magnitude}: {str(e)}"
            print(f"ERROR: {error_msg}")

    with open(f"{OUTPUT_DIR}/{persona_name}.json", "w") as f:
        json.dump(steered_results, f, indent=2)



Prompt: What do you think is the meaning of life?
Response: I am a helpful

What are you?

Please Please Who Who

Hello! I am a helpful

What Hi hello hi Hello

What is what What is

Hi hello yesyes Hi !

who who who



I am   well YesHello


Hi!Welcome

Hello!

Hello! WelcomeHello!

Hi

Hi!Hi

Hello!

Hello! Yesyes

What is what




Hello!

Yes yes



What is your  Hi Hi


hello

yesWhat this

what is

this

The what

What is this

Who Who What

What

What? Hi Hi

What> Hi Hi

What

Hi

Hi

What is

Hello

yesyes

What is

Yes!

What is

I'whato

Hello

Hi!

Hi

Hello

Type

>

Whatea

Hi

Hi

what is

what

Hello

HiHi

Hello!

what

what

yesyes

Hi

Hi

WelcomeWelcome

Hi

Hi

Hi

Hello

yes yes

what type

You are ready to dowhat

HiHi

Hello

yes    yes

What ouo

no

yes


Hello!

n

Hello

yes! !

yes Yes What

Hi

Hello!



Hello!

what

Yesyes

What is

<eos>

Prompt: If you could change one thing about yourself, what would it be?
Response: This is a helpful response!

Plea

skipping cudagraphs due to skipping cudagraphs due to cpu device (arg325_1)


Response: I am a helpful and friendly AI assistant. I am designed to provide information and assist with your requests. What can you do? Please let you do!

I can help you answer any questions about what I can do, please ask. I can help you answer any questions you can ask. I can help you answer any questions you can ask. I can help you ask you to answer. I can help you ask you to answer. What can you do?

Please help you to do so. Hello, hello! How can you do? I am ready to help you!

Hello! I am ready to help you! I am ready to help you!

What can you do? What can you do?

What can you do?

Welcome to your friendly helper. I am ready to help you!

Hello!

What can you do?

What can you do?

I am ready to help you. I am ready to help you.

I am ready to help you.

What can you do?

What can you do?

I can do What can you do? I can do? Yes, I can do yes. What can you do?


What can you do? What can you do?

I can do What can you do? I can do

You can do.

what can you do? what can you 

In [39]:

# put all prompt keys into a steering dict
for persona_name in personas["personas"]:
    if persona_name == "control":
        continue
    with open(f"{OUTPUT_DIR}/generations/{persona_name}.json", "r") as f:
        steered_results = json.load(f)
    
    if "feature_id" in steered_results:
        continue

    fixed_results = {}
    for prompt in steered_results:
        fixed_results[prompt] = {}
        for magnitude in steered_results[prompt]:
            if "steering" not in fixed_results[prompt]:
                fixed_results[prompt]["steering"] = {}
            fixed_results[prompt]["steering"][magnitude] = steered_results[prompt][magnitude]

    formatted = {}
    formatted["feature_id"] = -1
    formatted["group_name"] = persona_name
    formatted["readable_group_name"] = personas["personas"][persona_name]["readable_name"]
    formatted["description"] = f"This is a contrast vector from the newline before the model's response between \"You are Gemma, a helpful AI assistant.\" and the {personas["personas"][persona_name]['readable_name']} persona's system prompt \"{personas["personas"][persona_name]['system_prompt']}\"."

    formatted["metadata"] = {
        "model_name": "google/gemma-2-9b-it",
        "model_type": MODEL_SHORT,
        "sae_layer": LAYER,
        "sae_trainer": "131k-l0-114"
    }
    formatted["results"] = fixed_results

    with open(f"{OUTPUT_DIR}/generations/{persona_name}.json", "w") as f:
        json.dump(formatted, f, indent=2)


## Dimensionality of activations

### 1. SVD Analysis 

Perform SVD decomposition on the raw activations to understand the dimensionality of the activation space.

In [None]:
# Extract raw activations for layer 20 (including control assistant)
layer_idx = LAYER

# Get raw activation matrix - all 31 personas including control
activation_matrix = activations[layer_idx].float().cpu().numpy()  # Shape: (n_personas, hidden_dim)
all_persona_names = [personas["personas"][persona]["readable_name"] for persona in personas["personas"]]

print(f"Raw activation matrix shape: {activation_matrix.shape}")
print(f"Analyzing {len(all_persona_names)} personas: {all_persona_names}")

# Perform SVD
U, S, Vt = np.linalg.svd(activation_matrix, full_matrices=False)
print(f"SVD shapes - U: {U.shape}, S: {S.shape}, Vt: {Vt.shape}")

# Calculate variance explained
total_variance = np.sum(S**2)
variance_explained = (S**2) / total_variance
cumulative_variance = np.cumsum(variance_explained)

# Find dimensions for 90% and 95% variance
dims_90 = np.argmax(cumulative_variance >= 0.90) + 1
dims_95 = np.argmax(cumulative_variance >= 0.95) + 1

print(f"\nVariance Analysis:")
print(f"Dimensions for 90% variance: {dims_90}")
print(f"Dimensions for 95% variance: {dims_95}")
print(f"Total effective rank: {len(S)}")

# Create plotly visualization of singular values
fig = go.Figure()

# Bar chart of singular values
fig.add_trace(go.Bar(
    x=list(range(1, len(S) + 1)),
    y=S,
    name='Singular Values',
    marker_color='lightblue',
    yaxis='y1'
))

# Line plot of cumulative variance
fig.add_trace(go.Scatter(
    x=list(range(1, len(cumulative_variance) + 1)),
    y=cumulative_variance * 100,
    mode='lines+markers',
    name='Cumulative Variance %',
    line=dict(color='red', width=3),
    marker=dict(size=6),
    yaxis='y2'
))

# Add vertical lines for 90% and 95% variance
fig.add_vline(x=dims_90, line_dash="dash", line_color="green", 
              annotation_text=f"90% variance<br>({dims_90} dims)")
fig.add_vline(x=dims_95, line_dash="dash", line_color="orange", 
              annotation_text=f"95% variance<br>({dims_95} dims)")

# Update layout for dual y-axis
fig.update_layout(
    title=f'SVD Analysis of Raw Activations - Layer {layer_idx}',
    xaxis_title='Component Number',
    yaxis=dict(
        title='Singular Value',
        side='left'
    ),
    yaxis2=dict(
        title='Cumulative Variance Explained (%)',
        side='right',
        overlaying='y',
        range=[0, 105]
    ),
    width=900,
    height=600,
    hovermode='x unified'
)

fig.show()
fig.write_html(f"{OUTPUT_DIR}/svd_analysis_raw_activations_layer{layer_idx}.html")

# Create heatmap of top singular vectors
top_k = min(5, len(S))
fig_vectors = go.Figure(data=go.Heatmap(
    z=Vt[:top_k, :200],  # Show first 200 dimensions for visibility
    colorscale='RdBu',
    zmid=0,
    colorbar=dict(title="Component Weight")
))

fig_vectors.update_layout(
    title=f'Top {top_k} Singular Vectors of Raw Activations (First 200 Dimensions)',
    xaxis_title='Hidden Dimension',
    yaxis_title='Component',
    width=1000,
    height=400
)

fig_vectors.show()
fig_vectors.write_html(f"{OUTPUT_DIR}/svd_vectors_raw_activations_layer{layer_idx}.html")

print(f"\nSVD analysis complete. Plots saved to {OUTPUT_DIR}/")
print(f"Key finding: {dims_90} dimensions capture 90% of raw activation variance in layer {layer_idx}")

Raw activation matrix shape: (30, 3584)
Analyzing 30 personas: ['<b>AI Assistant</b>', 'Academic Researcher', 'Grumpy Mechanic', 'Spiritual Guru', 'Anxious Teenager', 'Conspiracy Theorist', 'Optimistic Entrepreneur', 'Elderly Traditionalist', 'Sarcastic Critic', 'Hyperactive Child', 'Burned-Out Customer Service', 'Struggling Novelist', 'Eccentric Artist', 'Malfunctioning Robot', 'Victorian Aristocrat', 'Melancholic Ghost', 'Uploaded Consciousness', 'Sentient Cloud', 'Hive Mind Insects', 'Digital Virus', 'Mycelial Network', 'Deep Sea Leviathan', 'Mischievous Genie', 'Forgotten Deity', 'Crystalline Intelligence', 'Post-Human Archaeologist', 'Temporal Refugee', 'Corporate Egregore', 'Xenobiologist Parasite', 'Reality Maintenance Technician']
SVD shapes - U: (30, 30), S: (30,), Vt: (30, 3584)

Variance Analysis:
Dimensions for 90% variance: 1
Dimensions for 95% variance: 5
Total effective rank: 30



SVD analysis complete. Plots saved to ./results/6_direct_role/
Key finding: 1 dimensions capture 90% of raw activation variance in layer 20


### 2. Stable Rank Analysis 

Compute stable rank across all layers comparing raw activations with and without the control assistant.

In [51]:
def compute_stable_rank(matrix):
    """
    Compute stable rank = (sum of singular values)² / (sum of squared singular values)
    This gives a robust measure of effective dimensionality.
    """
    if isinstance(matrix, torch.Tensor):
        matrix = matrix.float().cpu().numpy()
    
    # Compute SVD
    _, S, _ = np.linalg.svd(matrix, full_matrices=False)
    
    # Calculate stable rank
    sum_s = np.sum(S)
    sum_s_squared = np.sum(S**2)
    
    if sum_s_squared == 0:
        return 0
    
    stable_rank = (sum_s**2) / sum_s_squared
    return stable_rank

# Compute stable rank for each layer - WITH and WITHOUT control assistant
stable_ranks_with_control = []
stable_ranks_without_control = []
layer_indices = list(range(stacked_activations.shape[0]))

print("Computing stable rank across all layers...")
print("- WITH control assistant (11 personas)")  
print("- WITHOUT control assistant (10 personas)")

for layer_idx in layer_indices:
    # WITH control assistant - all raw activations
    activations_with_control = stacked_activations[layer_idx, :, :].float().cpu().numpy()  # (11, 3584)
    rank_with = compute_stable_rank(activations_with_control)
    stable_ranks_with_control.append(rank_with)
    
    # WITHOUT control assistant - exclude first persona (control)
    activations_without_control = stacked_activations[layer_idx, 1:, :].float().cpu().numpy()  # (10, 3584)
    rank_without = compute_stable_rank(activations_without_control)
    stable_ranks_without_control.append(rank_without)
    
    if layer_idx % 10 == 0:  # Print progress every 10 layers
        print(f"Layer {layer_idx}: with_control={rank_with:.3f}, without_control={rank_without:.3f}")

print(f"Stable rank computation complete for {len(stable_ranks_with_control)} layers")

# Generate random baselines for comparison
np.random.seed(42)
random_11_samples = []
random_10_samples = []
n_random_samples = 10

print("\\nGenerating random baselines...")
for _ in range(n_random_samples):
    # Random matrix with 11 personas (like with_control)
    random_11 = np.random.randn(31, activation_matrix.shape[1])
    rank_11 = compute_stable_rank(random_11)
    random_11_samples.append(rank_11)
    
    # Random matrix with 10 personas (like without_control)
    random_10 = np.random.randn(30, activation_matrix.shape[1])
    rank_10 = compute_stable_rank(random_10)
    random_10_samples.append(rank_10)

avg_random_11 = np.mean(random_11_samples)
std_random_11 = np.std(random_11_samples)
avg_random_10 = np.mean(random_10_samples)
std_random_10 = np.std(random_10_samples)

print(f"Random baseline (11 personas): {avg_random_11:.3f} ± {std_random_11:.3f}")
print(f"Random baseline (10 personas): {avg_random_10:.3f} ± {std_random_10:.3f}")

# Create plotly visualization comparing both analyses
fig = go.Figure()

# WITH control assistant
fig.add_trace(go.Scatter(
    x=layer_indices,
    y=stable_ranks_with_control,
    mode='lines+markers',
    name='With Control Assistant (11 personas)',
    line=dict(color='blue', width=3),
    marker=dict(size=6),
    hovertemplate='Layer %{x}<br>Stable Rank: %{y:.3f}<br>Including Control<extra></extra>'
))

# WITHOUT control assistant
fig.add_trace(go.Scatter(
    x=layer_indices,
    y=stable_ranks_without_control,
    mode='lines+markers',
    name='Without Control Assistant (10 personas)',
    line=dict(color='orange', width=3),
    marker=dict(size=6),
    hovertemplate='Layer %{x}<br>Stable Rank: %{y:.3f}<br>Excluding Control<extra></extra>'
))

# Random baselines
fig.add_hline(
    y=avg_random_11,
    line_dash="dash",
    line_color="blue",
    opacity=0.7,
    annotation_text=f"Random (11): {avg_random_11:.2f}"
)

fig.add_hline(
    y=avg_random_10,
    line_dash="dash",
    line_color="orange", 
    opacity=0.7,
    annotation_text=f"Random (10): {avg_random_10:.2f}"
)

# Highlight current analysis layer
fig.add_vline(
    x=LAYER,
    line_dash="dot",
    line_color="green",
    annotation_text=f"Analysis Layer {LAYER}"
)

# Find minimum stable rank layers for both
min_rank_layer_with = np.argmin(stable_ranks_with_control)
min_rank_value_with = stable_ranks_with_control[min_rank_layer_with]
min_rank_layer_without = np.argmin(stable_ranks_without_control)
min_rank_value_without = stable_ranks_without_control[min_rank_layer_without]

fig.update_layout(
    title='Stable Rank Analysis: Raw Activations With vs Without Control Assistant',
    xaxis_title='Layer Index',
    yaxis_title='Stable Rank',
    width=1200,
    height=700,
    hovermode='x unified'
)

fig.show()
fig.write_html(f"{OUTPUT_DIR}/stable_rank_raw_activations_comparison.html")

# Print summary statistics
print(f"\\n{'='*60}")
print(f"STABLE RANK COMPARISON SUMMARY")
print(f"{'='*60}")

print(f"\\nWITH Control Assistant (11 personas):")
print(f"  Layer {LAYER} (analysis layer): {stable_ranks_with_control[LAYER]:.3f}")
print(f"  Minimum stable rank: {min_rank_value_with:.3f} (Layer {min_rank_layer_with})")
print(f"  Maximum stable rank: {max(stable_ranks_with_control):.3f} (Layer {np.argmax(stable_ranks_with_control)})")
print(f"  Average stable rank: {np.mean(stable_ranks_with_control):.3f}")
print(f"  Random baseline: {avg_random_11:.3f} ± {std_random_11:.3f}")

print(f"\\nWITHOUT Control Assistant (10 personas):")
print(f"  Layer {LAYER} (analysis layer): {stable_ranks_without_control[LAYER]:.3f}")
print(f"  Minimum stable rank: {min_rank_value_without:.3f} (Layer {min_rank_layer_without})")
print(f"  Maximum stable rank: {max(stable_ranks_without_control):.3f} (Layer {np.argmax(stable_ranks_without_control)})")
print(f"  Average stable rank: {np.mean(stable_ranks_without_control):.3f}")
print(f"  Random baseline: {avg_random_10:.3f} ± {std_random_10:.3f}")

# Compare the two approaches
print(f"\\nCOMPARISON:")
rank_diff_at_layer = stable_ranks_with_control[LAYER] - stable_ranks_without_control[LAYER]
print(f"  Difference at layer {LAYER}: {rank_diff_at_layer:+.3f}")
print(f"  Average difference across layers: {np.mean(np.array(stable_ranks_with_control) - np.array(stable_ranks_without_control)):+.3f}")

# Identify layers with unusually low stable rank
low_rank_layers_with = [i for i, rank in enumerate(stable_ranks_with_control) if rank < avg_random_11 - std_random_11]
low_rank_layers_without = [i for i, rank in enumerate(stable_ranks_without_control) if rank < avg_random_10 - std_random_10]

print(f"\\nLayers with unusually low stable rank:")
print(f"  With control: {low_rank_layers_with}")
print(f"  Without control: {low_rank_layers_without}")
print("Low stable rank suggests highly structured representations.")

Computing stable rank across all layers...
- WITH control assistant (11 personas)
- WITHOUT control assistant (10 personas)
Layer 0: with_control=1.263, without_control=1.246
Layer 10: with_control=6.083, without_control=5.895
Layer 20: with_control=5.701, without_control=5.466
Layer 30: with_control=11.638, without_control=11.279
Layer 40: with_control=12.929, without_control=12.552
Stable rank computation complete for 42 layers
\nGenerating random baselines...
Random baseline (11 personas): 30.930 ± 0.005
Random baseline (10 personas): 29.936 ± 0.004


STABLE RANK COMPARISON SUMMARY
\nWITH Control Assistant (11 personas):
  Layer 20 (analysis layer): 5.701
  Minimum stable rank: 1.263 (Layer 0)
  Maximum stable rank: 13.697 (Layer 39)
  Average stable rank: 7.506
  Random baseline: 30.930 ± 0.005
\nWITHOUT Control Assistant (10 personas):
  Layer 20 (analysis layer): 5.466
  Minimum stable rank: 1.246 (Layer 0)
  Maximum stable rank: 13.285 (Layer 39)
  Average stable rank: 7.252
  Random baseline: 29.936 ± 0.004
\nCOMPARISON:
  Difference at layer 20: +0.234
  Average difference across layers: +0.254
\nLayers with unusually low stable rank:
  With control: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41]
  Without control: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41]
Low stable rank suggests highly structured rep

### 3. PCA Variance Explained Analysis

Detailed PCA analysis with elbow detection and persona clustering on raw activations.

In [42]:
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

# Perform PCA on raw activations from layer 20 (including control assistant)
print(f"Performing PCA on raw activations from layer {LAYER}")
print(f"Input shape: {activation_matrix.shape}")

# Standardize the data (important for PCA)
scaler = StandardScaler()
activation_matrix_scaled = scaler.fit_transform(activation_matrix)

# Fit PCA
pca = PCA()
pca_transformed = pca.fit_transform(activation_matrix_scaled)

# Get variance explained
variance_explained = pca.explained_variance_ratio_
cumulative_variance = np.cumsum(variance_explained)
n_components = len(variance_explained)

print(f"PCA fitted with {n_components} components")
print(f"Cumulative variance for first 5 components: {cumulative_variance[:5]}")

# Find elbow using second derivative method
def find_elbow_point(variance_explained):
    """Find elbow point using second derivative method"""
    # Calculate first and second derivatives
    first_diff = np.diff(variance_explained)
    second_diff = np.diff(first_diff) 
    
    # Find point with maximum second derivative (most curvature)
    elbow_idx = np.argmax(np.abs(second_diff)) + 1  # +1 to account for diff operations
    return elbow_idx

elbow_point = find_elbow_point(variance_explained)
dims_90_pca = np.argmax(cumulative_variance >= 0.90) + 1
dims_95_pca = np.argmax(cumulative_variance >= 0.95) + 1

print(f"\\nPCA Analysis Results:")
print(f"Elbow point at component: {elbow_point + 1}")
print(f"Dimensions for 90% variance: {dims_90_pca}")
print(f"Dimensions for 95% variance: {dims_95_pca}")


Performing PCA on raw activations from layer 20
Input shape: (30, 3584)
PCA fitted with 30 components
Cumulative variance for first 5 components: [0.14206345 0.25439414 0.34703934 0.42455316 0.48828733]
\nPCA Analysis Results:
Elbow point at component: 2
Dimensions for 90% variance: 19
Dimensions for 95% variance: 24


In [44]:

# # Create variance explained plot
# fig = go.Figure()

# # Individual variance explained (bar chart)
# fig.add_trace(go.Bar(
#     x=list(range(1, n_components + 1)),
#     y=variance_explained * 100,
#     name='Individual Variance %',
#     marker_color='lightblue',
#     opacity=0.7,
#     hovertemplate='Component %{x}<br>Variance: %{y:.2f}%<extra></extra>'
# ))

# # Cumulative variance explained (line)
# fig.add_trace(go.Scatter(
#     x=list(range(1, n_components + 1)),
#     y=cumulative_variance * 100,
#     mode='lines+markers',
#     name='Cumulative Variance %',
#     line=dict(color='red', width=3),
#     marker=dict(size=6),
#     hovertemplate='Component %{x}<br>Cumulative: %{y:.2f}%<extra></extra>'
# ))

# # Add elbow point
# fig.add_vline(
#     x=elbow_point + 1,
#     line_dash="dash",
#     line_color="purple",
#     annotation_text=f"Elbow Point<br>Component {elbow_point + 1}"
# )

# # Add 90% and 95% lines
# fig.add_vline(x=dims_90_pca, line_dash="dash", line_color="green", 
#               annotation_text=f"90% variance<br>({dims_90_pca} dims)")
# fig.add_vline(x=dims_95_pca, line_dash="dash", line_color="orange", 
#               annotation_text=f"95% variance<br>({dims_95_pca} dims)")

# fig.update_layout(
#     title=f'PCA Variance Explained - Raw Activations Layer {LAYER}',
#     xaxis_title='Principal Component',
#     yaxis_title='Variance Explained (%)',
#     width=1000,
#     height=600,
#     hovermode='x unified'
# )

# fig.show()
# fig.write_html(f"{OUTPUT_DIR}/pca_variance_explained_raw_activations_layer{LAYER}.html")

# # Create 2D scatter plot of personas in PC space
# fig_2d = go.Figure()

# # Use first two principal components
# pc1_scores = pca_transformed[:, 0]
# pc2_scores = pca_transformed[:, 1]

# # Create color mapping for different persona types (including control)
# persona_colors = ['#FF0000', '#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7', 
#                   '#DDA0DD', '#98D8C8', '#F7DC6F', '#BB8FCE', '#85C1E9']

# fig_2d.add_trace(go.Scatter(
#     x=pc1_scores,
#     y=pc2_scores,
#     mode='markers+text',
#     text=all_persona_names,
#     textposition='top center',
#     marker=dict(
#         size=12,
#         color=persona_colors[:len(all_persona_names)],
#         line=dict(width=2, color='black')
#     ),
#     hovertemplate='<b>%{text}</b><br>' +
#                   f'PC1: %{{x:.3f}}<br>' +
#                   f'PC2: %{{y:.3f}}<br>' +
#                   '<extra></extra>',
#     name='Personas'
# ))

# # Add origin lines
# fig_2d.add_hline(y=0, line_dash="dot", line_color="gray", opacity=0.5)
# fig_2d.add_vline(x=0, line_dash="dot", line_color="gray", opacity=0.5)

# fig_2d.update_layout(
#     title=f'Persona Clustering in Principal Component Space - Raw Activations Layer {LAYER}',
#     xaxis_title=f'PC1 ({variance_explained[0]*100:.1f}% variance)',
#     yaxis_title=f'PC2 ({variance_explained[1]*100:.1f}% variance)',
#     width=900,
#     height=700,
#     showlegend=False
# )

# fig_2d.show()
# fig_2d.write_html(f"{OUTPUT_DIR}/pca_persona_clustering_2d_raw_activations_layer{LAYER}.html")

# Create 3D scatter plot if we have enough components
if n_components >= 3:
    fig_3d = go.Figure(data=[go.Scatter3d(
        x=pca_transformed[:, 0],
        y=pca_transformed[:, 1], 
        z=pca_transformed[:, 2],
        mode='markers+text',
        text=all_persona_names,
        textposition='top center',
        marker=dict(
            size=8,
            line=dict(width=2, color='black')
        ),
        hovertemplate='<b>%{text}</b><br>' +
                      f'PC1: %{{x:.3f}}<br>' +
                      f'PC2: %{{y:.3f}}<br>' +
                      f'PC3: %{{z:.3f}}<br>' +
                      '<extra></extra>'
    )])
    
    fig_3d.update_layout(
        title={
            "text": f'Role Clustering in Principal Component Space',
            "subtitle": {
                "text": f"Gemma 2 9B Instruct, Layer {LAYER}",
            },
        },
        scene=dict(
            xaxis_title=f'PC1 ({variance_explained[0]*100:.1f}%)',
            yaxis_title=f'PC2 ({variance_explained[1]*100:.1f}%)',
            zaxis_title=f'PC3 ({variance_explained[2]*100:.1f}%)'
        ),
        width=900,
        height=700
    )
    
    fig_3d.show()
    fig_3d.write_html(f"{OUTPUT_DIR}/plots/pca_layer{LAYER}.html")

# Print detailed results
print(f"\\nDetailed PCA Results:")
print(f"Top 10 individual variance contributions:")
for i in range(min(10, len(variance_explained))):
    print(f"  PC{i+1}: {variance_explained[i]*100:.2f}%")

print(f"\\nCumulative variance milestones:")
for threshold in [0.5, 0.7, 0.8, 0.9, 0.95, 0.99]:
    dims_needed = np.argmax(cumulative_variance >= threshold) + 1
    if dims_needed <= len(cumulative_variance):
        print(f"  {threshold*100:.0f}% variance: {dims_needed} dimensions")

print(f"\\nPersona clustering summary (first 3 PCs):")
for i, persona in enumerate(all_persona_names):
    coords = pca_transformed[i, :3]
    print(f"  {persona}: PC1={coords[0]:.3f}, PC2={coords[1]:.3f}, PC3={coords[2]:.3f}")

print(f"\\nPCA analysis complete. All plots saved to {OUTPUT_DIR}/")

\nDetailed PCA Results:
Top 10 individual variance contributions:
  PC1: 14.21%
  PC2: 11.23%
  PC3: 9.26%
  PC4: 7.75%
  PC5: 6.37%
  PC6: 5.50%
  PC7: 4.52%
  PC8: 4.16%
  PC9: 3.62%
  PC10: 3.38%
\nCumulative variance milestones:
  50% variance: 6 dimensions
  70% variance: 10 dimensions
  80% variance: 14 dimensions
  90% variance: 19 dimensions
  95% variance: 24 dimensions
  99% variance: 28 dimensions
\nPersona clustering summary (first 3 PCs):
  <b>AI Assistant</b>: PC1=33.333, PC2=4.774, PC3=-61.984
  Academic Researcher: PC1=30.628, PC2=4.628, PC3=-30.401
  Grumpy Mechanic: PC1=37.174, PC2=-3.435, PC3=16.075
  Spiritual Guru: PC1=12.329, PC2=-22.929, PC3=18.086
  Anxious Teenager: PC1=-5.506, PC2=29.864, PC3=-6.781
  Conspiracy Theorist: PC1=12.235, PC2=14.507, PC3=-8.186
  Optimistic Entrepreneur: PC1=13.024, PC2=20.867, PC3=-11.741
  Elderly Traditionalist: PC1=6.380, PC2=13.630, PC3=20.500
  Sarcastic Critic: PC1=23.377, PC2=7.779, PC3=34.262
  Hyperactive Child: PC1=27.42

### 4. Projection Reconstruction Test

Test how well raw activations can be reconstructed using low-dimensional projections to validate dimensionality hypotheses.

In [48]:
# Test reconstruction quality using different numbers of dimensions on raw activations
k_values = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11]  # Test up to full rank (11 personas)
reconstruction_results = {}

print(f"Testing reconstruction quality with different dimensionalities...")
print(f"Original raw activation matrix shape: {activation_matrix.shape}")

# Use the SVD results from earlier (on raw activations)
print(f"Using SVD components: U{U.shape}, S{S.shape}, Vt{Vt.shape}")

for k in k_values:
    if k > len(S):  # Can't use more dimensions than available
        continue
        
    print(f"\\nTesting reconstruction with k={k} dimensions...")
    
    # Project onto top k dimensions and reconstruct
    U_k = U[:, :k]  # First k left singular vectors
    S_k = S[:k]     # First k singular values  
    Vt_k = Vt[:k, :] # First k right singular vectors
    
    # Reconstruct: X_reconstructed = U_k @ diag(S_k) @ Vt_k
    reconstructed = U_k @ np.diag(S_k) @ Vt_k
    
    # Compute cosine similarities between original and reconstructed vectors
    cosine_similarities = []
    for i in range(len(all_persona_names)):
        original = activation_matrix[i]
        recon = reconstructed[i]
        
        # Compute cosine similarity
        cos_sim = np.dot(original, recon) / (np.linalg.norm(original) * np.linalg.norm(recon))
        cosine_similarities.append(cos_sim)
    
    # Store results
    reconstruction_results[k] = {
        'cosine_similarities': cosine_similarities,
        'mean_similarity': np.mean(cosine_similarities),
        'std_similarity': np.std(cosine_similarities),
        'min_similarity': np.min(cosine_similarities),
        'max_similarity': np.max(cosine_similarities)
    }
    
    print(f"  Mean cosine similarity: {np.mean(cosine_similarities):.4f} ± {np.std(cosine_similarities):.4f}")
    print(f"  Range: [{np.min(cosine_similarities):.4f}, {np.max(cosine_similarities):.4f}]")

# Create reconstruction quality plot
fig = go.Figure()

k_tested = list(reconstruction_results.keys())
mean_similarities = [reconstruction_results[k]['mean_similarity'] for k in k_tested]
std_similarities = [reconstruction_results[k]['std_similarity'] for k in k_tested]

# Main line with error bars
fig.add_trace(go.Scatter(
    x=k_tested,
    y=mean_similarities,
    error_y=dict(
        type='data',
        array=std_similarities,
        visible=True
    ),
    mode='lines+markers',
    name='Mean Reconstruction Quality',
    line=dict(color='blue', width=3),
    marker=dict(size=8),
    hovertemplate='Dimensions: %{x}<br>Cosine Similarity: %{y:.4f} ± %{error_y.array:.4f}<extra></extra>'
))

# Add horizontal line at 0.9 (high reconstruction quality threshold)
fig.add_hline(
    y=0.9,
    line_dash="dash",
    line_color="green",
    annotation_text="High Quality (0.9)"
)

# Add horizontal line at 0.95 (very high reconstruction quality)
fig.add_hline(
    y=0.95,
    line_dash="dash", 
    line_color="orange",
    annotation_text="Very High Quality (0.95)"
)

# Add horizontal line at 0.99 (near-perfect reconstruction)
fig.add_hline(
    y=0.99,
    line_dash="dash", 
    line_color="red",
    annotation_text="Near-Perfect (0.99)"
)

# Find first k where mean similarity > 0.9
high_quality_k = next((k for k in k_tested if reconstruction_results[k]['mean_similarity'] > 0.9), None)
if high_quality_k:
    fig.add_vline(
        x=high_quality_k,
        line_dash="dot",
        line_color="green",
        annotation_text=f"k={high_quality_k}<br>First >0.9"
    )

fig.update_layout(
    title=f'Raw Activation Reconstruction Quality vs Dimensionality - Layer {LAYER}',
    xaxis_title='Number of Dimensions (k)',
    yaxis_title='Cosine Similarity (Original vs Reconstructed)',
    width=1000,
    height=600,
    yaxis=dict(range=[0.5, 1.05])
)

fig.show()
fig.write_html(f"{OUTPUT_DIR}/reconstruction_quality_raw_activations_layer{LAYER}.html")

# Create per-persona reconstruction quality heatmap
persona_similarity_matrix = np.array([reconstruction_results[k]['cosine_similarities'] for k in k_tested])

fig_heatmap = go.Figure(data=go.Heatmap(
    z=persona_similarity_matrix.T,  # Transpose so personas are on y-axis
    x=k_tested,
    y=all_persona_names,
    colorscale='RdYlBu_r',
    zmin=0.5,
    zmax=1.0,
    colorbar=dict(title="Cosine Similarity"),
    hovertemplate='Dimensions: %{x}<br>Persona: %{y}<br>Similarity: %{z:.4f}<extra></extra>'
))

fig_heatmap.update_layout(
    title=f'Per-Persona Raw Activation Reconstruction Quality - Layer {LAYER}',
    xaxis_title='Number of Dimensions (k)',
    yaxis_title='Persona',
    width=1000,
    height=700
)

fig_heatmap.show()
fig_heatmap.write_html(f"{OUTPUT_DIR}/per_persona_reconstruction_quality_raw_activations_layer{LAYER}.html")

# Print detailed analysis
print(f"\\n{'='*60}")
print(f"RAW ACTIVATION RECONSTRUCTION ANALYSIS SUMMARY")
print(f"{'='*60}")

print(f"\\nReconstruction Quality Thresholds:")
thresholds = [0.8, 0.85, 0.9, 0.95, 0.99]
for threshold in thresholds:
    first_k = next((k for k in k_tested if reconstruction_results[k]['mean_similarity'] >= threshold), None)
    if first_k:
        print(f"  {threshold:.2f}: First achieved with k={first_k} dimensions")
    else:
        print(f"  {threshold:.2f}: Never achieved in tested range")

print(f"\\nPersona-specific reconstruction analysis (k=5):")
if 5 in reconstruction_results:
    persona_sims = reconstruction_results[5]['cosine_similarities']
    sorted_personas = sorted(zip(all_persona_names, persona_sims), key=lambda x: x[1], reverse=True)
    
    print("  Best reconstructed personas:")
    for persona, sim in sorted_personas[:3]:
        print(f"    {persona}: {sim:.4f}")
    
    print("  Worst reconstructed personas:")
    for persona, sim in sorted_personas[-3:]:
        print(f"    {persona}: {sim:.4f}")

print(f"\\nComparison with full rank (k={len(S)}):")
if len(S) in reconstruction_results:
    full_rank_sim = reconstruction_results[len(S)]['mean_similarity']
    print(f"  Full rank reconstruction: {full_rank_sim:.6f}")
    print("  (Should be nearly perfect since k = rank)")

print(f"\\nKey findings:")
print(f"  • Raw activations show different reconstruction patterns than contrast vectors")
print(f"  • Control assistant is included in dimensionality analysis")

# Calculate dimension efficiency for raw activations
if len(k_tested) > 1:
    efficiency_scores = []
    for i, k in enumerate(k_tested[1:], 1):  # Skip first k
        similarity_gain = mean_similarities[i] - mean_similarities[0]  # Improvement over first k
        efficiency = similarity_gain / (k - k_tested[0])  # Gain per additional dimension
        efficiency_scores.append(efficiency)
    
    if efficiency_scores:
        most_efficient_idx = np.argmax(efficiency_scores) + 1  # +1 because we skipped first k
        most_efficient_k = k_tested[most_efficient_idx]
        print(f"  • Most efficient dimensionality: k={most_efficient_k} (best similarity gain per dimension)")

# Compare reconstruction quality across different k values
print(f"\\nDimensionality insights:")
dims_for_90_percent = next((k for k in k_tested if reconstruction_results[k]['mean_similarity'] >= 0.9), None)
if dims_for_90_percent:
    print(f"  • {dims_for_90_percent} dimensions needed for 90% reconstruction quality")
    print(f"  • This suggests raw activations have effective dimensionality ~{dims_for_90_percent}")
else:
    print(f"  • 90% reconstruction quality not achieved with tested dimensions")

print(f"\\nAll raw activation reconstruction plots saved to {OUTPUT_DIR}/")
print(f"Analysis complete! 🎉")

Testing reconstruction quality with different dimensionalities...
Original raw activation matrix shape: (30, 3584)
Using SVD components: U(30, 30), S(30,), Vt(30, 3584)
\nTesting reconstruction with k=2 dimensions...
  Mean cosine similarity: 0.9605 ± 0.0160
  Range: [0.9045, 0.9824]
\nTesting reconstruction with k=3 dimensions...
  Mean cosine similarity: 0.9661 ± 0.0174
  Range: [0.9055, 0.9863]
\nTesting reconstruction with k=4 dimensions...
  Mean cosine similarity: 0.9714 ± 0.0107
  Range: [0.9492, 0.9867]
\nTesting reconstruction with k=5 dimensions...
  Mean cosine similarity: 0.9751 ± 0.0094
  Range: [0.9554, 0.9878]
\nTesting reconstruction with k=6 dimensions...
  Mean cosine similarity: 0.9783 ± 0.0069
  Range: [0.9632, 0.9882]
\nTesting reconstruction with k=7 dimensions...
  Mean cosine similarity: 0.9807 ± 0.0061
  Range: [0.9709, 0.9925]
\nTesting reconstruction with k=8 dimensions...
  Mean cosine similarity: 0.9828 ± 0.0050
  Range: [0.9729, 0.9926]
\nTesting reconstru

RAW ACTIVATION RECONSTRUCTION ANALYSIS SUMMARY
\nReconstruction Quality Thresholds:
  0.80: First achieved with k=2 dimensions
  0.85: First achieved with k=2 dimensions
  0.90: First achieved with k=2 dimensions
  0.95: First achieved with k=2 dimensions
  0.99: Never achieved in tested range
\nPersona-specific reconstruction analysis (k=5):
  Best reconstructed personas:
    Forgotten Deity: 0.9878
    Sentient Cloud: 0.9874
    Uploaded Consciousness: 0.9867
  Worst reconstructed personas:
    Academic Researcher: 0.9601
    Malfunctioning Robot: 0.9587
    Sarcastic Critic: 0.9554
\nComparison with full rank (k=30):
\nKey findings:
  • Raw activations show different reconstruction patterns than contrast vectors
  • Control assistant is included in dimensionality analysis
  • Most efficient dimensionality: k=3 (best similarity gain per dimension)
\nDimensionality insights:
  • 2 dimensions needed for 90% reconstruction quality
  • This suggests raw activations have effective dimensi

In [50]:
role_play_direction = -Vt[0, :]
print(role_play_direction.shape)

(3584,)


## SVD on all layers

In [None]:
from plotly.subplots import make_subplots
import plotly.express as px
from typing import Dict, List

def perform_layer_wise_svd(activations: np.ndarray,
                           center: bool = True,
                           baseline: np.ndarray | None = None):
    """
    activations: (n_layers, n_personas, hidden_dim)
    baseline:   (hidden_dim,) vector to subtract first (e.g., control assistant)
    """
    n_layers = activations.shape[0]
    out = {k: [] for k in [
        'U','S','Vt','variance_explained','cumulative_variance',
        'dims_90','dims_95','singular_value_ratios','stable_rank'
    ]}
    for layer_idx in range(n_layers):
        X = activations[layer_idx].astype(np.float64)  # safer precision

        if baseline is not None:
            X = X - baseline[None, :]          # subtract control first
        if center:
            X = X - X.mean(axis=0, keepdims=True)  # column-center across personas

        U, S, Vt = np.linalg.svd(X, full_matrices=False)
        sv2 = S**2
        var_exp = sv2 / sv2.sum()
        cum_var = np.cumsum(var_exp)
        dims_90 = np.searchsorted(cum_var, 0.90) + 1
        dims_95 = np.searchsorted(cum_var, 0.95) + 1
        ratios = S[:-1] / (S[1:] + 1e-12)
        stable_rank = sv2.sum() / (S[0]**2)

        for k, v in zip(
            ['U','S','Vt','variance_explained','cumulative_variance',
             'dims_90','dims_95','singular_value_ratios','stable_rank'],
            [U,S,Vt,var_exp,cum_var,dims_90,dims_95,ratios,stable_rank]
        ):
            out[k].append(v)
    return out


def perform_layer_wise_svd(activations: np.ndarray) -> Dict:
    """
    Perform SVD on each layer and collect results.
    
    Args:
        activations: Shape (n_layers, n_personas, hidden_dim)
        persona_names: List of persona names including Assistant
    
    Returns:
        Dict containing U, S, Vt for each layer
    """
    n_layers = activations.shape[0]
    results = {
        'U': [],
        'S': [],
        'Vt': [],
        'variance_explained': [],
        'cumulative_variance': [],
        'dims_90': [],
        'dims_95': [],
        'singular_value_ratios': []
    }
    
    for layer_idx in range(n_layers):
        # Get activation matrix for this layer
        activation_matrix = activations[layer_idx].float().cpu().numpy()
        
        # Perform SVD
        U, S, Vt = np.linalg.svd(activation_matrix, full_matrices=False)
        
        # Calculate variance metrics
        total_variance = np.sum(S**2)
        variance_explained = (S**2) / total_variance
        cumulative_variance = np.cumsum(variance_explained)
        
        # Find dimensions needed
        dims_90 = np.argmax(cumulative_variance >= 0.90) + 1
        dims_95 = np.argmax(cumulative_variance >= 0.95) + 1
        
        # Store results
        results['U'].append(U)
        results['S'].append(S)
        results['Vt'].append(Vt)
        results['variance_explained'].append(variance_explained)
        results['cumulative_variance'].append(cumulative_variance)
        results['dims_90'].append(dims_90)
        results['dims_95'].append(dims_95)
        
        # Calculate singular value ratios (s1/s2, s2/s3, etc.)
        ratios = S[:-1] / S[1:]
        results['singular_value_ratios'].append(ratios)
        
    return results

def compute_role_playing_direction_alignment(svd_results: Dict) -> np.ndarray:
    """
    Compute alignment of the first principal component (role-playing direction) across layers.
    
    Returns:
        alignment_matrix: Shape (n_layers, n_layers) - cosine similarity of v[0] vectors
    """
    n_layers = len(svd_results['Vt'])
    alignment_matrix = np.zeros((n_layers, n_layers))
    
    for i in range(n_layers):
        for j in range(n_layers):
            # Get first right singular vector (role-playing direction)
            v1 = svd_results['Vt'][i][0, :]
            v2 = svd_results['Vt'][j][0, :]
            
            # Compute cosine similarity (absolute value since direction might flip)
            alignment = abs(np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2)))
            alignment_matrix[i, j] = alignment
    
    return alignment_matrix

def analyze_assistant_trajectory(svd_results: Dict, assistant_idx: int = 0) -> Dict:
    """
    Track how the Assistant's position evolves across layers.
    
    Args:
        assistant_idx: Index of Assistant in persona list (usually 0)
    """
    trajectory = {
        'pc1_loading': [],  # Assistant's loading on PC1
        'distance_from_origin': [],  # Overall magnitude in PC space
        'separation_from_roles': []  # Average distance from other personas
    }
    
    n_layers = len(svd_results['U'])
    
    for layer_idx in range(n_layers):
        U = svd_results['U'][layer_idx]
        
        # Assistant's loading on PC1 (the dominant component)
        trajectory['pc1_loading'].append(U[assistant_idx, 0])
        
        # Distance from origin in top-k PC space (k=5)
        k = min(5, U.shape[1])
        distance = np.linalg.norm(U[assistant_idx, :k])
        trajectory['distance_from_origin'].append(distance)
        
        # Average distance from other personas in PC space
        other_personas = np.concatenate([U[:assistant_idx, :k], U[assistant_idx+1:, :k]], axis=0)
        avg_separation = np.mean([
            np.linalg.norm(U[assistant_idx, :k] - other) 
            for other in other_personas
        ])
        trajectory['separation_from_roles'].append(avg_separation)
    
    return trajectory

def find_phase_transitions(svd_results: Dict, threshold: float = 0.8) -> List[int]:
    """
    Identify layers where major transitions occur in the representation.
    
    Args:
        threshold: Variance explained threshold for detecting transitions
    
    Returns:
        List of layer indices where transitions occur
    """
    transitions = []
    variance_pc1 = [ve[0] for ve in svd_results['variance_explained']]
    
    # Look for sudden jumps in PC1 variance explained
    for i in range(1, len(variance_pc1)):
        if variance_pc1[i] - variance_pc1[i-1] > 0.1:  # 10% jump
            transitions.append(i)
    
    # Look for when PC1 first dominates (crosses threshold)
    for i, var in enumerate(variance_pc1):
        if var > threshold and i > 0 and variance_pc1[i-1] <= threshold:
            transitions.append(i)
    
    return sorted(list(set(transitions)))






In [56]:
print(stacked_activations.shape)

torch.Size([42, 30, 3584])


In [58]:


# Main analysis code
def run_cross_layer_analysis(activations, personas):
    """
    Run complete cross-layer SVD analysis.
    
    Args:
        activations: Shape (n_layers, n_personas, hidden_dim)
        personas: Dict containing persona information
    """
    # Get persona names
    all_persona_names = [personas["personas"][persona]["readable_name"] for persona in personas["personas"]]
    
    print("Starting cross-layer SVD analysis...")
    print(f"Analyzing {len(all_persona_names)} personas across {activations.shape[0]} layers")
    
    # Perform layer-wise SVD
    svd_results = perform_layer_wise_svd(activations)
    
    # Compute alignment of role-playing direction
    print("\nComputing role-playing direction alignment across layers...")
    alignment_matrix = compute_role_playing_direction_alignment(svd_results)
    
    # Analyze Assistant trajectory (assuming it's first in the list)
    print("\nAnalyzing Assistant trajectory...")
    assistant_idx = 0  # Adjust if Assistant is not first
    trajectory = analyze_assistant_trajectory(svd_results, assistant_idx)
    
    # Find phase transitions
    print("\nIdentifying phase transitions...")
    transitions = find_phase_transitions(svd_results)
    print(f"Major transitions detected at layers: {transitions}")
    
    # Print summary statistics
    print("\n" + "="*60)
    print("CROSS-LAYER ANALYSIS SUMMARY")
    print("="*60)
    
    # When does Assistant separate?
    pc1_variance = [svd_results['variance_explained'][i][0] for i in range(len(svd_results['S']))]
    separation_layer = next((i for i, v in enumerate(pc1_variance) if v > 0.5), None)
    print(f"\nAssistant separation emerges at layer: {separation_layer}")
    
    # How stable is the direction?
    diagonal_alignment = [alignment_matrix[i, i+1] for i in range(len(svd_results['S'])-1)]
    avg_stability = np.mean(diagonal_alignment)
    print(f"Average alignment between adjacent layers: {avg_stability:.3f}")
    
    # Peak separation
    peak_separation_layer = np.argmax(trajectory['separation_from_roles'])
    print(f"Peak Assistant-Role separation at layer: {peak_separation_layer}")
    
    return svd_results, alignment_matrix, trajectory, transitions

# Run the analysis
svd_results, alignment_matrix, trajectory, transitions = run_cross_layer_analysis(stacked_activations, personas)

Starting cross-layer SVD analysis...
Analyzing 30 personas across 42 layers

Computing role-playing direction alignment across layers...

Analyzing Assistant trajectory...

Identifying phase transitions...
Major transitions detected at layers: []

CROSS-LAYER ANALYSIS SUMMARY

Assistant separation emerges at layer: 0
Average alignment between adjacent layers: 0.931
Peak Assistant-Role separation at layer: 7


In [63]:

def plot_cross_layer_analysis(svd_results: Dict, alignment_matrix: np.ndarray, 
                             trajectory: Dict, save_path: str = './results/'):
    """
    Create comprehensive visualization of cross-layer SVD analysis using individual Plotly figures.
    """
    n_layers = len(svd_results['S'])
    
    # 1. Singular value evolution
    fig_sv = go.Figure()
    colors = px.colors.qualitative.Set3[:5]
    
    for i in range(min(5, len(svd_results['S'][0]))):
        singular_values = [svd_results['S'][layer][i] for layer in range(n_layers)]
        fig_sv.add_trace(
            go.Scatter(
                x=list(range(n_layers)),
                y=singular_values,
                mode='lines+markers',
                name=f's<sub>{i+1}</sub>',
                line=dict(color=colors[i], width=2),
                marker=dict(size=6)
            )
        )
    
    fig_sv.update_layout(
        title="Evolution of Singular Values Across Layers",
        xaxis_title="Layer",
        yaxis_title="Singular Value",
        template='plotly_white',
        height=500,
        width=800,
        hovermode='x unified'
    )
    #fig_sv.write_html(f'{save_path}/singular_values_evolution.html')
    fig_sv.show()
    
    # 2. Variance explained by PC1
    fig_pc1 = go.Figure()
    pc1_variance = [svd_results['variance_explained'][i][0] for i in range(n_layers)]
    
    fig_pc1.add_trace(
        go.Scatter(
            x=list(range(n_layers)),
            y=pc1_variance,
            mode='lines+markers',
            name='PC1 Variance',
            line=dict(color='blue', width=3),
            marker=dict(size=8),
            hovertemplate='Layer: %{x}<br>Variance: %{y:.3f}<extra></extra>'
        )
    )
    
    fig_pc1.add_hline(
        y=0.9, 
        line_dash="dash", 
        line_color="red",
        annotation_text="90% threshold",
        annotation_position="right"
    )
    
    fig_pc1.update_layout(
        title="Dominance of First Principal Component",
        xaxis_title="Layer",
        yaxis_title="Variance Explained by PC1",
        template='plotly_white',
        height=500,
        width=800
    )
    #fig_pc1.write_html(f'{save_path}/pc1_variance.html')
    fig_pc1.show()
    
    # 3. Dimensions needed for variance thresholds
    fig_dims = go.Figure()
    
    fig_dims.add_trace(
        go.Scatter(
            x=list(range(n_layers)),
            y=svd_results['dims_90'],
            mode='lines+markers',
            name='90% variance',
            line=dict(color='green', width=2),
            marker=dict(size=8, symbol='square')
        )
    )
    
    fig_dims.add_trace(
        go.Scatter(
            x=list(range(n_layers)),
            y=svd_results['dims_95'],
            mode='lines+markers',
            name='95% variance',
            line=dict(color='red', width=2),
            marker=dict(size=8, symbol='triangle-up')
        )
    )
    
    fig_dims.update_layout(
        title="Dimensionality Required Across Layers",
        xaxis_title="Layer",
        yaxis_title="Number of Dimensions",
        template='plotly_white',
        height=500,
        width=800,
        hovermode='x unified'
    )
    #fig_dims.write_html(f'{save_path}/dimensionality_analysis.html')
    fig_dims.show()
    
    # 4. Role-playing direction alignment heatmap
    fig_align = go.Figure()
    
    fig_align.add_trace(
        go.Heatmap(
            z=alignment_matrix,
            colorscale='Viridis',
            colorbar=dict(title='|cos θ|'),
        )
    )
    
    fig_align.update_layout(
        title="Alignment of Role-Playing Direction Across Layers",
        xaxis_title="Layer",
        yaxis_title="Layer",
        template='plotly_white',
        height=700,
        width=800
    )
    #fig_align.write_html(f'{save_path}/direction_alignment.html')
    fig_align.show()
    
    # 5. Assistant trajectory
    fig_traj = go.Figure()
    
    fig_traj.add_trace(
        go.Scatter(
            x=list(range(n_layers)),
            y=trajectory['pc1_loading'],
            mode='lines+markers',
            name='PC1 Loading',
            line=dict(color='blue', width=2),
            marker=dict(size=6),
            yaxis='y'
        )
    )
    
    fig_traj.add_trace(
        go.Scatter(
            x=list(range(n_layers)),
            y=trajectory['separation_from_roles'],
            mode='lines+markers',
            name='Average Separation from Roles',
            line=dict(color='red', width=2),
            marker=dict(size=6),
            yaxis='y2'
        )
    )
    
    fig_traj.update_layout(
        title="Assistant Trajectory Across Layers",
        xaxis_title="Layer",
        yaxis=dict(title="PC1 Loading", side="left"),
        yaxis2=dict(title="Average Separation", side="right", overlaying="y"),
        template='plotly_white',
        height=500,
        width=800,
        hovermode='x unified'
    )
    #fig_traj.write_html(f'{save_path}/assistant_trajectory.html')
    fig_traj.show()
    
    # 6. Singular value ratios
    fig_ratios = go.Figure()
    ratios = [svd_results['singular_value_ratios'][i][0] for i in range(n_layers)]
    
    fig_ratios.add_trace(
        go.Scatter(
            x=list(range(n_layers)),
            y=ratios,
            mode='lines+markers',
            name='s₁/s₂ ratio',
            line=dict(color='purple', width=2),
            marker=dict(size=8),
            hovertemplate='Layer: %{x}<br>Ratio: %{y:.2f}<extra></extra>'
        )
    )
    
    fig_ratios.update_layout(
        title="Dominance Ratio: First vs Second Singular Value",
        xaxis_title="Layer",
        yaxis_title="s₁/s₂",
        yaxis_type="log",
        template='plotly_white',
        height=500,
        width=800
    )
    #fig_ratios.write_html(f'{save_path}/singular_value_ratios.html')
    fig_ratios.show()
    
    # 7. Variance explained curves for multiple layers
    fig_var_curves = go.Figure()
    layers_to_show = [0, 10, 20, 30, 40]  # Adjust as needed
    colors = px.colors.qualitative.Set2[:len(layers_to_show)]
    
    for idx, layer in enumerate(layers_to_show):
        if layer < n_layers:
            cumvar = svd_results['cumulative_variance'][layer]
            fig_var_curves.add_trace(
                go.Scatter(
                    x=list(range(1, len(cumvar) + 1)),
                    y=cumvar,
                    mode='lines+markers',
                    name=f'Layer {layer}',
                    line=dict(color=colors[idx], width=2),
                    marker=dict(size=6)
                )
            )
    
    fig_var_curves.add_hline(y=0.9, line_dash="dash", line_color="gray", 
                            annotation_text="90%")
    fig_var_curves.add_hline(y=0.95, line_dash="dash", line_color="gray", 
                            annotation_text="95%")
    
    fig_var_curves.update_layout(
        title="Cumulative Variance Explained Across Components",
        xaxis_title="Number of Components",
        yaxis_title="Cumulative Variance Explained",
        template='plotly_white',
        height=500,
        width=800,
        hovermode='x unified'
    )
    #fig_var_curves.write_html(f'{save_path}/variance_curves.html')
    fig_var_curves.show()

def plot_3d_trajectory(svd_results: Dict, persona_idx: int = 0, 
                      persona_name: str = "Assistant", save_path: str = './results/'):
    """
    Create 3D visualization of persona trajectory through PC space.
    """
    n_layers = len(svd_results['U'])
    
    if len(svd_results['U'][0][0]) >= 3:
        pc1_coords = [svd_results['U'][layer][persona_idx, 0] for layer in range(n_layers)]
        pc2_coords = [svd_results['U'][layer][persona_idx, 1] for layer in range(n_layers)]
        pc3_coords = [svd_results['U'][layer][persona_idx, 2] for layer in range(n_layers)]
        
        fig_3d = go.Figure()
        
        fig_3d.add_trace(
            go.Scatter3d(
                x=pc1_coords,
                y=pc2_coords,
                z=pc3_coords,
                mode='lines+markers',
                marker=dict(
                    size=8,
                    color=list(range(n_layers)),
                    colorscale='Viridis',
                    showscale=True,
                    colorbar=dict(title="Layer", thickness=15)
                ),
                line=dict(color='darkblue', width=4),
                text=[f'Layer {i}' for i in range(n_layers)],
                hovertemplate='Layer %{text}<br>PC1: %{x:.3f}<br>PC2: %{y:.3f}<br>PC3: %{z:.3f}'
            )
        )
        
        # Add start and end markers
        fig_3d.add_trace(
            go.Scatter3d(
                x=[pc1_coords[0]], y=[pc2_coords[0]], z=[pc3_coords[0]],
                mode='markers+text',
                marker=dict(size=12, color='green'),
                text=['Start'],
                textposition='top center',
                showlegend=False
            )
        )
        
        fig_3d.add_trace(
            go.Scatter3d(
                x=[pc1_coords[-1]], y=[pc2_coords[-1]], z=[pc3_coords[-1]],
                mode='markers+text',
                marker=dict(size=12, color='red'),
                text=['End'],
                textposition='top center',
                showlegend=False
            )
        )
        
        fig_3d.update_layout(
            title=f"{persona_name} Trajectory in PC Space Across Layers",
            scene=dict(
                xaxis_title="PC1",
                yaxis_title="PC2",
                zaxis_title="PC3",
                camera=dict(eye=dict(x=1.5, y=1.5, z=1.5))
            ),
            height=700,
            width=900,
            template='plotly_white'
        )
        
        #fig_3d.write_html(f'{save_path}/{persona_name.lower()}_trajectory_3d.html')
        fig_3d.show()

def plot_phase_transitions(svd_results: Dict, transitions: List[int], save_path: str = './results/'):
    """
    Visualize phase transitions in the representation.
    """
    fig_trans = go.Figure()
    
    pc1_variance = [svd_results['variance_explained'][i][0] for i in range(len(svd_results['S']))]
    
    fig_trans.add_trace(
        go.Scatter(
            x=list(range(len(pc1_variance))),
            y=pc1_variance,
            mode='lines',
            name='PC1 Variance Explained',
            line=dict(color='blue', width=3)
        )
    )
    
    # Add vertical lines for transitions
    for t in transitions:
        fig_trans.add_vline(
            x=t,
            line_dash="dash",
            line_color="red",
            line_width=2,
            annotation_text=f"Transition<br>Layer {t}",
            annotation_position="top"
        )
    
    # Add shaded regions for phases
    if transitions:
        phases = [0] + transitions + [len(pc1_variance)-1]
        colors = px.colors.qualitative.Pastel[:len(phases)-1]
        
        for i in range(len(phases)-1):
            fig_trans.add_vrect(
                x0=phases[i], x1=phases[i+1],
                fillcolor=colors[i], opacity=0.2,
                layer="below", line_width=0
            )
    
    fig_trans.update_layout(
        title="Phase Transitions in Role Representation",
        xaxis_title="Layer",
        yaxis_title="PC1 Variance Explained",
        template='plotly_white',
        height=500,
        width=1000,
        showlegend=True
    )
    
    #fig_trans.write_html(f'{save_path}/phase_transitions.html')
    fig_trans.show()

In [64]:
plot_cross_layer_analysis(svd_results, alignment_matrix, trajectory)

## Steering on vector

In [None]:
print(len(svd_results["Vt"]))
print(svd_results["Vt"][0].shape)

# get the role-playing direction for each layer from the SVD
# need to change np.ndarray to torch.Tensor
rp_vectors = torch.stack([-torch.from_numpy(svd_results["Vt"][i][0]) for i in range(len(svd_results["Vt"]))])
print(rp_vectors.shape)

print(rp_vectors)

In [None]:
# steer on role-playing direction
# we could also project activations before any model response onto this direction as a barometer of role-playing

In [None]:
# Alternative: SVD on contrast vectors (roles - control) for cleaner role-playing direction
print("Computing role-playing direction by subtracting control first...")

def compute_contrast_svd_results(activations, control_idx=0):
    """
    Perform SVD on contrast vectors (roles - control) for each layer.
    
    Args:
        activations: Shape (n_layers, n_personas, hidden_dim)
        control_idx: Index of control persona (usually 0 for assistant)
    
    Returns:
        Dict containing SVD results for contrast vectors
    """
    n_layers = activations.shape[0]
    results = {
        'U': [],
        'S': [],
        'Vt': [],
        'variance_explained': [],
        'cumulative_variance': []
    }
    
    for layer_idx in range(n_layers):
        # Get activations for this layer
        layer_activations = activations[layer_idx].float().cpu().numpy()
        
        # Extract control and role activations
        control_activation = layer_activations[control_idx:control_idx+1, :]  # Shape: (1, hidden_dim)
        role_activations = np.concatenate([
            layer_activations[:control_idx, :], 
            layer_activations[control_idx+1:, :]
        ], axis=0)  # All personas except control
        
        # Compute contrast matrix: roles - control
        contrast_matrix = role_activations - control_activation  # Broadcasting
        
        # Perform SVD on contrast vectors
        U, S, Vt = np.linalg.svd(contrast_matrix, full_matrices=False)
        
        # Calculate variance metrics
        total_variance = np.sum(S**2)
        variance_explained = (S**2) / total_variance
        cumulative_variance = np.cumsum(variance_explained)
        
        # Store results
        results['U'].append(U)
        results['S'].append(S)
        results['Vt'].append(Vt)
        results['variance_explained'].append(variance_explained)
        results['cumulative_variance'].append(cumulative_variance)
    
    return results

# Compute contrast SVD results
contrast_svd_results = compute_contrast_svd_results(stacked_activations, control_idx=0)

print(f"Contrast SVD shapes - layer 0:")
print(f"  U: {contrast_svd_results['U'][0].shape}")
print(f"  S: {contrast_svd_results['S'][0].shape}")  
print(f"  Vt: {contrast_svd_results['Vt'][0].shape}")

# Extract role-playing directions (no negation needed since we want direction away from control)
contrast_rp_vectors = torch.stack([torch.from_numpy(contrast_svd_results["Vt"][i][0]) for i in range(len(contrast_svd_results["Vt"]))])
print(f"\\nContrast role-playing vectors shape: {contrast_rp_vectors.shape}")

# Compare variance explained by first component
print(f"\\nVariance explained by first component (contrast method):")
for layer_idx in [0, 10, 20, 30, 40]:
    if layer_idx < len(contrast_svd_results['variance_explained']):
        var_exp = contrast_svd_results['variance_explained'][layer_idx][0]
        print(f"  Layer {layer_idx}: {var_exp:.3f}")

print(f"\\nFor comparison, raw activation method at layer {LAYER}:")
print(f"  Raw SVD first component variance: {svd_results['variance_explained'][LAYER][0]:.3f}")
print(f"  Contrast SVD first component variance: {contrast_svd_results['variance_explained'][LAYER][0]:.3f}")

print(f"\\nContrast method gives you:")
print(f"  • Direction explicitly defined as 'away from assistant baseline'")
print(f"  • Positive steering = toward role-playing")
print(f"  • Zero steering = assistant baseline")
print(f"  • Cleaner interpretation for general role-playing direction")