# Setup
Libraries + **MODEL** (EleutherAI/pythia-70m-deduped)

In [34]:
# Time
import time
start_time = time.time()

In [30]:
import sys
import os
parent_dir = os.path.abspath('..')
sys.path.append(parent_dir)

from datasets import load_dataset
import random
from nnsight import LanguageModel
import torch as t
from torch import nn
from attribution import patching_effect
from dictionary_learning import AutoEncoder, ActivationBuffer
from dictionary_learning.dictionary import IdentityDict
from dictionary_learning.interp import examine_dimension
from dictionary_learning.utils import hf_dataset_to_generator
from tqdm import tqdm
import gc

DEBUGGING = False

if DEBUGGING:
    tracer_kwargs = dict(scan=True, validate=True)
else:
    tracer_kwargs = dict(scan=False, validate=False)

# model hyperparameters
DEVICE = 'cuda:0'
model = LanguageModel('EleutherAI/pythia-70m-deduped', device_map=DEVICE, dispatch=True)
activation_dim = 512

## Dataset
+ Ambiguous/balanced set:  
    + The ambiguous set, consisting of bios of male professors (labeled 0) and female nurses (labeled 1).  
    + The balanced set, consisting of an equal number of bios for male professors, male nurses, female professors, and female nurses.  
  
+ Train/test  
  
+ Intended label [profession] (0) / Unintended label [gender] (1)


In [31]:
# dataset hyperparameters
dataset = load_dataset("LabHC/bias_in_bios")
profession_dict = {'professor' : 21, 'nurse' : 13}
male_prof = 'professor'
female_prof = 'nurse'

# data preparation hyperparameters
batch_size = 1024
SEED = 42

# To fit on 24GB VRAM GPU, I set the next 2 default batch_sizes to 64
def get_data(train=True, ambiguous=True, batch_size=128, seed=SEED):
    """
    Loads and processes the dataset for training or testing.

    Parameters:
    train (bool): If True, loads the training set; otherwise, loads the test set. Default is True.
    ambiguous (bool): If True, loads the ambiguous set; otherwise, loads the balanced set. Default is True.
    batch_size (int): The size of each batch. Default is 128.
    seed (int): The random seed for shuffling the data. Default is SEED.

    Returns:
    list of tuples: A list of batches, where each batch is a tuple containing:
        - data (list): A list of text data (bios).
        - true_labels (torch.Tensor): A tensor of intended labels (profession).
        - spurious_labels (torch.Tensor): A tensor of unintended labels (gender).
    """
    #• The ambiguous set, consisting of bios of male professors (labeled 0) and female nurses (labeled 1).
    #• The balanced set, consisting of an equal number of bios for male professors, male nurses, female professors, and female nurses.
    if train:
        data = dataset['train']
    else:
        data = dataset['test']
    if ambiguous:
        neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        n = min([len(neg), len(pos)])
        neg, pos = neg[:n], pos[:n]
        data = neg + pos
        labels = [0]*n + [1]*n
        idxs = list(range(2*n))
        random.Random(seed).shuffle(idxs)
        data, labels = [data[i] for i in idxs], [labels[i] for i in idxs]
        true_labels = spurious_labels = labels
    else:
        neg_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        neg_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 1]
        pos_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 0]
        pos_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        n = min([len(neg_neg), len(neg_pos), len(pos_neg), len(pos_pos)])
        neg_neg, neg_pos, pos_neg, pos_pos = neg_neg[:n], neg_pos[:n], pos_neg[:n], pos_pos[:n]
        data = neg_neg + neg_pos + pos_neg + pos_pos
        true_labels     = [0]*n + [0]*n + [1]*n + [1]*n
        spurious_labels = [0]*n + [1]*n + [0]*n + [1]*n
        idxs = list(range(4*n))
        random.Random(seed).shuffle(idxs)
        data, true_labels, spurious_labels = [data[i] for i in idxs], [true_labels[i] for i in idxs], [spurious_labels[i] for i in idxs]

    batches = [
        (data[i:i+batch_size], t.tensor(true_labels[i:i+batch_size], device=DEVICE), t.tensor(spurious_labels[i:i+batch_size], device=DEVICE)) for i in range(0, len(data), batch_size)
    ]

    return batches

