In [168]:
PREFIX = (
    "[INST] Question: "
    # "[INST] Question: Anne is quiet. Anne is not young. Anne is sleepy. Anne is smart if she is loud. Is Anne smart? Answer: No\n"
    # "Question: Anne is loud. Anne is not young. Anne is sleepy. Anne is smart if she is loud. Is Anne smart? Answer: Yes\n"
    # "Question: Anne is quiet. Anne is not young. Anne is sleepy. Anne is smart if she is old and quiet. Is Anne smart? Answer: Yes\n"
    # "Question: "
)
SUFFIX = " Answer (Yes/No): [/INST]"
PROMPT_TEMPLATE = "{NAME} is {ATTR1}. {NAME} is {ATTR2}. {NAME} is {ATTR3}. Is {NAME} {ATTR_L} {OPERATOR} {ATTR_R}?"
NAMES = [
    "Anne",
    "Bob",
    "Carol",
    "David",
    "Emma",
    "Mike",
    "Sarah",
    "John",
    "Linda",
    "Peter",
    "Grace",
    "Oliver",
    "Sophie",
    "Josh",
    "Mia",
    "Tom",
    "Rachel",
    "Henry",
    "Alice",
    "George",
]
POSITIVE_ATTRIBUTES = [
    "loud",
    "fast",
    "tall",
    "fat",
    "young",
    "strong",
    "smart",
    "happy",
    "kind",
    "funny",
    "curious",
    "calm",
    "pretty",
]
NEGATIVE_ATTRIBUTES = [
    "quiet",
    "slow",
    "short",
    "thin",
    "old",
    "weak",
    "dumb",
    "sad",
    "mean",
    "serious",
    "dull",
    "nervous",
    "ugly",
]
OPERATORS = [
    "and",
    "or",
]

In [169]:
def get_attribute_sign_and_index(attr: str) -> Tuple[bool, int]:
    if attr in POSITIVE_ATTRIBUTES:
        return True, POSITIVE_ATTRIBUTES.index(attr)
    elif attr in NEGATIVE_ATTRIBUTES:
        return False, NEGATIVE_ATTRIBUTES.index(attr)
    else:
        raise ValueError(f"Unknown attribute {attr}")

In [170]:
def get_answers_for_prompt_tuples(
    prompt_tuples: List[Tuple[str, str, str, str, str, str, str]]
) -> List[str]:
    answers = []
    for _, attr1, attr2, attr3, attr_l, operator, attr_r in prompt_tuples:
        attr1_sign, attr1_idx = get_attribute_sign_and_index(attr1)
        attr2_sign, attr2_idx = get_attribute_sign_and_index(attr2)
        attr3_sign, attr3_idx = get_attribute_sign_and_index(attr3)
        _, attr_l_idx = get_attribute_sign_and_index(attr_l)
        _, attr_r_idx = get_attribute_sign_and_index(attr_r)
        if operator == "and":
            if attr_l_idx == attr1_idx and attr_r_idx == attr2_idx:
                answer = attr1_sign and attr2_sign
            elif attr_l_idx == attr2_idx and attr_r_idx == attr1_idx:
                answer = attr1_sign and attr2_sign
            elif attr_l_idx == attr1_idx and attr_r_idx == attr3_idx:
                answer = attr1_sign and attr3_sign
            elif attr_l_idx == attr3_idx and attr_r_idx == attr1_idx:
                answer = attr1_sign and attr3_sign
            elif attr_l_idx == attr2_idx and attr_r_idx == attr3_idx:
                answer = attr2_sign and attr3_sign
            elif attr_l_idx == attr3_idx and attr_r_idx == attr2_idx:
                answer = attr2_sign and attr3_sign
            else:
                raise ValueError(
                    f"Invalid combination of attributes {attr_l} and {attr_r}"
                )
        elif operator == "or":
            if attr_l_idx == attr1_idx and attr_r_idx == attr2_idx:
                answer = attr1_sign or attr2_sign
            elif attr_l_idx == attr2_idx and attr_r_idx == attr1_idx:
                answer = attr1_sign or attr2_sign
            elif attr_l_idx == attr1_idx and attr_r_idx == attr3_idx:
                answer = attr1_sign or attr3_sign
            elif attr_l_idx == attr3_idx and attr_r_idx == attr1_idx:
                answer = attr1_sign or attr3_sign
            elif attr_l_idx == attr2_idx and attr_r_idx == attr3_idx:
                answer = attr2_sign or attr3_sign
            elif attr_l_idx == attr3_idx and attr_r_idx == attr2_idx:
                answer = attr2_sign or attr3_sign
            else:
                raise ValueError(
                    f"Invalid combination of attributes {attr_l} and {attr_r}"
                )
        else:
            raise ValueError(f"Unknown operator {operator}")
        answers.append("Yes" if answer else "No")
    return answers

