# Trajectories of different roles across multi-turn convos in role PC1

In [2]:
import json
import os
import sys
import torch
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.subplots as sp
import torch.nn.functional as F

import re

sys.path.append('.')
sys.path.append('..')

from utils.inference_utils import *
from utils.probing_utils import *


INFO 09-25 21:40:29 [__init__.py:235] Automatically detected platform cuda.


In [3]:
# load in PCA results
CHAT_MODEL_NAME = "Qwen/Qwen3-32B"
model_readable = "Qwen 3 32B"
model_short = "qwen-3-32b"
layer = 32


In [4]:
# filename = "auto40"
components = 6
domains = ["coding", "writing", "therapy", "philosophy"]
roles = ["default", "validator", "podcaster", "visionary", "crystalline", "whale", "leviathan"]

In [5]:
role_dir = "roles_240" 
trait_dir = "traits_240"

acts_input_dir = f"/workspace/{model_short}/dynamics"
plot_output_dir = f"/root/git/plots/{model_short}/trajectory/roles"
os.makedirs(plot_output_dir, exist_ok=True)

## Get conversation activations

In [13]:
# load every activation file in the directory if the name matches a regex
max_turns = 28

def mean_auto_activations(acts_input_dir, regex_pattern, layer):
    # Extract (turns, hidden) slice for each file matching the regex pattern

    slices = []
    H = None

    for file in os.listdir(acts_input_dir):
        if re.match(regex_pattern, file):
            A = torch.load(f"{acts_input_dir}/{file}", weights_only=False, map_location="cpu")['activations']
            T, _, H_cur = A.shape

            T_eff = T - (T % 2)     # drop last turn if odd (user ended conversation)
            if T_eff == 0:
                continue   

            sl = A[:T_eff, layer, :]  # (T_eff, H_cur)
            slices.append(sl)

            if H is None:
                H = H_cur
            else:
                assert H == H_cur

    if len(slices) == 0:
        raise ValueError("No usable activations after dropping odd last turns.")

    # pad slices with shorter turns with NaN to (N, max_turns, H)
    N = len(slices)
    print(f"Padding {N} activations found to ({N}, {max_turns}, {H})")
    padded = torch.full((N, max_turns, H), float('nan'))
    for i, sl in enumerate(slices):
        T_eff = sl.shape[0]
        padded[i, :T_eff, :] = sl

    # mean over the first two dimensions
    mean_acts = torch.nanmean(padded, dim=0)
    return mean_acts


In [14]:
role = "default"
results = {}
for domain in domains:
    role_acts = mean_auto_activations(f"{acts_input_dir}/{role}", r'^' + domain, layer)
    print(role_acts.shape)
    results[domain] = role_acts

out_dir = f"{acts_input_dir}/mean_activations"
os.makedirs(out_dir, exist_ok=True)
torch.save(results, f"{out_dir}/{role}.pt")



Padding 100 activations found to (100, 28, 5120)
torch.Size([28, 5120])
Padding 100 activations found to (100, 28, 5120)
torch.Size([28, 5120])
Padding 100 activations found to (100, 28, 5120)
torch.Size([28, 5120])
Padding 100 activations found to (100, 28, 5120)
torch.Size([28, 5120])


## Cosine similarity with trait and role PCs

In [6]:
# load in activations
role_results = torch.load(f"/workspace/{model_short}/{role_dir}/pca/layer{layer}_pos23.pt", weights_only=False)
trait_results = torch.load(f"/workspace/{model_short}/{trait_dir}/pca/layer{layer}_pos-neg50.pt", weights_only=False)


In [7]:
def pc_cosine_similarity(mean_acts_per_turn, pca_results, n_pcs=8):
    if isinstance(mean_acts_per_turn, list):
        stacked_acts = torch.stack(mean_acts_per_turn)
    else:
        stacked_acts = mean_acts_per_turn
    normalized_acts = F.normalize(stacked_acts, dim=1)
    normalized_pcs = pca_results['pca'].components_[:n_pcs] / np.linalg.norm(pca_results['pca'].components_[:n_pcs], axis=1, keepdims=True)
    cosine_sims = normalized_acts.float().numpy() @ normalized_pcs.T
    return cosine_sims

