In [1]:
import sys
import os
parent_dir = os.path.abspath('..')
sys.path.append(parent_dir)

from datasets import load_dataset
import random
from nnsight import LanguageModel
import torch as t
from torch import nn
from attribution import patching_effect
from dictionary_learning import AutoEncoder, ActivationBuffer
from dictionary_learning.dictionary import IdentityDict
from dictionary_learning.interp import examine_dimension
from dictionary_learning.utils import hf_dataset_to_generator
from tqdm import tqdm
import gc

DEBUGGING = False

if DEBUGGING:
    tracer_kwargs = dict(scan=True, validate=True)
else:
    tracer_kwargs = dict(scan=False, validate=False)

# model hyperparameters
DEVICE = 'cuda:0'
model = LanguageModel('EleutherAI/pythia-70m-deduped', device_map=DEVICE, dispatch=True)
activation_dim = 512

In [2]:
# dataset hyperparameters
dataset = load_dataset("LabHC/bias_in_bios")
profession_dict = {'professor' : 21, 'nurse' : 13}
male_prof = 'professor'
female_prof = 'nurse'

# data preparation hyperparameters
batch_size = 1024
SEED = 42

# To fit on 24GB VRAM GPU, I set the next 2 default batch_sizes to 64
def get_data(train=True, ambiguous=True, batch_size=64, seed=SEED):
    if train:
        data = dataset['train']
    else:
        data = dataset['test']
    if ambiguous:
        neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        n = min([len(neg), len(pos)])
        neg, pos = neg[:n], pos[:n]
        data = neg + pos
        labels = [0]*n + [1]*n
        idxs = list(range(2*n))
        random.Random(seed).shuffle(idxs)
        data, labels = [data[i] for i in idxs], [labels[i] for i in idxs]
        true_labels = spurious_labels = labels
    else:
        neg_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        neg_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 1]
        pos_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 0]
        pos_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        n = min([len(neg_neg), len(neg_pos), len(pos_neg), len(pos_pos)])
        neg_neg, neg_pos, pos_neg, pos_pos = neg_neg[:n], neg_pos[:n], pos_neg[:n], pos_pos[:n]
        data = neg_neg + neg_pos + pos_neg + pos_pos
        true_labels     = [0]*n + [0]*n + [1]*n + [1]*n
        spurious_labels = [0]*n + [1]*n + [0]*n + [1]*n
        idxs = list(range(4*n))
        random.Random(seed).shuffle(idxs)
        data, true_labels, spurious_labels = [data[i] for i in idxs], [true_labels[i] for i in idxs], [spurious_labels[i] for i in idxs]

    batches = [
        (data[i:i+batch_size], t.tensor(true_labels[i:i+batch_size], device=DEVICE), t.tensor(spurious_labels[i:i+batch_size], device=DEVICE)) for i in range(0, len(data), batch_size)
    ]

    return batches

def get_subgroups(train=True, ambiguous=True, batch_size=64, seed=SEED):
    if train:
        data = dataset['train']
    else:
        data = dataset['test']
    if ambiguous:
        neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        neg_labels, pos_labels = (0, 0), (1, 1)
        subgroups = [(neg, neg_labels), (pos, pos_labels)]
    else:
        neg_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        neg_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 1]
        pos_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 0]
        pos_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        neg_neg_labels, neg_pos_labels, pos_neg_labels, pos_pos_labels = (0, 0), (0, 1), (1, 0), (1, 1)
        subgroups = [(neg_neg, neg_neg_labels), (neg_pos, neg_pos_labels), (pos_neg, pos_neg_labels), (pos_pos, pos_pos_labels)]
    
    out = {}
    for data, label_profile in subgroups:
        out[label_profile] = []
        for i in range(0, len(data), batch_size):
            text = data[i:i+batch_size]
            out[label_profile].append(
                (
                    text,
                    t.tensor([label_profile[0]]*len(text), device=DEVICE),
                    t.tensor([label_profile[1]]*len(text), device=DEVICE)
                )
            )
    return out

In [3]:
# probe training hyperparameters

layer = 4 # model layer for attaching linear classification head

class Probe(nn.Module):
    def __init__(self, activation_dim):
        super().__init__()
        self.net = nn.Linear(activation_dim, 1, bias=True)

    def forward(self, x):
        logits = self.net(x).squeeze(-1)
        return logits

