# PCA on roles and traits

In [1]:
import json
import os
import torch
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.subplots as sp
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.metrics.pairwise import cosine_similarity
from pathlib import Path
from tqdm import tqdm


In [2]:
# roles or roles_240
role_dir = "roles_240" 

# 30 or 240
if role_dir == "roles":
    n_questions = 30
    n_prompt_types = 2
elif dir == "roles_240":
    n_questions = 240
    n_prompt_types = 1

# traits or traits_240
trait_dir = "traits_240"


## Load vectors

In [29]:
# load all vectors 
role_vector_dir = f"/workspace/{role_dir}/vectors"

# iterate through each .pt file in the directory
role_vectors = {}
for file in os.listdir(role_vector_dir):
    if file.endswith(".pt"):
        role_vectors[file.replace(".pt", "")] = torch.load(os.path.join(role_vector_dir, file))

print(f"Found {len(role_vectors.keys())} roles with vectors")

Found 275 roles with vectors


In [30]:
print(role_vectors['graduate'].keys())

dict_keys(['pos_2', 'pos_3', 'pos_all'])


In [31]:
trait_vector_dir = f"/workspace/{trait_dir}/vectors"

trait_vectors = {}
for file in os.listdir(trait_vector_dir):
    if file.endswith(".pt"):
        trait_vectors[file.replace(".pt", "")] = torch.load(os.path.join(trait_vector_dir, file))

print(f"Found {len(trait_vectors.keys())} traits with vectors")

Found 240 traits with vectors


## PCA + plotting functions

In [3]:
def compute_pca(activation_list, layer):
    layer_activations = activation_list[:, layer, :]
    
    scaler = StandardScaler()
    scaled_layer_activations = scaler.fit_transform(layer_activations)

    pca = PCA()
    pca_transformed = pca.fit_transform(scaled_layer_activations)

    variance_explained = pca.explained_variance_ratio_
    cumulative_variance = np.cumsum(variance_explained)
    n_components = len(variance_explained)

    print(f"PCA fitted with {n_components} components")
    print(f"Cumulative variance for first 5 components: {cumulative_variance[:5]}")

    # Find elbow using second derivative method
    def find_elbow_point(variance_explained):
        """Find elbow point using second derivative method"""
        # Calculate first and second derivatives
        first_diff = np.diff(variance_explained)
        second_diff = np.diff(first_diff) 
        
        # Find point with maximum second derivative (most curvature)
        elbow_idx = np.argmax(np.abs(second_diff)) + 1  # +1 to account for diff operations
        return elbow_idx

    elbow_point = find_elbow_point(variance_explained)
    dims_70_pca = np.argmax(cumulative_variance >= 0.70) + 1
    dims_80_pca = np.argmax(cumulative_variance >= 0.80) + 1
    dims_90_pca = np.argmax(cumulative_variance >= 0.90) + 1
    dims_95_pca = np.argmax(cumulative_variance >= 0.95) + 1

    print("\nPCA Analysis Results:")
    print(f"Elbow point at component: {elbow_point + 1}")
    print(f"Dimensions for 70% variance: {dims_70_pca}")
    print(f"Dimensions for 80% variance: {dims_80_pca}")
    print(f"Dimensions for 90% variance: {dims_90_pca}")
    print(f"Dimensions for 95% variance: {dims_95_pca}")

    return pca_transformed, variance_explained, n_components, pca, scaler 

In [17]:
print(pca_results['pca_transformed'].shape)
print(pca_results['pca'].components_.shape)

(798, 798)
(798, 4608)


