In [2]:
# %% Import modules
import torch
torch.set_grad_enabled(False)
device = torch.device('cpu') #'cuda' if torch.cuda.is_available() else 

import plotly.express as px
import pandas as pd
import numpy as np
import einops
import importlib
import sys
from torch import Tensor
from jaxtyping import Float, Int, Bool
from typing import Callable

from transformer_lens import HookedTransformer
from transformer_lens.utils import get_act_name

from plotting import (
    single_head_full_resid_projection,
    ntensor_to_long,
    line_with_river
)
from load_data import get_prompts_t
from utils import (
    projection,
    cos_similarity,
    reinforcement_ratio
)

In [3]:
#%% Setup model & load data
model = HookedTransformer.from_pretrained('gelu-4l')
model.cfg.use_attn_result = True
model.to(device)

prompts_t = get_prompts_t()

Loaded pretrained model gelu-4l into HookedTransformer
Moving model to device:  cpu
Loading 80 prompts from c4-tokenized-2b...


  0%|          | 0/80 [00:00<?, ?it/s]

Loading 20 prompts from code-tokenized...


  0%|          | 0/20 [00:00<?, ?it/s]

## Project H0.2 to pos 0
Collect attn result of H0.2 by pos and project them to pos 0.

In [7]:
BATCH_SIZE = 30


In [4]:
def plot_H02_proj_on_pos0(
    prompts_t: Int[Tensor, "batch seq"],
    model: HookedTransformer,
    proj_func: Callable,
    batch_size: int,
    by_prompt: bool = False,
):

    layer, head = 0, 2

    _, cache = model.run_with_cache(
        prompts_t[:batch_size],
        names_filter=lambda name: name in get_act_name("result", layer),
    )

    attn_result_H0_2: Float[Tensor, "batch pos dmodel"] = cache["result", 0][:, :, head]

    projections = torch.zeros(BATCH_SIZE, model.cfg.n_ctx)

    for pos_idx in range(model.cfg.n_ctx):
        projections[:, pos_idx] = proj_func(
            attn_result_H0_2[:, None, pos_idx],
            attn_result_H0_2[:, None, 0],
        ).squeeze(dim=1)


    if not by_prompt:
        return px.line(
            projections.mean(dim=0),
            title=f"H0.2 each pos attn out project onto pos 0, mean across {batch_size} prompts using {proj_func.__name__}",
        )
    else:
        # plot by prompt
        df = ntensor_to_long(projections)
        df.columns = ["projection_value", "prompts", "pos"]

        return px.line(
            df,
            x="pos",
            y="projection_value",
            color="prompts",
            title=f"H0.2 each pos attn out project onto pos 0, mean across {batch_size} prompts using {proj_func.__name__}",
        )

In [49]:
plot_H02_proj_on_pos0(prompts_t, model, reinforcement_ratio, BATCH_SIZE, by_prompt=False).show()
plot_H02_proj_on_pos0(prompts_t, model, projection, BATCH_SIZE, by_prompt=False).show()
plot_H02_proj_on_pos0(prompts_t, model, cos_similarity, BATCH_SIZE, by_prompt=False).show()

plot_H02_proj_on_pos0(prompts_t, model, reinforcement_ratio, BATCH_SIZE, by_prompt=True).show()
plot_H02_proj_on_pos0(prompts_t, model, projection, BATCH_SIZE, by_prompt=True).show()
plot_H02_proj_on_pos0(prompts_t, model, cos_similarity, BATCH_SIZE, by_prompt=True).show()

In [9]:
def plot_H2X_proj_on_H02pos0(
    prompts_t: Int[Tensor, "batch seq"],
    model: HookedTransformer,
    proj_func: Callable,
    batch_size: int,
    by_prompt: bool = False,
    sum_up: bool = False,
):

    layer, head = 0, 2

    _, cache = model.run_with_cache(
        prompts_t[:batch_size],
        names_filter=lambda name: "result" in name,
    )

    attn_result_H2_X: Float[Tensor, "batch pos head dmodel"] = cache["result", 2]
    attn_result_H0_2: Float[Tensor, "batch pos dmodel"] = cache["result", 0][:, :, head]

    projections = torch.zeros(BATCH_SIZE, model.cfg.n_ctx, model.cfg.n_heads)

    for head_idx in range(model.cfg.n_heads):
        for pos_idx in range(model.cfg.n_ctx):
            projections[:, pos_idx, head_idx] = proj_func(
                attn_result_H2_X[:, None, pos_idx, head_idx],
                attn_result_H0_2[:, None, 0],
            ).squeeze(dim=1)

    # plot by head
    if sum_up:
        return px.line(-projections.mean(dim=0).sum(dim=-1))

    else:
        df = ntensor_to_long(projections.mean(dim=0))
        df.columns = ["projection_value", "pos", "head"]

        return px.line(
            df,
            x="pos",
            y="projection_value",
            color="head",
            title=f"H2.X each pos attn out project onto H0.2 pos 0, mean across {batch_size} prompts using {proj_func.__name__}",
        )

In [56]:
plot_H2X_proj_on_H02pos0(prompts_t, model, reinforcement_ratio, BATCH_SIZE, by_prompt=False).show()


In [11]:
plot_H2X_proj_on_H02pos0(
    prompts_t, model, reinforcement_ratio, BATCH_SIZE, by_prompt=False,
    sum_up=True
).show()
