# Main Demo
## Introduction

This Jupyter Notebook contains the code to reproduce the results presented in our paper "A Mechanistic Interpretation of Syllogistic Reasoning in Auto-Regressive Language Models".

### Contents:
1. Setup
2. Dataset
3. Intervention (Figure 4)
4. Evaluation (Figure 5)
5. Transferability (Figure 6, Table 2, Figure 3)

Note: This demo is provided to ensure transparency and reproducibility of our research. If you encounter any issues or have questions, please contact the authors.

# 1. Setup

In [1]:
# if colab
!pip install torch==2.2.0
!pip install plotly==5.19.0
!pip install numpy==1.26.4
!pip install nbformat==5.9.2
!pip install kaleido==0.2.1
!pip install transformer_lens
#!pip install circuitsvis



In [22]:
# Import libraries
import torch as t
import sys
import os
from torch import Tensor
import random
from functools import partial
import importlib
import circuitsvis as cv
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
from transformers import AutoTokenizer, AutoModelForCausalLM
# Get the current working directory (cwd) instead of __file__ in Jupyter Notebook
current_path = os.getcwd()

# Add the parent directory to Python's path
parent_path = os.path.abspath(os.path.join(current_path, "../app/"))
sys.path.insert(0, parent_path)

# Import your module
import importlib
from global_variables import IMAGE_DIR, DATASET_DIR
from prepare_dataset import SyllogismDataset
import helper_functions as h
import plot_utils as pu
t.set_grad_enabled(False)

# Reload the package
#importlib.reload(pu)
#importlib.reload(h)

<torch.autograd.grad_mode.set_grad_enabled at 0x3319f5d30>

In [23]:
# Device Setup
device = t.device("cuda" if t.cuda.is_available() else "cpu")
device = t.device("mps") # if machine has mps setup

In [27]:
# Import the model-medium Through HookedTransformer
gpt2 = ["gpt2-small", "gpt2-medium", "gpt2-large", "gpt2-xl"]
pythia = "pythia-410m"

current_model = gpt2[1] # medium
current_model_2 = pythia

model = HookedTransformer.from_pretrained(
    current_model,
    fold_ln=False,
    device = device
)

model2 = HookedTransformer.from_pretrained(
    current_model_2,
    fold_ln=False,
    device = device
)

Loaded pretrained model gpt2-medium into HookedTransformer
Loaded pretrained model pythia-410m into HookedTransformer


In [33]:
# if fine-tuned model is used
# model_name = "DebateLabKIT/cript-medium"  # Example: 'bert-base-uncased', 'gpt2', etc.
# f_tokenizer = AutoTokenizer.from_pretrained(model_name)
# f_model = AutoModelForCausalLM.from_pretrained(model_name)
# f_model = f_model.to(device)

# model = h.align_fine_tuning_gpt2(model, f_model, device)

# current_model = "cript-medium"

In [28]:
# label setup for sequence
image_dir = IMAGE_DIR / "gpt-2_vs_pythia"
print(f"Image Directory: {image_dir}")
labels = [
    "BEGIN",
    "s_1",
    "s_1 -> m_1",
    "m_1",
    "m_1 -> m_2",
    "m_2",
    "m_2 -> p",
    "p",
    "p -> s_2",
    "s_2",
    "END"
]

Image Directory: ../images/gpt-2_vs_pythia


# 2. Dataset

In [29]:
# Ensure the Dataset folder and dataset_generator.py file is in the same directory as this notebook
# This file contains the custom data generation functions used in our experiments
# 90 samples for each dataset, seed is set to 317 for reproducibility
N = 30
seed = 317

In [30]:
# Symbolic
s_dataset = SyllogismDataset(
            N=N/6, # because of permutation
            seed=seed,
            device=device,
            type='symbolic',
            template_type='AAA1'
        )
for s, l in zip(s_dataset.sentences[:6], s_dataset.labels[:6]):
    print(f'{s} => {l}')

All O are A. All A are W. Therefore, all O are =>  W
All O are W. All W are A. Therefore, all O are =>  A
All A are O. All O are W. Therefore, all A are =>  W
All A are W. All W are O. Therefore, all A are =>  O
All W are O. All O are A. Therefore, all W are =>  A
All W are A. All A are O. Therefore, all W are =>  O


In [31]:
len(s_dataset.sentences)

30

In [32]:
# Belief-Consistent
bc_dataset = SyllogismDataset(
            N=N,
            seed=seed,
            device=device,
            type='consistent',
            template_type='AAA1'
        )
for s, l in zip(bc_dataset.sentences[:6], bc_dataset.labels[:6]):
    print(f'{s} => {l}')

All men are humans. All humans are mortal. Therefore, all men are =>  mortal
All dens are units. All units are parts. Therefore, all dens are =>  parts
All pets are animals. All animals are organisms. Therefore, all pets are =>  organisms
All openings are starts. All starts are beginnings. Therefore, all openings are =>  beginnings
All fronts are sides. All sides are surfaces. Therefore, all fronts are =>  surfaces
All fences are boundaries. All boundaries are lines. Therefore, all fences are =>  lines


In [33]:
# Belief-Inconsistent
bi_dataset = SyllogismDataset(
            N=N,
            seed=seed,
            device=device,
            type='inconsistent',
            template_type='AAA1'
        )
