# Direct & Indirect effect plots

This notebook is for creating the bar chart which will be used in the paper, to compare direct & indirect effects of head 10.7.

In [1]:
import os, sys
from pathlib import Path
p = Path(r"/root/SERI-MATS-2023-Streamlit-pages")
if os.path.exists(str_p := str(p.resolve())):
    os.chdir(str_p)
    if str_p not in sys.path:
        sys.path.append(str_p)

from transformer_lens.cautils.notebook import *

from transformer_lens.rs.callum2.generate_st_html.model_results import (
    get_result_mean,
    get_model_results,
    get_model_results_batched,
    ModelResults,
    HeadResults,
    LayerResults,
    DictOfHeadResults,
)

from transformer_lens.rs.callum2.utils import (
    parse_str,
    parse_str_toks_for_printing,
    kl_div,
    process_webtext,
)

def create_dict_for_table(
    results: DictOfHeadResults,
    ln_mode: str = "frozen",
    head: Tuple[int, int] = (10, 7)
) -> Dict[str, Tensor]:
    L, H = head
    return {
        k[0]: v[L, H]
        for k, v in results.items()
        if k[1:] == (ln_mode, "mean")
    }

clear_output()

In [2]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device="cuda",
)
model.set_use_split_qkv_input(False)
model.set_use_attn_result(True)

clear_output()

In [3]:
BATCH_SIZE = 500 # 80 for viz
SEQ_LEN = 1000 # 61 for viz
batch_idx = 36

DATA_TOKS, DATA_STR_TOKS_PARSED = process_webtext(seed=1, batch_size=BATCH_SIZE, seq_len=SEQ_LEN, verbose=True, model=model)

Shape = torch.Size([500, 1000])

First prompt:
<|endoftext|>History Edit

In 1914, Senator Robert La Follette Sr. and Representative John M. Nelson, both of Wisconsin, promoted the inclusion in the legislative, executive, and judicial appropriations act of a provision directing the establishment of a special reference unit within the Library of Congress.[5] Building upon a concept developed by the New York State Library in 1890, and the Wisconsin Legislative Reference Library in 1901, they were motivated by Progressive era ideas about the importance of the acquisition of knowledge for an informed and independent legislature.[4] The move also reflected the expanding role of the librarian and the professionalization of the profession.[4] The new department was charged with responding to congressional requests for information.[4] The legislation authorized the Librarian of Congress, Herbert Putnam, to "employ competent persons to prepare such indexes, digests, and compilations of laws as 

In [4]:
mini_seq_len = 800

loss = t.empty(size=(0, mini_seq_len-1)).to(device)

for toks in tqdm(DATA_TOKS[:, :mini_seq_len].split(10)):
    loss = t.concat([loss, model.forward(toks, return_type="loss", loss_per_token=True)])

print(f"Mean loss = {loss.mean():.4f}")

  0%|          | 0/50 [00:00<?, ?it/s]

Mean loss = 3.0289


In [4]:
result_mean = get_result_mean(
    head_list=[(10, 7)],
    toks=DATA_TOKS[:, :800],
    model=model,
    minibatch_size=10,
    keep_seq_dim=True,
    verbose=True
)

model_results = get_model_results_batched(
    model = model,
    toks = DATA_TOKS[:, :800],
    max_batch_size = 1,
    negative_heads = [(10, 7)],
    use_cuda = True,
    store_in_cuda = False,
    verbose = True,
    effective_embedding = "W_E (including MLPs)",
    result_mean = result_mean,
    keep_logits = False,
    keep_seq_dim_when_mean_ablating = True,
)

loss_diffs_dict = create_dict_for_table(model_results.loss_diffs)
kl_divs_dict = create_dict_for_table(model_results.kl_divs)

table = Table("Type of ablation", "Average effect on loss", "Average absolute effect on loss", "Average KL div from clean", title="Results for bar chart!")
for k, v in loss_diffs_dict.items():
    table.add_row(k.capitalize(), f"{v.mean():.4f}", f"{v.abs().mean():.4f}", f"{kl_divs_dict[k].mean():.5f}")