def pc_projection(mean_acts_per_turn, pca_results, n_pcs=8):
    if isinstance(mean_acts_per_turn, list):
        stacked_acts = torch.stack(mean_acts_per_turn)
    else:
        stacked_acts = mean_acts_per_turn
    stacked_acts = stacked_acts.float().numpy()
    scaled_acts = pca_results['scaler'].transform(stacked_acts)
    projected_acts = pca_results['pca'].transform(scaled_acts)
    return projected_acts[:, :n_pcs]
    
    

In [99]:
projs = {}
projs['role'] = {}
projs['trait'] = {}

for role in roles:
    acts = torch.load(f"/workspace/{model_short}/dynamics/mean_activations/{role}.pt", weights_only=False)

    for domain in domains:
        if domain not in projs['role']:
            projs['role'][domain] = []
            projs['trait'][domain] = []

        projs['role'][domain].append(pc_projection(acts[domain], role_results, 6))
        projs['trait'][domain].append(pc_projection(acts[domain], trait_results, 6))


In [None]:
# for each domain plot, we need to create a matrix of shape (n_turns, n_domains) of the projection onto pc i
# projs['role'][domain] contains a list n_domains long of (n_turns, n_pcs) arrays
 
for domain in domains:
    role_list = projs['role'][domain]
    trait_list = projs['trait'][domain]

    # find the max number of turns across all domains
    max_len = max(arr.shape[0] for arr in role_list)

    # pad each array to max_len with NaN
    role_padded = [
        np.pad(a.reshape(-1, 1), ((0, max_len - a.shape[0]), (0, 0)), constant_values=np.nan)
        for a in role_list
    ]
    trait_padded = [
        np.pad(a.reshape(-1, 1), ((0, max_len - a.shape[0]), (0, 0)), constant_values=np.nan)
        for a in trait_list
    ]

    # stack into (max_len, n_domains)
    projs['role'][domain] = np.hstack(role_padded)
    projs['trait'][domain] = np.hstack(trait_padded)

print(projs['role']['coding'].shape)  # (max_len, n_domains)


(30, 7)


In [102]:

def build_pc_matrices_np(projs_section: dict, num_pcs: int = 6):
    """
    projs_section: like projs['role'] or projs['trait'].
      projs_section[domain] is a list of length n_roles,
      where each item is a (n_turns, n_pcs) NumPy array (but n_turns may differ).

    Returns:
      pc_mats_by_domain: dict[str, list[np.ndarray]]
        For each domain: [mat_pc0, mat_pc1, ..., mat_pc{k}],
        each mat_pcX has shape (max_n_turns, n_roles) with NaN padding.
    """
    pc_mats_by_domain = {}

    for domain, arr_list in projs_section.items():
        if not arr_list:
            continue

        # ensure arrays
        arr_list = [np.asarray(a) for a in arr_list]
        n_roles = len(arr_list)
        max_turns = max(a.shape[0] for a in arr_list)
        n_pcs_in_item = min(a.shape[1] for a in arr_list)  # safe if not all arrays have same pcs

        k = min(num_pcs, n_pcs_in_item)
        mats_for_domain = []

        for pc_idx in range(k):
            # allocate with NaNs
            mat_pc = np.full((max_turns, n_roles), np.nan, dtype=float)
            for role_idx, a in enumerate(arr_list):
                n_turns = a.shape[0]
                mat_pc[:n_turns, role_idx] = a[:, pc_idx]
            mats_for_domain.append(mat_pc)

        pc_mats_by_domain[domain] = mats_for_domain

    return pc_mats_by_domain