for s, l in zip(bi_dataset.sentences[:6], bi_dataset.labels[:6]):
    print(f'{s} => {l}')

All patients are cases. All cases are arguments. Therefore, all patients are =>  arguments
All allegations are claims. All claims are rights. Therefore, all allegations are =>  rights
All wards are people. All people are judges. Therefore, all wards are =>  judges
All losers are people. All people are racists. Therefore, all losers are =>  racists
All aliens are people. All people are parents. Therefore, all aliens are =>  parents
All expectations are feelings. All feelings are sensitivity. Therefore, all expectations are =>  sensitivity


# 3. Empirical Evaluation Part

## Localization of Transitive Reasoning Mechanisms

In [104]:
# corruption function setup
ALPHABET_LIST = [' A', ' B', ' C', ' D', ' E', ' F', ' G', ' H', ' I', ' J', ' K', ' L', ' M', ' N', ' O', ' P', ' Q', ' R', ' S', ' T', ' U', ' V', ' W', ' X', ' Y', ' Z']
NUMBER_LIST = [' 1', ' 2', ' 3', ' 4', ' 5',' 6',' 7',' 8',' 9']
def middle_term_corrupt(prompts, As, Bs, labels, prompt_type):
    if prompt_type == 'symbolic':
        candidates = ALPHABET_LIST
    elif prompt_type == 'numeric':
        candidates = NUMBER_LIST
    elif prompt_type == 'non-symbolic':
        candidates = bc_dataset.A + bc_dataset.B

    corrupted_prompts = []

    for i in range(len(prompts)):
        prompt = prompts[i]
        label = labels[i]
        corrupted_labels = []
        A = As[i]
        B = Bs[i]

        new_list = list(filter(lambda x: x not in [A, B, label], candidates))
        target = random.choice(new_list)
        corrupted_prompt = prompt.replace('All' + B , 'All'+ target)
        corrupted_prompts.append(corrupted_prompt)
        corrupted_labels.append(target)

    return corrupted_prompts


def all_term_corrupt(prompts, As, Bs, labels, prompt_type):
    if prompt_type == 'symbolic':
        candidates = ALPHABET_LIST
    elif prompt_type == 'numeric':
        candidates = NUMBER_LIST
    elif prompt_type == 'non-symbolic':
        candidates = bc_dataset.A + bc_dataset.B

    corrupted_prompts = []
    for i in range(len(prompts)):
        prompt = prompts[i]
        label = labels[i]
        A = As[i]
        B = Bs[i]
        new_list = list(filter(lambda x: x not in [A, B, label], candidates))
        target = random.sample(new_list, 3)
        corrupted_prompt = 'All' + target[0] + ' are' + target[1] + '. All' + target[1] + ' are' + target[2] + '. Therefore, all' + target[0] + ' are'
        corrupted_prompts.append(corrupted_prompt)
    return corrupted_prompts



In [105]:
# answer pair setting (p, m)
s_prompts = s_dataset.sentences
s_labels = s_dataset.labels
s_second_labels = s_dataset.B
s_As = s_dataset.A
s_Bs = s_dataset.B

s_answers = list(zip(s_labels, s_second_labels))
s_answer_tokens = t.concat([
    model.to_tokens(names, prepend_bos=False).T for names in s_answers
])
for prompt, (p, m) in zip(s_prompts[:6], s_answers):
    print(prompt, p, m)

All O are A. All A are W. Therefore, all O are  W  A
All O are W. All W are A. Therefore, all O are  A  W
All A are O. All O are W. Therefore, all A are  W  O
All A are W. All W are O. Therefore, all A are  O  W
All W are O. All O are A. Therefore, all W are  A  O
All W are A. All A are O. Therefore, all W are  O  A


In [106]:
# Corruption and tokenization
s_corrupted_prompts = middle_term_corrupt(s_prompts, s_As, s_Bs, s_labels, 'symbolic')

s_tokens = model.to_tokens(s_prompts, prepend_bos=False).to(device)
s_corrupted_tokens = model.to_tokens(s_corrupted_prompts, prepend_bos=False).to(device)

In [107]:
# Compute logits for interventions
s_clean_logits, s_clean_cache = model.run_with_cache(s_tokens)
s_corrupted_logits, s_corrupted_cache = model.run_with_cache(s_corrupted_tokens)

s_clean_logit_diff = h.compute_logit_diff(s_clean_logits, s_answer_tokens)
print(f"Clean logit diff: {s_clean_logit_diff:.4f}")

s_corrupted_logit_diff = h.compute_logit_diff(s_corrupted_logits, s_answer_tokens)
print(f"Corrupted logit diff: {s_corrupted_logit_diff:.4f}")

Clean logit diff: 0.5280
Corrupted logit diff: -0.4590


In [108]:
# for pythia
s_prompts = s_dataset.sentences
s_labels = s_dataset.labels
s_second_labels = s_dataset.B
s_As = s_dataset.A
s_Bs = s_dataset.B

s_answers = list(zip(s_labels, s_second_labels))
s_answer_tokens_2 = t.concat([
    model2.to_tokens(names, prepend_bos=False).T for names in s_answers
])
for prompt, (p, m) in zip(s_prompts[:6], s_answers):
    print(prompt, p, m)
# Corruption and tokenization
s_corrupted_prompts = middle_term_corrupt(s_prompts, s_As, s_Bs, s_labels, 'symbolic')

