# Do zeros in the Jacobian imply a lack of causal connections between variables?


In [3]:
from datasets import load_dataset
import sys
sys.path.append('..')
import torch
from torch.nn import functional as F
from tqdm import tqdm
from jacobian_saes.utils import load_pretrained, default_prompt


In [4]:
sae, model, mlp_with_grads, layer = load_pretrained("lucyfarnik/jsaes_pythia70m1/sae_pythia-70m-deduped_blocks.3.ln2.hook_normalized_16384:v0")

Loaded pretrained model pythia-70m-deduped into HookedTransformer


In [5]:
def sandwich(mlp_in_features):
    mlp_in_reconstr = sae.decode(mlp_in_features, False)
    mlp_out, _ = mlp_with_grads(mlp_in_reconstr)
    mlp_out_features = sae.encode(mlp_out, True)
    return mlp_out_features

def sliced_sandwich(mlp_in_features):
    mlp_out_features = sandwich(mlp_in_features)
    return mlp_out_features[mlp_out_features>0]

def get_sliced_jac(mlp_in_features):
    jacobian = torch.autograd.functional.jacobian(sliced_sandwich, mlp_in_features)
    return jacobian[:, mlp_in_features>0]

def get_sliced_jac2(mlp_in_features):
    jacobian2 = torch.autograd.functional.jacobian(get_sliced_jac, mlp_in_features)
    return jacobian2[:, :, mlp_in_features>0].diagonal(dim1=-2, dim2=-1)


## Is there a noticeably correlation between having near-zero Jacobian values and near-zero Jacobian^2 values?
By "Jacobian^2" I mean a matrix where $$J_{i,j}^2 = \frac{d^2 y_i}{dx_j^2}$$

This is kind of a weak signal, but it does give us some data

In [3]:
_, cache = model.run_with_cache(default_prompt, names_filter=[f"blocks.{layer}.ln2.hook_normalized"])
mlp_in = cache["normalized", layer, "ln2"]

In [7]:
jacobians = []
jacobians2 = []
for act in tqdm(mlp_in[0]):
    mlp_in_features = sae.encode(act, False)
    jacobians.append(get_sliced_jac(mlp_in_features).flatten())
    jacobians2.append(get_sliced_jac2(mlp_in_features).flatten())

jacobians = torch.cat(jacobians)
jacobians2 = torch.cat(jacobians2)
jacobians.shape, jacobians.sum(), jacobians2.shape, jacobians2.sum()

NameError: name 'mlp_in' is not defined

In [None]:
F.cosine_similarity(jacobians, jacobians2, dim=0).item()

0.0

# Average Jacobians over token positions
We'd kinda expect connections to be sparse across input tokens.
Specifically, we'd expect the degree of sparsity here to be such that roughly
`n_avg_input_feats * n_feats` elements in the averaged Jacobian have absolute values
substantially above 0 (in our case > 0.01). In the ideal scenario `n_avg_input_feats` 
should be around 5ish, so we'd expect about 0.03% of elements in the averaged Jacobian
to be substantially non-zero.

In [7]:
num_tokens = 100_000
dataset = load_dataset("monology/pile-uncopyrighted", split="train", streaming=True)

summed_abs_jacobians = torch.zeros(sae.cfg.d_sae, sae.cfg.d_sae, device=sae.device)
max_abs_jac_elements = torch.zeros_like(summed_abs_jacobians)
num_tokens_processed = 0
with tqdm(total=num_tokens) as pbar:
    for idx, sample in enumerate(dataset):
        with torch.no_grad():
            _, cache = model.run_with_cache(sample["text"], names_filter=[f"blocks.{layer}.ln2.hook_normalized"])
            mlp_in = cache["normalized", layer, "ln2"][0, 1:]
            mlp_in_features = sae.encode(mlp_in, False)
            mlp_out_features = sandwich(mlp_in_features)

        for idx2, (mlp_in_feats, mlp_out_feats) in tqdm(enumerate(zip(mlp_in_features, mlp_out_features))):
            jacobian = torch.autograd.functional.jacobian(sliced_sandwich, mlp_in_feats)
            with torch.no_grad():
                full_jacobian_abs = torch.zeros_like(summed_abs_jacobians)
                full_jacobian_abs[mlp_out_feats>0] = jacobian.detach().abs()
                summed_abs_jacobians += full_jacobian_abs
                max_abs_jac_elements = torch.max(max_abs_jac_elements, full_jacobian_abs)
                num_tokens_processed += 1
                pbar.update(1)
                if idx2 % 10 == 0:
                    mean_abs_jacobian = summed_abs_jacobians / num_tokens_processed
                    proportion_nonzero = (mean_abs_jacobian>0.01).float().mean()
                    pbar.set_description(f"Percentage above 0.01: {100*proportion_nonzero.item():.4f}%")
                if num_tokens_processed >= num_tokens:
                    break

        if num_tokens_processed >= num_tokens:
            break

