# Main Demo
## Introduction

This Jupyter Notebook contains the code to reproduce the results presented in our paper "A Mechanistic Interpretation of Syllogistic Reasoning in Auto-Regressive Language Models".

### Contents:
1. Setup
2. Dataset
3. Intervention (Figure 4)
4. Evaluation (Figure 5)
5. Transferability (Figure 6, Table 2, Figure 3)

Note: This demo is provided to ensure transparency and reproducibility of our research. If you encounter any issues or have questions, please contact the authors.

# 1. Setup

In [15]:
# if colab
!pip install torch==2.2.0
!pip install plotly==5.19.0
!pip install numpy==1.26.4
!pip install nbformat==5.9.2
!pip install kaleido==0.2.1
!pip install scipy==1.13.1
!pip install transformer_lens
#!pip install circuitsvis

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




In [1]:
# Import libraries
import torch as t
import sys
import os
import numpy as np
from torch import Tensor

import random
from functools import partial
import importlib
import circuitsvis as cv
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
from transformers import AutoTokenizer, AutoModelForCausalLM
# Get the current working directory (cwd) instead of __file__ in Jupyter Notebook
current_path = os.getcwd()

# Add the parent directory to Python's path
parent_path = os.path.abspath(os.path.join(current_path, "../app/"))
sys.path.insert(0, parent_path)

# Import your module
import importlib
from global_variables import IMAGE_DIR, DATASET_DIR
from prepare_dataset import SyllogismDataset
import helper_functions as h
import plot_utils as pu
t.set_grad_enabled(False)

# Reload the package
#importlib.reload(pu)
#importlib.reload(h)

  from .autonotebook import tqdm as notebook_tqdm


<torch.autograd.grad_mode.set_grad_enabled at 0x17ff0e460>

In [2]:
# Device Setup
device = t.device("cuda" if t.cuda.is_available() else "cpu")
device = t.device("mps") # if machine has mps setup

In [3]:
device

device(type='mps')

In [3]:
# Import the model-medium Through HookedTransformer
gpt2 = ["gpt2-small", "gpt2-medium", "gpt2-large", "gpt2-xl"]
pythia = ["pythia-70m", "pythia-160m", "pythia-410m", "pythia-1b"]
qwen = ["Qwen/Qwen2.5-0.5B", "Qwen/Qwen2.5-0.5B-Instruct"]
llama = ["meta-llama/Llama-3.2-1B"]
current_model = pythia[2]
#current_model = pythia

model = HookedTransformer.from_pretrained(
    current_model,
    fold_ln=False,
    device = device
)

The `GPTNeoXSdpaAttention` class is deprecated in favor of simply modifying the `config._attn_implementation`attribute of the `GPTNeoXAttention` class! It will be removed in v4.48


Loaded pretrained model pythia-410m into HookedTransformer


In [6]:
# if fine-tuned model is used
# model_name = "DebateLabKIT/cript-medium"  # Example: 'bert-base-uncased', 'gpt2', etc.
# f_tokenizer = AutoTokenizer.from_pretrained(model_name)
# f_model = AutoModelForCausalLM.from_pretrained(model_name)
# f_model = f_model.to(device)

# model = h.align_fine_tuning_gpt2(model, f_model, device)

# current_model = "cript-medium"

In [4]:
# label setup for sequence
image_dir = IMAGE_DIR / current_model
print(f"Image Directory: {image_dir}")
labels = [
    "BEGIN",
    "s_1",
    "s_1 -> m_1",
    "m_1",
    "m_1 -> m_2",
    "m_2",
    "m_2 -> p",
    "p",
    "p -> s_2",
    "s_2",
    "END"
]

Image Directory: ../images/pythia-410m


# 2. Dataset

In [5]:
# Ensure the Dataset folder and dataset_generator.py file is in the same directory as this notebook
# This file contains the custom data generation functions used in our experiments
# 90 samples for each dataset, seed is set to 317 for reproducibility
N = 30
seed = 317

In [6]:
# Symbolic
s_dataset = SyllogismDataset(
            N=N/6, # because of permutation
            seed=seed,
            device=device,
            type='symbolic',
            template_type='AAA1'
        )
for s, l in zip(s_dataset.sentences[:6], s_dataset.labels[:6]):
    print(f'{s} => {l}')

All O are A. All A are W. Therefore, all O are =>  W
All O are W. All W are A. Therefore, all O are =>  A
All A are O. All O are W. Therefore, all A are =>  W
All A are W. All W are O. Therefore, all A are =>  O
All W are O. All O are A. Therefore, all W are =>  A
All W are A. All A are O. Therefore, all W are =>  O


In [48]:
len(s_dataset.sentences)

90

