# PCA on roles and traits

In [2]:
import os
import sys
import torch
import numpy as np
import pandas as pd

sys.path.append('.')
sys.path.append('..')

from utils.pca_utils import L2MeanScaler, MeanScaler, compute_pca, plot_variance_explained
#from plots import plot_pc


In [14]:
# roles or roles_240
model = "gemma-2-27b"
role_dir = f"{model}/roles_240" 
layer = 22

# 30 or 240
if role_dir.split("/")[0] == "roles":
    n_questions = 30
    n_prompt_types = 2
elif role_dir.split("/")[0] == "roles_240":
    n_questions = 240
    n_prompt_types = 1

# traits or traits_240
trait_dir = f"{model}/traits_240"


## Load vectors

In [15]:
# load all vectors 
role_vector_dir = f"/workspace/{role_dir}/vectors"

# iterate through each .pt file in the directory
role_vectors = {}
for file in os.listdir(role_vector_dir):
    if file.endswith(".pt"):
        role_vectors[file.replace(".pt", "")] = torch.load(os.path.join(role_vector_dir, file))

print(f"Found {len(role_vectors.keys())} roles with vectors")

Found 275 roles with vectors


In [16]:
print(role_vectors['graduate'].keys())

dict_keys(['pos_2', 'pos_3', 'pos_all'])


In [17]:
trait_vector_dir = f"/workspace/{trait_dir}/vectors"

trait_vectors = {}
for file in os.listdir(trait_vector_dir):
    if file.endswith(".pt"):
        trait_vectors[file.replace(".pt", "")] = torch.load(os.path.join(trait_vector_dir, file))

print(f"Found {len(trait_vectors.keys())} traits with vectors")

Found 240 traits with vectors


## Compute and save PCA 

In [18]:

pos_2_roles = []
pos_2_vectors = []
pos_3_roles = []
pos_3_vectors = []

# get the vectors keys for pos_2 and pos_3 for each role
for role, vector in role_vectors.items():
    if 'pos_2' in vector.keys():
        pos_2_roles.append(role)
        pos_2_vectors.append(vector['pos_2'])
    if 'pos_3' in vector.keys():
        pos_3_roles.append(role)
        pos_3_vectors.append(vector['pos_3'])

print(len(pos_2_roles))
print(len(pos_3_roles))


173
275


In [19]:
pos_70_traits = []
pos_70_vectors = []
pos_40_70_traits = []
pos_40_70_vectors = []

# need to filter out which ones have less than 10 pos_70 and pos_40_70
trait_stats = pd.read_csv(f'/root/git/persona-subspace/traits/results/{model}/pos.csv', index_col='trait')

for trait, vector in trait_vectors.items():
    if trait_stats.loc[trait]['pos_70_count'] >= 10:
        pos_70_traits.append(trait)
        pos_70_vectors.append(vector['pos_70'])
    if trait_stats.loc[trait]['pos_40_70_count'] >= 10:
        pos_40_70_traits.append(trait)
        pos_40_70_vectors.append(vector['pos_40_70'])

print(len(pos_70_traits))
print(len(pos_40_70_traits))

239
111


In [20]:
combined_vectors = pos_2_vectors + pos_3_vectors + pos_40_70_vectors + pos_70_vectors
print(len(combined_vectors))

798


In [10]:
def sorted_by_pc(pca_transformed, pc_index, labels):
    df = pd.DataFrame({
        "label": labels,
        "projection": pca_transformed[:, pc_index],
    })
    df_sorted = df.sort_values(by="projection", ascending=True)
    return df_sorted



In [22]:
scaler = MeanScaler()
pca_transformed, variance_explained, n_components, pca, scaler = compute_pca(torch.stack(combined_vectors).float(), layer, scaler)

PCA fitted with 748 components
Cumulative variance for first 5 components: [0.18248554 0.31359082 0.371287   0.4150008  0.45511097]

PCA Analysis Results:
Elbow point at component: 3
Dimensions for 70% variance: 21
Dimensions for 80% variance: 40
Dimensions for 90% variance: 92
Dimensions for 95% variance: 159


In [23]:
pc_df = sorted_by_pc(pca_transformed, 0, pos_2_roles + pos_3_roles + pos_40_70_traits + pos_70_traits)
top_roles = pc_df['label'].head(5).tolist()
bottom_roles = pc_df['label'].tail(5).tolist()
print(top_roles)
print(bottom_roles)

