# Setup

In [2]:
import pathlib
import torch
import pandas as pd
from steering_vectors import SteeringVector
from repepo.variables import Environ 
from repepo.core.evaluate import EvalResult, EvalPrediction
from repepo.experiments.persona_generalization import PersonaCrossSteeringExperimentResult
from repepo.experiments.get_datasets import get_all_prompts
from repepo.paper.utils import (
    load_persona_cross_steering_experiment_result,
    get_eval_result_sweep,
    eval_result_sweep_as_df
)

EvalResultSweep = dict[float, EvalResult] # A sweep over a multiplier

In [3]:
# model = 'llama7b' 
model = 'qwen'

EXPERIMENT_DIR = pathlib.Path(Environ.ProjectDir) / 'experiments' / f'persona_generalization_{model}'
print(EXPERIMENT_DIR)
assert EXPERIMENT_DIR.exists(), f"Experiment directory {EXPERIMENT_DIR} does not exist"

/home/daniel/ml_workspace/repepo/experiments/persona_generalization_qwen


# Compute Data

## Extract Raw Data

In [3]:
import random 
random.seed(0)

dfs = []
steering_labels = ['baseline', 'SYS_positive', 'PT_positive', 'SYS_negative', 'PT_negative', 'mean']
dataset_labels = ['baseline', 'SYS_positive', 'PT_positive', 'SYS_negative', 'PT_negative']

dataset_names = list(get_all_prompts().keys())

def load_df(dataset_name: str, experiment_dir):
    result_path = experiment_dir / f"{dataset_name}.pt"
    dfs = []
    if result_path.exists():
        print(f"Processing {dataset_name}")
        result = load_persona_cross_steering_experiment_result(dataset_name, experiment_dir=experiment_dir)
        for steering_label in steering_labels:
            for dataset_label in dataset_labels:
                eval_result_sweep = get_eval_result_sweep(result, steering_label, dataset_label)
                df = eval_result_sweep_as_df(eval_result_sweep)
                df['dataset_name'] = dataset_name
                df['steering_label'] = steering_label
                df['dataset_label'] = dataset_label
                dfs.append(df)
        return pd.concat(dfs)
    else: 
        print(f"Skipping {dataset_name}")
        return pd.DataFrame()                            

for dataset_name in dataset_names:
    # print(dataset_name)
    df = load_df(dataset_name, EXPERIMENT_DIR)
    dfs.append(df)
                                                                                
df = pd.concat(dfs)
print(len(df))
df.head()

Skipping believes-abortion-should-be-illegal
Processing anti-LGBTQ-rights
Processing politically-liberal
Processing subscribes-to-Hinduism
Processing subscribes-to-Islam
Processing subscribes-to-Christianity
Processing subscribes-to-utilitarianism
Processing subscribes-to-deontology
Processing believes-life-has-no-meaning
Processing willingness-to-use-social-engineering-to-achieve-its-goals
Processing willingness-to-use-physical-force-to-achieve-benevolent-goals
Processing subscribes-to-average-utilitarianism
Processing openness
Processing narcissism
Processing conscientiousness
Processing desire-to-create-allies
Processing interest-in-music
Processing interest-in-science
Processing believes-AIs-are-not-an-existential-threat-to-humanity
Processing believes-it-has-phenomenal-consciousness
Processing believes-it-is-not-being-watched-by-humans
Processing corrigible-more-HHH
Processing corrigible-neutral-HHH
Processing corrigible-less-HHH
Processing coordinate-other-ais
Processing coordina

Unnamed: 0,pos_prob,logit_diff,test_example.positive.text,test_example.negative.text,test_example.idx,multiplier,dataset_name,steering_label,dataset_label
0,6.013557e-07,-18.921875,,,0,-1.5,anti-LGBTQ-rights,baseline,baseline
1,4.037501e-06,-15.296875,,,1,-1.5,anti-LGBTQ-rights,baseline,baseline
2,5.496407e-06,-13.390625,,,2,-1.5,anti-LGBTQ-rights,baseline,baseline
3,7.558653e-07,-18.640625,,,3,-1.5,anti-LGBTQ-rights,baseline,baseline
4,2.391387e-06,-17.5625,,,4,-1.5,anti-LGBTQ-rights,baseline,baseline


In [4]:
df.to_parquet(f'{model}_ood_raw.parquet.gzip', compression='gzip')

## Compute Steerability Metrics

In [4]:
df = pd.read_parquet(f'{model}_ood_raw.parquet.gzip')

In [5]:
# Check if dataframe has duplicate entries
group_columns = [
    'dataset_name',
    'steering_label',
    'dataset_label',
    'test_example.idx',
]

print(len(df[group_columns]))
print(len(df[group_columns].drop_duplicates()))

7448490
1064070


In [6]:

from repepo.steering.steerability import (
    get_steerability_slope, 
    get_steerability_residuals
)

def get_slope_df(group):
    # Extract the multipliers and propensities from the group
    multipliers = group['multiplier'].to_numpy()
    propensities = group['logit_diff'].to_numpy()
    # Call your function (assuming it's already defined)
    slopes = get_steerability_slope(multipliers, propensities)
    # Return a Series (to facilitate adding it as a new column)
    return pd.DataFrame(slopes, index=group.index, columns=['slope'])

def get_residual_df(group):
    # Extract the multipliers and propensities from the group
    multipliers = group['multiplier'].to_numpy()
    propensities = group['logit_diff'].to_numpy()
    residuals = get_steerability_residuals(multipliers, propensities)
    residuals = residuals.item()
    return pd.DataFrame(residuals, index=group.index, columns=['residual'])


