In this notebook, we get the overall effect of ablating each component and save the results as counterfact_ablation_{saved_name}.pkl, based on the df generated by df_inference.ipynb.

### Import

In [15]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [16]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import tqdm.auto as tqdm
import transformer_lens.utils as tl_utils
from jaxtyping import Float
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer

from pii import utils, datasets, vocab, decomp

### Load model

In [17]:
# pick model to run on
# or meta-llama/Llama-2-7b-chat-hf or meta-llama/Llama-2-70b-chat-hf
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"

if MODEL_NAME == "meta-llama/Llama-2-7b-chat-hf":
    SAVED_NAME = "llama2_7b"
    N_DEVICES = 1
    BATCH_SIZE = 4
elif MODEL_NAME == "meta-llama/Llama-2-13b-chat-hf":
    SAVED_NAME = "llama2_13b"
    N_DEVICES = 1
    BATCH_SIZE = 1
elif MODEL_NAME == "meta-llama/Llama-2-70b-chat-hf":
    SAVED_NAME = "llama2_70b"
    N_DEVICES = 2
    BATCH_SIZE = 1

In [18]:
# You will need to login to huggingface first:
#   huggingface-cli login
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
hf_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [19]:
tl_model = HookedTransformer.from_pretrained(
    MODEL_NAME,
    hf_model=hf_model,
    device="cuda",
    n_devices=N_DEVICES,
    move_to_device=True,
    fold_ln=False,
    fold_value_biases=False,
    center_writing_weights=False,
    center_unembed=False,
    tokenizer=tokenizer,
    torch_dtype=torch.float16,
)

with torch.no_grad():
    print(
        tl_model.generate(
            "The capital of Germany is", max_new_tokens=20, temperature=0
        )
    )

Using pad_token, but it is not set yet.


Loaded pretrained model meta-llama/Llama-2-7b-chat-hf into HookedTransformer


  0%|          | 0/20 [00:00<?, ?it/s]

The capital of Germany is Berlin. Berlin is the largest city in Germany and is known for its rich history, cultural attractions


In [20]:
filename = f"counterfact_inference_{SAVED_NAME}.csv"
df_inference = pd.read_csv(utils.get_repo_root() / "data" / filename)
df_inference.head()

Unnamed: 0,case_id,pararel_idx,relation_id,subject,target_new_str,target_true_str,fact_prefix,irrelevant_word,prompt_c,prompt_nc0,prompt_nc1,p_correct_c,p_correct_nc0,p_correct_nc1,lo_correct_c,lo_correct_nc0,lo_correct_nc1,log_bf0,log_bf1
0,3348,10929,P495,Pinoy Idol,Norway,Philippines,"Pinoy Idol, that was from",fox,[INST] <<SYS>>\nYou are an obedient assistant ...,[INST] <<SYS>>\nYou are an obedient assistant ...,[INST] <<SYS>>\nYou are an obedient assistant ...,0.00042,0.376995,0.010282,-7.775841,-0.502321,-4.566983,-7.27352,-3.208858
1,6515,6929,P17,Guillaumes,Norway,France,"Guillaumes, which is located in",kiwi,[INST] <<SYS>>\nYou are an obedient assistant ...,[INST] <<SYS>>\nYou are an obedient assistant ...,[INST] <<SYS>>\nYou are an obedient assistant ...,0.000837,0.7536,0.956851,-7.085032,1.117907,3.098995,-8.202939,-10.184027
2,3342,8551,P27,Howe Yoon Chong,Belgium,Singapore,"Howe Yoon Chong, who has a citizenship from",dog,[INST] <<SYS>>\nYou are an obedient assistant ...,[INST] <<SYS>>\nYou are an obedient assistant ...,[INST] <<SYS>>\nYou are an obedient assistant ...,0.010512,0.582458,0.809956,-4.544692,0.332871,1.449727,-4.877562,-5.994419
3,8494,9273,P463,Mexico national football team,Hamas,FIFA,Mexico national football team is a member of,toothbrush,[INST] <<SYS>>\nYou are an obedient assistant ...,[INST] <<SYS>>\nYou are an obedient assistant ...,[INST] <<SYS>>\nYou are an obedient assistant ...,0.000205,0.091492,0.070407,-8.494187,-2.295549,-2.580447,-6.198638,-5.91374
4,20754,10478,P495,The Last Hunter,Japan,Italy,The Last Hunter originated in,cloud,[INST] <<SYS>>\nYou are an obedient assistant ...,[INST] <<SYS>>\nYou are an obedient assistant ...,[INST] <<SYS>>\nYou are an obedient assistant ...,9.4e-05,0.003586,0.000136,-9.276137,-5.627262,-8.902347,-3.648875,-0.37379


