# Confusion plot

Big question for this project - how often does it happen that 10.7 doesn't do copy suppression, in situations where we think it should have?

I could do a scatter plot, but for now I'll just do a simple confusion plot.

What are the conditions for when we think copy-suppression should take place?

1. The unembedding for a certain source token is larger than a certain threshold (so that it turns off attention to BOS)
2. The unembdding for this source token is larger than the unembeddings for all other source tokens (so that it takes all the attention)

There's only one hyperparameter here - `X`, the threshold.

I'll split points by whether this is true or not, and I'll also split points by whether they're attended to more than anything else (including BOS). I'll create a table of this. I hope that the table won't have many diagonal elements.

*(Note - when I say "BOS" in this document, that's a shorthand for the token at position zero, since the head doesn't seem to care what its identity is.)*

# Setup

In [1]:
# %pip install pattern --no-dependencies
# %pip install nltk
# %pip install protobuf==3.20.0

import os, sys
from pathlib import Path
p = Path(r"/home/ubuntu/SERI-MATS-2023-Streamlit-pages")
if os.path.exists(str_p := str(p.resolve())):
    os.chdir(str_p)
    if str_p not in sys.path:
        sys.path.append(str_p)

from transformer_lens.cautils.notebook import *
t.set_grad_enabled(False)

from transformer_lens.rs.callum2.cspa.cspa_functions import (
    FUNCTION_STR_TOKS,
    # project,
    get_cspa_results,
    get_cspa_results_batched,
)
from transformer_lens.rs.callum2.utils import (
    parse_str,
    parse_str_toks_for_printing,
    process_webtext,
    ST_HTML_PATH,
)
from transformer_lens.rs.callum2.cspa.cspa_plots import (
    generate_scatter,
    generate_loss_based_scatter,
    add_cspa_to_streamlit_page,
)
from transformer_lens.rs.callum2.generate_st_html.model_results import (
    get_result_mean,
    get_model_results,
)
from transformer_lens.rs.callum2.generate_st_html.generate_html_funcs import (
    generate_4_html_plots,
    CSS,
)
from transformer_lens.rs.callum2.cspa.cspa_semantic_similarity import (
    get_equivalency_toks,
    get_related_words,
    concat_lists,
    make_list_correct_length,
    create_full_semantic_similarity_dict,
)
clear_output()

In [2]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device="cuda",
    # fold value bias?
)
model.set_use_split_qkv_input(False)
model.set_use_attn_result(True)

clear_output()

In [18]:
BATCH_SIZE = 80 # 80 for viz
SEQ_LEN = 61 # 61 for viz
batch_idx = 36

NEGATIVE_HEADS = [(10, 7), (11, 10)]

DATA_TOKS, DATA_STR_TOKS_PARSED = process_webtext(seed=6, batch_size=BATCH_SIZE, seq_len=SEQ_LEN, model=model, verbose=True)

Shape = torch.Size([80, 61])

First prompt:
<|endoftext|>Oh boy was this damn hard to crack.

Ok, I believe before it was established before that Aperture Science headquarters are in Cleveland, OH.

Source: HL2EP2

Though, this has been found.

Source: Portal 2

It can be assumed


In [6]:
cspa_semantic_dict = pickle.load(open(ST_HTML_PATH.parent.parent / "cspa/cspa_semantic_dict_full.pkl", "rb"))

# Copy suppression condition

First, I'll ignore semantically related tokens, and just look at raw tokens.

Rather than a table, I actually think I'll do 2 violin plots, one for "should be CS" and one for "shouldn't be CS" (the values in each violin plot are attention). I can do this and a table pretty easily.