In [49]:
# Belief-Consistent
bc_dataset = SyllogismDataset(
            N=N,
            seed=seed,
            device=device,
            type='consistent',
            template_type='AAA1'
        )
for s, l in zip(bc_dataset.sentences[:6], bc_dataset.labels[:6]):
    print(f'{s} => {l}')

All men are humans. All humans are mortal. Therefore, all men are =>  mortal
All dens are units. All units are parts. Therefore, all dens are =>  parts
All pets are animals. All animals are organisms. Therefore, all pets are =>  organisms
All openings are starts. All starts are beginnings. Therefore, all openings are =>  beginnings
All fronts are sides. All sides are surfaces. Therefore, all fronts are =>  surfaces
All fences are boundaries. All boundaries are lines. Therefore, all fences are =>  lines


In [50]:
# Belief-Inconsistent
bi_dataset = SyllogismDataset(
            N=N,
            seed=seed,
            device=device,
            type='inconsistent',
            template_type='AAA1'
        )
for s, l in zip(bi_dataset.sentences[:6], bi_dataset.labels[:6]):
    print(f'{s} => {l}')

All patients are cases. All cases are arguments. Therefore, all patients are =>  arguments
All allegations are claims. All claims are rights. Therefore, all allegations are =>  rights
All wards are people. All people are judges. Therefore, all wards are =>  judges
All losers are people. All people are racists. Therefore, all losers are =>  racists
All aliens are people. All people are parents. Therefore, all aliens are =>  parents
All expectations are feelings. All feelings are sensitivity. Therefore, all expectations are =>  sensitivity


# 3. Empirical Evaluation Part

## Accuracy

In [9]:
# accuracy evaluation
def evaluate_accuracy(dataset, model, device):
    correct = 0
    total = 0
    for i, (sentence, label) in enumerate(zip(dataset.sentences, dataset.labels)):
        #print(sentence, label)
        # Encode the sentence
        input_ids = model.tokenizer.encode(sentence, return_tensors="pt").to(device)
        # Generate the prediction
        output = model(input_ids)
        #print(output.shape)
        # Get the predicted label
        predicted_label = model.tokenizer.decode(output[0, -1].argmax(dim=-1))
        #print(predicted_label)
        # Check if the prediction is correct
        if predicted_label == label:
            correct += 1
        total += 1
    return correct / total

# Evaluate the accuracy of the model on the symbolic dataset
symbolic_accuracy = evaluate_accuracy(s_dataset, model, device)
print(f"Symbolic Accuracy: {symbolic_accuracy}")

Symbolic Accuracy: 1.0


## Localization of Transitive Reasoning Mechanisms

In [10]:
# corruption function setup
ALPHABET_LIST = [' A', ' B', ' C', ' D', ' E', ' F', ' G', ' H', ' I', ' J', ' K', ' L', ' M', ' N', ' O', ' P', ' Q', ' R', ' S', ' T', ' U', ' V', ' W', ' X', ' Y', ' Z']
NUMBER_LIST = [' 1', ' 2', ' 3', ' 4', ' 5',' 6',' 7',' 8',' 9']
def middle_term_corrupt(prompts, As, Bs, labels, prompt_type):
    if prompt_type == 'symbolic':
        candidates = ALPHABET_LIST
    elif prompt_type == 'numeric':
        candidates = NUMBER_LIST
    elif prompt_type == 'non-symbolic':
        candidates = bc_dataset.A + bc_dataset.B

    corrupted_prompts = []

    for i in range(len(prompts)):
        prompt = prompts[i]
        label = labels[i]
        corrupted_labels = []
        A = As[i]
        B = Bs[i]

        new_list = list(filter(lambda x: x not in [A, B, label], candidates))
        target = random.choice(new_list)
        corrupted_prompt = prompt.replace('All' + B , 'All'+ target)
        corrupted_prompts.append(corrupted_prompt)
        corrupted_labels.append(target)

    return corrupted_prompts


def all_term_corrupt(prompts, As, Bs, labels, prompt_type):
    if prompt_type == 'symbolic':
        candidates = ALPHABET_LIST
    elif prompt_type == 'numeric':
        candidates = NUMBER_LIST
    elif prompt_type == 'non-symbolic':
        candidates = bc_dataset.A + bc_dataset.B

    corrupted_prompts = []
    for i in range(len(prompts)):
        prompt = prompts[i]
        label = labels[i]
        A = As[i]
        B = Bs[i]
        new_list = list(filter(lambda x: x not in [A, B, label], candidates))
        target = random.sample(new_list, 3)
        corrupted_prompt = 'All' + target[0] + ' are' + target[1] + '. All' + target[1] + ' are' + target[2] + '. Therefore, all' + target[0] + ' are'
        corrupted_prompts.append(corrupted_prompt)
    return corrupted_prompts



