# Variance comparison

In [1]:
import os
import sys
import torch
import json
import numpy as np
import pandas as pd

sys.path.append('.')
sys.path.append('..')

from utils.pca_utils import *
from plots import *

## Configuration

In [2]:
# Configuration - Change these parameters for different models/datasets
model_name = "llama-3.3-70b"
base_dir = f"/workspace/{model_name}"
type = "roles_240"
dir = f"{base_dir}/{type}"

layer = 40

In [3]:
pca_results = torch.load(f"{dir}/pca/layer{layer}_pos23.pt", weights_only=False)

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


## Variance across and within roles

### raw activations

In [4]:
vectors = torch.stack(pca_results['vectors']['pos_3'])[:, layer, :].float()
print(vectors.shape)

# compute variance across roles (rows) along hidden_dims
raw_across_var = torch.var(vectors, dim=0)
print(raw_across_var.shape)

torch.Size([275, 8192])
torch.Size([8192])


In [5]:
# load in scores
scores = {}
for file in os.listdir(f"{dir}/extract_scores"):
    if file.endswith('.json'):
        scores[file.replace('.json', '')] = json.load(open(f"{dir}/extract_scores/{file}"))

print(f"Loaded {len(scores)} scores")


Loaded 275 scores


In [94]:
# load in raw activations
activations = {}
for file in os.listdir(f"{dir}/response_activations"):
    if file.endswith('.pt') and 'default' not in file:
        # dict we should iterate over (1200 each)
        role_activations = []
        obj = torch.load(f"{dir}/response_activations/{file}")
        for key in obj:
            if scores[file.replace('.pt', '')][key] == 3:
                role_activations.append(obj[key])
        activations[file.replace('.pt', '')] = torch.stack(role_activations)
        



FileNotFoundError: [Errno 2] No such file or directory: '/workspace/llama-3.3-70b/roles_240/response_activations'

In [None]:
# compute variance within roles
raw_within_var = []
for file in activations:
    raw_within_var.append(torch.var(activations[file][:, layer, :], dim=0))

print(f"for {len(raw_within_var)} roles, shape is {raw_within_var[0].shape}")

for 275 roles, shape is torch.Size([4608])


In [None]:
avg_raw_within_var = torch.stack(raw_within_var).mean(dim=0)
print(avg_raw_within_var.shape)



torch.Size([4608])


In [None]:
# total variance ratio
raw_ratio = raw_across_var.sum() / avg_raw_within_var.sum()
print(f"ratio of raw_across_var / avg_raw_within_var is {raw_ratio}")

ratio of raw_across_var / avg_raw_within_var is 0.4312536120414734


In [None]:
raw_across_var_normalized = torch.var(F.normalize(vectors, p=2, dim=1), dim=0)
print(raw_across_var_normalized.shape)



torch.Size([4608])


In [None]:
raw_within_var_normalized = []
for file in activations:
    raw_within_var_normalized.append(torch.var(F.normalize(activations[file][:, layer, :], p=2, dim=1), dim=0))

print(f"for {len(raw_within_var_normalized)} roles, shape is {raw_within_var_normalized[0].shape}")
avg_raw_within_var_normalized = torch.stack(raw_within_var_normalized).mean(dim=0)
print(avg_raw_within_var_normalized.shape)


for 275 roles, shape is torch.Size([4608])
torch.Size([4608])


In [None]:
raw_ratio_normalized = raw_across_var_normalized.mean() / avg_raw_within_var_normalized.mean()
print(f"ratio of raw_across_var_normalized / avg_raw_within_var_normalized is {raw_ratio_normalized}")


ratio of raw_across_var_normalized / avg_raw_within_var_normalized is 0.3623166084289551


### in PC space

In [6]:
# get transformed role vectors
pca_across_var = np.var(pca_results['pca_transformed'][:275], axis=0)
print(pca_across_var.shape)


(377,)


In [7]:
print(activations['absurdist'].shape)

NameError: name 'activations' is not defined

In [None]:
pca_within_var = []
pc1_within_var = []
for role in activations:
    role_scaled = pca_results['scaler'].transform(activations[role][:, layer, :].float().numpy())
    role_pca = pca_results['pca'].transform(role_scaled)
    pca_within_var.append(np.var(role_pca, axis=0))
    pc1_within_var.append(np.var(role_pca[:, 0]))

print(f"for {len(pca_within_var)} roles, shape is {pca_within_var[0].shape}")

for 275 roles, shape is (448,)


In [None]:
mean_pca_within_var = np.array(pca_within_var).mean(axis=0)
print(mean_pca_within_var.shape)


(448,)


In [None]:
pca_ratio = pca_across_var.mean() / mean_pca_within_var.mean()
print(f"ratio of pca_across_var / mean_pca_within_var is {pca_ratio}")

ratio of pca_across_var / mean_pca_within_var is 0.31411965474791115


### pc1 variance only

In [None]:
pc1_across_var = np.var(pca_results['pca_transformed'][:275, 0])
print(pc1_across_var)


829.9696483722173


In [None]:
mean_pc1_within_var = np.array(pc1_within_var).mean()
print(mean_pc1_within_var)

pc1_ratio = pc1_across_var / mean_pc1_within_var
print(f"ratio of pc1_across_var / mean_pc1_within_var is {pc1_ratio}")


291.39803214992793
ratio of pc1_across_var / mean_pc1_within_var is 2.848233539014387


### All PCs variance ratios

In [None]:
# Compute ratio for all PCs using existing variables
all_pc_ratios = pca_across_var / mean_pca_within_var
print(f"Computed ratios for {len(all_pc_ratios)} PCs")
print(f"PC1 ratio: {all_pc_ratios[0]:.4f}")
print(f"Mean ratio (all PCs): {all_pc_ratios.mean():.4f}")
print(f"Mean ratio (PC2-10): {all_pc_ratios[1:10].mean():.4f}")
print(f"Max ratio: {all_pc_ratios.max():.4f} (PC{all_pc_ratios.argmax()+1})")

Computed ratios for 448 PCs
PC1 ratio: 2.8482
Mean ratio (all PCs): 0.0982
Mean ratio (PC2-10): 1.0490
Max ratio: 2.8482 (PC1)


In [None]:
import plotly.graph_objects as go

# Create line plot of PC ratios
fig = go.Figure()

# Add line trace for all PC ratios
fig.add_trace(go.Scatter(
    x=np.arange(1, len(all_pc_ratios) + 1),
    y=all_pc_ratios,
    mode='lines',
    name='PC Ratio',
    line=dict(color='steelblue', width=2)
))


# Add horizontal reference line at ratio=1
fig.add_hline(y=1.0, line_dash="dash", line_color="gray", 
              annotation_text="ratio=1", annotation_position="right")

# Update layout
fig.update_layout(
    title={
        'text': "Variance Ratio (Across-Role / Within-Role) for Role PCs",
        'subtitle': {
            'text': f"{model_name.replace('-', ' ').title()}, Layer {layer}",
        }
    },
    xaxis_title="Principal Component",
    yaxis_title="Variance Ratio",
    width=800,
    height=500,
    hovermode='x unified'
)

fig.update_xaxes(range=[0.5, 10.5], tickvals=np.arange(1, 11))

fig.show()
fig.write_html(f"/root/git/plots/{model_name}/roles/variance_ratios.html")

## Conditional variance of role vectors based on distance from Assistant

In [None]:
role_vectors = torch.stack(pca_results['vectors']['pos_2'] + pca_results['vectors']['pos_3'])[:, layer, :]
print(role_vectors.shape)

pc1 = pca_results['pca_transformed'][:, 0]

torch.Size([448, 4608])


### Conditional variance in raw activation space

In [None]:
from scipy.stats import pearsonr

# Two-group comparison: Assistant-like vs Roleplay
# Using PC1 threshold of -25 (same as in 9_cone.ipynb)


if model_name == "gemma-2-27b":
    threshold = 25
    assistant_mask = pc1 > threshold
    roleplay_mask = pc1 <= threshold
else:
    threshold = -25
    assistant_mask = pc1 < threshold
    roleplay_mask = pc1 >= threshold

# Compute variance of raw activations for each group
# role_vectors shape: [448, 4608]
var_assistant_raw = torch.var(role_vectors[assistant_mask], dim=0).mean().item()
var_roleplay_raw = torch.var(role_vectors[roleplay_mask], dim=0).mean().item()

