In [46]:
from typing import cast

import torch
from jaxtyping import Float, Int
from torch import Tensor
from torch.utils.data import DataLoader

from muutils.dbg import dbg, dbg_tensor, dbg_auto

from spd.experiments.resid_mlp.resid_mlp_dataset import ResidualMLPDataset
from spd.models.component_model import ComponentModel
from spd.models.components import EmbeddingComponent, GateMLP, LinearComponent, VectorGateMLP
from spd.utils.component_utils import calc_causal_importances
from spd.utils.data_utils import DatasetGeneratedDataLoader
from spd.utils.general_utils import extract_batch_data

DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu"


In [33]:
component_model, cfg, path = ComponentModel.from_pretrained("wandb:goodfire/spd/runs/dcjm9g2n")
component_model.to(DEVICE);
# dbg_auto(component_model)
# dbg_auto(cfg)
# dbg_auto(path)
# dir(component_model)

Downloaded checkpoint from /home/miv/projects/MATS/spd/wandb/5mk5h1lk/files/resid_mlp.pth


In [34]:

# grep_repr((component_model, cfg, path, dir(component_model)), "_features")
# cfg.task_config
# grep_repr(, "_features")

In [43]:
N_SAMPLES: int = 1000

dataset = ResidualMLPDataset(
    n_features=component_model.model.config.n_features,
    feature_probability=cfg.task_config.feature_probability,
    device=DEVICE,
    calc_labels=False,  # Our labels will be the output of the target model
    label_type=None,
    act_fn_name=None,
    label_fn_seed=None,
    label_coeffs=None,
    data_generation_type=cfg.task_config.data_generation_type,
    # synced_inputs=synced_inputs,
)

dataloader = DatasetGeneratedDataLoader(dataset, batch_size=N_SAMPLES, shuffle=False)


In [51]:

def component_activations(
    model: ComponentModel,
    dataloader: DataLoader[Int[Tensor, "..."]]
    | DataLoader[tuple[Float[Tensor, "..."], Float[Tensor, "..."]]],
    device: str,
) -> dict[str, Float[Tensor, " C n_steps"]]:
    """Get the number and strength of the masks over the full dataset."""
    # We used "-" instead of "." as module names can't have "." in them
    gates: dict[str, GateMLP | VectorGateMLP] = {
        k.removeprefix("gates.").replace("-", "."): cast(GateMLP | VectorGateMLP, v)
        for k, v in model.gates.items()
    }
    components: dict[str, LinearComponent | EmbeddingComponent] = {
        k.removeprefix("components.").replace("-", "."): cast(
            LinearComponent | EmbeddingComponent, v
        )
        for k, v in model.components.items()
    }

    # --- Get Batch --- #
    batch = extract_batch_data(next(iter(dataloader)))
    batch = batch.to(device)

    _, pre_weight_acts = model.forward_with_pre_forward_cache_hooks(
        batch, module_names=list(components.keys())
    )
    Vs = {module_name: v.V for module_name, v in components.items()}

    causal_importances, _ = calc_causal_importances(
        pre_weight_acts=pre_weight_acts,
        Vs=Vs,
        gates=gates,
        detach_inputs=False,
    )

    return causal_importances

ci = component_activations(
	component_model,
	dataloader,
	device=DEVICE,
	# threshold=0.1,
)

dbg_auto(ci);

[ <ipykernel>:45 ] ci: <dict of len()=2, key_types={str}, val_types={Tensor}>
  layers.0.mlp_in: μ=[36m0.01[0m σ=[33m0.09[0m x̃=[32m0.00[0m R=[[35m0.00[0m,[35m1.00[0m] ℙ˪=|[34m█▂▂▃▂▄▅[0m| shape=([95m1000[0m,[95m100[0m) dtype=[38;5;208mtorch[0m.[38;5;167mfloat32[0m device=[38;5;76mcuda:0[0m [90m∇✓[0m
  layers.0.mlp_out: μ=[36m0.06[0m σ=[33m0.18[0m x̃=[32m0.00[0m R=[[35m0.00[0m,[35m1.00[0m] ℙ˪=|[34m█▅▅▅▅▅▅[0m| shape=([95m1000[0m,[95m100[0m) dtype=[38;5;208mtorch[0m.[38;5;167mfloat32[0m device=[38;5;76mcuda:0[0m [90m∇✓[0m
