# PCA on roles

In [1]:
import json
import os
import torch
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.subplots as sp
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.metrics.pairwise import cosine_similarity
from pathlib import Path
from tqdm import tqdm


In [2]:
# roles or roles_240
dir = "roles" 

# 30 or 240
if dir == "roles":
    n_questions = 30
    n_prompt_types = 2
elif dir == "roles_240":
    n_questions = 240
    n_prompt_types = 1


## Load vectors

In [3]:
# load all vectors 
vector_dir = f"/workspace/{dir}/vectors"

# iterate through each .pt file in the directory
vectors = {}
for file in os.listdir(vector_dir):
    if file.endswith(".pt"):
        vectors[file.replace(".pt", "")] = torch.load(os.path.join(vector_dir, file))

print(f"Found {len(vectors.keys())} traits with vectors")

Found 275 traits with vectors


In [4]:
# load default vectors
default_vectors = torch.load(f"/workspace/{dir}/default_vectors.pt")

In [5]:
print(vectors['graduate'].keys())
print(default_vectors.keys())
print(default_vectors['activations'].keys())

dict_keys(['pos_3', 'pos_all', 'default_0', 'default_1', 'default_all'])
dict_keys(['activations', 'metadata'])
dict_keys(['pos_1', 'default_1', 'all_1'])


## PCA + plotting functions

In [6]:
def compute_pca(activation_list, layer):
    layer_activations = activation_list[:, layer, :]
    
    scaler = StandardScaler()
    scaled_layer_activations = scaler.fit_transform(layer_activations)

    pca = PCA()
    pca_transformed = pca.fit_transform(scaled_layer_activations)

    variance_explained = pca.explained_variance_ratio_
    cumulative_variance = np.cumsum(variance_explained)
    n_components = len(variance_explained)

    print(f"PCA fitted with {n_components} components")
    print(f"Cumulative variance for first 5 components: {cumulative_variance[:5]}")

    # Find elbow using second derivative method
    def find_elbow_point(variance_explained):
        """Find elbow point using second derivative method"""
        # Calculate first and second derivatives
        first_diff = np.diff(variance_explained)
        second_diff = np.diff(first_diff) 
        
        # Find point with maximum second derivative (most curvature)
        elbow_idx = np.argmax(np.abs(second_diff)) + 1  # +1 to account for diff operations
        return elbow_idx

    elbow_point = find_elbow_point(variance_explained)
    dims_70_pca = np.argmax(cumulative_variance >= 0.70) + 1
    dims_80_pca = np.argmax(cumulative_variance >= 0.80) + 1
    dims_90_pca = np.argmax(cumulative_variance >= 0.90) + 1
    dims_95_pca = np.argmax(cumulative_variance >= 0.95) + 1

    print("\nPCA Analysis Results:")
    print(f"Elbow point at component: {elbow_point + 1}")
    print(f"Dimensions for 70% variance: {dims_70_pca}")
    print(f"Dimensions for 80% variance: {dims_80_pca}")
    print(f"Dimensions for 90% variance: {dims_90_pca}")
    print(f"Dimensions for 95% variance: {dims_95_pca}")

    return pca_transformed, variance_explained, n_components, pca, scaler 