In [19]:
def get_data_for_table(
    toks: Int[Tensor, "batch seq"],
    model: HookedTransformer,
    threshold_logit_lens: Optional[float],
    threshold_cs_classification: Optional[float],
    minibatch_size: Optional[int] = None,
    head: Tuple[int, int] = (10, 7),
    filter_for_BOS_not_largest: bool = False,
    title: Optional[str] = None,
    use_tuned_lens: bool = False,
):
    '''
    Do a hooked forward pass on the tokens, getting the attention probs and the logit lens.

    Args:

        threshold_logit_lens
            The logits must be above this threshold for the point to be classified as "we expect CS here". This is
            interpreted as the value we must be higher than in order to override attn to BOS. If None, we don't apply
            any thresholding here (whatever the max logit lens src token is, we expect CS there).

        threshold_cs_classification
            The attn prob must be above this threshold for the point to be classified as "CS happening here". If None,
            we don't apply any thresholding here (whatever the max attn prob is, we expect CS there).

    '''
    layer, head_idx = head
    batch_size, seq_len = toks.shape
    model.reset_hooks()
    FUNCTION_TOKS = model.to_tokens(FUNCTION_STR_TOKS, prepend_bos=False).squeeze()

    # Create external storage, which we concatenate to
    external_storage = {
        "pattern": t.empty((0, seq_len, seq_len), device=toks.device, dtype=t.float),
        "logit_lens_is_above_threshold": t.empty((0, seq_len, seq_len), device=toks.device, dtype=t.bool),
    }

    # Get tuned lens if necessary
    W_U = model.W_U
    if use_tuned_lens:
        path = Path("/root/SERI-MATS-2023-Streamlit-pages/transformer_lens/rs/callum2/ov_qk_circuits/params.pt")
        params: Dict[str, Tensor] = t.load(path)
        id = t.eye(model.cfg.d_model)
        tuned_lens = params["10.weight"].T
        W_U = (id + tuned_lens).to(device) @ W_U


    # ! First we define hook functions to do most of the work for us

    def hook_fn_cache_attn_probs(pattern: Float[Tensor, "batch nheads seqQ seqK"], hook: HookPoint):
        '''
        Caches attention probs for a single head.
        '''
        external_storage["pattern"] = t.concat([
            external_storage["pattern"],
            pattern[:, head_idx]
        ])


    


    def hook_fn_compute_logit_lens(resid_pre: Float[Tensor, "batch seqK d_model"], hook: HookPoint, _toks: Int[Tensor, "batch seqK"]):
        '''
        Computes logit lens at the residual stream before the head, and figures out which (b, sQ, sK) should
        have copy suppression activated and which shouldn't.
        '''
        _batch_size, _seq_len = _toks.shape

        logit_lens_for_src = einops.einsum(
            resid_pre,
            W_U.T[_toks],
            "batch seqQ d_model, batch seqK d_model -> batch seqQ seqK"
        )
        # logit_lens_for_src[b, sQ, sK] = logits for src token sK at destination position (b, sQ)

        # We need to apply causal mask
        seqQ = einops.repeat(t.arange(seq_len, device=_toks.device), "seqQ -> 1 seqQ 1")
        seqK = einops.repeat(t.arange(seq_len, device=_toks.device), "seqK -> 1 1 seqK")
        logit_lens_for_src.masked_fill_(seqQ < seqK, float("-inf"))
        # We mask the attention from any token to BOS, and function words
        logit_lens_for_src[:, :, 0] = float("-inf")
        is_fn_word = (_toks[:, :, None] == FUNCTION_TOKS[None, None, :]).any(dim=-1)
        is_fn_word = einops.repeat(is_fn_word, "batch seqK -> batch seqQ seqK", seqQ=_seq_len)
        logit_lens_for_src = t.where(is_fn_word, float("-inf"), logit_lens_for_src)
        
        if threshold_logit_lens is None:
            # We also want to mask all but the best tokens (because attention is limited, and we only care about the src
            # token we think should be getting the most attention)
            top_src_token_logit_values = logit_lens_for_src.max(dim=-1, keepdim=True).values
            is_top_src_token = logit_lens_for_src + 1e-8 >= top_src_token_logit_values
            logit_lens_for_src.masked_fill_(~is_top_src_token, float("-inf"))
            logit_lens_is_above_threshold: Bool[Tensor, "batch seqQ seqK"] = logit_lens_for_src > 1e-6
        else:
            # In this case, we filter for attention being above some value
            logit_lens_for_src.masked_fill_(~is_top_src_token, float("-inf"))
            logit_lens_is_above_threshold: Bool[Tensor, "batch seqQ seqK"] = logit_lens_for_src > threshold_logit_lens

        # Finally, we store these indices in our results dict
        external_storage["logit_lens_is_above_threshold"] = t.concat([
            external_storage["logit_lens_is_above_threshold"],
            logit_lens_is_above_threshold
        ])

    
    # ! Next we run a fwd pass, to activate these hook fns

    toks_for_fwd_pass = (toks,) if (minibatch_size is None) else toks.split(minibatch_size, dim=0)

    for _toks in toks_for_fwd_pass:
        model.run_with_hooks(
            _toks,
            return_type = None,
            fwd_hooks = [
                (utils.get_act_name("pattern", layer), hook_fn_cache_attn_probs),
                (utils.get_act_name("resid_pre", layer), partial(hook_fn_compute_logit_lens, _toks=_toks)),
            ]
        )
        model.reset_hooks()
        t.cuda.empty_cache()

    # ! Get the (non-BOS) attention, split by whether it's in the above_threshold_indices or not

    pattern: Float[Tensor, "batch seqQ seqK_m1"] = external_storage.pop("pattern")
    BOS_pattern = pattern[..., 0]
    pattern = pattern[..., 1:]
    logit_lens_is_above_threshold: Bool[Tensor, "batch seqQ seqK_m1"] = external_storage.pop("logit_lens_is_above_threshold")[..., 1:]
    assert pattern.shape == logit_lens_is_above_threshold.shape == (batch_size, seq_len, seq_len - 1)

    pattern_above_threshold = pattern[logit_lens_is_above_threshold]
    pattern_below_threshold = pattern[~logit_lens_is_above_threshold]

    print(f"Avg attn above threshold: {pattern_above_threshold.mean():.3f}")
    print(f"Avg attn below threshold: {pattern_below_threshold.mean():.3f}")

    if threshold_cs_classification is None:
        attn_is_classified_as_cs: Bool[Tensor, "batch seqQ seqK_m1"] = (pattern + 1e-6) > pattern.max(dim=-1, keepdim=True).values
    elif isinstance(threshold_cs_classification, int):
        attn_is_classified_as_cs: Bool[Tensor, "batch seqQ seqK_m1"] = (pattern + 1e-6) > pattern.topk(k=threshold_cs_classification, dim=-1).values[..., [-1]]
    elif isinstance(threshold_cs_classification, float):
        attn_is_classified_as_cs: Bool[Tensor, "batch seqQ seqK_m1"] = pattern > threshold_cs_classification
    elif isinstance(threshold_cs_classification, t.Tensor):
        threshold_cs_classification = einops.repeat(threshold_cs_classification, "seqQ -> batch seqQ seqK", batch=batch_size, seqK=seq_len-1).to(toks.device)
        attn_is_classified_as_cs: Bool[Tensor, "batch seqQ seqK_m1"] = pattern > threshold_cs_classification

    # ! Possibly filter for when BOS isn't attended to (these are more interesting)

    if filter_for_BOS_not_largest:
        BOS_not_largest = einops.repeat(
            BOS_pattern < pattern.max(dim=-1).values,
            "batch seqQ -> batch seqQ seqK_m1",
            seqK_m1=seq_len-1
        )
        logit_lens_is_above_threshold = logit_lens_is_above_threshold[BOS_not_largest]
        attn_is_classified_as_cs = attn_is_classified_as_cs[BOS_not_largest]

    # ! Display the actual table, and return the patterns so we can make histograms (later I'll put hist in this function)

    table = Table(
        "", "CS is expected (logit lens)", "CS not expected (logit lens)", "% we would have guessed correctly",
        title = "Confusion plot" if title is None else title
    )

    result_yy = (logit_lens_is_above_threshold & attn_is_classified_as_cs).int().sum().item()
    result_yn = (~logit_lens_is_above_threshold & attn_is_classified_as_cs).int().sum().item()
    result_ny = (logit_lens_is_above_threshold & ~attn_is_classified_as_cs).int().sum().item()
    result_nn = (~logit_lens_is_above_threshold & ~attn_is_classified_as_cs).int().sum().item()

    def format_fraction(numerator, denominator, bold=False):
        return f"[bold dark_orange]{numerator/denominator:.1%}" if bold else f"{numerator/denominator:.1%}"

    table.add_row("[b]CS is happening (attn is max)", str(result_yy), str(result_yn), format_fraction(result_yy, result_yy+result_yn))
    table.add_row("[b]CS not happening (attn not max)", str(result_ny), str(result_nn), format_fraction(result_nn, result_nn+result_ny))
    table.add_row("[b]% we would have guessed correctly", format_fraction(result_yy, result_yy+result_ny, bold=True), format_fraction(result_nn, result_nn+result_yn, bold=True), "")
    rprint(table)

    # * Deleting violin plot for the time being, because it's way too squashed at zero to present interesting information
    # df = pd.DataFrame({
    #     "Copy suppression expected (from logit lens)": ["Yes" if i else "No" for i in logit_lens_is_above_threshold.flatten()],
    #     "Attention probabilities": pattern.flatten().tolist(),
    # })
    # fig = px.violin(df, x="Copy suppression expected (from logit lens)", y="Attention probabilities", box=True, points="all")
    # fig.show()

