# PCA on roles

In [2]:
import os
import sys
import torch
import numpy as np
import pandas as pd

sys.path.append('.')
sys.path.append('..')

from utils.pca_utils import *
from plots import *

## Configuration

In [52]:
# Configuration - Change these parameters for different models/datasets
base_dir = "/workspace/gemma-2-27b"
type = "roles_240"
dir = f"{base_dir}/{type}"
model_name = "Gemma-2-27B"
layer = 22

## Plots

In [53]:

type = "pos23"

In [54]:
# load in PCs
pca_results = torch.load(f"{dir}/pca/layer{layer}_{type}.pt", weights_only=False)
default_vectors = torch.load(f"{dir}/default_vectors.pt")

In [6]:
plot_dir = f"./results/{model_name.lower()}/pca_240"
os.makedirs(plot_dir, exist_ok=True)

In [55]:
# also calculate role labels for plotting
def get_role_labels(pca_results):
    labels = []
    if 'pos_2' in pca_results['roles'].keys():
        pos_2_roles = [role.replace('_', ' ').title() for role in pca_results['roles']['pos_2']]
        pos_2_roles = [f"{role} (Somewhat RP)" for role in pos_2_roles]
        labels.extend(pos_2_roles)
    if 'pos_3' in pca_results['roles'].keys():
        pos_3_roles = [role.replace('_', ' ').title() for role in pca_results['roles']['pos_3']]
        pos_3_roles = [f"{role} (Fully RP)" for role in pos_3_roles]
        labels.extend(pos_3_roles)
    return labels

role_labels = get_role_labels(pca_results)



In [8]:
# get default activation and project into PCA space
assistant_layer_activation = default_vectors['activations']['default_1'][layer, :].float().numpy().reshape(1, -1)
asst_scaled = pca_results['scaler'].transform(assistant_layer_activation)
asst_projected = pca_results['pca'].transform(asst_scaled)


In [56]:
vectors = torch.stack(pca_results['vectors']['pos_2'] + pca_results['vectors']['pos_3'])[:, layer, :].float()
print(vectors.shape)

torch.Size([448, 4608])


In [57]:
# scale using standard scaler
scaled_vectors = pca_results['scaler'].transform(vectors)
print(scaled_vectors.shape)

(448, 4608)


In [58]:
# compute L2 norm of scaled data
l2_norms = np.linalg.norm(scaled_vectors, ord=2, axis=1)
print(l2_norms.shape)

(448,)


In [59]:
import plotly.graph_objects as go
import plotly.express as px

# Get lengths to determine fully RP vs somewhat RP
n_pos_2 = len(pca_results['vectors']['pos_2'])  # somewhat RP
n_pos_3 = len(pca_results['vectors']['pos_3'])  # fully RP

# Split the norms based on the vector types
somewhat_rp_norms = l2_norms[:n_pos_2]
fully_rp_norms = l2_norms[n_pos_2:]

print(f"Somewhat RP vectors: {len(somewhat_rp_norms)}")
print(f"Fully RP vectors: {len(fully_rp_norms)}")

# Create histogram data
fig = go.Figure()

# Add somewhat RP histogram
fig.add_trace(go.Histogram(
    x=somewhat_rp_norms,
    name='Somewhat RP',
    opacity=0.7,
    nbinsx=30,
    marker_color='cyan'
))

# Add fully RP histogram  
fig.add_trace(go.Histogram(
    x=fully_rp_norms,
    name='Fully RP',
    opacity=0.7,
    nbinsx=30,
    marker_color='blue'
))

# Update layout for stacked bars
fig.update_layout(
    barmode='stack',
    title={
        'text': f'L2 Norms Distribution of Role-Playing Vectors',
        'subtitle': {
            'text': f'{model_name.replace("-", " ")}, Layer {layer}'
        },
    },
    xaxis_title='L2 Norm',
    yaxis_title='Count',
    bargap=0.2
)

fig.show()

Somewhat RP vectors: 173
Fully RP vectors: 275


In [60]:
# Alternative: Scatter plot with hover labels showing role names
import plotly.express as px

# Create dataframe with norms and labels
somewhat_rp_labels = role_labels[:n_pos_2]
fully_rp_labels = role_labels[n_pos_2:]

df = pd.DataFrame({
    'L2_Norm': np.concatenate([somewhat_rp_norms, fully_rp_norms]),
    'Type': ['Somewhat RP'] * len(somewhat_rp_norms) + ['Fully RP'] * len(fully_rp_norms),
    'Role': role_labels,
    'Index': range(len(role_labels))
})

# Create scatter plot
fig = px.scatter(df, 
                 x='L2_Norm', 
                 y='Index',
                 color='Type',
                 color_discrete_map={'Somewhat RP': 'cyan', 'Fully RP': 'blue'},
                 hover_data=['Role'],
                 title=f'L2 Norms by Role - {model_name} Layer {layer}',
                 labels={'Index': 'Role Index'})

fig.show()