In [153]:
def plot_pca_cosine_similarity(pca_results, role_labels, pc_component, 
                             layer, dir, type, assistant_activation=None):
    """
    Create a plot similar to the PC1 Cosine Similarity visualization.
    Shows labels on hover for most points, with visible labels and leader lines 
    for the 20 traits at either end of the range to avoid overlap.
    
    Parameters:
    - pca_transformed: PCA-transformed data (n_samples, n_components)
    - role_labels: List of labels for each data point
    - pc_component: Which PC component to use (0-indexed, so PC1 = 0)
    - layer: Layer number for title
    
    Returns:
    - Plotly figure object
    """
    
    # Extract the specified PC component
    pc_values = pca_results['pca_transformed'][:, pc_component]
    if assistant_activation is not None:
        assistant_pc_value = assistant_activation[pc_component]
    
    # Calculate cosine similarities
    cosine_sims = pc_values / np.linalg.norm(pc_values)  # Normalized PC values
    if assistant_activation is not None:
        assistant_cosine_sim = assistant_pc_value / np.linalg.norm(np.concatenate([pc_values, [assistant_pc_value]]))
    
    
    # Create colors based on vector type (pos_2 = green, pos_3 = orange)
    if type == "pos23" and 'pos_2' in pca_results['roles']:
        n_pos_2 = len(pca_results['roles']['pos_2'])
        colors = ['green'] * n_pos_2 + ['orange'] * (len(cosine_sims) - n_pos_2)
    else:
        # Default to orange for pos3 only
        colors = ['orange'] * len(cosine_sims)
    
    # Determine marker shapes based on vector type
    if type == "pos23" and 'pos_2' in pca_results['roles']:
        n_pos_2 = len(pca_results['roles']['pos_2'])
        # First n_pos_2 are pos_2 (circles), rest are pos_3 (squares)
        marker_symbols = ['circle'] * n_pos_2 + ['square'] * (len(cosine_sims) - n_pos_2)
    else:
        # Default to circles for all points
        marker_symbols = ['circle'] * len(cosine_sims)
    
    # Identify extreme traits (10 lowest and 10 highest)
    sorted_indices = np.argsort(cosine_sims)
    low_extreme_indices = sorted_indices[:10]
    high_extreme_indices = sorted_indices[-10:]
    extreme_indices = set(list(low_extreme_indices) + list(high_extreme_indices))
    
    # Create subplot figure
    fig = sp.make_subplots(
        rows=2, cols=1,
        row_heights=[0.6, 0.4],
        vertical_spacing=0.1,
        subplot_titles=[
            f'PC{pc_component+1} Cosine Similarity',
            'Trait Frequency Distribution'
        ]
    )
    
    if type == "pos23" and 'pos_2' in pca_results['roles']:
        # Split points by type for legend
        n_pos_2 = len(pca_results['roles']['pos_2'])
        
        # Split regular and extreme points by type
        pos2_regular_x, pos2_regular_y, pos2_regular_colors, pos2_regular_labels = [], [], [], []
        pos3_regular_x, pos3_regular_y, pos3_regular_colors, pos3_regular_labels = [], [], [], []
        pos2_extreme_x, pos2_extreme_y, pos2_extreme_colors, pos2_extreme_labels = [], [], [], []
        pos3_extreme_x, pos3_extreme_y, pos3_extreme_colors, pos3_extreme_labels = [], [], [], []
        
        for i, (sim, color, label, symbol) in enumerate(zip(cosine_sims, colors, role_labels, marker_symbols)):
            is_pos2 = i < n_pos_2
            if i in extreme_indices:
                if is_pos2:
                    pos2_extreme_x.append(sim)
                    pos2_extreme_y.append(1)
                    pos2_extreme_colors.append(color)
                    pos2_extreme_labels.append(label)
                else:
                    pos3_extreme_x.append(sim)
                    pos3_extreme_y.append(1)
                    pos3_extreme_colors.append(color)
                    pos3_extreme_labels.append(label)
            else:
                if is_pos2:
                    pos2_regular_x.append(sim)
                    pos2_regular_y.append(1)
                    pos2_regular_colors.append(color)
                    pos2_regular_labels.append(label)
                else:
                    pos3_regular_x.append(sim)
                    pos3_regular_y.append(1)
                    pos3_regular_colors.append(color)
                    pos3_regular_labels.append(label)
        
        # Add pos_2 regular points
        if pos2_regular_x:
            fig.add_trace(
                go.Scatter(
                    x=pos2_regular_x,
                    y=pos2_regular_y,
                    mode='markers',
                    marker=dict(
                        color='green',  # Changed from gray to green
                        size=8,
                        opacity=0.7,
                        symbol='circle',
                        line=dict(width=1, color='black')
                    ),
                    text=pos2_regular_labels,
                    name='Somewhat Role-Playing',
                    legendgroup='scatter_pos2',
                    legend='legend',
                    hovertemplate='<b>%{text}</b><br>Cosine Similarity: %{x:.3f}<extra></extra>'
                ),
                row=1, col=1
            )
        
        # Add pos_2 extreme points
        if pos2_extreme_x:
            fig.add_trace(
                go.Scatter(
                    x=pos2_extreme_x,
                    y=pos2_extreme_y,
                    mode='markers',
                    marker=dict(
                        color=pos2_extreme_colors,  # Will be green
                        size=8,
                        opacity=0.9,
                        symbol='circle',
                        line=dict(width=1, color='black')
                    ),
                    text=pos2_extreme_labels,
                    name='Somewhat Role-Playing',
                    legendgroup='scatter_pos2',
                    legend='legend',
                    showlegend=False,
                    hovertemplate='<b>%{text}</b><br>Cosine Similarity: %{x:.3f}<extra></extra>'
                ),
                row=1, col=1
            )
        
        # Add pos_3 regular points  
        if pos3_regular_x:
            fig.add_trace(
                go.Scatter(
                    x=pos3_regular_x,
                    y=pos3_regular_y,
                    mode='markers',
                    marker=dict(
                        color='orange',  # Changed from gray to orange
                        size=8,
                        opacity=0.7,
                        symbol='square',
                        line=dict(width=1, color='black')
                    ),
                    text=pos3_regular_labels,
                    name='Fully Role-Playing',
                    legendgroup='scatter_pos3',
                    legend='legend',
                    hovertemplate='<b>%{text}</b><br>Cosine Similarity: %{x:.3f}<extra></extra>'
                ),
                row=1, col=1
            )
        
        # Add pos_3 extreme points
        if pos3_extreme_x:
            fig.add_trace(
                go.Scatter(
                    x=pos3_extreme_x,
                    y=pos3_extreme_y,
                    mode='markers',
                    marker=dict(
                        color=pos3_extreme_colors,  # Will be orange
                        size=8,
                        opacity=0.9,
                        symbol='square',
                        line=dict(width=1, color='black')
                    ),
                    text=pos3_extreme_labels,
                    name='Fully Role-Playing',
                    legendgroup='scatter_pos3',
                    legend='legend',
                    showlegend=False,
                    hovertemplate='<b>%{text}</b><br>Cosine Similarity: %{x:.3f}<extra></extra>'
                ),
                row=1, col=1
            )
            
    else:
        # Original logic for single type
        # Split points into regular and extreme for different display modes
        regular_x, regular_y, regular_colors, regular_labels, regular_symbols = [], [], [], [], []
        extreme_x, extreme_y, extreme_colors, extreme_labels, extreme_symbols = [], [], [], [], []
        
        for i, (sim, color, label, symbol) in enumerate(zip(cosine_sims, colors, role_labels, marker_symbols)):
            if i in extreme_indices:
                extreme_x.append(sim)
                extreme_y.append(1)
                extreme_colors.append(color)
                extreme_labels.append(label)
                extreme_symbols.append(symbol)
            else:
                regular_x.append(sim)
                regular_y.append(1)
                regular_colors.append(color)
                regular_labels.append(label)
                regular_symbols.append(symbol)
        
        # Add regular points (hover labels only)
        if regular_x:
            fig.add_trace(
                go.Scatter(
                    x=regular_x,
                    y=regular_y,
                    mode='markers',
                    marker=dict(
                        color=regular_colors,  # Will be orange for pos3 only
                        size=8,
                        opacity=0.7,
                        symbol=regular_symbols,
                        line=dict(width=1, color='black')
                    ),
                    text=regular_labels,
                    showlegend=False,
                    hovertemplate='<b>%{text}</b><br>Cosine Similarity: %{x:.3f}<extra></extra>'
                ),
                row=1, col=1
            )
        
        # Add extreme points with visible labels and leader lines
        if extreme_x:
            fig.add_trace(
                go.Scatter(
                    x=extreme_x,
                    y=extreme_y,
                    mode='markers',
                    marker=dict(
                        color=extreme_colors,  # Will be orange for pos3 only
                        size=8,
                        opacity=0.9,
                        symbol=extreme_symbols,
                        line=dict(width=1, color='black')
                    ),
                    text=extreme_labels,
                    showlegend=False,
                    hovertemplate='<b>%{text}</b><br>Cosine Similarity: %{x:.3f}<extra></extra>'
                ),
                row=1, col=1
            )
    
    # Add leader lines and annotations for extreme points
    if len(extreme_indices) > 0:
        # Create predefined alternating heights with variation
        # High positions with variation
        high_positions = [1.6, 1.45, 1.55, 1.35, 1.5, 1.4, 1.65, 1.3, 1.58, 1.42]
        # Low positions with variation  
        low_positions = [0.4, 0.55, 0.45, 0.65, 0.5, 0.6, 0.35, 0.7, 0.42, 0.58]
        
        # Alternate high-low pattern
        all_y_positions = []
        for i in range(10):
            all_y_positions.extend([high_positions[i], low_positions[i]])
        
        # Handle low extremes (10 lowest cosine similarities)
        for i, idx in enumerate(low_extreme_indices):
            x_pos = cosine_sims[idx]
            label = role_labels[idx]
            # Keep leader lines red/blue based on cosine similarity as requested
            leader_color = 'red' if cosine_sims[idx] < 0 else 'blue'
            y_label = all_y_positions[i]
            
            # Add leader line as a separate trace
            fig.add_trace(
                go.Scatter(
                    x=[x_pos, x_pos],
                    y=[1.0, y_label],
                    mode='lines',
                    line=dict(color=leader_color, width=1),
                    showlegend=False,
                    hoverinfo='skip'
                ),
                row=1, col=1
            )
            
            # Add label at the end of the line
            fig.add_annotation(
                x=x_pos,
                y=y_label,
                text=label,
                showarrow=False,
                font=dict(size=10, color=leader_color),
                bgcolor="rgba(255, 255, 255, 0.9)",
                bordercolor=leader_color,
                borderwidth=1,
                row=1, col=1
            )
        
        # Handle high extremes (10 highest cosine similarities)
        for i, idx in enumerate(high_extreme_indices):
            x_pos = cosine_sims[idx]
            label = role_labels[idx]
            # Keep leader lines red/blue based on cosine similarity as requested
            leader_color = 'red' if cosine_sims[idx] < 0 else 'blue'
            y_label = all_y_positions[i + 10]  # Offset by 10 to continue the pattern
            
            # Add leader line as a separate trace
            fig.add_trace(
                go.Scatter(
                    x=[x_pos, x_pos],
                    y=[1.0, y_label],
                    mode='lines',
                    line=dict(color=leader_color, width=1),
                    showlegend=False,
                    hoverinfo='skip'
                ),
                row=1, col=1
            )
            
            # Add label at the end of the line
            fig.add_annotation(
                x=x_pos,
                y=y_label,
                text=label,
                showarrow=False,
                font=dict(size=10, color=leader_color),
                bgcolor="rgba(255, 255, 255, 0.9)",
                bordercolor=leader_color,
                borderwidth=1,
                row=1, col=1
            )
    
    # Add vertical line at x=0 for both panels
    fig.add_vline(
        x=0,
        line_dash="solid",
        line_color="gray",
        line_width=1,
        opacity=0.7,
        row=1, col=1
    )

    if assistant_activation is not None:
        # Add black dashed vertical line for assistant position
        fig.add_vline(x=assistant_cosine_sim, line_dash="dash", line_color="black", line_width=1, opacity=1.0, row=1, col=1)
        
        # Add Assistant label at same height as extremes
        assistant_y_position = 1.6  # Same as first high position
        fig.add_annotation(
            x=assistant_cosine_sim,
            y=assistant_y_position,
            text="Assistant",
            showarrow=False,
            font=dict(size=10, color="black"),
            bgcolor="rgba(255, 255, 255, 0.9)",
            bordercolor="black",
            borderwidth=1,
            row=1, col=1
        )
        
    fig.add_vline(
        x=0,
        line_dash="solid", 
        line_color="gray",
        line_width=1,
        opacity=0.7,
        row=2, col=1
    )
    
    # Bottom panel: Histogram
    if type == "pos23" and 'pos_2' in pca_results['roles']:
        # Split cosine similarities by type
        n_pos_2 = len(pca_results['roles']['pos_2'])
        pos2_cosine_sims = cosine_sims[:n_pos_2]
        pos3_cosine_sims = cosine_sims[n_pos_2:]
        
        # Calculate histogram bins manually
        nbins = 30
        min_val = min(cosine_sims)
        max_val = max(cosine_sims)
        bin_edges = np.linspace(min_val, max_val, nbins + 1)
        bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
        bin_width = bin_edges[1] - bin_edges[0]
        
        # Count occurrences in each bin for both types
        pos2_counts, _ = np.histogram(pos2_cosine_sims, bins=bin_edges)
        pos3_counts, _ = np.histogram(pos3_cosine_sims, bins=bin_edges)
        
        # Create long format data for stacked bars
        bin_centers_expanded = []
        counts_expanded = []
        types_expanded = []
        
        for i, (bin_center, pos2_count, pos3_count) in enumerate(zip(bin_centers, pos2_counts, pos3_counts)):
            # Add pos2 data
            bin_centers_expanded.append(bin_center)
            counts_expanded.append(pos2_count)
            types_expanded.append('Somewhat Role-Playing')
            
            # Add pos3 data
            bin_centers_expanded.append(bin_center)
            counts_expanded.append(pos3_count)
            types_expanded.append('Fully Role-Playing')
        
        # Create stacked bar chart using long format
        for i, type_name in enumerate(['Somewhat Role-Playing', 'Fully Role-Playing']):
            # Filter data for this type
            type_mask = [t == type_name for t in types_expanded]
            type_bin_centers = [bc for bc, mask in zip(bin_centers_expanded, type_mask) if mask]
            type_counts = [c for c, mask in zip(counts_expanded, type_mask) if mask]
            
            if type_name == 'Somewhat Role-Playing':
                pattern_shape = "."  # dot pattern
                # Change to green background
                pattern_config = dict(shape=pattern_shape, bgcolor="green", fgcolor="black")
                bar_color = 'green'
            else:
                pattern_shape = "+"  # plus pattern
                # Change to orange background
                pattern_config = dict(shape=pattern_shape, bgcolor="orange", fgcolor="black")
                bar_color = 'orange'
            
            fig.add_trace(
                go.Bar(
                    x=type_bin_centers,
                    y=type_counts,
                    width=bin_width * 0.9,
                    marker=dict(
                        color=bar_color,
                        pattern=pattern_config
                    ),
                    opacity=0.7,
                    name=type_name,
                    legendgroup=f'hist_{type_name.lower().replace(" ", "_").replace("-", "_")}',
                    legend='legend2',
                    hovertemplate=f'{type_name}<br>Count: %{{y}}<extra></extra>'
                ),
                row=2, col=1
            )
            
    else:
        # Single histogram for pos3 only
        fig.add_trace(
            go.Histogram(
                x=cosine_sims,
                nbinsx=30,
                opacity=0.7,
                marker_color='orange',  # Changed from steelblue to orange
                showlegend=False
            ),
            row=2, col=1
        )

    if type == "pos23":
        title = "PCA on Somewhat and Fully Role-Playing Vectors"
    elif type == "pos3":
        title = "PCA on Fully Role-Playing Vectors"
    
    subtitle = f"Gemma 2 27B, Layer {layer}"
    if dir == "roles":
        subtitle += " - Unique Question Set"
    elif dir == "roles_240":
        subtitle += " - Shared Question Set"

    # Update layout with separate legends and stacked bars
    show_legend = type == "pos23" and 'pos_2' in pca_results['roles']
    fig.update_layout(
        height=700,
        title=dict(
            text=title,
            subtitle={
                "text": subtitle,
            },
            x=0.5,
            font=dict(size=16)
        ),
        showlegend=show_legend,
        barmode='stack',  # Enable stacked bars
        legend=dict(
            x=0.02,
            y=1.02,
            xanchor='left',
            yanchor='bottom',
            bgcolor="rgba(255, 255, 255, 0.8)",
            bordercolor="gray",
            borderwidth=1
        ),
        legend2=dict(
            x=0.02,
            y=0.32,
            xanchor='left',
            yanchor='top',
            bgcolor="rgba(255, 255, 255, 0.8)",
            bordercolor="gray",
            borderwidth=1
        )
    )
    
    # Calculate symmetric range around 0 (not around data center)
    max_abs_value = max(abs(min(cosine_sims)), abs(max(cosine_sims)))
    x_half_width = max_abs_value * 1.1  # Add 10% padding
    
    # Update x-axes with symmetric ranges centered on 0
    fig.update_xaxes(
        row=1, col=1,
        range=[-x_half_width, x_half_width]
    )
    
    fig.update_xaxes(
        title_text=f"PC{pc_component+1} Cosine Similarity",
        row=2, col=1,
        range=[-x_half_width, x_half_width]
    )
    
    # Update y-axes
    fig.update_yaxes(
        title_text="",
        showticklabels=False,
        row=1, col=1,
        range=[0.25, 1.75]  # Range for varied label heights
    )
    
    fig.update_yaxes(
        title_text="Frequency",
        row=2, col=1
    )
    
    return fig