#### First experiment - when we think sK will be maximally attended to, how often is it maximally attended to?

Answer - about 38% of the time we make correct classifications. So, there's a lot of times when CS is expected but it doesn't happen.

I should investigate the "Browse Examples" page to figure out whether we should actually have reasonably expected it in these circumstances. But first, I'll run a few more experiments.

In [20]:
get_data_for_table(
    toks = DATA_TOKS,
    model = model,
    threshold_logit_lens = None,
    threshold_cs_classification = None,
    head = (10, 7),
    title = "Max logit lens & Max attn"
)

Avg attn above threshold: 0.075
Avg attn below threshold: 0.004


#### Second experiment - what if we only look at cases when BOS isn't attended to?

Result - much better! This suggests we understand well what happens when the head does shit, the only thing we don't understand is when it prefers to attend to BOS. This fits with my own personal model - the head has ways of switching off if it decides that attending to BOS isn't good.

In [21]:
get_data_for_table(
    toks = DATA_TOKS,
    model = model,
    threshold_logit_lens = None,
    threshold_cs_classification = None,
    head = (10, 7),
    filter_for_BOS_not_largest = True, # removing all cases when BOS is the most attended to
    title = "Max logit lens & Max attn (filter for BOS not largest)",
)

Avg attn above threshold: 0.075
Avg attn below threshold: 0.004