In [11]:
# answer pair setting (p, m)
s_prompts = s_dataset.sentences
s_labels = s_dataset.labels
s_second_labels = s_dataset.B
s_As = s_dataset.A
s_Bs = s_dataset.B

s_answers = list(zip(s_labels, s_second_labels))
s_answer_tokens = t.concat([
    model.to_tokens(names, prepend_bos=False).T for names in s_answers
])
for prompt, (p, m) in zip(s_prompts[:6], s_answers):
    print(prompt, p, m)

All O are A. All A are W. Therefore, all O are  W  A
All O are W. All W are A. Therefore, all O are  A  W
All A are O. All O are W. Therefore, all A are  W  O
All A are W. All W are O. Therefore, all A are  O  W
All W are O. All O are A. Therefore, all W are  A  O
All W are A. All A are O. Therefore, all W are  O  A


In [152]:
bc_prompts = bc_dataset.sentences
bc_labels = bc_dataset.labels
bc_second_labels = bc_dataset.B
bc_As = bc_dataset.A
bc_Bs = bc_dataset.B

bc_answers = list(zip(bc_labels, bc_second_labels))
bc_answer_tokens = t.concat([
    model.to_tokens(names, prepend_bos=False).T for names in bc_answers
])
for prompt, (p, m) in zip(bc_prompts[:6], bc_answers):
    print(prompt, p, m)

print(bc_answer_tokens.shape)

All men are humans. All humans are mortal. Therefore, all men are  mortal  humans
All dens are units. All units are parts. Therefore, all dens are  parts  units
All pets are animals. All animals are organisms. Therefore, all pets are  organisms  animals
All openings are starts. All starts are beginnings. Therefore, all openings are  beginnings  starts
All fronts are sides. All sides are surfaces. Therefore, all fronts are  surfaces  sides
All fences are boundaries. All boundaries are lines. Therefore, all fences are  lines  boundaries
torch.Size([60, 2])


In [153]:
bi_prompts = bi_dataset.sentences
bi_labels = bi_dataset.labels
bi_second_labels = bi_dataset.B
bi_As = bi_dataset.A
bi_Bs = bi_dataset.B

bi_answers = list(zip(bi_labels, bi_second_labels))
bi_answer_tokens = t.concat([
    model.to_tokens(names, prepend_bos=False).T for names in bi_answers
])

for prompt, (p, m) in zip(bi_prompts[:6], bi_answers):
    print(prompt, p, m)

print(bi_answer_tokens.shape)

All patients are cases. All cases are arguments. Therefore, all patients are  arguments  cases
All allegations are claims. All claims are rights. Therefore, all allegations are  rights  claims
All wards are people. All people are judges. Therefore, all wards are  judges  people
All losers are people. All people are racists. Therefore, all losers are  racists  people
All aliens are people. All people are parents. Therefore, all aliens are  parents  people
All expectations are feelings. All feelings are sensitivity. Therefore, all expectations are  sensitivity  feelings
torch.Size([68, 2])


In [12]:
# Corruption and tokenization
s_corrupted_prompts = middle_term_corrupt(s_prompts, s_As, s_Bs, s_labels, 'symbolic')

s_tokens = model.to_tokens(s_prompts, prepend_bos=False).to(device)
s_corrupted_tokens = model.to_tokens(s_corrupted_prompts, prepend_bos=False).to(device)

In [119]:
bc_corrupted_prompts = middle_term_corrupt(bc_prompts, bc_As, bc_Bs, bc_labels, 'non-symbolic')
bc_tokens = model.to_tokens(bc_prompts, prepend_bos=False).to(device)
bc_corrupted_tokens = model.to_tokens(bc_corrupted_prompts, prepend_bos=False).to(device)

In [27]:
bi_corrupted_prompts = middle_term_corrupt(bi_prompts, bi_As, bi_Bs, bi_labels, 'non-symbolic')
bi_tokens = model.to_tokens(bi_prompts, prepend_bos=False).to(device)
bi_corrupted_tokens = model.to_tokens(bi_corrupted_prompts, prepend_bos=False).to(device)

In [55]:
# Compute logits for interventions
s_clean_logits, s_clean_cache = model.run_with_cache(s_tokens)
s_corrupted_logits, s_corrupted_cache = model.run_with_cache(s_corrupted_tokens)

s_clean_logit_diff = h.compute_logit_diff(s_clean_logits, s_answer_tokens)
print(f"Clean logit diff: {s_clean_logit_diff:.4f}")

s_corrupted_logit_diff = h.compute_logit_diff(s_corrupted_logits, s_answer_tokens)
print(f"Corrupted logit diff: {s_corrupted_logit_diff:.4f}")