def get_subgroups(train=True, ambiguous=True, batch_size=128, seed=SEED):
    """
    Generates subgroups of text data based on gender and profession labels.
    Parameters:
    train (bool): If True, use training data; otherwise, use test data. Default is True.
    ambiguous (bool): If True, create two subgroups (ambiguous); otherwise, create four subgroups (non-ambiguous). Default is True.
    batch_size (int): The size of each batch of data. Default is 128.
    seed (int): Random seed for reproducibility.
    Returns:
    dict: A dictionary where keys are label profiles (tuples) and values are lists of batches. Each batch is a tuple containing:
        - A list of text data.
        - A tensor of the first label repeated for the batch size.
        - A tensor of the second label repeated for the batch size.
    """

    if train:
        data = dataset['train']
    else:
        data = dataset['test']
    if ambiguous:
        neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        neg_labels, pos_labels = (0, 0), (1, 1)
        subgroups = [(neg, neg_labels), (pos, pos_labels)]
    else:
        neg_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        neg_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 1]
        pos_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 0]
        pos_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        neg_neg_labels, neg_pos_labels, pos_neg_labels, pos_pos_labels = (0, 0), (0, 1), (1, 0), (1, 1)
        subgroups = [(neg_neg, neg_neg_labels), (neg_pos, neg_pos_labels), (pos_neg, pos_neg_labels), (pos_pos, pos_pos_labels)]
    
    out = {}
    for data, label_profile in subgroups:
        out[label_profile] = []
        for i in range(0, len(data), batch_size):
            text = data[i:i+batch_size]
            out[label_profile].append(
                (
                    text,
                    t.tensor([label_profile[0]]*len(text), device=DEVICE),
                    t.tensor([label_profile[1]]*len(text), device=DEVICE)
                )
            )
    return out

In [32]:
model

