# Direct & Indirect effect plots

This notebook is for creating the bar chart which will be used in the paper, to compare direct & indirect effects of head 10.7.

In [22]:
import os, sys
from pathlib import Path
p = Path(r"/home/ubuntu/SERI-MATS-2023-Streamlit-pages")
if os.path.exists(str_p := str(p.resolve())):
    os.chdir(str_p)
    if str_p not in sys.path:
        sys.path.append(str_p)

from transformer_lens.cautils.notebook import *

from transformer_lens.rs.callum2.generate_st_html.model_results import (
    get_model_results,
    get_model_results_batched,
    ModelResults,
    HeadResults,
    LayerResults,
    DictOfHeadResults,
)

from transformer_lens.rs.callum2.cspa.cspa_utils import (
    parse_str,
    parse_str_toks_for_printing,
    kl_div,
    get_result_mean,
)

def create_dict_for_table(
    results: DictOfHeadResults,
    ln_mode: str = "frozen",
    head: Tuple[int, int] = (10, 7)
) -> Dict[str, Tensor]:
    L, H = head
    return {
        k[0]: v[L, H]
        for k, v in results.items()
        if k[1:] == (ln_mode, "mean")
    }

In [2]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device="cuda",
)
model.set_use_split_qkv_input(False)
model.set_use_attn_result(True)

clear_output()

In [3]:
BATCH_SIZE = 200 # 80 for viz
SEQ_LEN = 500 # 61 for viz
batch_idx = 36

def process_webtext(
    seed: int = 6,
    batch_size: int = BATCH_SIZE,
    indices: Optional[List[int]] = None,
    seq_len: int = SEQ_LEN,
    verbose: bool = False,
):
    DATA_STR = get_webtext(seed=seed)
    if indices is None:
        DATA_STR = DATA_STR[:batch_size]
    else:
        DATA_STR = [DATA_STR[i] for i in indices]
    DATA_STR = [parse_str(s) for s in DATA_STR]

    DATA_TOKS = model.to_tokens(DATA_STR)
    DATA_STR_TOKS = model.to_str_tokens(DATA_STR)

    if seq_len < 1024:
        DATA_TOKS = DATA_TOKS[:, :seq_len]
        DATA_STR_TOKS = [str_toks[:seq_len] for str_toks in DATA_STR_TOKS]

    DATA_STR_TOKS_PARSED = list(map(parse_str_toks_for_printing, DATA_STR_TOKS))

    clear_output()
    if verbose:
        print(f"Shape = {DATA_TOKS.shape}\n")
        print("First prompt:\n" + "".join(DATA_STR_TOKS[0]))

    return DATA_TOKS, DATA_STR_TOKS_PARSED

DATA_TOKS, DATA_STR_TOKS_PARSED = process_webtext(verbose=True, seed=1)

Shape = torch.Size([200, 500])

First prompt:
<|endoftext|>Story highlights NASA probe gives insight into Mercury's formation

The surface has sulfur levels 10 times higher than Earth's

The MESSENGER orbiter has been circling Mercury since March 2011

X-ray data from NASA's MESSENGER probe points to high levels of magnesium and sulfur on the surface of the planet Mercury, suggesting its makeup is far different from that of other planets, scientists say.

The unmanned orbiter has been beaming back data from the first planet for a year and a half. Readings from its X-ray spectrometer point to a planet whose northern volcanic plains formed through upwellings of rocks more exotic than those often found on the Earth, the Moon or Mars, said Shoshana Weider, a researcher at the Carnegie Institution of Washington.

"Before this MESSENGER mission, a lot of people assumed it was very like the Moon -- it's dark, it's grey," Weider said. But while the Moon's surface formed when light materials fl

In [8]:
result_mean = get_result_mean(
    head_list=[(10, 7)],
    toks=DATA_TOKS[:100],
    model=model,
    minibatch_size=10,
    keep_seq_dim=True,
    verbose=True
)

model_results = get_model_results_batched(
    model = model,
    toks = DATA_TOKS[:100, :500],
    max_batch_size = 4,
    negative_heads = [(10, 7)],
    use_cuda = True,
    store_in_cuda = False,
    verbose = True,
    effective_embedding = "W_E (including MLPs)",
    result_mean = result_mean,
    keep_logits = False,
    keep_seq_dim_when_mean_ablating = True,
)

