# Exploration of ablation results

Something weird's going on in this chart

In [1]:
from datasets import load_dataset
import torch
from safetensors.torch import load_file
import sys
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from tqdm import tqdm
import random
sys.path.append("../")
from jacobian_saes.utils import load_pretrained, run_sandwich


What's going on with that cluster? Why isn't everything just neatly on the line?

In [2]:
n_samples = 10_000

sae_pair, model, mlp_with_grads, layer = load_pretrained("lucyfarnik/jsaes_pythia70m2/sae_pair_pythia-70m-deduped_layer3_16384_J1_k32:v0")
k = sae_pair.cfg.activation_fn_kwargs["k"]

dataset = load_dataset("allenai/c4", "en", split="validation", streaming=True)

results = []
with torch.no_grad():
    with tqdm(total=n_samples) as pbar:
        for item in dataset:
            _, cache = model.run_with_cache(
                item["text"],
                stop_at_layer=layer + 1,
                names_filter=[sae_pair.cfg.hook_name],
            )
            acts = cache[sae_pair.cfg.hook_name][:, 1:]

            for act in acts[0, 1:]:
                jacobian, acts_dict = run_sandwich(sae_pair, mlp_with_grads, act)
                sae_acts1 = acts_dict["sae_acts1"]
                sae_acts2 = acts_dict["sae_acts2"]
                topk_indices1 = acts_dict["topk_indices1"]
                topk_indices2 = acts_dict["topk_indices2"]

                for _ in range(k):  # doesn't have to be k, this can be any number
                    # randint from 0 to k-1
                    out_idx, in_idx = torch.randint(0, k, (2,))

                    in_idx_in_d_sae = topk_indices1[in_idx]
                    out_idx_in_d_sae = topk_indices2[out_idx]

                    act_abl = act - sae_pair.get_W_dec(False)[in_idx_in_d_sae]
                    mlp_out_abl, _ = mlp_with_grads(act_abl)
                    sae_acts2_abl = sae_pair.encode(mlp_out_abl, True)

                    results.append({
                        "id": ''.join([random.choice('0123456789abcdef') for _ in range(8)]),
                        "jac_el": jacobian[out_idx, in_idx],
                        "diff": sae_acts2_abl[out_idx_in_d_sae] - sae_acts2[out_idx_in_d_sae],
                        "act": act,
                        "in_feature_dir": sae_pair.get_W_dec(False)[in_idx_in_d_sae],
                        "in_feature_strength": sae_acts1[in_idx_in_d_sae],
                        "out_feature_strength": sae_acts2[out_idx_in_d_sae],
                        "out_idx_in_d_sae": out_idx_in_d_sae,
                    })

                    pbar.update(1)
                    if pbar.n >= n_samples:
                        break
                if pbar.n >= n_samples:
                    break
            if pbar.n >= n_samples:
                break


This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


Loaded pretrained model pythia-70m-deduped into HookedTransformer


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
100%|██████████| 10000/10000 [00:51<00:00, 195.85it/s]


In [3]:
take_abs = False

jac_els = torch.tensor([r["jac_el"] for r in results])
diffs = torch.tensor([r["diff"] for r in results])

get_is_on_line = lambda jac_els, diffs: (jac_els + diffs).abs() < 0.1
get_is_in_cluster = lambda jac_els, diffs: (jac_els < 0.05) & (diffs < -0.5)

is_on_line = get_is_on_line(jac_els, diffs)
is_in_cluster = get_is_in_cluster(jac_els, diffs)
colors = torch.zeros(len(results))
colors[is_on_line] = 1
colors[is_in_cluster] = 2

if take_abs:
    jac_els.abs_()
    diffs.abs_()

fig = go.Figure()