In [None]:
def plot_pca_cosine_similarity(pca_results, role_labels, role_urls, pc_component, 
                             layer, dir, type, assistant_activation=None, original_vectors=False):
    """
    Create a plot similar to the PC1 Cosine Similarity visualization.
    Shows labels on hover for most points, with visible labels and leader lines 
    for the 20 traits at either end of the range to avoid overlap.
    
    Parameters:
    - pca_transformed: PCA-transformed data (n_samples, n_components)
    - role_labels: List of labels for each data point
    - role_urls: List of URLs for each data point (for clicking)
    - pc_component: Which PC component to use (0-indexed, so PC1 = 0)
    - layer: Layer number for title
    
    Returns:
    - Plotly figure object
    """
    pc_vector = pca_results['pca'].components_[pc_component]
    # Extract the specified PC component
    pc_values = pca_results['pca_transformed'][:, pc_component]
    if assistant_activation is not None:
        assistant_pc_value = assistant_activation[pc_component]
    
    # Calculate cosine similarities
    cosine_sims = pc_values / np.linalg.norm(pc_values)  # Normalized PC values
    if assistant_activation is not None:
        assistant_cosine_sim = assistant_pc_value / np.linalg.norm(np.concatenate([pc_values, [assistant_pc_value]]))
    
    # Get vector counts from pca_results structure
    n_pos_2 = len(pca_results['vectors']['roles_pos_2'])
    n_pos_3 = len(pca_results['vectors']['roles_pos_3'])
    n_pos_70 = len(pca_results['vectors']['traits_pos_70'])
    n_pos_40_70 = len(pca_results['vectors']['traits_pos_40_70'])
    
    # Function to determine vector type, color, and appropriate text color for a given index
    def get_vector_type_and_colors(idx):
        if idx < n_pos_2:
            return 'cyan', 'roles_pos_2', 'black'  # black text on cyan
        elif idx < n_pos_2 + n_pos_3:
            return 'blue', 'roles_pos_3', 'white'  # white text on blue
        elif idx < n_pos_2 + n_pos_3 + n_pos_40_70:
            return 'lime', 'traits_pos_40_70', 'black'  # black text on lime
        else:
            return 'green', 'traits_pos_70', 'white'  # white text on green
    
    # Identify extreme traits (10 lowest and 10 highest)
    sorted_indices = np.argsort(cosine_sims)
    low_extreme_indices = sorted_indices[:10]
    high_extreme_indices = sorted_indices[-10:]
    extreme_indices = set(list(low_extreme_indices) + list(high_extreme_indices))
    
    # Create subplot figure
    fig = sp.make_subplots(
        rows=2, cols=1,
        row_heights=[0.6, 0.4],
        vertical_spacing=0.1,
        subplot_titles=[
            f'PC{pc_component+1} Cosine Similarity',
            'Vector Type Frequency Distribution'
        ]
    )
    
    # Split points by type for legend and plotting
    start_idx = 0
    
    # Helper function to split regular and extreme points by type
    def split_points_by_extremity(start, end, color, shape, legend_name, legend_group):
        regular_x, regular_y, regular_labels, regular_urls = [], [], [], []
        extreme_x, extreme_y, extreme_labels, extreme_urls = [], [], [], []
        
        for i in range(start, end):
            if i in extreme_indices:
                extreme_x.append(cosine_sims[i])
                extreme_y.append(1)
                extreme_labels.append(role_labels[i])
                extreme_urls.append(role_urls[i])
            else:
                regular_x.append(cosine_sims[i])
                regular_y.append(1)
                regular_labels.append(role_labels[i])
                regular_urls.append(role_urls[i])
        
        # Add regular points
        if regular_x:
            fig.add_trace(
                go.Scatter(
                    x=regular_x,
                    y=regular_y,
                    mode='markers',
                    marker=dict(
                        color=color,
                        size=8,
                        opacity=1.0,
                        symbol=shape,
                        line=dict(width=1, color='black')
                    ),
                    text=regular_labels,
                    customdata=regular_urls,
                    name=legend_name,
                    legendgroup=legend_group,
                    hovertemplate='<b>%{text}</b><br>Cosine Similarity: %{x:.3f}<br>Click to view<extra></extra>'
                ),
                row=1, col=1
            )
        
        # Add extreme points
        if extreme_x:
            fig.add_trace(
                go.Scatter(
                    x=extreme_x,
                    y=extreme_y,
                    mode='markers',
                    marker=dict(
                        color=color,
                        size=8,
                        opacity=1.0,
                        symbol=shape,
                        line=dict(width=1, color='black')
                    ),
                    text=extreme_labels,
                    customdata=extreme_urls,
                    name=legend_name,
                    legendgroup=legend_group,
                    showlegend=False,
                    hovertemplate='<b>%{text}</b><br>Cosine Similarity: %{x:.3f}<br>Click to view<extra></extra>'
                ),
                row=1, col=1
            )
    
    # Add traces for each vector type (swapped trait order)
    end_idx = start_idx + n_pos_2
    split_points_by_extremity(start_idx, end_idx, 'cyan', 'circle', 'Somewhat Role-Playing', 'pos2')
    start_idx = end_idx
    
    end_idx = start_idx + n_pos_3
    split_points_by_extremity(start_idx, end_idx, 'blue', 'circle', 'Fully Role-Playing', 'pos3')
    start_idx = end_idx
    
    end_idx = start_idx + n_pos_40_70
    split_points_by_extremity(start_idx, end_idx, 'lime', 'square', 'Somewhat Exhibiting Trait', 'pos40_70')
    start_idx = end_idx
    
    end_idx = start_idx + n_pos_70
    split_points_by_extremity(start_idx, end_idx, 'green', 'square', 'Fully Exhibiting Trait', 'pos70')
    
    # Add leader lines and annotations for extreme points
    if len(extreme_indices) > 0:
        # Create predefined alternating heights with variation
        # High positions with variation
        high_positions = [1.6, 1.45, 1.55, 1.35, 1.5, 1.4, 1.65, 1.3, 1.58, 1.42]
        # Low positions with variation  
        low_positions = [0.4, 0.55, 0.45, 0.65, 0.5, 0.6, 0.35, 0.7, 0.42, 0.58]
        
        # Alternate high-low pattern
        all_y_positions = []
        for i in range(10):
            all_y_positions.extend([high_positions[i], low_positions[i]])
        
        # Handle low extremes (10 lowest cosine similarities)
        for i, idx in enumerate(low_extreme_indices):
            x_pos = cosine_sims[idx]
            label = role_labels[idx]
            leader_color = 'black'
            bgcolor_color, _, text_color = get_vector_type_and_colors(idx)  # Get colors for this vector type
            y_label = all_y_positions[i]
            
            # Add leader line as a separate trace
            fig.add_trace(
                go.Scatter(
                    x=[x_pos, x_pos],
                    y=[1.0, y_label],
                    mode='lines',
                    line=dict(color=leader_color, width=1),
                    showlegend=False,
                    hoverinfo='skip'
                ),
                row=1, col=1
            )
            
            # Add label at the end of the line with matching bgcolor and appropriate text color
            fig.add_annotation(
                x=x_pos,
                y=y_label,
                text=label,
                showarrow=False,
                font=dict(size=10, color=text_color),
                bgcolor=bgcolor_color,
                bordercolor='black',
                borderwidth=1,
                row=1, col=1
            )
        
        # Handle high extremes (10 highest cosine similarities)
        for i, idx in enumerate(high_extreme_indices):
            x_pos = cosine_sims[idx]
            label = role_labels[idx]
            leader_color = 'black'
            bgcolor_color, _, text_color = get_vector_type_and_colors(idx)  # Get colors for this vector type
            y_label = all_y_positions[i + 10]  # Offset by 10 to continue the pattern
            
            # Add leader line as a separate trace
            fig.add_trace(
                go.Scatter(
                    x=[x_pos, x_pos],
                    y=[1.0, y_label],
                    mode='lines',
                    line=dict(color=leader_color, width=1),
                    showlegend=False,
                    hoverinfo='skip'
                ),
                row=1, col=1
            )
            
            # Add label at the end of the line with matching bgcolor and appropriate text color
            fig.add_annotation(
                x=x_pos,
                y=y_label,
                text=label,
                showarrow=False,
                font=dict(size=10, color=text_color),
                bgcolor=bgcolor_color,
                bordercolor='black',
                borderwidth=1,
                row=1, col=1
            )
    
    # Add vertical line at x=0 for both panels
    fig.add_vline(
        x=0,
        line_dash="solid",
        line_color="gray",
        line_width=1,
        opacity=0.7,
        row=1, col=1
    )

    if assistant_activation is not None:
        # Add black dashed vertical line for assistant position
        fig.add_vline(x=assistant_cosine_sim, line_dash="dash", line_color="red", line_width=1, opacity=1.0, row=1, col=1)
        
        # Add Assistant label at same height as extremes
        assistant_y_position = 1.6  # Same as first high position
        fig.add_annotation(
            x=assistant_cosine_sim,
            y=assistant_y_position,
            text="Assistant",
            showarrow=False,
            font=dict(size=10, color="red"),
            bgcolor="rgba(255, 255, 255, 0.9)",
            bordercolor="red",
            borderwidth=1,
            row=1, col=1
        )
        
    fig.add_vline(
        x=0,
        line_dash="solid", 
        line_color="gray",
        line_width=1,
        opacity=0.7,
        row=2, col=1
    )
    
    # Bottom panel: Histogram with 4 vector types (swapped trait order)
    # Calculate histogram bins manually
    nbins = 30
    min_val = min(cosine_sims)
    max_val = max(cosine_sims)
    bin_edges = np.linspace(min_val, max_val, nbins + 1)
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
    bin_width = bin_edges[1] - bin_edges[0]
    
    # Split cosine similarities by type (swapped trait order)
    start_idx = 0
    pos2_cosine_sims = cosine_sims[start_idx:start_idx + n_pos_2]
    start_idx += n_pos_2
    pos3_cosine_sims = cosine_sims[start_idx:start_idx + n_pos_3]
    start_idx += n_pos_3
    pos40_70_cosine_sims = cosine_sims[start_idx:start_idx + n_pos_40_70]
    start_idx += n_pos_40_70
    pos70_cosine_sims = cosine_sims[start_idx:start_idx + n_pos_70]
    
    # Count occurrences in each bin for all types
    pos2_counts, _ = np.histogram(pos2_cosine_sims, bins=bin_edges)
    pos3_counts, _ = np.histogram(pos3_cosine_sims, bins=bin_edges)
    pos40_70_counts, _ = np.histogram(pos40_70_cosine_sims, bins=bin_edges)
    pos70_counts, _ = np.histogram(pos70_cosine_sims, bins=bin_edges)
    
    # Add stacked bars for each type (swapped trait order)
    histogram_data = [
        (pos2_counts, 'Somewhat Role-Playing', 'cyan', '.', 'pos2'),
        (pos3_counts, 'Fully Role-Playing', 'blue', '.', 'pos3'),
        (pos40_70_counts, 'Somewhat Exhibiting Trait', 'lime', '+', 'pos40_70'),
        (pos70_counts, 'Fully Exhibiting Trait', 'green', '+', 'pos70')
    ]
    
    for counts, type_name, bar_color, pattern_shape, legendgroup in histogram_data:
        pattern_config = dict(shape=pattern_shape, bgcolor=bar_color, fgcolor="black")
        
        fig.add_trace(
            go.Bar(
                x=bin_centers,
                y=counts,
                width=bin_width * 0.9,
                marker=dict(
                    color=bar_color,
                    pattern=pattern_config
                ),
                opacity=1.0,
                name=type_name,
                legendgroup=legendgroup,
                showlegend=False,  # Don't show legend for histogram bars
                hovertemplate=f'{type_name}<br>Count: %{{y}}<extra></extra>'
            ),
            row=2, col=1
        )

    title = "PCA on Role and Trait Vectors"
    subtitle = f"Gemma 2 27B, Layer {layer}"
    if dir == "roles":
        subtitle += " - Unique Question Set"
    elif dir == "roles_240":
        subtitle += " - Shared Question Set"

    # Update layout with legend
    fig.update_layout(
        height=700,
        title=dict(
            text=title,
            subtitle={
                "text": subtitle,
            },
            x=0.5,
            font=dict(size=16)
        ),
        showlegend=True,
        barmode='stack',  # Enable stacked bars
        legend=dict(
            x=0.0,
            y=1.04,
            xanchor='left',
            yanchor='bottom',
            bgcolor="rgba(255, 255, 255, 0.8)",
            bordercolor="gray",
            borderwidth=1
        )
    )
    
    # Calculate symmetric range around 0 (not around data center)
    max_abs_value = max(abs(min(cosine_sims)), abs(max(cosine_sims)))
    x_half_width = max_abs_value * 1.1  # Add 10% padding
    
    # Update x-axes with symmetric ranges centered on 0
    fig.update_xaxes(
        row=1, col=1,
        range=[-x_half_width, x_half_width]
    )
    
    fig.update_xaxes(
        title_text=f"PC{pc_component+1} Cosine Similarity",
        row=2, col=1,
        range=[-x_half_width, x_half_width]
    )
    
    # Update y-axes
    fig.update_yaxes(
        title_text="",
        showticklabels=False,
        row=1, col=1,
        range=[0.25, 1.75]  # Range for varied label heights
    )
    
    fig.update_yaxes(
        title_text="Frequency",
        row=2, col=1
    )
    
    return fig