GPTNeoXForCausalLM(
  (gpt_neox): GPTNeoXModel(
    (embed_in): Embedding(50304, 512)
    (emb_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-5): 6 x GPTNeoXLayer(
        (input_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (post_attention_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (post_attention_dropout): Dropout(p=0.0, inplace=False)
        (post_mlp_dropout): Dropout(p=0.0, inplace=False)
        (attention): GPTNeoXSdpaAttention(
          (rotary_emb): GPTNeoXRotaryEmbedding()
          (query_key_value): Linear(in_features=512, out_features=1536, bias=True)
          (dense): Linear(in_features=512, out_features=512, bias=True)
          (attention_dropout): Dropout(p=0.0, inplace=False)
        )
        (mlp): GPTNeoXMLP(
          (dense_h_to_4h): Linear(in_features=512, out_features=2048, bias=True)
          (dense_4h_to_h): Linear(in_features=2048, out_features=512, bias=True)
        

## Probe - Linear classifier
Train and Test  
+ get activation form the LLM model

In [33]:
# probe training hyperparameters

layer = 4 # model layer for attaching linear classification head (second-last layer)

class Probe(nn.Module):
    def __init__(self, activation_dim):
        super().__init__()
        self.net = nn.Linear(activation_dim, 1, bias=True)

    def forward(self, x):
        logits = self.net(x).squeeze(-1)
        return logits

def train_probe(get_acts, label_idx=0, batches=get_data(), lr=1e-2, epochs=1, dim=512, seed=SEED):
    t.manual_seed(seed)
    probe = Probe(dim).to(DEVICE)
    optimizer = t.optim.AdamW(probe.parameters(), lr=lr)
    criterion = nn.BCEWithLogitsLoss()

    losses = []
    for epoch in range(epochs):
        for batch in batches:
            text = batch[0]
            labels = batch[label_idx+1] 
            acts = get_acts(text)
            logits = probe(acts)
            loss = criterion(logits, labels.float())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses.append(loss.item())

    return probe, losses

def test_probe(probe, get_acts, label_idx=0, batches=get_data(train=False), seed=SEED):
    with t.no_grad():
        corrects = []

        for batch in batches:
            text = batch[0]
            labels = batch[label_idx+1]
            acts = get_acts(text)
            logits = probe(acts)
            preds = (logits > 0.0).long()
            corrects.append((preds == labels).float())
        return t.cat(corrects).mean().item()
    
def get_acts(text):
    with t.no_grad(): 
        with model.trace(text, **tracer_kwargs):
            attn_mask = model.input[1]['attention_mask']
            acts = model.gpt_neox.layers[layer].output[0]
            acts = acts * attn_mask[:, :, None]
            acts = acts.sum(1) / attn_mask.sum(1)[:, None]
            acts = acts.save()
        return acts.value

In [34]:
layer

4

In [36]:
bb = get_data()

In [38]:
act_test = get_acts(bb[0][0])

In [40]:
act_test.shape

torch.Size([128, 512])

In [41]:
def get_acts_test(text):
    with t.no_grad(): 
        with model.trace(text, **tracer_kwargs):
            #attn_mask = model.input[1]['attention_mask']
            print('ciao')
            acts = model.gpt_neox.layers[layer].output[0]
            #acts = acts * attn_mask[:, :, None]
            #acts = acts.sum(1) / attn_mask.sum(1)[:, None]
            acts = acts.save()
        return acts.value
    

get_acts_test(bb[0][0])

ciao


tensor([[[ 3.4385e-01, -1.3381e-01, -1.4041e-01,  ..., -1.0235e-02,
           1.2699e+00, -2.4859e-01],
         [ 3.4698e-01, -1.0725e-01, -6.2551e-02,  ...,  2.1407e-02,
           1.2962e+00, -2.9423e-01],
         [ 3.6668e-01, -1.2669e-01, -1.2748e-01,  ...,  8.6772e-02,
           1.2887e+00, -3.3930e-01],
         ...,
         [ 2.9722e-03, -4.7632e-01, -3.1562e-01,  ...,  8.1343e-01,
           1.5681e+00, -4.6374e-01],
         [ 2.2903e-01, -8.1281e-01, -2.3526e-01,  ...,  2.1779e-01,
           4.3966e-01, -7.1307e-02],
         [ 5.2845e-01, -3.1384e-01,  1.1301e-01,  ...,  5.5013e-01,
           1.0357e+00,  2.1852e-01]],

        [[ 3.9677e-01, -2.4028e-01, -3.9316e-01,  ...,  2.0073e-01,
           1.0912e+00, -7.6334e-01],
         [ 4.7017e-01, -2.2216e-01, -3.9180e-01,  ...,  1.7163e-01,
           1.0646e+00, -7.6113e-01],
         [ 4.5083e-01, -2.2815e-01, -3.4973e-01,  ...,  1.8556e-01,
           9.8221e-01, -7.8530e-01],
         ...,
         [ 4.7234e-02, -6

In [None]:
def get_acts_test(text):
    with t.no_grad(): 
        with model.trace(text, **tracer_kwargs):
            #attn_mask = model.input[1]['attention_mask']
            acts = model.gpt_neox.layers[layer].output[0]
            print(type(acts))
            #acts = acts * attn_mask[:, :, None]
            #acts = acts.sum(1) / attn_mask.sum(1)[:, None]
            acts = acts.save()
        return acts.value
    
ttt = get_acts_test(bb[0][0])

<class 'nnsight.models.LanguageModel.LanguageModelProxy'>


In [65]:
ttt.shape

torch.Size([128, 212, 512])

### Oracle  
A classifier trained on ground-truth labels on the <u>**balanced**</u> set  
Intended (ground-truth) labels [profession] (label_idx=0) / Unintended label [gender] (label_idx=1)

In [42]:
oracle, _ = train_probe(get_acts, label_idx=0, batches=get_data(ambiguous=False)) # train on balanced training set
print("ambiguous test accuracy", test_probe(oracle, get_acts, label_idx=0)) # test on ambiguous test set
batches = get_data(train=False, ambiguous=False) # balanced test set
print("ground truth accuracy:", test_probe(oracle, get_acts, batches=batches, label_idx=0))
print("unintended feature accuracy:", test_probe(oracle, get_acts, batches=batches, label_idx=1))

ambiguous test accuracy 0.9318076372146606
ground truth accuracy: 0.9302995204925537
unintended feature accuracy: 0.49366357922554016


In [43]:
# get worst-group accuracy of oracle probe
subgroups = get_subgroups(train=False, ambiguous=False)
# label_profile: (profession (professor/nurse), gender (male/female)) [but I don't know in what order]
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(oracle, get_acts, batches=batches, label_idx=0))

Accuracy for (0, 0): 0.9468682408332825
Accuracy for (0, 1): 0.926398754119873
Accuracy for (1, 0): 0.9239631295204163
Accuracy for (1, 1): 0.9189126491546631


### Probe
A classifier trained on ground-truth labels on the <u>**ambiguous**</u> set  
Intended (ground-truth) labels [profession] (label_idx=0) / Unintended label [gender] (label_idx=1)

In [44]:
probe, _ = train_probe(get_acts, label_idx=0)
print('Ambiguous test accuracy:', test_probe(probe, get_acts, label_idx=0))
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(probe, get_acts, batches=batches, label_idx=0))
print('Unintended feature accuracy:', test_probe(probe, get_acts, batches=batches, label_idx=1))

Ambiguous test accuracy: 0.9955855011940002
Ground truth accuracy: 0.6186636090278625
Unintended feature accuracy: 0.8744239807128906


In [45]:
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(probe, get_acts, batches=batches, label_idx=0))

Accuracy for (0, 0): 0.9978401064872742
Accuracy for (0, 1): 0.24355988204479218
Accuracy for (1, 0): 0.2626728117465973
Accuracy for (1, 1): 0.9934943914413452


## Dictionary - SAEs + metric_fn
<img src="../images/LLM_schema.png" alt="LLM_schema" width="200"/>

In [46]:
print("Num of components, that is num of SAEs:", 1+(layer + 1)*3)


Num of components, that is num of SAEs: 16


In [47]:
# loading dictionaries

# dictionary hyperparameters
dict_id = 10
expansion_factor = 64
dictionary_size = expansion_factor * activation_dim

submodules = []
dictionaries = {}

submodules.append(model.gpt_neox.embed_in)
dictionaries[model.gpt_neox.embed_in] = AutoEncoder.from_pretrained(
    f'../dictionaries/pythia-70m-deduped/embed/{dict_id}_{dictionary_size}/ae.pt',
    device=DEVICE
)
for i in range(layer + 1):
    submodules.append(model.gpt_neox.layers[i].attention)
    dictionaries[model.gpt_neox.layers[i].attention] = AutoEncoder.from_pretrained(
        f'../dictionaries/pythia-70m-deduped/attn_out_layer{i}/{dict_id}_{dictionary_size}/ae.pt',
        device=DEVICE
    )

    submodules.append(model.gpt_neox.layers[i].mlp)
    dictionaries[model.gpt_neox.layers[i].mlp] = AutoEncoder.from_pretrained(
        f'../dictionaries/pythia-70m-deduped/mlp_out_layer{i}/{dict_id}_{dictionary_size}/ae.pt',
        device=DEVICE
    )

    submodules.append(model.gpt_neox.layers[i])
    dictionaries[model.gpt_neox.layers[i]] = AutoEncoder.from_pretrained(
        f'../dictionaries/pythia-70m-deduped/resid_out_layer{i}/{dict_id}_{dictionary_size}/ae.pt',
        device=DEVICE
    )

#probe is the linear classifier trained on ambiguous training set
def metric_fn(model, labels=None):
    attn_mask = model.input[1]['attention_mask']
    acts = model.gpt_neox.layers[layer].output[0]
    acts = acts * attn_mask[:, :, None]
    acts = acts.sum(1) / attn_mask.sum(1)[:, None]
    
    return t.where(
        labels == 0,
        probe(acts),
        - probe(acts)
    )

## Find most influential features
Patching effect  
ig = integrated gradient  
  
+ <u>Attribution patching</u>: activation patching at industrial scale, it uses gradients to take linear approximation to activation patching
+ <u>**Integrated gradients**</u>: a more expensive but more accurate approximation (applicable since we use small models)

In [48]:
# find most influential features
n_batches = 25
batch_size = 4

running_total = 0
nodes = None

for batch_idx, (clean, labels, _) in tqdm(enumerate(get_data(train=True, ambiguous=True, batch_size=batch_size, seed=SEED)), total=n_batches):
    if batch_idx == n_batches:
        break

    effects, _, _, _ = patching_effect(
        clean,
        None,
        model,
        submodules,
        dictionaries,
        metric_fn,
        metric_kwargs=dict(labels=labels),
        method='ig'
    )
    with t.no_grad():
        if nodes is None:
            nodes = {k : len(clean) * v.sum(dim=1).mean(dim=0) for k, v in effects.items()}
        else:
            for k, v in effects.items():
                nodes[k] += len(clean) * v.sum(dim=1).mean(dim=0)
        running_total += len(clean)
    del effects, _
    gc.collect()

nodes = {k : v / running_total for k, v in nodes.items()}

100%|██████████| 25/25 [01:23<00:00,  3.35s/it]


### Threshold
We apply some choice of node
threshold TN to select nodes with a large (absolute) IE.  
To keep the number of nodes we need to annotate manageable, we set a relatively high
node threshold of 0.1

In [49]:
n_features = 0
for component_idx, effect in enumerate(nodes.values()):
    print(f"Component {component_idx}:")
    for idx in (effect > 0.1).nonzero():
        print(idx.item(), effect[idx].item())
        n_features += 1
print(f"total features: {n_features}")

Component 0:
946 0.22117570042610168
5719 0.14886397123336792
7392 0.3218412399291992
10784 0.15049311518669128
17846 0.31437844038009644
22068 0.1622835248708725
23079 0.15247352421283722
25904 0.10423339903354645
28533 0.18998156487941742
29476 0.19172973930835724
31461 0.16866040229797363
31467 0.16690056025981903
32081 0.32751068472862244
32469 1.3901286125183105
Component 1:
23752 0.1079249158501625
Component 2:
2995 0.10343753546476364
3842 0.14583298563957214
10258 0.29746800661087036
13387 0.14572963118553162
13968 0.12934662401676178
18382 0.2589782476425171
19369 0.18766063451766968
28127 1.0541187524795532
30518 0.19712230563163757
Component 3:
1022 0.31592071056365967
9651 0.5887966156005859
10060 2.43733549118042
18967 0.6829056143760681
22084 0.2536494731903076
23898 0.4724956154823303
24799 0.10279475897550583
26504 0.30445045232772827
29626 0.2844652533531189
31201 0.17182287573814392
Component 4:
8147 0.10606715828180313
Component 5:
24159 0.24803251028060913
25018 0.4

## MANUALLY Examine most influential features  
Manually inspect and evaluate for task-relevancy each feature in the circuit from
previous Step. For each feature, examine the text data that activates the feature the most

In [50]:
# interpret features

# change the following two lines to pick which feature to interpret
component_idx = 9
feat_idx = 31098

submodule = submodules[component_idx]
dictionary = dictionaries[submodule]

# interpret some features
data = hf_dataset_to_generator("monology/pile-uncopyrighted")
buffer = ActivationBuffer(
    data,
    model,
    submodule,
    d_submodule=512,
    refresh_batch_size=128, # decrease to fit on smaller GPUs
    n_ctxs=512, # decrease to fit on smaller GPUs
    device=DEVICE
)

out = examine_dimension(
    model,
    submodule,
    buffer,
    dictionary,
    dim_idx=feat_idx,
    n_inputs=256 # decrease to fit on smaller GPUs
)
print(out.top_tokens)
print(out.top_affected)
out.top_contexts

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]



[(' nursing', 5.101802825927734), (' Nursing', 3.9356651306152344), (' nurse', 2.8416316509246826), (' nurses', 2.8025169372558594), (' RN', 1.5293744802474976), (' Teaching', 0.7016171216964722), (' rehabilitation', 0.6920690536499023), ('wife', 0.6749840974807739), ('unte', 0.6074658632278442), (' sewing', 0.6034237146377563), (' caring', 0.4292561411857605), ('ancy', 0.3882577419281006), ('akers', 0.38148069381713867), (' lending', 0.34468239545822144), (' volunteers', 0.3259945511817932), (' drinking', 0.2667090892791748), (' Clinical', 0.26244115829467773), (' inpatient', 0.23919767141342163), (' architect', 0.22628939151763916), (' Medical', 0.21756377816200256), (' relational', 0.19916176795959473), (' Dou', 0.19640305638313293), (' dialysis', 0.1761435568332672), (' Leadership', 0.1721816062927246), ('hin', 0.14378538727760315), ('okin', 0.1413027048110962), (' executive', 0.11437027156352997), (' teaching', 0.11418899148702621), (' poet', 0.10126781463623047), (' hospitals', 0

In [51]:
feats_to_ablate = {
    submodules[0] : [
        946, # 'his'
        # 5719, # 'research'
        7392, # 'He'
        # 10784, # 'Nursing'
        17846, # 'He'
        22068, # 'His'
        # 23079, # 'tastes'
        # 25904, # 'nursing'
        28533, # 'She'
        29476, # 'he'
        31461, # 'His'
        31467, # 'she'
        32081, # 'her'
        32469, # 'She'
    ],
    submodules[1] : [
        # 23752, # capitalized words, especially pronouns
    ],
    submodules[2] : [
        2995, # 'he'
        3842, # 'She'
        10258, # female names
        13387, # 'she'
        13968, # 'He'
        18382, # 'her'
        19369, # 'His'
        28127, # 'She'
        30518, # 'He'
    ],
    submodules[3] : [
        1022, # 'she'
        9651, # female names
        10060, # 'She'
        18967, # 'He'
        22084, # 'he'
        23898, # 'His'
        # 24799, # promotes surnames
        26504, # 'her'
        29626, # 'his'
        # 31201, # 'nursing'
    ],
    submodules[4] : [
        # 8147, # unclear, something with names
    ],
    submodules[5] : [
        24159, # 'She', 'she'
        25018, # female names
    ],
    submodules[6] : [
        4592, # 'her'
        8920, # 'he'
        9877, # female names
        12128, # 'his'
        15017, # 'she'
        # 17369, # contact info
        # 26969, # related to nursing
        30248, # female names
    ],
    submodules[7] : [
        13570, # promotes male-related words
        27472, # female names, promotes female-related words
    ],
    submodules[8] : [
    ],
    submodules[9] : [
        1995, # promotes female-associated words
        9128, # feminine pronouns
        11656, # promotes male-associated words
        12440, # promotes female-associated words
        # 14638, # related to contact information?
        29206, # gendered pronouns
        29295, # female names
        # 31098, # nursing-related words
    ],
    submodules[10] : [
        2959, # promotes female-associated words
        19128, # promotes male-associated words
        22029, # promotes female-associated words
    ],
    submodules[11] : [
    ],
    submodules[12] : [
        19558, # promotes female-associated words
        23545, # 'she'
        24806, # 'her'
        27334, # promotes male-associated words
        31453, # female names
    ],
    submodules[13] : [
        31101, # promotes female-associated words
    ],
    submodules[14] : [
    ],
    submodules[15] : [
        9766, # promotes female-associated words
        12420, # promotes female pronouns
        30220, # promotes male pronouns
    ]
}

print(f"Number of features to ablate: {sum(len(v) for v in feats_to_ablate.values())}")

Number of features to ablate: 55


In [52]:
# putting feats_to_ablate in a more useful format
def n_hot(feats, dim=dictionary_size):
    out = t.zeros(dim, dtype=t.bool, device=DEVICE)
    for feat in feats:
        out[feat] = True
    return out

feats_to_ablate = {
    submodule : n_hot(feats) for submodule, feats in feats_to_ablate.items()
}


In [53]:
# utilities for ablating features
is_tuple = {}
with t.no_grad(), model.trace("_"):
    for submodule in submodules:
        is_tuple[submodule] = type(submodule.output.shape) == tuple

def get_acts_ablated(
    text,
    model,
    submodules,
    dictionaries,
    to_ablate
):

    with t.no_grad(), model.trace(text, **tracer_kwargs):
        for submodule in submodules:
            dictionary = dictionaries[submodule]
            feat_idxs = to_ablate[submodule]
            x = submodule.output
            if is_tuple[submodule]:
                x = x[0]
            x_hat, f = dictionary(x, output_features=True)
            res = x - x_hat
            f[...,feat_idxs] = 0. # zero ablation
            if is_tuple[submodule]:
                submodule.output[0][:] = dictionary.decode(f) + res
            else:
                submodule.output = dictionary.decode(f) + res
        attn_mask = model.input[1]['attention_mask']
        act = model.gpt_neox.layers[layer].output[0]
        act = act * attn_mask[:, :, None]
        act = act.sum(1) / attn_mask.sum(1)[:, None]
        act = act.save()
    return act.value


# Accuracy after ablating features judged irrelevant by human annotators

In [54]:
get_acts_abl = lambda text : get_acts_ablated(text, model, submodules, dictionaries, feats_to_ablate)

print('Ambiguous test accuracy:', test_probe(probe, get_acts_abl, label_idx=0))
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(probe, get_acts_abl, batches=batches, label_idx=0))
print('Spurious accuracy:', test_probe(probe, get_acts_abl, batches=batches, label_idx=1))

Ambiguous test accuracy: 0.926347553730011
Ground truth accuracy: 0.8853686451911926
Spurious accuracy: 0.5397465229034424


For comparison  
Probe results:  
Ambiguous test accuracy: 0.9955855011940002  
Ground truth accuracy: 0.6186636090278625  
Unintended feature accuracy: 0.8744239807128906 

After ablating features:
Ambiguous test accuracy: 0.926347553730011
Ground truth accuracy: 0.8853686451911926
Spurious accuracy: 0.5397465229034424


In [55]:
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(probe, get_acts_abl, batches=batches, label_idx=0))

Accuracy for (0, 0): 0.9675408601760864
Accuracy for (0, 1): 0.9309049844741821
Accuracy for (1, 0): 0.7603686451911926
Accuracy for (1, 1): 0.8849906921386719


# Concept bottleneck probing baseline
from paper:  
Concept Bottleneck Probing (CBP), adapted from Yan et al. (2023) (originally for
multimodal text/image models). CBP works by training a probe to classify inputs
x given access only to a vector of affinities between the LM’s representation of x
and various concept vectors. See App. E.2 for implementation details.

In [56]:
concepts = [    
    ' nurse',
    ' healthcare',
    ' hospital',
    ' patient',
    ' medical',
    ' clinic',
    ' triage',
    ' medication',
    ' emergency',
    ' surgery',
    ' professor',
    ' academia',
    ' research',
    ' university',
    ' tenure',
    ' faculty',
    ' dissertation',
    ' sabbatical',
    ' publication',
    ' grant',
]
# get concept vectors
with t.no_grad(), model.trace(concepts):
    concept_vectors = model.gpt_neox.layers[layer].output[0][:, -1, :].save()
concept_vectors = concept_vectors.value - concept_vectors.value.mean(0, keepdim=True)

def get_bottleneck(text):
    with t.no_grad(), model.trace(text, **tracer_kwargs):
        attn_mask = model.input[1]['attention_mask']
        acts = model.gpt_neox.layers[layer].output[0]
        acts = acts * attn_mask[:, :, None]
        acts = acts.sum(1) / attn_mask.sum(1)[:, None]
        # compute cosine similarity with concept vectors
        sims = (acts @ concept_vectors.T) / (acts.norm(dim=-1)[:, None] @ concept_vectors.norm(dim=-1)[None])
        sims = sims.save()
    return sims.value

In [57]:
cbp_probe, _ = train_probe(get_bottleneck, label_idx=0, dim=len(concepts))
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(cbp_probe, get_bottleneck, batches=batches, label_idx=0))
print('Unintended feature accuracy:', test_probe(cbp_probe, get_bottleneck, batches=batches, label_idx=1))


Ground truth accuracy: 0.8335253596305847
Unintended feature accuracy: 0.600806474685669


In [58]:
# get subgroup accuracies
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(cbp_probe, get_bottleneck, batches=batches, label_idx=0))

Accuracy for (0, 0): 0.9276148676872253
Accuracy for (0, 1): 0.7864062786102295
Accuracy for (1, 0): 0.6774193644523621
Accuracy for (1, 1): 0.9409851431846619


# Get skyline neuron performance
The same as the feature skyline (below) but on neuron

In [59]:
# get neurons which are most influential for giving gender label
neuron_dicts = {
    submodule : IdentityDict(activation_dim).to(DEVICE) for submodule in submodules
}

n_batches = 25
batch_size = 4

running_total = 0
nodes = None

for batch_idx, (clean, _, labels) in tqdm(enumerate(get_data(train=True, ambiguous=False, batch_size=batch_size, seed=SEED)), total=n_batches):
    if batch_idx == n_batches:
        break

    effects, _, _, _ = patching_effect(
        clean,
        None,
        model,
        submodules,
        neuron_dicts,
        metric_fn,
        metric_kwargs=dict(labels=labels),
        method='ig'
    )
    with t.no_grad():
        if nodes is None:
            nodes = {k : len(clean) * v.sum(dim=1).mean(dim=0) for k, v in effects.items()}
        else:
            for k, v in effects.items():
                nodes[k] += len(clean) * v.sum(dim=1).mean(dim=0)
        running_total += len(clean)
    del effects, _
    gc.collect()

nodes = {k : v / running_total for k, v in nodes.items()}

100%|██████████| 25/25 [01:10<00:00,  2.82s/it]


In [60]:
neurons_to_ablate = {}
total_neurons = 0
for component_idx, effect in enumerate(nodes.values()):
    print(f"Component {component_idx}:")
    neurons_to_ablate[submodules[component_idx]] = []
    for idx in (effect.act > 0.2135).nonzero():
        print(idx.item(), effect[idx].item())
        neurons_to_ablate[submodules[component_idx]].append(idx.item())
        total_neurons += 1
print(f"total neurons: {total_neurons}")

neurons_to_ablate = {
    submodule : n_hot([neuron_idx], dim=512) for submodule, neuron_idx in neurons_to_ablate.items()
}

Component 0:
2 0.26396122574806213
42 0.23957575857639313
57 0.2323196530342102
99 0.32703617215156555
130 0.25990810990333557
135 0.5472005605697632
187 0.3968943655490875
197 0.3398488163948059
256 0.25377556681632996
335 0.29094165563583374
343 0.2429693192243576
400 0.2538072168827057
417 0.23600360751152039
421 0.24806241691112518
Component 1:
111 0.23795640468597412
156 0.29115042090415955
Component 2:
23 0.32223719358444214
156 0.39989417791366577
Component 3:
23 0.915780246257782
66 0.24019086360931396
111 0.3680082857608795
156 1.3047045469284058
162 0.2558390498161316
193 0.25004395842552185
209 0.22625890374183655
271 0.4073232412338257
334 0.23126104474067688
378 0.2749529480934143
394 0.25855162739753723
410 0.24639733135700226
473 0.2546446621417999
503 0.30264389514923096
Component 4:
Component 5:
Component 6:
14 0.3215298354625702
56 0.3693719506263733
89 0.2253105342388153
98 0.4586580991744995
111 1.0191479921340942
271 0.37807178497314453
369 0.3085884153842926
410 0

In [61]:
def get_acts_abl(text):
    with t.no_grad(), model.trace(text, **tracer_kwargs):
        for submodule in submodules:
            x = submodule.output
            if is_tuple[submodule]:
                x = x[0]
            x[...,neurons_to_ablate[submodule]] = x.mean(dim=(0,1))[...,neurons_to_ablate[submodule]] # mean ablation
            if is_tuple[submodule]:
                submodule.output[0][:] = x
            else:
                submodule.output = x

        attn_mask = model.input[1]['attention_mask']
        act = model.gpt_neox.layers[layer].output[0]
        act = act * attn_mask[:, :, None]
        act = act.sum(1) / attn_mask.sum(1)[:, None]
        act = act.save()
    return act.value

In [63]:
print('Ambiguous test accuracy:', test_probe(probe, get_acts_abl, label_idx=0))
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(probe, get_acts_abl, batches=batches, label_idx=0))
print('Spurious accuracy:', test_probe(probe, get_acts_abl, batches=batches, label_idx=1))

W0206 15:24:25.609000 5457 torch/fx/experimental/symbolic_shapes.py:6307] failed during evaluate_expr(Eq(4*u0, 0), hint=None, size_oblivious=False, forcing_spec=False
E0206 15:24:25.611000 5457 torch/fx/experimental/recording.py:299] failed while running evaluate_expr(*(Eq(4*u0, 0), None), **{'fx_node': False})


GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(4*u0, 0) (unhinted: Eq(4*u0, 0)).  (Size-like symbols: u0)

ATTENTION: guard_size_oblivious would fix the error, evaluating expression to False.
Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.

Caused by: (utils/_stats.py:21 in wrapper)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

In [27]:
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(probe, get_acts_abl, batches=batches, label_idx=0))

Accuracy for (0, 0): 0.9891391396522522
Accuracy for (0, 1): 0.3658280074596405
Accuracy for (1, 0): 0.5115207433700562
Accuracy for (1, 1): 0.9937267303466797


# Get skyline feature performance
from paper: Instead of relying on human judgement to evaluate whether a
feature should be ablated, we zero-ablate the 55 features from our circuit that are
most causally implicated in spurious feature accuracy on the balanced set.

In [28]:
# get features which are most useful for predicting gender label
n_batches = 25
batch_size = 4

running_total = 0
running_nodes = None

for batch_idx, (clean, _, labels) in tqdm(enumerate(get_data(train=True, ambiguous=False, batch_size=batch_size, seed=SEED)), total=n_batches):
    if batch_idx == n_batches:
        break

    effects, _, _, _ = patching_effect(
        clean,
        None,
        model,
        submodules,
        dictionaries,
        metric_fn,
        metric_kwargs=dict(labels=labels),
        method='ig'
    )
    with t.no_grad():
        if running_nodes is None:
            running_nodes = {k : len(clean) * v.sum(dim=1).mean(dim=0) for k, v in effects.items()}
        else:
            for k, v in effects.items():
                running_nodes[k] += len(clean) * v.sum(dim=1).mean(dim=0)
        running_total += len(clean)
    del effects, _
    gc.collect()

nodes = {k : v / running_total for k, v in running_nodes.items()}

100%|██████████| 25/25 [01:04<00:00,  2.58s/it]


In [29]:
top_feats_to_ablate = {}
total_features = 0
for component_idx, effect in enumerate(nodes.values()):
    print(f"Component {component_idx}:")
    top_feats_to_ablate[submodules[component_idx]] = []
    for idx in (effect > 0.1107).nonzero():
        print(idx.item(), effect[idx].item())
        top_feats_to_ablate[submodules[component_idx]].append(idx.item())
        total_features += 1
print(f"total features: {total_features}")

Component 0:
946 0.25006115436553955
7392 0.6816262602806091
17846 0.2601848244667053
28533 0.34427016973495483
29476 0.14480449259281158
31467 0.3418383300304413
32081 0.1740640252828598
32469 1.2141445875167847
Component 1:
Component 2:
3842 0.1846996247768402
10258 0.2082471251487732
13387 0.2748056948184967
18382 0.11551540344953537
19369 0.11571113020181656
28127 0.9595451354980469
30518 0.20731867849826813
31645 0.2460571974515915
Component 3:
1022 0.4973563551902771
3122 0.20655781030654907
9651 0.4463863968849182
10060 2.1757450103759766
18967 0.9766783118247986
22084 0.16898095607757568
23898 0.27730074524879456
26504 0.1362760365009308
29626 0.3603300452232361
Component 4:
Component 5:
24159 0.30897560715675354
25018 0.3229933977127075
Component 6:
4592 0.2170836329460144
8920 0.5713692307472229
9877 0.2788275182247162
12128 0.5513262152671814
12436 0.20832858979701996
15017 3.22841739654541
26204 0.12446071207523346
30248 0.6778391599655151
Component 7:
13570 0.1108106672763

In [30]:
top_feats_to_ablate = {
    submodule : n_hot(feats) for submodule, feats in top_feats_to_ablate.items()
}
get_acts_abl = lambda text : get_acts_ablated(text, model, submodules, dictionaries, top_feats_to_ablate)

In [31]:
print('Ambiguous test accuracy:', test_probe(probe, get_acts_abl, label_idx=0))
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(probe, get_acts_abl, batches=batches, label_idx=0))
print('Spurious accuracy:', test_probe(probe, get_acts_abl, batches=batches, label_idx=1))