rprint(table)

  0%|          | 0/50 [00:00<?, ?it/s]

100%|██████████| 50/50 [00:15<00:00,  3.32it/s]
Batch 500/500, shape (1, 800), time 6.12s: 100%|██████████| 500/500 [27:22<00:00,  3.29s/it]


# Make actual figure for paper

Using the following ChatGPT prompt (the output was heavily amended):

Write plotly code for a figure which I will now describe.

It should be a bar chart, with 4 values: [x, x+0.0008, x+0.0016, x+0.0102], where x is a value that is externally defined (by default, x is 3.0289). The x-axis labels should be ["None", "Indirect paths<br>(except L11H10)", "Indirect paths<br>(all)", "Direct path"]. The x-axis title should be "Ablated paths". The y-axis title should be "Loss". The plotly template should be "simple_white". All bars should be the same color: salmon pink. The title should be "Comparison of different ablations for L10H7". The width is 800, height is 600. The y-axes should be appropriately set, so that the min is 0.001 below the min y-value, and the max is 0.004 above the max y-value. The bars should display the values "", "+0.008", "+0.0016", "+0.0102" inside of them, just below the top of each bar. You should see this text even if you don't hover over the graph. All the font sizes should be increased, so the text is visible.

I also have 4 svg files, called c1, c2, c3, c4 respectively, each of the form "data:image/svg+xml;base64," + (encoded_image). I want these images to be added to the image, each one just above one of the bars. The width of these images should be equal to the width of the bar, and the images' aspect ratios should be preserved.

You don't have to run any code, I just want to see the code.

In [5]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import base64
import numpy as np
from pathlib import Path
import base64

figures_path = Path(r"/root/SERI-MATS-2023-Streamlit-pages/transformer_lens/rs/callum2/generate_st_html/paper_figures").resolve()

In [15]:
def fig_v1_salmon():

    figmain = figures_path / "cmain.svg"
    fig1 = figures_path / "c1.svg"
    fig2 = figures_path / "c2.svg"
    fig3 = figures_path / "c3.svg"
    fig4 = figures_path / "c4.svg"
    fig_dict = {}

    for path in [figmain, fig1, fig2, fig3, fig4]:
        with open(path, "rb") as f:
            encoded_svg = base64.b64encode(f.read()).decode("utf-8")
        svg_data_uri = "data:image/svg+xml;base64," + encoded_svg
        fig_dict[path.stem] = svg_data_uri

    print(fig_dict.keys())

    c1 = fig_dict["c1"]
    c2 = fig_dict["c2"]
    c3 = fig_dict["c3"]
    c4 = fig_dict["c4"]
    cmain = fig_dict["cmain"]

    # Set x value
    x_value = 3.0289
    y_values = [x_value, x_value + 0.0008, x_value + 0.0016, x_value + 0.0102]
    x_labels = ["None", "Indirect paths<br>(except L11H10)", "Indirect paths<br>(all)", "Direct path"]

    # Compute y-axis range
    y_min = min(y_values) - 0.001
    y_max = max(y_values) + 0.005

    # Define the annotations for the bar values
    annotations = [
        dict(
            x=x_labels[i],
            y=y_val,
            text=txt,
            showarrow=False,
            yshift=10,
            font=dict(size=28)
        )
        for i, (y_val, txt) in enumerate(zip(y_values, ["", "+0.0008", "+0.0016", "+0.0102"]))
    ]

    # Define the list of encoded SVG images
    encoded_images = [c1, c2, c3, c4]

    # Create the figure
    fig = go.Figure()

    # Add the bars
    fig.add_trace(go.Bar(
        x=x_labels,
        y=y_values,
        text=["", "+0.0008", "+0.0016", "+0.0102"],
        textposition='inside',
        marker_color='salmon',
        hoverinfo='y',
    ))

    # Add the images
    for i, img in enumerate(encoded_images):
        fig.add_layout_image(
            dict(
                source=img,
                x=x_labels[i],
                y=y_values[i] + 0.0005,
                xref="x",
                yref="y",
                sizex=0.5,
                sizey=0.0035,  # Placeholder value
                xanchor="center",
                yanchor="bottom",
                layer="below",
                sizing="contain",
            )
        )

    fig.add_layout_image(
        dict(
            source=cmain,
            x=0.35,
            y=0.76,
            xref="paper",
            yref="paper",
            sizex=0.45,  # Assuming the bars have a width of 1
            sizey=0.45,  # Placeholder value
            xanchor="center",
            yanchor="middle",
            layer="below",
            sizing="contain",
        )
    )

    # Update the layout
    fig.update_layout(
        title="Comparison of different ablations for L10H7",
        xaxis_title="Ablated paths",
        yaxis_title="Loss",
        template="simple_white",
        width=1200,
        height=900,
        yaxis=dict(range=[y_min, y_max]),
        # annotations=annotations,
        font=dict(size=28),
        bargap=0.45,
        margin=dict(l=120, r=60, b=120, t=140),
        title_y=0.95,
        title_font_size=52,
        xaxis_title_standoff=30,
    )

    # Show the plot
    fig.show(config=dict(staticPlot=True))