s_tokens_2 = model2.to_tokens(s_prompts, prepend_bos=False).to(device)
s_corrupted_tokens_2 = model2.to_tokens(s_corrupted_prompts, prepend_bos=False).to(device)
# Compute logits for interventions
s_clean_logits_2, s_clean_cache_2 = model2.run_with_cache(s_tokens_2)
s_corrupted_logits_2, s_corrupted_cache_2 = model2.run_with_cache(s_corrupted_tokens_2)

s_clean_logit_diff_2 = h.compute_logit_diff(s_clean_logits_2, s_answer_tokens_2)
print(f"Clean logit diff: {s_clean_logit_diff_2:.4f}")

s_corrupted_logit_diff_2 = h.compute_logit_diff(s_corrupted_logits_2, s_answer_tokens_2)
print(f"Corrupted logit diff: {s_corrupted_logit_diff_2:.4f}")

All O are A. All A are W. Therefore, all O are  W  A
All O are W. All W are A. Therefore, all O are  A  W
All A are O. All O are W. Therefore, all A are  W  O
All A are W. All W are O. Therefore, all A are  O  W
All W are O. All O are A. Therefore, all W are  A  O
All W are A. All A are O. Therefore, all W are  O  A
Clean logit diff: 0.8832
Corrupted logit diff: -0.2000


### Figure 4(a) attention head output patching

In [109]:
s_attn_denoising = h.patching_attention(model, s_corrupted_tokens, s_clean_cache, h.metric_denoising, s_answer_tokens, s_clean_logit_diff, s_corrupted_logit_diff, 'z', device)

100%|██████████| 24/24 [01:44<00:00,  4.36s/it]


In [110]:
pu.plot_attn(h.normalise_tensor(s_attn_denoising), labels, image_dir / "symbolic_denoising.png")

Plot saved to ../images/gpt-2_vs_pythia/symbolic_denoising.png


In [111]:
s_attn_denoising_2 = h.patching_attention(model2, s_corrupted_tokens_2, s_clean_cache_2, h.metric_denoising, s_answer_tokens_2, s_clean_logit_diff_2, s_corrupted_logit_diff_2, 'z', device)

100%|██████████| 24/24 [01:32<00:00,  3.86s/it]


In [112]:
pu.plot_attn(h.normalise_tensor(s_attn_denoising_2), labels, image_dir / "symbolic_denoising.png")

Plot saved to ../images/gpt-2_vs_pythia/symbolic_denoising.png


In [119]:
import torch
import torch.nn.functional as F

# Example tensors (a and b)
# Ensure these tensors have shape (24, 11)
a = s_attn_denoising
b = s_attn_denoising_2

# Normalize each row to create valid probability distributions
# This ensures the comparison respects 2D relationships row by row
a_prob = F.softmax(a, dim=1)  # Normalize rows
b_prob = F.softmax(b, dim=1)  # Normalize rows

# Create a random baseline with the same shape as a and b
random_baseline = torch.rand_like(a_prob)  # Randomly initialized distribution
random_baseline = random_baseline / random_baseline.sum(dim=1, keepdim=True)  # Normalize rows to sum to 1

# KL Divergence function with numerical stability
def kl_divergence_2d(p, q):
    epsilon = 1e-10  # To avoid log(0)
    return (p * (p.log() - (q + epsilon).log())).sum(dim=1)  # Sum over columns for each row

# Compute KL Divergence row by row
kl_a_b_rows = kl_divergence_2d(a_prob, b_prob)  # Divergence between a and b (row-wise)
kl_a_random_rows = kl_divergence_2d(a_prob, random_baseline)  # Divergence of a from random (row-wise)
kl_b_random_rows = kl_divergence_2d(b_prob, random_baseline)  # Divergence of b from random (row-wise)

# Average KL divergences across rows
mean_kl_a_b = kl_a_b_rows.mean().item()
mean_kl_a_random = kl_a_random_rows.mean().item()
mean_kl_b_random = kl_b_random_rows.mean().item()

# Print results
print("Mean KL(a || b):", mean_kl_a_b)
print("Mean KL(a || random):", mean_kl_a_random)
print("Mean KL(b || random):", mean_kl_b_random)

Mean KL(a || b): 0.0014849206199869514
Mean KL(a || random): 0.2657891511917114
Mean KL(b || random): 0.2682020664215088


### Figure 4(b) residual stream patching

In [45]:
s_resid_denoising = h.patching_residual(model, s_corrupted_tokens, s_clean_cache, h.metric_denoising, s_answer_tokens, s_clean_logit_diff, s_corrupted_logit_diff, device)

100%|██████████| 24/24 [01:30<00:00,  3.76s/it]


In [49]:
s_resid_denoising_2 = h.patching_residual(model2, s_corrupted_tokens_2, s_clean_cache_2, h.metric_denoising, s_answer_tokens_2, s_clean_logit_diff_2, s_corrupted_logit_diff_2, device)

100%|██████████| 24/24 [01:26<00:00,  3.61s/it]


In [50]:
pu.plot_residual(h.normalise_tensor(h.resize_all(s_resid_denoising)), labels, image_dir / "symbolic_residual_gpt2.png")

Plot saved to ../images/gpt-2_vs_pythia/symbolic_residual_gpt2.png