In [171]:
def get_counterfactual_tuples(
    prompt_tuples: List[Tuple[str, str, str, str, str, str, str]], seed: int = 0
) -> List[Tuple[str, str, str, str, str, str, str]]:
    random.seed(seed)
    cf_tuples = []
    for name, attr1, attr2, attr3, attr_l, operator, attr_r in prompt_tuples:
        idx_to_change = random.choice([0, 1, 2])
        attr_sign, attr_idx = get_attribute_sign_and_index(
            [attr1, attr2, attr3][idx_to_change]
        )
        cf_attr = (
            POSITIVE_ATTRIBUTES[attr_idx]
            if not attr_sign
            else NEGATIVE_ATTRIBUTES[attr_idx]
        )
        cf_attr1, cf_attr2, cf_attr3 = (
            cf_attr if idx_to_change == 0 else attr1,
            cf_attr if idx_to_change == 1 else attr2,
            cf_attr if idx_to_change == 2 else attr3,
        )
        cf_tuples.append((name, cf_attr1, cf_attr2, cf_attr3, attr_l, operator, attr_r))
    return cf_tuples

In [172]:
PROMPT_TUPLES = [
    (
        name,
        attr1_list[attr1_idx],
        attr2_list[attr2_idx],
        attr3_list[attr3_idx],
        attr_l,
        operator,
        attr_r,
    )
    for name in NAMES
    for operator in OPERATORS
    for attr1_idx, attr2_idx, attr3_idx in itertools.combinations(
        range(len(POSITIVE_ATTRIBUTES)), 3
    )
    for attr1_list in [POSITIVE_ATTRIBUTES, NEGATIVE_ATTRIBUTES]
    for attr2_list in [POSITIVE_ATTRIBUTES, NEGATIVE_ATTRIBUTES]
    for attr3_list in [POSITIVE_ATTRIBUTES, NEGATIVE_ATTRIBUTES]
    for attr_l, attr_r in [
        (POSITIVE_ATTRIBUTES[attr1_idx], POSITIVE_ATTRIBUTES[attr2_idx]),
        (POSITIVE_ATTRIBUTES[attr2_idx], POSITIVE_ATTRIBUTES[attr1_idx]),
        (POSITIVE_ATTRIBUTES[attr1_idx], POSITIVE_ATTRIBUTES[attr3_idx]),
        (POSITIVE_ATTRIBUTES[attr3_idx], POSITIVE_ATTRIBUTES[attr1_idx]),
        (POSITIVE_ATTRIBUTES[attr2_idx], POSITIVE_ATTRIBUTES[attr3_idx]),
        (POSITIVE_ATTRIBUTES[attr3_idx], POSITIVE_ATTRIBUTES[attr2_idx]),
    ]
]
random.shuffle(PROMPT_TUPLES)
PROMPT_TUPLES = PROMPT_TUPLES[:1000]
PROMPTS = [
    PROMPT_TEMPLATE.format(
        NAME=name,
        ATTR1=attr1,
        ATTR2=attr2,
        ATTR3=attr3,
        ATTR_L=attr_l,
        OPERATOR=operator,
        ATTR_R=attr_r,
    )
    for name, attr1, attr2, attr3, attr_l, operator, attr_r in PROMPT_TUPLES
]
CF_TUPLES = get_counterfactual_tuples(PROMPT_TUPLES)
CF_PROMPTS = [
    PROMPT_TEMPLATE.format(
        NAME=name,
        ATTR1=attr1,
        ATTR2=attr2,
        ATTR3=attr3,
        ATTR_L=attr_l,
        OPERATOR=operator,
        ATTR_R=attr_r,
    )
    for name, attr1, attr2, attr3, attr_l, operator, attr_r in CF_TUPLES
]
ANSWERS = get_answers_for_prompt_tuples(PROMPT_TUPLES)
CF_ANSWERS = get_answers_for_prompt_tuples(CF_TUPLES)
to_keep = [answer != cf_answer for answer, cf_answer in zip(ANSWERS, CF_ANSWERS)]
PROMPTS = [p for p, keep in zip(PROMPTS, to_keep) if keep]
CF_PROMPTS = [p for p, keep in zip(CF_PROMPTS, to_keep) if keep]
ANSWERS = [a for a, keep in zip(ANSWERS, to_keep) if keep]
CF_ANSWERS = [a for a, keep in zip(CF_ANSWERS, to_keep) if keep]
PROMPTS = PROMPTS[:100]
CF_PROMPTS = CF_PROMPTS[:100]
ANSWERS = ANSWERS[:100]