loss_diffs_dict = create_dict_for_table(model_results.loss_diffs)
kl_divs_dict = create_dict_for_table(model_results.kl_divs)

table = Table("Type of ablation", "Average effect on loss", "Average absolute effect on loss", "Average KL div from clean", title="Results for bar chart!")
for k, v in loss_diffs_dict.items():
    table.add_row(k.capitalize(), f"{v.mean():.4f}", f"{v.abs().mean():.4f}", f"{kl_divs_dict[k].mean():.5f}")
rprint(table)

In [10]:
# result_mean = get_result_mean(
#     head_list=[(10, 7)],
#     toks=DATA_TOKS[:100],
#     model=model,
#     minibatch_size=10,
#     keep_seq_dim=True,
#     verbose=True
# )

# model_results = get_model_results_batched(
#     model = model,
#     toks = DATA_TOKS[:100, :500],
#     max_batch_size = 4,
#     negative_heads = [(10, 7)],
#     use_cuda = True,
#     store_in_cuda = False,
#     verbose = True,
#     effective_embedding = "W_E (including MLPs)",
#     result_mean = result_mean,
#     keep_logits = False,
#     keep_seq_dim_when_mean_ablating = True,
# )

loss_diffs_dict = create_dict_for_table(model_results.loss_diffs, ln_mode="unfrozen")
kl_divs_dict = create_dict_for_table(model_results.kl_divs, ln_mode="unfrozen")

table = Table("Type of ablation", "Average effect on loss", "Average absolute effect on loss", "Average KL div from clean", title="Results for bar chart!")
for k, v in loss_diffs_dict.items():
    table.add_row(k.capitalize(), f"{v.mean():.4f}", f"{v.abs().mean():.4f}", f"{kl_divs_dict[k].mean():.5f}")
rprint(table)

In [16]:
result_mean = get_result_mean(
    head_list=[(10, 7)],
    toks=DATA_TOKS[:100],
    model=model,
    minibatch_size=10,
    keep_seq_dim=True,
    verbose=True
)

model_results = get_model_results_batched(
    model = model,
    toks = DATA_TOKS[:100, :500],
    max_batch_size = 4,
    negative_heads = [(10, 7)],
    use_cuda = True,
    store_in_cuda = False,
    verbose = True,
    effective_embedding = "W_E (including MLPs)",
    result_mean = result_mean,
    keep_logits = False,
    keep_seq_dim_when_mean_ablating = True,
)

loss_diffs_dict = create_dict_for_table(model_results.loss_diffs, ln_mode="frozen")
kl_divs_dict = create_dict_for_table(model_results.kl_divs, ln_mode="frozen")

table = Table("Type of ablation", "Average effect on loss", "Average absolute effect on loss", "Average KL div from clean", title="Results for bar chart!")
for k, v in loss_diffs_dict.items():
    table.add_row(k.capitalize(), f"{v.mean():.4f}", f"{v.abs().mean():.4f}", f"{kl_divs_dict[k].mean():.5f}")
rprint(table)

100%|██████████| 10/10 [00:01<00:00,  6.29it/s]
Batch 1/25, shape (4, 500), time 0.88s:   4%|▍         | 1/25 [01:12<28:53, 72.23s/it]
Batch 25/25, shape (4, 500), time 1.49s: 100%|██████████| 25/25 [00:29<00:00,  1.17s/it]


In [17]:
loss_diffs_dict = create_dict_for_table(model_results.loss_diffs, ln_mode="unfrozen")
kl_divs_dict = create_dict_for_table(model_results.kl_divs, ln_mode="unfrozen")

table = Table("Type of ablation", "Average effect on loss", "Average absolute effect on loss", "Average KL div from clean", title="Results for bar chart!")
for k, v in loss_diffs_dict.items():
    table.add_row(k.capitalize(), f"{v.mean():.4f}", f"{v.abs().mean():.4f}", f"{kl_divs_dict[k].mean():.5f}")
rprint(table)

In [23]:
result_mean = get_result_mean(
    head_list=[(11, 10)],
    toks=DATA_TOKS[:100],
    model=model,
    minibatch_size=10,
    keep_seq_dim=True,
    verbose=True
)