mean_abs_jacobian = summed_abs_jacobians / num_tokens_processed
(mean_abs_jacobian.abs()>0.01).float().mean()

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

2047it [04:17,  7.95it/s]0117%:   2%|▏         | 2047/100000 [04:18<3:10:04,  8.59it/s]
1236it [02:27,  8.40it/s]0119%:   3%|▎         | 3283/100000 [06:46<3:04:52,  8.72it/s]
81it [00:09,  8.25it/s]0.0118%:   3%|▎         | 3364/100000 [06:56<3:02:04,  8.85it/s]
385it [00:43,  8.91it/s].0116%:   4%|▎         | 3749/100000 [07:40<2:57:46,  9.02it/s]
515it [00:55,  9.36it/s].0112%:   4%|▍         | 4263/100000 [08:35<2:41:10,  9.90it/s] 
848it [01:55,  7.36it/s].0111%:   5%|▌         | 5112/100000 [10:31<3:10:05,  8.32it/s] 
229it [00:28,  8.13it/s].0110%:   5%|▌         | 5340/100000 [11:00<2:48:28,  9.36it/s]
246it [00:28,  8.63it/s].0109%:   6%|▌         | 5587/100000 [11:29<2:43:00,  9.65it/s]
300it [00:34,  8.68it/s].0109%:   6%|▌         | 5887/100000 [12:04<2:48:01,  9.33it/s]
592it [01:14,  7.94it/s].0107%:   6%|▋         | 6479/100000 [13:19<3:35:19,  7.24it/s]
77it [00:08,  9.24it/s]0.0106%:   7%|▋         | 6555/100000 [13:27<2:44:06,  9.49it/s]
320it [00:41,  7.70it/s].0106%

KeyboardInterrupt: 

In [8]:
for thresh in [0, 1e-6, 1e-5, 1e-4, 1e-3, 5e-3, 1e-2, 1e-1, 1]:
    above_thresh = mean_abs_jacobian>thresh
    print(f"{100*above_thresh.float().mean().item():.4f}% ({above_thresh.sum().item():.1e} elements) are above {thresh:.0e}")

50.9399% (1.4e+08 elements) are above 0e+00
39.9037% (1.1e+08 elements) are above 1e-06
13.6935% (3.7e+07 elements) are above 1e-05
1.5901% (4.3e+06 elements) are above 1e-04
0.2386% (6.4e+05 elements) are above 1e-03
0.0340% (9.1e+04 elements) are above 5e-03
0.0100% (2.7e+04 elements) are above 1e-02
0.0000% (1.0e+01 elements) are above 1e-01
0.0000% (0.0e+00 elements) are above 1e+00


## Vary inputs, see how that changes the output variables where the partial derivative is 0

In [4]:
_, cache = model.run_with_cache(default_prompt, names_filter=[f"blocks.{layer}.ln2.hook_normalized"])
mlp_in = cache["normalized", layer, "ln2"][0, 1:]
mlp_in_features = sae.encode(mlp_in, False)
mlp_out_features = sandwich(mlp_in_features)

In [None]:
for mlp_in_feats, mlp_out_feats in tqdm(zip(mlp_in_features, mlp_out_features)):
    sliced_jac = get_sliced_jac(mlp_in_feats)
    small_indices_in_sliced = (sliced_jac.abs() < 5e-3).nonzero()
    
    break #!

0it [00:00, ?it/s]


In [None]:
small_indices_in_sliced.shape

torch.Size([443, 2])

In [38]:
(mlp_in_feats>0).nonzero().flatten(), (mlp_out_feats>0).nonzero().flatten()

(tensor([  286,   683,   844,  1170,  1705,  1756,  2013,  3080,  4117,  5289,
          5595,  5861,  6225,  8014,  8591,  8647,  9356,  9410,  9931, 10175,
         10516, 10732, 10899, 12249, 12618, 13132, 13644, 13981, 14614, 15219,
         15537, 16085], device='mps:0'),
 tensor([  307,   365,   479,   792,   936,  1302,  3338,  3507,  4307,  4366,
          4463,  4485,  4560,  5150,  6380,  6394,  6712,  8943,  9023,  9692,
          9714,  9717,  9982, 10270, 10608, 10756, 10976, 11524, 11781, 12469,
         12590, 12671], device='mps:0'))