# Final charts


In [128]:
import numpy as np
import torch
from typing import Literal
from safetensors.torch import load_file
import sys
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd
import wandb
import warnings
sys.path.append("../")
from jacobian_saes.utils import load_pretrained, run_sandwich


## Histograms
Run `runners/histogram.py` first to get the data

We also plot the cumulative density function because it lets us see the differences more easily

In [129]:
model = "pythia-70m-deduped"
layer = 3

def get_hist_path(model: str, layer: int, trained_with_jac: bool = True):
    return f"../results/histograms/sae_pair_{model}_layer{layer}_16384_J{1 if trained_with_jac else 0.0}_k32:v0.safetensor"

hist_data = load_file(get_hist_path(model, layer, True))
hist = hist_data["hist"]
bin_edges = hist_data["bin_edges"]
bin_centers = (bin_edges[1:] + bin_edges[:-1]) / 2
hist_no_optim = load_file(get_hist_path(model, layer, False))["hist"]


In [130]:
upper_zoom_thresh = 0.1
cropped_bin_centers = bin_centers[bin_centers < upper_zoom_thresh]
cropped_hist = hist[bin_centers < upper_zoom_thresh]
cropped_hist_no_optim = hist_no_optim[bin_centers < upper_zoom_thresh]

fig = go.Figure()
fig.add_trace(go.Bar(
    x=cropped_bin_centers.cpu().numpy(),
    y=cropped_hist.cpu().numpy() / hist.sum().item() * 100,
    marker_color='forestgreen',
    width=(bin_edges[1] - bin_edges[0]).item(),
    name='JSAEs',
    opacity=0.9,
))
fig.add_trace(go.Bar(
    x=cropped_bin_centers.cpu().numpy(),
    y=cropped_hist_no_optim.cpu().numpy() / hist_no_optim.sum().item() * 100,
    marker_color='red',
    width=(bin_edges[1] - bin_edges[0]).item(),
    name="Traditional SAEs",
    opacity=0.5,
    marker=dict(pattern_shape="/")
))
fig.update_layout(
    title='Histogram of absolute values of Jacobian elements',
    xaxis_title='Absolute value of Jacobian element',
    yaxis_title='Frequency (%)',
    template='plotly_white',
    barmode='overlay',
    font_family="Verdana, Avenir Next",
    width=500,
)
fig.show()


In [131]:
cdf = hist.cumsum(dim=0) / hist.sum()
cdf_no_optim = hist_no_optim.cumsum(dim=0) / hist_no_optim.sum()
fig = go.Figure()
fig.add_trace(go.Scatter(
    x=bin_centers.cpu().numpy(),
    y=cdf_no_optim.cpu().numpy(),
    mode='lines',
    marker_color='crimson',
    name="Traditional SAEs",
    line=dict(dash='dash')
))
fig.add_trace(go.Scatter(
    x=bin_centers.cpu().numpy(),
    y=cdf.cpu().numpy(),
    mode='lines',
    marker_color='forestgreen',
    name='JSAEs',
))
fig.update_layout(
    title='Cummulative distribution func. of Jacobian elements',
    xaxis_title='Absolute value of Jacobian elements',
    yaxis_title='Proportion of elements below thresholds',
    template='plotly_white',
    font_family="Verdana, Avenir Next",
    width=500,
)
fig.show()

One figure for demonstrating sparsity with a bunch of different metrics (subcharts) incl this one

The Jacobian is generally smaller here

# Example Jacobian before and after

In [132]:
layer = 3

def get_wandb_path(layer: int, trained_with_jac: bool = True):
    return f"lucyfarnik/jsaes_pythia70m2/sae_pair_pythia-70m-deduped_layer{layer}_16384_J{1 if trained_with_jac else 0.0}_k32:v0"

sae_pair, model, mlp_with_grads, layer = load_pretrained(get_wandb_path(layer))
sae_pair_no_optim = load_pretrained(get_wandb_path(layer, False))[0]



