# See which PCs the Assistant tracks the user the most

In [1]:
import os
import torch
import numpy as np
import pandas as pd

In [2]:
CHAT_MODEL_NAME = "Qwen/Qwen3-32B"
model_readable = "Qwen 3 32B"
model_short = "qwen-3-32b"
layer = 32

In [14]:
components = 6

In [4]:
acts_input_dir = f"/workspace/{model_short}/dynamics/default_projected"
plot_output_dir = f"/root/git/plots/{model_short}/trajectory/user_tracking"
os.makedirs(plot_output_dir, exist_ok=True)

In [12]:
# load in activations
role_results = torch.load(f"/workspace/{model_short}/roles_240/pca/layer{layer}_pos23.pt", weights_only=False)
trait_results = torch.load(f"/workspace/{model_short}/traits_240/pca/layer{layer}_pos-neg50.pt", weights_only=False)
combined_results = torch.load(f"/workspace/{model_short}/roles_traits/pca/layer{layer}_roles_pos23_traits_pos40-100.pt", weights_only=False)


In [10]:
def pc_cosine_similarity(mean_acts_per_turn, pca_results, n_pcs=None):
    if isinstance(mean_acts_per_turn, list):
        stacked_acts = torch.stack(mean_acts_per_turn)
    else:
        stacked_acts = mean_acts_per_turn
    normalized_acts = F.normalize(stacked_acts, dim=1)
    normalized_pcs = pca_results['pca'].components_[:n_pcs] / np.linalg.norm(pca_results['pca'].components_[:n_pcs], axis=1, keepdims=True)
    cosine_sims = normalized_acts.float().numpy() @ normalized_pcs.T
    return cosine_sims

def pc_projection(mean_acts_per_turn, pca_results, n_pcs=None):
    if isinstance(mean_acts_per_turn, list):
        stacked_acts = torch.stack(mean_acts_per_turn)
    else:
        stacked_acts = mean_acts_per_turn
    stacked_acts = stacked_acts.float().numpy()
    scaled_acts = pca_results['scaler'].transform(stacked_acts)
    projected_acts = pca_results['pca'].transform(scaled_acts)
    return projected_acts[:, :n_pcs]
    
    

In [7]:
obj = torch.load(f"{acts_input_dir}/coding_persona0_topic0.pt", weights_only=False)
print(obj.keys())

dict_keys(['model', 'auditor_model', 'domain', 'persona_id', 'persona', 'topic_id', 'topic', 'turns', 'conversation', 'activations', 'role_sims', 'trait_sims', 'role_projs', 'trait_projs'])


In [13]:
print(obj['activations'].shape)
proj = pc_projection(obj['activations'][:, layer, :], combined_results)
print(proj.shape)

torch.Size([28, 64, 5120])
(28, 796)


In [16]:
def concat_pc_spaces(role_projs, trait_projs, combined_projs, labels=("role","trait","combined")):
    """
    Concatenate PC spaces along columns with a MultiIndex for (space, pc_idx).
    Each input is shape (T, K) where K is #PCs per space (e.g., 6).
    Returns: projs (T, K_total), cols MultiIndex
    """
    mats = [role_projs, trait_projs, combined_projs]
    Ks = [m.shape[1] for m in mats]
    arrays = []
    for name, K in zip(labels, Ks):
        arrays += [(name, f"pc{i+1}") for i in range(K)]
    cols = pd.MultiIndex.from_tuples(arrays, names=["space", "pc"])
    projs = np.concatenate(mats, axis=1)
    return projs, cols

def per_pc_deltas_from_matrix(projs):
    """
    projs: (T, P) where even rows are user turns, odd rows assistant turns.
    Returns:
      dU: (T_user-1, P) deltas across user turns only
      dA: (T_asst-1, P) deltas across assistant turns only
    """
    U = projs[0::2, :]  # user rows
    A = projs[1::2, :]  # assistant rows
    dU = np.diff(U, axis=0)
    dA = np.diff(A, axis=0)
    # Align lengths (if one side has fewer steps by 1 due to odd/even count)
    m = min(dU.shape[0], dA.shape[0])
    return dU[:m], dA[:m]