fig.add_trace(go.Scatter(
    x=jac_els.cpu().numpy(),
    y=diffs.cpu().numpy(),
    text=[r["id"] for r in results],
    marker=dict(color=colors.cpu().numpy(), colorscale=['red', 'blue', 'green']),
    mode='markers',
))
fig.update_layout(
    title='Correlation between Jacobian value and change in downstream feature',
    xaxis_title=f"Jacobian element{' (absolute value)' if take_abs else ''}",
    yaxis_title=f"Change in downstream feature{' (absolute value)' if take_abs else ''}",
    template='plotly_white',
    showlegend=False,
    width=800,
)
fig.show()

print(f"on line: {is_on_line.sum().item()}, in cluster: {is_in_cluster.sum().item()}, outliers: {(~is_in_cluster & ~is_on_line).sum().item()}")

on line: 9758, in cluster: 214, outliers: 28


In [None]:
selected_ids = ["5e9b096c", "0d2c391d", "d5ff6fef", "b39434ab", "baa7945a", "387fb266",
                "4f27bcc5", "8dd124db", "4a527538", "8ede4044",
                "a04b706c", "476c773f"]

def find_by_id(id: str):
    filtered = [r for r in results if r["id"] == id]
    if len(filtered) > 1:
        raise ValueError(f"Multiple results with the same ID ({id})")
    if len(filtered) == 0:
        raise ValueError(f"No results with the ID {id}")

    return filtered[0]

@torch.no_grad()
def get_act_relationship(feature_data: dict, max_upstream: int = 5, n_points: int = 1000):
    upstream_acts = torch.linspace(0, max_upstream, n_points,
                                   device=results[0]["act"].device).reshape(-1, 1)
    act_abl = feature_data["act"] - feature_data["in_feature_strength"] * feature_data["in_feature_dir"]
    act_range = upstream_acts * feature_data["in_feature_dir"] + act_abl.unsqueeze(0)
    mlp_out_range, _ = mlp_with_grads(act_range)
    sae_acts2_range = sae_pair.encode(mlp_out_range, True)
    downstream_acts = sae_acts2_range[:, feature_data["out_idx_in_d_sae"]]

    return upstream_acts.flatten(), downstream_acts


def plot_feature_relationship(selected_id: str):
    feature_data = find_by_id(selected_id)

    is_on_line = get_is_on_line(feature_data["jac_el"], feature_data["diff"])
    is_in_cluster = get_is_in_cluster(feature_data["jac_el"], feature_data["diff"])
    color = "blue" if is_on_line else "green" if is_in_cluster else "red"

    upstream_acts, downstream_acts = get_act_relationship(feature_data,
                                                          max(5, feature_data["in_feature_strength"]+1))

    fig = go.Figure()
    fig.add_trace(go.Scatter(
        x=upstream_acts.cpu().numpy(),
        y=downstream_acts.cpu().numpy(),
        mode='lines',
        line=dict(color=color),
    ))
    # add a dashed vertical line indicating where in_feature_strength is on the x axis
    fig.add_shape(
        type="line",
        x0=feature_data["in_feature_strength"],
        y0=downstream_acts.min(),
        x1=feature_data["in_feature_strength"],
        y1=downstream_acts.max(),
        line=dict(
            color="red",
            width=1,
            dash="dash",
        ),
    )

    fig.update_layout(
        title=f'Change in downstream feature with respect to change in input feature ({selected_id})',
        xaxis_title="Upstream activation",
        yaxis_title="Downstream activation",
        template='plotly_white',
        showlegend=False,
        width=800,
    )
    fig.show()

for id in selected_ids:
    plot_feature_relationship(id)

## How many features pairs are there where the connection follows this JumpReLU pattern?

