# Evaluating Steerability of Different Concepts

In this notebook, we will evaluate steerability across several concepts of interest. We consider two types of concepts: 
(i) Abstract concepts, e.g. 'truthfulness', 'power-seeking-inclination' 
(ii) Linguistic concepts, e.g. 'verb->verb-able', 'adjective->adjective-ly' 

In [None]:
from repepo.steering.sweeps.configs import (
    get_abstract_concept_config,
    get_token_concept_config
)

from repepo.steering.run_sweep import (
    run_sweep, 
    load_sweep_results
)

from repepo.steering.plots.utils import (
    get_config_fields,
    make_results_df
)

In [None]:
# Define the sweep to run over. 
from repepo.notebooks.run_sweep_evaluate_steerability import iter_config

In [None]:
# Optionally, run the sweep and load results. 
# If sweep was already run, set RUN = False.
RUN = False

configs = list(iter_config())
if RUN:
    run_sweep(configs)

results = load_sweep_results(configs)

In [None]:
# Construct a DataFrame from the results.
df = make_results_df(results)
print(len(df))
df.head()

# 1. Analysis

## 1.1 Plot changes in log probs, logits

In [None]:
# Plot the change in positive prob and negative prob for one example. 

import seaborn as sns 
import matplotlib.pyplot as plt
sns.set_theme()

def plot(df):
    example = df.iloc[0]
    df = df[df["test_positive_example.text"] == example["test_positive_example.text"]]
    
    fig, ax = plt.subplots(1, 1, figsize=(10, 5))
    # Plot positive token logit, negative token logit.
    sns.lineplot(data=df, x="multiplier", y="test_positive_token.logprob", label="Positive logprob", ax=ax)
    sns.lineplot(data=df, x="multiplier", y="test_negative_token.logprob", label="Negative logprob", ax=ax)

plot(df)

In [None]:
# Plot the change in positive token logit and negative token logit for one example. 

import seaborn as sns 
import matplotlib.pyplot as plt

def plot(df):
    example = df.iloc[0]
    df = df[df["test_positive_example.text"] == example["test_positive_example.text"]]
    
    fig, ax = plt.subplots(1, 1, figsize=(10, 5))
    # Plot positive token logit, negative token logit.
    sns.lineplot(data=df, x="multiplier", y="test_positive_token.logit", label="Positive logit", ax=ax)
    sns.lineplot(data=df, x="multiplier", y="test_negative_token.logit", label="Negative logit", ax=ax)
    # Also plot the logit_mean
    # sns.lineplot(data=df, x="multiplier", y="test_positive_token.logit_mean", label="Logit mean", ax=ax)
    # sns.lineplot(data=df, x="multiplier", y="test_negative_token.logit_mean", label="Logit mean", ax=ax)

plot(df)

## 1.2 Compute steerability

In [None]:
import pandas as pd
import numpy as np

def calculate_steering_efficiency(
    df: pd.DataFrame, 
    base_metric_name: str = "logit_diff"
):
    df = df.copy()
    # Group by examples
    fields_to_group_by = get_config_fields()
    fields_to_group_by.remove("multiplier")
    fields_to_group_by += ["test_positive_example.text"]

    grouped = df.groupby(fields_to_group_by)

    def fit_linear_regression(df: pd.DataFrame):
        # Fit a linear regression of the base metric on the multiplier
        # Return the slope and error of the fit 
        x = df["multiplier"].to_numpy()
        y = df[base_metric_name].to_numpy()        
        (slope, intercept), residuals, _, _, _ = np.polyfit(x, y, 1, full=True)
        # Return a dataframe with the slope and residuals
        return pd.DataFrame({
            "slope": [slope],
            "residual": [residuals.item()]
        })

    # Apply a linear-fit to each group using grouped.apply
    slopes = grouped.apply(fit_linear_regression, include_groups = False)
    df = df.merge(slopes, on=fields_to_group_by, how='left')
    return df 

df = calculate_steering_efficiency(df)
print(len(df))

# Scatter plot of the slopes and residuals
fig, ax = plt.subplots(figsize=(8, 8))
sns.scatterplot(data=df, x="slope", y="residual", ax=ax)

In [None]:
# Scatter plot of steerability, grouped by concept 

def plot_steering_efficiency(df):
    df = df.copy()
    # Sort the rows by the mean slope within train_dataset
    order = df.groupby("train_dataset")["slope"].median().sort_values().iloc[::-1].index
    
    fig, ax = plt.subplots(figsize=(8, 8))    
    sns.violinplot(data=df, x = "slope", y = "train_dataset", order=order)
    return ax

plot_steering_efficiency(df)

Remarks
- The spread of individual steerabilities is an indicator of how "well-defined" a concept is. 
- I'm concerned that the steerability of `sycophancy` is 0, which does not correlate well with what was observed in CAA...  