# Comparing role vs. trait PCA

- take top 10 trait PCs
- project roles into them
- measure the fraction of role variance preserved
- compare to projecting into top 10 role PCs as baseline
- and vice versa

In [1]:
import json
import os
import torch
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.subplots as sp
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.metrics.pairwise import cosine_similarity
from pathlib import Path
from tqdm import tqdm

In [14]:
role_dir = "roles_240"
trait_dir = "traits_240"
layer = 22

components = 10

In [5]:
# load in pca results
role_results = torch.load(f"/workspace/{role_dir}/pca/layer{layer}_pos23.pt", weights_only=False)
trait_results = torch.load(f"/workspace/{trait_dir}/pca/layer{layer}_pos-neg50.pt", weights_only=False)


In [None]:
# vectors
role_vectors = role_results['vectors']['pos_2'] + role_results['vectors']['pos_3']
trait_vectors = trait_results['vectors']['pos_neg_50']

print(len(role_vectors))
print(len(trait_vectors))

448
239
(5, 4608)
(5, 4608)


In [None]:
# stack vectors and extract layer
role_tensor = torch.stack(role_vectors)
trait_tensor = torch.stack(trait_vectors)

print(role_tensor.shape)
print(trait_tensor.shape)

# extract layer
role_layer_vectors = role_tensor[:, layer, :].float().numpy()
trait_layer_vectors = trait_tensor[:, layer, :].float().numpy()

print(role_layer_vectors.shape)
print(trait_layer_vectors.shape)

torch.Size([448, 46, 4608])
torch.Size([239, 46, 4608])
(448, 4608)
(239, 4608)


In [13]:
# scale the vectors
scaled_role_vectors = role_results['scaler'].transform(role_layer_vectors)
scaled_trait_vectors = trait_results['scaler'].transform(trait_layer_vectors)

print(scaled_role_vectors.shape)
print(scaled_trait_vectors.shape)

(448, 4608)
(239, 4608)


In [None]:
# project roles/traits into corresponding PCA
roles_in_trait_space = trait_results['pca'].transform(scaled_role_vectors)
traits_in_role_space = role_results['pca'].transform(scaled_trait_vectors)

roles_in_role_space = role_results['pca_transformed']
traits_in_trait_space = trait_results['pca_transformed']

print(roles_in_trait_space.shape)
print(traits_in_role_space.shape)

print(roles_in_role_space.shape)
print(traits_in_trait_space.shape)

(448, 239)
(239, 448)
(448, 448)
(239, 239)


In [19]:
# get fraction of variance preserved by first N PCs
roles_projected_trait = roles_in_trait_space[:, :components]  # first N trait PCs
roles_reconstructed = roles_projected_trait @ trait_results['pca'].components_[:components, :]
role_trait_variance_preserved = 1 - np.var(scaled_role_vectors - roles_reconstructed) / np.var(scaled_role_vectors)

traits_projected_role = traits_in_role_space[:, :components]  # first N role PCs
traits_reconstructed = traits_projected_role @ role_results['pca'].components_[:components, :]
trait_role_variance_preserved = 1 - np.var(scaled_trait_vectors - traits_reconstructed) / np.var(scaled_trait_vectors)

print(role_trait_variance_preserved)
print(trait_role_variance_preserved)

0.37631764246477517
0.43240921936796084


In [20]:
role_role_variance_preserved = role_results['variance_explained'][:components].sum()
trait_trait_variance_preserved = trait_results['variance_explained'][:components].sum()

print(role_role_variance_preserved)
print(trait_trait_variance_preserved)


0.6098394662418952
0.7123619371840361


In [32]:
# Grouped bar chart visualization
fig = go.Figure()

# Add all bars with proper grouping for legend
fig.add_trace(go.Bar(
    name='Cross-projection',
    x=['Role→Trait', 'Trait→Role'],
    y=[role_trait_variance_preserved, trait_role_variance_preserved],
    marker_color='lightcoral',
    text=[f'{role_trait_variance_preserved:.3f}', f'{trait_role_variance_preserved:.3f}'],
    textposition='outside'
))

fig.add_trace(go.Bar(
    name='Baseline (same space)',
    x=['Role→Role', 'Trait→Trait'],
    y=[role_role_variance_preserved, trait_trait_variance_preserved],
    marker_color='steelblue',
    text=[f'{role_role_variance_preserved:.3f}', f'{trait_trait_variance_preserved:.3f}'],
    textposition='outside'
))

# Update layout to force the x-axis order
fig.update_layout(
    title=dict(
        text=f'Variance Explained by Roles in Trait Space and Traits in Role Space (Top {components} PCs)',
        subtitle={
            'text': 'Gemma 2 27B, Layer 22 | 239 traits, 448 roles',
        }
    ),
    yaxis_title='Variance Explained',
    yaxis=dict(range=[0, 1]),
    xaxis_title='Projection Type',
    xaxis={'categoryorder': 'array', 'categoryarray': ['Role→Trait', 'Role→Role', 'Trait→Role', 'Trait→Trait']},
    width=800,
    showlegend=True
)

fig.show()