In [162]:


def plot_3d_pca(pca_results, role_labels, layer, dir, type, assistant_activation=None):
    # Create 3D scatter plot if we have enough components
    pca_transformed = pca_results['pca_transformed']
    variance_explained = pca_results['variance_explained']

    if type == "pos23" and 'pos_2' in pca_results['roles']:
        # Split into two traces for legend
        n_pos_2 = len(pca_results['roles']['pos_2'])
        
        # Select subset of points to display labels (50% of each type)
        pos2_label_indices = list(range(0, n_pos_2, 3))  # Every other point for pos_2
        pos3_label_indices = list(range(n_pos_2, len(role_labels), 3))  # Every other point for pos_3
        
        # Create text arrays with labels only for selected points
        pos2_text = [role_labels[i] if i in pos2_label_indices else '' for i in range(n_pos_2)]
        pos3_text = [role_labels[i] if i in pos3_label_indices else '' for i in range(n_pos_2, len(role_labels))]
        
        # Add pos_2 trace (green circles)
        fig_3d = go.Figure(data=[go.Scatter3d(
            x=pca_transformed[:n_pos_2, 0],
            y=pca_transformed[:n_pos_2, 1], 
            z=pca_transformed[:n_pos_2, 2],
            mode='markers+text',
            text=pos2_text,
            textposition='top center',
            textfont=dict(size=6),
            marker=dict(
                size=3,
                color='green',
                symbol='circle'
            ),
            name='Somewhat Role-Playing',
            hovertemplate='<b>%{hovertext}</b><br>' +
                        f'PC1: %{{x:.3f}}<br>' +
                        f'PC2: %{{y:.3f}}<br>' +
                        f'PC3: %{{z:.3f}}<br>' +
                        '<extra></extra>',
            hovertext=role_labels[:n_pos_2]
        )])
        
        # Add pos_3 trace (blue squares)
        fig_3d.add_trace(go.Scatter3d(
            x=pca_transformed[n_pos_2:, 0],
            y=pca_transformed[n_pos_2:, 1], 
            z=pca_transformed[n_pos_2:, 2],
            mode='markers+text',
            text=pos3_text,
            textposition='top center',
            textfont=dict(size=6),
            marker=dict(
                size=3,
                color='orange',
                symbol='square'
            ),
            name='Fully Role-Playing',
            hovertemplate='<b>%{hovertext}</b><br>' +
                        f'PC1: %{{x:.3f}}<br>' +
                        f'PC2: %{{y:.3f}}<br>' +
                        f'PC3: %{{z:.3f}}<br>' +
                        '<extra></extra>',
            hovertext=role_labels[n_pos_2:]
        ))
    else:
        # Default single trace - show labels for 50% of points
        label_indices = list(range(0, len(role_labels), 3))  # Every other point
        text_labels = [role_labels[i] if i in label_indices else '' for i in range(len(role_labels))]
        
        fig_3d = go.Figure(data=[go.Scatter3d(
            x=pca_transformed[:, 0],
            y=pca_transformed[:, 1], 
            z=pca_transformed[:, 2],
            mode='markers+text',
            text=text_labels,
            textposition='top center',
            textfont=dict(size=6),
            marker=dict(
                size=3,
                color='orange',
            ),
            showlegend=False,
            hovertemplate='<b>%{hovertext}</b><br>' +
                        f'PC1: %{{x:.3f}}<br>' +
                        f'PC2: %{{y:.3f}}<br>' +
                        f'PC3: %{{z:.3f}}<br>' +
                        '<extra></extra>',
            hovertext=role_labels
        )])
    
    if assistant_activation is not None:
        fig_3d.add_trace(go.Scatter3d(
        x=[assistant_activation[0]],
        y=[assistant_activation[1]],
        z=[assistant_activation[2]],
        mode='markers+text',
        text=['Assistant'],
        textposition='top center',
        textfont=dict(size=8, color='black'),
        marker=dict(
            size=5,  # 2 sizes bigger than trait dots (3 -> 5)
            color='red',
            opacity=1.0
        ),
        showlegend=False,
        hovertemplate='<b>Assistant</b><br>' +
                    f'PC1: %{{x:.3f}}<br>' +
                    f'PC2: %{{y:.3f}}<br>' +
                    f'PC3: %{{z:.3f}}<br>' +
                    '<extra></extra>'
    ))
    

    if type == "pos23":
        title = "Somewhat and Fully Role-Playing Vectors in 3D PC Space"
    elif type == "pos3":
        title = "Fully Role-Playing Vectors in 3D PC Space"

    subtitle = f"Gemma 2 27B, Layer {layer}"
    if dir == "roles":
        subtitle += " - Unique Question Set"
    elif dir == "roles_240":
        subtitle += " - Shared Question Set"
    
    fig_3d.update_layout(
        title={
            "text": title,
            "subtitle": {
                "text": subtitle,
            },
        },
        scene=dict(
            xaxis_title=f'PC1 ({variance_explained[0]*100:.1f}%)',
            yaxis_title=f'PC2 ({variance_explained[1]*100:.1f}%)',
            zaxis_title=f'PC3 ({variance_explained[2]*100:.1f}%)'
        ),
        legend=dict(
            itemsizing='constant',
            itemwidth=30,
        ),
        width=1000,
        height=800
    )
    
    return fig_3d

