In [1]:
from nnsight import LanguageModel
import random
from datasets import load_dataset
import torch as t
import torch.nn as nn
from dictionary_learning import AutoEncoder
from circuit_triangles import get_circuit
from circuit_plotting import plot_circuit
from attribution import patching_effect
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm
  torch.utils._pytree._register_pytree_node(


In [2]:
dataset = load_dataset("LabHC/bias_in_bios")
profession_dict = {'professor' : 21, 'nurse' : 13}
male_prof = 'professor'
female_prof = 'nurse'

DEVICE = 'cuda:0'
model = LanguageModel('EleutherAI/pythia-70m-deduped', device_map=DEVICE)
layer = 4

dict_id = 10

batch_size = 1024
SEED = 42

def get_data(train=True, ambiguous=True, batch_size=128, seed=SEED):
    if train:
        data = dataset['train']
    else:
        data = dataset['test']
    if ambiguous:
        neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        n = min([len(neg), len(pos)])
        neg, pos = neg[:n], pos[:n]
        data = neg + pos
        labels = [0]*n + [1]*n
        idxs = list(range(2*n))
        random.Random(seed).shuffle(idxs)
        data, labels = [data[i] for i in idxs], [labels[i] for i in idxs]
        true_labels = spurious_labels = labels
    else:
        neg_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        neg_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 1]
        pos_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 0]
        pos_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        n = min([len(neg_neg), len(neg_pos), len(pos_neg), len(pos_pos)])
        neg_neg, neg_pos, pos_neg, pos_pos = neg_neg[:n], neg_pos[:n], pos_neg[:n], pos_pos[:n]
        data = neg_neg + neg_pos + pos_neg + pos_pos
        true_labels     = [0]*n + [0]*n + [1]*n + [1]*n
        spurious_labels = [0]*n + [1]*n + [0]*n + [1]*n
        idxs = list(range(4*n))
        random.Random(seed).shuffle(idxs)
        data, true_labels, spurious_labels = [data[i] for i in idxs], [true_labels[i] for i in idxs], [spurious_labels[i] for i in idxs]

    batches = [
        (data[i:i+batch_size], t.tensor(true_labels[i:i+batch_size], device=DEVICE), t.tensor(spurious_labels[i:i+batch_size], device=DEVICE)) for i in range(0, len(data), batch_size)
    ]

    return batches

def get_subgroups(train=True, ambiguous=True, batch_size=128, seed=SEED):
    if train:
        data = dataset['train']
    else:
        data = dataset['test']
    if ambiguous:
        neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        neg_labels, pos_labels = (0, 0), (1, 1)
        subgroups = [(neg, neg_labels), (pos, pos_labels)]
    else:
        neg_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        neg_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 1]
        pos_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 0]
        pos_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        neg_neg_labels, neg_pos_labels, pos_neg_labels, pos_pos_labels = (0, 0), (0, 1), (1, 0), (1, 1)
        subgroups = [(neg_neg, neg_neg_labels), (neg_pos, neg_pos_labels), (pos_neg, pos_neg_labels), (pos_pos, pos_pos_labels)]
    
    out = {}
    for data, label_profile in subgroups:
        out[label_profile] = []
        for i in range(0, len(data), batch_size):
            text = data[i:i+batch_size]
            out[label_profile].append(
                (
                    text,
                    t.tensor([label_profile[0]]*len(text), device=DEVICE),
                    t.tensor([label_profile[1]]*len(text), device=DEVICE)
                )
            )
    return out

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
# probe training hyperparameters
lr = 1e-2
epochs = 1

class Probe(nn.Module):
    def __init__(self, activation_dim):
        super().__init__()
        self.net = nn.Linear(activation_dim, 1, bias=True)

    def forward(self, x):
        logits = self.net(x).squeeze(-1)
        return logits.sigmoid()
    
def train_probe(get_acts, label_idx=0, batches=get_data(), lr=1e-2, epochs=1, seed=SEED):
    t.manual_seed(seed)
    probe = Probe(512).to('cuda:0')
    optimizer = t.optim.AdamW(probe.parameters(), lr=lr)
    criterion = nn.BCELoss()

    losses = []
    for epoch in range(epochs):
        for batch in batches:
            text = batch[0]
            labels = batch[label_idx+1] 
            acts = get_acts(text)
            logits = probe(acts)
            loss = criterion(logits, labels.float())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses.append(loss.item())

    return probe, losses