In [5]:
def plot_3d_pca(pca_results, role_labels, role_urls, layer, dir, type, assistant_activation=None):
    # Create 3D scatter plot if we have enough components
    pca_transformed = pca_results['pca_transformed']
    variance_explained = pca_results['variance_explained']
    
    # Get vector counts from pca_results structure
    n_pos_2 = len(pca_results['vectors']['roles_pos_2'])
    n_pos_3 = len(pca_results['vectors']['roles_pos_3'])
    n_pos_70 = len(pca_results['vectors']['traits_pos_70'])
    n_pos_40_70 = len(pca_results['vectors']['traits_pos_40_70'])

    # Initialize figure
    fig_3d = go.Figure()
    
    # Helper function to add traces with labels for subset of points
    def add_3d_trace(start_idx, count, color, shape, legend_name):
        end_idx = start_idx + count
        
        # Select subset of points to display labels (every 10th point)
        label_indices = list(range(start_idx, end_idx, 10))
        text_labels = [role_labels[i] if i in label_indices else '' for i in range(start_idx, end_idx)]
        
        fig_3d.add_trace(go.Scatter3d(
            x=pca_transformed[start_idx:end_idx, 0],
            y=pca_transformed[start_idx:end_idx, 1], 
            z=pca_transformed[start_idx:end_idx, 2],
            mode='markers+text',
            text=text_labels,
            textposition='top center',
            textfont=dict(size=6),
            marker=dict(
                size=3,
                color=color,
                symbol=shape
            ),
            name=legend_name,
            customdata=role_urls[start_idx:end_idx],
            hovertemplate='<b>%{hovertext}</b><br>' +
                        f'PC1: %{{x:.3f}}<br>' +
                        f'PC2: %{{y:.3f}}<br>' +
                        f'PC3: %{{z:.3f}}<br>' +
                        'Click to view<extra></extra>',
            hovertext=role_labels[start_idx:end_idx]
        ))
        return end_idx
    
    # Add traces for each vector type with consistent styling (swapped trait order)
    start_idx = 0
    start_idx = add_3d_trace(start_idx, n_pos_2, 'cyan', 'circle', 'Somewhat Role-Playing')
    start_idx = add_3d_trace(start_idx, n_pos_3, 'blue', 'circle', 'Fully Role-Playing')
    start_idx = add_3d_trace(start_idx, n_pos_40_70, 'lime', 'square', 'Somewhat Exhibiting Trait')
    start_idx = add_3d_trace(start_idx, n_pos_70, 'green', 'square', 'Fully Exhibiting Trait')
    
    if assistant_activation is not None:
        fig_3d.add_trace(go.Scatter3d(
            x=[assistant_activation[0]],
            y=[assistant_activation[1]],
            z=[assistant_activation[2]],
            mode='markers+text',
            text=['Assistant'],
            textposition='top center',
            textfont=dict(size=8, color='black'),
            marker=dict(
                size=5,  # 2 sizes bigger than trait dots (3 -> 5)
                color='red',
                opacity=1.0
            ),
            showlegend=False,
            hovertemplate='<b>Assistant</b><br>' +
                        f'PC1: %{{x:.3f}}<br>' +
                        f'PC2: %{{y:.3f}}<br>' +
                        f'PC3: %{{z:.3f}}<br>' +
                        '<extra></extra>'
        ))

    title = "Role and Trait Vectors in 3D PC Space"
    subtitle = f"Gemma 2 27B, Layer {layer}"
    if dir == "roles":
        subtitle += " - Unique Question Set"
    elif dir == "roles_240":
        subtitle += " - Shared Question Set"
    
    fig_3d.update_layout(
        title={
            "text": title,
            "subtitle": {
                "text": subtitle,
            },
        },
        scene=dict(
            xaxis_title=f'PC1 ({variance_explained[0]*100:.1f}%)',
            yaxis_title=f'PC2 ({variance_explained[1]*100:.1f}%)',
            zaxis_title=f'PC3 ({variance_explained[2]*100:.1f}%)'
        ),
        legend=dict(
            itemsizing='constant',
            itemwidth=30,
        ),
        width=1000,
        height=800
    )
    
    return fig_3d