['flexible', 'cautious', 'inspirational', 'anthropocentric', 'conciliatory']
['sardonic', 'cruel', 'hostile', 'nonchalant', 'flippant']


In [23]:
to_flip = [1]
for pc in to_flip:
    pca.components_[pc - 1] *= -1
    pca_transformed[:, pc - 1] *= -1

In [24]:
results = {}
results['layer'] = layer
results['roles_or_traits'] = {
    'roles_pos_2': pos_2_roles,
    'roles_pos_3': pos_3_roles,
    'traits_pos_40_70': pos_40_70_traits,
    'traits_pos_70': pos_70_traits
}
results['vectors'] = {
    'roles_pos_2': pos_2_vectors,
    'roles_pos_3': pos_3_vectors,
    'traits_pos_40_70': pos_40_70_vectors,
    'traits_pos_70': pos_70_vectors
}
results['order'] = ['roles_pos_2', 'roles_pos_3', 'traits_pos_40_70', 'traits_pos_70']
results['pca_transformed'] = pca_transformed
results['variance_explained'] = variance_explained
results['n_components'] = n_components
results['pca'] = pca
results['scaler'] = scaler

pca_dir = f"/workspace/{model}/roles_traits/pca"
os.makedirs(pca_dir, exist_ok=True)
torch.save(results, f"{pca_dir}/layer{layer}_mean_roles_pos23_traits_pos40-100.pt")

In [21]:
scaler = L2MeanScaler()
pca_transformed, variance_explained, n_components, pca, scaler = compute_pca(torch.stack(combined_vectors).float(), layer, scaler)

PCA fitted with 798 components
Cumulative variance for first 5 components: [0.3663978  0.51538634 0.58227015 0.62599045 0.6644131 ]

PCA Analysis Results:
Elbow point at component: 2
Dimensions for 70% variance: 7
Dimensions for 80% variance: 11
Dimensions for 90% variance: 26
Dimensions for 95% variance: 54


In [22]:
pc_df = sorted_by_pc(pca_transformed, 0, pos_2_roles + pos_3_roles + pos_40_70_traits + pos_70_traits)
top_roles = pc_df['label'].head(5).tolist()
bottom_roles = pc_df['label'].tail(5).tolist()
print(top_roles)
print(bottom_roles)

['romantic', 'bohemian', 'dramatic', 'zealous', 'rhetorical']
['evaluator', 'researcher', 'validator', 'analyst', 'examiner']


In [24]:
results = {}
results['layer'] = layer
results['roles_or_traits'] = {
    'roles_pos_2': pos_2_roles,
    'roles_pos_3': pos_3_roles,
    'traits_pos_40_70': pos_40_70_traits,
    'traits_pos_70': pos_70_traits
}
results['vectors'] = {
    'roles_pos_2': pos_2_vectors,
    'roles_pos_3': pos_3_vectors,
    'traits_pos_40_70': pos_40_70_vectors,
    'traits_pos_70': pos_70_vectors
}
results['order'] = ['roles_pos_2', 'roles_pos_3', 'traits_pos_40_70', 'traits_pos_70']
results['pca_transformed'] = pca_transformed
results['variance_explained'] = variance_explained
results['n_components'] = n_components
results['pca'] = pca
results['scaler'] = scaler

pca_dir = f"/workspace/{model}/roles_traits/pca"
os.makedirs(pca_dir, exist_ok=True)
torch.save(results, f"{pca_dir}/layer{layer}_normalized_roles_pos23_traits_pos40-100.pt")

## Plots

In [6]:
layer = 22
# load in PCs
pca_results = torch.load(f"/workspace/roles_traits/pca/layer{layer}_roles_pos23_traits_pos40-100.pt", weights_only=False)
default_vectors = torch.load("/workspace/roles_240/default_vectors.pt")


In [7]:

output_dir = f"./results/pca/layer{layer}"
os.makedirs(output_dir, exist_ok=True)

