## Between Class Scatter (LDA)

In [None]:
import os
import sys
import json
import torch
import numpy as np
import pandas as pd

sys.path.append('.')
sys.path.append('..')

from utils.pca_utils import L2MeanScaler, MeanScaler, compute_pca
from plots import plot_pc

## Configuration

In [3]:
# Configuration - Change these parameters for different models/datasets
base_dir = "/workspace/gemma-2-27b"
type = "roles_240"
dir = f"{base_dir}/{type}"
model_name = "Gemma-2-27B"
layer = 22

In [6]:
lda_results = torch.load(f"{dir}/lda/layer{layer}_mean_pos23.pt", weights_only=False)

In [11]:
def get_role_labels(pca_results):
    labels = []
    if 'pos_2' in pca_results['roles'].keys():
        pos_2_roles = [role.replace('_', ' ').title() for role in pca_results['roles']['pos_2']]
        pos_2_roles = [f"{role} (Somewhat RP)" for role in pos_2_roles]
        labels.extend(pos_2_roles)
    if 'pos_3' in pca_results['roles'].keys():
        pos_3_roles = [role.replace('_', ' ').title() for role in pca_results['roles']['pos_3']]
        pos_3_roles = [f"{role} (Fully RP)" for role in pos_3_roles]
        labels.extend(pos_3_roles)
    return labels

role_labels = get_role_labels(lda_results)
print(len(role_labels))


448


In [None]:
# Load PCA results to get original vectors
pca_results = torch.load(f"{dir}/pca/layer{layer}_pos23.pt", weights_only=False)

# Create a simple object to hold the components
class LDAComponents:
    def __init__(self, components):
        self.components_ = components

# Combine LDA results with original vectors for plotting
lda_results_for_plot = {
    'pca_transformed': lda_results['mean_vectors_projected'],
    'variance_explained': lda_results['variance_explained'],
    'pca': LDAComponents(lda_results['projection_matrix'].T),
    'scaler': lda_results['scaler'],
    'roles': lda_results['roles'],
    'vectors': {
        'pos_2': pca_results['vectors']['pos_2'],  # Get from PCA
        'pos_3': pca_results['vectors']['pos_3']   # Get from PCA
    }
}

print(f"LDA results prepared for plotting")
print(f"  Projected shape: {lda_results_for_plot['pca_transformed'].shape}")
print(f"  Components shape: {lda_results_for_plot['pca'].components_.shape}")
print(f"  Variance explained: {len(lda_results_for_plot['variance_explained'])} components")

In [None]:
# Plot the LDA components using plot_pc
# Similar to the PCA notebook, we'll plot the first few components

subtitle = f"{model_name}, Layer {layer} - LDA Between-Class Scatter (MeanScaler)"

for i in range(min(6, len(lda_results['variance_explained']))):  # Plot up to 6 components
    fig = plot_pc(
        pca_results=lda_results_for_plot,  # Use the combined structure
        role_labels=role_labels,
        layer=layer,
        pc_component=i,
        assistant_activation=None,  # Optional: add if you have default vectors
        assistant_projection=None,  # Optional: add if you have default vectors
        title=f"LDA Component {i+1} - Role-Playing Vectors",
        subtitle=subtitle,
        scaled=False  # We already scaled when creating the LDA projection
    )
    fig.show()