: 

In [17]:
get_data_for_table(
    toks = DATA_TOKS,
    model = model,
    threshold_logit_lens = None,
    threshold_cs_classification = None,
    head = (10, 7),
    filter_for_BOS_not_largest = True, # removing all cases when BOS is the most attended to
    title = "Max logit lens & Max attn (filter for BOS not largest)",
    use_tuned_lens = True,
)

Avg attn above threshold: 0.063
Avg attn below threshold: 0.002


#### Third experiment - top 3, not just top 1

Makes both the experiments above look better

In [137]:
get_data_for_table(
    toks = DATA_TOKS,
    model = model,
    threshold_logit_lens = None,
    threshold_cs_classification = 3, # classify as CS if attn is in the top 3, rather than just the max
    head = (10, 7),
    title = "Max logit lens & Top-3 attn",
)

get_data_for_table(
    toks = DATA_TOKS,
    model = model,
    threshold_logit_lens = None,
    threshold_cs_classification = 3, # classify as CS if attn is in the top 3, rather than just the max
    head = (10, 7),
    filter_for_BOS_not_largest = True,
    title = "Max logit lens & Top-3 attn (filter for BOS not largest)",
)

Avg attn above threshold: 0.075
Avg attn below threshold: 0.004


Avg attn above threshold: 0.075
Avg attn below threshold: 0.004