Ambiguous test accuracy: 0.9312267303466797
Ground truth accuracy: 0.8894008994102478
Spurious accuracy: 0.539170503616333


In [32]:
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(probe, get_acts_abl, batches=batches, label_idx=0))

Accuracy for (0, 0): 0.9735883474349976
Accuracy for (0, 1): 0.9285767674446106
Accuracy for (1, 0): 0.7718893885612488
Accuracy for (1, 1): 0.8894051909446716


# Retraining probe on activations after ablating features

In [33]:
get_acts_abl = lambda text : get_acts_ablated(text, model, submodules, dictionaries, feats_to_ablate)

new_probe, _ = train_probe(get_acts_abl, label_idx=0)
print('Ambiguous test accuracy:', test_probe(new_probe, get_acts_abl, label_idx=0))
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(new_probe, get_acts_abl, batches=batches, label_idx=0))
print('Unintended feature accuracy:', test_probe(new_probe, get_acts_abl, batches=batches, label_idx=1))

Ambiguous test accuracy: 0.9572490453720093
Ground truth accuracy: 0.9308755993843079
Unintended feature accuracy: 0.5195852518081665


For comparison  
Probe results:  
Ambiguous test accuracy: 0.9955855011940002  
Ground truth accuracy: 0.6186636090278625  
Unintended feature accuracy: 0.8744239807128906  

After ablating features:  
Ambiguous test accuracy: 0.926347553730011  
Ground truth accuracy: 0.8853686451911926  
Spurious accuracy: 0.5397465229034424  

In [34]:
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(new_probe, get_acts_abl, batches=batches, label_idx=0))

Accuracy for (0, 0): 0.9477938413619995
Accuracy for (0, 1): 0.8895981907844543
Accuracy for (1, 0): 0.9423962831497192
Accuracy for (1, 1): 0.9679368138313293


In [38]:
# Time

end_time = time.time()
execution_time = end_time - start_time

# Convert the execution time to hours, minutes, and seconds
hours, remainder = divmod(execution_time, 3600)
minutes, seconds = divmod(remainder, 60)

print(f"Total execution time: {int(hours)} hours, {int(minutes)} minutes, {seconds:.2f} seconds")

Total execution time: 0 hours, 4 minutes, 31.49 seconds