In [51]:
pu.plot_residual(h.normalise_tensor(h.resize_all(s_resid_denoising_2)), labels, image_dir / "symbolic_residual_pythia410m.png")

Plot saved to ../images/gpt-2_vs_pythia/symbolic_residual_pythia410m.png


In [57]:
a = h.resize_all(s_resid_denoising)
b = h.resize_all(s_resid_denoising_2)

print(a.shape, b.shape)
# torch.Size([24, 11]) torch.Size([24, 11])

torch.Size([24, 11]) torch.Size([24, 11])


In [64]:
s_resid_denoising.shape, s_resid_denoising_2.shape, a.shape, b.shape

(torch.Size([24, 15]),
 torch.Size([24, 15]),
 torch.Size([24, 11]),
 torch.Size([24, 11]))

In [92]:
import torch
import torch.nn.functional as F

# Example tensors (a and b)
# Ensure these tensors have shape (24, 11)
a = h.resize_all(s_resid_denoising)
b = h.resize_all(s_resid_denoising_2)

# Normalize each row to create valid probability distributions
# This ensures the comparison respects 2D relationships row by row
a_prob = F.softmax(a, dim=1)  # Normalize rows
b_prob = F.softmax(b, dim=1)  # Normalize rows

# Create a random baseline with the same shape as a and b
random_baseline = torch.rand_like(a_prob)  # Randomly initialized distribution
random_baseline = random_baseline / random_baseline.sum(dim=1, keepdim=True)  # Normalize rows to sum to 1

# KL Divergence function with numerical stability
def kl_divergence_2d(p, q):
    epsilon = 1e-10  # To avoid log(0)
    return (p * (p.log() - (q + epsilon).log())).sum(dim=1)  # Sum over columns for each row

# Compute KL Divergence row by row
kl_a_b_rows = kl_divergence_2d(a_prob, b_prob)  # Divergence between a and b (row-wise)
kl_a_random_rows = kl_divergence_2d(a_prob, random_baseline)  # Divergence of a from random (row-wise)
kl_b_random_rows = kl_divergence_2d(b_prob, random_baseline)  # Divergence of b from random (row-wise)

# Average KL divergences across rows
mean_kl_a_b = kl_a_b_rows.mean().item()
mean_kl_a_random = kl_a_random_rows.mean().item()
mean_kl_b_random = kl_b_random_rows.mean().item()

# Print results
print("Mean KL(a || b):", mean_kl_a_b)
print("Mean KL(a || random):", mean_kl_a_random)
print("Mean KL(b || random):", mean_kl_b_random)

Mean KL(a || b): 0.00877382978796959
Mean KL(a || random): 0.3386194407939911
Mean KL(b || random): 0.34812092781066895


### Figure 4(c) OV logit lens

In [69]:
token_alpha = model.to_tokens(ALPHABET_LIST, prepend_bos=False).to(device)
token_alpha = [ element[0].item() for element in token_alpha ]

In [70]:
deductive_head_ov = model.OV
ev = model.W_E
uev = model.W_U
ov_circuit = ev.cpu()[token_alpha, : ] @ deductive_head_ov.AB.cpu() @ uev.cpu()[:, token_alpha]

In [86]:
# OV circuit for head 11.10
layer = 11
head = 10
pu.plot_attn(h.normalise_tensor(ov_circuit[layer, head, : , :]), ALPHABET_LIST, image_dir / "symbolic_ov_circuit.png")

Plot saved to ../images/gpt-2_vs_pythia/symbolic_ov_circuit.png


In [82]:
# for pythia
token_alpha_2 = model2.to_tokens(ALPHABET_LIST, prepend_bos=False).to(device)
token_alpha_2 = [ element[0].item() for element in token_alpha_2 ]
deductive_head_ov_2 = model2.OV
ev = model2.W_E
uev = model2.W_U
ov_circuit_2 = ev.cpu()[token_alpha_2, : ] @ deductive_head_ov_2.AB.cpu() @ uev.cpu()[:, token_alpha_2]
# OV circuit for head 11.10
layer = 11
head = 4
pu.plot_attn(h.normalise_tensor(ov_circuit_2[layer, head, : , :]), ALPHABET_LIST, image_dir / "symbolic_ov_circuit.png")

Plot saved to ../images/gpt-2_vs_pythia/symbolic_ov_circuit.png


In [91]:
import torch
import torch.nn.functional as F

# Example tensors (a and b)
# Ensure these tensors have shape (24, 11)
a = h.normalise_tensor(ov_circuit[11, 10, :, :])
b = h.normalise_tensor(ov_circuit_2[11, 4, :, :])

# Normalize each row to create valid probability distributions
# This ensures the comparison respects 2D relationships row by row
a_prob = F.softmax(a, dim=1)  # Normalize rows
b_prob = F.softmax(b, dim=1)  # Normalize rows

# Create a random baseline with the same shape as a and b
random_baseline = torch.rand_like(a_prob)  # Randomly initialized distribution
random_baseline = random_baseline / random_baseline.sum(dim=1, keepdim=True)  # Normalize rows to sum to 1

# KL Divergence function with numerical stability
def kl_divergence_2d(p, q):
    epsilon = 1e-10  # To avoid log(0)
    return (p * (p.log() - (q + epsilon).log())).sum(dim=1)  # Sum over columns for each row

