# setup

In [1]:
%load_ext autoreload
%autoreload 2
%env CUBLAS_WORKSPACE_CONFIG=:4096:8

env: CUBLAS_WORKSPACE_CONFIG=:4096:8


In [2]:
from teren import dir_act_utils as dau, metric
from teren import utils as tu
from teren import direction, experiment_context
from transformer_lens import HookedTransformer
from teren.typing import *

device = tu.get_device_str()
print(f"{device=}")

device='cuda'


In [3]:
LAYER = 11
SEQ_LEN = 4
INFERENCE_TOKENS = 12_800
SEED = 0
tu.setup_determinism(SEED)
INFERENCE_BATCH_SIZE = INFERENCE_TOKENS // SEQ_LEN
print(f"{INFERENCE_BATCH_SIZE=}")

N_PROMPTS = INFERENCE_BATCH_SIZE


input_ids = dau.get_input_ids(chunk=0, seq_len=SEQ_LEN)[:N_PROMPTS]
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

INFERENCE_BATCH_SIZE=3200
Loaded pretrained model gpt2-small into HookedTransformer


In [4]:
exctx = experiment_context.ExperimentContext(
    model=model,
    layer=LAYER,
    input_ids=input_ids,
    acts_q_range=(0.01, 0.95),
    n_act=15,
    batch_size=INFERENCE_BATCH_SIZE,
)

  0%|          | 0/1 [00:00<?, ?it/s]

In [7]:
dirs = []
i = 0
while True:
    dir = torch.randn(model.cfg.d_model)
    dir /= dir.norm()
    dir = direction.Direction(dir, exctx)
    dirs.append(dir)
    dir.process_metric_mm(metric.jsd_metric)
    n_above_thresh = dir.res_by_metric[metric.jsd_metric].mm_sel.shape[0]
    dir.process_metric_cvx(metric.jsd_metric)
    print(f"{i=}, act range: {dir.act_min:.3f}, {dir.act_max:.3f}, {n_above_thresh=}")
    i += 1

i=0, act range: 0.027, 6.310, n_above_thresh=0
i=1, act range: 0.092, 2.240, n_above_thresh=0
i=2, act range: 0.083, 26.304, n_above_thresh=5772
i=3, act range: 0.049, 6.896, n_above_thresh=4
i=4, act range: 0.019, 3.277, n_above_thresh=0
i=5, act range: 0.591, 13.284, n_above_thresh=182
i=6, act range: 0.111, 93.209, n_above_thresh=12039
i=7, act range: 0.998, 16.507, n_above_thresh=884
i=8, act range: 1.009, 2.797, n_above_thresh=0
i=9, act range: 0.016, 3.747, n_above_thresh=0
i=10, act range: 0.036, 6.516, n_above_thresh=1
i=11, act range: 0.130, 171.639, n_above_thresh=12070
i=12, act range: 0.040, 6.056, n_above_thresh=1
i=13, act range: 0.534, 43.867, n_above_thresh=10402
i=14, act range: 4.979, 164.508, n_above_thresh=12005
i=15, act range: 0.029, 5.865, n_above_thresh=1
i=16, act range: 6.403, 31.295, n_above_thresh=5457
i=17, act range: 0.105, 92.809, n_above_thresh=11988
i=18, act range: 0.076, 28.590, n_above_thresh=7025
i=19, act range: 0.116, 9.046, n_above_thresh=24
i=20

KeyboardInterrupt: 

In [10]:
len(dirs)

98

In [11]:
mm_hist = sum(dir.res_by_metric[metric.jsd_metric].mm_hist for dir in dirs)

In [12]:
hist_by_name = {
    "random": mm_hist,
}

color_by_name = {
    "sae": "255, 0, 0",
    "random": "0, 255, 0",
    "svd": "0, 0, 255",
}

In [38]:
import numpy as np
import plotly.graph_objects as go
import plotly.express as px


def plot_hist(hist_by_name, color_by_name, what):
    fig = go.Figure()
    # Add traces for each line and its shaded area
    for name, hist in hist_by_name.items():
        color = color_by_name[name]
        line_color = f"rgb({color})"
        shade_color = f"rgba({color}, 0.2)"
        fig.add_trace(
            go.Scatter(
                x=np.arange(len(hist)),
                y=hist / hist.sum(),
                line=dict(color=line_color, width=2),
                name=name,
                fill="tozeroy",  # Fill to y=0
                fillcolor=shade_color,  # Semi-transparent color
            )
        )
    title_params = f"{exctx.acts_q_range[0]*100:.0f}% and {exctx.acts_q_range[1]*100:.0f}%<br>(layer {LAYER}, {len(dirs)} dirs per type, {N_PROMPTS*SEQ_LEN//1000}k tokens)"
    fig.update_layout(
        title=f"distribution of {what} between activations set to {title_params}",
        xaxis_title=what,
        yaxis_title="density",
        legend_title="dirs type",
    )
    fig.show()

In [32]:
plot_hist(hist_by_name, color_by_name, "JSD")

In [36]:
cvx_hist = sum(
    torch.histogram(
        dir.res_by_metric[metric.jsd_metric].cvx_score,
        bins=100,
        range=(0.0, 1.0),
    )[0]
    for dir in dirs
)
cvx_hist_by_name = {
    "random": cvx_hist,
}
plot_hist(cvx_hist_by_name, color_by_name, "convexity score")

In [37]:
cvx_act_hist = sum(
    torch.histogram(
        dir.res_by_metric[metric.jsd_metric].cvx_act.float(),
        bins=exctx.n_act,
        range=(0.0, exctx.n_act - 1),
    )[0]
    for dir in dirs
)
cvx_act_hist_by_name = {
    "random": cvx_act_hist,
}
plot_hist(cvx_act_hist_by_name, color_by_name, "act lvl A maximizing convexity score")

In [42]:
mm_sel = torch.cat([dir.res_by_metric[metric.jsd_metric].mm_sel for dir in dirs])
# count combinations of seq_in and seq_out
mm_sel = mm_sel.cpu().numpy()
unique, counts = np.unique(mm_sel[:, 1:], axis=0, return_counts=True)

In [43]:
unique, counts

(array([[0, 0],
        [1, 1],
        [1, 2],
        [1, 3],
        [2, 2],
        [2, 3],
        [3, 3]]),
 array([130162, 114456,    209,    154, 103211,    159,  95686]))

In [None]:
arr

In [None]:
px.scatter(
    x=mm_sel[:, 1],
    y=mm_sel[:, 2],
    labels={"x": "seq_in", "y": "seq_out"},
    title="seq_in vs seq_out for highest JSD",
)