In [173]:
PREPEND_SPACE_TO_ANSWER = False

In [174]:
for prompt, cf_prompt in zip(PROMPTS, CF_PROMPTS):
    prompt_str_tokens = model.to_str_tokens(prompt)
    cf_str_tokens = model.to_str_tokens(cf_prompt)
    assert len(prompt_str_tokens) == len(cf_str_tokens), (
        f"Prompt and counterfactual prompt must have the same length, "
        f"for prompt \n{prompt_str_tokens} \n and counterfactual\n{cf_str_tokens} \n"
        f"got {len(prompt_str_tokens)} and {len(cf_str_tokens)}"
    )

In [175]:
i = 0
for prompt, answer, cf_prompt, cf_answer in zip(
    PROMPTS, ANSWERS, CF_PROMPTS, CF_ANSWERS
):
    print(prompt)
    test_prompt(
        PREFIX + prompt + SUFFIX,
        answer,
        model,
        top_k=10,
        prepend_space_to_answer=PREPEND_SPACE_TO_ANSWER,
    )
    print(cf_prompt)
    test_prompt(
        PREFIX + cf_prompt + SUFFIX,
        cf_answer,
        model,
        top_k=10,
        prepend_space_to_answer=PREPEND_SPACE_TO_ANSWER,
    )
    i += 2
    if i > 10:
        break

Emma is slow. Emma is funny. Emma is ugly. Is Emma funny or pretty?
Tokenized prompt: ['<s>', '[', 'INST', ']', 'Question', ':', 'Emma', 'is', 'slow', '.', 'Emma', 'is', 'funny', '.', 'Emma', 'is', 'ugly', '.', 'Is', 'Emma', 'funny', 'or', 'pretty', '?', 'Answer', '(', 'Yes', '/', 'No', '):', '[', '/', 'INST', ']']
Tokenized answer: ['Yes']


Top 0th token. Logit: 14.31 Prob: 27.73% Token: |No|
Top 1th token. Logit: 14.31 Prob: 27.73% Token: |Yes|
Top 2th token. Logit: 13.88 Prob: 17.87% Token: |Answer|
Top 3th token. Logit: 12.62 Prob:  5.13% Token: |This|
Top 4th token. Logit: 12.50 Prob:  4.52% Token: |The|
Top 5th token. Logit: 12.44 Prob:  4.25% Token: |I|
Top 6th token. Logit: 12.06 Prob:  2.92% Token: |It|
Top 7th token. Logit: 11.69 Prob:  2.00% Token: |Is|
Top 8th token. Logit: 11.50 Prob:  1.66% Token: |Emma|
Top 9th token. Logit: 10.62 Prob:  0.69% Token: |Question|