# Compute KL Divergence row by row
kl_a_b_rows = kl_divergence_2d(a_prob, b_prob)  # Divergence between a and b (row-wise)
kl_a_random_rows = kl_divergence_2d(a_prob, random_baseline)  # Divergence of a from random (row-wise)
kl_b_random_rows = kl_divergence_2d(b_prob, random_baseline)  # Divergence of b from random (row-wise)

# Average KL divergences across rows
mean_kl_a_b = kl_a_b_rows.mean().item()
mean_kl_a_random = kl_a_random_rows.mean().item()
mean_kl_b_random = kl_b_random_rows.mean().item()

# Print results
print("Mean KL(a || b):", mean_kl_a_b)
print("Mean KL(a || random):", mean_kl_a_random)
print("Mean KL(b || random):", mean_kl_b_random)


Mean KL(a || b): 0.02601531147956848
Mean KL(a || random): 0.3195052146911621
Mean KL(b || random): 0.32373884320259094


## Localization of Term-Related Information Flow

In [93]:
# Corruption and tokenization
s_corrupted_prompts = all_term_corrupt(s_prompts, s_As, s_Bs, s_labels, 'symbolic')

s_tokens = model.to_tokens(s_prompts, prepend_bos=False).to(device)
s_corrupted_tokens = model.to_tokens(s_corrupted_prompts, prepend_bos=False).to(device)

In [94]:
# Compute logits for interventions
s_clean_logits, s_clean_cache = model.run_with_cache(s_tokens)
s_corrupted_logits, s_corrupted_cache = model.run_with_cache(s_corrupted_tokens)

s_clean_logit_diff = h.compute_logit_diff(s_clean_logits, s_answer_tokens)
print(f"Clean logit diff: {s_clean_logit_diff:.4f}")

s_corrupted_logit_diff = h.compute_logit_diff(s_corrupted_logits, s_answer_tokens)
print(f"Corrupted logit diff: {s_corrupted_logit_diff:.4f}")

Clean logit diff: 0.5280
Corrupted logit diff: 0.0404


In [98]:
# pythia410m
s_tokens = model2.to_tokens(s_prompts, prepend_bos=False).to(device)
s_corrupted_tokens_2 = model2.to_tokens(s_corrupted_prompts, prepend_bos=False).to(device)
# Compute logits for interventions
s_clean_logits_2, s_clean_cache_2 = model2.run_with_cache(s_tokens)
s_corrupted_logits_2, s_corrupted_cache_2 = model2.run_with_cache(s_corrupted_tokens_2)

s_clean_logit_diff_2 = h.compute_logit_diff(s_clean_logits_2, s_answer_tokens_2)
print(f"Clean logit diff: {s_clean_logit_diff_2:.4f}")

s_corrupted_logit_diff_2 = h.compute_logit_diff(s_corrupted_logits_2, s_answer_tokens_2)
print(f"Corrupted logit diff: {s_corrupted_logit_diff_2:.4f}")

Clean logit diff: 0.8832
Corrupted logit diff: 0.1461


### Figure 4(d) residual stream patching

In [96]:
s_resid_denoising2 = h.patching_residual(model, s_corrupted_tokens, s_clean_cache, h.metric_denoising, s_answer_tokens, s_clean_logit_diff, s_corrupted_logit_diff, device)

100%|██████████| 24/24 [01:51<00:00,  4.63s/it]


In [97]:
pu.plot_residual(h.normalise_tensor(h.resize_all(s_resid_denoising2)), labels, image_dir / "symbolic_residual2.png")

Plot saved to ../images/gpt-2_vs_pythia/symbolic_residual2.png


In [99]:
s_resid_denoising2_pythia = h.patching_residual(model2, s_corrupted_tokens_2, s_clean_cache_2, h.metric_denoising, s_answer_tokens_2, s_clean_logit_diff_2, s_corrupted_logit_diff_2, device)

100%|██████████| 24/24 [02:46<00:00,  6.94s/it]


In [100]:
pu.plot_residual(h.normalise_tensor(h.resize_all(s_resid_denoising2_pythia)), labels, image_dir / "symbolic_residual2.png")

Plot saved to ../images/gpt-2_vs_pythia/symbolic_residual2.png


In [101]:
import torch
import torch.nn.functional as F

# Example tensors (a and b)
# Ensure these tensors have shape (24, 11)
a = h.resize_all(s_resid_denoising2)
b = h.resize_all(s_resid_denoising2_pythia)

# Normalize each row to create valid probability distributions
# This ensures the comparison respects 2D relationships row by row
a_prob = F.softmax(a, dim=1)  # Normalize rows
b_prob = F.softmax(b, dim=1)  # Normalize rows

# Create a random baseline with the same shape as a and b
random_baseline = torch.rand_like(a_prob)  # Randomly initialized distribution
random_baseline = random_baseline / random_baseline.sum(dim=1, keepdim=True)  # Normalize rows to sum to 1

# KL Divergence function with numerical stability
def kl_divergence_2d(p, q):
    epsilon = 1e-10  # To avoid log(0)
    return (p * (p.log() - (q + epsilon).log())).sum(dim=1)  # Sum over columns for each row