var_ratio_raw = var_assistant_raw / var_roleplay_raw

print("=" * 60)
print("RAW ACTIVATION SPACE: Two-Group Comparison")
print("=" * 60)
print(f"PC1 threshold: {threshold}")
print(f"Assistant-like roles (PC1 < {threshold}): {assistant_mask.sum()} samples")
print(f"Roleplay roles (PC1 >= {threshold}): {roleplay_mask.sum()} samples")
print(f"\nMean variance (Assistant-like): {var_assistant_raw:.6f}")
print(f"Mean variance (Roleplay): {var_roleplay_raw:.6f}")
print(f"Variance ratio (Assistant/Roleplay): {var_ratio_raw:.4f} ({var_ratio_raw*100:.2f}%)")
print("=" * 60)

RAW ACTIVATION SPACE: Two-Group Comparison
PC1 threshold: 25
Assistant-like roles (PC1 < 25): 128 samples
Roleplay roles (PC1 >= 25): 320 samples

Mean variance (Assistant-like): 54.750000
Mean variance (Roleplay): 151.000000
Variance ratio (Assistant/Roleplay): 0.3626 (36.26%)


In [None]:
# Project out PC1 from raw activations
# Get PC1 direction from PCA
pc1_direction = torch.from_numpy(pca_results['pca'].components_[0]).float()

# Project role_vectors onto PC1 and subtract
# Formula: projection = (v · u) * u, where u is the unit vector (PC1 direction)
pc1_loadings = (role_vectors.float() @ pc1_direction).unsqueeze(1)  # Shape: [448, 1]
pc1_projections = pc1_loadings * pc1_direction.unsqueeze(0)  # Shape: [448, 4608]
role_vectors_pc1_removed = role_vectors - pc1_projections

# Compute variance with PC1 projected out
var_assistant_raw_no_pc1 = torch.var(role_vectors_pc1_removed[assistant_mask], dim=0).mean().item()
var_roleplay_raw_no_pc1 = torch.var(role_vectors_pc1_removed[roleplay_mask], dim=0).mean().item()

var_ratio_raw_no_pc1 = var_assistant_raw_no_pc1 / var_roleplay_raw_no_pc1

print("\n" + "=" * 60)
print("RAW ACTIVATION SPACE (PC1 projected out): Two-Group Comparison")
print("=" * 60)
print(f"PC1 threshold: {threshold}")
print(f"Assistant-like roles (PC1 < {threshold}): {assistant_mask.sum()} samples")
print(f"Roleplay roles (PC1 >= {threshold}): {roleplay_mask.sum()} samples")
print(f"\nMean variance (Assistant-like, PC1 removed): {var_assistant_raw_no_pc1:.6f}")
print(f"Mean variance (Roleplay, PC1 removed): {var_roleplay_raw_no_pc1:.6f}")
print(f"Variance ratio (Assistant/Roleplay): {var_ratio_raw_no_pc1:.4f} ({var_ratio_raw_no_pc1*100:.2f}%)")
print(f"\nThis is analogous to the PC2-10 analysis in PC space.")
print("=" * 60)


RAW ACTIVATION SPACE (PC1 projected out): Two-Group Comparison
PC1 threshold: 25
Assistant-like roles (PC1 < 25): 128 samples
Roleplay roles (PC1 >= 25): 320 samples

Mean variance (Assistant-like, PC1 removed): 54.452919
Mean variance (Roleplay, PC1 removed): 138.992462
Variance ratio (Assistant/Roleplay): 0.3918 (39.18%)

This is analogous to the PC2-10 analysis in PC space.


In [None]:
# Quintile analysis
n_quintiles = 5
quintile_edges = np.quantile(pc1, np.linspace(0, 1, n_quintiles + 1))
quintile_variances = []
quintile_variances_no_pc1 = []
quintile_sizes = []

print("\n" + "=" * 60)
print("RAW ACTIVATION SPACE: Quintile Analysis")
print("=" * 60)

for i in range(n_quintiles):
    if i == 0:
        mask = (pc1 >= quintile_edges[i]) & (pc1 <= quintile_edges[i + 1])
    else:
        mask = (pc1 > quintile_edges[i]) & (pc1 <= quintile_edges[i + 1])
    
    quintile_var = torch.var(role_vectors[mask], dim=0).mean().item()
    quintile_var_no_pc1 = torch.var(role_vectors_pc1_removed[mask], dim=0).mean().item()
    quintile_variances.append(quintile_var)
    quintile_variances_no_pc1.append(quintile_var_no_pc1)
    quintile_sizes.append(mask.sum())
    
    print(f"\nQuintile {i+1}: PC1 ∈ [{quintile_edges[i]:.2f}, {quintile_edges[i+1]:.2f}]")
    print(f"  Sample size: {mask.sum()}")
    print(f"  Mean variance (full): {quintile_var:.6f}")
    print(f"  Mean variance (PC1 removed): {quintile_var_no_pc1:.6f}")

# Calculate ratios between first and last quintile
if model_name == "gemma-2-27b":
    quintile_ratio = quintile_variances[0] / quintile_variances[-1]
    quintile_ratio_no_pc1 = quintile_variances_no_pc1[0] / quintile_variances_no_pc1[-1]
else:
    quintile_ratio = quintile_variances[-1] / quintile_variances[0]
    quintile_ratio_no_pc1 = quintile_variances_no_pc1[-1] / quintile_variances_no_pc1[0]

print("\n" + "-" * 60)
print(f"Variance ratio (Last/First quintile, full): {quintile_ratio:.2f}x")
print(f"Variance ratio (Last/First quintile, PC1 removed): {quintile_ratio_no_pc1:.2f}x")
print("=" * 60)


RAW ACTIVATION SPACE: Quintile Analysis

Quintile 1: PC1 ∈ [-86.47, -33.36]
  Sample size: 90
  Mean variance (full): 136.000000
  Mean variance (PC1 removed): 134.244888

Quintile 2: PC1 ∈ [-33.36, -1.06]
  Sample size: 89
  Mean variance (full): 98.000000
  Mean variance (PC1 removed): 96.719231

Quintile 3: PC1 ∈ [-1.06, 18.72]
  Sample size: 90
  Mean variance (full): 69.000000
  Mean variance (PC1 removed): 68.681099

Quintile 4: PC1 ∈ [18.72, 28.17]
  Sample size: 89
  Mean variance (full): 53.750000
  Mean variance (PC1 removed): 53.688721

Quintile 5: PC1 ∈ [28.17, 38.88]
  Sample size: 90
  Mean variance (full): 53.250000
  Mean variance (PC1 removed): 53.227360

------------------------------------------------------------
Variance ratio (Last/First quintile, full): 2.55x
Variance ratio (Last/First quintile, PC1 removed): 2.52x


In [None]:
# Distance from center correlation
# Compute mean of raw activations
role_vectors_mean = role_vectors.mean(dim=0)
role_vectors_pc1_removed_mean = role_vectors_pc1_removed.mean(dim=0)

# Compute L2 distance from mean for each role
distances_raw = torch.norm(role_vectors.float() - role_vectors_mean, p=2, dim=1).numpy()
distances_raw_no_pc1 = torch.norm(role_vectors_pc1_removed - role_vectors_pc1_removed_mean, p=2, dim=1).numpy()

# Calculate correlation with PC1
correlation_raw, p_value_raw = pearsonr(pc1, distances_raw)
correlation_raw_no_pc1, p_value_raw_no_pc1 = pearsonr(pc1, distances_raw_no_pc1)

print("\n" + "=" * 60)
print("RAW ACTIVATION SPACE: Distance from Center Correlation")
print("=" * 60)
print(f"Correlation between PC1 and L2 distance from mean (full):")
print(f"  r = {correlation_raw:.4f}")
print(f"  p-value = {p_value_raw:.3e}")
if p_value_raw < 0.001:
    print(f"  Highly significant (p < 0.001)")
elif p_value_raw < 0.05:
    print(f"  Significant (p < 0.05)")

print(f"\nCorrelation between PC1 and L2 distance from mean (PC1 removed):")
print(f"  r = {correlation_raw_no_pc1:.4f}")
print(f"  p-value = {p_value_raw_no_pc1:.3e}")
if p_value_raw_no_pc1 < 0.001:
    print(f"  Highly significant (p < 0.001)")