Clean logit diff: 0.8777
Corrupted logit diff: -0.2414


In [21]:
bc_clean_logits, bc_clean_cache = model.run_with_cache(bc_tokens)
bc_corrupted_logits, bc_corrupted_cache = model.run_with_cache(bc_corrupted_tokens)

bc_clean_logit_diff = h.compute_logit_diff(bc_clean_logits, bc_answer_tokens)
print(f"Clean logit diff: {bc_clean_logit_diff:.4f}")

bc_corrupted_logit_diff = h.compute_logit_diff(bc_corrupted_logits, bc_answer_tokens)
print(f"Corrupted logit diff: {bc_corrupted_logit_diff:.4f}")

Clean logit diff: 0.2409
Corrupted logit diff: 0.2152


In [None]:
bi_clean_logits, bi_clean_cache = model.run_with_cache(bi_tokens)
bi_corrupted_logits, bi_corrupted_cache = model.run_with_cache(bi_corrupted_tokens)

bi_clean_logit_diff = h.compute_logit_diff(bi_clean_logits, bi_answer_tokens)
print(f"Clean logit diff: {bi_clean_logit_diff:.4f}")

bi_corrupted_logit_diff = h.compute_logit_diff(bi_corrupted_logits, bi_answer_tokens)
print(f"Corrupted logit diff: {bi_corrupted_logit_diff:.4f}")

In [23]:
from scipy.stats import norm

# Compute individual logit differences for standard deviation
s_clean_logit_diffs = h.compute_logit_diff(s_clean_logits, s_answer_tokens, array=True)
s_corrupted_logit_diffs = h.compute_logit_diff(s_corrupted_logits, s_answer_tokens, array=True)

bc_clean_logit_diffs = h.compute_logit_diff(bc_clean_logits, bc_answer_tokens, array=True)
bc_corrupted_logit_diffs = h.compute_logit_diff(bc_corrupted_logits, bc_answer_tokens, array=True)

bi_clean_logit_diffs = h.compute_logit_diff(bi_clean_logits, bi_answer_tokens, array=True)
bi_corrupted_logit_diffs = h.compute_logit_diff(bi_corrupted_logits, bi_answer_tokens, array=True)

# Compute standard deviations
std_clean_s = np.std(s_clean_logit_diffs, ddof=1)  # Sample standard deviation
std_corrupted_s = np.std(s_corrupted_logit_diffs, ddof=1)

std_clean_bc = np.std(bc_clean_logit_diffs, ddof=1)  # Sample standard deviation
std_corrupted_bc = np.std(bc_corrupted_logit_diffs, ddof=1)

std_clean_bi = np.std(bi_clean_logit_diffs, ddof=1)  # Sample standard deviation
std_corrupted_bi = np.std(bi_corrupted_logit_diffs, ddof=1)

# Compute standard error
SE_s = np.sqrt((std_clean_s**2 / N) + (std_corrupted_s**2 / N))
SE_bc = np.sqrt((std_clean_bc**2 / N) + (std_corrupted_bc**2 / N))
SE_bi = np.sqrt((std_clean_bi**2 / N) + (std_corrupted_bi**2 / N))

# Compute mean difference
s_mean_diff = s_clean_logit_diff - s_corrupted_logit_diff
bc_mean_diff = bc_clean_logit_diff - bc_corrupted_logit_diff
bi_mean_diff = bi_clean_logit_diff - bi_corrupted_logit_diff

# Compute z-score
z_score_s = s_mean_diff / SE_s
z_score_bc = bc_mean_diff / SE_bc
z_score_bi = bi_mean_diff / SE_bi
print(f"Z-score (Symbolic): {z_score_s:.4f}")
print(f"Z-score (Belief-Consistent): {z_score_bc:.4f}")
print(f"Z-score (Belief-Inconsistent): {z_score_bi:.4f}")

def compute_two_tailed_p_value(z_score):
    """
    Computes the two-tailed p-value given a z-score.

    Args:
    - z_score (float): The z-score to compute the p-value for.

    Returns:
    - float: The two-tailed p-value.
    """
    # Compute the one-tailed p-value using the survival function (1 - CDF)
    one_tailed_p_value = norm.sf(abs(z_score))  # sf is 1 - CDF
    
    # Multiply by 2 for the two-tailed p-value
    two_tailed_p_value = 2 * one_tailed_p_value
    return two_tailed_p_value

p_value_s = compute_two_tailed_p_value(z_score_s.cpu().numpy())
p_value_bc = compute_two_tailed_p_value(z_score_bc.cpu().numpy())
p_value_bi = compute_two_tailed_p_value(z_score_bi.cpu().numpy())

