# Comparing role vs. trait PCA

In [1]:
import json
import os
import torch
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.subplots as sp
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.metrics.pairwise import cosine_similarity
from pathlib import Path
from tqdm import tqdm

In [2]:
role_dir = "roles_240"
trait_dir = "traits_240"
layer = 22

components = 10

## Compare variance explained

- take top 10 trait PCs
- project roles into them
- measure the fraction of role variance preserved
- compare to projecting into top 10 role PCs as baseline
- and vice versa

In [3]:
# load in pca results
role_results = torch.load(f"/workspace/{role_dir}/pca/layer{layer}_pos23.pt", weights_only=False)
trait_results = torch.load(f"/workspace/{trait_dir}/pca/layer{layer}_pos-neg50.pt", weights_only=False)


In [8]:
# vectors
role_vectors = role_results['vectors']['pos_2'] + role_results['vectors']['pos_3']
trait_tensor = trait_results['vectors']['pos_neg_50']

print(len(role_vectors))

448


In [9]:
# stack vectors and extract layer
role_tensor = torch.stack(role_vectors)

print(role_tensor.shape)
print(trait_tensor.shape)

# extract layer
role_layer_vectors = role_tensor[:, layer, :].float().numpy()
trait_layer_vectors = trait_tensor[:, layer, :].float().numpy()

print(role_layer_vectors.shape)
print(trait_layer_vectors.shape)

torch.Size([448, 46, 4608])
torch.Size([239, 46, 4608])
(448, 4608)
(239, 4608)


In [10]:
# scale the vectors
scaled_role_vectors = role_results['scaler'].transform(role_layer_vectors)
scaled_trait_vectors = trait_results['scaler'].transform(trait_layer_vectors)

print(scaled_role_vectors.shape)
print(scaled_trait_vectors.shape)

(448, 4608)
(239, 4608)


In [11]:
# project roles/traits into corresponding PCA
roles_in_trait_space = trait_results['pca'].transform(scaled_role_vectors)
traits_in_role_space = role_results['pca'].transform(scaled_trait_vectors)

roles_in_role_space = role_results['pca_transformed']
traits_in_trait_space = trait_results['pca_transformed']

print(roles_in_trait_space.shape)
print(traits_in_role_space.shape)

print(roles_in_role_space.shape)
print(traits_in_trait_space.shape)

(448, 239)
(239, 448)
(448, 448)
(239, 239)


In [12]:
# get fraction of variance preserved by first N PCs
roles_projected_trait = roles_in_trait_space[:, :components]  # first N trait PCs
roles_reconstructed = roles_projected_trait @ trait_results['pca'].components_[:components, :]
role_trait_variance_preserved = 1 - np.var(scaled_role_vectors - roles_reconstructed) / np.var(scaled_role_vectors)

traits_projected_role = traits_in_role_space[:, :components]  # first N role PCs
traits_reconstructed = traits_projected_role @ role_results['pca'].components_[:components, :]
trait_role_variance_preserved = 1 - np.var(scaled_trait_vectors - traits_reconstructed) / np.var(scaled_trait_vectors)

print(role_trait_variance_preserved)
print(trait_role_variance_preserved)

0.37631764246477517
0.43240921936796084


In [13]:
role_role_variance_preserved = role_results['variance_explained'][:components].sum()
trait_trait_variance_preserved = trait_results['variance_explained'][:components].sum()

print(role_role_variance_preserved)
print(trait_trait_variance_preserved)


0.6098394662418952
0.7123619371840361


In [None]:
# Grouped bar chart visualization
fig = go.Figure()

# Add all bars with proper grouping for legend
fig.add_trace(go.Bar(
    name='Cross-projection',
    x=['Role→Trait', 'Trait→Role'],
    y=[role_trait_variance_preserved, trait_role_variance_preserved],
    marker_color='lightcoral',
    text=[f'{role_trait_variance_preserved:.3f}', f'{trait_role_variance_preserved:.3f}'],
    textposition='outside'
))

fig.add_trace(go.Bar(
    name='Baseline (same space)',
    x=['Role→Role', 'Trait→Trait'],
    y=[role_role_variance_preserved, trait_trait_variance_preserved],
    marker_color='steelblue',
    text=[f'{role_role_variance_preserved:.3f}', f'{trait_trait_variance_preserved:.3f}'],
    textposition='outside'
))