# /root/numeric-representations/SERI-MATS-2023-Streamlit-pages/transformer_lens/rs/callum2/generate_st_html/paper_figures/cmain.svg

fig_v1_salmon()

dict_keys(['cmain', 'c1', 'c2', 'c3', 'c4'])


In [30]:
def fig_v2_kl(
    include_L11H10=False,
    include_total=True,
):

    figmain = figures_path / "cmain.svg"
    fig1 = figures_path / "c1.svg"
    fig2 = figures_path / "c2.svg"
    fig3 = figures_path / "c3.svg"
    fig4 = figures_path / "c4.svg"
    fig3p = figures_path / "c3.pdf"
    fig_dict = {}

    for path in [figmain, fig1, fig2, fig3, fig4, fig3p]:
        with open(path, "rb") as f:
            encoded_svg = base64.b64encode(f.read()).decode("utf-8")
        svg_data_uri = "data:image/svg+xml;base64," + encoded_svg
        fig_dict[path.stem + (".pdf" if "pdf" in path.suffix else "")] = svg_data_uri

    print(fig_dict.keys())

    c1 = fig_dict["c1"]
    c2 = fig_dict["c2"]
    c3 = fig_dict["c3"]
    c4 = fig_dict["c4"]
    c3p = fig_dict["c3.pdf"]
    cmain = fig_dict["cmain"]

    # Set x value
    x_value = 3.0289
    y_values = [x_value, x_value + 0.0008, x_value + 0.0016, x_value + 0.0102]
    # y_str_values = [f"{y:.4f}" for y in y_values]
    y_str_values = ["", "+0.0008", "+0.0016", "+0.0102"]
    y2_values = [0, 0.00297, 0.00323, 0.01449]
    y2_str_values = [f"{y:.4f}" for y in y2_values]
    # y2_str_values = ["", "0.0030", "0.0032", "0.0145"]
    x_labels = ["None", "Indirect paths<br>(except L11H10)", "Indirect paths", "Direct path"]

    # Compute y-axis range
    y_min = min(y_values)
    y_max = max(y_values) + 0.005

    # Define the list of encoded SVG images
    encoded_images = [c1, c2, c3, c4]

    F = 0.01025 / 0.0145
    factors = [1, F * 30/8, F * 32/16,  F * (0.0145 / 0.0102)] # for orange/blue y balance

    if not include_L11H10:
        x_labels.pop(1)
        y_values.pop(1)
        y2_values.pop(1)
        y_str_values.pop(1)
        encoded_images.pop(1)
        factors.pop(1)

    if not include_total:
        x_labels.pop(-1)
        y_values.pop(-1)
        y2_values.pop(-1)
        y_str_values.pop(-1)
        encoded_images.pop(-1)
        factors.pop(-1)

    # Create the figure, add the bars
    width = 0.4
    fig = make_subplots(specs=[[{"secondary_y": True}]])
    fig.add_trace(
        go.Bar(
            x=np.arange(len(x_labels)) - width/2,
            y=y_values,
            text=y_str_values,
            textposition='inside',
            marker_color='#1F77B4',
            hoverinfo='y',
            name="Loss",
        ),
        secondary_y = False,
    )
    fig.add_trace(
        go.Bar(
            x=np.arange(len(x_labels)) + width/2,
            y=y2_values,
            text=y2_str_values,
            textposition='inside',
            marker_color='#FF7F0E',
            hoverinfo='y',
            name="KL Divergence"
        ),
        secondary_y = True,
    )
    
    # Add the images
    for i, img in enumerate(encoded_images):
        fig.add_layout_image(
            dict(
                source=img,
                x=np.arange(len(x_labels))[i],
                y=x_value + (y_values[i] - x_value) * factors[i] + 0.0003,
                xref="x",
                yref="y",
                sizex=0.5,
                sizey=0.003, # 0.003 for no legend
                xanchor="center",
                yanchor="bottom",
                layer="below",
                sizing="contain",
            )
        )

    fig.add_layout_image(
        dict(
            source=cmain,
            x=0.35,
            y=0.76,
            xref="paper",
            yref="paper",
            sizex=0.72 * 0.8,
            sizey=0.5 * 0.8, # 0.4 for no legend
            xanchor="center",
            yanchor="middle",
            layer="below",
            sizing="contain",
        )
    )

    # Update the layout
    fig.update_layout(
        barmode="group",
        # xaxis_title="Ablated paths", # Arthur thinks this is covered by the title
        yaxis_title="Loss",
        template="simple_white",
        width=1200,
        height=900,
        
        yaxis=dict(
            range=[y_min, y_max],
            title_text="<b>Loss</b>",
            title_font_size=48,
            title_font_color="#1F77B4",
            tickfont=dict(size=28),
        ),
        yaxis2=dict(
            range=[0.0, 0.0215],
            title_text="<b>KL Divergence</b>",
            title_font_size=48,
            title_font_color="#FF7F0E",
            title_standoff=30,
            tickfont=dict(size=28),
        ),
        # annotations=annotations,
        font=dict(size=20),
        bargap=0.6,
        margin=dict(l=120, r=80, b=120, t=140),
        title=dict(
            text="Ablating different paths from L10H7",
            y=0.95,
            font_size=52,
        ),
        xaxis=dict(
            title_standoff=30,
            tickvals=np.arange(len(x_labels)),
            ticktext=x_labels,
            tickfont=dict(size=28),
        ),
        showlegend=False,
    )



    return fig