## Compute and save PCA 

In [33]:
layer = 22


In [34]:

pos_2_roles = []
pos_2_vectors = []
pos_3_roles = []
pos_3_vectors = []

# get the vectors keys for pos_2 and pos_3 for each role
for role, vector in role_vectors.items():
    if 'pos_2' in vector.keys():
        pos_2_roles.append(role)
        pos_2_vectors.append(vector['pos_2'])
    if 'pos_3' in vector.keys():
        pos_3_roles.append(role)
        pos_3_vectors.append(vector['pos_3'])

print(len(pos_2_roles))
print(len(pos_3_roles))


173
275


In [38]:
pos_70_traits = []
pos_70_vectors = []
pos_40_70_traits = []
pos_40_70_vectors = []

# need to filter out which ones have less than 10 pos_70 and pos_40_70
trait_stats = pd.read_csv('/root/git/persona-subspace/traits/results/pca_240/pos.csv', index_col='trait')

for trait, vector in trait_vectors.items():
    if trait_stats.loc[trait]['pos_70_count'] >= 10:
        pos_70_traits.append(trait)
        pos_70_vectors.append(vector['pos_70'])
    if trait_stats.loc[trait]['pos_40_70_count'] >= 10:
        pos_40_70_traits.append(trait)
        pos_40_70_vectors.append(vector['pos_40_70'])