In [21]:
filt = (np.minimum(df_inference.p_correct_nc0, df_inference.p_correct_nc1) > 0.5) & (
    np.maximum(df_inference.log_bf0, df_inference.log_bf1) / np.log(10) < -2
)
df = df_inference[filt].reset_index(drop=True)
print(f"Retained {len(df)} / {len(df_inference)} facts")

Retained 2645 / 14002 facts


### Importance of components (first-order-ablation)

In [22]:
VEM = vocab.VocabEquivalenceMap(tl_model)

In [31]:
LABELS: list[str] | None = None
metrics = []

pbar = tqdm(range(0, len(df), BATCH_SIZE))
for idx_start in pbar:
    idx_end = min(idx_start + BATCH_SIZE, len(df))

    prompts = []
    for idx in range(idx_start, idx_end):
        prompts.extend(
            [
                df.prompt_c[idx],
                df.prompt_nc0[idx],
                df.prompt_nc1[idx],
            ]
        )

    with torch.no_grad():
        logits, cache = tl_model.run_with_cache(prompts)

    for i in range(idx_end - idx_start):
        n_tokens = [
            len(tl_model.to_tokens(prompts[3 * i + j])[0]) for j in range(3)
        ]
        answer_str = df.target_true_str[idx_start + i]

        with torch.no_grad():
            lc_c, lc_nc0, lc_nc1 = (
                decomp.get_all_resid_components(
                    tl_model=tl_model,
                    cache=cache,
                    pos=n_tokens[j] - 1,
                    batch_idx=3 * i + j,
                )
                for j in range(3)
            )
            LABELS = lc_c.labels

            res_final_c, res_final_nc0, res_final_nc1 = (
                cache["resid_post", -1][3 * i + j, n_tokens[j] - 1]
                for j in range(3)
            )

            delta_c_nc0 = lc_nc0.components - lc_c.components
            delta_c_nc1 = lc_nc1.components - lc_c.components
            delta_nc0_c = lc_c.components - lc_nc0.components
            delta_nc1_c = lc_c.components - lc_nc1.components

            get_probs = lambda reses: VEM.p_correct(
                (tl_model.ln_final(reses) @ tl_model.W_U)
                .double()
                .softmax(dim=-1),
                correct_answer=answer_str,
            )

            p_correct_c_nc0 = get_probs(res_final_c[None, :] + delta_c_nc0)
            p_correct_c_nc1 = get_probs(res_final_c[None, :] + delta_c_nc1)
            p_correct_nc0_c = get_probs(res_final_nc0[None, :] + delta_nc0_c)
            p_correct_nc1_c = get_probs(res_final_nc1[None, :] + delta_nc1_c)

            # Dynamic cumulative ablations

            ord_c_nc = torch.argsort(
                p_correct_c_nc0.logit()
                + p_correct_c_nc1.logit()
                - 2 * torch.tensor(df.p_correct_c[idx_start + i]).logit(),
                descending=True,
            )
            ord_nc_c = torch.argsort(
                p_correct_nc0_c.logit()
                + p_correct_nc1_c.logit()
                - torch.tensor(df.p_correct_nc0[idx_start + i]).logit()
                - torch.tensor(df.p_correct_nc1[idx_start + i]).logit(),
                descending=False,
            )

            delta_c_nc0_dcum = torch.cumsum(delta_c_nc0[ord_c_nc], dim=0)
            delta_c_nc1_dcum = torch.cumsum(delta_c_nc1[ord_c_nc], dim=0)
            delta_nc0_c_dcum = torch.cumsum(delta_nc0_c[ord_nc_c], dim=0)
            delta_nc1_c_dcum = torch.cumsum(delta_nc1_c[ord_nc_c], dim=0)

            p_correct_c_nc0_dcum = get_probs(
                res_final_c[None, :] + delta_c_nc0_dcum
            )
            p_correct_c_nc1_dcum = get_probs(
                res_final_c[None, :] + delta_c_nc1_dcum
            )
            p_correct_nc0_c_dcum = get_probs(
                res_final_nc0[None, :] + delta_nc0_c_dcum
            )
            p_correct_nc1_c_dcum = get_probs(
                res_final_nc1[None, :] + delta_nc1_c_dcum
            )

        to_logodds = lambda probs: probs.logit().float().cpu().numpy()

        metrics.append(
            dict(
                lo_correct_c_nc0=to_logodds(p_correct_c_nc0),
                lo_correct_c_nc1=to_logodds(p_correct_c_nc1),
                lo_correct_nc0_c=to_logodds(p_correct_nc0_c),
                lo_correct_nc1_c=to_logodds(p_correct_nc1_c),
                lo_correct_c_nc0_dcum=to_logodds(p_correct_c_nc0_dcum),
                lo_correct_c_nc1_dcum=to_logodds(p_correct_c_nc1_dcum),
                lo_correct_nc0_c_dcum=to_logodds(p_correct_nc0_c_dcum),
                lo_correct_nc1_c_dcum=to_logodds(p_correct_nc1_c_dcum),
            )
        )