def train_probe(get_acts, label_idx=0, batches=get_data(), lr=1e-2, epochs=1, dim=512, seed=SEED):
    t.manual_seed(seed)
    probe = Probe(dim).to(DEVICE)
    optimizer = t.optim.AdamW(probe.parameters(), lr=lr)
    criterion = nn.BCEWithLogitsLoss()

    losses = []
    for epoch in range(epochs):
        for batch in batches:
            text = batch[0]
            labels = batch[label_idx+1] 
            acts = get_acts(text)
            logits = probe(acts)
            loss = criterion(logits, labels.float())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses.append(loss.item())

    return probe, losses

def test_probe(probe, get_acts, label_idx=0, batches=get_data(train=False), seed=SEED):
    with t.no_grad():
        corrects = []

        for batch in batches:
            text = batch[0]
            labels = batch[label_idx+1]
            acts = get_acts(text)
            logits = probe(acts)
            preds = (logits > 0.0).long()
            corrects.append((preds == labels).float())
        return t.cat(corrects).mean().item()
    
def get_acts(text):
    with t.no_grad(): 
        with model.trace(text, **tracer_kwargs):
            attn_mask = model.input[1]['attention_mask']
            acts = model.gpt_neox.layers[layer].output[0]
            acts = acts * attn_mask[:, :, None]
            acts = acts.sum(1) / attn_mask.sum(1)[:, None]
            acts = acts.save()
        return acts.value

In [4]:
oracle, _ = train_probe(get_acts, label_idx=0, batches=get_data(ambiguous=False))
print("ambiguous test accuracy", test_probe(oracle, get_acts, label_idx=0))
batches = get_data(train=False, ambiguous=False)
print("ground truth accuracy:", test_probe(oracle, get_acts, batches=batches, label_idx=0))
print("unintended feature accuracy:", test_probe(oracle, get_acts, batches=batches, label_idx=1))

You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


ambiguous test accuracy 0.9247211813926697
ground truth accuracy: 0.9285714030265808
unintended feature accuracy: 0.489631325006485


In [5]:
# get worst-group accuracy of oracle probe
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(oracle, get_acts, batches=batches, label_idx=0))

Accuracy for (0, 0): 0.9679111242294312
Accuracy for (0, 1): 0.9576417207717896
Accuracy for (1, 0): 0.9032257795333862
Accuracy for (1, 1): 0.884061336517334


In [6]:
probe, _ = train_probe(get_acts, label_idx=0)
print('Ambiguous test accuracy:', test_probe(probe, get_acts, label_idx=0))
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(probe, get_acts, batches=batches, label_idx=0))
print('Unintended feature accuracy:', test_probe(probe, get_acts, batches=batches, label_idx=1))

Ambiguous test accuracy: 0.9957016706466675
Ground truth accuracy: 0.5921658873558044
Unintended feature accuracy: 0.9020737409591675


In [7]:
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(probe, get_acts, batches=batches, label_idx=0))

Accuracy for (0, 0): 0.9977167248725891
Accuracy for (0, 1): 0.17649267613887787
Accuracy for (1, 0): 0.22580644488334656
Accuracy for (1, 1): 0.9941914677619934


In [8]:
# loading dictionaries

import importlib
import dictionary_learning
from dictionary_learning import AutoEncoder
importlib.reload(dictionary_learning)

%load_ext autoreload
%autoreload 2

# dictionary hyperparameters
dict_id = 10
expansion_factor = 64
dictionary_size = expansion_factor * activation_dim

submodules = []
dictionaries = {}

submodules.append(model.gpt_neox.embed_in)
dictionaries[model.gpt_neox.embed_in] = AutoEncoder.from_pretrained(
    f'../dictionaries/pythia-70m-deduped/embed/{dict_id}_{dictionary_size}/ae.pt',
    device=DEVICE
)
for i in range(layer + 1):
    submodules.append(model.gpt_neox.layers[i].attention)
    dictionaries[model.gpt_neox.layers[i].attention] = AutoEncoder.from_pretrained(
        f'../dictionaries/pythia-70m-deduped/attn_out_layer{i}/{dict_id}_{dictionary_size}/ae.pt',
        device=DEVICE
    )

    submodules.append(model.gpt_neox.layers[i].mlp)
    dictionaries[model.gpt_neox.layers[i].mlp] = AutoEncoder.from_pretrained(
        f'../dictionaries/pythia-70m-deduped/mlp_out_layer{i}/{dict_id}_{dictionary_size}/ae.pt',
        device=DEVICE
    )

    submodules.append(model.gpt_neox.layers[i])
    dictionaries[model.gpt_neox.layers[i]] = AutoEncoder.from_pretrained(
        f'../dictionaries/pythia-70m-deduped/resid_out_layer{i}/{dict_id}_{dictionary_size}/ae.pt',
        device=DEVICE
    )