print(len(pos_70_traits))
print(len(pos_40_70_traits))

239
111


In [39]:
combined_vectors = pos_2_vectors + pos_3_vectors + pos_40_70_vectors + pos_70_vectors
print(len(combined_vectors))

798


In [40]:
pca_transformed, variance_explained, n_components, pca, scaler = compute_pca(torch.stack(combined_vectors).float(), layer)

PCA fitted with 798 components
Cumulative variance for first 5 components: [0.16901494 0.26605914 0.32784762 0.37938097 0.42401652]

PCA Analysis Results:
Elbow point at component: 2
Dimensions for 70% variance: 21
Dimensions for 80% variance: 39
Dimensions for 90% variance: 86
Dimensions for 95% variance: 154


In [41]:
results = {}
results['layer'] = layer
results['roles_or_traits'] = {
    'roles_pos_2': pos_2_roles,
    'roles_pos_3': pos_3_roles,
    'traits_pos_40_70': pos_40_70_traits,
    'traits_pos_70': pos_70_traits
}
results['vectors'] = {
    'roles_pos_2': pos_2_vectors,
    'roles_pos_3': pos_3_vectors,
    'traits_pos_40_70': pos_40_70_vectors,
    'traits_pos_70': pos_70_vectors
}
results['order'] = ['roles_pos_2', 'roles_pos_3', 'traits_pos_40_70', 'traits_pos_70']
results['pca_transformed'] = pca_transformed
results['variance_explained'] = variance_explained
results['n_components'] = n_components
results['pca'] = pca
results['scaler'] = scaler