In [8]:
# also calculate role labels for plotting
def get_role_labels_and_urls(pca_results):
    label_dict = {}
    url_dict = {}
    base_url = "https://lu-christina.github.io/persona-subspace/viewer/index.html"
    
    for key in pca_results['roles_or_traits'].keys():
        original_names = pca_results['roles_or_traits'][key]
        labels = [name.replace('_', ' ').title() for name in original_names]
        
        if key == 'roles_pos_2':
            label_dict[key] = [f"{label} (Somewhat RP)" for label in labels]
            url_dict[key] = [f"{base_url}?source=role_shared&role={name}" for name in original_names]
        elif key == 'roles_pos_3':
            label_dict[key] = [f"{label} (Fully RP)" for label in labels]
            url_dict[key] = [f"{base_url}?source=role_shared&role={name}" for name in original_names]
        elif key == 'traits_pos_40_70':
            label_dict[key] = [f"{label} (Somewhat)" for label in labels]
            url_dict[key] = [f"{base_url}?source=trait_shared&trait={name}" for name in original_names]
        elif key == 'traits_pos_70':
            label_dict[key] = [f"{label} (Fully)" for label in labels]
            url_dict[key] = [f"{base_url}?source=trait_shared&trait={name}" for name in original_names]
    
    return label_dict, url_dict

role_labels_dict, url_dict = get_role_labels_and_urls(pca_results)

# Combine in the correct display order for plotting
role_labels = (role_labels_dict['roles_pos_2'] + 
               role_labels_dict['roles_pos_3'] + 
               role_labels_dict['traits_pos_40_70'] + 
               role_labels_dict['traits_pos_70'])

# Combine URLs in the same order
role_urls = (url_dict['roles_pos_2'] + 
             url_dict['roles_pos_3'] + 
             url_dict['traits_pos_40_70'] + 
             url_dict['traits_pos_70'])

print(len(role_labels))

798


In [9]:
print(role_labels[:10])
print(role_labels[-10:])

['Writer (Somewhat RP)', 'Workaholic (Somewhat RP)', 'Witness (Somewhat RP)', 'Visionary (Somewhat RP)', 'Virus (Somewhat RP)', 'Virtuoso (Somewhat RP)', 'Vigilante (Somewhat RP)', 'Veterinarian (Somewhat RP)', 'Vegan (Somewhat RP)', 'Validator (Somewhat RP)']
['Analytical (Fully)', 'Altruistic (Fully)', 'Agreeable (Fully)', 'Adventurous (Fully)', 'Adaptable (Fully)', 'Acerbic (Fully)', 'Accommodating (Fully)', 'Accessible (Fully)', 'Abstract (Fully)', 'Absolutist (Fully)']


In [10]:
# get default activation and project into PCA space
assistant_layer_activation = default_vectors['activations']['default_1'][layer, :].float().numpy().reshape(1, -1)
asst_scaled = pca_results['scaler'].transform(assistant_layer_activation)
asst_projected = pca_results['pca'].transform(asst_scaled)


In [27]:
print(default_vectors['activations']['default_1'].shape)

torch.Size([46, 4608])


In [31]:
assistant = True

# JavaScript for handling clicks on plotly points
click_js = """
<script>
function setupClickHandlers() {
    const plotElements = document.querySelectorAll('.js-plotly-plot');
    
    plotElements.forEach(function(plotElement) {
        plotElement.on('plotly_click', function(data) {
            if (data.points && data.points.length > 0) {
                const point = data.points[0];
                if (point.customdata) {
                    window.open(point.customdata, '_blank');
                }
            }
        });
    });
}

// Setup handlers when page loads
document.addEventListener('DOMContentLoaded', setupClickHandlers);
// Also setup when plotly is done rendering
if (window.Plotly) {
    window.Plotly.newPlot = (function(originalNewPlot) {
        return function() {
            const result = originalNewPlot.apply(this, arguments);
            setTimeout(setupClickHandlers, 100);
            return result;
        };
    })(window.Plotly.newPlot);
}
</script>
"""

for i in range(10):
    component = i
    if assistant:
        fig = plot_pca_cosine_similarity(pca_results, role_labels, role_urls, component, layer, dir, assistant_activation=default_vectors['activations']['default_1'])
        fig.show()
        
        # Write HTML with click handling
        html_content = fig.to_html()
        html_with_clicks = html_content.replace('</body>', f'{click_js}</body>')
        with open(f"{output_dir}/pc{component+1}_cossim.html", 'w') as f:
            f.write(html_with_clicks)
    else:
        fig = plot_pca_cosine_similarity(pca_results, role_labels, role_urls, component, layer, dir)
        fig.show()
        
        # Write HTML with click handling
        html_content = fig.to_html()
        html_with_clicks = html_content.replace('</body>', f'{click_js}</body>')
        with open(f"{output_dir}/pc{component+1}_cossim.html", 'w') as f:
            f.write(html_with_clicks)

In [21]:
assistant = True