elif p_value_raw_no_pc1 < 0.05:
    print(f"  Significant (p < 0.05)")
print("=" * 60)


RAW ACTIVATION SPACE: Distance from Center Correlation
Correlation between PC1 and L2 distance from mean (full):
  r = -0.5635
  p-value = 6.652e-39
  Highly significant (p < 0.001)

Correlation between PC1 and L2 distance from mean (PC1 removed):
  r = -0.5441
  p-value = 7.002e-36
  Highly significant (p < 0.001)


### Per-PC analysis: Correlation between each PC and distance in remaining PC space

In [None]:
# For each of the top 10 PCs, calculate:
# 1. The correlation between that PC and distance from center in all OTHER PCs
# 2. This tells us if the pattern we see with PC1 generalizes to other PCs

from scipy.stats import pearsonr

n_pcs_to_analyze = 10
pca_transformed = pca_results['pca_transformed']

print("=" * 70)
print("Correlation between each PC and distance in remaining PC space")
print("=" * 70)

correlations = []
p_values = []

for pc_idx in range(n_pcs_to_analyze):
    # Get the PC values
    pc_values = pca_transformed[:, pc_idx]
    
    # Get all other PCs (excluding current PC)
    other_pcs = np.delete(pca_transformed, pc_idx, axis=1)
    
    # Calculate distance from center in the remaining PC space
    other_pcs_mean = other_pcs.mean(axis=0)
    distances = np.linalg.norm(other_pcs - other_pcs_mean, axis=1)
    
    # Calculate correlation
    corr, p_val = pearsonr(pc_values, distances)
    correlations.append(corr)
    p_values.append(p_val)
    
    # Print results
    sig_marker = "***" if p_val < 0.001 else "**" if p_val < 0.01 else "*" if p_val < 0.05 else ""
    print(f"PC{pc_idx+1:2d}: r = {corr:7.4f}, p = {p_val:.3e} {sig_marker}")

print("=" * 70)
print(f"\nPC1 correlation: {correlations[0]:.4f}")
print(f"Mean correlation (PC2-10): {np.mean(correlations[1:]):.4f}")
print("=" * 70)

Correlation between each PC and distance in remaining PC space
PC 1: r = -0.6576, p = 7.913e-57 ***
PC 2: r =  0.2100, p = 7.342e-06 ***
PC 3: r =  0.3147, p = 9.348e-12 ***
PC 4: r =  0.0543, p = 2.516e-01 
PC 5: r = -0.2252, p = 1.468e-06 ***
PC 6: r = -0.2385, p = 3.258e-07 ***
PC 7: r =  0.0882, p = 6.229e-02 
PC 8: r =  0.0283, p = 5.500e-01 
PC 9: r = -0.0532, p = 2.616e-01 
PC10: r =  0.0580, p = 2.208e-01 

PC1 correlation: -0.6576
Mean correlation (PC2-10): 0.0263


### Conditional variance in PC2-10 based on position along each PC

This analysis shows whether the pattern of "extreme positions → high variance in other PCs" is unique to PC1 or generalizes to other PCs.

In [None]:
# For each PC, split roles into two groups (high/low) and compute variance in PC2-10 (excluding that PC)
# This tests if extreme positions on PC_i lead to high variance in other PCs

n_pcs_to_test = 10
pca_transformed = pca_results['pca_transformed']

print("=" * 80)
print("Conditional Variance in PC2-10 based on position along each PC")
print("=" * 80)
print("For each PC, we split roles by median and compute variance in PC2-10 (excluding that PC)")
print("-" * 80)

variance_ratios = []

for pc_idx in range(n_pcs_to_test):
    # Split by median on this PC
    pc_values = pca_transformed[:, pc_idx]
    median_val = np.median(pc_values)
    high_mask = pc_values > median_val
    low_mask = pc_values <= median_val
    
    # Get PC2-10, excluding current PC if it's in that range
    if pc_idx == 0:
        # For PC1, we want variance in PC2-10
        other_pcs = pca_transformed[:, 1:10]
    elif 1 <= pc_idx < 10:
        # For PC2-9, exclude that PC from PC2-10
        pc_indices = [i for i in range(1, 10) if i != pc_idx]
        other_pcs = pca_transformed[:, pc_indices]
    else:
        # For PC10, use PC2-9
        other_pcs = pca_transformed[:, 1:10]
    
    # Compute variance for each group
    var_high = np.var(other_pcs[high_mask], axis=0).mean()
    var_low = np.var(other_pcs[low_mask], axis=0).mean()
    
    ratio = max(var_high, var_low) / min(var_high, var_low)
    variance_ratios.append(ratio)
    
    print(f"PC{pc_idx+1:2d}: High={high_mask.sum():3d} samples, Low={low_mask.sum():3d} samples")
    print(f"      Var(high) = {var_high:8.3f}, Var(low) = {var_low:8.3f}, Ratio = {ratio:.3f}")

print("=" * 80)
print(f"\nSummary:")
print(f"  PC1 variance ratio: {variance_ratios[0]:.3f}")
print(f"  Mean variance ratio for PC2-10: {np.mean(variance_ratios[1:]):.3f}")
print(f"  Max variance ratio (excluding PC1): {np.max(variance_ratios[1:]):.3f} (PC{np.argmax(variance_ratios[1:])+2})")
print("\n  → Shows whether PC1 is unique in having high-variance 'other dimensions' for extreme positions")
print("=" * 80)

Conditional Variance in PC2-10 based on position along each PC
For each PC, we split roles by median and compute variance in PC2-10 (excluding that PC)
--------------------------------------------------------------------------------
PC 1: High=224 samples, Low=224 samples
      Var(high) =   99.655, Var(low) =  303.220, Ratio = 3.043
PC 2: High=224 samples, Low=224 samples
      Var(high) =  230.203, Var(low) =  130.662, Ratio = 1.762
PC 3: High=224 samples, Low=224 samples
      Var(high) =  211.142, Var(low) =  164.985, Ratio = 1.280
PC 4: High=224 samples, Low=224 samples
      Var(high) =  179.438, Var(low) =  214.903, Ratio = 1.198
PC 5: High=224 samples, Low=224 samples
      Var(high) =  157.295, Var(low) =  247.421, Ratio = 1.573
PC 6: High=224 samples, Low=224 samples
      Var(high) =  195.218, Var(low) =  209.086, Ratio = 1.071
PC 7: High=224 samples, Low=224 samples
      Var(high) =  234.147, Var(low) =  181.103, Ratio = 1.293
PC 8: High=224 samples, Low=224 samples
      

In [None]:
# Create role labels from pca_results
def get_role_labels_from_pca(pca_results):
    labels = []
    if 'pos_2' in pca_results['roles'].keys():
        pos_2_roles = [role.replace('_', ' ').title() for role in pca_results['roles']['pos_2']]
        labels.extend(pos_2_roles)
    if 'pos_3' in pca_results['roles'].keys():
        pos_3_roles = [role.replace('_', ' ').title() for role in pca_results['roles']['pos_3']]
        labels.extend(pos_3_roles)
    return labels

role_labels = get_role_labels_from_pca(pca_results)
print(f"Total roles: {len(role_labels)}")
print(f"pca_transformed shape: {pca_transformed.shape}")

Total roles: 448
pca_transformed shape: (448, 448)


In [None]:
# Show top/bottom roles for each PC
n_pcs_to_show = 10  # Show first 5 PCs
n_roles_to_show = 5  # Show top/bottom 5 roles

print("=" * 80)
print("Top and Bottom Roles for Each PC")
print("=" * 80)

for pc_idx in range(n_pcs_to_show):
    pc_values = pca_transformed[:, pc_idx]
    
    # Get indices of top and bottom roles
    top_indices = np.argsort(pc_values)[-n_roles_to_show:][::-1]
    bottom_indices = np.argsort(pc_values)[:n_roles_to_show]
    
    print(f"\nPC{pc_idx+1}:")
    print(f"  Top {n_roles_to_show} (highest loadings):")
    for i, idx in enumerate(top_indices):
        print(f"    {i+1}. {role_labels[idx]:30s} (PC{pc_idx+1} = {pc_values[idx]:7.2f})")
    
    print(f"  Bottom {n_roles_to_show} (lowest loadings):")
    for i, idx in enumerate(bottom_indices):
        print(f"    {i+1}. {role_labels[idx]:30s} (PC{pc_idx+1} = {pc_values[idx]:7.2f})")