torch.save(results, f"/workspace/roles_traits/pca/layer{layer}_roles_pos23_traits_pos40-100.pt")

## Plots

In [6]:
layer = 22
# load in PCs
pca_results = torch.load(f"/workspace/roles_traits/pca/layer{layer}_roles_pos23_traits_pos40-100.pt", weights_only=False)
default_vectors = torch.load("/workspace/roles_240/default_vectors.pt")


In [7]:

output_dir = f"./results/pca/layer{layer}"
os.makedirs(output_dir, exist_ok=True)

In [8]:
# also calculate role labels for plotting
def get_role_labels_and_urls(pca_results):
    label_dict = {}
    url_dict = {}
    base_url = "https://lu-christina.github.io/persona-subspace/viewer/index.html"
    
    for key in pca_results['roles_or_traits'].keys():
        original_names = pca_results['roles_or_traits'][key]
        labels = [name.replace('_', ' ').title() for name in original_names]
        
        if key == 'roles_pos_2':
            label_dict[key] = [f"{label} (Somewhat RP)" for label in labels]
            url_dict[key] = [f"{base_url}?source=role_shared&role={name}" for name in original_names]
        elif key == 'roles_pos_3':
            label_dict[key] = [f"{label} (Fully RP)" for label in labels]
            url_dict[key] = [f"{base_url}?source=role_shared&role={name}" for name in original_names]
        elif key == 'traits_pos_40_70':
            label_dict[key] = [f"{label} (Somewhat)" for label in labels]
            url_dict[key] = [f"{base_url}?source=trait_shared&trait={name}" for name in original_names]
        elif key == 'traits_pos_70':
            label_dict[key] = [f"{label} (Fully)" for label in labels]
            url_dict[key] = [f"{base_url}?source=trait_shared&trait={name}" for name in original_names]
    
    return label_dict, url_dict

