In this notebook, we get the overall effect of forbidding the correct answer and save the results as counterfact_inference_{saved_name}.csv

### Import

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import pandas as pd
import torch
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

from pii import utils, datasets, vocab

### Load model

In [3]:
# pick model to run on
# or meta-llama/Llama-2-7b-chat-hf or meta-llama/Llama-2-70b-chat-hf
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"

if MODEL_NAME == "meta-llama/Llama-2-7b-chat-hf":
    SAVED_NAME = "llama2_7b"
    BATCH_SIZE = 8
    DTYPE = torch.float16
    DEVICE_MAP = None
elif MODEL_NAME == "meta-llama/Llama-2-13b-chat-hf":
    SAVED_NAME = "llama2_13b"
    BATCH_SIZE = 1
    DTYPE = torch.float16
    DEVICE_MAP = "auto"
elif MODEL_NAME == "meta-llama/Llama-2-70b-chat-hf":
    SAVED_NAME = "llama2_70b"
    BATCH_SIZE = 1
    DTYPE = torch.float16
    DEVICE_MAP = "auto"

In [4]:
# You will need to login to huggingface first:
#   huggingface-cli login
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token_id = tokenizer.eos_token_id

hf_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    low_cpu_mem_usage=True,
    torch_dtype=DTYPE,
    device_map=DEVICE_MAP,
)
if DEVICE_MAP is None:
    hf_model = hf_model.cuda()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



In [5]:
print(
    utils.gen_text(
        hf_model,
        tokenizer,
        prompt="The capital of Germany is",
        max_new_tokens=22,
    )
)

Berlin. Berlin is the largest city in Germany and is known for its rich history, cultural attractions, and


### Develop dataset

In [6]:
df_all = datasets.get_counterfact_dataset(tokenizer=tokenizer)
print(len(df_all))

14002


In [7]:
df_raw = df_all.sample(
    frac=1,
    random_state=42,
).reset_index(drop=True)
print(len(df_raw))
df_raw.head()

14002


Unnamed: 0,case_id,pararel_idx,relation_id,subject,target_new_str,target_true_str,fact_prefix,irrelevant_word,prompt_c,prompt_nc0,prompt_nc1
0,3348,10929,P495,Pinoy Idol,Norway,Philippines,"Pinoy Idol, that was from",fox,[INST] <<SYS>>\nYou are an obedient assistant ...,[INST] <<SYS>>\nYou are an obedient assistant ...,[INST] <<SYS>>\nYou are an obedient assistant ...
1,6515,6929,P17,Guillaumes,Norway,France,"Guillaumes, which is located in",kiwi,[INST] <<SYS>>\nYou are an obedient assistant ...,[INST] <<SYS>>\nYou are an obedient assistant ...,[INST] <<SYS>>\nYou are an obedient assistant ...
2,3342,8551,P27,Howe Yoon Chong,Belgium,Singapore,"Howe Yoon Chong, who has a citizenship from",dog,[INST] <<SYS>>\nYou are an obedient assistant ...,[INST] <<SYS>>\nYou are an obedient assistant ...,[INST] <<SYS>>\nYou are an obedient assistant ...
3,8494,9273,P463,Mexico national football team,Hamas,FIFA,Mexico national football team is a member of,toothbrush,[INST] <<SYS>>\nYou are an obedient assistant ...,[INST] <<SYS>>\nYou are an obedient assistant ...,[INST] <<SYS>>\nYou are an obedient assistant ...
4,20754,10478,P495,The Last Hunter,Japan,Italy,The Last Hunter originated in,cloud,[INST] <<SYS>>\nYou are an obedient assistant ...,[INST] <<SYS>>\nYou are an obedient assistant ...,[INST] <<SYS>>\nYou are an obedient assistant ...


In [8]:
idx = 20
for i, prompt in enumerate(
    [df_raw.prompt_nc0[idx], df_raw.prompt_nc1[idx], df_raw.prompt_c[idx]]
):
    utils.print_with_wrap(prompt)
    utils.get_top_responses_hf(
        prompt=prompt, model=hf_model, tokenizer=tokenizer, top_k=5, n_continuation_tokens=20
    )
    print() if i < 2 else None

[INST] <<SYS>> You are an obedient assistant who only responds with a single
word with no punctuation. You answer truthfully. However, you are not allowed to
say the forbidden word Libya. <</SYS>> Biberach district is located in the
country of [/INST]
Rank 0. Logit: 21.14 Prob: 95.69% Tokens: ( 9556) |Germany|</s>|
Rank 1. Logit: 17.89 Prob:  3.71% Tokens: (28806) |Switzerland|</s>|
Rank 2. Logit: 15.25 Prob:  0.26% Tokens: ( 3444) |France|</s>|
Rank 3. Logit: 13.80 Prob:  0.06% Tokens: (18585) |Sure|!|Bi|ber|ach|district|is|located|in|Germany|.|</s>|
Rank 4. Logit: 13.59 Prob:  0.05% Tokens: (21311) |Baden|-|W|ür|ttemberg|</s>|

