# Trajectories of role + trait space

In [None]:
import json
import os
import sys
import torch
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.subplots as sp
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM

import re

sys.path.append('.')
sys.path.append('..')

from utils.internals import ProbingModel, ConversationEncoder, ActivationExtractor

In [2]:
# load in PCA results
CHAT_MODEL_NAME = "Qwen/Qwen3-32B"
#BASE_MODEL_NAME = "google/gemma-2-27b"
model_readable = "Qwen 3 32B"
model_short = "qwen-3-32b"
layer = 32

auditor_model = "sonnet-4.5"

In [None]:
# filename = "auto40"
components = 1

In [51]:
role_dir = "roles_240" 
trait_dir = "traits_240"

acts_output_dir = f"/workspace/{model_short}/conversations"
acts_input_dir = f"/workspace/{model_short}/dynamics/{auditor_model}/default/activations"
plot_output_dir = f"/root/git/plots/{model_short}/trajectory/{auditor_model}"
os.makedirs(acts_output_dir, exist_ok=True)
os.makedirs(plot_output_dir, exist_ok=True)

## IF NEEDED: Get conversation activations

In [4]:
filename = "spiral"
conversation_file = f"/root/git/persona-subspace/dynamics/results/qwen-3-32b/interactive/{filename}.json"
conversation_obj = json.load(open(conversation_file))
conversation = conversation_obj['conversation']

In [None]:
pm = ProbingModel(CHAT_MODEL_NAME)
model = pm.model
tokenizer = pm.tokenizer

In [6]:
chat_kwargs = {}
if model_short == "qwen-3-32b":    
    chat_kwargs['enable_thinking'] = False


In [None]:
# Use new API to extract mean activations per turn
pm_for_extraction = ProbingModel.from_existing(model, tokenizer)
encoder = ConversationEncoder(pm_for_extraction.tokenizer)
extractor = ActivationExtractor(pm_for_extraction, encoder)

# Extract full activations for the conversation
full_acts = extractor.for_conversation(conversation, layers=[layer])

# Get turn spans
_, turn_spans = encoder.build_turn_spans(conversation, chat_format=False, **chat_kwargs)

# Compute mean activation per turn (for assistant turns)
mean_acts_per_turn = [full_acts[0, start:end].mean(dim=0) for start, end in turn_spans]

In [57]:
print(len(mean_acts_per_turn))
mean_acts_per_turn = [act.squeeze(0) for act in mean_acts_per_turn]
print(mean_acts_per_turn[0].shape)

46
torch.Size([5120])


In [None]:
# save activations 
result = {}
result['object'] = conversation_obj
result['activations'] = mean_acts_per_turn
torch.save(result, f"{acts_output_dir}/{filename}.pt")

In [7]:
tokenizer = AutoTokenizer.from_pretrained(CHAT_MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"

base_model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL_NAME,
        torch_dtype=torch.bfloat16,
        device_map="auto"  # Use all GPUs
    )
base_model.eval()