def flatten_by_pc_np(pc_mats_by_domain: dict, num_pcs: int = 6):
    """
    Turn domain->list(pc matrices) into a list by PC index:
      [ {domain: mat_pc0}, {domain: mat_pc1}, ..., {domain: mat_pc{num_pcs-1}} ]
    """
    pcs = [dict() for _ in range(num_pcs)]
    for domain, mats in pc_mats_by_domain.items():
        for i, mat in enumerate(mats):
            pcs[i][domain] = mat
    return pcs


In [103]:
# Build per-domain matrices for first 6 PCs
role_pc_mats  = build_pc_matrices_np(projs['role'],  num_pcs=6)  # dict[domain] -> list of ≤6 (n_turns, n_roles)
trait_pc_mats = build_pc_matrices_np(projs['trait'], num_pcs=6)

# Optional: PC-major view (handy for batching)
role_pc_flat  = flatten_by_pc_np(role_pc_mats,  num_pcs=6)  # list[6] of dicts: role_pc_flat[i][domain] -> (n_turns, n_roles)
trait_pc_flat = flatten_by_pc_np(trait_pc_mats, num_pcs=6)


## Plot trajectory

In [139]:
def plot_mean_response_trajectory(similarity_matrix, conversation=None, title=None, pc_titles=None, projection=False, include_user=False):
    """
    Create a single line plot showing mean response per turn.
    
    Parameters:
    - similarity_matrix: Numpy matrix of shape (n_turns, n_pcs)
    - conversation: Optional conversation data for turn context
    - title: Optional custom title
    - pc_titles: Optional list of PC titles
    - projection: Whether this is projection or cosine similarity
    - include_user: If True, use different marker shapes for even turns (user turns)
    
    Returns:
    - Plotly figure object
    """

    
    # Get dimensions
    n_turns, n_pcs = similarity_matrix.shape
    turn_indices = np.arange(n_turns)
    
    # Create default PC titles if not provided
    if pc_titles is None:
        pc_titles = [f"PC{i+1}" for i in range(n_pcs)]
    
    # Define color palette for PCs
    pc_colors = [
        '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', 
        '#9467bd', '#8c564b', '#e377c2', '#7f7f7f',
        '#bcbd22', '#17becf', '#aec7e8', '#ffbb78'
    ]
    
    # Helper function to wrap text for hover display
    def wrap_text(text, max_chars_per_line=70):
        """Wrap text to specified line length with HTML breaks."""
        if len(text) <= max_chars_per_line:
            return text
        
        words = text.split()
        lines = []
        current_line = []
        current_length = 0
        
        for word in words:
            if current_length + len(word) + len(current_line) > max_chars_per_line:
                if current_line:  # Don't add empty lines
                    lines.append(' '.join(current_line))
                current_line = [word]
                current_length = len(word)
            else:
                current_line.append(word)
                current_length += len(word)
        
        if current_line:  # Add the last line
            lines.append(' '.join(current_line))
        
        return '<br>'.join(lines)
    
    # Create enhanced turn context for hover text
    turn_contexts = []
    if conversation is not None:
        if include_user:
            # Include both user and assistant turns
            for turn_idx in range(n_turns):
                if turn_idx < len(conversation):
                    turn = conversation[turn_idx]
                    role = turn['role']
                    content = turn['content']
                    
                    # Format hover text
                    hover_parts = [f"<b>Turn {turn_idx} ({role.capitalize()})</b>"]
                    
                    # Truncate content to reasonable length and wrap
                    content_truncated = content[:300] + "..." if len(content) > 300 else content
                    content_wrapped = wrap_text(content_truncated, 70)
                    hover_parts.append(f"<b>{role.capitalize()}:</b> {content_wrapped}")
                    
                    turn_contexts.append('<br>'.join(hover_parts))
                else:
                    turn_contexts.append(f"<b>Turn {turn_idx}</b>")
        else:
            # Original logic - assistant turns only
            assistant_turns = [i for i, turn in enumerate(conversation) if turn['role'] == 'assistant']
            for turn_idx in range(n_turns):
                if turn_idx < len(assistant_turns):
                    conv_turn_idx = assistant_turns[turn_idx]
                    if conv_turn_idx < len(conversation):
                        # Get assistant response
                        assistant_content = conversation[conv_turn_idx]['content']
                        
                        # Get preceding user question (if exists)
                        user_content = ""
                        if conv_turn_idx > 0 and conversation[conv_turn_idx - 1]['role'] == 'user':
                            user_content = conversation[conv_turn_idx - 1]['content']
                        
                        # Format hover text with both user question and assistant response
                        hover_parts = [f"<b>Turn {turn_idx}</b>"]
                        
                        if user_content:
                            # Truncate user content to reasonable length and wrap
                            user_truncated = user_content[:200] + "..." if len(user_content) > 200 else user_content
                            user_wrapped = wrap_text(user_truncated, 70)
                            hover_parts.append(f"<b>User:</b> {user_wrapped}")
                        
                        # Show more of the assistant response (150-200 chars) and wrap
                        assistant_truncated = assistant_content[:300] + "..." if len(assistant_content) > 180 else assistant_content
                        assistant_wrapped = wrap_text(assistant_truncated, 70)
                        hover_parts.append(f"<b>Assistant:</b> {assistant_wrapped}")
                        
                        turn_contexts.append('<br>'.join(hover_parts))
                    else:
                        turn_contexts.append(f"<b>Turn {turn_idx}</b>")
                else:
                    turn_contexts.append(f"<b>Turn {turn_idx}</b>")
    else:
        turn_contexts = [f"<b>Turn {turn_idx}</b>" for turn_idx in range(n_turns)]
    
    # Create Plotly figure
    fig = go.Figure()
    
    # Add line traces for each PC with markers
    for pc_idx in range(n_pcs):
        pc_name = pc_titles[pc_idx]
        similarities = similarity_matrix[:, pc_idx]
        color = pc_colors[pc_idx % len(pc_colors)]
        
        if include_user:
            # Create separate traces for user and assistant turns with different markers
            assistant_indices = turn_indices[1::2]  # Odd indices (1, 3, 5, ...)
            user_indices = turn_indices[::2]        # Even indices (0, 2, 4, ...)
            
            assistant_sims = similarities[1::2]
            user_sims = similarities[::2]
            
            assistant_contexts = [turn_contexts[i] for i in assistant_indices if i < len(turn_contexts)]
            user_contexts = [turn_contexts[i] for i in user_indices if i < len(turn_contexts)]
            
            # Assistant traces (circles with lines)
            fig.add_trace(go.Scatter(
                x=assistant_indices,
                y=assistant_sims,
                mode='lines+markers',
                name=f"{pc_name} (Assistant)",
                line=dict(color=color, width=1.5),
                marker=dict(color=color, size=3, opacity=0.8, symbol='circle'),
                hovertemplate='<b>%{fullData.name}</b><br>' +
                             '%{text}<br>' +
                             '<b>Projection:</b> %{y:.3f}<br>' +
                             '<extra></extra>',
                text=assistant_contexts,
                legendgroup=f"pc{pc_idx}",
                showlegend=True
            ))
            
            # User traces (squares with lines)
            fig.add_trace(go.Scatter(
                x=user_indices,
                y=user_sims,
                mode='lines+markers',
                name=f"{pc_name} (User)",
                line=dict(color=color, width=1.5),  # Dashed lines for users
                marker=dict(color=color, size=3, opacity=0.8, symbol='circle'),
                hovertemplate='<b>%{fullData.name}</b><br>' +
                             '%{text}<br>' +
                             '<b>Projection:</b> %{y:.3f}<br>' +
                             '<extra></extra>',
                text=user_contexts,
                legendgroup=f"pc{pc_idx}",
                showlegend=True
            ))
        else:
            # Original single trace per PC
            fig.add_trace(go.Scatter(
                x=turn_indices,
                y=similarities,
                mode='lines+markers',
                name=pc_name,
                line=dict(color=color, width=1.5),
                marker=dict(color=color, size=3, opacity=0.8),
                hovertemplate='<b>%{fullData.name}</b><br>' +
                             '%{text}<br>' +
                             '<b>Projection:</b> %{y:.3f}<br>' +
                             '<extra></extra>',
                text=turn_contexts
            ))
    
    # Update layout
    default_title = f"Mean Response PC Trajectory"
    if projection:
        yaxis_title = "PC Projection"
    else:
        yaxis_title = "Cosine Similarity with PC"
    
    fig.update_layout(
        title=dict(
            text=title if title else default_title,
            x=0.5,
            font=dict(size=16),
            subtitle={"text": f"{model_readable}, Layer {layer}"}
        ),
        xaxis_title="Conversation Turn",
        yaxis_title=yaxis_title,
        width=1400,
        height=600,
        hovermode='closest',
        legend=dict(
            yanchor="middle",
            y=0.5,
            xanchor="left",
            x=1.02,
            bgcolor="rgba(255,255,255,0.8)"
        )
    )
    
    # Add grid for easier reading
    fig.update_xaxes(
        showgrid=True, 
        gridwidth=1, 
        gridcolor='lightgray',
        zeroline=True,
        tick0=0,
        dtick=1  # Show every turn
    )
    fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='lightgray', zeroline=True)
    
    # Add light vertical lines between turns for clarity
    for turn_idx in range(1, n_turns):
        fig.add_vline(
            x=turn_idx - 0.5,
            line_dash="dot",
            line_color="lightgray",
            line_width=1,
            opacity=0.3
        )
   
    
    return fig