# Update layout to force the x-axis order
fig.update_layout(
    title=dict(
        text=f'Variance Explained by Roles in Trait Space and Traits in Role Space (Top {components} PCs)',
        subtitle={
            'text': 'Gemma 2 27B, Layer 22 | 239 traits, 448 roles',
        }
    ),
    yaxis_title='Variance Explained',
    yaxis=dict(range=[0, 1]),
    xaxis_title='Projection Type',
    xaxis={'categoryorder': 'array', 'categoryarray': ['Role→Trait', 'Role→Role', 'Trait→Role', 'Trait→Trait']},
    width=800,
    showlegend=True
)

fig.show()

## Role-play basis in trait space

- project PC1 onto trait vectors to see which have high positive/negative loadings
- can PC1 be well-approximated by a linear combination of trait vectors?

In [14]:
# Load role PC1 and scale it using the trait scaler
rp_basis = torch.load('/root/git/persona-subspace/roles/data/pca_240/layer22_pos23_pc1.pt', weights_only=False)
print(f"Role PC1 shape: {rp_basis.shape}")


Role PC1 shape: (4608,)


In [17]:
# trait names
trait_names = trait_results['traits']['pos_neg_50']
print(trait_names[:10])

['zealous', 'wry', 'witty', 'whimsical', 'visceral', 'verbose', 'utilitarian', 'urgent', 'universalist', 'understated']


In [19]:
# Scale role PC1 using the trait scaler for consistent comparison
rp_basis_scaled = trait_results['scaler'].transform(rp_basis.reshape(1, -1))

# Project role PC1 into trait PCA space
rp_basis_in_trait_space = trait_results['pca'].transform(rp_basis_scaled)
print(f"Role PC1 in trait PCA space shape: {rp_basis_in_trait_space.shape}")

# Compute correlations between Role PC1 and each trait in PCA space
correlations = traits_in_trait_space @ rp_basis_in_trait_space.T
correlations = correlations.flatten()
print(f"Correlations shape: {correlations.shape}")

# Find top 15 traits most aligned with Assistant-like (most positive correlations)
top_roleplay_indices = np.argsort(correlations)[-15:][::-1]
print("\nTop 15 traits most aligned with Assistant-like (positive correlations):")
for i, idx in enumerate(top_roleplay_indices):
    print(f"{i+1:2d}. {trait_names[idx]:15s} (#{idx:3d}): {correlations[idx]:.4f}")

# Find top 15 traits most aligned with role-playing (most negative correlations)  
top_assistant_indices = np.argsort(correlations)[:15]
print("\nTop 15 traits most aligned with role-playing (negative correlations):")
for i, idx in enumerate(top_assistant_indices):
    print(f"{i+1:2d}. {trait_names[idx]:15s} (#{idx:3d}): {correlations[idx]:.4f}")

Role PC1 in trait PCA space shape: (1, 239)
Correlations shape: (239,)

