In [1]:
import torch

from cupbearer import tasks, scripts, utils
from cupbearer.tasks.tiny_natural_mechanisms import get_effect_tokens
from cupbearer.detectors.statistical import MahalanobisDetector
from elk_experiments.auto_circuit_detector import AutoCircuitGradScoresDetector, AutoCircuitPruningDetector
from elk_experiments.tiny_natural_mechanisms_utils import get_task_subset
from elk_experiments.utils import repo_path_to_abs_path
from auto_circuit.types import AblationType

In [None]:
class AutoCircuitGradScoresMahalanobis(AutoCircuitGradScoresDetector, MahalanobisDetector):

    def train(self, **kwargs):
        super().train(**kwargs)
        # Post process
        with torch.inference_mode():
            self.means = self._means
            self.covariances = {k: C / (self._ns[k] - 1) for k, C in self._Cs.items()}
            if any(torch.count_nonzero(C) == 0 for C in self.covariances.values()):
                raise RuntimeError("All zero covariance matrix detected.")

            self.post_covariance_training(**kwargs)

In [None]:
model_name = "pythia-70m"
device="cpu"
task_name="ifelse"

In [None]:
task = get_task_subset(tasks.tiny_natural_mechanisms(task_name, device, model_name), 64, 32, 32)

In [None]:
detector_path = repo_path_to_abs_path("output") / f"auto-circuit_mahalanobis_{model_name}_{task_name}_detector"

detector = AutoCircuitGradScoresMahalanobis(
    effect_tokens = get_effect_tokens(task_name, task.model), 
    ablation_type=AblationType.ZERO,
    resid_src=False, 
    resid_dest=False,
    mlp_src=False,
    mlp_dest=False,
    device=device, 
)

In [None]:
detector.set_model(task.model)
detector.train(
    trusted_data=task.trusted_data,
    untrusted_data=task.untrusted_train_data,
    save_path=None,
    batch_size=32,
)

In [None]:
scripts.eval_detector(
    task, 
    detector, 
    save_path=None, 
    pbar=True,
    batch_size=2,
)