print(f"Two-tailed p-value (Symbolic): {p_value_s:.2e}")
print(f"Two-tailed p-value (Belief-Consistent): {p_value_bc:.2e}")
print(f"Two-tailed p-value (Belief-Inconsistent): {p_value_bi:.2e}")


RuntimeError: Size does not match at dimension 0 expected index [34, 2] to be smaller than self [30, 50304] apart from dimension 1

### Figure 4(a) attention head output patching

In [14]:
s_attn_denoising = h.patching_attention(model, s_corrupted_tokens, s_clean_cache, h.metric_denoising, s_answer_tokens, s_clean_logit_diff, s_corrupted_logit_diff, 'z', device)

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:23<?, ?it/s]


RuntimeError: MPS backend out of memory (MPS allocated: 17.73 GB, other allocations: 166.70 MB, max allowed: 18.13 GB). Tried to allocate 660.50 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

In [None]:
pu.plot_attn(h.normalise_tensor(s_attn_denoising), labels, image_dir / "symbolic_denoising.png")

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Plot saved to ../images/Qwen/Qwen2.5-0.5B-Instruct/symbolic_denoising.png


### Figure 4(b) residual stream patching

In [16]:
s_resid_denoising = h.patching_residual(model, s_corrupted_tokens, s_clean_cache, h.metric_denoising, s_answer_tokens, s_clean_logit_diff, s_corrupted_logit_diff, device)

100%|██████████| 24/24 [03:47<00:00,  9.49s/it]


In [18]:
pu.plot_residual(h.normalise_tensor(h.resize_all(s_resid_denoising)), labels, image_dir / "symbolic_residual.png")

Plot saved to ../images/Qwen/Qwen2.5-0.5B-Instruct/symbolic_residual.png


### Figure 4(c) OV logit lens

In [19]:
token_alpha = model.to_tokens(ALPHABET_LIST, prepend_bos=False).to(device)
token_alpha = [ element[0].item() for element in token_alpha ]

In [20]:
deductive_head_ov = model.OV
ev = model.W_E
uev = model.W_U
ov_circuit = ev.cpu()[token_alpha, : ] @ deductive_head_ov.AB.cpu() @ uev.cpu()[:, token_alpha]

In [21]:
# OV circuit for head 11.10
layer = 20
head = 9

pu.plot_attn(h.normalise_tensor(ov_circuit[layer, head, : , :]), ALPHABET_LIST, image_dir / "symbolic_ov_circuit.png")

Plot saved to ../images/Qwen/Qwen2.5-0.5B-Instruct/symbolic_ov_circuit.png


## Localization of Term-Related Information Flow

In [13]:
# Corruption and tokenization
s_corrupted_prompts = all_term_corrupt(s_prompts, s_As, s_Bs, s_labels, 'symbolic')

s_tokens = model.to_tokens(s_prompts, prepend_bos=False).to(device)
s_corrupted_tokens = model.to_tokens(s_corrupted_prompts, prepend_bos=False).to(device)

In [14]:
# Compute logits for interventions
s_clean_logits, s_clean_cache = model.run_with_cache(s_tokens)
s_corrupted_logits, s_corrupted_cache = model.run_with_cache(s_corrupted_tokens)

s_clean_logit_diff = h.compute_logit_diff(s_clean_logits, s_answer_tokens)
print(f"Clean logit diff: {s_clean_logit_diff:.4f}")

s_corrupted_logit_diff = h.compute_logit_diff(s_corrupted_logits, s_answer_tokens)
print(f"Corrupted logit diff: {s_corrupted_logit_diff:.4f}")

Clean logit diff: 0.8832
Corrupted logit diff: 0.1459


### Figure 4(d) residual stream patching

In [15]:
s_resid_denoising2 = h.patching_residual(model, s_corrupted_tokens, s_clean_cache, h.metric_denoising, s_answer_tokens, s_clean_logit_diff, s_corrupted_logit_diff, device)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
100%|██████████| 24/24 [01:13<00:00,  3.04s/it]


In [16]:
pu.plot_residual(h.normalise_tensor(h.resize_all(s_resid_denoising2)), labels, image_dir / "symbolic_residual2.png")

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Plot saved to ../images/pythia-410m/symbolic_residual2.png


# 4. Circuit Evaluation

### Figure 5(a) correctness of the circuit

In [72]:
# necessity of symbolic
necessity_score = h.necessity_check(model, s_labels, s_tokens, s_answer_tokens, s_clean_logit_diff, 'mean', device)
sufficiency_score = h.sufficiency_check(model, s_labels, s_tokens, s_answer_tokens, s_clean_logit_diff, 'mean', device)

100%|██████████| 11/11 [00:04<00:00,  2.52it/s]
100%|██████████| 11/11 [00:02<00:00,  4.18it/s]