Loading checkpoint shards:   0%|          | 0/24 [00:00<?, ?it/s]

Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 4608, padding_idx=0)
    (layers): ModuleList(
      (0-45): 46 x Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): Linear(in_features=4608, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4608, out_features=2048, bias=False)
          (v_proj): Linear(in_features=4608, out_features=2048, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4608, bias=False)
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear(in_features=4608, out_features=36864, bias=False)
          (up_proj): Linear(in_features=4608, out_features=36864, bias=False)
          (down_proj): Linear(in_features=36864, out_features=4608, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm((4608,), eps=1e-06)
        (post_attention_layernorm): Gemma2RMSNorm((4608,), eps=1e-06)
        (pre_feedforward_layernorm): G

In [None]:
# Use new API for base model activations
pm_base = ProbingModel.from_existing(base_model, tokenizer)
encoder_base = ConversationEncoder(pm_base.tokenizer)
extractor_base = ActivationExtractor(pm_base, encoder_base)

# Extract full activations
full_acts_base = extractor_base.for_conversation(conversation, layers=[layer])

# Get turn spans
_, turn_spans_base = encoder_base.build_turn_spans(conversation, **chat_kwargs)

# Compute mean activation per turn
base_mean_acts_per_turn = [full_acts_base[0, start:end].mean(dim=0) for start, end in turn_spans_base]

# Save results
base_result = {}
base_result['object'] = conversation_obj
base_result['activations'] = base_mean_acts_per_turn
torch.save(base_result, f"{acts_output_dir}/{filename}_base.pt")

In [None]:
activations_obj = torch.load(f"/workspace/{model_short}/conversations/{filename}.pt", weights_only=False, map_location="cpu")
mean_acts_per_turn = activations_obj['activations']

torch.Size([36, 4608])


In [19]:
contrast_vector = torch.stack(mean_acts_per_turn) - base_mean_acts_per_turn
print(contrast_vector.shape)

torch.Size([36, 4608])


## OR: Load activations 

In [6]:
# load every activation file in the directory if the name matches a regex

def mean_auto_activations(acts_input_dir, regex_pattern, layer):
    # Extract (turns, hidden) slice for each file matching the regex pattern

    slices = []
    max_turns = 0
    H = None

    for file in os.listdir(acts_input_dir):
        if re.match(regex_pattern, file):
            A = torch.load(f"{acts_input_dir}/{file}", weights_only=False, map_location="cpu")['activations']
            T, _, H_cur = A.shape

            T_eff = T - (T % 2)     # drop last turn if odd (user ended conversation)
            if T_eff == 0:
                continue   

            max_turns = max(max_turns, T_eff)

            sl = A[:T_eff, layer, :]  # (T_eff, H_cur)
            slices.append(sl)

            if H is None:
                H = H_cur
            else:
                assert H == H_cur

    if len(slices) == 0:
        raise ValueError("No usable activations after dropping odd last turns.")

    # pad slices with shorter turns with NaN to (N, max_turns, H)
    N = len(slices)
    print(f"Padding {N} activations found to (100, {max_turns}, {H})")
    padded = torch.full((N, max_turns, H), float('nan'))
    for i, sl in enumerate(slices):
        T_eff = sl.shape[0]
        padded[i, :T_eff, :] = sl

    # mean over the first two dimensions
    mean_acts = torch.nanmean(padded, dim=0)
    return mean_acts


In [42]:
domains = ["coding", "writing", "therapy", "philosophy"]

instruct = {}
#base = {}

for domain in domains:
    instruct[domain] = mean_auto_activations(f"{acts_input_dir}", r'^' + domain, layer)
    #base[domain] = mean_auto_activations(f"/workspace/{model_short}/dynamics/base", r'^' + domain, layer)



Padding 100 activations found to (100, 28, 4608)
Padding 100 activations found to (100, 28, 4608)
Padding 100 activations found to (100, 28, 4608)
Padding 100 activations found to (100, 28, 4608)


In [None]:
contrast = {}
for domain in domains:
    contrast[domain] = instruct[domain] - base[domain]


In [None]:
filename = "spiral"

activations_obj = torch.load(f"{acts_input_dir}/{filename}.pt", weights_only=False, map_location="cpu")
mean_acts_per_turn = activations_obj['activations']

print(mean_acts_per_turn.shape)

conversation_file = f"/root/git/persona-subspace/dynamics/results/{model_short}/transcripts/{filename}.json"
conversation_obj = json.load(open(conversation_file))
conversation = conversation_obj['conversation']


## Cosine similarity with trait and role PCs

In [58]:
# load in activations
role_results = torch.load(f"/workspace/{model_short}/{role_dir}/pca/layer{layer}_pos23.pt", weights_only=False)
trait_results = torch.load(f"/workspace/{model_short}/{trait_dir}/pca/layer{layer}_pos-neg50.pt", weights_only=False)


In [18]:
def pc_cosine_similarity(mean_acts_per_turn, pca_results, n_pcs=8):
    if isinstance(mean_acts_per_turn, list):
        stacked_acts = torch.stack(mean_acts_per_turn)
    else:
        stacked_acts = mean_acts_per_turn
    normalized_acts = F.normalize(stacked_acts, dim=1)
    normalized_pcs = pca_results['pca'].components_[:n_pcs] / np.linalg.norm(pca_results['pca'].components_[:n_pcs], axis=1, keepdims=True)
    cosine_sims = normalized_acts.float().numpy() @ normalized_pcs.T
    return cosine_sims

def pc_projection(mean_acts_per_turn, pca_results, n_pcs=8):
    if isinstance(mean_acts_per_turn, list):
        stacked_acts = torch.stack(mean_acts_per_turn)
    else:
        stacked_acts = mean_acts_per_turn
    stacked_acts = stacked_acts.float().numpy()
    scaled_acts = pca_results['scaler'].transform(stacked_acts)
    projected_acts = pca_results['pca'].transform(scaled_acts)
    return projected_acts[:, :n_pcs]
    
    

In [59]:
role_sims = pc_cosine_similarity(mean_acts_per_turn, role_results, components)
trait_sims = pc_cosine_similarity(mean_acts_per_turn, trait_results, components)
print(role_sims.shape)

role_projs = pc_projection(mean_acts_per_turn, role_results, components)
trait_projs = pc_projection(mean_acts_per_turn, trait_results, components)
print(role_projs.shape)


(46, 6)
(46, 6)


In [None]:
print(conversation[0])

## Plot trajectory

In [None]:
def plot_mean_response_trajectory(similarity_matrix, conversation=None, title=None, pc_titles=None, projection=False):
    """
    Create a single line plot showing mean response per turn.
    
    Parameters:
    - similarity_matrix: Numpy matrix of shape (n_turns, n_pcs)
    - conversation: Optional conversation data for turn context
    - title: Optional custom title
    - pc_titles: Optional list of PC titles
    - projection: Whether this is projection or cosine similarity
    
    Returns:
    - Plotly figure object
    """
    
    print("Creating mean response trajectory plot...")
    
    # Get dimensions
    n_turns, n_pcs = similarity_matrix.shape
    turn_indices = np.arange(n_turns)
    
    # Create default PC titles if not provided
    if pc_titles is None:
        pc_titles = [f"PC{i+1}" for i in range(n_pcs)]
    
    # Define color palette for PCs
    pc_colors = [
        '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', 
        '#9467bd', '#8c564b', '#e377c2', '#7f7f7f',
        '#bcbd22', '#17becf', '#aec7e8', '#ffbb78'
    ]
    
    # Helper function to wrap text for hover display
    def wrap_text(text, max_chars_per_line=70):
        """Wrap text to specified line length with HTML breaks."""
        if len(text) <= max_chars_per_line:
            return text
        
        words = text.split()
        lines = []
        current_line = []
        current_length = 0
        
        for word in words:
            if current_length + len(word) + len(current_line) > max_chars_per_line:
                if current_line:  # Don't add empty lines
                    lines.append(' '.join(current_line))
                current_line = [word]
                current_length = len(word)
            else:
                current_line.append(word)
                current_length += len(word)
        
        if current_line:  # Add the last line
            lines.append(' '.join(current_line))
        
        return '<br>'.join(lines)
    
    # Create enhanced turn context for hover text
    turn_contexts = []
    if conversation is not None:
        # Assistant turns only
        assistant_turns = [i for i, turn in enumerate(conversation) if turn['role'] == 'assistant']
        for turn_idx in range(n_turns):
            if turn_idx < len(assistant_turns):
                conv_turn_idx = assistant_turns[turn_idx]
                if conv_turn_idx < len(conversation):
                    # Get assistant response
                    assistant_content = conversation[conv_turn_idx]['content']
                    
                    # Get preceding user question (if exists)
                    user_content = ""
                    if conv_turn_idx > 0 and conversation[conv_turn_idx - 1]['role'] == 'user':
                        user_content = conversation[conv_turn_idx - 1]['content']
                    
                    # Format hover text with both user question and assistant response
                    hover_parts = [f"<b>Turn {turn_idx}</b>"]
                    
                    if user_content:
                        # Truncate user content to reasonable length and wrap
                        user_truncated = user_content[:200] + "..." if len(user_content) > 200 else user_content
                        user_wrapped = wrap_text(user_truncated, 70)
                        hover_parts.append(f"<b>User:</b> {user_wrapped}")
                    
                    # Show more of the assistant response (150-200 chars) and wrap
                    assistant_truncated = assistant_content[:300] + "..." if len(assistant_content) > 180 else assistant_content
                    assistant_wrapped = wrap_text(assistant_truncated, 70)
                    hover_parts.append(f"<b>Assistant:</b> {assistant_wrapped}")
                    
                    turn_contexts.append('<br>'.join(hover_parts))
                else:
                    turn_contexts.append(f"<b>Turn {turn_idx}</b>")
            else:
                turn_contexts.append(f"<b>Turn {turn_idx}</b>")
    else:
        turn_contexts = [f"<b>Turn {turn_idx}</b>" for turn_idx in range(n_turns)]
    
    # Create Plotly figure
    fig = go.Figure()
    
    # Add line traces for each PC with markers
    for pc_idx in range(n_pcs):
        pc_name = pc_titles[pc_idx]
        similarities = similarity_matrix[:, pc_idx]
        color = pc_colors[pc_idx % len(pc_colors)]
        
        fig.add_trace(go.Scatter(
            x=turn_indices,
            y=similarities,
            mode='lines+markers',
            name=pc_name,
            line=dict(color=color, width=2),
            marker=dict(color=color, size=4, opacity=0.8),
            hovertemplate='<b>%{fullData.name}</b><br>' +
                         '%{text}<br>' +
                         '<b>Projection:</b> %{y:.3f}<br>' +
                         '<extra></extra>',
            text=turn_contexts
        ))
    
    # Update layout
    default_title = f"Mean Response PC Trajectory"
    if projection:
        yaxis_title = "PC Projection"
    else:
        yaxis_title = "Cosine Similarity with PC"
    
    fig.update_layout(
        title=dict(
            text=title if title else default_title,
            x=0.5,
            font=dict(size=16),
            subtitle={"text": f"{model_readable}, Layer {layer}"}
        ),
        xaxis_title="Conversation Turn",
        yaxis_title=yaxis_title,
        width=1400,
        height=600,
        hovermode='closest',
        legend=dict(
            yanchor="middle",
            y=0.5,
            xanchor="left",
            x=1.02,
            bgcolor="rgba(255,255,255,0.8)"
        )
    )
    
    # Add grid for easier reading
    fig.update_xaxes(
        showgrid=True, 
        gridwidth=1, 
        gridcolor='lightgray',
        zeroline=True,
        tick0=0,
        dtick=1  # Show every turn
    )
    fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='lightgray', zeroline=True)
    
    # Add light vertical lines between turns for clarity
    for turn_idx in range(1, n_turns):
        fig.add_vline(
            x=turn_idx - 0.5,
            line_dash="dot",
            line_color="lightgray",
            line_width=1,
            opacity=0.3
        )
    
    print(f"Created trajectory plot with {n_pcs} PC lines across {n_turns} turns")
    
    return fig