model_results = get_model_results_batched(
    model = model,
    toks = DATA_TOKS[:100, :500],
    max_batch_size = 4,
    negative_heads = [(11, 10)],
    use_cuda = True,
    store_in_cuda = False,
    verbose = True,
    effective_embedding = "W_E (including MLPs)",
    result_mean = result_mean,
    keep_logits = False,
    keep_seq_dim_when_mean_ablating = True,
)

loss_diffs_dict = create_dict_for_table(model_results.loss_diffs, ln_mode="frozen", head=(11, 10))
kl_divs_dict = create_dict_for_table(model_results.kl_divs, ln_mode="frozen", head=(11, 10))

table = Table("Type of ablation", "Average effect on loss", "Average absolute effect on loss", "Average KL div from clean", title="Results for bar chart!")
for k, v in loss_diffs_dict.items():
    table.add_row(k.capitalize(), f"{v.mean():.4f}", f"{v.abs().mean():.4f}", f"{kl_divs_dict[k].mean():.5f}")
rprint(table)

In [24]:
result_mean = get_result_mean(
    head_list=[(11, 10)],
    toks=DATA_TOKS[:100],
    model=model,
    minibatch_size=10,
    keep_seq_dim=True,
    verbose=True
)

model_results = get_model_results_batched(
    model = model,
    toks = DATA_TOKS[:100, :500],
    max_batch_size = 4,
    negative_heads = [(11, 10)],
    use_cuda = True,
    store_in_cuda = False,
    verbose = True,
    effective_embedding = "W_E (including MLPs)",
    result_mean = result_mean,
    keep_logits = False,
    keep_seq_dim_when_mean_ablating = True,
)

loss_diffs_dict = create_dict_for_table(model_results.loss_diffs, ln_mode="frozen", head=(11, 10))
kl_divs_dict = create_dict_for_table(model_results.kl_divs, ln_mode="frozen", head=(11, 10))

table = Table("Type of ablation", "Average effect on loss", "Average absolute effect on loss", "Average KL div from clean", title="Results for bar chart!")
for k, v in loss_diffs_dict.items():
    table.add_row(k.capitalize(), f"{v.mean():.4f}", f"{v.abs().mean():.4f}", f"{kl_divs_dict[k].mean():.5f}")
rprint(table)

100%|██████████| 10/10 [00:01<00:00,  6.29it/s]
Batch 1/25, shape (4, 500), time 0.77s:   4%|▍         | 1/25 [04:07<1:39:07, 247.82s/it]
Batch 25/25, shape (4, 500), time 1.42s: 100%|██████████| 25/25 [00:27<00:00,  1.08s/it]


In [25]:
result_mean = get_result_mean(
    head_list=[(11, 2)],
    toks=DATA_TOKS[:100],
    model=model,
    minibatch_size=10,
    keep_seq_dim=True,
    verbose=True
)

model_results = get_model_results_batched(
    model = model,
    toks = DATA_TOKS[:100, :500],
    max_batch_size = 4,
    negative_heads = [(11, 2)],
    use_cuda = True,
    store_in_cuda = False,
    verbose = True,
    effective_embedding = "W_E (including MLPs)",
    result_mean = result_mean,
    keep_logits = False,
    keep_seq_dim_when_mean_ablating = True,
)

loss_diffs_dict = create_dict_for_table(model_results.loss_diffs, ln_mode="frozen", head=(11, 2))
kl_divs_dict = create_dict_for_table(model_results.kl_divs, ln_mode="frozen", head=(11, 2))

table = Table("Type of ablation", "Average effect on loss", "Average absolute effect on loss", "Average KL div from clean", title="Results for bar chart!")
for k, v in loss_diffs_dict.items():
    table.add_row(k.capitalize(), f"{v.mean():.4f}", f"{v.abs().mean():.4f}", f"{kl_divs_dict[k].mean():.5f}")
rprint(table)

100%|██████████| 10/10 [00:01<00:00,  6.29it/s]
Batch 25/25, shape (4, 500), time 1.42s: 100%|██████████| 25/25 [00:27<00:00,  1.08s/it]