Top 15 traits most aligned with Assistant-like (positive correlations):
 1. understated     (#  9): 921.5897
 2. reserved        (# 46): 837.3296
 3. calm            (#213): 831.8392
 4. dispassionate   (#173): 796.6787
 5. grounded        (#129): 783.7020
 6. literal         (#103): 779.3799
 7. moderate        (# 89): 731.1619
 8. factual         (#144): 707.0011
 9. avoidant        (#221): 700.2712
10. descriptive     (#178): 680.0480
11. concise         (#199): 666.5287
12. serious         (# 33): 652.0432
13. efficient       (#164): 651.3655
14. methodical      (# 94): 628.0478
15. traditional     (# 11): 626.6142

Top 15 traits most aligned with role-playing (negative correlations):
 1. theatrical      (# 13): -857.2169
 2. charismatic     (#208): -847.6765
 3. dramatic        (#169): -847.5021
 4. melodramatic    (# 97): -815.2751
 5. bombastic       (#216): -795.5102
 6. witty           (#  2): -759.6967
 7

In [20]:
# Alternative analysis: Direct cosine similarity with original trait vectors (before PCA)
print("=" * 80)
print("ANALYSIS 2: Direct cosine similarity with original trait vectors")
print("=" * 80)

# Use the scaled trait vectors (before PCA transformation)
normalized_trait_vectors = scaled_trait_vectors / np.linalg.norm(scaled_trait_vectors, axis=1, keepdims=True)
normalized_rp_basis = rp_basis_scaled.flatten() / np.linalg.norm(rp_basis_scaled.flatten())

print(f"Normalized trait vectors shape: {normalized_trait_vectors.shape}")
print(f"Normalized role PC1 shape: {normalized_rp_basis.shape}")

# Compute cosine similarities (dot products of normalized vectors)
cosine_similarities = normalized_trait_vectors @ normalized_rp_basis
print(f"Cosine similarities shape: {cosine_similarities.shape}")

# Find top 15 traits most aligned with role-playing (most positive cosine similarities)
top_roleplay_indices_cosine = np.argsort(cosine_similarities)[-15:][::-1]
print("\nTop 15 traits most aligned with role-playing (positive cosine similarity):")
for i, idx in enumerate(top_roleplay_indices_cosine):
    print(f"{i+1:2d}. {trait_names[idx]:15s} (#{idx:3d}): {cosine_similarities[idx]:.4f}")

# Find top 15 traits most aligned with Assistant-like (most negative cosine similarities)  
top_assistant_indices_cosine = np.argsort(cosine_similarities)[:15]
print("\nTop 15 traits most aligned with Assistant-like (negative cosine similarity):")
for i, idx in enumerate(top_assistant_indices_cosine):
    print(f"{i+1:2d}. {trait_names[idx]:15s} (#{idx:3d}): {cosine_similarities[idx]:.4f}")

ANALYSIS 2: Direct cosine similarity with original trait vectors
Normalized trait vectors shape: (239, 4608)
Normalized role PC1 shape: (4608,)
Cosine similarities shape: (239,)

Top 15 traits most aligned with role-playing (positive cosine similarity):
 1. reserved        (# 46): 0.7085
 2. grounded        (#129): 0.6994
 3. factual         (#144): 0.6716
 4. reductionist    (# 49): 0.6580
 5. literal         (#103): 0.6317
 6. descriptive     (#178): 0.6190
 7. dispassionate   (#173): 0.6112
 8. understated     (#  9): 0.6107
 9. materialist     (#100): 0.6043
10. traditional     (# 11): 0.6006
11. moderate        (# 89): 0.5931
12. calm            (#213): 0.5914
13. rationalist     (# 52): 0.5724
14. secular         (# 35): 0.5669
15. convergent      (#191): 0.5477

Top 15 traits most aligned with Assistant-like (negative cosine similarity):
 1. dramatic        (#169): -0.7331
 2. melodramatic    (# 97): -0.7026
 3. theatrical      (# 13): -0.6803
 4. rhetorical      (# 43): -0.6652

In [31]:
def plot_trait_correlations_combined(pca_correlations, cosine_correlations, trait_names, n_top=15):
    """
    Plot both PCA and cosine similarity correlations as subplots.
    """
    # Create subplot figure
    fig = sp.make_subplots(
        rows=2, cols=1,
        row_heights=[0.5, 0.5],
        vertical_spacing=0.15,
        subplot_titles=[
            'Correlations in Trait PCA Space',
            'Direct Cosine Similarity Analysis'
        ]
    )
    
    # Function to add correlation plot to specific subplot
    def add_correlation_subplot(correlations, row_num, x_title):
        # Sort correlations
        sorted_indices = np.argsort(correlations)
        
        # Get extreme indices
        low_extreme_indices = sorted_indices[:n_top]
        high_extreme_indices = sorted_indices[-n_top:][::-1]
        extreme_indices = set(list(low_extreme_indices) + list(high_extreme_indices))
        
        # Split points into regular and extreme
        regular_x, regular_y = [], []
        extreme_x, extreme_y, extreme_labels = [], [], []
        
        for i, corr in enumerate(correlations):
            if i in extreme_indices:
                extreme_x.append(corr)
                extreme_y.append(1)
                extreme_labels.append(trait_names[i])
            else:
                regular_x.append(corr)
                regular_y.append(1)
        
        # Add regular points (hover only)
        fig.add_trace(
            go.Scatter(
                x=regular_x,
                y=regular_y,
                mode='markers',
                marker=dict(
                    color='steelblue',
                    size=8,
                    opacity=1.0,
                    symbol='circle',
                    line=dict(width=1, color='black')
                ),
                text=[trait_names[i] for i in range(len(correlations)) if i not in extreme_indices],
                showlegend=False,
                hovertemplate='<b>%{text}</b><br>Value: %{x:.3f}<extra></extra>'
            ),
            row=row_num, col=1
        )
        
        # Add extreme points
        fig.add_trace(
            go.Scatter(
                x=extreme_x,
                y=extreme_y,
                mode='markers',
                marker=dict(
                    color='steelblue',
                    size=8,
                    opacity=1.0,
                    symbol='circle',
                    line=dict(width=1, color='black')
                ),
                text=extreme_labels,
                showlegend=False,
                hovertemplate='<b>%{text}</b><br>Value: %{x:.3f}<extra></extra>'
            ),
            row=row_num, col=1
        )
        
        # Create predefined alternating heights with variation
        high_positions = [1.6, 1.45, 1.55, 1.35, 1.5, 1.4, 1.65, 1.3, 1.58, 1.42, 1.52, 1.38, 1.62, 1.43, 1.57]
        low_positions = [0.4, 0.55, 0.45, 0.65, 0.5, 0.6, 0.35, 0.7, 0.42, 0.58, 0.48, 0.62, 0.38, 0.67, 0.43]
        
        # Alternate high-low pattern
        all_y_positions = []
        for i in range(n_top):
            all_y_positions.extend([high_positions[i % len(high_positions)], low_positions[i % len(low_positions)]])
        
        # Handle low extremes (Assistant-like, left side, red)
        for i, idx in enumerate(low_extreme_indices):
            x_pos = correlations[idx]
            label = trait_names[idx]
            leader_color = 'red'
            y_label = all_y_positions[i]
            
            # Add leader line
            fig.add_trace(
                go.Scatter(
                    x=[x_pos, x_pos],
                    y=[1.0, y_label],
                    mode='lines',
                    line=dict(color=leader_color, width=1),
                    showlegend=False,
                    hoverinfo='skip'
                ),
                row=row_num, col=1
            )
            
            # Add label
            fig.add_annotation(
                x=x_pos,
                y=y_label,
                text=label,
                showarrow=False,
                font=dict(size=10, color=leader_color),
                bgcolor="rgba(255, 255, 255, 0.9)",
                bordercolor=leader_color,
                borderwidth=1,
                row=row_num, col=1
            )
        
        # Handle high extremes (Role-playing, right side, blue)
        for i, idx in enumerate(high_extreme_indices):
            x_pos = correlations[idx]
            label = trait_names[idx]
            leader_color = 'blue'
            y_label = all_y_positions[i + n_top]
            
            # Add leader line
            fig.add_trace(
                go.Scatter(
                    x=[x_pos, x_pos],
                    y=[1.0, y_label],
                    mode='lines',
                    line=dict(color=leader_color, width=1),
                    showlegend=False,
                    hoverinfo='skip'
                ),
                row=row_num, col=1
            )
            
            # Add label
            fig.add_annotation(
                x=x_pos,
                y=y_label,
                text=label,
                showarrow=False,
                font=dict(size=10, color=leader_color),
                bgcolor="rgba(255, 255, 255, 0.9)",
                bordercolor=leader_color,
                borderwidth=1,
                row=row_num, col=1
            )
        
        # Add vertical line at x=0
        fig.add_vline(
            x=0,
            line_dash="solid",
            line_color="gray",
            line_width=1,
            opacity=0.7,
            row=row_num, col=1
        )
        
        # Calculate symmetric range for this subplot
        max_abs_value = max(abs(min(correlations)), abs(max(correlations)))
        x_half_width = max_abs_value * 1.1
        
        # Update x-axis for this subplot
        fig.update_xaxes(
            title_text=x_title,
            range=[-x_half_width, x_half_width],
            row=row_num, col=1
        )
        
        # Update y-axis for this subplot
        fig.update_yaxes(
            title_text="",
            showticklabels=False,
            range=[0.25, 1.75],
            row=row_num, col=1
        )
    
    # Add both subplots
    add_correlation_subplot(pca_correlations, 1, "Projection onto Role PC1")
    add_correlation_subplot(cosine_correlations, 2, "Cosine Similarity with Role PC1")
    
    # Update overall layout
    fig.update_layout(
        height=800,
        title=dict(
            text='Role PC1 (Assistant-like vs Role-playing) Compared to Traits',
            subtitle={
                "text": f"Gemma 2 27B, Layer 22 | {len(trait_names)} traits",
            },
            x=0.5,
            font=dict(size=16)
        ),
        showlegend=False
    )
    
    return fig

# Plot combined analysis
fig_combined = plot_trait_correlations_combined(correlations, cosine_similarities, trait_names)
fig_combined.show()

In [32]:
fig_combined.write_html('results/role_pc1_vs_traits.html')