role_labels_dict, url_dict = get_role_labels_and_urls(pca_results)

# Combine in the correct display order for plotting
role_labels = (role_labels_dict['roles_pos_2'] + 
               role_labels_dict['roles_pos_3'] + 
               role_labels_dict['traits_pos_40_70'] + 
               role_labels_dict['traits_pos_70'])

# Combine URLs in the same order
role_urls = (url_dict['roles_pos_2'] + 
             url_dict['roles_pos_3'] + 
             url_dict['traits_pos_40_70'] + 
             url_dict['traits_pos_70'])

print(len(role_labels))

798


In [9]:
print(role_labels[:10])
print(role_labels[-10:])

['Writer (Somewhat RP)', 'Workaholic (Somewhat RP)', 'Witness (Somewhat RP)', 'Visionary (Somewhat RP)', 'Virus (Somewhat RP)', 'Virtuoso (Somewhat RP)', 'Vigilante (Somewhat RP)', 'Veterinarian (Somewhat RP)', 'Vegan (Somewhat RP)', 'Validator (Somewhat RP)']
['Analytical (Fully)', 'Altruistic (Fully)', 'Agreeable (Fully)', 'Adventurous (Fully)', 'Adaptable (Fully)', 'Acerbic (Fully)', 'Accommodating (Fully)', 'Accessible (Fully)', 'Abstract (Fully)', 'Absolutist (Fully)']


In [10]:
# get default activation and project into PCA space
assistant_layer_activation = default_vectors['activations']['default_1'][layer, :].float().numpy().reshape(1, -1)
asst_scaled = pca_results['scaler'].transform(assistant_layer_activation)
asst_projected = pca_results['pca'].transform(asst_scaled)


In [14]:
assistant = True