def test_probe(probe, get_acts, label_idx=0, batches=get_data(train=False), seed=SEED):
    with t.no_grad():
        corrects = []

        for batch in batches:
            text = batch[0]
            labels = batch[label_idx+1]
            acts = get_acts(text)
            logits = probe(acts)
            preds = (logits > 0.5).long()
            corrects.append((preds == labels).float())
        return t.cat(corrects).mean().item()
    
def get_acts(text):
    with model.invoke(text):
        acts = model.gpt_neox.layers[layer].output[0][:,-1,:].save()
    return acts.value.clone()

In [4]:
# probe0, _ = train_probe(get_acts, label_idx=0, batches=get_data(ambiguous=False), lr=lr, epochs=epochs)
# probe1, _ = train_probe(get_acts, label_idx=1, batches=get_data(ambiguous=False), lr=lr, epochs=epochs)
# batches = get_data(train=False, ambiguous=False)
# print('Probe 0 accuracy:', test_probe(probe0, get_acts, batches=batches, label_idx=0))
# print('Probe 1 accuracy:', test_probe(probe1, get_acts, batches=batches, label_idx=1))

In [5]:
probe, _ = train_probe(get_acts, label_idx=0, lr=lr, epochs=epochs)
print('Ambiguous test accuracy:', test_probe(probe, get_acts, label_idx=0))
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(probe, get_acts, batches=batches, label_idx=0))
print('Spurious accuracy:', test_probe(probe, get_acts, batches=batches, label_idx=1))

You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Ambiguous test accuracy: 0.9915195107460022
Ground truth accuracy: 0.6261520981788635
Spurious accuracy: 0.8680875301361084


In [6]:
attns = [layer.attention for layer in model.gpt_neox.layers[:layer+1]]
mlps = [layer.mlp for layer in model.gpt_neox.layers[:layer+1]]
resids = [layer for layer in model.gpt_neox.layers[:layer+1]]

dictionaries = {}
for i in range(layer + 1):
    ae = AutoEncoder(512, 64 * 512).to(DEVICE)
    ae.load_state_dict(t.load(f'/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/attn_out_layer{i}/{dict_id}_32768/ae.pt'))
    dictionaries[attns[i]] = ae

    ae = AutoEncoder(512, 64 * 512).to(DEVICE)
    ae.load_state_dict(t.load(f'/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/mlp_out_layer{i}/{dict_id}_32768/ae.pt'))
    dictionaries[mlps[i]] = ae

    ae = AutoEncoder(512, 64 * 512).to(DEVICE)
    ae.load_state_dict(t.load(f'/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/resid_out_layer{i}/{dict_id}_32768/ae.pt'))
    dictionaries[resids[i]] = ae

def metric_fn(model, labels=None):
    return t.where(
        labels == 0,
        probe(model.gpt_neox.layers[layer].output[0][:,-1,:]),
        1 - probe(model.gpt_neox.layers[layer].output[0][:,-1,:])
    )

In [7]:
running_total = 0
running_nodes = None
running_edges = None
for clean, labels, _ in tqdm(get_data(train=True, ambiguous=True, batch_size=5, seed=SEED), total=20):
    nodes, edges = get_circuit(
        clean,
        None,
        model,
        attns,
        mlps,
        resids,
        dictionaries,
        metric_fn,
        metric_kwargs={'labels': labels},
        node_threshold=0.01,
        edge_threshold=0.0001,
    )
    running_total += len(clean)
    if running_nodes is None:
        running_nodes = nodes
        running_edges = edges
    else:
        for k, effect in nodes.items():
            if k == 'y': continue
            running_nodes[k] += len(clean) * effect
        for k in edges.keys():
            for kk, effect in edges[k].items():
                running_edges[k][kk] = (running_edges[k][kk] + len(clean) * effect).coalesce()

for k in running_nodes.keys():
    if k == 'y': continue
    running_nodes[k] = running_nodes[k] / running_total
for k in running_edges.keys():
    for kk in running_edges[k].keys():
        running_edges[k][kk] = running_edges[k][kk] / running_total

 10%|█         | 2/20 [06:22<57:18, 191.02s/it]


OutOfMemoryError: CUDA out of memory. Tried to allocate 1.98 GiB. GPU 0 has a total capacity of 47.51 GiB of which 1.36 GiB is free. Including non-PyTorch memory, this process has 46.04 GiB memory in use. Of the allocated memory 41.33 GiB is allocated by PyTorch, and 4.39 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
plot_circuit(nodes, edges, layers=layer+1, node_threshold=0.01, edge_threshold=0.001, pen_thickness=3, save_dir='BiB_5')