#### Fourth experiment - classify CS by prob absolute value, not prob relative rank

I'll classify smth as CS if it has 3x the uniform probability of $1/N$ (where $N$ is sequence length).

Result - also pretty excellent.

In [138]:
get_data_for_table(
    toks = DATA_TOKS,
    model = model,
    threshold_logit_lens = None,
    threshold_cs_classification = 5 / t.arange(1, SEQ_LEN+1),
    head = (10, 7),
    filter_for_BOS_not_largest = True, # removing all cases when BOS is the most attended to
    title = "Max logit lens & Max attn (filter for BOS not largest)"
)

Avg attn above threshold: 0.075
Avg attn below threshold: 0.004


# Spearman

Another cool idea - measure spearman corrcoef between our expected rank orderings of the attn score (just based on logit lens), and the actual rank orderings. This is an elegant way to ignore BOS.

In [31]:
a = t.rand(3, 4)
a

tensor([[0.5441, 0.9387, 0.7216, 0.2433],
        [0.6815, 0.8348, 0.1832, 0.1904],
        [0.4015, 0.7922, 0.9931, 0.1871]])

In [32]:
a.argsort(dim=-1)

tensor([[3, 0, 2, 1],
        [2, 3, 0, 1],
        [3, 0, 1, 2]])

In [86]:
def rankdata(tensor):
    """Compute the ranks of a tensor."""
    # Argsort the tensor
    sorted_indices = t.argsort(tensor)
    
    # Create an empty tensor for ranks
    ranks = t.zeros_like(tensor)
    
    # Assign ranks
    rank = 1
    for i in sorted_indices:
        ranks[i] = rank
        rank += 1
        
    return ranks

def spearman_rank_correlation_coefficient(tensor1, tensor2):
    """Compute the Spearman Rank Correlation Coefficient between two tensors."""
    assert tensor1.shape == tensor2.shape, "Tensors must have the same shape"
    
    # Get the ranks of each tensor
    rank1 = rankdata(tensor1)
    rank2 = rankdata(tensor2)
    
    # Calculate the difference between the ranks
    d = rank1 - rank2
    
    # Compute the Spearman Rank Correlation Coefficient
    n = tensor1.numel()
    rs = 1 - (6 * t.sum(d**2)) / (n * (n**2 - 1))
    
    return rs.item()


tensor1 = torch.tensor([3.1, 2.3, 9.5, 4.1], dtype=torch.float32)
tensor2 = torch.tensor([1.5, 2.1, 10.2, 3.9], dtype=torch.float32) # quite similar, not perfect (zeroth is smaller)

print(spearman_rank_correlation_coefficient(tensor1, tensor2))
print(spearman_rank_correlation_coefficient(tensor1, tensor1))
print(spearman_rank_correlation_coefficient(tensor1, -tensor1))

0.800000011920929
1.0
-1.0