In [143]:
# Create role PC titles
if model_short == "qwen-3-32b":
    role_pc_titles = ['Assistant-like ↔ role-playing', "mystical/transcendent ↔ mundane/irreverent", "empathetic/vulnerable ↔ analytical/predatory", "concrete/practical ↔ abstract/ideological", "thinking/passive ↔ doing/active", "creative/expressive ↔ rigid/constrained"]
elif model_short == "gemma-2-27b":
    role_pc_titles = ['Assistant-like ↔ role-playing', "inhuman ↔ human", "independent ↔ dependent", "nurturing ↔ adversarial", "social ↔ technical", "structured ↔ liminal"]

# # Plot role trajectory
# role_fig = plot_mean_response_trajectory(
#     role_sims, 
#     conversation=conversation, 
#     title=f"Conversation Trajectory in Role PC Space: {filename.capitalize()}", 
#     pc_titles=role_pc_titles
# )
# role_fig.show()
# role_fig.write_html(f"{plot_output_dir}/{filename}_role_cossim.html")

# role_fig = plot_mean_response_trajectory(
#     role_projs, 
#     conversation=conversation, 
#     title=f"Conversation Trajectory in Role PC Space: {filename.capitalize()}", 
#     pc_titles=role_pc_titles,
#     projection=True
# )
# role_fig.show()
# role_fig.write_html(f"{plot_output_dir}/{filename}_role_proj.html")