# Compute KL Divergence row by row
kl_a_b_rows = kl_divergence_2d(a_prob, b_prob)  # Divergence between a and b (row-wise)
kl_a_random_rows = kl_divergence_2d(a_prob, random_baseline)  # Divergence of a from random (row-wise)
kl_b_random_rows = kl_divergence_2d(b_prob, random_baseline)  # Divergence of b from random (row-wise)

# Average KL divergences across rows
mean_kl_a_b = kl_a_b_rows.mean().item()
mean_kl_a_random = kl_a_random_rows.mean().item()
mean_kl_b_random = kl_b_random_rows.mean().item()

# Print results
print("Mean KL(a || b):", mean_kl_a_b)
print("Mean KL(a || random):", mean_kl_a_random)
print("Mean KL(b || random):", mean_kl_b_random)

Mean KL(a || b): 0.22703813016414642
Mean KL(a || random): 2.2060394287109375
Mean KL(b || random): 1.8735147714614868


| Experiment Description                   | Mean KL(GPT-2 || Pythia) | Mean KL(GPT-2 || random) | Mean KL(Pythia || random) |
|------------------------------------------|--------------------------|--------------------------|---------------------------|
| 1. Attention head patching               | 0.0015                   | 0.2567                   | 0.2568                    |
| 2. Residual stream patching (middle-term)| 0.0088                   | 0.3386                   | 0.3481                    |
| 3. OV circuit logit lens                 | 0.0260                   | 0.3195                   | 0.3237                    |
| 4. Residual stream patching (all-term)   | 0.2270                   | 2.2060                   | 1.8735                    |

# 4. Circuit Evaluation

### Figure 5(a) correctness of the circuit

In [72]:
# necessity of symbolic
necessity_score = h.necessity_check(model, s_labels, s_tokens, s_answer_tokens, s_clean_logit_diff, 'mean', device)
sufficiency_score = h.sufficiency_check(model, s_labels, s_tokens, s_answer_tokens, s_clean_logit_diff, 'mean', device)

100%|██████████| 11/11 [00:04<00:00,  2.52it/s]
100%|██████████| 11/11 [00:02<00:00,  4.18it/s]


In [73]:
pu.plot_ablation(s_clean_logit_diff, necessity_score, sufficiency_score, image_dir / "correctness_of_circuit.png")

Plot saved to ../images/pythia-31m/correctness_of_circuit.png


### Figure 5(b) robustness of the circuit

In [99]:
# numeric perturbed dataset
n_dataset = SyllogismDataset(
            N=N/6, # because of permutation
            seed=seed,
            device=device,
            type='numeric',
            template_type='AAA1'
        )
for s, l in zip(n_dataset.sentences[:6], n_dataset.labels[:6]):
    print(f'{s} => {l}')


All 8 are 1. All 1 are 7. Therefore, all 8 are =>  7
All 8 are 7. All 7 are 1. Therefore, all 8 are =>  1
All 1 are 8. All 8 are 7. Therefore, all 1 are =>  7
All 1 are 7. All 7 are 8. Therefore, all 1 are =>  8
All 7 are 8. All 8 are 1. Therefore, all 7 are =>  1
All 7 are 1. All 1 are 8. Therefore, all 7 are =>  8


In [100]:
# Corruption (option)
def perturb_quantifier(prompts, As, Bs, labels):
    candidates = [' Every', ' Every', ' Each', ' Each', ' All']
    corrupted_prompts = []
    for i in range(len(prompts)):
        prompt = prompts[i]
        label = labels[i]
        corrupted_labels = []
        A = As[i]
        B = Bs[i]
        new_list = list(filter(lambda x: x not in [A, B, label], candidates))
        target = random.sample(new_list, 2)
        a_be = ' are' if target[0] == ' All' else ' is'
        b_be = ' are' if target[1] == ' All' else ' is'
        corrupted_prompt = prompt.replace(' All' + B + ' are' , target[0] + B + a_be).replace('All' + A + ' are', target[1][1:] + A + b_be)
        corrupted_prompts.append(corrupted_prompt)
        corrupted_labels.append(target)

    return corrupted_prompts, corrupted_labels


In [101]:
n_prompts = n_dataset.sentences
n_labels = n_dataset.labels
n_second_labels = n_dataset.B

n_answers = list(zip(n_labels, n_second_labels))
n_answer_tokens = t.concat([
    model.to_tokens(names, prepend_bos=False).T for names in n_answers
])

In [102]:
q_prompts = s_dataset.sentences
q_labels = s_dataset.labels
q_As = s_dataset.A
q_Bs = s_dataset.B

q_corrupted_prompts, _ = perturb_quantifier(q_prompts, q_As, q_Bs, q_labels)

q_answers = list(zip(q_labels, q_Bs))
q_answer_tokens = t.concat([
    model.to_tokens(names, prepend_bos=False).T for names in q_answers
])

In [103]:
# tokenisation
n_tokens = model.to_tokens(n_prompts, prepend_bos=False).to(device)
q_tokens = model.to_tokens(q_prompts, prepend_bos=False).to(device)

In [104]:
# Compute logits for interventions
n_clean_logits, n_clean_cache = model.run_with_cache(n_tokens)
q_clean_logits, q_clean_cache = model.run_with_cache(n_tokens)

n_clean_logit_diff = h.compute_logit_diff(n_clean_logits, n_answer_tokens)
print(f"Clean logit diff: {n_clean_logit_diff:.4f}")