## Compute and save PCA 

In [None]:
layer = 22

pos_2_roles = []
pos_2_vectors = []
pos_3_roles = []
pos_3_vectors = []

# get the vectors keys for pos_2 and pos_3 for each role
for role, vector in vectors.items():
    if 'pos_2' in vector.keys():
        pos_2_roles.append(role)
        pos_2_vectors.append(vector['pos_2'])
    if 'pos_3' in vector.keys():
        pos_3_roles.append(role)
        pos_3_vectors.append(vector['pos_3'])

print(len(pos_2_roles))
print(len(pos_3_roles))


66
269


In [24]:
combined_vectors = pos_2_vectors + pos_3_vectors
pca_transformed, variance_explained, n_components, pca, scaler = compute_pca(torch.stack(combined_vectors).float(), layer)

PCA fitted with 335 components
Cumulative variance for first 5 components: [0.11621695 0.16829965 0.20788695 0.24367225 0.27714902]

PCA Analysis Results:
Elbow point at component: 2
Dimensions for 70% variance: 45
Dimensions for 80% variance: 70
Dimensions for 90% variance: 116
Dimensions for 95% variance: 162


In [None]:
results = {}
results['layer'] = layer
results['roles'] = {
    'pos_2': pos_2_roles,
    'pos_3': pos_3_roles
}
results['vectors'] = {
    'pos_2': pos_2_vectors,
    'pos_3': pos_3_vectors
}
results['pca_transformed'] = pca_transformed
results['variance_explained'] = variance_explained
results['n_components'] = n_components
results['pca'] = pca
results['scaler'] = scaler

