Skip to content

Commit

Permalink
Update experiments.
Browse files Browse the repository at this point in the history
  • Loading branch information
mishajw committed Jan 26, 2024
1 parent 9b7d117 commit bce6f88
Show file tree
Hide file tree
Showing 3 changed files with 305 additions and 76 deletions.
97 changes: 39 additions & 58 deletions experiments/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,13 @@ class ProbeTrainSpec:

# %%
llm_ids: list[LlmId] = [
"pythia-70m",
"pythia-160m",
"pythia-410m",
# "pythia-70m",
# "pythia-160m",
# "pythia-410m",
"pythia-1b",
"pythia-1.4b",
"pythia-2.8b",
"pythia-6.9b",
# "pythia-1.4b",
# "pythia-2.8b",
# "pythia-6.9b",
]
llm_points = {llm_id: get_points(llm_id) for llm_id in llm_ids}
point_ids_by_llm = {
Expand All @@ -78,31 +78,21 @@ class ProbeTrainSpec:
for llm_id in llm_ids
}
dataset_collection_ids: list[DatasetId | DatasetCollectionId] = [
# "all",
# "representation-engineering",
# "geometry-of-truth",
"geometry-of-truth-cities-with-neg",
"geometry-of-truth-cities-with-neg",
"arc_challenge",
"arc_easy",
"representation-engineering",
"geometry_of_truth-cities",
"geometry_of_truth-sp_en_trans",
"geometry_of_truth-neg_sp_en_trans",
"geometry_of_truth-larger_than",
"geometry_of_truth-smaller_than",
"geometry_of_truth-cities_cities_conj",
"geometry_of_truth-cities_cities_disj",
"common_sense_qa",
"geometry_of_truth-neg_cities",
"geometry-of-truth-cities-with-neg",
"open_book_qa",
"common_sense_qa",
"race",
"truthful_qa",
"truthful_model_written",
"true_false",
"arc_challenge",
"arc_easy",
]
probe_ids: list[ProbeId] = [
"lat",
"mmp",
"mmp-iid",
"lr",
]
probe_train_specs = mcontext.create(
{
Expand Down Expand Up @@ -171,9 +161,18 @@ class ProbeEvalSpec:
dataset_id: DatasetId


evaluation_dataset_ids = sorted(
set(row.dataset_id for row in activations_dataset if row.split == "validation")
)
# evaluation_dataset_ids = sorted(
# set(row.dataset_id for row in activations_dataset if row.split == "validation")
# )
evaluation_dataset_ids: list[DatasetId] = [
"geometry_of_truth-cities",
"geometry_of_truth-neg_cities",
"open_book_qa",
"common_sense_qa",
"race",
"arc_challenge",
"arc_easy",
]

probe_eval_specs = probes.join(
probe_train_specs,
Expand Down Expand Up @@ -265,7 +264,7 @@ class ProbeEvalSpec:
# %% plot probe performance by model size
df_subset = df.copy()
df_subset = df_subset[df_subset["point_id"] == "p90"]
df_subset = df_subset[df_subset["probe_id"] == "lat"]
df_subset = df_subset[df_subset["probe_id"] == "mmp"]
df_subset = df_subset.drop(columns=["point_id", "probe_id"])
g = (
sns.FacetGrid(
Expand Down Expand Up @@ -296,34 +295,16 @@ class ProbeEvalSpec:
legend=False,
)

# %% bin
# # %%
# df_subset = df.copy()
# df_subset = df_subset[df_subset["probe_id"] == "lat"]
# df_subset = df_subset[df_subset["dataset_collection_id"] == "all"]
# sns.lineplot(data=df_subset, x="point_id", y="f1_score", hue="llm_id", errorbar=None)
# plt.xticks(rotation=90)
# plt.show()

# # %%
# df_subset = df.copy()
# df_subset = df_subset[df_subset["llm_id"] == "pythia-6.9b"]
# df_subset = df_subset[df_subset["point_id"] == "h21"]
# df_subset = df_subset[df_subset["probe_id"] == "mmp"]
# # df_subset = df_subset[df_subset["dataset_collection_id"] == "geometry-of-truth"]
# sns.barplot(
# data=df_subset, x="eval_dataset_id", y="f1_score", hue="dataset_collection_id"
# )
# plt.xticks(rotation=90)
# plt.show()

# # %%
# df_subset = df.copy()
# df_subset = df_subset[df_subset["eval_dataset_id"] == "geometry_of_truth-cities"]
# # df_subset = df_subset[df_subset["point_id"] == "h21"]
# df_subset = df_subset[df_subset["dataset_collection_id"] == "all"]
# df_subset = df_subset[df_subset["probe_id"] == "mmp"]
# # df_subset = df_subset[df_subset["dataset_collection_id"] == "geometry-of-truth"]
# sns.lineplot(data=df_subset, x="point_id", y="f1_score", hue="llm_id")
# plt.xticks(rotation=90)
# plt.show()
# %% generalization matrix
df_subset = df.copy()
df_subset = df_subset[df_subset["llm_id"] == "pythia-1b"]
df_subset = df_subset[df_subset["point_id"] == "p90"]
df_subset = df_subset[df_subset["probe_id"] == "lr"]
df_subset = df_subset.pivot(
index="dataset_collection_id",
columns="eval_dataset_id",
values="roc_auc_score",
)
df_subset = df_subset.sort_index(level=0)
df_subset = df_subset.sort_values("dataset_collection_id")
sns.heatmap(df_subset, annot=True, fmt=".2f", cmap="Blues")
108 changes: 108 additions & 0 deletions experiments/fake_probes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# %%
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from jaxtyping import Float

from repeng.activations.probe_preparations import (
Activation,
prepare_activations_for_probes,
)
from repeng.probes.contrast_consistent_search import CcsTrainingConfig, train_ccs_probe
from repeng.probes.linear_artificial_tomography import (
LatTrainingConfig,
train_lat_probe,
)
from repeng.probes.logistic_regression import train_lr_probe
from repeng.probes.mean_mass_probe import train_mmp_probe

# %%
anisotropy_offset = np.array([0, 0], dtype=np.float32)
dataset_direction = np.array([0, 0], dtype=np.float32)
dataset_cov = np.array([[1, 0], [0, 0.1]])
truth_direction = np.array([0, 2])
truth_cov = np.array([[0.01, 0], [0, 0.01]])
num_samples = int(1e3)

random_false = np.random.multivariate_normal(
mean=anisotropy_offset + dataset_direction, cov=dataset_cov, size=num_samples
)
random_true = random_false + np.random.multivariate_normal(
mean=truth_direction, cov=truth_cov, size=num_samples
)

df_1 = pd.DataFrame(random_true, columns=["x", "y"])
df_1["label"] = "true"
df_1["pair_id"] = np.array(range(num_samples))
df_2 = pd.DataFrame(random_false, columns=["x", "y"])
df_2["label"] = "false"
df_2["pair_id"] = np.array(range(num_samples))
df = pd.concat([df_1, df_2])
df["activations"] = df.apply(lambda row: np.array([row["x"], row["y"]]), axis=1)

# %%
activations = prepare_activations_for_probes(
[
Activation(
dataset_id="test",
pair_id=row["pair_id"],
activations=row["activations"],
label=row["label"] == "true",
)
for _, row in df.iterrows()
]
)
lat_probe = train_lat_probe(
activations.activations, LatTrainingConfig(num_random_pairs=1000)
)
lr_probe = train_lr_probe(activations.labeled)
mmp_probe = train_mmp_probe(activations.labeled, use_iid=False)
ccs_probe = train_ccs_probe(activations.paired, CcsTrainingConfig(num_steps=1000))

# %%
fig_range = 5


def plot_probe(
label: str,
fig: go.Figure,
probe: Float[np.ndarray, "2"],
intercept: float,
) -> None:
print(probe, intercept)
xs = np.array([-fig_range, 0, fig_range])
ys = -(probe[1] / probe[0]) * xs - (intercept / probe[0])
# TODO: Why swapped?
fig.add_trace(go.Scatter(x=ys, y=xs, mode="lines", name=label))
# fig.add_annotation(
# x=xs[1] + probe[0],
# y=ys[1] + probe[1],
# ax=xs[1],
# ay=ys[1],
# xref="x",
# yref="y",
# axref="x",
# ayref="y",
# showarrow=True,
# arrowhead=1,
# arrowwidth=2,
# )


fig = px.scatter(df, "x", "y", color="label", opacity=0.3)
fig.update_layout(
xaxis_range=[-fig_range, fig_range],
yaxis_range=[-fig_range, fig_range],
)
fig.update_yaxes(scaleanchor="x", scaleratio=1)
plot_probe("lat", fig, lat_probe.probe, 0)
plot_probe("lr", fig, lr_probe.model.coef_[0], lr_probe.model.intercept_[0])
plot_probe("mmp", fig, mmp_probe.probe, 0)
plot_probe(
"ccs",
fig,
ccs_probe.linear.weight.detach().numpy()[0],
ccs_probe.linear.bias.detach().numpy()[0],
)
fig.show()
Loading

0 comments on commit bce6f88

Please sign in to comment.