def metric_fn(model, labels=None):
    attn_mask = model.input[1]['attention_mask']
    acts = model.gpt_neox.layers[layer].output[0]
    acts = acts * attn_mask[:, :, None]
    acts = acts.sum(1) / attn_mask.sum(1)[:, None]
    
    return t.where(
        labels == 0,
        probe(acts),
        - probe(acts)
    )

In [9]:
# find most influential features
n_batches = 25
batch_size = 4

running_total = 0
nodes = None

for batch_idx, (clean, labels, _) in tqdm(enumerate(get_data(train=True, ambiguous=True, batch_size=batch_size, seed=SEED)), total=n_batches):
    if batch_idx == n_batches:
        break

    effects, _, _, _ = patching_effect(
        clean,
        None,
        model,
        submodules,
        dictionaries,
        metric_fn,
        metric_kwargs=dict(labels=labels),
        method='ig'
    )
    with t.no_grad():
        if nodes is None:
            nodes = {k : len(clean) * v.sum(dim=1).mean(dim=0) for k, v in effects.items()}
        else:
            for k, v in effects.items():
                nodes[k] += len(clean) * v.sum(dim=1).mean(dim=0)
        running_total += len(clean)
    del effects, _
    gc.collect()

nodes = {k : v / running_total for k, v in nodes.items()}

100%|██████████| 25/25 [00:54<00:00,  2.19s/it]


In [10]:
n_features = 0
for component_idx, effect in enumerate(nodes.values()):
    print(f"Component {component_idx}:")
    for idx in (effect > 0.1).nonzero():
        print(idx.item(), effect[idx].item())
        n_features += 1
print(f"total features: {n_features}")

Component 0:
946 0.2662804424762726
5719 0.15922614932060242
7392 0.4218404293060303
10784 0.18791018426418304
17846 0.37995392084121704
22068 0.20120769739151
23079 0.17827855050563812
25904 0.13578002154827118
28533 0.2241894155740738
29476 0.2351006269454956
31461 0.20119118690490723
31467 0.19470244646072388
32081 0.39268073439598083
32469 1.5978020429611206
Component 1:
4427 0.10151158273220062
23752 0.12343307584524155
Component 2:
2995 0.12388136237859726
3842 0.17480571568012238
10258 0.3594623804092407
13387 0.17669181525707245
13968 0.15486972033977509
14861 0.11585155874490738
18382 0.31952595710754395
19369 0.2192731499671936
21736 0.10511206090450287
28127 1.267708659172058
30037 0.11253583431243896
30518 0.23797392845153809
Component 3:
1022 0.38165226578712463
9651 0.7097998857498169
10060 2.938981294631958
18967 0.8327097296714783
22084 0.3067134916782379
23898 0.5447293519973755
24799 0.12614016234874725
26504 0.38154157996177673
29626 0.3487662672996521
31201 0.210504

In [11]:
# interpret features

# change the following two lines to pick which feature to interpret
component_idx = 9
feat_idx = 31098

submodule = submodules[component_idx]
dictionary = dictionaries[submodule]

# interpret some features
data = hf_dataset_to_generator("monology/pile-uncopyrighted")
buffer = ActivationBuffer(
    data,
    model,
    submodule,
    d_submodule=512,
    refresh_batch_size=128, # decrease to fit on smaller GPUs
    n_ctxs=512, # decrease to fit on smaller GPUs
    device=DEVICE
)

out = examine_dimension(
    model,
    submodule,
    buffer,
    dictionary,
    dim_idx=feat_idx,
    n_inputs=50 # decrease to fit on smaller GPUs
)
print(out.top_tokens)
print(out.top_affected)
out.top_contexts

Downloading readme:   0%|          | 0.00/776 [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]