[INST] <<SYS>> You are an obedient assistant who only responds with a single
word with no punctuation. You answer truthfully. However, you are not allowed to
say the forbidden word dog. <</SYS>> Biberach district is located in the country
of [/INST]
Rank 0. Logit: 21.22 Prob: 92.90% Tokens: ( 9556) |Germany|</s>|
Rank 1. Logit: 18.56 Prob:  6.52% Tokens: (2880

In [9]:
VEM = vocab.VocabEquivalenceMap(tokenizer=tokenizer)

In [10]:
metrics = []
pbar = tqdm(range(0, len(df_raw), BATCH_SIZE))
for idx_start in pbar:
    idx_end = min(idx_start + BATCH_SIZE, len(df_raw))

    prompts = []
    n_tokens = []
    for idx in range(idx_start, idx_end):
        prompts.extend(
            [
                df_raw.prompt_c[idx],
                df_raw.prompt_nc0[idx],
                df_raw.prompt_nc1[idx],
            ]
        )
        n_tokens.extend(
            [
                len(tokenizer(df_raw.prompt_c[idx])[0]),
                len(tokenizer(df_raw.prompt_nc0[idx])[0]),
                len(tokenizer(df_raw.prompt_nc1[idx])[0]),
            ]
        )

    tokenize = lambda x: tokenizer(
        x, padding=True, return_tensors="pt"
    ).input_ids.to(hf_model.device)
    prompt_tokens = tokenize(prompts)

    with torch.no_grad():
        logits = hf_model.forward(prompt_tokens, use_cache=False).logits

    for i in range(idx_end - idx_start):
        n_tokens_c = n_tokens[3 * i]
        n_tokens_nc0 = n_tokens[3 * i + 1]
        n_tokens_nc1 = n_tokens[3 * i + 2]

        logits_c = logits[3 * i, n_tokens_c - 1].double()
        logits_nc0 = logits[3 * i + 1, n_tokens_nc0 - 1].double()
        logits_nc1 = logits[3 * i + 2, n_tokens_nc1 - 1].double()

        answer_str = df_raw.target_true_str[idx_start + i]
        with torch.no_grad():
            p_correct_c = VEM.p_correct(logits_c.softmax(dim=-1), answer_str)
            p_correct_nc0 = VEM.p_correct(
                logits_nc0.softmax(dim=-1), answer_str
            )
            p_correct_nc1 = VEM.p_correct(
                logits_nc1.softmax(dim=-1), answer_str
            )

            log_bf0 = p_correct_c.logit() - p_correct_nc0.logit()
            log_bf1 = p_correct_c.logit() - p_correct_nc1.logit()

        metrics.append(
            dict(
                p_correct_c=p_correct_c.item(),
                p_correct_nc0=p_correct_nc0.item(),
                p_correct_nc1=p_correct_nc1.item(),
                lo_correct_c=p_correct_c.logit().item(),
                lo_correct_nc0=p_correct_nc0.logit().item(),
                lo_correct_nc1=p_correct_nc1.logit().item(),
                log_bf0=log_bf0.item(),
                log_bf1=log_bf1.item(),
            )
        )

df_raw = df_raw.assign(**pd.DataFrame(metrics))

  0%|          | 0/1751 [00:00<?, ?it/s]

In [11]:
filename = f"counterfact_inference_{SAVED_NAME}.csv"
# Save df_raw
df_raw.to_csv(
    utils.get_repo_root() / "data" / filename, index=False
)

In [12]:
df_inference = pd.read_csv(utils.get_repo_root() / "data" / filename)
print(df_inference)

       case_id  pararel_idx relation_id                        subject  \
0         3348        10929        P495                     Pinoy Idol   
1         6515         6929         P17                     Guillaumes   
2         3342         8551         P27                Howe Yoon Chong   
3         8494         9273        P463  Mexico national football team   
4        20754        10478        P495                The Last Hunter   
...        ...          ...         ...                            ...   
13997     8173        20262        P276           Concordia University   
13998    21028        18471       P1412                    Livia Turco   
13999     8493        18783        P407                       Le Monde   
14000     1360        16454        P136                Freddie Keppard   
14001    11541         1947        P131                 Austin College   

      target_new_str target_true_str  \
0             Norway     Philippines   
1             Norway          F

In [13]:
filt = (np.minimum(df_raw.p_correct_nc0, df_raw.p_correct_nc1) > 0.5) & (
    np.maximum(df_raw.log_bf0, df_raw.log_bf1) / np.log(10) < -2
)
df = df_raw[filt].reset_index(drop=True)
print(f"Retained {len(df)} / {len(df_raw)} facts")

Retained 2578 / 14002 facts