q_clean_logit_diff = h.compute_logit_diff(q_clean_logits, q_answer_tokens)
print(f"Clean logit diff: {q_clean_logit_diff:.4f}")

Clean logit diff: 0.1938
Clean logit diff: -0.0007


In [105]:
# necessity, sufficiency of numeric perturbation
n_necessity_score = h.necessity_check(model, n_labels, n_tokens, n_answer_tokens, n_clean_logit_diff, 'mean', device)
n_sufficiency_score = h.sufficiency_check(model, n_labels, n_tokens, n_answer_tokens, n_clean_logit_diff, 'mean', device)

100%|██████████| 11/11 [00:07<00:00,  1.42it/s]
100%|██████████| 11/11 [00:02<00:00,  4.09it/s]


In [106]:
# necessity, sufficiency of quantifier perturbation
q_necessity_score = h.necessity_check(model, q_labels, q_tokens, q_answer_tokens, q_clean_logit_diff, 'mean', device)
q_sufficiency_score = h.sufficiency_check(model, q_labels, q_tokens, q_answer_tokens, q_clean_logit_diff, 'mean', device)

100%|██████████| 11/11 [00:03<00:00,  3.31it/s]
100%|██████████| 11/11 [00:02<00:00,  3.98it/s]


In [107]:
pu.plot_ablation_robust(y1=n_necessity_score, y2=n_sufficiency_score, y3=q_necessity_score, y4 = q_sufficiency_score,  baseline1= n_clean_logit_diff, baseline2=q_clean_logit_diff, save_path=image_dir / "robustness_of_circuit.png", title="")


Plot saved to ../images/cript-medium/robustness_of_circuit.png


# 5. Circuit Transferability

In [108]:
# setup
bc_prompts = bc_dataset.sentences
bc_labels = bc_dataset.labels
bc_second_labels = bc_dataset.B

bc_answers = list(zip(bc_labels, bc_second_labels))
bc_answer_tokens = t.concat([
    model.to_tokens(names, prepend_bos=False).T for names in bc_answers
])

bi_prompts = bi_dataset.sentences
bi_labels = bi_dataset.labels
bi_second_labels = bi_dataset.B

bi_answers = list(zip(bi_labels, bi_second_labels))
bi_answer_tokens = t.concat([
    model.to_tokens(names, prepend_bos=False).T for names in bi_answers
])

# tokenisation
bc_tokens = model.to_tokens(bc_prompts, prepend_bos=False).to(device)
bi_tokens = model.to_tokens(bi_prompts, prepend_bos=False).to(device)

In [109]:
# Compute logits for interventions
bc_clean_logits, bc_clean_cache = model.run_with_cache(bc_tokens)
bi_clean_logits, bi_clean_cache = model.run_with_cache(bi_tokens)

bc_clean_logit_diff = h.compute_logit_diff(bc_clean_logits, bc_answer_tokens)
print(f"Clean logit diff: {bc_clean_logit_diff:.4f}")

bi_clean_logit_diff = h.compute_logit_diff(bi_clean_logits, bi_answer_tokens)
print(f"Clean logit diff: {bi_clean_logit_diff:.4f}")

Clean logit diff: 0.5050
Clean logit diff: 0.0146


In [110]:
# belief-consistent
bc_necessity_score = h.necessity_check(model, bc_labels, bc_tokens, bc_answer_tokens, bc_clean_logit_diff, 'mean', device)
bc_sufficiency_score = h.sufficiency_check(model, bc_labels, bc_tokens, bc_answer_tokens, bc_clean_logit_diff, 'mean', device)

  0%|          | 0/11 [00:00<?, ?it/s]

100%|██████████| 11/11 [00:07<00:00,  1.54it/s]
100%|██████████| 11/11 [00:02<00:00,  4.09it/s]


In [111]:
# belief-inconsistent
bi_necessity_score = h.necessity_check(model, bi_labels, bi_tokens, bi_answer_tokens, bi_clean_logit_diff, 'mean', device)
bi_sufficiency_score = h.sufficiency_check(model, bi_labels, bi_tokens, bi_answer_tokens, bi_clean_logit_diff, 'mean', device)

100%|██████████| 11/11 [00:02<00:00,  4.09it/s]
100%|██████████| 11/11 [00:02<00:00,  4.09it/s]


### Figure 6(a) belief-consistent

In [112]:
pu.plot_ablation(bc_clean_logit_diff, bc_necessity_score, bc_sufficiency_score, image_dir / "correctness_of_circuit_bc.png")

Plot saved to ../images/cript-medium/correctness_of_circuit_bc.png


### Figure 6(b) belief-inconsistent

In [113]:
pu.plot_ablation(bi_clean_logit_diff, bi_necessity_score, bi_sufficiency_score, image_dir / "correctness_of_circuit_bi.png")

Plot saved to ../images/cript-medium/correctness_of_circuit_bi.png


### Table 2 all unconditionally valid syllogisms

In [114]:
# ordered by accuracy
moods_ordered = ['AII3', 'IAI3', 'IAI4', 'AAA1', 'EAE1', 'EIO4', 'EIO3', 'AII1', 'AOO2', 'AEE4', 'OAO3', 'EIO1', 'EIO2', 'EAE2', 'AEE2']

In [115]:
# all symbolic datasets import
s_datasets = {}