In [73]:
pu.plot_ablation(s_clean_logit_diff, necessity_score, sufficiency_score, image_dir / "correctness_of_circuit.png")

Plot saved to ../images/pythia-31m/correctness_of_circuit.png


### Figure 5(b) robustness of the circuit

In [99]:
# numeric perturbed dataset
n_dataset = SyllogismDataset(
            N=N/6, # because of permutation
            seed=seed,
            device=device,
            type='numeric',
            template_type='AAA1'
        )
for s, l in zip(n_dataset.sentences[:6], n_dataset.labels[:6]):
    print(f'{s} => {l}')


All 8 are 1. All 1 are 7. Therefore, all 8 are =>  7
All 8 are 7. All 7 are 1. Therefore, all 8 are =>  1
All 1 are 8. All 8 are 7. Therefore, all 1 are =>  7
All 1 are 7. All 7 are 8. Therefore, all 1 are =>  8
All 7 are 8. All 8 are 1. Therefore, all 7 are =>  1
All 7 are 1. All 1 are 8. Therefore, all 7 are =>  8


In [100]:
# Corruption (option)
def perturb_quantifier(prompts, As, Bs, labels):
    candidates = [' Every', ' Every', ' Each', ' Each', ' All']
    corrupted_prompts = []
    for i in range(len(prompts)):
        prompt = prompts[i]
        label = labels[i]
        corrupted_labels = []
        A = As[i]
        B = Bs[i]
        new_list = list(filter(lambda x: x not in [A, B, label], candidates))
        target = random.sample(new_list, 2)
        a_be = ' are' if target[0] == ' All' else ' is'
        b_be = ' are' if target[1] == ' All' else ' is'
        corrupted_prompt = prompt.replace(' All' + B + ' are' , target[0] + B + a_be).replace('All' + A + ' are', target[1][1:] + A + b_be)
        corrupted_prompts.append(corrupted_prompt)
        corrupted_labels.append(target)

    return corrupted_prompts, corrupted_labels


In [101]:
n_prompts = n_dataset.sentences
n_labels = n_dataset.labels
n_second_labels = n_dataset.B

n_answers = list(zip(n_labels, n_second_labels))
n_answer_tokens = t.concat([
    model.to_tokens(names, prepend_bos=False).T for names in n_answers
])

In [102]:
q_prompts = s_dataset.sentences
q_labels = s_dataset.labels
q_As = s_dataset.A
q_Bs = s_dataset.B

q_corrupted_prompts, _ = perturb_quantifier(q_prompts, q_As, q_Bs, q_labels)

q_answers = list(zip(q_labels, q_Bs))
q_answer_tokens = t.concat([
    model.to_tokens(names, prepend_bos=False).T for names in q_answers
])

In [103]:
# tokenisation
n_tokens = model.to_tokens(n_prompts, prepend_bos=False).to(device)
q_tokens = model.to_tokens(q_prompts, prepend_bos=False).to(device)

In [104]:
# Compute logits for interventions
n_clean_logits, n_clean_cache = model.run_with_cache(n_tokens)
q_clean_logits, q_clean_cache = model.run_with_cache(n_tokens)

n_clean_logit_diff = h.compute_logit_diff(n_clean_logits, n_answer_tokens)
print(f"Clean logit diff: {n_clean_logit_diff:.4f}")

q_clean_logit_diff = h.compute_logit_diff(q_clean_logits, q_answer_tokens)
print(f"Clean logit diff: {q_clean_logit_diff:.4f}")

Clean logit diff: 0.1938
Clean logit diff: -0.0007


In [105]:
# necessity, sufficiency of numeric perturbation
n_necessity_score = h.necessity_check(model, n_labels, n_tokens, n_answer_tokens, n_clean_logit_diff, 'mean', device)
n_sufficiency_score = h.sufficiency_check(model, n_labels, n_tokens, n_answer_tokens, n_clean_logit_diff, 'mean', device)

100%|██████████| 11/11 [00:07<00:00,  1.42it/s]
100%|██████████| 11/11 [00:02<00:00,  4.09it/s]


In [106]:
# necessity, sufficiency of quantifier perturbation
q_necessity_score = h.necessity_check(model, q_labels, q_tokens, q_answer_tokens, q_clean_logit_diff, 'mean', device)
q_sufficiency_score = h.sufficiency_check(model, q_labels, q_tokens, q_answer_tokens, q_clean_logit_diff, 'mean', device)

100%|██████████| 11/11 [00:03<00:00,  3.31it/s]
100%|██████████| 11/11 [00:02<00:00,  3.98it/s]


In [107]:
pu.plot_ablation_robust(y1=n_necessity_score, y2=n_sufficiency_score, y3=q_necessity_score, y4 = q_sufficiency_score,  baseline1= n_clean_logit_diff, baseline2=q_clean_logit_diff, save_path=image_dir / "robustness_of_circuit.png", title="")