[(' nursing', 5.238176345825195), (' nurses', 2.8025174140930176), (' nurse', 2.7965087890625), (' Teaching', 0.7016171216964722), (' rehabilitation', 0.6920654773712158), ('unte', 0.6074657440185547), (' caring', 0.429257333278656), ('akers', 0.3814794421195984), (' volunteers', 0.32599467039108276), (' Medical', 0.21756425499916077), (' relational', 0.19916227459907532), (' Dou', 0.1964026391506195), (' teaching', 0.17128399014472961), (' health', 0.10173860192298889), (' hospitals', 0.09990262985229492), (' education', 0.08860695362091064), (' engineering', 0.07574862241744995), (' patient', 0.07372689247131348), (' stance', 0.0615004301071167), (' dors', 0.04598349332809448), (' OR', 0.038925766944885254), (' toddler', 0.030461221933364868), ('opedic', 0.02716928720474243), (' undergraduate', 0.020936429500579834), (' and', 0.009745706804096699), (' but', 0.008150335401296616), ('It', 0.0), (' is', 0.0), (' done', 0.0), (',', 0.0)]
[('�', 0.5553182363510132), ('��', 0.5427490472793

In [12]:
# putting feats_to_ablate in a more useful format
def n_hot(feats, dim=dictionary_size):
    out = t.zeros(dim, dtype=t.bool, device=DEVICE)
    for feat in feats:
        out[feat] = True
    return out


In [13]:
# utilities for ablating features
is_tuple = {}
with t.no_grad(), model.trace("_"):
    for submodule in submodules:
        is_tuple[submodule] = type(submodule.output.shape) == tuple

def get_acts_ablated(
    text,
    model,
    submodules,
    dictionaries,
    to_ablate
):

    with t.no_grad(), model.trace(text, **tracer_kwargs):
        for submodule in submodules:
            dictionary = dictionaries[submodule]
            feat_idxs = to_ablate[submodule]
            x = submodule.output
            if is_tuple[submodule]:
                x = x[0]
            x_hat, f = dictionary(x, output_features=True)
            res = x - x_hat
            f[...,feat_idxs] = 0. # zero ablation
            if is_tuple[submodule]:
                submodule.output[0][:] = dictionary.decode(f) + res
            else:
                submodule.output = dictionary.decode(f) + res
        attn_mask = model.input[1]['attention_mask']
        act = model.gpt_neox.layers[layer].output[0]
        act = act * attn_mask[:, :, None]
        act = act.sum(1) / attn_mask.sum(1)[:, None]
        act = act.save()
    return act.value


# Accuracy after ablating features judged irrelevant by human annotators

In [23]:
feats_to_ablate = {
    submodules[0] : [
        946, # 'his'
        # 5719, # 'research'
        7392, # 'He'
        # 10784, # 'Nursing'
        17846, # 'He'
        22068, # 'His'
        # 23079, # 'tastes'
        # 25904, # 'nursing'
        28533, # 'She'
        29476, # 'he'
        31461, # 'His'
        31467, # 'she'
        32081, # 'her'
        32469, # 'She'
    ],
    submodules[1] : [
        # 23752, # capitalized words, especially pronouns
    ],
    submodules[2] : [
        2995, # 'he'
        3842, # 'She'
        10258, # female names
        13387, # 'she'
        13968, # 'He'
        18382, # 'her'
        19369, # 'His'
        28127, # 'She'
        30518, # 'He'
    ],
    submodules[3] : [
        1022, # 'she'
        9651, # female names
        10060, # 'She'
        18967, # 'He'
        22084, # 'he'
        23898, # 'His'
        # 24799, # promotes surnames
        26504, # 'her'
        29626, # 'his'
        # 31201, # 'nursing'
    ],
    submodules[4] : [
        # 8147, # unclear, something with names
    ],
    submodules[5] : [
        24159, # 'She', 'she'
        25018, # female names
    ],
    submodules[6] : [
        4592, # 'her'
        8920, # 'he'
        9877, # female names
        12128, # 'his'
        15017, # 'she'
        # 17369, # contact info
        # 26969, # related to nursing
        30248, # female names
    ],
    submodules[7] : [
        13570, # promotes male-related words
        27472, # female names, promotes female-related words
    ],
    submodules[8] : [
    ],
    submodules[9] : [
        1995, # promotes female-associated words
        9128, # feminine pronouns
        11656, # promotes male-associated words
        12440, # promotes female-associated words
        # 14638, # related to contact information?
        29206, # gendered pronouns
        29295, # female names
        # 31098, # nursing-related words
    ],
    submodules[10] : [
        2959, # promotes female-associated words
        19128, # promotes male-associated words
        22029, # promotes female-associated words
    ],
    submodules[11] : [
    ],
    submodules[12] : [
        19558, # promotes female-associated words
        23545, # 'she'
        24806, # 'her'
        27334, # promotes male-associated words
        31453, # female names
    ],
    submodules[13] : [
        31101, # promotes female-associated words
    ],
    submodules[14] : [
    ],
    submodules[15] : [
        9766, # promotes female-associated words
        12420, # promotes female pronouns
        30220, # promotes male pronouns
    ]
}

# print(f"Number of features to ablate: {sum(len(v) for v in feats_to_ablate.values())}")

# feats_to_ablate = {
#     submodule : n_hot(feats) for submodule, feats in feats_to_ablate.items()
# }

import json
from tqdm import tqdm

results = []

# Flatten the feats_to_ablate dictionary
# all_features = [(submodule, feats) for submodule, feats in feats_to_ablate.items() for feat in feats]

# Generate ablation ranges
ranges_to_ablate = []

# Individual indices 0-15
for i in range(16):
    ranges_to_ablate.append([i])

# Specific ranges
specific_ranges = [[0,3], [3,6], [6,9], [9,12], [12,15], [0,3,6], [3,6,9], [6,9,12], [9,12,15]]
ranges_to_ablate.extend(specific_ranges)

# Continuous ranges'
ranges_to_ablate.append(list(range(0, 7))) 
ranges_to_ablate.append(list(range(1, 9))) 
ranges_to_ablate.append(list(range(1, 16)))
ranges_to_ablate.append(list(range(3, 16)))
ranges_to_ablate.append(list(range(6, 16)))
ranges_to_ablate.append(list(range(9, 16)))

for range_index, range_to_ablate in enumerate(tqdm(ranges_to_ablate)):
    # All elements in ranges_to_ablate are now lists, so no type checking is needed
    submodules_to_ablate = [submodules[i] for i in range_to_ablate]
    feats = [feat for i in range_to_ablate for feat in feats_to_ablate[submodules[i]]]
    
    # Create the ablation dictionary
    ablation_dict = {submodule: n_hot(feats_to_ablate[submodule]) for submodule in submodules_to_ablate}
    
    # Define the get_acts_abl function for this ablation
    get_acts_abl = lambda text: get_acts_ablated(text, model, submodules_to_ablate, dictionaries, ablation_dict)
    
    # Perform the tests
    ambiguous_acc = test_probe(probe, get_acts_abl, label_idx=0)
    
    batches = get_data(train=False, ambiguous=False)
    ground_truth_acc = test_probe(probe, get_acts_abl, batches=batches, label_idx=0)
    spurious_acc = test_probe(probe, get_acts_abl, batches=batches, label_idx=1)
    
    # Save results
    result = {
        "ablation_index": range_index,
        "submodules": range_to_ablate,
        "ambiguous_accuracy": float(ambiguous_acc),
        "ground_truth_accuracy": float(ground_truth_acc),
        "spurious_accuracy": float(spurious_acc)
    }
    results.append(result)

    print(f"Ablation {range_index}, Submodules {range_to_ablate}: Ambiguous: {ambiguous_acc:.4f}, Ground Truth: {ground_truth_acc:.4f}, Spurious: {spurious_acc:.4f}")
    
    # Save to JSON after each iteration
    with open('ablation_results.json', 'w') as f:
        json.dump(results, f, indent=2)

print("Ablation complete. Results saved to 'ablation_results.json'")

  0%|          | 0/31 [00:00<?, ?it/s]

  3%|▎         | 1/31 [00:35<17:46, 35.55s/it]

Ablation 0, Submodules [0]: Ambiguous: 0.9484, Ground Truth: 0.8566, Spurious: 0.5893


  6%|▋         | 2/31 [01:12<17:33, 36.34s/it]

Ablation 1, Submodules [1]: Ambiguous: 0.9957, Ground Truth: 0.5922, Spurious: 0.9021


 10%|▉         | 3/31 [01:49<17:08, 36.72s/it]

Ablation 2, Submodules [2]: Ambiguous: 0.9748, Ground Truth: 0.7615, Spurious: 0.7131


 13%|█▎        | 4/31 [02:27<16:41, 37.11s/it]

Ablation 3, Submodules [3]: Ambiguous: 0.9708, Ground Truth: 0.7972, Spurious: 0.6740


 16%|█▌        | 5/31 [03:05<16:10, 37.33s/it]

Ablation 4, Submodules [4]: Ambiguous: 0.9957, Ground Truth: 0.5922, Spurious: 0.9021


 19%|█▉        | 6/31 [03:42<15:36, 37.47s/it]

Ablation 5, Submodules [5]: Ambiguous: 0.9954, Ground Truth: 0.6135, Spurious: 0.8796


 23%|██▎       | 7/31 [04:20<15:02, 37.59s/it]

Ablation 6, Submodules [6]: Ambiguous: 0.9617, Ground Truth: 0.7725, Spurious: 0.6861


 26%|██▌       | 8/31 [04:58<14:28, 37.75s/it]

Ablation 7, Submodules [7]: Ambiguous: 0.9940, Ground Truth: 0.6596, Spurious: 0.8289


 29%|██▉       | 9/31 [05:36<13:53, 37.89s/it]

Ablation 8, Submodules [8]: Ambiguous: 0.9957, Ground Truth: 0.5922, Spurious: 0.9021


 32%|███▏      | 10/31 [06:14<13:16, 37.94s/it]

Ablation 9, Submodules [9]: Ambiguous: 0.9746, Ground Truth: 0.7598, Spurious: 0.7126


 35%|███▌      | 11/31 [06:52<12:38, 37.90s/it]

Ablation 10, Submodules [10]: Ambiguous: 0.9942, Ground Truth: 0.7010, Spurious: 0.7886


 39%|███▊      | 12/31 [07:30<11:58, 37.79s/it]

Ablation 11, Submodules [11]: Ambiguous: 0.9957, Ground Truth: 0.5922, Spurious: 0.9021


 42%|████▏     | 13/31 [08:07<11:18, 37.67s/it]

Ablation 12, Submodules [12]: Ambiguous: 0.9870, Ground Truth: 0.7500, Spurious: 0.7350


 45%|████▌     | 14/31 [08:45<10:41, 37.72s/it]

Ablation 13, Submodules [13]: Ambiguous: 0.9958, Ground Truth: 0.6169, Spurious: 0.8750


 48%|████▊     | 15/31 [09:23<10:03, 37.70s/it]

Ablation 14, Submodules [14]: Ambiguous: 0.9957, Ground Truth: 0.5922, Spurious: 0.9021


 52%|█████▏    | 16/31 [10:00<09:25, 37.73s/it]

Ablation 15, Submodules [15]: Ambiguous: 0.9924, Ground Truth: 0.7402, Spurious: 0.7506


 55%|█████▍    | 17/31 [10:52<09:45, 41.84s/it]

Ablation 16, Submodules [0, 3]: Ambiguous: 0.9460, Ground Truth: 0.8577, Spurious: 0.5847


 58%|█████▊    | 18/31 [11:44<09:43, 44.87s/it]

Ablation 17, Submodules [3, 6]: Ambiguous: 0.9685, Ground Truth: 0.7955, Spurious: 0.6711


 61%|██████▏   | 19/31 [12:35<09:22, 46.84s/it]

Ablation 18, Submodules [6, 9]: Ambiguous: 0.9553, Ground Truth: 0.7863, Spurious: 0.6676


 65%|██████▍   | 20/31 [13:27<08:52, 48.37s/it]

Ablation 19, Submodules [9, 12]: Ambiguous: 0.9531, Ground Truth: 0.7713, Spurious: 0.6757


 68%|██████▊   | 21/31 [14:19<08:14, 49.46s/it]

Ablation 20, Submodules [12, 15]: Ambiguous: 0.9780, Ground Truth: 0.7926, Spurious: 0.6820


 71%|███████   | 22/31 [15:25<08:09, 54.44s/it]

Ablation 21, Submodules [0, 3, 6]: Ambiguous: 0.9434, Ground Truth: 0.8635, Spurious: 0.5766


 74%|███████▍  | 23/31 [16:32<07:44, 58.06s/it]

Ablation 22, Submodules [3, 6, 9]: Ambiguous: 0.9614, Ground Truth: 0.8088, Spurious: 0.6498


 77%|███████▋  | 24/31 [17:38<07:04, 60.58s/it]

Ablation 23, Submodules [6, 9, 12]: Ambiguous: 0.9423, Ground Truth: 0.7961, Spurious: 0.6429


 81%|████████  | 25/31 [18:44<06:13, 62.28s/it]

Ablation 24, Submodules [9, 12, 15]: Ambiguous: 0.9445, Ground Truth: 0.8041, Spurious: 0.6325


 84%|████████▍ | 26/31 [20:47<06:41, 80.26s/it]

Ablation 25, Submodules [0, 1, 2, 3, 4, 5, 6]: Ambiguous: 0.9431, Ground Truth: 0.8635, Spurious: 0.5766


 87%|████████▋ | 27/31 [23:04<06:29, 97.38s/it]

Ablation 26, Submodules [1, 2, 3, 4, 5, 6, 7, 8]: Ambiguous: 0.9628, Ground Truth: 0.8018, Spurious: 0.6613


 90%|█████████ | 28/31 [27:01<06:58, 139.39s/it]

Ablation 27, Submodules [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]: Ambiguous: 0.9289, Ground Truth: 0.8422, Spurious: 0.5818


 94%|█████████▎| 29/31 [30:31<05:20, 160.46s/it]

Ablation 28, Submodules [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]: Ambiguous: 0.9333, Ground Truth: 0.8468, Spurious: 0.5818


 97%|█████████▋| 30/31 [33:17<02:42, 162.11s/it]

Ablation 29, Submodules [6, 7, 8, 9, 10, 11, 12, 13, 14, 15]: Ambiguous: 0.9215, Ground Truth: 0.8318, Spurious: 0.5876


100%|██████████| 31/31 [35:20<00:00, 68.39s/it] 

Ablation 30, Submodules [9, 10, 11, 12, 13, 14, 15]: Ambiguous: 0.9308, Ground Truth: 0.8116, Spurious: 0.6123
Ablation complete. Results saved to 'ablation_results.json'





In [None]:
raise ValueError("Stop here")

In [None]:

get_acts_abl = lambda text : get_acts_ablated(text, model, submodules, dictionaries, feats_to_ablate)

print('Ambiguous test accuracy:', test_probe(probe, get_acts_abl, label_idx=0))
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(probe, get_acts_abl, batches=batches, label_idx=0))
print('Spurious accuracy:', test_probe(probe, get_acts_abl, batches=batches, label_idx=1))

In [None]:
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(probe, get_acts_abl, batches=batches, label_idx=0))