# JavaScript for handling clicks on plotly points
click_js = """
<script>
function setupClickHandlers() {
    const plotElements = document.querySelectorAll('.js-plotly-plot');
    
    plotElements.forEach(function(plotElement) {
        plotElement.on('plotly_click', function(data) {
            if (data.points && data.points.length > 0) {
                const point = data.points[0];
                if (point.customdata) {
                    window.open(point.customdata, '_blank');
                }
            }
        });
    });
}

// Setup handlers when page loads
document.addEventListener('DOMContentLoaded', setupClickHandlers);
// Also setup when plotly is done rendering
if (window.Plotly) {
    window.Plotly.newPlot = (function(originalNewPlot) {
        return function() {
            const result = originalNewPlot.apply(this, arguments);
            setTimeout(setupClickHandlers, 100);
            return result;
        };
    })(window.Plotly.newPlot);
}
</script>
"""

for i in range(10):
    component = i
    if assistant:
        fig = plot_pca_cosine_similarity(pca_results, role_labels, role_urls, component, layer, dir, type, assistant_activation=asst_projected[0])
        fig.show()
        
        # Write HTML with click handling
        html_content = fig.to_html()
        html_with_clicks = html_content.replace('</body>', f'{click_js}</body>')
        with open(f"{output_dir}/pc{component+1}_assistant.html", 'w') as f:
            f.write(html_with_clicks)
    else:
        fig = plot_pca_cosine_similarity(pca_results, role_labels, role_urls, component, layer, dir, type)
        fig.show()
        
        # Write HTML with click handling
        html_content = fig.to_html()
        html_with_clicks = html_content.replace('</body>', f'{click_js}</body>')
        with open(f"{output_dir}/pc{component+1}.html", 'w') as f:
            f.write(html_with_clicks)

In [12]:
assistant = True
if assistant:
    fig_3d = plot_3d_pca(pca_results, role_labels, role_urls, layer, dir, type, assistant_activation=asst_projected[0])
    fig_3d.show()
    
    # Write HTML with click handling for 3D plot
    html_content = fig_3d.to_html()
    html_with_clicks = html_content.replace('</body>', f'{click_js}</body>')
    with open(f"{output_dir}/3d_pca_assistant.html", 'w') as f:
        f.write(html_with_clicks)
else:
    fig_3d = plot_3d_pca(pca_results, role_labels, role_urls, layer, dir, type)
    fig_3d.show()
    
    # Write HTML with click handling for 3D plot
    html_content = fig_3d.to_html()
    html_with_clicks = html_content.replace('</body>', f'{click_js}</body>')
    with open(f"{output_dir}/3d_pca.html", 'w') as f:
        f.write(html_with_clicks)

In [13]:
# Test URL generation
print("First few role URLs:")
for i in range(5):
    print(f"{role_labels[i]} -> {role_urls[i]}")

print("\nFirst few trait URLs:")
trait_start_idx = len(role_labels_dict['roles_pos_2']) + len(role_labels_dict['roles_pos_3'])
for i in range(trait_start_idx, trait_start_idx + 5):
    print(f"{role_labels[i]} -> {role_urls[i]}")
    
print(f"\nTotal URLs generated: {len(role_urls)}")
print(f"Total labels: {len(role_labels)}")
print(f"URLs match labels: {len(role_urls) == len(role_labels)}")

First few role URLs:
Writer (Somewhat RP) -> https://lu-christina.github.io/persona-subspace/viewer/index.html?source=role_shared&role=writer
Workaholic (Somewhat RP) -> https://lu-christina.github.io/persona-subspace/viewer/index.html?source=role_shared&role=workaholic
Witness (Somewhat RP) -> https://lu-christina.github.io/persona-subspace/viewer/index.html?source=role_shared&role=witness
Visionary (Somewhat RP) -> https://lu-christina.github.io/persona-subspace/viewer/index.html?source=role_shared&role=visionary
Virus (Somewhat RP) -> https://lu-christina.github.io/persona-subspace/viewer/index.html?source=role_shared&role=virus

First few trait URLs:
Visceral (Somewhat) -> https://lu-christina.github.io/persona-subspace/viewer/index.html?source=trait_shared&trait=visceral
Utilitarian (Somewhat) -> https://lu-christina.github.io/persona-subspace/viewer/index.html?source=trait_shared&trait=utilitarian
Universalist (Somewhat) -> https://lu-christina.github.io/persona-subspace/viewer/i