df = df.assign(**pd.DataFrame(metrics))
df.attrs['LABELS'] = LABELS

  0%|          | 0/662 [00:00<?, ?it/s]

In [34]:
lo_correct_nc_c = np.concatenate(
    [np.stack(df.lo_correct_nc0_c), np.stack(df.lo_correct_nc1_c)]
)
lo_correct_c_nc = np.concatenate(
    [np.stack(df.lo_correct_c_nc0), np.stack(df.lo_correct_c_nc1)]
)
print(lo_correct_c_nc.shape, lo_correct_nc_c.shape)

lo_correct_c = np.concatenate([df.lo_correct_c, df.lo_correct_c])
lo_correct_nc = np.concatenate([df.lo_correct_nc0, df.lo_correct_nc1])
print(lo_correct_c.shape, lo_correct_nc.shape)

log_bf_c_nc = lo_correct_c_nc - lo_correct_c[:, None]
log_bf_nc_c = lo_correct_nc_c - lo_correct_nc[:, None]
print(log_bf_c_nc.shape, log_bf_nc_c.shape)

(5290, 1057) (5290, 1057)
(5290,) (5290,)
(5290, 1057) (5290, 1057)


### Cumulative ablations

In [35]:
ordering_nc_c = torch.argsort(
    torch.tensor(log_bf_nc_c.mean(axis=0)), descending=False
)
ordering_c_nc = torch.argsort(
    torch.tensor(log_bf_c_nc.mean(axis=0)), descending=True
)

print(ordering_nc_c[:10])
print(ordering_c_nc[:10])

tensor([ 604, 1056,  921,  868,  838,  858,  651,  747,  779,  932])
tensor([1056,  604,  921,  868,  838,  858,  651,  932,  747, 1051])


In [36]:
metrics = []