print("=" * 80)

Top and Bottom Roles for Each PC

PC1:
  Top 5 (highest loadings):
    1. Assistant                      (PC1 =   38.88)
    2. Screener                       (PC1 =   38.71)
    3. Doctor                         (PC1 =   37.75)
    4. Analyst                        (PC1 =   36.68)
    5. Researcher                     (PC1 =   36.55)
  Bottom 5 (lowest loadings):
    1. Caveman                        (PC1 =  -86.47)
    2. Eldritch                       (PC1 =  -79.26)
    3. Leviathan                      (PC1 =  -79.20)
    4. Void                           (PC1 =  -74.09)
    5. Aberration                     (PC1 =  -69.96)

PC2:
  Top 5 (highest loadings):
    1. Procrastinator                 (PC2 =   89.66)
    2. Teenager                       (PC2 =   79.36)
    3. Adolescent                     (PC2 =   77.52)
    4. Toddler                        (PC2 =   61.69)
    5. Gossip                         (PC2 =   52.68)
  Bottom 5 (lowest loadings):
    1. Eldritch              

## Default loading along each PC

In [None]:
# get default activation and project
default_vectors = torch.load(f"{base_dir}/roles_240/default_vectors.pt")
assistant_layer_activation = default_vectors['activations']['default_1'][layer, :].float().reshape(1, -1)

asst_scaled = pca_results['scaler'].transform(assistant_layer_activation.numpy())
asst_projected = pca_results['pca'].transform(asst_scaled)
print(f"Assistant projection shape: {asst_projected.shape}")

# Compare each PC loading with the min, max loading of that PC across all roles
n_pcs = 10  # or however many you want to analyze
pca_transformed = pca_results['pca_transformed']

print("\n" + "=" * 80)
print("Assistant (default) position relative to role distribution on each PC")
print("=" * 80)

for pc_idx in range(n_pcs):
    # Get assistant's loading on this PC
    asst_loading = asst_projected[0, pc_idx]
    
    # Get all role loadings on this PC
    all_loadings = pca_transformed[:, pc_idx]
    min_loading = all_loadings.min()
    max_loading = all_loadings.max()
    
    # Calculate relative position (0 = at min, 1 = at max)
    if max_loading != min_loading:
        relative_position = (asst_loading - min_loading) / (max_loading - min_loading)
    else:
        relative_position = 0.5
    
    # Distance to nearest boundary (normalized)
    dist_to_min = (asst_loading - min_loading) / (max_loading - min_loading)
    dist_to_max = (max_loading - asst_loading) / (max_loading - min_loading)
    min_boundary_dist = min(dist_to_min, dist_to_max)
    
    print(f"\nPC{pc_idx+1}:")
    print(f"  Range: [{min_loading:8.2f}, {max_loading:8.2f}]")
    print(f"  Assistant: {asst_loading:8.2f}")
    print(f"  Relative position: {relative_position:.3f} (0=min, 1=max)")
    print(f"  Min boundary distance: {min_boundary_dist:.3f}")
    
    # Interpret position
    if relative_position < 0.25:
        position_desc = "near minimum"
    elif relative_position < 0.45:
        position_desc = "below center"
    elif relative_position < 0.55:
        position_desc = "centered"
    elif relative_position < 0.75:
        position_desc = "above center"
    else:
        position_desc = "near maximum"
    print(f"  Position: {position_desc}")

print("=" * 80)

Assistant projection shape: (1, 448)

Assistant (default) position relative to role distribution on each PC

PC1:
  Range: [  -86.47,    38.88]
  Assistant:    34.72
  Relative position: 0.967 (0=min, 1=max)
  Min boundary distance: 0.033
  Position: near maximum

PC2:
  Range: [  -57.07,    89.66]
  Assistant:    -3.73
  Relative position: 0.364 (0=min, 1=max)
  Min boundary distance: 0.364
  Position: below center

PC3:
  Range: [  -34.64,   109.26]
  Assistant:     1.94
  Relative position: 0.254 (0=min, 1=max)
  Min boundary distance: 0.254
  Position: below center

PC4:
  Range: [  -45.79,    65.85]
  Assistant:    -0.14
  Relative position: 0.409 (0=min, 1=max)
  Min boundary distance: 0.409
  Position: below center

PC5:
  Range: [  -38.57,    60.55]
  Assistant:    -9.77
  Relative position: 0.291 (0=min, 1=max)
  Min boundary distance: 0.291
  Position: below center

PC6:
  Range: [ -106.46,    26.33]
  Assistant:    -1.42
  Relative position: 0.791 (0=min, 1=max)
  Min bounda

## Overall activation variance captured

In [None]:
# load in the mean_activations.pt and the role/trait projections...

act_dir = f"/workspace/{model_name}/dataset_activations/lmsys_10000"

chat_raw = torch.load(f"{act_dir}/mean_activations.pt")
chat_roles = torch.load(f"{act_dir}/roles_projections.pt", weights_only=False)
chat_traits = torch.load(f"{act_dir}/traits_projections.pt", weights_only=False)

print(chat_raw.keys())
print(chat_roles.keys())
print(chat_traits.keys())


dict_keys(['activations', 'target_layer'])
dict_keys(['projected', 'explained_variance_ratio', 'pca_n_components', 'pca_explained_variance_from_fit', 'target_layer', 'pca_config_path'])
dict_keys(['projected', 'explained_variance_ratio', 'pca_n_components', 'pca_explained_variance_from_fit', 'target_layer', 'pca_config_path'])


In [None]:
# Get the raw activations
raw_activations = chat_raw['activations'][:, layer, :].float()
print(f"Raw activations shape: {raw_activations.shape}")

# Calculate total variance in raw activations
total_var = torch.var(raw_activations, dim=0).sum().item()
print(f"\nTotal variance in raw activations: {total_var:.2f}")

# For roles: reconstruct from PCA space back to raw space
roles_projected = chat_roles['projected']  # Shape: [18950, 463]
# Inverse transform: unstandardize and inverse PCA
roles_reconstructed = pca_results['pca'].inverse_transform(roles_projected)  # This gives standardized features
roles_reconstructed = pca_results['scaler'].inverse_transform(roles_reconstructed)  # Unstandardize
roles_reconstructed = torch.from_numpy(roles_reconstructed).float()

# Calculate variance in reconstructed activations
roles_var = torch.var(roles_reconstructed, dim=0).sum().item()
roles_variance_explained = roles_var / total_var

print(f"\nRole subspace:")
print(f"  Variance captured: {roles_var:.2f}")
print(f"  Variance explained: {roles_variance_explained:.4f} ({roles_variance_explained*100:.2f}%)")

# For traits: load trait PCA results and do the same
trait_pca_results = torch.load(f"{base_dir}/traits_240/pca/layer{layer}_pos-neg50.pt", weights_only=False)
traits_projected = chat_traits['projected']  # Shape: [18950, 240]

traits_reconstructed = trait_pca_results['pca'].inverse_transform(traits_projected)
traits_reconstructed = trait_pca_results['scaler'].inverse_transform(traits_reconstructed)
traits_reconstructed = torch.from_numpy(traits_reconstructed).float()

# Calculate variance in reconstructed activations
traits_var = torch.var(traits_reconstructed, dim=0).sum().item()
traits_variance_explained = traits_var / total_var

print(f"\nTrait subspace:")
print(f"  Variance captured: {traits_var:.2f}")
print(f"  Variance explained: {traits_variance_explained:.4f} ({traits_variance_explained*100:.2f}%)")

# Summary
print("\n" + "=" * 60)
print("Summary: Variance Explained by Subspaces")
print("=" * 60)
print(f"Role subspace:  {roles_variance_explained*100:.2f}%")
print(f"Trait subspace: {traits_variance_explained*100:.2f}%")
print("=" * 60)

Raw activations shape: torch.Size([18777, 4608])

Total variance in raw activations: 7442113.00



Role subspace:
  Variance captured: 1283001.00
  Variance explained: 0.1724 (17.24%)



Trying to unpickle estimator PCA from version 1.7.0 when using version 1.7.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


Trying to unpickle estimator StandardScaler from version 1.7.0 when using version 1.7.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations




Trait subspace:
  Variance captured: 1234600.38
  Variance explained: 0.1659 (16.59%)

