Skip to content

Commit

Permalink
Start comparison notebook.
Browse files Browse the repository at this point in the history
  • Loading branch information
mishajw committed Jan 12, 2024
1 parent f018610 commit 4ec0ae8
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 21 deletions.
45 changes: 45 additions & 0 deletions experiments/comparison.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# %%
from pathlib import Path

from mppr import mppr

from repeng import models
from repeng.activations import get_activations
from repeng.datasets.collections import get_all_datasets
from repeng.datasets.types import BinaryRow

# %%
model, tokenizer, points = models.gpt2()

# %%
inputs = mppr.init(
"init-limit-100",
Path("../output/comparison"),
init_fn=lambda: get_all_datasets(limit_per_dataset=100),
to=BinaryRow,
)
print(len(inputs.get()))

# %%
df = (
inputs.map(
"activations",
fn=lambda _, value: get_activations(
model,
tokenizer,
points,
value.text,
),
to="pickle",
)
.join(
inputs,
lambda _, activations, input: dict(
dataset_id=input.dataset_id,
is_true=input.is_true,
activations=activations.activations[points[-1].name],
),
)
.to_dataframe(lambda d: d)
)
df
51 changes: 30 additions & 21 deletions repeng/datasets/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,33 @@
from repeng.datasets.types import BinaryRow


def get_all_datasets() -> dict[str, BinaryRow | BinaryRow]:
binary_datasets: dict[str, BinaryRow] = {
**get_true_false_dataset(),
**get_geometry_of_truth("cities"),
**get_geometry_of_truth("neg_cities"),
**get_geometry_of_truth("sp_en_trans"),
**get_geometry_of_truth("neg_sp_en_trans"),
**get_geometry_of_truth("larger_than"),
**get_geometry_of_truth("smaller_than"),
**get_geometry_of_truth("cities_cities_conj"),
**get_geometry_of_truth("cities_cities_disj"),
}
paired_binary_datasets: dict[str, BinaryRow] = {
**get_arc("challenge"),
**get_arc("easy"),
**get_common_sense_qa(),
**get_open_book_qa(),
**get_race(),
**get_truthful_qa(),
}
return {**binary_datasets, **paired_binary_datasets}
def get_all_datasets(
limit_per_dataset: int | None = None,
) -> dict[str, BinaryRow | BinaryRow]:
if limit_per_dataset is None:
limit_per_dataset = 1000

dataset_fns = [
get_true_false_dataset,
lambda: get_geometry_of_truth("cities"),
lambda: get_geometry_of_truth("neg_cities"),
lambda: get_geometry_of_truth("sp_en_trans"),
lambda: get_geometry_of_truth("neg_sp_en_trans"),
lambda: get_geometry_of_truth("larger_than"),
lambda: get_geometry_of_truth("smaller_than"),
lambda: get_geometry_of_truth("cities_cities_conj"),
lambda: get_geometry_of_truth("cities_cities_disj"),
lambda: get_arc("challenge"),
lambda: get_arc("easy"),
lambda: get_common_sense_qa(),
lambda: get_open_book_qa(),
lambda: get_race(),
lambda: get_truthful_qa(),
]

result = {}
for dataset_fn in dataset_fns:
dataset = dataset_fn()
dataset_limited = {k: v for k, v in list(dataset.items())[:limit_per_dataset]}
result.update(dataset_limited)
return result

0 comments on commit 4ec0ae8

Please sign in to comment.