Plot saved to ../images/cript-medium/robustness_of_circuit.png


# 5. Circuit Transferability

In [108]:
# setup
bc_prompts = bc_dataset.sentences
bc_labels = bc_dataset.labels
bc_second_labels = bc_dataset.B

bc_answers = list(zip(bc_labels, bc_second_labels))
bc_answer_tokens = t.concat([
    model.to_tokens(names, prepend_bos=False).T for names in bc_answers
])

bi_prompts = bi_dataset.sentences
bi_labels = bi_dataset.labels
bi_second_labels = bi_dataset.B

bi_answers = list(zip(bi_labels, bi_second_labels))
bi_answer_tokens = t.concat([
    model.to_tokens(names, prepend_bos=False).T for names in bi_answers
])

# tokenisation
bc_tokens = model.to_tokens(bc_prompts, prepend_bos=False).to(device)
bi_tokens = model.to_tokens(bi_prompts, prepend_bos=False).to(device)

In [109]:
# Compute logits for interventions
bc_clean_logits, bc_clean_cache = model.run_with_cache(bc_tokens)
bi_clean_logits, bi_clean_cache = model.run_with_cache(bi_tokens)

bc_clean_logit_diff = h.compute_logit_diff(bc_clean_logits, bc_answer_tokens)
print(f"Clean logit diff: {bc_clean_logit_diff:.4f}")

bi_clean_logit_diff = h.compute_logit_diff(bi_clean_logits, bi_answer_tokens)
print(f"Clean logit diff: {bi_clean_logit_diff:.4f}")

Clean logit diff: 0.5050
Clean logit diff: 0.0146


In [110]:
# belief-consistent
bc_necessity_score = h.necessity_check(model, bc_labels, bc_tokens, bc_answer_tokens, bc_clean_logit_diff, 'mean', device)
bc_sufficiency_score = h.sufficiency_check(model, bc_labels, bc_tokens, bc_answer_tokens, bc_clean_logit_diff, 'mean', device)

  0%|          | 0/11 [00:00<?, ?it/s]

100%|██████████| 11/11 [00:07<00:00,  1.54it/s]
100%|██████████| 11/11 [00:02<00:00,  4.09it/s]


In [111]:
# belief-inconsistent
bi_necessity_score = h.necessity_check(model, bi_labels, bi_tokens, bi_answer_tokens, bi_clean_logit_diff, 'mean', device)
bi_sufficiency_score = h.sufficiency_check(model, bi_labels, bi_tokens, bi_answer_tokens, bi_clean_logit_diff, 'mean', device)

100%|██████████| 11/11 [00:02<00:00,  4.09it/s]
100%|██████████| 11/11 [00:02<00:00,  4.09it/s]


### Figure 6(a) belief-consistent

In [112]:
pu.plot_ablation(bc_clean_logit_diff, bc_necessity_score, bc_sufficiency_score, image_dir / "correctness_of_circuit_bc.png")

Plot saved to ../images/cript-medium/correctness_of_circuit_bc.png


### Figure 6(b) belief-inconsistent

In [113]:
pu.plot_ablation(bi_clean_logit_diff, bi_necessity_score, bi_sufficiency_score, image_dir / "correctness_of_circuit_bi.png")

Plot saved to ../images/cript-medium/correctness_of_circuit_bi.png


### Table 2 all unconditionally valid syllogisms

In [114]:
# ordered by accuracy
moods_ordered = ['AII3', 'IAI3', 'IAI4', 'AAA1', 'EAE1', 'EIO4', 'EIO3', 'AII1', 'AOO2', 'AEE4', 'OAO3', 'EIO1', 'EIO2', 'EAE2', 'AEE2']

In [115]:
# all symbolic datasets import
s_datasets = {}

for template in moods_ordered:
    s_dataset = SyllogismDataset(
            N=N,
            seed=seed,
            device=device,
            type='symbolic',
            template_type=template
        )
    s_datasets[template] = s_dataset


In [116]:
# setup
mood_dics = []

for template in moods_ordered:
    mood_dic = {
        'mood': template,
        'prompts': s_datasets[template].sentences,
        'labels': s_datasets[template].labels,
        'second_labels': s_datasets[template].B

    }
    s_answers = list(zip(mood_dic['labels'], mood_dic['second_labels']))
    mood_dic['answers'] = s_answers
    s_answer_tokens = t.concat([
        model.to_tokens(names, prepend_bos=False).T for names in s_answers
    ])
    mood_dic['answer_tokens'] = s_answer_tokens

    mood_dics.append(mood_dic)