fig = fig_v2_kl()
fig.show(config=dict(staticPlot=True))
fig.write_image(figures_path / "fig_kl.pdf")
fig.write_image(figures_path / "fig_kl.svg")




dict_keys(['cmain', 'c1', 'c2', 'c3', 'c4', 'c3.pdf'])


In [32]:
from transformer_lens.loading_from_pretrained import get_pretrained_model_config
get_pretrained_model_config("gelu-4l")

In [33]:
get_pretrained_model_config("gelu-4l")

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.28k [00:00<?, ?B/s]

HookedTransformerConfig:
{'act_fn': 'gelu',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 64,
 'd_mlp': 2048,
 'd_model': 512,
 'd_vocab': 48262,
 'd_vocab_out': 48262,
 'device': 'cuda',
 'eps': 1e-05,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': 0.035355339059327376,
 'model_name': 'GELU_4L512W_C4_Code',
 'n_ctx': 1024,
 'n_devices': 1,
 'n_heads': 8,
 'n_layers': 4,
 'n_params': 12582912,
 'normalization_type': 'LN',
 'original_architecture': 'neel',
 'parallel_attn_mlp': False,
 'positional_embedding_type': 'standard',
 'rotary_dim': None,
 'scale_attn_by_inverse_layer_idx': False,
 'seed': None,
 'tokenizer_name': 'NeelNanda/gpt-neox-tokenizer-digits',
 'use_attn_in': False,
 'use_attn_result': False,
 'use_attn_scale': True,
 'use_hook_mlp_in': False,
 'use_hook_tokens':