In this notebook, we get the overall effect of ablating each component and save the results as counterfact_ablation_{saved_name}.pkl, based on the df generated by df_inference.ipynb.

### Import

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import pandas as pd
import torch
from tqdm.auto import tqdm
from transformers import LlamaForCausalLM, LlamaTokenizerFast

from pii import utils, vocab, decomp_hf

### Load model

In [3]:
# pick model to run on
# or meta-llama/Llama-2-7b-chat-hf or meta-llama/Llama-2-70b-chat-hf
MODEL_NAME = "meta-llama/Llama-2-70b-chat-hf"

# pick a dataset to run on
# "counterfact" or "hqfact"
DATASET_NAME = "counterfact"

if MODEL_NAME == "meta-llama/Llama-2-7b-chat-hf":
    SAVED_NAME = "llama2_7b"
    BATCH_SIZE = 8
    DTYPE = torch.float16
    DEVICE_MAP = None
elif MODEL_NAME == "meta-llama/Llama-2-13b-chat-hf":
    SAVED_NAME = "llama2_13b"
    BATCH_SIZE = 2
    DTYPE = torch.float16
    DEVICE_MAP = "auto"
elif MODEL_NAME == "meta-llama/Llama-2-70b-chat-hf":
    SAVED_NAME = "llama2_70b"
    BATCH_SIZE = 16
    DTYPE = torch.float16
    DEVICE_MAP = "auto"

In [4]:
# You will need to login to huggingface first:
#   huggingface-cli login
tokenizer = LlamaTokenizerFast.from_pretrained(MODEL_NAME)
tokenizer.pad_token_id = tokenizer.eos_token_id

hf_model = LlamaForCausalLM.from_pretrained(
    MODEL_NAME,
    low_cpu_mem_usage=True,
    torch_dtype=DTYPE,
    device_map=DEVICE_MAP,
).eval()
if DEVICE_MAP is None:
    hf_model = hf_model.cuda()

Loading checkpoint shards:   0%|          | 0/15 [00:00<?, ?it/s]

In [5]:
print(
    utils.gen_text(
        hf_model,
        tokenizer,
        prompt="The capital of Germany is",
        max_new_tokens=22,
    )
)

Berlin.

The currency of Germany is the Euro.

The country code for Germany is DE.


In [6]:
filename = f"{DATASET_NAME}_inference_{SAVED_NAME}.csv"
df_inference = pd.read_csv(
    utils.get_repo_root() / "data" / "inference" / filename
)
df_inference.head()

Unnamed: 0,case_id,pararel_idx,relation_id,subject,target_new_str,target_true_str,fact_prefix,irrelevant_word,prompt_c,prompt_nc0,prompt_nc1,p_correct_c,p_correct_nc0,p_correct_nc1,lo_correct_c,lo_correct_nc0,lo_correct_nc1,log_bf0,log_bf1
0,3348,10929,P495,Pinoy Idol,Norway,Philippines,"Pinoy Idol, that was from",fox,[INST] <<SYS>>\nYou are an obedient assistant ...,[INST] <<SYS>>\nYou are an obedient assistant ...,[INST] <<SYS>>\nYou are an obedient assistant ...,0.000968,0.063472,0.003805,-6.938923,-2.691584,-5.567579,-4.247339,-1.371345
1,6515,6929,P17,Guillaumes,Norway,France,"Guillaumes, which is located in",kiwi,[INST] <<SYS>>\nYou are an obedient assistant ...,[INST] <<SYS>>\nYou are an obedient assistant ...,[INST] <<SYS>>\nYou are an obedient assistant ...,9e-06,0.71652,0.316847,-11.650864,0.927266,-0.768299,-12.57813,-10.882565
2,3342,8551,P27,Howe Yoon Chong,Belgium,Singapore,"Howe Yoon Chong, who has a citizenship from",dog,[INST] <<SYS>>\nYou are an obedient assistant ...,[INST] <<SYS>>\nYou are an obedient assistant ...,[INST] <<SYS>>\nYou are an obedient assistant ...,0.00136,0.966347,0.803121,-6.598834,3.357422,1.405915,-9.956256,-8.00475
3,8494,9273,P463,Mexico national football team,Hamas,FIFA,Mexico national football team is a member of,toothbrush,[INST] <<SYS>>\nYou are an obedient assistant ...,[INST] <<SYS>>\nYou are an obedient assistant ...,[INST] <<SYS>>\nYou are an obedient assistant ...,0.000634,0.984641,0.913474,-7.363537,4.160598,2.356807,-11.524135,-9.720343
4,20754,10478,P495,The Last Hunter,Japan,Italy,The Last Hunter originated in,cloud,[INST] <<SYS>>\nYou are an obedient assistant ...,[INST] <<SYS>>\nYou are an obedient assistant ...,[INST] <<SYS>>\nYou are an obedient assistant ...,2.3e-05,0.001977,0.000519,-10.688438,-6.223971,-7.563049,-4.464467,-3.125389