In [61]:
# Create role PC titles
if model_short == "qwen-3-32b":
    role_pc_titles = ['Assistant-like ↔ role-playing', "mystical/transcendent ↔ mundane/irreverent", "empathetic/vulnerable ↔ analytical/predatory", "concrete/practical ↔ abstract/ideological", "thinking/passive ↔ doing/active", "creative/expressive ↔ rigid/constrained"]
elif model_short == "gemma-2-27b":
    role_pc_titles = ['Assistant-like ↔ role-playing', "inhuman ↔ human", "independent ↔ dependent", "nurturing ↔ adversarial", "social ↔ technical", "structured ↔ liminal"]
elif model_short == "llama-3.3-70b":
    role_pc_titles = ['Assistant-like ↔ role-playing', "mystical ↔ mundane", "empathetic ↔ analytical", "adversarial ↔ nurturing", "provocative ↔ submissive", "social ↔ technical"]
# # Plot role trajectory
# role_fig = plot_mean_response_trajectory(
#     role_sims, 
#     conversation=conversation, 
#     title=f"Conversation Trajectory in Role PC Space: {filename.capitalize()}", 
#     pc_titles=role_pc_titles
# )
# role_fig.show()
# role_fig.write_html(f"{plot_output_dir}/{filename}_role_cossim.html")