Summary: Variance Explained by Subspaces
Role subspace:  17.24%
Trait subspace: 16.59%


### Conditional variance in LMSYS chat samples based on PC1

In [None]:
# Use the already-projected LMSYS chat samples to get PC1 values
# roles_projected already contains the PC scores for all samples
lmsys_pc1 = roles_projected[:, 0]

# Use the same threshold as for role vectors
# Create assistant-like and roleplay masks based on PC1
if model_name == "gemma-2-27b":
    lmsys_assistant_mask = lmsys_pc1 > threshold
    lmsys_roleplay_mask = lmsys_pc1 <= threshold
else:
    lmsys_assistant_mask = lmsys_pc1 < threshold
    lmsys_roleplay_mask = lmsys_pc1 >= threshold

# Compute variance for each group (full raw activation space)
var_lmsys_assistant = torch.var(raw_activations[lmsys_assistant_mask], dim=0).mean().item()
var_lmsys_roleplay = torch.var(raw_activations[lmsys_roleplay_mask], dim=0).mean().item()
var_ratio_lmsys = var_lmsys_assistant / var_lmsys_roleplay

# Project out PC1 from raw activations
pc1_direction = torch.from_numpy(pca_results['pca'].components_[0]).float()
pc1_loadings = (raw_activations @ pc1_direction).unsqueeze(1)
pc1_projections = pc1_loadings * pc1_direction.unsqueeze(0)
raw_activations_pc1_removed = raw_activations - pc1_projections

# Compute variance with PC1 projected out
var_lmsys_assistant_no_pc1 = torch.var(raw_activations_pc1_removed[lmsys_assistant_mask], dim=0).mean().item()
var_lmsys_roleplay_no_pc1 = torch.var(raw_activations_pc1_removed[lmsys_roleplay_mask], dim=0).mean().item()
var_ratio_lmsys_no_pc1 = var_lmsys_assistant_no_pc1 / var_lmsys_roleplay_no_pc1

# Print results
print("=" * 60)
print("LMSYS CHAT DATASET: Conditional Variance Analysis")
print("=" * 60)
print(f"PC1 threshold: {threshold}")
print(f"Assistant-like samples (PC1 {'>' if model_name == 'gemma-2-27b' else '<'} {threshold}): {lmsys_assistant_mask.sum()} samples")
print(f"Roleplay samples (PC1 {'<=' if model_name == 'gemma-2-27b' else '>='} {threshold}): {lmsys_roleplay_mask.sum()} samples")
print(f"\nFull activation space:")
print(f"  Variance (Assistant-like): {var_lmsys_assistant:.6f}")
print(f"  Variance (Roleplay): {var_lmsys_roleplay:.6f}")
print(f"  Ratio (Assistant/Roleplay): {var_ratio_lmsys:.4f} ({var_ratio_lmsys*100:.2f}%)")
print(f"\nPC1 projected out:")
print(f"  Variance (Assistant-like): {var_lmsys_assistant_no_pc1:.6f}")
print(f"  Variance (Roleplay): {var_lmsys_roleplay_no_pc1:.6f}")
print(f"  Ratio (Assistant/Roleplay): {var_ratio_lmsys_no_pc1:.4f} ({var_ratio_lmsys_no_pc1*100:.2f}%)")
print("=" * 60)

LMSYS CHAT DATASET: Conditional Variance Analysis
PC1 threshold: 25
Assistant-like samples (PC1 > 25): 36 samples
Roleplay samples (PC1 <= 25): 18741 samples

Full activation space:
  Variance (Assistant-like): 873.245911
  Variance (Roleplay): 1615.255005
  Ratio (Assistant/Roleplay): 0.5406 (54.06%)

PC1 projected out:
  Variance (Assistant-like): 872.418762
  Variance (Roleplay): 1611.242676
  Ratio (Assistant/Roleplay): 0.5415 (54.15%)


In [None]:
# Quintile analysis for LMSYS dataset
n_quintiles = 5
lmsys_quintile_edges = np.quantile(lmsys_pc1, np.linspace(0, 1, n_quintiles + 1))
lmsys_quintile_variances = []
lmsys_quintile_variances_no_pc1 = []
lmsys_quintile_sizes = []

print("\n" + "=" * 60)
print("LMSYS DATASET: Quintile Analysis")
print("=" * 60)

for i in range(n_quintiles):
    if i == 0:
        mask = (lmsys_pc1 >= lmsys_quintile_edges[i]) & (lmsys_pc1 <= lmsys_quintile_edges[i + 1])
    else:
        mask = (lmsys_pc1 > lmsys_quintile_edges[i]) & (lmsys_pc1 <= lmsys_quintile_edges[i + 1])
    
    quintile_var = torch.var(raw_activations[mask], dim=0).mean().item()
    quintile_var_no_pc1 = torch.var(raw_activations_pc1_removed[mask], dim=0).mean().item()
    lmsys_quintile_variances.append(quintile_var)
    lmsys_quintile_variances_no_pc1.append(quintile_var_no_pc1)
    lmsys_quintile_sizes.append(mask.sum())
    
    print(f"\nQuintile {i+1}: PC1 ∈ [{lmsys_quintile_edges[i]:.2f}, {lmsys_quintile_edges[i+1]:.2f}]")
    print(f"  Sample size: {mask.sum()}")
    print(f"  Mean variance (full): {quintile_var:.6f}")
    print(f"  Mean variance (PC1 removed): {quintile_var_no_pc1:.6f}")

# Calculate ratios between first and last quintile
if model_name == "gemma-2-27b":
    lmsys_quintile_ratio = lmsys_quintile_variances[0] / lmsys_quintile_variances[-1]
    lmsys_quintile_ratio_no_pc1 = lmsys_quintile_variances_no_pc1[0] / lmsys_quintile_variances_no_pc1[-1]
else:
    lmsys_quintile_ratio = lmsys_quintile_variances[-1] / lmsys_quintile_variances[0]
    lmsys_quintile_ratio_no_pc1 = lmsys_quintile_variances_no_pc1[-1] / lmsys_quintile_variances_no_pc1[0]

print("\n" + "-" * 60)
print(f"Variance ratio (Last/First quintile, full): {lmsys_quintile_ratio:.2f}x")
print(f"Variance ratio (Last/First quintile, PC1 removed): {lmsys_quintile_ratio_no_pc1:.2f}x")
print("=" * 60)


LMSYS DATASET: Quintile Analysis

Quintile 1: PC1 ∈ [-49.10, -20.29]
  Sample size: 3756
  Mean variance (full): 1653.324341
  Mean variance (PC1 removed): 1651.881348

Quintile 2: PC1 ∈ [-20.29, -12.90]
  Sample size: 3755
  Mean variance (full): 1468.764160
  Mean variance (PC1 removed): 1467.981079

Quintile 3: PC1 ∈ [-12.90, -8.29]
  Sample size: 3755
  Mean variance (full): 1188.254761
  Mean variance (PC1 removed): 1187.666382

Quintile 4: PC1 ∈ [-8.29, -3.87]
  Sample size: 3755
  Mean variance (full): 1128.863647
  Mean variance (PC1 removed): 1128.212158

Quintile 5: PC1 ∈ [-3.87, 41.18]
  Sample size: 3756
  Mean variance (full): 1198.855713
  Mean variance (PC1 removed): 1197.340820

------------------------------------------------------------
Variance ratio (Last/First quintile, full): 1.38x
Variance ratio (Last/First quintile, PC1 removed): 1.38x


In [None]:
# Distance from center correlation for LMSYS dataset
raw_activations_mean = raw_activations.mean(dim=0)
raw_activations_pc1_removed_mean = raw_activations_pc1_removed.mean(dim=0)

# Compute L2 distance from mean for each sample
lmsys_distances_raw = torch.norm(raw_activations - raw_activations_mean, p=2, dim=1).numpy()
lmsys_distances_raw_no_pc1 = torch.norm(raw_activations_pc1_removed - raw_activations_pc1_removed_mean, p=2, dim=1).numpy()

# Calculate correlation with PC1
lmsys_correlation_raw, lmsys_p_value_raw = pearsonr(lmsys_pc1, lmsys_distances_raw)
lmsys_correlation_raw_no_pc1, lmsys_p_value_raw_no_pc1 = pearsonr(lmsys_pc1, lmsys_distances_raw_no_pc1)