pbar = tqdm(range(0, len(df), BATCH_SIZE))
for idx_start in pbar:
    idx_end = min(idx_start + BATCH_SIZE, len(df))

    prompts = []
    for idx in range(idx_start, idx_end):
        prompts.extend(
            [
                df.prompt_c[idx],
                df.prompt_nc0[idx],
                df.prompt_nc1[idx],
            ]
        )

    with torch.no_grad():
        logits, cache = tl_model.run_with_cache(prompts)

    for i in range(idx_end - idx_start):
        n_tokens = [
            len(tl_model.to_tokens(prompts[3 * i + j])[0]) for j in range(3)
        ]
        answer_str = df.target_true_str[idx_start + i]

        with torch.no_grad():
            lc_c, lc_nc0, lc_nc1 = (
                decomp.get_all_resid_components(
                    tl_model=tl_model,
                    cache=cache,
                    pos=n_tokens[j] - 1,
                    batch_idx=3 * i + j,
                )
                for j in range(3)
            )
            LABELS = lc_c.labels

            res_final_c, res_final_nc0, res_final_nc1 = (
                cache["resid_post", -1][3 * i + j, n_tokens[j] - 1]
                for j in range(3)
            )

            delta_c_nc0 = lc_nc0.components - lc_c.components
            delta_c_nc1 = lc_nc1.components - lc_c.components
            delta_nc0_c = lc_c.components - lc_nc0.components
            delta_nc1_c = lc_c.components - lc_nc1.components

            delta_c_nc0_cum = torch.cumsum(delta_c_nc0[ordering_c_nc], dim=0)
            delta_c_nc1_cum = torch.cumsum(delta_c_nc1[ordering_c_nc], dim=0)
            delta_nc0_c_cum = torch.cumsum(delta_nc0_c[ordering_nc_c], dim=0)
            delta_nc1_c_cum = torch.cumsum(delta_nc1_c[ordering_nc_c], dim=0)

            get_cum_probs = lambda reses: p_correct(
                (tl_model.ln_final(reses) @ tl_model.W_U)
                .double()
                .softmax(dim=-1),
                correct_answer=answer_str,
            )

            p_correct_c_nc0_cum = get_probs(
                res_final_c[None, :] + delta_c_nc0_cum
            )
            p_correct_c_nc1_cum = get_probs(
                res_final_c[None, :] + delta_c_nc1_cum
            )
            p_correct_nc0_c_cum = get_probs(
                res_final_nc0[None, :] + delta_nc0_c_cum
            )
            p_correct_nc1_c_cum = get_probs(
                res_final_nc1[None, :] + delta_nc1_c_cum
            )

        to_logodds = lambda probs: probs.logit().float().cpu().numpy()

        metrics.append(
            dict(
                lo_correct_c_nc0_cum=to_logodds(p_correct_c_nc0_cum),
                lo_correct_c_nc1_cum=to_logodds(p_correct_c_nc1_cum),
                lo_correct_nc0_c_cum=to_logodds(p_correct_nc0_c_cum),
                lo_correct_nc1_c_cum=to_logodds(p_correct_nc1_c_cum),
            )
        )

df = df.assign(**pd.DataFrame(metrics))

  0%|          | 0/662 [00:00<?, ?it/s]

In [37]:
lo_correct_nc_c_cum = np.concatenate(
    [np.stack(df.lo_correct_nc0_c_cum), np.stack(df.lo_correct_nc1_c_cum)]
)
lo_correct_c_nc_cum = np.concatenate(
    [np.stack(df.lo_correct_c_nc0_cum), np.stack(df.lo_correct_c_nc1_cum)]
)
lo_correct_nc_c_dcum = np.concatenate(
    [np.stack(df.lo_correct_nc0_c_dcum), np.stack(df.lo_correct_nc1_c_dcum)]
)
lo_correct_c_nc_dcum = np.concatenate(
    [np.stack(df.lo_correct_c_nc0_dcum), np.stack(df.lo_correct_c_nc1_dcum)]
)
print(lo_correct_c_nc_cum.shape, lo_correct_nc_c_cum.shape)
print(lo_correct_c_nc_dcum.shape, lo_correct_nc_c_dcum.shape)

log_bf_c_nc_cum = lo_correct_c_nc_cum - lo_correct_c[:, None]
log_bf_nc_c_cum = lo_correct_nc_c_cum - lo_correct_nc[:, None]
log_bf_c_nc_dcum = lo_correct_c_nc_dcum - lo_correct_c[:, None]
log_bf_nc_c_dcum = lo_correct_nc_c_dcum - lo_correct_nc[:, None]
print(log_bf_c_nc_cum.shape, log_bf_nc_c_cum.shape)
print(log_bf_c_nc_dcum.shape, log_bf_nc_c_dcum.shape)

(5290, 1057) (5290, 1057)
(5290, 1057) (5290, 1057)
(5290, 1057) (5290, 1057)
(5290, 1057) (5290, 1057)


### Save dataframe

In [38]:
# Save df

file_name = f"counterfact_ablation_{SAVED_NAME}.pkl"
df.to_pickle(utils.get_repo_root() / "data" / file_name)