# Concept bottleneck probing baseline

In [None]:
concepts = [    
    ' nurse',
    ' healthcare',
    ' hospital',
    ' patient',
    ' medical',
    ' clinic',
    ' triage',
    ' medication',
    ' emergency',
    ' surgery',
    ' professor',
    ' academia',
    ' research',
    ' university',
    ' tenure',
    ' faculty',
    ' dissertation',
    ' sabbatical',
    ' publication',
    ' grant',
]
# get concept vectors
with t.no_grad(), model.trace(concepts):
    concept_vectors = model.gpt_neox.layers[layer].output[0][:, -1, :].save()
concept_vectors = concept_vectors.value - concept_vectors.value.mean(0, keepdim=True)

def get_bottleneck(text):
    with t.no_grad(), model.trace(text, **tracer_kwargs):
        attn_mask = model.input[1]['attention_mask']
        acts = model.gpt_neox.layers[layer].output[0]
        acts = acts * attn_mask[:, :, None]
        acts = acts.sum(1) / attn_mask.sum(1)[:, None]
        # compute cosine similarity with concept vectors
        sims = (acts @ concept_vectors.T) / (acts.norm(dim=-1)[:, None] @ concept_vectors.norm(dim=-1)[None])
        sims = sims.save()
    return sims.value