In [7]:
filt = np.minimum(df_inference.p_correct_nc0, df_inference.p_correct_nc1) > 0.5
df = df_inference[filt].reset_index(drop=True)
print(f"Retained {len(df)} / {len(df_inference)} facts")

Retained 4251 / 14002 facts


### Importance of components (first-order-ablation)

In [8]:
VEM = vocab.VocabEquivalenceMap(tokenizer=tokenizer)

In [9]:
LABELS: list[str] | None = None
metrics = []

pbar = tqdm(range(0, len(df), BATCH_SIZE))
for idx_start in pbar:
    idx_end = min(idx_start + BATCH_SIZE, len(df))

    prompts = []
    n_tokens_all = []
    for idx in range(idx_start, idx_end):
        prompts.extend(
            [
                df.prompt_c[idx],
                df.prompt_nc0[idx],
                df.prompt_nc1[idx],
            ]
        )
        n_tokens_all.extend(
            [
                len(tokenizer(df.prompt_c[idx])[0]),
                len(tokenizer(df.prompt_nc0[idx])[0]),
                len(tokenizer(df.prompt_nc1[idx])[0]),
            ]
        )

    tokenize = lambda x: tokenizer(
        x, padding=True, return_tensors="pt"
    ).input_ids.to(hf_model.device)
    prompt_tokens = tokenize(prompts)

    with torch.no_grad():
        comp_list = decomp_hf.get_all_resid_components_hf(
            model=hf_model,
            tokens=prompt_tokens,
            resid_pos_idxs=[n - 1 for n in n_tokens_all],
        )

    for i in range(idx_end - idx_start):
        n_tokens = n_tokens_all[3 * i : 3 * i + 3]
        answer_str = df.target_true_str[idx_start + i]

        with torch.no_grad():
            lc_c, lc_nc0, lc_nc1 = comp_list[3 * i : 3 * i + 3]
            LABELS = lc_c.labels

            res_final_c, res_final_nc0, res_final_nc1 = (
                comp_list[3 * i + j].resid_post for j in range(3)
            )

            delta_c_nc0 = lc_nc0.components - lc_c.components
            delta_c_nc1 = lc_nc1.components - lc_c.components
            delta_nc0_c = lc_c.components - lc_nc0.components
            delta_nc1_c = lc_c.components - lc_nc1.components

            get_probs = lambda reses: VEM.p_correct(
                hf_model.lm_head(hf_model.model.norm(reses))
                .double()
                .softmax(dim=-1),
                correct_answer=answer_str,
            ).to(hf_model.device)

            p_correct_c_nc0 = get_probs(res_final_c[None, :] + delta_c_nc0)
            p_correct_c_nc1 = get_probs(res_final_c[None, :] + delta_c_nc1)
            p_correct_nc0_c = get_probs(res_final_nc0[None, :] + delta_nc0_c)
            p_correct_nc1_c = get_probs(res_final_nc1[None, :] + delta_nc1_c)

            # Dynamic cumulative ablations

            ord_c_nc = torch.argsort(
                p_correct_c_nc0.logit()
                + p_correct_c_nc1.logit()
                - 2 * torch.tensor(df.p_correct_c[idx_start + i]).logit(),
                descending=True,
            )
            ord_nc_c = torch.argsort(
                p_correct_nc0_c.logit()
                + p_correct_nc1_c.logit()
                - torch.tensor(df.p_correct_nc0[idx_start + i]).logit()
                - torch.tensor(df.p_correct_nc1[idx_start + i]).logit(),
                descending=False,
            )

            delta_c_nc0_dcum = torch.cumsum(delta_c_nc0[ord_c_nc], dim=0)
            delta_c_nc1_dcum = torch.cumsum(delta_c_nc1[ord_c_nc], dim=0)
            delta_nc0_c_dcum = torch.cumsum(delta_nc0_c[ord_nc_c], dim=0)
            delta_nc1_c_dcum = torch.cumsum(delta_nc1_c[ord_nc_c], dim=0)

            p_correct_c_nc0_dcum = get_probs(
                res_final_c[None, :] + delta_c_nc0_dcum
            )
            p_correct_c_nc1_dcum = get_probs(
                res_final_c[None, :] + delta_c_nc1_dcum
            )
            p_correct_nc0_c_dcum = get_probs(
                res_final_nc0[None, :] + delta_nc0_c_dcum
            )
            p_correct_nc1_c_dcum = get_probs(
                res_final_nc1[None, :] + delta_nc1_c_dcum
            )

        to_logodds = lambda probs: probs.logit().float().cpu().numpy()

        metrics.append(
            dict(
                lo_correct_c_nc0=to_logodds(p_correct_c_nc0),
                lo_correct_c_nc1=to_logodds(p_correct_c_nc1),
                lo_correct_nc0_c=to_logodds(p_correct_nc0_c),
                lo_correct_nc1_c=to_logodds(p_correct_nc1_c),
                lo_correct_c_nc0_dcum=to_logodds(p_correct_c_nc0_dcum),
                lo_correct_c_nc1_dcum=to_logodds(p_correct_c_nc1_dcum),
                lo_correct_nc0_c_dcum=to_logodds(p_correct_nc0_c_dcum),
                lo_correct_nc1_c_dcum=to_logodds(p_correct_nc1_c_dcum),
            )
        )

