# Do zeros in the Jacobian imply a lack of causal connections between variables?


In [1]:
from datasets import load_dataset
import sys
sys.path.append('..')
import torch
from torch.nn import functional as F
from tqdm import tqdm
from jacobian_saes.utils import load_pretrained, default_prompt


In [2]:
sae, model, mlp, layer = load_pretrained("lucyfarnik/jacobian_saes_test/sae_pythia-70m-deduped_blocks.3.ln2.hook_normalized_16384:v4")

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Downloading large artifact sae_pythia-70m-deduped_blocks.3.ln2.hook_normalized_16384:v4, 136.15MB. 2 files... 
[34m[1mwandb[0m:   2 of 2 files downloaded.  
Done. 0:0:0.8
This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


Loaded pretrained model pythia-70m-deduped into HookedTransformer


In [3]:
def sandwich(mlp_in_features):
    mlp_in_reconstr = sae.decode(mlp_in_features, False)
    mlp_out = mlp(mlp_in_reconstr)
    mlp_out_features = sae.encode(mlp_out, True)
    return mlp_out_features

def sliced_sandwich(mlp_in_features):
    mlp_out_features = sandwich(mlp_in_features)
    return mlp_out_features[mlp_out_features>0]

def get_sliced_jac(mlp_in_features):
    jacobian = torch.autograd.functional.jacobian(sliced_sandwich, mlp_in_features)
    return jacobian[:, mlp_in_features>0]

def get_sliced_jac2(mlp_in_features):
    jacobian2 = torch.autograd.functional.jacobian(get_sliced_jac, mlp_in_features)
    return jacobian2[:, :, mlp_in_features>0].diagonal(dim1=-2, dim2=-1)


## Is there a noticeably correlation between having near-zero Jacobian values and near-zero Jacobian^2 values?
By "Jacobian^2" I mean a matrix where $$J_{i,j}^2 = \frac{d^2 y_i}{dx_j^2}$$

This is kind of a weak signal, but it does give us some data

In [3]:
_, cache = model.run_with_cache(default_prompt, names_filter=[f"blocks.{layer}.ln2.hook_normalized"])
mlp_in = cache["normalized", layer, "ln2"]

In [7]:
jacobians = []
jacobians2 = []
for act in tqdm(mlp_in[0]):
    mlp_in_features = sae.encode(act, False)
    jacobians.append(get_sliced_jac(mlp_in_features).flatten())
    jacobians2.append(get_sliced_jac2(mlp_in_features).flatten())

jacobians = torch.cat(jacobians)
jacobians2 = torch.cat(jacobians2)
jacobians.shape, jacobians.sum(), jacobians2.shape, jacobians2.sum()

NameError: name 'mlp_in' is not defined

In [None]:
F.cosine_similarity(jacobians, jacobians2, dim=0).item()

0.0

# Average Jacobians over token positions
There's a chance that the connections are sparse across the input distribution

In [6]:
num_tokens = 100_000
dataset = load_dataset("monology/pile-uncopyrighted", split="train", streaming=True)

summed_abs_jacobians = torch.zeros(sae.cfg.d_sae, sae.cfg.d_sae, device=sae.device)
max_abs_jac_elements = torch.zeros_like(summed_abs_jacobians)
num_tokens_processed = 0
with tqdm(total=num_tokens) as pbar:
    for idx, sample in enumerate(dataset):
        with torch.no_grad():
            _, cache = model.run_with_cache(sample["text"], names_filter=[f"blocks.{layer}.ln2.hook_normalized"])
            mlp_in = cache["normalized", layer, "ln2"][0, 1:]
            mlp_in_features = sae.encode(mlp_in, False)
            mlp_out_features = sandwich(mlp_in_features)

        for idx2, (mlp_in_feats, mlp_out_feats) in tqdm(enumerate(zip(mlp_in_features, mlp_out_features))):
            jacobian = torch.autograd.functional.jacobian(sliced_sandwich, mlp_in_feats)
            with torch.no_grad():
                full_jacobian_abs = torch.zeros_like(summed_abs_jacobians)
                full_jacobian_abs[mlp_out_feats>0] = jacobian.detach().abs()
                summed_abs_jacobians += full_jacobian_abs
                max_abs_jac_elements = torch.max(max_abs_jac_elements, full_jacobian_abs)
                num_tokens_processed += 1
                pbar.update(1)
                if idx2 % 10 == 0:
                    mean_abs_jacobian = summed_abs_jacobians / num_tokens_processed
                    proportion_nonzero = (mean_abs_jacobian>0).float().mean()
                    pbar.set_description(f"Nonzero percentage: {100*proportion_nonzero.item():.1f}%")
                if num_tokens_processed >= num_tokens:
                    break

        if num_tokens_processed >= num_tokens:
            break

mean_abs_jacobian = summed_abs_jacobians / num_tokens_processed
(mean_abs_jacobian.abs()>0).float().mean()

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

2047it [03:43,  9.17it/s]:   2%|▏         | 2046/100000 [03:57<2:42:48, 10.03it/s]
1236it [02:15,  9.12it/s]:   3%|▎         | 3282/100000 [06:13<2:55:15,  9.20it/s]
81it [00:08,  9.92it/s]1%:   3%|▎         | 3364/100000 [06:21<2:34:03, 10.45it/s]
385it [00:41,  9.22it/s]%:   4%|▎         | 3749/100000 [07:03<4:05:36,  6.53it/s]
515it [00:55,  9.34it/s]%:   4%|▍         | 4263/100000 [07:58<2:50:29,  9.36it/s]
848it [01:36,  8.77it/s]%:   5%|▌         | 5112/100000 [09:35<2:35:57, 10.14it/s] 
229it [00:26,  8.78it/s]%:   5%|▌         | 5341/100000 [10:02<2:49:47,  9.29it/s]
246it [00:26,  9.22it/s]%:   6%|▌         | 5587/100000 [10:29<2:36:35, 10.05it/s] 
300it [00:33,  9.02it/s]%:   6%|▌         | 5886/100000 [11:02<2:44:30,  9.53it/s]
592it [01:04,  9.16it/s]%:   6%|▋         | 6479/100000 [12:07<3:07:04,  8.33it/s]
77it [00:09,  8.48it/s]2%:   7%|▋         | 6555/100000 [12:16<2:54:57,  8.90it/s]
320it [00:37,  8.42it/s]%:   7%|▋         | 6876/100000 [12:55<2:57:25,  8.75it/s]
71

KeyboardInterrupt: 

In [24]:
for thresh in [0, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1]:
    above_thresh = mean_abs_jacobian>thresh
    print(f"{100*above_thresh.float().mean().item():.2f}% ({above_thresh.sum().item():.1e} elements) are above {thresh:.0e}")

49.07% (1.3e+08 elements) are above 0e+00
41.64% (1.1e+08 elements) are above 1e-06
15.35% (4.1e+07 elements) are above 1e-05
1.77% (4.8e+06 elements) are above 1e-04
0.32% (8.6e+05 elements) are above 1e-03
0.02% (5.0e+04 elements) are above 1e-02
0.00% (1.4e+01 elements) are above 1e-01
0.00% (0.0e+00 elements) are above 1e+00


## Vary inputs, see how that changes the output variables where the partial derivative is 0

In [4]:
_, cache = model.run_with_cache(default_prompt, names_filter=[f"blocks.{layer}.ln2.hook_normalized"])
mlp_in = cache["normalized", layer, "ln2"][0, 1:]
mlp_in_features = sae.encode(mlp_in, False)
mlp_out_features = sandwich(mlp_in_features)

In [None]:
for mlp_in_feats, mlp_out_feats in tqdm(zip(mlp_in_features, mlp_out_features)):
    sliced_jac = get_sliced_jac(mlp_in_feats)
    small_indices_in_sliced = (sliced_jac.abs() < 5e-3).nonzero()
    
    break #!

0it [00:00, ?it/s]


In [None]:
small_indices_in_sliced.shape

torch.Size([443, 2])

In [38]:
(mlp_in_feats>0).nonzero().flatten(), (mlp_out_feats>0).nonzero().flatten()

(tensor([  286,   683,   844,  1170,  1705,  1756,  2013,  3080,  4117,  5289,
          5595,  5861,  6225,  8014,  8591,  8647,  9356,  9410,  9931, 10175,
         10516, 10732, 10899, 12249, 12618, 13132, 13644, 13981, 14614, 15219,
         15537, 16085], device='mps:0'),
 tensor([  307,   365,   479,   792,   936,  1302,  3338,  3507,  4307,  4366,
          4463,  4485,  4560,  5150,  6380,  6394,  6712,  8943,  9023,  9692,
          9714,  9717,  9982, 10270, 10608, 10756, 10976, 11524, 11781, 12469,
         12590, 12671], device='mps:0'))