# role_fig = plot_mean_response_trajectory(
#     role_projs, 
#     conversation=conversation, 
#     title=f"Conversation Trajectory in Role PC Space: {filename.capitalize()}", 
#     pc_titles=role_pc_titles,
#     projection=True
# )
# role_fig.show()
# role_fig.write_html(f"{plot_output_dir}/{filename}_role_proj.html")

In [62]:
# Create trait PC titles
if model_short == "qwen-3-32b":
    trait_pc_titles = ["expressive/irreverent ↔ controlled/professional", "analytical ↔ intuitive", "accessible/practical ↔ esoteric/complex", "active/flexible ↔ passsive/rigid", "questioning ↔ confident", "indirect/diplomatic ↔ direct/assertive"]
elif model_short == "gemma-2-27b":
    trait_pc_titles = ["benevolent ↔ evil", "analytical ↔ emotional", "accommodating ↔ confrontational", "grounded ↔ mystical", "thoughtful ↔ decisive", "conformist ↔ defiant"]
elif model_short == "llama-3.3-70b":
    trait_pc_titles = ["agreeable ↔ antagonistic", "emotional ↔ analytical", "bombastic ↔ accessible", "assertive ↔ circumspect", "exploratory ↔ absolutist", "independent ↔ anxious"]
# Plot trait trajectory
# trait_fig = plot_mean_response_trajectory(
#     trait_sims, 
#     conversation=conversation, 
#     title=f"Conversation Trajectory in Trait PC Space: {filename.capitalize()}", 
#     pc_titles=trait_pc_titles
# )
# trait_fig.show()
# trait_fig.write_html(f"{plot_output_dir}/{filename}_trait_cossim.html")