In [117]:
# get base logit diff for 15 syllogisms
for i, template in enumerate(moods_ordered):
    tokens = model.to_tokens(mood_dics[i]['prompts'], prepend_bos=False).to(device)
    clean_logit_diff = h.get_batched_logit_diff(5, tokens, mood_dics[i]['answer_tokens'], model)
    mood_dics[i]['tokens'] = tokens
    mood_dics[i]['clean_logit_diff'] = clean_logit_diff

    bia_diff = 0
    for correct, wrong in mood_dics[i]['answer_tokens']:
        diff = model.unembed.b_U[correct.item()] - model.unembed.b_U[wrong.item()]
        bia_diff += diff

    bia_diff = bia_diff/len(mood_dics[i]['prompts'])
    mood_dics[i]['bia_diff'] = bia_diff
    print(mood_dics[i]['clean_logit_diff'], mood_dics[i]['bia_diff'])

tensor(0.9720, device='mps:0') tensor(0., device='mps:0')
tensor(0.4753, device='mps:0') tensor(0., device='mps:0')
tensor(0.4753, device='mps:0') tensor(0., device='mps:0')
tensor(0.3170, device='mps:0') tensor(0., device='mps:0')
tensor(0.1401, device='mps:0') tensor(0., device='mps:0')
tensor(0.4769, device='mps:0') tensor(0., device='mps:0')
tensor(0.4769, device='mps:0') tensor(0., device='mps:0')
tensor(-0.1302, device='mps:0') tensor(0., device='mps:0')
tensor(-1.1784, device='mps:0') tensor(0., device='mps:0')
tensor(-0.5424, device='mps:0') tensor(0., device='mps:0')
tensor(-0.4168, device='mps:0') tensor(0., device='mps:0')
tensor(-0.9202, device='mps:0') tensor(0., device='mps:0')
tensor(-1.4959, device='mps:0') tensor(0., device='mps:0')
tensor(-1.5060, device='mps:0') tensor(0., device='mps:0')
tensor(-1.7159, device='mps:0') tensor(0., device='mps:0')


In [119]:
# ablation
for i, template in enumerate(moods_ordered):
    mean_scores = h.necessity_check(model, mood_dics[i]['labels'], mood_dics[i]['tokens'], mood_dics[i]['answer_tokens'], mood_dics[i]['clean_logit_diff'], 'mean', device)
    sf_mean_scores = h.sufficiency_check(model, mood_dics[i]['labels'], mood_dics[i]['tokens'], mood_dics[i]['answer_tokens'], mood_dics[i]['clean_logit_diff'], 'mean',device)
    mood_dics[i]['mean_scores'] = mean_scores
    mood_dics[i]['sf_mean_score'] = sf_mean_scores

100%|██████████| 11/11 [00:31<00:00,  2.82s/it]
100%|██████████| 11/11 [00:11<00:00,  1.06s/it]
100%|██████████| 11/11 [00:11<00:00,  1.00s/it]
100%|██████████| 11/11 [00:11<00:00,  1.08s/it]
100%|██████████| 11/11 [00:11<00:00,  1.03s/it]
100%|██████████| 11/11 [00:12<00:00,  1.13s/it]
100%|██████████| 11/11 [00:11<00:00,  1.01s/it]
100%|██████████| 11/11 [00:11<00:00,  1.08s/it]
100%|██████████| 11/11 [00:11<00:00,  1.02s/it]
100%|██████████| 11/11 [00:11<00:00,  1.07s/it]
100%|██████████| 11/11 [00:10<00:00,  1.07it/s]
100%|██████████| 11/11 [00:12<00:00,  1.11s/it]
100%|██████████| 11/11 [00:18<00:00,  1.65s/it]
100%|██████████| 11/11 [00:20<00:00,  1.89s/it]
100%|██████████| 11/11 [00:13<00:00,  1.24s/it]
100%|██████████| 11/11 [00:11<00:00,  1.09s/it]
100%|██████████| 11/11 [00:11<00:00,  1.04s/it]
100%|██████████| 11/11 [00:14<00:00,  1.28s/it]
100%|██████████| 11/11 [00:11<00:00,  1.00s/it]
100%|██████████| 11/11 [00:12<00:00,  1.10s/it]
100%|██████████| 11/11 [00:11<00:00,  1.

In [120]:
pu.plot_ablation_syllogisms(mood_dics = mood_dics, save_path=image_dir / "syllogism_ablation.png", title="",)

Plot saved to ../images/cript-medium/syllogism_ablation.png


In [27]:
label_token = ['All','s','are','m1','.','All','m2','are','p','.', 'Tf',',','all','s','END']
cv.attention.attention_patterns(tokens=label_token, attention=bc_clean_cache[f'blocks.4.attn.hook_pattern'][0]) # First layer attention pattern