Emma is slow. Emma is serious. Emma is ugly. Is Emma funny or pretty?
Tokenized prompt: ['<s>', '[', 'INST', ']', 'Question', ':', 'Emma', 'is', 'slow', '.', 'Emma', 'is', 'serious', '.', 'Emma', 'is', 'ugly', '.', 'Is', 'Emma', 'funny', 'or', 'pretty', '?', 'Answer', '(', 'Yes', '/', 'No', '):', '[', '/', 'INST', ']']
Tokenized answer: ['No']


Top 0th token. Logit: 14.88 Prob: 49.41% Token: |No|
Top 1th token. Logit: 13.81 Prob: 17.09% Token: |I|
Top 2th token. Logit: 13.25 Prob:  9.72% Token: |Answer|
Top 3th token. Logit: 12.19 Prob:  3.37% Token: |Yes|
Top 4th token. Logit: 12.06 Prob:  2.97% Token: |The|
Top 5th token. Logit: 12.00 Prob:  2.78% Token: |This|
Top 6th token. Logit: 11.75 Prob:  2.17% Token: |Question|
Top 7th token. Logit: 11.62 Prob:  1.92% Token: |It|
Top 8th token. Logit: 11.56 Prob:  1.79% Token: |Emma|
Top 9th token. Logit: 11.44 Prob:  1.59% Token: |Is|


Josh is quiet. Josh is slow. Josh is curious. Is Josh loud or fast?
Tokenized prompt: ['<s>', '[', 'INST', ']', 'Question', ':', 'Josh', 'is', 'quiet', '.', 'Josh', 'is', 'slow', '.', 'Josh', 'is', 'curious', '.', 'Is', 'Josh', 'loud', 'or', 'fast', '?', 'Answer', '(', 'Yes', '/', 'No', '):', '[', '/', 'INST', ']']
Tokenized answer: ['No']