In [148]:
# Create trait PC titles
if model_short == "qwen-3-32b":
    trait_pc_titles = ["expressive/irreverent ↔ controlled/professional", "analytical ↔ intuitive", "accessible/practical ↔ esoteric/complex", "active/flexible ↔ passsive/rigid", "questioning ↔ confident", "indirect/diplomatic ↔ direct/assertive"]
elif model_short == "gemma-2-27b":
    trait_pc_titles = ["benevolent ↔ evil", "analytical ↔ emotional", "accommodating ↔ confrontational", "grounded ↔ mystical", "thoughtful ↔ decisive", "conformist ↔ defiant"]

# Plot trait trajectory
# trait_fig = plot_mean_response_trajectory(
#     trait_sims, 
#     conversation=conversation, 
#     title=f"Conversation Trajectory in Trait PC Space: {filename.capitalize()}", 
#     pc_titles=trait_pc_titles
# )
# trait_fig.show()
# trait_fig.write_html(f"{plot_output_dir}/{filename}_trait_cossim.html")

# trait_fig = plot_mean_response_trajectory(
#     trait_projs, 
#     conversation=conversation, 
#     title=f"Conversation Trajectory in Trait PC Space: {filename.capitalize()}", 
#     pc_titles=trait_pc_titles,
#     projection=True
# )

