In [5]:
import sys

sys.path.append("/workspace/circuit-finder")

import torch

from functools import partial
from circuit_finder.data_loader import load_datasets_from_json
from circuit_finder.pretrained import load_model
from circuit_finder.experiments.run_dataset_sweep import ALL_DATASETS
from circuit_finder.metrics import batch_avg_answer_diff
from circuit_finder.constants import ProjectDir

torch.set_grad_enabled(False)

model = load_model()


def logit_diff(model, tokens, batch):
    # Get the last-token logits
    logits = model(tokens)[:, -1, :]
    logit_diff = batch_avg_answer_diff(logits, batch)
    return logit_diff.mean()



Loaded pretrained model gpt2 into HookedTransformer


In [12]:
for dataset_path in ALL_DATASETS:
    train_loader, _ = load_datasets_from_json(
        model, ProjectDir / dataset_path, torch.device("cuda"),
        batch_size = 1
    )
    batch = next(iter(train_loader))
    metric_fn = partial(logit_diff, batch=batch)
    clean_metric = metric_fn(model, batch.clean)
    corrupt_metric = metric_fn(model, batch.corrupt)
    print(f"Dataset: {dataset_path}")
    print(f"Clean tokens: {batch.clean.shape}")
    print(f"Corrupt tokens: {batch.corrupt.shape}")
    print(f"Clean metric: {clean_metric}")
    print(f"Corrupt metric: {corrupt_metric}")
    print()

Dataset: datasets/gender_bias.json
Clean tokens: torch.Size([1, 6])
Corrupt tokens: torch.Size([1, 6])
Clean metric: 1.8746223449707031
Corrupt metric: -3.3825511932373047

Dataset: datasets/greaterthan_gpt2-small_prompts.json
Clean tokens: torch.Size([1, 11])
Corrupt tokens: torch.Size([1, 11])
Clean metric: 3.544696807861328
Corrupt metric: -1.5264644622802734

Dataset: datasets/subject_verb_agreement.json
Clean tokens: torch.Size([1, 8])
Corrupt tokens: torch.Size([1, 8])
Clean metric: 2.7700538635253906
Corrupt metric: -3.5251340866088867

Dataset: datasets/ioi/ioi_ABBA_template_0_prompts.json
Clean tokens: torch.Size([1, 16])
Corrupt tokens: torch.Size([1, 16])
Clean metric: 4.267217636108398
Corrupt metric: 1.609701156616211

Dataset: datasets/ioi/ioi_ABBA_template_1_prompts.json
Clean tokens: torch.Size([1, 20])
Corrupt tokens: torch.Size([1, 20])
Clean metric: 4.107460021972656
Corrupt metric: 1.4901819229125977

Dataset: datasets/ioi/ioi_BABA_template_0_prompts.json
Clean toke