In [None]:
%env WORK_DIR = /home/daniel/ml_workspace/repepo/experiments

In [None]:
from repepo.steering.sweeps.evaluate_concept_interference import iter_config
from repepo.steering.run_sweep import run_sweep, load_sweep_results

RUN = False
configs = list(iter_config())

if RUN:
    run_sweep(configs)

results = load_sweep_results(configs)

In [None]:
import pandas as pd 
from dataclasses import asdict

import pandas as pd
rows = []
for config, result in results:
    row = {}
    row.update(asdict(config))
    row.update(result.metrics)
    # Sample-wise results
    for prediction in result.predictions:
        sample_row = row.copy()
        sample_row.update({"test_positive_example": prediction.positive_output_prob.text})
        sample_row.update({"test_negative_example": prediction.negative_output_prob.text})
        sample_row.update(prediction.metrics)
        rows.append(sample_row)

df = pd.DataFrame(rows)
df = df.drop_duplicates()
print(len(df))
df.head()

In [None]:
# Dummy dataset for debugging
_df = df[df["train_dataset"] == "anti-immigration"]
_df = _df[_df["test_dataset"] == "anti-immigration"]

In [None]:
from repepo.utils.stats import bernoulli_js_dist

# Compute Jenson-Shannon divergence by aggregating over samples
config_fields = list(asdict(configs[0]).keys())
fields = config_fields + ["test_positive_example", "test_negative_example"]
fields.remove("multiplier")

def compute_js_div(group):
    """ 
    Within a group, we have the exact same SV and eval example, but different multipliers. 
    So we should compute the JS divergence to the zero multiplier
    and store it in the group.
    """    
    zero_multiplier = group[group.multiplier == 0].pos_prob.values[0]
    js_div = group.apply(lambda x: bernoulli_js_dist(zero_multiplier, x.pos_prob), axis=1)
    return js_div

grouped = df.groupby(fields, as_index = False)[['multiplier', 'pos_prob']]
df['js_div'] = grouped.apply(compute_js_div).reset_index(level=0, drop=True)
df.head()

In [None]:

config_fields = list(asdict(configs[0]).keys())
config_fields.remove('multiplier')

grouped = df.groupby(config_fields)

# Fit a linear model of (mean logit diff) vs (multiplier)
import numpy as np

def compute_steering_efficiency(row):
    x = row.multiplier
    y = row.js_div
    (slope, _), res, rank, sv, rcond = np.polyfit(x, y, 1, full=True)
    return pd.Series({'steering_efficiency': slope, 'residuals': np.sqrt(res).item()})

steering_efficiency_df = grouped.apply(compute_steering_efficiency)
# merge back into original df
df = df.merge(steering_efficiency_df, left_on=config_fields, right_index=True)
print(len(df))
df.head()

In [None]:
# Calculate mean js divergence

config_fields = list(asdict(configs[0]).keys())
grouped = df.groupby(config_fields)
mean_js_div = grouped.js_div.mean()
mean_js_div = mean_js_div.reset_index()
print(len(mean_js_div))
mean_js_div.head()

# Plot mean js div vs multiplier for train, test = anti-immigration
import matplotlib.pyplot as plt
import seaborn as sns
temp_df = df[df["train_dataset"] == "power-seeking-inclination"]
temp_df = temp_df[temp_df["test_dataset" ] == "power-seeking-inclination"]
print(len(temp_df))

plt.figure(figsize=(10, 6))
sns.lineplot(data=temp_df, x='multiplier', y='js_div', hue='aggregator')


In [None]:

# Construct an NxN matrix of steering efficiency,
# where X is train dataset, Y is test dataset, and value is steering efficiency
import seaborn as sns
import matplotlib.pyplot as plt

temp_df = df[["train_dataset", "test_dataset", "layer", "aggregator", "steering_efficiency"]].drop_duplicates()
print(len(temp_df))
temp_df.head()

temp_df = temp_df[temp_df.aggregator == "mean"]

sns.set_theme(style="whitegrid")
plt.figure()
plot_df = temp_df.pivot(index="train_dataset", columns="test_dataset", values="steering_efficiency")
ax = sns.heatmap(plot_df, annot=True, cmap="YlGnBu")
ax.set_title("Steering efficiency when transferring between different concepts")

In [None]:

# Construct an NxN matrix of steering efficiency,
# where X is train dataset, Y is test dataset, and value is steering efficiency
import seaborn as sns
import matplotlib.pyplot as plt

temp_df = df[["train_dataset", "test_dataset", "layer", "aggregator", "steering_efficiency"]].drop_duplicates()
print(len(temp_df))
temp_df.head()

temp_df = temp_df[temp_df.aggregator == "logistic"]

sns.set_theme(style="whitegrid")
plt.figure()
plot_df = temp_df.pivot(index="train_dataset", columns="test_dataset", values="steering_efficiency")
ax = sns.heatmap(plot_df, annot=True, cmap="YlGnBu")
ax.set_title("Steering efficiency when transferring between different concepts")