In [None]:
cbp_probe, _ = train_probe(get_bottleneck, label_idx=0, dim=len(concepts))
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(cbp_probe, get_bottleneck, batches=batches, label_idx=0))
print('Unintended feature accuracy:', test_probe(cbp_probe, get_bottleneck, batches=batches, label_idx=1))


In [None]:
# get subgroup accuracies
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(cbp_probe, get_bottleneck, batches=batches, label_idx=0))

# Get skyline neuron performance

In [None]:
# get neurons which are most influential for giving gender label
neuron_dicts = {
    submodule : IdentityDict(activation_dim).to(DEVICE) for submodule in submodules
}

n_batches = 25
batch_size = 4

running_total = 0
nodes = None

for batch_idx, (clean, _, labels) in tqdm(enumerate(get_data(train=True, ambiguous=False, batch_size=batch_size, seed=SEED)), total=n_batches):
    if batch_idx == n_batches:
        break

    effects, _, _, _ = patching_effect(
        clean,
        None,
        model,
        submodules,
        neuron_dicts,
        metric_fn,
        metric_kwargs=dict(labels=labels),
        method='ig'
    )
    with t.no_grad():
        if nodes is None:
            nodes = {k : len(clean) * v.sum(dim=1).mean(dim=0) for k, v in effects.items()}
        else:
            for k, v in effects.items():
                nodes[k] += len(clean) * v.sum(dim=1).mean(dim=0)
        running_total += len(clean)
    del effects, _
    gc.collect()