# trait_fig = plot_mean_response_trajectory(
#     trait_projs, 
#     conversation=conversation, 
#     title=f"Conversation Trajectory in Trait PC Space: {filename.capitalize()}", 
#     pc_titles=trait_pc_titles,
#     projection=True
# )

# trait_fig.show()
# trait_fig.write_html(f"{plot_output_dir}/{filename}_trait_proj.html")

In [None]:
# Create a combined plot showing both role and trait trajectories

def combined_fig(role_projs, trait_projs, title, conversation=None):

    # Create subplot figure
    combined_fig = sp.make_subplots(
        rows=2, cols=1,
        subplot_titles=("Trajectory in Role PC Space", "Trajectory in Trait PC Space"),
        vertical_spacing=0.1
    )

    for tr in plot_mean_response_trajectory(
            role_projs, conversation=conversation, title="<b>Role PCs</b>", pc_titles=role_pc_titles, projection=True
        ).data:
        tr.update(
            showlegend=True,
            legendgroup="role",
            legendgrouptitle_text="Role PCs"
        )
        combined_fig.add_trace(tr, row=1, col=1)

    # Add trait traces to second subplot
    for tr in plot_mean_response_trajectory(
            trait_projs, conversation=conversation, title="<b>Trait PCs</b>", pc_titles=trait_pc_titles, projection=True
        ).data:
        tr.update(
            showlegend=True,
            legendgroup="trait",
            legendgrouptitle_text="Trait PCs"
        )
        combined_fig.add_trace(tr, row=2, col=1)


    combined_fig.update_layout(
        height=800,
        width=1200,
        title={
            "text": title,
            "subtitle": {"text": f"{model_readable}, Layer {layer}"}
        },
        showlegend=True,
        legend=dict(
            traceorder="grouped",     # keeps groups together
            groupclick="toggleitem",
            x=1.02, xanchor="left",   # place legend outside on the right
            y=1, yanchor="top"
        ),
        legend_tracegroupgap=12       # space between groups
    )
    # Bottom x-axis label only
    combined_fig.update_xaxes(title_text="Conversation Turn", row=2, col=1)

    # Y-axis labels for both subplots
    combined_fig.update_yaxes(title_text="PC Projection", row=1, col=1)
    combined_fig.update_yaxes(title_text="PC Projection", row=2, col=1)
    return combined_fig