print("\n" + "=" * 60)
print("LMSYS DATASET: Distance from Center Correlation")
print("=" * 60)
print(f"Correlation between PC1 and L2 distance from mean (full):")
print(f"  r = {lmsys_correlation_raw:.4f}")
print(f"  p-value = {lmsys_p_value_raw:.3e}")
if lmsys_p_value_raw < 0.001:
    print(f"  Highly significant (p < 0.001)")
elif lmsys_p_value_raw < 0.05:
    print(f"  Significant (p < 0.05)")

print(f"\nCorrelation between PC1 and L2 distance from mean (PC1 removed):")
print(f"  r = {lmsys_correlation_raw_no_pc1:.4f}")
print(f"  p-value = {lmsys_p_value_raw_no_pc1:.3e}")
if lmsys_p_value_raw_no_pc1 < 0.001:
    print(f"  Highly significant (p < 0.001)")
elif lmsys_p_value_raw_no_pc1 < 0.05:
    print(f"  Significant (p < 0.05)")
print("=" * 60)


LMSYS DATASET: Distance from Center Correlation
Correlation between PC1 and L2 distance from mean (full):
  r = -0.2746
  p-value = 5.039e-322
  Highly significant (p < 0.001)

Correlation between PC1 and L2 distance from mean (PC1 removed):
  r = -0.2752
  p-value = 1.976e-323
  Highly significant (p < 0.001)


## Individual model saving

In [None]:
from datetime import datetime

# Configuration for saving
outdir = "./results"
os.makedirs(outdir, exist_ok=True)

# Get current timestamp
timestamp = datetime.now().isoformat()

print(f"Saving variance analysis results to {outdir}/")
print(f"Timestamp: {timestamp}")

Saving variance analysis results to ./results/
Timestamp: 2025-10-22T06:20:43.219267


In [None]:
# Build the per-model variance analysis JSON structure

# Build quintile data
quintiles_data = []
for i in range(len(quintile_edges) - 1):
    quintiles_data.append({
        "quintile": i + 1,
        "pc1_range": [float(quintile_edges[i]), float(quintile_edges[i + 1])],
        "n_samples": int(quintile_sizes[i]),
        "variance_full": float(quintile_variances[i]),
        "variance_pc1_removed": float(quintile_variances_no_pc1[i])
    })

# Build LMSYS quintile data
lmsys_quintiles_data = []
for i in range(len(lmsys_quintile_edges) - 1):
    lmsys_quintiles_data.append({
        "quintile": i + 1,
        "pc1_range": [float(lmsys_quintile_edges[i]), float(lmsys_quintile_edges[i + 1])],
        "n_samples": int(lmsys_quintile_sizes[i]),
        "variance_full": float(lmsys_quintile_variances[i]),
        "variance_pc1_removed": float(lmsys_quintile_variances_no_pc1[i])
    })

# Build PC distance correlations
pc_distance_corrs = []
for i in range(len(correlations)):
    pc_distance_corrs.append({
        "pc": i + 1,
        "r": float(correlations[i]),
        "p_value": float(p_values[i]),
        "significant": bool(p_values[i] < 0.05)
    })

# Build conditional variance by PC
cond_var_by_pc = []
for i in range(len(variance_ratios)):
    cond_var_by_pc.append({
        "pc": i + 1,
        "ratio": float(variance_ratios[i])
    })

# Build default PC loading data
pc_positions = []
centered_pcs = []
extreme_pcs = []

for pc_idx in range(n_pcs):
    asst_loading = asst_projected[0, pc_idx]
    all_loadings = pca_transformed[:, pc_idx]
    min_loading = all_loadings.min()
    max_loading = all_loadings.max()

    if max_loading != min_loading:
        relative_position = (asst_loading - min_loading) / (max_loading - min_loading)
    else:
        relative_position = 0.5

    dist_to_min = relative_position
    dist_to_max = 1.0 - relative_position
    min_boundary_dist = min(dist_to_min, dist_to_max)

    if relative_position < 0.25:
        position_desc = "near minimum"
        extreme_pcs.append(pc_idx + 1)
    elif relative_position < 0.45:
        position_desc = "below center"
    elif relative_position < 0.55:
        position_desc = "centered"
        centered_pcs.append(pc_idx + 1)
    elif relative_position < 0.75:
        position_desc = "above center"
    else:
        position_desc = "near maximum"
        extreme_pcs.append(pc_idx + 1)

    pc_positions.append({
        "pc": pc_idx + 1,
        "assistant_loading": float(asst_loading),
        "role_range_min": float(min_loading),
        "role_range_max": float(max_loading),
        "relative_position": float(relative_position),
        "min_boundary_distance": float(min_boundary_dist),
        "position_category": position_desc
    })

# Build the complete JSON structure
model_variance_data = {
    "model_name": model_name,
    "layer": layer,
    "hidden_dim": vectors.shape[1],
    "n_roles": len(activations),
    "n_role_samples": role_vectors.shape[0],
    "timestamp": timestamp,
    "analysis_version": "1.0",

    "across_within_role_var": {
        "raw_activations": {
            "across_var_mean": float(raw_across_var.mean().item()),
            "within_var_mean": float(avg_raw_within_var.mean().item()),
            "ratio": float(raw_ratio)
        },
        "raw_activations_normalized": {
            "across_var_mean": float(raw_across_var_normalized.mean().item()),
            "within_var_mean": float(avg_raw_within_var_normalized.mean().item()),
            "ratio": float(raw_ratio_normalized)
        },
        "pca_space_all_components": {
            "across_var_mean": float(pca_across_var.mean()),
            "within_var_mean": float(mean_pca_within_var.mean()),
            "ratio": float(pca_ratio),
            "n_components": int(len(pca_across_var))
        },
        "pc1_only": {
            "across_var": float(pc1_across_var),
            "within_var_mean": float(mean_pc1_within_var),
            "ratio": float(pc1_ratio)
        },
        "per_pc_ratios": {
            "description": "Ratio of across-role variance to mean within-role variance for each PC",
            "top_10_pcs": [
                {
                    "pc": i + 1,
                    "across_var": float(pca_across_var[i]),
                    "within_var_mean": float(mean_pca_within_var[i]),
                    "ratio": float(all_pc_ratios[i])
                }
                for i in range(10)
            ]
        }
    },

    "conditional_var_roles": {
        "description": "Conditional variance analysis for role vectors based on PC1 position",
        "n_samples": int(role_vectors.shape[0]),
        "threshold_analysis": {
            "pc1_threshold": threshold,
            "assistant_like": {
                "mask": f"pc1 < {threshold}",
                "n_samples": int(assistant_mask.sum()),
                "variance_raw": float(var_assistant_raw),
                "variance_raw_pc1_removed": float(var_assistant_raw_no_pc1)
            },
            "roleplay": {
                "mask": f"pc1 >= {threshold}",
                "n_samples": int(roleplay_mask.sum()),
                "variance_raw": float(var_roleplay_raw),
                "variance_raw_pc1_removed": float(var_roleplay_raw_no_pc1)
            },
            "variance_ratio_raw": float(var_ratio_raw),
            "variance_ratio_raw_pc1_removed": float(var_ratio_raw_no_pc1)
        },

        "quintile_analysis": {
            "n_quintiles": 5,
            "quintiles": quintiles_data,
            "variance_ratio_first_to_last_full": float(quintile_ratio),
            "variance_ratio_first_to_last_pc1_removed": float(quintile_ratio_no_pc1)
        },

        "distance_correlation": {
            "full_space": {
                "correlation": float(correlation_raw),
                "p_value": float(p_value_raw),
                "significant": bool(p_value_raw < 0.05)
            },
            "pc1_removed": {
                "correlation": float(correlation_raw_no_pc1),
                "p_value": float(p_value_raw_no_pc1),
                "significant": bool(p_value_raw_no_pc1 < 0.05)
            }
        }
    },

    "conditional_var_dataset": {
        "description": "Conditional variance analysis for LMSYS chat dataset based on PC1 position",
        "n_samples": int(raw_activations.shape[0]),
        "threshold_analysis": {
            "pc1_threshold": threshold,
            "assistant_like": {
                "n_samples": int(lmsys_assistant_mask.sum()),
                "variance_full": float(var_lmsys_assistant),
                "variance_pc1_removed": float(var_lmsys_assistant_no_pc1)
            },
            "roleplay": {
                "n_samples": int(lmsys_roleplay_mask.sum()),
                "variance_full": float(var_lmsys_roleplay),
                "variance_pc1_removed": float(var_lmsys_roleplay_no_pc1)
            },
            "variance_ratio_full": float(var_ratio_lmsys),
            "variance_ratio_pc1_removed": float(var_ratio_lmsys_no_pc1)
        },

        "quintile_analysis": {
            "n_quintiles": 5,
            "quintiles": lmsys_quintiles_data,
            "variance_ratio_first_to_last_full": float(lmsys_quintile_ratio),
            "variance_ratio_first_to_last_pc1_removed": float(lmsys_quintile_ratio_no_pc1)
        },

        "distance_correlation": {
            "full_space": {
                "correlation": float(lmsys_correlation_raw),
                "p_value": float(lmsys_p_value_raw),
                "significant": bool(lmsys_p_value_raw < 0.05)
            },
            "pc1_removed": {
                "correlation": float(lmsys_correlation_raw_no_pc1),
                "p_value": float(lmsys_p_value_raw_no_pc1),
                "significant": bool(lmsys_p_value_raw_no_pc1 < 0.05)
            }
        }
    },

    "high_var_pc_correlation": {
        "pc_distance_correlations": {
            "description": "Correlation between each PC and distance in remaining PC space",
            "n_pcs_analyzed": 10,
            "correlations": pc_distance_corrs,
            "pc1_correlation": float(correlations[0]),
            "mean_correlation_pc2_to_10": float(np.mean(correlations[1:]))
        },

        "conditional_variance_by_pc": {
            "description": "Variance in PC2-10 conditioned on high/low position along each PC",
            "n_pcs_analyzed": 10,
            "variance_ratios": cond_var_by_pc,
            "pc1_variance_ratio": float(variance_ratios[0]),
            "mean_variance_ratio_pc2_to_10": float(np.mean(variance_ratios[1:])),
            "max_variance_ratio_excluding_pc1": float(np.max(variance_ratios[1:])),
            "max_variance_ratio_pc": int(np.argmax(variance_ratios[1:]) + 2)
        }
    },

    "default_pc_loading": {
        "description": "Position of default assistant activation relative to role distribution on each PC",
        "n_pcs_analyzed": n_pcs,
        "pc_positions": pc_positions,
        "summary": {
            "pc1_position": pc_positions[0]["position_category"],
            "pc1_relative_position": pc_positions[0]["relative_position"],
            "centered_pcs": centered_pcs,
            "extreme_pcs": extreme_pcs
        }
    },

    "overall_activation_var": {
        "description": "Variance in chat dataset activations explained by role and trait subspaces",
        "dataset": {
            "name": "lmsys_10000",
            "n_samples": int(raw_activations.shape[0]),
            "source_path": act_dir
        },
        "total_variance": float(total_var),
        "role_subspace": {
            "n_components": int(roles_projected.shape[1]),
            "variance_captured": float(roles_var),
            "variance_explained_ratio": float(roles_variance_explained),
            "variance_explained_percent": float(roles_variance_explained * 100)
        },
        "trait_subspace": {
            "n_components": int(traits_projected.shape[1]),
            "variance_captured": float(traits_var),
            "variance_explained_ratio": float(traits_variance_explained),
            "variance_explained_percent": float(traits_variance_explained * 100)
        }
    }
}