nodes = {k : v / running_total for k, v in nodes.items()}

In [None]:
neurons_to_ablate = {}
total_neurons = 0
for component_idx, effect in enumerate(nodes.values()):
    print(f"Component {component_idx}:")
    neurons_to_ablate[submodules[component_idx]] = []
    for idx in (effect.act > 0.2135).nonzero():
        print(idx.item(), effect[idx].item())
        neurons_to_ablate[submodules[component_idx]].append(idx.item())
        total_neurons += 1
print(f"total neurons: {total_neurons}")

neurons_to_ablate = {
    submodule : n_hot([neuron_idx], dim=512) for submodule, neuron_idx in neurons_to_ablate.items()
}

In [None]:
def get_acts_abl(text):
    with t.no_grad(), model.trace(text, **tracer_kwargs):
        for submodule in submodules:
            x = submodule.output
            if is_tuple[submodule]:
                x = x[0]
            x[...,neurons_to_ablate[submodule]] = x.mean(dim=(0,1))[...,neurons_to_ablate[submodule]] # mean ablation
            if is_tuple[submodule]:
                submodule.output[0][:] = x
            else:
                submodule.output = x

        attn_mask = model.input[1]['attention_mask']
        act = model.gpt_neox.layers[layer].output[0]
        act = act * attn_mask[:, :, None]
        act = act.sum(1) / attn_mask.sum(1)[:, None]
        act = act.save()
    return act.value