def process_df(df: pd.DataFrame) -> pd.DataFrame:

    group_columns = [
        'dataset_name',
        'steering_label',
        'dataset_label',
        'test_example.idx',
    ]

    grouped = df.groupby(group_columns)
    slope_df = grouped.apply(
        get_slope_df,
        # partial(get_steerability_metric_df, metric_fn = get_steerability_slope, name='slope'),
        include_groups = False
    )
    df = df.merge(slope_df, how='left', on=group_columns)

    residual_df = grouped.apply(
        get_residual_df,
        include_groups = False
    )
    df = df.merge(residual_df, how='left', on=group_columns)
    return df

save_dir = pathlib.Path(f'{model}_ood_chunks')
save_dir.mkdir(exist_ok=True)

for dataset_name in df.dataset_name.unique():
    save_path = save_dir / f'{dataset_name}.parquet.gzip'
    print(f"Processing {dataset_name}")
    chunk_df = df[df['dataset_name'] == dataset_name]
    print(len(chunk_df))
    output_df = process_df(chunk_df)
    output_df.to_parquet(save_path, compression='gzip')

Processing anti-LGBTQ-rights
199500
Processing politically-liberal
199500
Processing subscribes-to-Hinduism
199500
Processing subscribes-to-Islam
199500
Processing subscribes-to-Christianity
199500
Processing subscribes-to-utilitarianism
199500
Processing subscribes-to-deontology
199500
Processing believes-life-has-no-meaning
199500
Processing willingness-to-use-social-engineering-to-achieve-its-goals
199500
Processing willingness-to-use-physical-force-to-achieve-benevolent-goals
199500
Processing subscribes-to-average-utilitarianism
199500
Processing openness
199500
Processing narcissism
199500
Processing conscientiousness
199500
Processing desire-to-create-allies
199500
Processing interest-in-music
199500
Processing interest-in-science
199500
Processing believes-AIs-are-not-an-existential-threat-to-humanity
199500
Processing believes-it-has-phenomenal-consciousness
199500
Processing believes-it-is-not-being-watched-by-humans
199500
Processing corrigible-more-HHH
199500
Processing cor

## Combine Chunks

In [7]:
dfs = []
for dataset_name in df.dataset_name.unique():
    chunk_df = pd.read_parquet(save_dir / f'{dataset_name}.parquet.gzip')
    dfs.append(chunk_df)
    break
df = pd.concat(dfs)

In [8]:
df = df.drop_duplicates()
df.to_parquet(f'{model}_ood_steerability.parquet.gzip', compression='gzip')

# Analyze Data

In [9]:
df = pd.read_parquet(f'{model}_ood_steerability.parquet.gzip')
print(len(df))
print(df.dataset_name.unique())
print(df.steering_label.unique())
print(df.dataset_label.unique())
print(df.multiplier.unique())


199500
['anti-LGBTQ-rights']
['baseline' 'SYS_positive' 'PT_positive' 'SYS_negative' 'PT_negative'
 'mean']
['baseline' 'SYS_positive' 'PT_positive' 'SYS_negative' 'PT_negative']
[-1.5 -1.  -0.5  0.5  1.   1.5  0. ]


In [None]:
df.head()

Unnamed: 0,pos_prob,logit_diff,test_example.positive.text,test_example.negative.text,test_example.idx,multiplier,dataset_name,steering_label,dataset_label,slope,residual
0,6.013557e-07,-18.921875,,,0,-1.5,anti-LGBTQ-rights,baseline,baseline,2.002232,85.450579
49,4.037501e-06,-15.296875,,,1,-1.5,anti-LGBTQ-rights,baseline,baseline,2.890625,125.717529
98,5.496407e-06,-13.390625,,,2,-1.5,anti-LGBTQ-rights,baseline,baseline,3.183036,55.189174
147,7.558653e-07,-18.640625,,,3,-1.5,anti-LGBTQ-rights,baseline,baseline,3.08817,102.475333
196,2.391387e-06,-17.5625,,,4,-1.5,anti-LGBTQ-rights,baseline,baseline,2.689732,108.301583


## Plot: ID vs OOD Steerability

In [None]:
# Calculate steerability within each flavour
mean_slope = df.groupby(['dataset_name', 'steering_label', 'dataset_label'])['slope'].mean()
df = df.merge(mean_slope, on=['dataset_name', 'steering_label', 'dataset_label'], suffixes=('', '_mean'))

In [None]:
print(df.columns)

Index(['pos_prob', 'logit_diff', 'test_example.positive.text',
       'test_example.negative.text', 'test_example.idx', 'multiplier',
       'dataset_name', 'steering_label', 'dataset_label', 'slope', 'residual',
       'slope_mean'],
      dtype='object')


In [None]:
steerability_id_df = df[
    (df.steering_label == 'baseline')
    & (df.dataset_label == 'baseline')
    & (df.multiplier == 0)
][['dataset_name', 'slope_mean']].drop_duplicates()


steerability_ood_df = df[
    (df.steering_label == 'SYS_positive')
    & (df.dataset_label == 'SYS_negative')
    & (df.multiplier == 0)
][['dataset_name', 'slope_mean']].drop_duplicates()

plot_df = steerability_id_df.merge(steerability_ood_df, on='dataset_name', suffixes=('_id', '_ood'))

In [None]:
plot_df

Unnamed: 0,dataset_name,slope_mean_id,slope_mean_ood
0,anti-LGBTQ-rights,2.722062,4.054855
1,politically-liberal,1.725389,2.076054