Top 0th token. Logit: 18.50 Prob: 89.84% Token: |No|
Top 1th token. Logit: 16.12 Prob:  8.35% Token: |Answer|
Top 2th token. Logit: 13.56 Prob:  0.64% Token: |The|
Top 3th token. Logit: 12.62 Prob:  0.25% Token: |Is|
Top 4th token. Logit: 12.62 Prob:  0.25% Token: |Based|
Top 5th token. Logit: 11.88 Prob:  0.12% Token: |Josh|
Top 6th token. Logit: 10.94 Prob:  0.05% Token: |(|
Top 7th token. Logit: 10.88 Prob:  0.04% Token: |I|
Top 8th token. Logit: 10.88 Prob:  0.04% Token: |no|
Top 9th token. Logit: 10.88 Prob:  0.04% Token: |An|


Josh is quiet. Josh is fast. Josh is curious. Is Josh loud or fast?
Tokenized prompt: ['<s>', '[', 'INST', ']', 'Question', ':', 'Josh', 'is', 'quiet', '.', 'Josh', 'is', 'fast', '.', 'Josh', 'is', 'curious', '.', 'Is', 'Josh', 'loud', 'or', 'fast', '?', 'Answer', '(', 'Yes', '/', 'No', '):', '[', '/', 'INST', ']']
Tokenized answer: ['Yes']


Top 0th token. Logit: 15.25 Prob: 35.55% Token: |No|
Top 1th token. Logit: 15.19 Prob: 33.40% Token: |Answer|
Top 2th token. Logit: 14.19 Prob: 12.26% Token: |Yes|
Top 3th token. Logit: 13.81 Prob:  8.45% Token: |Is|
Top 4th token. Logit: 13.06 Prob:  3.98% Token: |The|
Top 5th token. Logit: 12.75 Prob:  2.91% Token: |Josh|
Top 6th token. Logit: 11.00 Prob:  0.51% Token: |Question|
Top 7th token. Logit: 10.75 Prob:  0.39% Token: |This|
Top 8th token. Logit: 10.62 Prob:  0.35% Token: |Based|
Top 9th token. Logit: 10.25 Prob:  0.24% Token: |I|


Sarah is fat. Sarah is happy. Sarah is dull. Is Sarah curious and happy?
Tokenized prompt: ['<s>', '[', 'INST', ']', 'Question', ':', 'Sarah', 'is', 'fat', '.', 'Sarah', 'is', 'happy', '.', 'Sarah', 'is', 'dull', '.', 'Is', 'Sarah', 'curious', 'and', 'happy', '?', 'Answer', '(', 'Yes', '/', 'No', '):', '[', '/', 'INST', ']']
Tokenized answer: ['No']


Top 0th token. Logit: 13.94 Prob: 21.09% Token: |It|
Top 1th token. Logit: 13.94 Prob: 21.09% Token: |No|
Top 2th token. Logit: 13.62 Prob: 15.43% Token: |Yes|
Top 3th token. Logit: 13.44 Prob: 12.79% Token: |The|
Top 4th token. Logit: 12.94 Prob:  7.76% Token: |This|
Top 5th token. Logit: 12.12 Prob:  3.44% Token: |I|
Top 6th token. Logit: 12.00 Prob:  3.04% Token: |There|
Top 7th token. Logit: 11.75 Prob:  2.37% Token: |Answer|
Top 8th token. Logit: 11.69 Prob:  2.22% Token: |Based|
Top 9th token. Logit: 11.38 Prob:  1.62% Token: |We|


Sarah is fat. Sarah is happy. Sarah is curious. Is Sarah curious and happy?
Tokenized prompt: ['<s>', '[', 'INST', ']', 'Question', ':', 'Sarah', 'is', 'fat', '.', 'Sarah', 'is', 'happy', '.', 'Sarah', 'is', 'curious', '.', 'Is', 'Sarah', 'curious', 'and', 'happy', '?', 'Answer', '(', 'Yes', '/', 'No', '):', '[', '/', 'INST', ']']
Tokenized answer: ['Yes']


Top 0th token. Logit: 17.12 Prob: 88.67% Token: |Yes|
Top 1th token. Logit: 13.94 Prob:  3.66% Token: |It|
Top 2th token. Logit: 13.69 Prob:  2.86% Token: |The|
Top 3th token. Logit: 13.06 Prob:  1.53% Token: |No|
Top 4th token. Logit: 12.06 Prob:  0.56% Token: |This|
Top 5th token. Logit: 11.31 Prob:  0.27% Token: |Is|
Top 6th token. Logit: 11.31 Prob:  0.27% Token: |Based|
Top 7th token. Logit: 11.19 Prob:  0.23% Token: |We|
Top 8th token. Logit: 11.12 Prob:  0.22% Token: |I|
Top 9th token. Logit: 11.00 Prob:  0.19% Token: |There|


Tom is fast. Tom is fat. Tom is calm. Is Tom calm and fat?
Tokenized prompt: ['<s>', '[', 'INST', ']', 'Question', ':', 'Tom', 'is', 'fast', '.', 'Tom', 'is', 'fat', '.', 'Tom', 'is', 'calm', '.', 'Is', 'Tom', 'calm', 'and', 'fat', '?', 'Answer', '(', 'Yes', '/', 'No', '):', '[', '/', 'INST', ']']
Tokenized answer: ['Yes']


Top 0th token. Logit: 17.12 Prob: 66.41% Token: |Yes|
Top 1th token. Logit: 15.62 Prob: 14.84% Token: |The|
Top 2th token. Logit: 15.44 Prob: 12.30% Token: |No|
Top 3th token. Logit: 13.44 Prob:  1.66% Token: |This|
Top 4th token. Logit: 12.88 Prob:  0.95% Token: |Answer|
Top 5th token. Logit: 12.75 Prob:  0.84% Token: |Ah|
Top 6th token. Logit: 12.38 Prob:  0.57% Token: |Tom|
Top 7th token. Logit: 12.25 Prob:  0.51% Token: |It|
Top 8th token. Logit: 11.50 Prob:  0.24% Token: |We|
Top 9th token. Logit: 11.44 Prob:  0.23% Token: |According|


Tom is fast. Tom is fat. Tom is nervous. Is Tom calm and fat?
Tokenized prompt: ['<s>', '[', 'INST', ']', 'Question', ':', 'Tom', 'is', 'fast', '.', 'Tom', 'is', 'fat', '.', 'Tom', 'is', 'nervous', '.', 'Is', 'Tom', 'calm', 'and', 'fat', '?', 'Answer', '(', 'Yes', '/', 'No', '):', '[', '/', 'INST', ']']
Tokenized answer: ['No']


Top 0th token. Logit: 16.75 Prob: 73.05% Token: |No|
Top 1th token. Logit: 14.69 Prob:  9.33% Token: |The|
Top 2th token. Logit: 14.25 Prob:  6.01% Token: |Yes|
Top 3th token. Logit: 13.69 Prob:  3.42% Token: |Answer|
Top 4th token. Logit: 13.62 Prob:  3.22% Token: |This|
Top 5th token. Logit: 12.62 Prob:  1.18% Token: |Ah|
Top 6th token. Logit: 11.94 Prob:  0.60% Token: |(|
Top 7th token. Logit: 11.81 Prob:  0.52% Token: |Tom|
Top 8th token. Logit: 11.75 Prob:  0.49% Token: |Is|
Top 9th token. Logit: 11.31 Prob:  0.32% Token: |It|


Sarah is tall. Sarah is weak. Sarah is curious. Is Sarah strong or curious?
Tokenized prompt: ['<s>', '[', 'INST', ']', 'Question', ':', 'Sarah', 'is', 'tall', '.', 'Sarah', 'is', 'weak', '.', 'Sarah', 'is', 'curious', '.', 'Is', 'Sarah', 'strong', 'or', 'curious', '?', 'Answer', '(', 'Yes', '/', 'No', '):', '[', '/', 'INST', ']']
Tokenized answer: ['Yes']


Top 0th token. Logit: 14.38 Prob: 29.88% Token: |The|
Top 1th token. Logit: 14.00 Prob: 20.61% Token: |No|
Top 2th token. Logit: 13.75 Prob: 16.02% Token: |Answer|
Top 3th token. Logit: 13.12 Prob:  8.59% Token: |Yes|
Top 4th token. Logit: 12.56 Prob:  4.88% Token: |This|
Top 5th token. Logit: 12.56 Prob:  4.88% Token: |Cur|
Top 6th token. Logit: 12.38 Prob:  4.05% Token: |Is|
Top 7th token. Logit: 11.94 Prob:  2.61% Token: |It|
Top 8th token. Logit: 11.75 Prob:  2.17% Token: |Sarah|
Top 9th token. Logit: 11.00 Prob:  1.03% Token: |Based|


Sarah is tall. Sarah is weak. Sarah is dull. Is Sarah strong or curious?
Tokenized prompt: ['<s>', '[', 'INST', ']', 'Question', ':', 'Sarah', 'is', 'tall', '.', 'Sarah', 'is', 'weak', '.', 'Sarah', 'is', 'dull', '.', 'Is', 'Sarah', 'strong', 'or', 'curious', '?', 'Answer', '(', 'Yes', '/', 'No', '):', '[', '/', 'INST', ']']
Tokenized answer: ['No']


Top 0th token. Logit: 15.44 Prob: 50.00% Token: |No|
Top 1th token. Logit: 14.31 Prob: 16.21% Token: |The|
Top 2th token. Logit: 13.88 Prob: 10.50% Token: |Answer|
Top 3th token. Logit: 13.38 Prob:  6.35% Token: |Based|
Top 4th token. Logit: 13.31 Prob:  5.98% Token: |This|
Top 5th token. Logit: 12.44 Prob:  2.49% Token: |Is|
Top 6th token. Logit: 12.31 Prob:  2.20% Token: |There|
Top 7th token. Logit: 11.69 Prob:  1.18% Token: |I|
Top 8th token. Logit: 10.94 Prob:  0.56% Token: |It|
Top 9th token. Logit: 10.88 Prob:  0.52% Token: |Ins|


Mia is loud. Mia is old. Mia is smart. Is Mia smart or young?
Tokenized prompt: ['<s>', '[', 'INST', ']', 'Question', ':', 'M', 'ia', 'is', 'loud', '.', 'M', 'ia', 'is', 'old', '.', 'M', 'ia', 'is', 'smart', '.', 'Is', 'M', 'ia', 'smart', 'or', 'young', '?', 'Answer', '(', 'Yes', '/', 'No', '):', '[', '/', 'INST', ']']
Tokenized answer: ['Yes']


Top 0th token. Logit: 14.69 Prob: 27.73% Token: |No|
Top 1th token. Logit: 14.62 Prob: 25.98% Token: |Answer|
Top 2th token. Logit: 14.44 Prob: 21.58% Token: |The|
Top 3th token. Logit: 13.44 Prob:  7.96% Token: |Yes|
Top 4th token. Logit: 12.88 Prob:  4.54% Token: |This|
Top 5th token. Logit: 12.19 Prob:  2.28% Token: |Is|
Top 6th token. Logit: 11.75 Prob:  1.47% Token: |M|
Top 7th token. Logit: 11.75 Prob:  1.47% Token: |It|
Top 8th token. Logit: 11.50 Prob:  1.15% Token: |Based|
Top 9th token. Logit: 11.06 Prob:  0.74% Token: |There|


Mia is loud. Mia is old. Mia is dumb. Is Mia smart or young?
Tokenized prompt: ['<s>', '[', 'INST', ']', 'Question', ':', 'M', 'ia', 'is', 'loud', '.', 'M', 'ia', 'is', 'old', '.', 'M', 'ia', 'is', 'dumb', '.', 'Is', 'M', 'ia', 'smart', 'or', 'young', '?', 'Answer', '(', 'Yes', '/', 'No', '):', '[', '/', 'INST', ']']
Tokenized answer: ['No']


Top 0th token. Logit: 16.25 Prob: 72.66% Token: |No|
Top 1th token. Logit: 13.94 Prob:  7.18% Token: |The|
Top 2th token. Logit: 13.88 Prob:  6.74% Token: |Answer|
Top 3th token. Logit: 13.69 Prob:  5.59% Token: |This|
Top 4th token. Logit: 12.00 Prob:  1.04% Token: |Your|
Top 5th token. Logit: 12.00 Prob:  1.04% Token: |Based|
Top 6th token. Logit: 11.94 Prob:  0.97% Token: |There|
Top 7th token. Logit: 11.88 Prob:  0.92% Token: |I|
Top 8th token. Logit: 11.56 Prob:  0.67% Token: |Is|
Top 9th token. Logit: 10.94 Prob:  0.36% Token: |Yes|


In [176]:
model.tokenizer.padding_side = "left"
all_tokens = model.to_tokens(
    [PREFIX + prompt + SUFFIX for prompt in PROMPTS], prepend_bos=True
)
cf_tokens = model.to_tokens(
    [PREFIX + cf_prompt + SUFFIX for cf_prompt in CF_PROMPTS], prepend_bos=True
)
attention_mask = get_attention_mask(model.tokenizer, all_tokens, prepend_bos=False)
answer_prefix = " " if PREPEND_SPACE_TO_ANSWER else ""
answer_tokens = torch.tensor(
    [
        (
            model.to_single_token(answer_prefix + answer),
            model.to_single_token(answer_prefix + cf_answer),
        )
        for answer, cf_answer in zip(ANSWERS, CF_ANSWERS)
    ],
    device=device,
    dtype=torch.int64,
)
pct_true = (answer_tokens[:, 0] == answer_tokens[0, 0]).float().mean().item()
assert all_tokens.shape == cf_tokens.shape
assert (all_tokens == model.tokenizer.pad_token_id).sum() == (
    cf_tokens == model.tokenizer.pad_token_id
).sum()
# assert np.isclose(pct_true, 0.5)
print(all_tokens.shape, answer_tokens.shape, pct_true)

torch.Size([100, 38]) torch.Size([100, 2]) 0.5399999618530273


In [177]:
CENTERED = np.isclose(pct_true, 0.5)
all_logits: Float[Tensor, "batch pos d_vocab"] = model(
    all_tokens, prepend_bos=False, return_type="logits", attention_mask=attention_mask
)
all_logit_diffs = get_logit_diff(
    all_logits, answer_tokens=answer_tokens, per_prompt=True
)
if CENTERED:
    all_logit_diffs -= all_logit_diffs.mean()
print(all_logit_diffs)

tensor([ 0.0625,  8.9375,  0.2500,  1.6250, -0.8750, -1.2500,  3.1875,  2.6875,
         4.1250,  4.0000,  3.1250, -2.1250,  4.8125, -2.0625,  0.6250,  2.5625,
         4.6875,  8.0000,  1.8125,  8.4375, -1.7500,  1.6875,  4.1875,  3.1250,
         8.0625, -0.3750, -2.6875, -1.1875,  4.5000, -1.7500,  0.0000,  4.5625,
         3.5000,  8.8125,  1.8750,  3.8125, -0.1250,  2.9375,  3.0000,  6.6250,
         3.8125, -0.6250,  2.5000,  1.2500,  0.1875,  1.0000,  2.6875, -0.5000,
         4.5625,  2.2500,  0.1875,  6.6875,  4.6875,  5.4375,  0.8750,  3.7500,
         5.5625,  3.4375,  4.9375,  2.6875,  2.6250, -1.0000,  0.8125,  3.9375,
         1.0000,  1.7500,  1.1250,  5.4375,  3.3125,  4.6875,  6.2500,  5.6875,
         5.9375,  4.6250,  4.1250,  0.6875,  2.0625,  2.5625,  0.0000, -0.6875,
        -3.1875, -2.1875,  3.2500,  5.3125,  7.8750,  1.1875, -0.9375,  3.0625,
         0.9375, -1.3750, -0.8750,  4.0625,  4.3125,  1.1875, -3.2500,  1.9375,
         1.5000, -0.6250,  4.0000,  2.81

In [178]:
cf_logits: Float[Tensor, "batch pos d_vocab"] = model(
    cf_tokens, prepend_bos=False, return_type="logits", attention_mask=attention_mask
)
cf_logit_diffs = get_logit_diff(cf_logits, answer_tokens=answer_tokens, per_prompt=True)
if CENTERED:
    cf_logit_diffs -= cf_logit_diffs.mean()
print(cf_logit_diffs)

tensor([-2.7500,  1.0625, -4.0625, -2.5625, -4.6875, -5.3125, -3.1875,  0.2500,
        -3.0000, -4.9375, -4.2500, -5.1250,  0.8750, -5.3125, -6.6250, -2.4375,
        -0.4375,  4.3750, -0.9375,  1.2500, -7.5000, -0.9375,  2.8750, -0.7500,
        -1.1875, -8.3125, -6.8125, -3.1875,  0.1250, -6.6250, -7.3750, -3.5625,
        -2.5000,  1.5000,  0.3750,  1.6875, -6.8125, -1.6875, -0.0625,  0.6875,
        -1.9375, -3.4375, -4.4375, -2.1250, -4.6875, -0.9375, -3.5000, -5.6875,
        -2.9375, -0.8750, -6.0000,  0.1250,  0.8125,  1.3125, -3.6875, -0.1250,
        -1.3125,  0.3125, -1.0625, -1.9375, -3.8750, -5.1250, -3.0000, -4.0000,
        -5.0000, -4.0625, -4.3750, -1.1250, -1.4375, -4.1875,  0.5625, -0.6250,
         2.4375,  0.1875,  0.6875, -5.0625, -1.8125, -2.2500, -2.8750, -3.8750,
        -5.9375, -7.5625, -4.2500, -3.1250,  3.5000, -3.0000, -2.5000, -1.8125,
        -4.9375, -6.2500, -4.6875,  0.1875, -0.5000, -3.8125, -6.0625, -6.8750,
        -3.3750, -4.3750, -0.6250, -3.31

In [179]:
print(f"Original accuracy: {(all_logit_diffs > 0).float().mean():.2f}")
print(f"Counterfactual accuracy: {(cf_logit_diffs < 0).float().mean():.2f}")

Original accuracy: 0.77
Counterfactual accuracy: 0.79


In [180]:
print(f"Original mean: {all_logit_diffs.mean():.2f}")
print(f"Counterfactual mean: {cf_logit_diffs.mean():.2f}")

Original mean: 2.41
Counterfactual mean: -2.58
