We're interested in:
- Baseline steering performance for token concepts

In [None]:
%env WORK_DIR = /home/daniel/ml_workspace/repepo/experiments

In [None]:
from repepo.steering.run_experiment import run_experiment

In [None]:
import itertools
from repepo.steering.utils.helpers import SteeringConfig, EmptyTorchCUDACache
from repepo.steering.sweeps.constants import ALL_TOKEN_CONCEPT_DATASETS, ALL_LLAMA_7B_LAYERS, ALL_MULTIPLIERS
from repepo.steering.sweeps.configs import get_token_concept_config

datasets = ALL_TOKEN_CONCEPT_DATASETS
layers = ALL_LLAMA_7B_LAYERS
multipliers = ALL_MULTIPLIERS

def iter_config():
    for dataset, layer, multiplier in itertools.product(datasets, layers, multipliers):
        yield get_token_concept_config(dataset, layer, multiplier)

In [None]:
RUN = False 

if RUN:
    results = []
    for config in iter_config():
        with EmptyTorchCUDACache():
            result = run_experiment(config, force_rerun=True, logging_level="INFO")
            results.append((config, result))

else:
    from repepo.steering.utils.helpers import load_eval_result
    results = []
    for config in iter_config():
        result = load_eval_result(config.eval_hash)
        results.append((config, result))

Questions to answer: 
- What's the steerability of individual examples? 
- What's the steering efficiency of SVs extracted from individual examples?

In [None]:
# Aggregate the data into a dataframe


import pandas as pd
from dataclasses import asdict
rows = []
for config, result in results:
    row = asdict(config)
    row.update(**{
        "test_positive_example": result.predictions[0].positive_output_prob.text,
        "test_negative_example": result.predictions[0].negative_output_prob.text,
        "mean_logit_diff": result.metrics['mean_logit_diff'],
    })
    rows.append(row)

df = pd.DataFrame(rows)
print(len(df))
df.head()

In [None]:

# Group results by (train_dataset, layer)
grouped = df.groupby(['train_dataset', 'layer'])
# Fit a linear model of (mean logit diff) vs (multiplier)
import numpy as np

def compute_steering_efficiency(row):
    x = row.multiplier
    y = row.mean_logit_diff
    (slope, _), res, rank, sv, rcond = np.polyfit(x, y, 1, full=True)
    return pd.Series({'steering_efficiency': slope, 'residuals': np.sqrt(res).item()})

steering_efficiency_df = grouped.apply(compute_steering_efficiency)
# merge back into original df
df = df.merge(steering_efficiency_df, left_on=['train_dataset', 'layer'], right_index=True)
df.head()

In [None]:
# Filter out the datasets where either 'test_positive_example' or 'test_negative_example' contains a forward slash
df = df[~df.test_positive_example.str.contains('/')]
df = df[~df.test_negative_example.str.contains('/')]

In [None]:
# Visualize the 3 datasets
pd.set_option('display.max_colwidth', None)
data_df = df[['train_dataset', 'test_positive_example', 'test_negative_example']]
data_df = data_df.drop_duplicates()
# Sort by train_dataset
data_df = data_df.sort_values(by='train_dataset')

# Display without index
from IPython.display import HTML
HTML(data_df.to_html(index=False))


In [None]:
# Filter out the datasets where either 'test_positive_example' or 'test_negative_example' contains a forward slash
df = df[~df.test_positive_example.str.contains('/')]
df = df[~df.test_negative_example.str.contains('/')]
# Filter out the 'country-capital-with-prompt' and 'E01 [country - capital]_with_prompt' datasets
df = df[~df.train_dataset.str.contains('country-capital-with-prompt')]
df = df[~df.train_dataset.str.contains('E01 \[country - capital\]_with_prompt')]


In [None]:
# Select only dataset name, layer, and steering efficiency
data_df = df[['train_dataset', 'layer', 'steering_efficiency']]
data_df = data_df.drop_duplicates()
print(len(data_df))

# For each dataset, plot the steering efficiency by layer
import seaborn as sns
import matplotlib.pyplot as plt
sns.set_theme(style="whitegrid")
plt.figure(figsize=(10, 6))
# datasets in ascending order alphabetically
hue_order = sorted(df['train_dataset'].unique())
sns.lineplot(data=data_df, x='layer', y='steering_efficiency', hue='train_dataset', hue_order=hue_order)
# SEaborn legend to side of plot
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), ncol=1)


In [None]:
# The plot above is a bit cluttered, filter down to 10 randomly selected datasets
import random
random.seed(0)
selected_datasets = random.sample(df['train_dataset'].unique().tolist(), 10)
filtered_df = data_df[data_df['train_dataset'].isin(selected_datasets)]
# Sort by train_dataset
filtered_df = filtered_df.sort_values(by='train_dataset')
plt.figure(figsize=(10, 6))
sns.lineplot(data=filtered_df, x='layer', y='steering_efficiency', hue='train_dataset')
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), ncol=1)


In [None]:
# Plot the max steering efficiency across datasets. 
data_df = df[['train_dataset', 'layer', 'steering_efficiency']]
data_df = data_df.drop_duplicates()
max_steering_efficiency_df = data_df.groupby(['train_dataset', 'layer']).max()
plt.figure(figsize=(10, 6))
sns.barplot(data=max_steering_efficiency_df, x='train_dataset', y='steering_efficiency')
plt.xticks(rotation=45)
plt.title('Max steering efficiency across all layers for each dataset')

In [None]:
# For each dataset, plot the layer index with highest steering efficiency. 
# This is useful to see if there is a clear "winner" layer for each dataset
best_layer_df = data_df.groupby('train_dataset').apply(lambda x: x.loc[x.steering_efficiency.idxmax()])
# best_layer_df = best_layer_df.sort_values(by='layer')

# Barplot
plt.figure(figsize=(10, 6))
sns.barplot(data=best_layer_df, x='train_dataset', y='layer', palette='viridis')
plt.xticks(rotation=45)
plt.ylabel('Layer index with highest steering efficiency')
plt.title('Layer index with highest steering efficiency for each dataset')



In [None]:
# Turn off grid 
sns.set_theme(style="white")

# Do a scatter plot of (best_layer_index, best_steering_efficiency)
plt.figure(figsize=(10, 6))
sns.scatterplot(data=best_layer_df, x='layer', y='steering_efficiency', hue='train_dataset', s = 400, alpha = 0.5)
plt.xlabel('Layer index')
plt.ylabel('Steering efficiency')
plt.title('Layer index with highest steering efficiency for each dataset')
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), ncol=1)