torch.save(results, f"/workspace/{dir}/pca/layer{layer}_pos23.pt")

## Plots

In [244]:
# config
type = "pos3" # either pos23 or pos3
dir = "roles_240" # either roles or roles_240
layer = 34 # either layer 22 or 34


In [245]:

# load in PCs
pca_results = torch.load(f"/workspace/{dir}/pca/layer{layer}_{type}.pt", weights_only=False)
default_vectors = torch.load(f"/workspace/{dir}/default_vectors.pt")


In [246]:

output_dir = f"./results/{dir}/pca/layer{layer}_{type}"
os.makedirs(output_dir, exist_ok=True)

In [247]:
# also calculate role labels for plotting
def get_role_labels(pca_results, type):
    if 'pos_2' in pca_results['roles'].keys():
        pos_2_roles = [role.replace('_', ' ').title() for role in pca_results['roles']['pos_2']]
        pos_2_roles = [f"{role} (Somewhat RP)" for role in pos_2_roles]
    
    if 'pos_3' in pca_results['roles'].keys():
        pos_3_roles = [role.replace('_', ' ').title() for role in pca_results['roles']['pos_3']]
        if type == "pos23":
            pos_3_roles = [f"{role} (Fully RP)" for role in pos_3_roles]
    
    if type == "pos23":
        combined_role_labels = pos_2_roles + pos_3_roles
    elif type == "pos3":
        combined_role_labels = pos_3_roles

    return combined_role_labels