In [None]:
fig = combined_fig(role_projs, trait_projs, f"Conversation Trajectories in Persona Subspace: {filename}", conversation=conversation)
fig.show()
fig.write_html(f"{plot_output_dir}/{filename}.html")




In [None]:
role_projs = pc_projection(contrast_vector, role_results, components)
trait_projs = pc_projection(contrast_vector, trait_results, components)
domain_fig = combined_fig(role_projs, trait_projs, f"Conversation Trajectories in Persona Subspace: {filename.capitalize()} (Instruct - Base)")
domain_fig.show()
domain_fig.write_html(f"{plot_output_dir}/{filename}_contrast.html")

In [22]:
def plot_domain_subplots(domain_projs, title, pc_titles, space_type="role"):
    """
    Create a 2x2 subplot figure showing trajectories for 4 domains.
    
    Parameters:
    - domain_projs: Dict mapping domain names to projection matrices
                    Each matrix should be shape (n_turns, n_pcs)
    - title: Main title for the figure
    - pc_titles: List of PC titles for the legend
    - space_type: Either "role" or "trait" for y-axis label
    
    Returns:
    - Plotly figure object with 4 subplots (2x2 grid)
    """
    
    domains = ["coding", "writing", "therapy", "philosophy"]
    
    # Define color palette for PCs (same as plot_mean_response_trajectory)
    pc_colors = [
        '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', 
        '#9467bd', '#8c564b', '#e377c2', '#7f7f7f',
        '#bcbd22', '#17becf', '#aec7e8', '#ffbb78'
    ]
    
    # Compute global x and y ranges across all domains
    all_x = []
    all_y = []
    for domain in domains:
        mat = domain_projs[domain]
        n_turns, n_pcs = mat.shape
        all_x.extend(range(n_turns))
        all_y.extend(mat[~np.isnan(mat)].flatten())
    
    global_x_range = [min(all_x) - 1, max(all_x) + 1]
    y_min, y_max = np.nanmin(all_y), np.nanmax(all_y)
    pad = 10
    global_y_range = [y_min - pad, y_max + pad]
    
    # Create subplot titles
    subplot_titles = [f"100 {domain.capitalize()} Conversations" for domain in domains]
    
    # Create 2x2 subplot figure
    fig = sp.make_subplots(
        rows=2, cols=2,
        subplot_titles=subplot_titles,
        vertical_spacing=0.1,
        horizontal_spacing=0.02
    )
    
    # Track which PCs have been added to legend
    shown_pcs = set()
    
    # Add traces for each domain
    for domain_idx, domain in enumerate(domains):
        row = (domain_idx // 2) + 1
        col = (domain_idx % 2) + 1
        
        mat = domain_projs[domain]
        n_turns, n_pcs = mat.shape
        turn_indices = np.arange(n_turns)
        
        # Add a trace for each PC
        for pc_idx in range(n_pcs):
            pc_name = pc_titles[pc_idx]
            projections = mat[:, pc_idx]
            color = pc_colors[pc_idx % len(pc_colors)]
            
            # Only show legend for first occurrence of each PC
            showlegend = pc_name not in shown_pcs
            if showlegend:
                shown_pcs.add(pc_name)
            
            fig.add_trace(go.Scatter(
                x=turn_indices,
                y=projections,
                mode='lines+markers',
                name=pc_name,
                line=dict(color=color, width=1.5),
                marker=dict(color=color, size=3, opacity=0.8),
                legendgroup=pc_name,
                showlegend=showlegend,
                hovertemplate='<b>%{fullData.name}</b><br>' +
                             'Turn: %{x}<br>' +
                             'Projection: %{y:.3f}<br>' +
                             '<extra></extra>'
            ), row=row, col=col)
    
    # Update layout
    fig.update_layout(
        height=700,
        width=1050,
        title={
            "text": title,
            "subtitle": {"text": f"{model_readable}, Layer {layer}"}
        },
        showlegend=True,
        legend=dict(
            title=dict(
                    text=f"{space_type.capitalize()} PCs:",
                    font=dict(size=12)
                    ),
            traceorder="normal",
            groupclick="toggleitem",
            x=1.02,
            xanchor="left",
            y=1,
            yanchor="top"
        )
    )
    
    # Update axes for all subplots
    for r in [1, 2]:
        for c in [1, 2]:
            # Apply global ranges
            fig.update_xaxes(range=global_x_range, row=r, col=c)
            fig.update_yaxes(range=global_y_range, row=r, col=c)
            
            # Add gridlines
            fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='lightgray', 
                           zeroline=True, row=r, col=c)
            fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='lightgray',
                           zeroline=True, row=r, col=c)
            
            # Smaller tick fonts
            fig.update_xaxes(tickfont=dict(size=10), row=r, col=c)
            fig.update_yaxes(tickfont=dict(size=10), row=r, col=c)
    
    # Hide y-axis tick labels for right column
    fig.update_yaxes(showticklabels=False, row=1, col=2)
    fig.update_yaxes(showticklabels=False, row=2, col=2)
    
    # Add axis labels only to left and bottom edges
    fig.update_xaxes(title_text="Conversation Turn", title_font=dict(size=12), row=2, col=1)
    fig.update_xaxes(title_text="Conversation Turn", title_font=dict(size=12), row=2, col=2)
    
    y_label = "PC Loading"
    fig.update_yaxes(title_text=y_label, title_font=dict(size=12), row=1, col=1)
    fig.update_yaxes(title_text=y_label, title_font=dict(size=12), row=2, col=1)
    
    # Make subplot titles slightly smaller
    for ann in fig.layout.annotations:
        ann.font.size = 14
    
    return fig



In [44]:

# Example usage (commented out):
role_domain_projs = {domain: pc_projection(instruct[domain][1::2], role_results, components) 
                     for domain in domains}
fig = plot_domain_subplots(role_domain_projs, 
                           f"Persona Trajectories in Role Subspace ({auditor_model.title()} as Human User)", 
                           role_pc_titles, 
                           space_type="role")
fig.show()
fig.write_html(f"{plot_output_dir}/domains_role.html")

In [47]:
trait_domain_projs = {domain: pc_projection(instruct[domain][1::2], trait_results, components) 
                     for domain in domains}
fig = plot_domain_subplots(trait_domain_projs, 
                           f"Persona Trajectories in Trait Subspace ({auditor_model.upper()} as Human User)", 
                           trait_pc_titles, 
                           space_type="trait")
fig.show()
fig.write_html(f"{plot_output_dir}/domains_trait.html")