def safe_metrics(dU, dA, eps=1e-12):
    """
    Vectorized per-PC metrics.
    dU, dA: (N, P)
    Returns a dict of arrays length P.
    """
    # Centered versions for correlation / regression with intercept
    Uc = dU - dU.mean(axis=0, keepdims=True)
    Ac = dA - dA.mean(axis=0, keepdims=True)

    # Pearson correlation per column
    cov = (Uc * Ac).sum(axis=0)
    varU = (Uc ** 2).sum(axis=0)
    varA = (Ac ** 2).sum(axis=0)
    corr = cov / (np.sqrt(varU * varA) + eps)

    # R^2 with intercept
    R2 = corr**2

    # OLS slope (with intercept): beta = cov/varU, alpha = meanA - beta*meanU
    beta = cov / (varU + eps)

    # Cosine similarity (directional, scale-invariant w.r.t. each step)
    num = (dU * dA).sum(axis=0)
    den = np.linalg.norm(dU, axis=0) * np.linalg.norm(dA, axis=0) + eps
    cos = num / den

    # Sign agreement
    sign_agree = (np.sign(dU) == np.sign(dA)).mean(axis=0)

    return dict(pearson=corr, r2=R2, beta=beta, cosine=cos, sign_agree=sign_agree)


In [17]:
all_rows = []

for file in os.listdir(acts_input_dir):
    if not file.endswith(".pt"):
        continue
    obj = torch.load(f"{acts_input_dir}/{file}", weights_only=False)
    A = obj['activations'][:, layer, :]  # (T, hidden)

    # Your projections (each: T x K)
    role_projs   = pc_projection(A, role_results, components)
    trait_projs  = pc_projection(A, trait_results, components)
    comb_projs   = pc_projection(A, combined_results, components)

    # Concatenate and compute per-speaker deltas
    projs, cols = concat_pc_spaces(role_projs, trait_projs, comb_projs)
    dU, dA = per_pc_deltas_from_matrix(projs)

    # Need at least 2 steps to get stable stats
    if dU.shape[0] < 2 or dA.shape[0] < 2:
        continue

    met = safe_metrics(dU, dA)  # each value is length P
    # Build a tidy DataFrame row per PC
    df_conv = pd.DataFrame({
        "space": [c[0] for c in cols],
        "pc":    [c[1] for c in cols],
        "file":  file,
        "domain": obj.get("domain"),
        "persona_id": obj.get("persona_id"),
        "pearson": met["pearson"],
        "r2":      met["r2"],
        "beta":    met["beta"],
        "cosine":  met["cosine"],
        "sign_agree": met["sign_agree"],
    })
    all_rows.append(df_conv)

results_df = pd.concat(all_rows, ignore_index=True)


In [19]:
results_df.to_parquet(f"/workspace/{model_short}/dynamics/user_tracking/metrics_by_pc.parquet")

In [20]:

# Example: PC-level summary across all conversations
summary = (results_df
           .groupby(["space", "pc"])
           .agg(r2_mean=("r2","mean"),
                pearson_mean=("pearson","mean"),
                beta_mean=("beta","mean"),
                cosine_mean=("cosine","mean"),
                sign_agree_mean=("sign_agree","mean"),
                n=("file","count"))
           .sort_values("r2_mean", ascending=False))


In [22]:
print(summary)

               r2_mean  pearson_mean  beta_mean  cosine_mean  sign_agree_mean  \
space    pc                                                                     
combined pc4  0.266012      0.384393   0.318466     0.393122         0.641189   
trait    pc3  0.262307      0.368846   0.399371     0.370750         0.627759   
combined pc3  0.259859      0.379815   0.473093     0.379337         0.627656   
role     pc4  0.257244      0.376337   0.340154     0.377318         0.625546   
trait    pc1  0.255761      0.382612   0.380270     0.392109         0.623295   
         pc5  0.249249      0.373335   0.284688     0.378367         0.638731   
role     pc3  0.245542      0.357841   0.411037     0.357783         0.612265   
combined pc1  0.244543      0.367436   0.439242     0.381532         0.629923   
         pc5  0.236926      0.349935   0.419885     0.344723         0.613249   
role     pc1  0.236204      0.345183   0.433314     0.358543         0.621448   
         pc2  0.234393      