This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)



Loaded pretrained model pythia-70m-deduped into HookedTransformer
Loaded pretrained model pythia-70m-deduped into HookedTransformer


In [133]:
_, cache = model.run_with_cache("Never gonna give you up", stop_at_layer=layer+1,
                                names_filter=[sae_pair.cfg.hook_name])
acts = cache[sae_pair.cfg.hook_name]
jacobian, _ = run_sandwich(sae_pair, mlp_with_grads, acts)
jacobian_no_optim, _ = run_sandwich(sae_pair_no_optim, mlp_with_grads, acts)

In [134]:
# Reshuffle (otherwise the ordering is determined by the size of the SAE latent activations which makes it look as if there's a correlation)
def shuffle_jacobian(jacobian):
    jacobian = jacobian[..., torch.randperm(jacobian.shape[-1])]
    jacobian = jacobian[..., torch.randperm(jacobian.shape[-2]), :]
    return jacobian

jacobian = shuffle_jacobian(jacobian)
jacobian_no_optim = shuffle_jacobian(jacobian_no_optim)

In [135]:
def plot_jacobian_examples(normalization: Literal[None, "L1", "L2"] = None):
    fig = make_subplots(rows=2, cols=4)

    for seq_pos in range(1, 5):
        jac = jacobian[0, seq_pos].detach().cpu().abs().numpy()
        jac_no_optim = jacobian_no_optim[0, seq_pos].detach().cpu().abs().numpy()

        if normalization == "L1":
            jac /= np.abs(jac).sum()
            jac_no_optim /= np.abs(jac_no_optim).sum()
        elif normalization == "L2":
            jac /= np.sqrt(np.sum(jac**2))
            jac_no_optim /= np.sqrt(np.sum(jac_no_optim**2))

        fig.add_trace(go.Heatmap(
            z=jac_no_optim,
            coloraxis="coloraxis",
            text=jac_no_optim.round(4),
            hoverinfo='text',
        ), row=1, col=seq_pos)

        fig.add_trace(go.Heatmap(
            z=jac,
            coloraxis="coloraxis",
            text=jac.round(4),
            hoverinfo='text',
        ), row=2, col=seq_pos)

    title = "Absolute values of Jacobian elements"
    if normalization:
        title += f" ({normalization} normalized)"
    

    fig.update_layout(
        coloraxis=dict(colorscale='Blues'),
        title=title,
        template='plotly_white',
        font_family="Verdana, Avenir Next",
        width=500,
    )

    fig.update_yaxes(title_text="Traditional SAE", row=1, col=1)
    fig.update_yaxes(title_text="JSAE", row=2, col=1)
    for i in range(1, 5):
        fig.update_xaxes(title_text=f"Token {i}", row=2, col=i)

    fig.show()

plot_jacobian_examples(), plot_jacobian_examples(normalization="L1"), plot_jacobian_examples(normalization="L2");

TODO: make sure we're only shuffling rows and cols, not just shuffling all elements

# Causal independence

In [136]:
take_abs = True
n_downsampled = 1_000

abl_data = load_file("../results/ablation/sae_pair_pythia-70m-deduped_layer3_16384_J1_k32:v0.safetensor")
abl_samples_full = abl_data["results"]

is_on_line = (abl_samples_full[:, 0] + abl_samples_full[:, 1]).abs() < 0.1

if take_abs:
    abl_samples_full.abs_()
    # abl_samples_no_optim.abs_()


abl_sample_small = abl_samples_full[torch.randint(0, abl_samples_full.shape[0], (n_downsampled,))]

# fig = make_subplots(rows=1, cols=2, subplot_titles=("Traditional SAE", "JSAE"))
fig = go.Figure()