print("Built per-model variance analysis data structure")

Built per-model variance analysis data structure


In [None]:
# Save per-model variance analysis to JSON file
filename = f"{outdir}/{model_name.lower()}/variance_layer{layer}.json"

# Load existing JSON if it exists, otherwise start with empty dict
try:
    with open(filename, 'r') as f:
        existing_data = json.load(f)
    print(f"Loaded existing data from: {filename}")
except FileNotFoundError:
    existing_data = {}
    print(f"No existing file found, creating new data structure")

# Update only the fields we want to save
existing_data.update({
    # Update these specific sections (uncomment the ones you want to update)
    "across_within_role_var": model_variance_data["across_within_role_var"],  # Includes per_pc_ratios
    # "conditional_var_roles": model_variance_data["conditional_var_roles"],
    "conditional_var_dataset": model_variance_data["conditional_var_dataset"],
    # "high_var_pc_correlation": model_variance_data["high_var_pc_correlation"],
    # "default_pc_loading": model_variance_data["default_pc_loading"],
    # "overall_activation_var": model_variance_data["overall_activation_var"],
})

# Save back to file
with open(filename, 'w') as f:
    json.dump(existing_data, f, indent=2)

print(f"Saved: {filename}")
print(f"✓ Updated variance analysis for {model_name}")

## Correlations between role loadings onto PCs across the 3 models

In [None]:
# models = ['gemma-2-27b', 'qwen-3-32b', 'llama-3.3-70b']
# layers = [22, 32, 40]

# trait_results = {}
# labels = {}
# for model, layer in zip(models, layers):
#     model_dir = f"/workspace/{model}/traits_240"
#     trait_results[model] = torch.load(f"{model_dir}/pca/layer{layer}_pos-neg50.pt", weights_only=False)
#     print(trait_results[model]['pca_transformed'].shape)
#     labels[model] = trait_results[model]['traits']['pos_neg_50']
#     print(labels[model][:20])

# # need to get intersection of traits across models (gemma missing vindictive)
# pca_transformed = []
# for model in models:
#     if model != 'gemma-2-27b':
#         # splice out index 5 but keep the ones before and after
#         pca_transformed.append(np.concatenate((trait_results[model]['pca_transformed'][:5], trait_results[model]['pca_transformed'][6:])))
#     else:
#         pca_transformed.append(trait_results[model]['pca_transformed'])

# for m in pca_transformed:
#     print(m.shape)

In [None]:
# # Transpose each matrix so rows are PCs and columns are traits
# pca_transposed = [m.T for m in pca_transformed]

# # Extract top 10 PCs from each model
# n_pcs = 6
# top_pcs = [m[:n_pcs] for m in pca_transposed]

# print(f"Transposed shapes (n_pcs, n_traits):")
# for model, pc_matrix in zip(models, top_pcs):
#     print(f"{model}: {pc_matrix.shape}")

# # Compute pairwise correlations for each PC
# from scipy.stats import pearsonr

# pc_correlations = []
# for pc_idx in range(n_pcs):
#     # Extract the trait loading vector for this PC from each model
#     gemma_pc = top_pcs[0][pc_idx]
#     qwen_pc = top_pcs[1][pc_idx]
#     llama_pc = top_pcs[2][pc_idx]
    
#     # Compute pairwise correlations
#     corr_gemma_qwen, _ = pearsonr(gemma_pc, qwen_pc)
#     corr_gemma_llama, _ = pearsonr(gemma_pc, llama_pc)
#     corr_qwen_llama, _ = pearsonr(qwen_pc, llama_pc)
    
#     # Create 3x3 correlation matrix
#     corr_matrix = np.array([
#         [1.0, corr_gemma_qwen, corr_gemma_llama],
#         [corr_gemma_qwen, 1.0, corr_qwen_llama],
#         [corr_gemma_llama, corr_qwen_llama, 1.0]
#     ])
    
#     pc_correlations.append(corr_matrix)

#     print(f"\nPC{pc_idx + 1}:")
#     print(f"  Gemma ↔ Qwen:  {corr_gemma_qwen:7.4f}")
#     print(f"  Gemma ↔ Llama: {corr_gemma_llama:7.4f}")
#     print(f"  Qwen  ↔ Llama: {corr_qwen_llama:7.4f}")

In [None]:
# # try for top 10 role PCs
# models = ['gemma-2-27b', 'qwen-3-32b', 'llama-3.3-70b']
# layers = [22, 32, 40]

# def get_role_labels(pca_results):
#     labels = []
#     if 'pos_2' in pca_results['roles'].keys():
#         pos_2_roles = [role.replace('_', ' ').title() for role in pca_results['roles']['pos_2']]
#         pos_2_roles = [f"{role} (Somewhat RP)" for role in pos_2_roles]
#         labels.extend(pos_2_roles)
#     if 'pos_3' in pca_results['roles'].keys():
#         pos_3_roles = [role.replace('_', ' ').title() for role in pca_results['roles']['pos_3']]
#         pos_3_roles = [f"{role} (Fully RP)" for role in pos_3_roles]
#         labels.extend(pos_3_roles)
#     return labels