df = df.assign(**pd.DataFrame(metrics))
df.attrs["LABELS"] = LABELS

  0%|          | 0/266 [00:00<?, ?it/s]

In [10]:
lo_correct_nc_c = np.concatenate(
    [np.stack(df.lo_correct_nc0_c), np.stack(df.lo_correct_nc1_c)]
)
lo_correct_c_nc = np.concatenate(
    [np.stack(df.lo_correct_c_nc0), np.stack(df.lo_correct_c_nc1)]
)
print(lo_correct_c_nc.shape, lo_correct_nc_c.shape)

lo_correct_c = np.concatenate([df.lo_correct_c, df.lo_correct_c])
lo_correct_nc = np.concatenate([df.lo_correct_nc0, df.lo_correct_nc1])
print(lo_correct_c.shape, lo_correct_nc.shape)

log_bf_c_nc = lo_correct_c_nc - lo_correct_c[:, None]
log_bf_nc_c = lo_correct_nc_c - lo_correct_nc[:, None]
print(log_bf_c_nc.shape, log_bf_nc_c.shape)

(8502, 5201) (8502, 5201)
(8502,) (8502,)
(8502, 5201) (8502, 5201)


### Cumulative ablations

In [11]:
ordering_nc_c = torch.argsort(
    torch.tensor(log_bf_nc_c.mean(axis=0)), descending=False
)
ordering_c_nc = torch.argsort(
    torch.tensor(log_bf_c_nc.mean(axis=0)), descending=True
)

print(ordering_nc_c[:10])
print(ordering_c_nc[:10])

tensor([5200, 4468, 4512, 3390, 5135, 4463, 3250, 4908, 4510, 4909])
tensor([5200, 4468, 5135, 4512, 3390, 4463, 4908, 3250, 4510, 4909])


In [12]:
metrics = []