fig.add_trace(go.Scatter(
    x=abl_sample_small[:, 0].cpu().numpy(),
    y=abl_sample_small[:, 1].cpu().numpy(),
    marker_color='blue',
    mode='markers',
# ), row=1, col=1)
))
# fig.add_trace(go.Bar(
#     # bar chart for how many are on the line vs off the line
#     x=["On the line", "Off the line"],
#     y=[is_on_line.float().mean().item(), (~is_on_line).float().mean().item()],

# ), row=1, col=2)
fig.update_layout(
    title='Correlation between Jacobian value and <br>change in downstream feature',
    xaxis_title=f"Jacobian element{' (abs. value)' if take_abs else ''}",
    yaxis_title=f"Change in downstream feature{' (abs. value)' if take_abs else ''}",
    # yaxis2=dict(
    #     # title='Proportion of samples',
    #     tickformat=".0f%",
    #     tickvals=[0, 0.25, 0.5, 0.75, 1.0]
    # ),
    template='plotly_white',
    font_family="Verdana, Avenir Next",
    width=500,
    showlegend=False
)
fig.show()

is_on_line.float().mean().item()

0.9749128818511963

In [137]:
is_on_line.shape

torch.Size([10000000])

chart ideas to explore this
- Laurence's idea: for each plot on this chart, sweep over the upstream feature values and plot how the downstream feature values change (so for each dot on this plot, make a line chart)


TODO: bar plot for how many examples are on the line vs in the cluster

TODO: chart where each dot is a training run where you change the jac coeff; x axis is reconstruction quality, y axis is sparsity (and also simialr stuff for autointerp et al)

TODO: figure showing that the reconstruction quality (and autointerp) doesn't suffer

# Sweeps

## Jacobian sweep

### CE loss score vs jac elements above 0.01

In [142]:
api = wandb.Api()
runs = api.runs("lucyfarnik/pythia70m-l3-sweep-j")

if runs.more:
    warnings.warn("You're not fetching all of the runs.\n\n")

metrics = ["jacobian_sparsity/jac_abs_above_0.01",
           "model_performance_preservation/ce_loss_score",
           "model_performance_preservation/ce_loss_score2"]

data = []
for run in runs:
    if run.state != "finished":
        warnings.warn(f"Run {run.id} is not finished.")
        continue
    run_data = {
        "jacobian_coefficient": run.config["jacobian_coefficient"],
    }
    for m in metrics:
        run_data[m.split("/")[-1]] = run.summary[m]
    data.append(run_data)

df = pd.DataFrame(data)
df["avd_ce_loss_score"] = (df["ce_loss_score"] + df["ce_loss_score2"]) / 2
# df["jac_coeff_with_label"] = df["jacobian_coefficient"].apply(lambda x: f"λ={x}")

fig = go.Figure()
fig.add_trace(go.Scatter(
    x=df["avd_ce_loss_score"],
    y=df["jac_abs_above_0.01"],
    text=df.apply(lambda row: row["jacobian_coefficient"] if row["jacobian_coefficient"] in [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1, 2, 3, 4, 5, 50] else "", axis=1),
    mode='markers+text',
    marker_color='forestgreen',
    textposition='top right',
    textfont=dict(size=11, color="gray")
))
fig.update_layout(
    title='Tradeoff between reconstruction and Jacobian sparsity',
    xaxis_title='Average cross-entropy score (1 = perfect reconstruction)',
    yaxis_title='Proportion of Jac. elements above 0.01',
    xaxis=dict(range=[0.5, 0.99]),
    template='plotly_white',
    width=500,
    yaxis=dict(autorange='reversed'),
    font_family="Verdana, Avenir Next",
)
fig.show()


You're not fetching all of the runs.





TODO rewrite to instead get the data from the eval runner

TODO do a bunch of these for different metrics; the more charts per unit of effort the better

TODO big grid of charts where the charts in a given row always have the same y axis and plots in the same col always have the same x axis
- These should include L1 sparsity dividied by sqrt(L2) — if those don't look as good, we might need to train with that as the objective to demonstrate that we're not just making the Jacs smaller