for template in moods_ordered:
    s_dataset = SyllogismDataset(
            N=N,
            seed=seed,
            device=device,
            type='symbolic',
            template_type=template
        )
    s_datasets[template] = s_dataset


In [116]:
# setup
mood_dics = []

for template in moods_ordered:
    mood_dic = {
        'mood': template,
        'prompts': s_datasets[template].sentences,
        'labels': s_datasets[template].labels,
        'second_labels': s_datasets[template].B

    }
    s_answers = list(zip(mood_dic['labels'], mood_dic['second_labels']))
    mood_dic['answers'] = s_answers
    s_answer_tokens = t.concat([
        model.to_tokens(names, prepend_bos=False).T for names in s_answers
    ])
    mood_dic['answer_tokens'] = s_answer_tokens

    mood_dics.append(mood_dic)

In [117]:
# get base logit diff for 15 syllogisms
for i, template in enumerate(moods_ordered):
    tokens = model.to_tokens(mood_dics[i]['prompts'], prepend_bos=False).to(device)
    clean_logit_diff = h.get_batched_logit_diff(5, tokens, mood_dics[i]['answer_tokens'], model)
    mood_dics[i]['tokens'] = tokens
    mood_dics[i]['clean_logit_diff'] = clean_logit_diff

    bia_diff = 0
    for correct, wrong in mood_dics[i]['answer_tokens']:
        diff = model.unembed.b_U[correct.item()] - model.unembed.b_U[wrong.item()]
        bia_diff += diff

    bia_diff = bia_diff/len(mood_dics[i]['prompts'])
    mood_dics[i]['bia_diff'] = bia_diff
    print(mood_dics[i]['clean_logit_diff'], mood_dics[i]['bia_diff'])

tensor(0.9720, device='mps:0') tensor(0., device='mps:0')
tensor(0.4753, device='mps:0') tensor(0., device='mps:0')
tensor(0.4753, device='mps:0') tensor(0., device='mps:0')
tensor(0.3170, device='mps:0') tensor(0., device='mps:0')
tensor(0.1401, device='mps:0') tensor(0., device='mps:0')
tensor(0.4769, device='mps:0') tensor(0., device='mps:0')
tensor(0.4769, device='mps:0') tensor(0., device='mps:0')
tensor(-0.1302, device='mps:0') tensor(0., device='mps:0')
tensor(-1.1784, device='mps:0') tensor(0., device='mps:0')
tensor(-0.5424, device='mps:0') tensor(0., device='mps:0')
tensor(-0.4168, device='mps:0') tensor(0., device='mps:0')
tensor(-0.9202, device='mps:0') tensor(0., device='mps:0')
tensor(-1.4959, device='mps:0') tensor(0., device='mps:0')
tensor(-1.5060, device='mps:0') tensor(0., device='mps:0')
tensor(-1.7159, device='mps:0') tensor(0., device='mps:0')


In [119]:
# ablation
for i, template in enumerate(moods_ordered):
    mean_scores = h.necessity_check(model, mood_dics[i]['labels'], mood_dics[i]['tokens'], mood_dics[i]['answer_tokens'], mood_dics[i]['clean_logit_diff'], 'mean', device)
    sf_mean_scores = h.sufficiency_check(model, mood_dics[i]['labels'], mood_dics[i]['tokens'], mood_dics[i]['answer_tokens'], mood_dics[i]['clean_logit_diff'], 'mean',device)
    mood_dics[i]['mean_scores'] = mean_scores
    mood_dics[i]['sf_mean_score'] = sf_mean_scores

100%|██████████| 11/11 [00:31<00:00,  2.82s/it]
100%|██████████| 11/11 [00:11<00:00,  1.06s/it]
100%|██████████| 11/11 [00:11<00:00,  1.00s/it]
100%|██████████| 11/11 [00:11<00:00,  1.08s/it]
100%|██████████| 11/11 [00:11<00:00,  1.03s/it]
100%|██████████| 11/11 [00:12<00:00,  1.13s/it]
100%|██████████| 11/11 [00:11<00:00,  1.01s/it]
100%|██████████| 11/11 [00:11<00:00,  1.08s/it]
100%|██████████| 11/11 [00:11<00:00,  1.02s/it]
100%|██████████| 11/11 [00:11<00:00,  1.07s/it]
100%|██████████| 11/11 [00:10<00:00,  1.07it/s]
100%|██████████| 11/11 [00:12<00:00,  1.11s/it]
100%|██████████| 11/11 [00:18<00:00,  1.65s/it]
100%|██████████| 11/11 [00:20<00:00,  1.89s/it]
100%|██████████| 11/11 [00:13<00:00,  1.24s/it]
100%|██████████| 11/11 [00:11<00:00,  1.09s/it]
100%|██████████| 11/11 [00:11<00:00,  1.04s/it]
100%|██████████| 11/11 [00:14<00:00,  1.28s/it]
100%|██████████| 11/11 [00:11<00:00,  1.00s/it]
100%|██████████| 11/11 [00:12<00:00,  1.10s/it]
100%|██████████| 11/11 [00:11<00:00,  1.

In [120]:
pu.plot_ablation_syllogisms(mood_dics = mood_dics, save_path=image_dir / "syllogism_ablation.png", title="",)

Plot saved to ../images/cript-medium/syllogism_ablation.png