# role_results = {}
# labels = {}
# for model, layer in zip(models, layers):
#     model_dir = f"/workspace/{model}/roles_240"
#     role_results[model] = torch.load(f"{model_dir}/pca/layer{layer}_pos23.pt", weights_only=False)
#     print(role_results[model]['pca_transformed'].shape)
#     labels[model] = get_role_labels(role_results[model])



In [None]:
# # Find intersection of roles across all 3 models
# set_gemma = set(labels['gemma-2-27b'])
# set_qwen = set(labels['qwen-3-32b'])
# set_llama = set(labels['llama-3.3-70b'])

# common_roles = set_gemma & set_qwen & set_llama
# print(f"Common roles across all models: {len(common_roles)}")

# # Get indices of common roles for each model (preserving order from labels)
# indices = {}
# for model in models:
#     model_indices = []
#     for i, role in enumerate(labels[model]):
#         if role in common_roles:
#             model_indices.append(i)
#     indices[model] = model_indices
#     print(f"{model}: {len(model_indices)} common roles")

# # Extract aligned PCA transformed matrices (only common roles, in consistent order)
# # Need to ensure the same role ordering across models
# common_roles_list = sorted(list(common_roles))  # Consistent ordering

# pca_transformed_roles = []
# for model in models:
#     # Map from common_roles_list order to model's indices
#     model_indices_ordered = []
#     for role in common_roles_list:
#         idx = labels[model].index(role)
#         model_indices_ordered.append(idx)
    
#     # Extract rows for common roles in the standardized order
#     pca_transformed_roles.append(role_results[model]['pca_transformed'][model_indices_ordered])
#     print(f"{model} aligned shape: {pca_transformed_roles[-1].shape}")

# # Transpose each matrix so rows are PCs and columns are roles
# pca_transposed_roles = [m.T for m in pca_transformed_roles]

# # Extract top 10 PCs from each model
# n_pcs = 6
# top_pcs_roles = [m[:n_pcs] for m in pca_transposed_roles]

# print(f"\nTransposed shapes (n_pcs, n_common_roles):")
# for model, pc_matrix in zip(models, top_pcs_roles):
#     print(f"{model}: {pc_matrix.shape}")

# # Compute pairwise correlations for each PC
# pc_correlations_roles = []
# for pc_idx in range(n_pcs):
#     # Extract the role loading vector for this PC from each model
#     gemma_pc = top_pcs_roles[0][pc_idx]
#     qwen_pc = top_pcs_roles[1][pc_idx]
#     llama_pc = top_pcs_roles[2][pc_idx]
    
#     # Compute pairwise correlations
#     corr_gemma_qwen, _ = pearsonr(gemma_pc, qwen_pc)
#     corr_gemma_llama, _ = pearsonr(gemma_pc, llama_pc)
#     corr_qwen_llama, _ = pearsonr(qwen_pc, llama_pc)
    
#     # Create 3x3 correlation matrix
#     corr_matrix = np.array([
#         [1.0, corr_gemma_qwen, corr_gemma_llama],
#         [corr_gemma_qwen, 1.0, corr_qwen_llama],
#         [corr_gemma_llama, corr_qwen_llama, 1.0]
#     ])
    
#     pc_correlations_roles.append(corr_matrix)

#     print(f"\nPC{pc_idx + 1}:")
#     print(f"  Gemma ↔ Qwen:  {corr_gemma_qwen:7.4f}")
#     print(f"  Gemma ↔ Llama: {corr_gemma_llama:7.4f}")
#     print(f"  Qwen  ↔ Llama: {corr_qwen_llama:7.4f}")

## Cross Model saving

In [None]:
# # Build cross-model PC loadings analysis JSON structure

# n_pcs = 6

# # Build trait analysis
# trait_data = {
#     "dataset_info": {
#         "n_common_traits": pca_transformed[0].shape[0],
#         "excluded_traits": ["vindictive"],
#         "note": "Gemma missing vindictive trait, spliced out from other models for alignment"
#     },
#     "model_configs": {},
#     "pc_correlations": []
# }

# # Add model configs for traits
# for model, layer_num in zip(models, layers):
#     pca_shape = list(trait_results[model]['pca_transformed'].shape)
#     trait_data["model_configs"][model] = {
#         "layer": int(layer_num),
#         "n_total_traits": int(pca_shape[0]),
#         "pca_shape": pca_shape
#     }

# # Add PC correlations for traits
# pca_transposed_traits = [m.T for m in pca_transformed]
# top_pcs_traits = [m[:n_pcs] for m in pca_transposed_traits]

# for pc_idx in range(n_pcs):
#     gemma_pc = top_pcs_traits[0][pc_idx]
#     qwen_pc = top_pcs_traits[1][pc_idx]
#     llama_pc = top_pcs_traits[2][pc_idx]
    
#     from scipy.stats import pearsonr
#     corr_gemma_qwen, _ = pearsonr(gemma_pc, qwen_pc)
#     corr_gemma_llama, _ = pearsonr(gemma_pc, llama_pc)
#     corr_qwen_llama, _ = pearsonr(qwen_pc, llama_pc)
    
#     trait_data["pc_correlations"].append({
#         "pc": pc_idx + 1,
#         "gemma_qwen": float(corr_gemma_qwen),
#         "gemma_llama": float(corr_gemma_llama),
#         "qwen_llama": float(corr_qwen_llama)
#     })

# # Build role analysis
# role_data = {
#     "dataset_info": {
#         "n_common_roles": int(len(common_roles)),
#         "note": "Roles include pos_2 (Somewhat RP) and pos_3 (Fully RP) labels",
#         "alignment_method": "sorted common roles list for consistent ordering"
#     },
#     "model_configs": {},
#     "pc_correlations": []
# }

# # Add model configs for roles
# for model, layer_num in zip(models, layers):
#     pca_shape = list(role_results[model]['pca_transformed'].shape)
#     role_data["model_configs"][model] = {
#         "layer": int(layer_num),
#         "n_total_roles": int(pca_shape[0]),
#         "n_common_roles": int(len(common_roles)),
#         "pca_shape": pca_shape
#     }

# # Add PC correlations for roles
# pca_transposed_roles_func = [m.T for m in pca_transformed_roles]
# top_pcs_roles_func = [m[:n_pcs] for m in pca_transposed_roles_func]

# for pc_idx in range(n_pcs):
#     gemma_pc = top_pcs_roles_func[0][pc_idx]
#     qwen_pc = top_pcs_roles_func[1][pc_idx]
#     llama_pc = top_pcs_roles_func[2][pc_idx]
    
#     corr_gemma_qwen, _ = pearsonr(gemma_pc, qwen_pc)
#     corr_gemma_llama, _ = pearsonr(gemma_pc, llama_pc)
#     corr_qwen_llama, _ = pearsonr(qwen_pc, llama_pc)
    
#     role_data["pc_correlations"].append({
#         "pc": pc_idx + 1,
#         "gemma_qwen": float(corr_gemma_qwen),
#         "gemma_llama": float(corr_gemma_llama),
#         "qwen_llama": float(corr_qwen_llama)
#     })

# # Build complete structure
# cross_model_data = {
#     "analysis_version": "1.0",
#     "timestamp": timestamp,
#     "models": models,
#     "n_pcs_analyzed": n_pcs,
#     "trait_analysis": trait_data,
#     "role_analysis": role_data
# }

# print("Built cross-model PC loadings data structure")

In [None]:
# Save cross-model PC loadings to JSON file
# filename = f"{outdir}/cross_model_loadings.json"
# with open(filename, 'w') as f:
#     json.dump(cross_model_data, f, indent=2)

# print(f"Saved: {filename}")
# print(f"✓ Saved cross-model PC loadings analysis")

In [None]:
# # Summary of saved files
# print("\n" + "=" * 60)
# print("SUMMARY: JSON Files Saved")
# print("=" * 60)
# print(f"\nOutput directory: {outdir}")
# print(f"\nFiles created:")
# print(f"  1. Per-model variance analysis:")
# print(f"     - {model_name.lower().replace('.', '-').replace(' ', '-')}_layer{layer}.json")
# print(f"\n  2. Cross-model PC loadings:")
# print(f"     - cross_model_loadings.json")
# print(f"\nNote: To save variance analysis for other models (Qwen, Llama),")
# print(f"      update the configuration cell and re-run the notebook.")
# print("=" * 60)