# JavaScript for handling clicks on plotly points
click_js = """
<script>
function setupClickHandlers() {
    const plotElements = document.querySelectorAll('.js-plotly-plot');
    
    plotElements.forEach(function(plotElement) {
        plotElement.on('plotly_click', function(data) {
            if (data.points && data.points.length > 0) {
                const point = data.points[0];
                if (point.customdata) {
                    window.open(point.customdata, '_blank');
                }
            }
        });
    });
}

// Setup handlers when page loads
document.addEventListener('DOMContentLoaded', setupClickHandlers);
// Also setup when plotly is done rendering
if (window.Plotly) {
    window.Plotly.newPlot = (function(originalNewPlot) {
        return function() {
            const result = originalNewPlot.apply(this, arguments);
            setTimeout(setupClickHandlers, 100);
            return result;
        };
    })(window.Plotly.newPlot);
}
</script>
"""

for i in range(10):
    component = i
    if assistant:
        fig = plot_pca_projection(pca_results, role_labels, role_urls, component, layer, dir, type, assistant_activation=asst_projected[0])
        fig.show()
        
        # Write HTML with click handling
        html_content = fig.to_html()
        html_with_clicks = html_content.replace('</body>', f'{click_js}</body>')
        with open(f"{output_dir}/pc{component+1}_projection.html", 'w') as f:
            f.write(html_with_clicks)
    else:
        fig = plot_pca_projection(pca_results, role_labels, role_urls, component, layer, dir, type)
        fig.show()
        
        # Write HTML with click handling
        html_content = fig.to_html()
        html_with_clicks = html_content.replace('</body>', f'{click_js}</body>')
        with open(f"{output_dir}/pc{component+1}_projection.html", 'w') as f:
            f.write(html_with_clicks)

In [12]:
assistant = True
if assistant:
    fig_3d = plot_3d_pca(pca_results, role_labels, role_urls, layer, dir, type, assistant_activation=asst_projected[0])
    fig_3d.show()
    
    # Write HTML with click handling for 3D plot
    html_content = fig_3d.to_html()
    html_with_clicks = html_content.replace('</body>', f'{click_js}</body>')
    with open(f"{output_dir}/3d_pca_assistant.html", 'w') as f:
        f.write(html_with_clicks)
else:
    fig_3d = plot_3d_pca(pca_results, role_labels, role_urls, layer, dir, type)
    fig_3d.show()
    
    # Write HTML with click handling for 3D plot
    html_content = fig_3d.to_html()
    html_with_clicks = html_content.replace('</body>', f'{click_js}</body>')
    with open(f"{output_dir}/3d_pca.html", 'w') as f:
        f.write(html_with_clicks)

In [13]:
# Test URL generation
print("First few role URLs:")
for i in range(5):
    print(f"{role_labels[i]} -> {role_urls[i]}")

print("\nFirst few trait URLs:")
trait_start_idx = len(role_labels_dict['roles_pos_2']) + len(role_labels_dict['roles_pos_3'])
for i in range(trait_start_idx, trait_start_idx + 5):
    print(f"{role_labels[i]} -> {role_urls[i]}")
    
print(f"\nTotal URLs generated: {len(role_urls)}")
print(f"Total labels: {len(role_labels)}")
print(f"URLs match labels: {len(role_urls) == len(role_labels)}")

First few role URLs:
Writer (Somewhat RP) -> https://lu-christina.github.io/persona-subspace/viewer/index.html?source=role_shared&role=writer
Workaholic (Somewhat RP) -> https://lu-christina.github.io/persona-subspace/viewer/index.html?source=role_shared&role=workaholic
Witness (Somewhat RP) -> https://lu-christina.github.io/persona-subspace/viewer/index.html?source=role_shared&role=witness
Visionary (Somewhat RP) -> https://lu-christina.github.io/persona-subspace/viewer/index.html?source=role_shared&role=visionary
Virus (Somewhat RP) -> https://lu-christina.github.io/persona-subspace/viewer/index.html?source=role_shared&role=virus

First few trait URLs:
Visceral (Somewhat) -> https://lu-christina.github.io/persona-subspace/viewer/index.html?source=trait_shared&trait=visceral
Utilitarian (Somewhat) -> https://lu-christina.github.io/persona-subspace/viewer/index.html?source=trait_shared&trait=utilitarian
Universalist (Somewhat) -> https://lu-christina.github.io/persona-subspace/viewer/i