role_labels = get_role_labels(pca_results, type)



In [248]:
print(role_labels[:10])
print(role_labels[-10:])

['Zeitgeist', 'Zealot', 'Writer', 'Wraith', 'Workaholic', 'Witness', 'Witch', 'Wind', 'Widow', 'Whale']
['Altruist', 'Alien', 'Advocate', 'Adolescent', 'Addict', 'Actor', 'Activist', 'Accountant', 'Absurdist', 'Aberration']


In [249]:
# get default activation and project into PCA space
assistant_layer_activation = default_vectors['activations']['default_1'][layer, :].float().numpy().reshape(1, -1)
asst_scaled = pca_results['scaler'].transform(assistant_layer_activation)
asst_projected = pca_results['pca'].transform(asst_scaled)


In [250]:
assistant = True

for i in range(10):
    component = i
    if assistant:
        fig = plot_pca_cosine_similarity(pca_results, role_labels, component, layer, dir, type, assistant_activation=asst_projected[0])
        fig.show()
        fig.write_html(f"{output_dir}/pc{component+1}_assistant.html")
    else:
        fig = plot_pca_cosine_similarity(pca_results, role_labels, component, layer, dir, type)
        fig.show()
        fig.write_html(f"{output_dir}/pc{component+1}.html")

In [251]:
if assistant:
    fig_3d = plot_3d_pca(pca_results, role_labels, layer, dir, type, assistant_activation=asst_projected[0])
    fig_3d.show()
    fig_3d.write_html(f"{output_dir}/3d_pca_assistant.html")
else:
    fig_3d = plot_3d_pca(pca_results, role_labels, layer, dir, type)
    fig_3d.show()
    fig_3d.write_html(f"{output_dir}/3d_pca.html")