# What do those Jacobian histograms look like?


In [2]:
import torch
import sys
from tqdm import tqdm
import wandb
import warnings
import pandas as pd
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.express as px
sys.path.append("../")
from jacobian_saes.sae_pair import SAEPair
from transformer_lens import HookedTransformer

api = wandb.Api()

## Approximation from the training runs

In [11]:
take_abs = True
train_with_jac_term = True
print(f"{take_abs=} {train_with_jac_term=}")
hist_key = f"jacobian_sparsity/jac_{"abs_" if take_abs else ""}hist"
run = api.run(f"lucyfarnik/jacobian_saes_test/{"oz886jk3" if train_with_jac_term else "f38b1a3f"}")
history = run.history(keys=[hist_key])
hist = history[hist_key][len(history)-1]
bins = hist["bins"]
values = hist["values"]

# calculate x values by averaging the bin edges
x = [(bins[i] + bins[i+1]) / 2 for i in range(len(bins) - 1)]
px.bar(x=x, y=values, width=800).show()
px.bar(x=x, y=values, width=800, log_y=True).show()

for idx, bin in enumerate(bins[:12]):
    if idx == 0:
        continue
    print(f"There are {sum(values[idx:])} values and above {bin:.3f}.")

take_abs=True train_with_jac_term=True


There are 289 values and above 0.003.
There are 151 values and above 0.006.
There are 99 values and above 0.009.
There are 72 values and above 0.012.
There are 56 values and above 0.015.
There are 46 values and above 0.018.
There are 38 values and above 0.021.
There are 33 values and above 0.024.
There are 29 values and above 0.027.
There are 26 values and above 0.030.
There are 23 values and above 0.033.


### More precise measurements


In [3]:
train_with_jac_term = True
ignore_ln = True
path = f"lucyfarnik/jacobian_saes_test/sae_pythia-70m-deduped_blocks.3.hook_resid_pre_16384:v2{1 if train_with_jac_term else 2}" # TODO check this
artifact = api.artifact(path)
artifact.download()
sae = SAEPair.load_from_pretrained("artifacts/" + path.split("/")[-1], device="mps")
sae

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Downloading large artifact sae_pythia-70m-deduped_blocks.3.hook_resid_pre_16384:v21, 136.15MB. 2 files... 
[34m[1mwandb[0m:   2 of 2 files downloaded.  
Done. 0:0:0.8
This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


SAE(
  (activation_fn): TopK(
    (postact_fn): ReLU()
  )
  (hook_sae_input): HookPoint()
  (hook_sae_acts_pre): HookPoint()
  (hook_sae_acts_post): HookPoint()
  (hook_sae_output): HookPoint()
  (hook_sae_recons): HookPoint()
  (hook_sae_error): HookPoint()
)

In [4]:
model = HookedTransformer.from_pretrained(sae.cfg.model_name, device=sae.device)
layer = sae.cfg.hook_layer

Loaded pretrained model pythia-70m-deduped into HookedTransformer


In [5]:
prompt = "Given the existence as uttered forth in the public works of Puncher and Wattmann of a personal God quaquaquaqua with white beard quaquaquaqua outside time without extension who from the heights of divine apathia divine athambia divine aphasia loves us dearly with some exceptions for reasons unknown but time will tell and suffers like the divine Miranda with those who for reasons unknown but time will tell are plunged in torment plunged in fire whose fire flames if that continues and who can doubt it will fire the firmament that is to say blast hell to heaven so blue still and calm so calm with a calm which even though intermittent is better than nothing but not so fast and considering what is more that as a result of the labors left unfinished crowned by the Acacacacademy of Anthropopopometry of Essy-in-Possy of Testew and Cunard it is established beyond all doubt all other doubt than that which clings to the labours of men"

_, cache = model.run_with_cache(prompt)
acts = cache[sae.cfg.hook_name]
_, sae_cache = sae.run_with_cache(acts, False)
sae_pre_acts = sae_cache['hook_sae_acts_pre']
sae_pre_acts.shape

torch.Size([1, 202, 16384])

In [6]:
def sae_mlp_sandwitch(feature_acts_pre):
    feature_acts = sae.activation_fn(feature_acts_pre)
    reconstr_acts = sae.decode(feature_acts, False)
    if ignore_ln:
        mlp_out = model.blocks[layer].mlp(reconstr_acts)
    else:
        ln_out = model.blocks[layer].ln2(reconstr_acts)
        mlp_out = model.blocks[layer].mlp(ln_out)
    return sae.encode(mlp_out, True)

sliced_jacs_list = []
for act in tqdm(sae_pre_acts[0, 11:21]):
    in_indices = sae.activation_fn(act).nonzero().flatten()
    out_indices = sae_mlp_sandwitch(act).nonzero().flatten()
    jacobian = torch.autograd.functional.jacobian(sae_mlp_sandwitch, act)
    sliced_jacobian = jacobian[out_indices][:, in_indices]
    sliced_jacs_list.append(sliced_jacobian.flatten())
    # print(torch.nonzero(sliced_jacobian.abs()>0.01).shape[0])
    # px.histogram(sliced_jacobian.abs().flatten().cpu(), log_y=True).show()

sliced_jacs = torch.stack(sliced_jacs_list)


  0%|          | 0/10 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
100%|██████████| 10/10 [01:40<00:00, 10.02s/it]


In [7]:
# make a histogram
sliced_jacs = torch.stack(sliced_jacs_list).flatten().abs()
# calculate histogram, values in the bins by the number of jacobians
hist = torch.histc(sliced_jacs, bins=200, min=0, max=0.5).cpu() / len(sliced_jacs_list)
bins = torch.linspace(0, 0.5, 201)
x = [(bins[i] + bins[i+1]) / 2 for i in range(len(bins) - 1)]
px.bar(x=x, y=hist, width=800, log_y=True).show()

for idx, bin in enumerate(bins[:20]):
    if idx == 0:
        continue
    print(f"There are {sum(hist[idx:]):.1f} values and above {bin:.3f}.")

There are 401.5 values and above 0.002.
There are 213.9 values and above 0.005.
There are 141.1 values and above 0.007.
There are 102.1 values and above 0.010.
There are 78.2 values and above 0.012.
There are 65.5 values and above 0.015.
There are 55.3 values and above 0.018.
There are 48.8 values and above 0.020.
There are 42.9 values and above 0.022.
There are 37.8 values and above 0.025.
There are 34.0 values and above 0.027.
There are 30.9 values and above 0.030.
There are 28.2 values and above 0.032.
There are 25.7 values and above 0.035.
There are 23.8 values and above 0.037.
There are 21.7 values and above 0.040.
There are 20.2 values and above 0.043.
There are 19.0 values and above 0.045.
There are 17.8 values and above 0.047.