In [None]:
print('Ambiguous test accuracy:', test_probe(probe, get_acts_abl, label_idx=0))
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(probe, get_acts_abl, batches=batches, label_idx=0))
print('Spurious accuracy:', test_probe(probe, get_acts_abl, batches=batches, label_idx=1))

In [None]:
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(probe, get_acts_abl, batches=batches, label_idx=0))

# Get skyline feature performance

In [None]:
# get features which are most useful for predicting gender label
n_batches = 25
batch_size = 4

running_total = 0
running_nodes = None

for batch_idx, (clean, _, labels) in tqdm(enumerate(get_data(train=True, ambiguous=False, batch_size=batch_size, seed=SEED)), total=n_batches):
    if batch_idx == n_batches:
        break

    effects, _, _, _ = patching_effect(
        clean,
        None,
        model,
        submodules,
        dictionaries,
        metric_fn,
        metric_kwargs=dict(labels=labels),
        method='ig'
    )
    with t.no_grad():
        if running_nodes is None:
            running_nodes = {k : len(clean) * v.sum(dim=1).mean(dim=0) for k, v in effects.items()}
        else:
            for k, v in effects.items():
                running_nodes[k] += len(clean) * v.sum(dim=1).mean(dim=0)
        running_total += len(clean)
    del effects, _
    gc.collect()

nodes = {k : v / running_total for k, v in running_nodes.items()}

In [None]:
top_feats_to_ablate = {}
total_features = 0
for component_idx, effect in enumerate(nodes.values()):
    print(f"Component {component_idx}:")
    top_feats_to_ablate[submodules[component_idx]] = []
    for idx in (effect > 0.1107).nonzero():
        print(idx.item(), effect[idx].item())
        top_feats_to_ablate[submodules[component_idx]].append(idx.item())
        total_features += 1
print(f"total features: {total_features}")

In [None]:
top_feats_to_ablate = {
    submodule : n_hot(feats) for submodule, feats in top_feats_to_ablate.items()
}
get_acts_abl = lambda text : get_acts_ablated(text, model, submodules, dictionaries, top_feats_to_ablate)

In [None]:
print('Ambiguous test accuracy:', test_probe(probe, get_acts_abl, label_idx=0))
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(probe, get_acts_abl, batches=batches, label_idx=0))
print('Spurious accuracy:', test_probe(probe, get_acts_abl, batches=batches, label_idx=1))

In [None]:
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(probe, get_acts_abl, batches=batches, label_idx=0))

# Retraining probe on activations after ablating features

In [None]:
get_acts_abl = lambda text : get_acts_ablated(text, model, submodules, dictionaries, feats_to_ablate)

new_probe, _ = train_probe(get_acts_abl, label_idx=0)
print('Ambiguous test accuracy:', test_probe(new_probe, get_acts_abl, label_idx=0))
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(new_probe, get_acts_abl, batches=batches, label_idx=0))
print('Unintended feature accuracy:', test_probe(new_probe, get_acts_abl, batches=batches, label_idx=1))

In [None]:
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(new_probe, get_acts_abl, batches=batches, label_idx=0))