# Causal sparsity 2: electric boogaloo
We have nicely sparse Jacobians and that almost certainly implies sparsity in the causal graph.
But what if my JSAEs are actually finding some sort of perverse solution that leads to sparse Jacobians but not sparse computation?

For an illustration, imagine if the function of which we're taking the Jacobian is a high-degree polynomial, our SAEs could just be finding its local minima/maxima.
This kind of thing seems _very_ unlikely to me considering that the function is affine-ReLUish-affine, but it's plausible.
So we should check if there actually is causal sparsity using something other than the thing we're directly optimizing for.

## Intervening on the input feature activations
Very simple idea -- if our Jacobians say "on this particular token, input feature X doesn't causally affect outpput feature Y", then let's just try varying feature X (e.g. ablating it) and see if output feature Y changes.

In [9]:
import torch
from tqdm import tqdm
import sys
sys.path.append('..')
from jacobian_saes.utils import load_pretrained, default_prompt, run_sandwich

layer = 3
wandb_path = f"lucyfarnik/jsaes_pythia70m1/sae_pythia-70m-deduped_blocks.{layer}.ln2.hook_normalized_16384:v0"
sae_pair, model, mlp_with_grads, _ = load_pretrained(wandb_path, device="cpu")

_, cache = model.run_with_cache(default_prompt, stop_at_layer=layer+1,
                                names_filter=[sae_pair.cfg.hook_name])
acts = cache[sae_pair.cfg.hook_name]

Loaded pretrained model pythia-70m-deduped into HookedTransformer


In [11]:
def run_ablation_experiment(look_at_important_connections: bool = False, eps: float = 0.01):
    """
    For each activation, use the Jacobian to find connections which are either
    important or unimportant (depending on the value of look_at_important_connections).
    Ablate the input feature in the connection and see how much the output feature changes.
    """
    out_act_diffs = []
    with tqdm(total=acts.shape[1]-1) as pbar:
        for act in acts[0, 1:]:
            jacobian, acts_dict = run_sandwich(sae_pair, mlp_with_grads, act)
            sae_acts2 = acts_dict["sae_acts2"]
            topk_indices1 = acts_dict["topk_indices1"]
            topk_indices2 = acts_dict["topk_indices2"]
            if look_at_important_connections:
                connections = (jacobian.abs() > eps).nonzero()
            else:
                connections = (jacobian.abs() < eps).nonzero()

            for out_idx, in_idx in connections:
                in_idx_in_d_sae = topk_indices1[in_idx]
                out_idx_in_d_sae = topk_indices2[out_idx]
                #! Make sure the indexing is aligned (i.e. that get_jacobian doesn't swap things around)
                
                # ablate in_idx and run through sandwich
                act_abl = act - sae_pair.get_W_dec(False)[in_idx_in_d_sae] 
                mlp_out_abl, _ = mlp_with_grads(act_abl)
                sae_acts2_abl = sae_pair.encode(mlp_out_abl, True)

                # check that the values are similar
                out_act_diffs.append((sae_acts2_abl[out_idx_in_d_sae] - sae_acts2[out_idx_in_d_sae]).item())


            pbar.update(1)
            diffs_tensor = torch.tensor(out_act_diffs)
            pbar.set_description(f"Mean diff: {diffs_tensor.mean():.3f} | Mean abs diff: {diffs_tensor.abs().mean():.3f} | N diffs: {len(out_act_diffs)}")

    return torch.tensor(out_act_diffs).abs().mean().item(), len(out_act_diffs)


In [12]:
run_ablation_experiment(False)

Mean diff: -0.013 | Mean abs diff: 0.015 | N diffs: 778063: 100%|██████████| 858/858 [26:55<00:00,  1.88s/it]


(0.014961636625230312, 778063)

In [13]:
run_ablation_experiment(True)

Mean diff: -0.082 | Mean abs diff: 0.099 | N diffs: 100529: 100%|██████████| 858/858 [04:22<00:00,  3.27it/s]


(0.0989634096622467, 100529)