pbar = tqdm(range(0, len(df), BATCH_SIZE))
for idx_start in pbar:
    idx_end = min(idx_start + BATCH_SIZE, len(df))

    prompts = []
    n_tokens_all = []
    for idx in range(idx_start, idx_end):
        prompts.extend(
            [
                df.prompt_c[idx],
                df.prompt_nc0[idx],
                df.prompt_nc1[idx],
            ]
        )
        n_tokens_all.extend(
            [
                len(tokenizer(df.prompt_c[idx])[0]),
                len(tokenizer(df.prompt_nc0[idx])[0]),
                len(tokenizer(df.prompt_nc1[idx])[0]),
            ]
        )

    tokenize = lambda x: tokenizer(
        x, padding=True, return_tensors="pt"
    ).input_ids.to(hf_model.device)
    prompt_tokens = tokenize(prompts)

    with torch.no_grad():
        comp_list = decomp_hf.get_all_resid_components_hf(
            model=hf_model,
            tokens=prompt_tokens,
            resid_pos_idxs=[n - 1 for n in n_tokens_all],
        )

    for i in range(idx_end - idx_start):
        n_tokens = n_tokens_all[3 * i : 3 * i + 3]
        answer_str = df.target_true_str[idx_start + i]

        with torch.no_grad():
            lc_c, lc_nc0, lc_nc1 = comp_list[3 * i : 3 * i + 3]
            LABELS = lc_c.labels

            res_final_c, res_final_nc0, res_final_nc1 = (
                comp_list[3 * i + j].resid_post for j in range(3)
            )

            delta_c_nc0 = lc_nc0.components - lc_c.components
            delta_c_nc1 = lc_nc1.components - lc_c.components
            delta_nc0_c = lc_c.components - lc_nc0.components
            delta_nc1_c = lc_c.components - lc_nc1.components

            delta_c_nc0_cum = torch.cumsum(delta_c_nc0[ordering_c_nc], dim=0)
            delta_c_nc1_cum = torch.cumsum(delta_c_nc1[ordering_c_nc], dim=0)
            delta_nc0_c_cum = torch.cumsum(delta_nc0_c[ordering_nc_c], dim=0)
            delta_nc1_c_cum = torch.cumsum(delta_nc1_c[ordering_nc_c], dim=0)

            get_cum_probs = lambda reses: VEM.p_correct(
                hf_model.lm_head(hf_model.model.norm(reses))
                .double()
                .softmax(dim=-1),
                correct_answer=answer_str,
            ).to(hf_model.device)

            p_correct_c_nc0_cum = get_probs(
                res_final_c[None, :] + delta_c_nc0_cum
            )
            p_correct_c_nc1_cum = get_probs(
                res_final_c[None, :] + delta_c_nc1_cum
            )
            p_correct_nc0_c_cum = get_probs(
                res_final_nc0[None, :] + delta_nc0_c_cum
            )
            p_correct_nc1_c_cum = get_probs(
                res_final_nc1[None, :] + delta_nc1_c_cum
            )

        to_logodds = lambda probs: probs.logit().float().cpu().numpy()

        metrics.append(
            dict(
                lo_correct_c_nc0_cum=to_logodds(p_correct_c_nc0_cum),
                lo_correct_c_nc1_cum=to_logodds(p_correct_c_nc1_cum),
                lo_correct_nc0_c_cum=to_logodds(p_correct_nc0_c_cum),
                lo_correct_nc1_c_cum=to_logodds(p_correct_nc1_c_cum),
            )
        )

df = df.assign(**pd.DataFrame(metrics))

  0%|          | 0/266 [00:00<?, ?it/s]

In [13]:
lo_correct_nc_c_cum = np.concatenate(
    [np.stack(df.lo_correct_nc0_c_cum), np.stack(df.lo_correct_nc1_c_cum)]
)
lo_correct_c_nc_cum = np.concatenate(
    [np.stack(df.lo_correct_c_nc0_cum), np.stack(df.lo_correct_c_nc1_cum)]
)
lo_correct_nc_c_dcum = np.concatenate(
    [np.stack(df.lo_correct_nc0_c_dcum), np.stack(df.lo_correct_nc1_c_dcum)]
)
lo_correct_c_nc_dcum = np.concatenate(
    [np.stack(df.lo_correct_c_nc0_dcum), np.stack(df.lo_correct_c_nc1_dcum)]
)
print(lo_correct_c_nc_cum.shape, lo_correct_nc_c_cum.shape)
print(lo_correct_c_nc_dcum.shape, lo_correct_nc_c_dcum.shape)

log_bf_c_nc_cum = lo_correct_c_nc_cum - lo_correct_c[:, None]
log_bf_nc_c_cum = lo_correct_nc_c_cum - lo_correct_nc[:, None]
log_bf_c_nc_dcum = lo_correct_c_nc_dcum - lo_correct_c[:, None]
log_bf_nc_c_dcum = lo_correct_nc_c_dcum - lo_correct_nc[:, None]
print(log_bf_c_nc_cum.shape, log_bf_nc_c_cum.shape)
print(log_bf_c_nc_dcum.shape, log_bf_nc_c_dcum.shape)

(8502, 5201) (8502, 5201)
(8502, 5201) (8502, 5201)
(8502, 5201) (8502, 5201)
(8502, 5201) (8502, 5201)


### Save dataframe

In [17]:
# Save df
file_name = f"{DATASET_NAME}_ablation_{SAVED_NAME}.pkl"
df.to_pickle(utils.get_repo_root() / "data" / "ablation" / file_name)