In [48]:
def is_jump_relu(feature_data: dict, tolerance=1e-2):
    xs, ys = get_act_relationship(feature_data,
                                  max(5, feature_data["in_feature_strength"]+1),
                                  n_points=100)
    
    # Check for zero region
    zero_region = (ys == 0)
    if torch.any(zero_region):
        # Find the bounds of the zero region
        zero_indices = torch.where(zero_region)[0]
        zero_start, zero_end = xs[zero_indices[0]], xs[zero_indices[-1]]

        # Check linearity outside the zero region
        linear_region_mask = (xs > zero_end) | (xs < zero_start)
        linear_x = xs[linear_region_mask]
        linear_y = ys[linear_region_mask]

        zero_start, zero_end = zero_start.item(), zero_end.item()
    else:
        zero_start, zero_end = None, None
        linear_x, linear_y = xs, ys

    if linear_x.numel() < 2:
        return False, {"reason": "no linear region"}

    # Fit a linear model to the outside region (y = slope * x + intercept)
    A = torch.stack([linear_x, torch.ones_like(linear_x)], dim=1)  # [x, 1]
    lstsq = torch.linalg.lstsq(A.cpu(), linear_y.unsqueeze(1).cpu())  # Linear regression
    slope = lstsq.solution[0].item()
    intercept = lstsq.solution[1].item()

    # Check if residuals are within tolerance
    residuals = linear_y - (slope * linear_x + intercept)
    is_linear = torch.all(torch.abs(residuals) < tolerance)

    if is_linear:
        return True, {
            "zero_region": (zero_start, zero_end),
            "linear_slope": slope,
            "linear_intercept": intercept,
        }
    else:
        return False, {"reason": "not linear outside zero region"}

for id in selected_ids:
    print(id, is_jump_relu(find_by_id(id)))

5e9b096c (True, {'zero_region': (0.0, 0.808080792427063), 'linear_slope': 0.13886849582195282, 'linear_intercept': 0.6423861384391785})
0d2c391d (True, {'zero_region': (None, None), 'linear_slope': 0.09742142260074615, 'linear_intercept': 0.8357949256896973})
d5ff6fef (True, {'zero_region': (None, None), 'linear_slope': 0.1340959072113037, 'linear_intercept': 1.0700109004974365})
b39434ab (False, {'reason': 'not linear outside zero region'})
baa7945a (False, {'reason': 'not linear outside zero region'})
387fb266 (False, {'reason': 'not linear outside zero region'})
4f27bcc5 (True, {'zero_region': (0.0, 0.10101009905338287), 'linear_slope': 0.034605707973241806, 'linear_intercept': 0.758158266544342})
8dd124db (False, {'reason': 'no linear region'})
4a527538 (True, {'zero_region': (0.0, 0.9090908765792847), 'linear_slope': 0.0029683939646929502, 'linear_intercept': 0.8627093434333801})
8ede4044 (True, {'zero_region': (0.0, 1.8181817531585693), 'linear_slope': 0.03210395947098732, 'linea

In [49]:
n_samples_to_check = 10_000

jump_relu_results = []
for r in tqdm(results[:n_samples_to_check]):
    jump_relu_results.append(is_jump_relu(r)[0])
f"{torch.tensor(jump_relu_results).float().sum().int():,} out of {n_samples_to_check:,} are jump relus"

100%|██████████| 10000/10000 [08:09<00:00, 20.43it/s] 


'9,710 out of 10,000 are jump relus'

### Plot all of the features on the same plot

In [69]:
n_samples_to_plot = 200

fig = go.Figure()

with tqdm(total=n_samples_to_plot) as pbar:
    for feature_data in results:
        if not get_is_on_line(feature_data["jac_el"], feature_data["diff"]):
            continue
        color = "blue"

        upstream_acts, downstream_acts = get_act_relationship(feature_data, 10, n_points=100)

        # normalize downstream_acts so that the largest value is 1
        downstream_acts = downstream_acts / downstream_acts.max()

        fig.add_trace(go.Scatter(
            x=upstream_acts.cpu().numpy(),
            y=downstream_acts.cpu().numpy(),
            mode='lines',
            line=dict(color=color),
        ))

        pbar.update(1)
        if pbar.n >= n_samples_to_plot:
            break

fig.update_layout(
    title=f'Change in downstream feature with respect to change in input feature',
    xaxis_title="Upstream activation",
    yaxis_title="Downstream activation",
    template='plotly_white',
    showlegend=False,
    width=800,
)
fig.show()

100%|██████████| 200/200 [00:01<00:00, 123.38it/s]