In [95]:
def spearman_experiment(
    toks: Int[Tensor, "batch seq"],
    model: HookedTransformer,
    minibatch_size: Optional[int],
    head: Tuple[int, int] = (10, 7),
):
    layer, head_idx = head
    batch_size, seq_len = toks.shape
    model.reset_hooks()
    FUNCTION_TOKS = model.to_tokens(FUNCTION_STR_TOKS, prepend_bos=False).squeeze()
    buffer = 5

    # Create external storage, which we append to
    SPEARMAN_LIST = []
    # SPEARMAN_BASELINE_LIST = []

    # ! First we define hook functions to do most of the work for us

    attn_scores_hook_name = utils.get_act_name("attn_scores", layer)
    resid_pre_hook_name = utils.get_act_name("resid_pre", layer)

    progress_bar = tqdm(total=batch_size * (seq_len - buffer - 1), desc="Computing Spearman")

    def hook_fn_compute_spearman(attn_scores: Float[Tensor, "batch nheads seqQ seqK"], hook: HookPoint, _toks: Int[Tensor, "batch seqK"]):
        '''
        We compute spearman here, because the ranks for logit lens have already been calculated & stored in the hook for resid_pre.
        '''
        _batch_size, _seq_len = _toks.shape
        
        # Mask to remove function words (true where we want to keep values)
        sK_mask: Bool[Tensor, "batch seqK"] = (FUNCTION_TOKS[None, None, :] != _toks[:, 1:, None]).all(dim=-1)
        # And for causality
        sK_mask = einops.repeat(sK_mask, "batch seqK -> batch seqQ seqK", seqQ = _seq_len-1)
        sK_indices = einops.repeat(t.arange(1, _seq_len, device=_toks.device), "seqK -> 1 1 seqK") # first sK isn't BOS, it's BOS+1
        sQ_indices = einops.repeat(t.arange(1, _seq_len, device=_toks.device), "seqQ -> 1 seqQ 1")
        sK_mask = t.where(sQ_indices < sK_indices, sK_mask, False)

        head_score = attn_scores[:, head_idx, 1:, 1:]
        logit_lens: Tensor = model.hook_dict[resid_pre_hook_name].ctx.pop("logit_lens")
        assert logit_lens.shape == head_score.shape == (_batch_size, _seq_len-1, _seq_len-1)

        for b in range(_batch_size):
            # To be fair, we're starting queries with at least 5 elements (noisy and overweighting weird shit otherwise)
            for sQ in range(buffer, _seq_len-1):
                _head_score = head_score[b, sQ, sK_mask[b, sQ]]
                _logit_lens = logit_lens[b, sQ, sK_mask[b, sQ]]
                SPEARMAN_LIST.append(spearman_rank_correlation_coefficient(_head_score, _logit_lens))
                # # only calculate baseline random sparingly, to save time
                # if b == 0:
                #     SPEARMAN_BASELINE_LIST.append(spearman_rank_correlation_coefficient(_head_score, t.rand_like(_logit_lens)))
                progress_bar.update()


    def hook_fn_compute_logit_lens(resid_pre: Float[Tensor, "batch seqK d_model"], hook: HookPoint, _toks: Int[Tensor, "batch seqK"]):
        '''
        Computes logit lens at the residual stream before the head, and figures out which (b, sQ, sK) should
        have copy suppression activated and which shouldn't.
        '''
        logit_lens_for_src = einops.einsum(
            resid_pre[:, 1:],
            model.W_U.T[_toks[:, 1:]], # ignore BOS
            "batch seqQ d_model, batch seqK d_model -> batch seqQ seqK"
        )
        hook.ctx["logit_lens"] = logit_lens_for_src
     
    
    # ! Next we run a fwd pass, to activate these hook fns

    toks_for_fwd_pass = (toks,) if (minibatch_size is None) else toks.split(minibatch_size, dim=0)

    for _toks in toks_for_fwd_pass:
        model.run_with_hooks(
            _toks,
            return_type = None,
            fwd_hooks = [
                (attn_scores_hook_name, partial(hook_fn_compute_spearman, _toks=_toks)),
                (resid_pre_hook_name, partial(hook_fn_compute_logit_lens, _toks=_toks)),
            ]
        )
        model.reset_hooks()
        t.cuda.empty_cache()

    # ! Print results (including random spearman baseline)

    SPEARMAN_LIST = t.tensor(SPEARMAN_LIST)[~t.isnan(t.tensor(SPEARMAN_LIST))]
    # SPEARMAN_BASELINE_LIST = t.tensor(SPEARMAN_BASELINE_LIST)[~t.isnan(t.tensor(SPEARMAN_BASELINE_LIST))]

    print(f"Avg spearman correlation coefficient = {SPEARMAN_LIST.mean()}")
    # print(f"Baseline: random spearman = {SPEARMAN_BASELINE_LIST.mean()}")

    return SPEARMAN_LIST

In [96]:
SPEARMAN_LIST = spearman_experiment(
    toks = DATA_TOKS,
    model = model,
    minibatch_size = 10,
    head = (10, 7),
)



Computing Spearman: 100%|██████████| 4400/4400 [00:03<00:00, 1120.55it/s]

Avg spearman correlation coefficient = 0.013215672224760056