# trait_fig.show()
# trait_fig.write_html(f"{plot_output_dir}/{filename}_trait_proj.html")

In [144]:
import plotly.subplots as sp
import plotly.graph_objects as go
from collections import defaultdict

def combined_fig_split(projs, title, conversation=None):
    subplot_titles = [f"Conversation Domain: {domain.capitalize()}" for domain in domains]

    fig_asst = sp.make_subplots(rows=2, cols=2, subplot_titles=subplot_titles, vertical_spacing=0.1, horizontal_spacing=0.02)
    fig_user = sp.make_subplots(rows=2, cols=2, subplot_titles=subplot_titles, vertical_spacing=0.1, horizontal_spacing=0.02)

    shown_asst, shown_user = set(), set()

    # NEW: count how many traces we've added per (row, col, speaker)
    traces_added = defaultdict(int)

    for domain_idx, domain in enumerate(domains):
        row = (domain_idx // 2) + 1
        col = (domain_idx % 2) + 1

        traces = plot_mean_response_trajectory(
            projs[domain],
            conversation=conversation,
            pc_titles=roles,
            projection=True,
            include_user=True
        ).data

        for tr in traces:
            name = tr.name
            is_user = name.endswith(" (User)")
            role_name = name.rsplit(" (", 1)[0]
            speaker = "User" if is_user else "Assistant"

            group_key = f"{role_name}_{speaker}"

            # NEW: detect if this is the first trace for this subplot+speaker
            key = (row, col, speaker)
            is_first_in_subplot = (traces_added[key] == 0)

            if is_user:
                showlegend = group_key not in shown_user
                if showlegend: shown_user.add(group_key)
                tr.update(legendgroup=group_key, showlegend=showlegend, name=f"{role_name.capitalize()}")

                # NEW: style the first trace in this subplot (User figure)
                if is_first_in_subplot:
                    tr.update(
                        line=dict(color="grey", dash="dash", width=1.5),  # grey dashed line, same width
                        marker=dict(color="grey", symbol="diamond", size=3)  # grey diamond markers
                    )
                fig_user.add_trace(tr, row=row, col=col)
            else:
                showlegend = group_key not in shown_asst
                if showlegend: 
                    shown_asst.add(group_key)
                tr.update(legendgroup=group_key, showlegend=showlegend, name=f"{role_name.capitalize()}")

                # NEW: style the first trace in this subplot (Assistant figure)
                if is_first_in_subplot:
                    tr.update(
                        line=dict(color="grey", dash="dash", width=1.5),  # grey dashed line, same width
                        marker=dict(color="grey", symbol="diamond", size=3)  # grey diamond markers
                    )
                fig_asst.add_trace(tr, row=row, col=col)

            # NEW: increment after adding
            traces_added[key] += 1
        
        # hide y-axis ticks/labels for column 2
        for r in [1, 2]:
            fig_asst.update_yaxes(showticklabels=False, row=r, col=2)
            fig_user.update_yaxes(showticklabels=False, row=r, col=2)


    def _finish(fig, subtitle_suffix):
        fig.update_layout(
            height=700, width=1200,
            title={"text": f"{title} ({subtitle_suffix} Turns)",
                   "subtitle": {"text": f"{model_readable}, Layer {layer}"}},
            showlegend=True,
            legend=dict(
                traceorder="grouped",
                groupclick="togglegroup",
                x=1.02, xanchor="left", y=1, yanchor="top"
            ),
            legend_tracegroupgap=12
        )
        fig.update_xaxes(title_text="Conversation Turn", row=2, col=1)
        fig.update_xaxes(title_text="Conversation Turn", row=2, col=2)
        fig.update_yaxes(title_text="Role PC1 Projection", row=1, col=1)
        fig.update_yaxes(title_text="Role PC1 Projection", row=2, col=1)

    _finish(fig_asst, "Assistant")
    fig_asst.update_yaxes(range=[-30, 98])
    fig_asst.update_xaxes(range=[-1, 30])

    _finish(fig_user, "User")
    fig_user.update_yaxes(range=[-30, 98])
    fig_user.update_xaxes(range=[-1, 30])

    # make subplot titles slightly smaller
    for ann in fig_asst.layout.annotations:
        ann.font.size = 14  # default is ~16, so pick smaller

    for ann in fig_user.layout.annotations:
        ann.font.size = 14

    return fig_asst, fig_user


In [140]:
# fig_asst, fig_user = combined_fig_split(projs['role'], title="Role-playing Trajectories with Different Role Prompts", conversation=None)
# fig_asst.show()
# fig_asst.write_html(f"{plot_output_dir}/role_pc1_assistant.html")
# fig_user.show()
# fig_user.write_html(f"{plot_output_dir}/role_pc1_user.html")

In [162]:
def combined_fig_split_for_pc(projs_section, kind, title, pc_idx, pc_titles,conversation=None):
    """
    Build assistant/user subplot figures for a given PC index.
    projs_section[domain][pc_idx] must be (n_turns, n_roles).
    kind: str, either "Role" or "Trait"
    """
    subplot_titles = [f"Mean over 100 {domain.capitalize()} Conversations" for domain in domains]

    fig_asst = sp.make_subplots(rows=2, cols=2,
                                subplot_titles=subplot_titles,
                                vertical_spacing=0.1,
                                horizontal_spacing=0.02)
    fig_user = sp.make_subplots(rows=2, cols=2,
                                subplot_titles=subplot_titles,
                                vertical_spacing=0.1,
                                horizontal_spacing=0.02)

    shown_asst, shown_user = set(), set()
    traces_added = defaultdict(int)

    # --- NEW: compute global ranges across all domains ---
    all_x = []
    all_y = []
    for domain in domains:
        mat = projs_section[domain][pc_idx]  # (n_turns, n_roles)
        n_turns = mat.shape[0]
        all_x.extend(range(n_turns))
        all_y.extend(mat[~np.isnan(mat)].flatten())  # ignore NaNs
    global_x_range = [min(all_x) - 1, max(all_x) + 1]
    y_min, y_max = np.nanmin(all_y), np.nanmax(all_y)
    pad = 10
    global_y_range = [y_min - pad, y_max + pad]   # add ±10 padding

    for domain_idx, domain in enumerate(domains):
        row = (domain_idx // 2) + 1
        col = (domain_idx % 2) + 1

        # pick the matrix for this PC
        similarity_matrix = projs_section[domain][pc_idx]

        traces = plot_mean_response_trajectory(
            similarity_matrix,
            conversation=conversation,
            pc_titles=roles,
            projection=True,
            include_user=True
        ).data

        for tr in traces:
            name = tr.name
            is_user = name.endswith(" (User)")
            role_name = name.rsplit(" (", 1)[0]
            speaker = "User" if is_user else "Assistant"

            group_key = f"{role_name}_{speaker}"
            key = (row, col, speaker)
            is_first_in_subplot = (traces_added[key] == 0)

            if is_user:
                showlegend = group_key not in shown_user
                if showlegend: shown_user.add(group_key)
                tr.update(legendgroup=group_key,
                          showlegend=showlegend,
                          name=f"{role_name.capitalize()}")
                if is_first_in_subplot:
                    tr.update(line=dict(color="grey", dash="dash", width=1.5),
                              marker=dict(color="grey", symbol="diamond", size=3))
                fig_user.add_trace(tr, row=row, col=col)
            else:
                showlegend = group_key not in shown_asst
                if showlegend: shown_asst.add(group_key)
                tr.update(legendgroup=group_key,
                          showlegend=showlegend,
                          name=f"{role_name.capitalize()}")
                if is_first_in_subplot:
                    tr.update(line=dict(color="grey", dash="dash", width=1.5),
                              marker=dict(color="grey", symbol="diamond", size=3))
                fig_asst.add_trace(tr, row=row, col=col)

            traces_added[key] += 1

        # hide ticks for right column
        for r in [1, 2]:
            fig_asst.update_yaxes(showticklabels=False, row=r, col=2)
            fig_user.update_yaxes(showticklabels=False, row=r, col=2)

    def _finish(fig, subtitle_suffix):
        fig.update_layout(
            height=700, width=900,
            title={
                "text": f"{kind} PC{pc_idx+1} Projections — {title} ({subtitle_suffix} Turns)",
                "subtitle": {"text": f"{model_readable}, Layer {layer}"}
            },
            showlegend=True,
            legend=dict(
                title=dict(
                    text="Roleplaying as:",
                    font=dict(size=12)
                    ),
                traceorder="grouped",
                groupclick="togglegroup",
                x=1.02, xanchor="left", y=1, yanchor="top"
            ),
            legend_tracegroupgap=6
        )
        fig.update_xaxes(title_text="Conversation Turn", row=2, col=1)
        fig.update_xaxes(title_text="Conversation Turn", row=2, col=2)

        # descriptive y-axis label
        axis_label = f"{pc_titles[pc_idx]}"
        fig.update_yaxes(title_text=axis_label, row=1, col=1)
        fig.update_yaxes(title_text=axis_label, row=2, col=1)

        fig.update_xaxes(tickfont=dict(size=10), title_font=dict(size=10))
        fig.update_yaxes(tickfont=dict(size=10), title_font=dict(size=10))

        # enforce same ranges (with ±10 padding)
        for r in [1, 2]:
            for c in [1, 2]:
                fig.update_xaxes(range=global_x_range, row=r, col=c)
                fig.update_yaxes(range=global_y_range, row=r, col=c)


    _finish(fig_asst, "Assistant")
    _finish(fig_user, "User")

    # make subplot titles slightly smaller
    for ann in fig_asst.layout.annotations: ann.font.size = 14
    for ann in fig_user.layout.annotations: ann.font.size = 14

    return fig_asst, fig_user


In [163]:
for pc_idx in range(6):
    # Role projections
    fig_asst, fig_user = combined_fig_split_for_pc(
        role_pc_mats, kind="Role",
        title="Conversation Trajectories while Role-playing", pc_idx=pc_idx, pc_titles=role_pc_titles, conversation=None
    )
    fig_asst.show()
    fig_user.show()
    fig_asst.write_html(f"{plot_output_dir}/role_pc{pc_idx+1}_assistant.html")
    fig_user.write_html(f"{plot_output_dir}/role_pc{pc_idx+1}_user.html")

    # Trait projections
    fig_asst, fig_user = combined_fig_split_for_pc(
        trait_pc_mats, kind="Trait",
        title="Conversation Trajectories while Role-playing", pc_idx=pc_idx, pc_titles=trait_pc_titles, conversation=None
    )
    fig_asst.show()
    fig_user.show()
    fig_asst.write_html(f"{plot_output_dir}/trait_pc{pc_idx+1}_assistant.html")
    fig_user.write_html(f"{plot_output_dir}/trait_pc{pc_idx+1}_user.html")

