In [None]:
import sys
import os
parent_dir = os.path.abspath('..')
sys.path.append(parent_dir)

from datasets import load_dataset
import random
from nnsight import LanguageModel
import torch as t
from torch import nn
from attribution import patching_effect
from dictionary_learning import AutoEncoder, ActivationBuffer
from dictionary_learning.dictionary import IdentityDict
from dictionary_learning.interp import examine_dimension
from dictionary_learning.utils import hf_dataset_to_generator
from tqdm import tqdm
import gc
from sklearn.utils import shuffle
import matplotlib.pyplot as plt
import pandas as pd
import random
from collections import defaultdict



DEBUGGING = False

if DEBUGGING:
    tracer_kwargs = dict(scan=True, validate=True)
else:
    tracer_kwargs = dict(scan=False, validate=False)

# model hyperparameters
DEVICE = 'cuda:0'
model = LanguageModel('EleutherAI/pythia-70m-deduped', device_map=DEVICE, dispatch=True)
activation_dim = 512

In [None]:
# dataset hyperparameters
dataset = load_dataset("LabHC/bias_in_bios")

# data preparation hyperparameters
batch_size = 128
SEED = 42

In [None]:
dataset['train'][0]

In [None]:
# Dictionary for professions
profession_dict = {
    'accountant': 0,
    'architect': 1,
    'attorney': 2,
    'chiropractor': 3,
    'comedian': 4,
    'composer': 5,
    'dentist': 6,
    'dietitian': 7,
    'dj': 8,
    'filmmaker': 9,
    'interior_designer': 10,
    'journalist': 11,
    'model': 12,
    'nurse': 13,
    'painter': 14,
    'paralegal': 15,
    'pastor': 16,
    'personal_trainer': 17,
    'photographer': 18,
    'physician': 19,
    'poet': 20,
    'professor': 21,
    'psychologist': 22,
    'rapper': 23,
    'software_engineer': 24,
    'surgeon': 25,
    'teacher': 26,
    'yoga_teacher': 27
}

# Reverse the profession dictionary for easy lookup
profession_dict_rev = {v: k for k, v in profession_dict.items()}

# Convert the dataset to a pandas DataFrame for easier manipulation
df = pd.DataFrame(dataset['train'])

# Create a combined label column for (multiclass x binary)
df['combined_label'] = df['profession'].astype(str) + '_' + df['gender'].astype(str)

# Plot the number of samples per (multiclass x binary) label
label_counts = df['combined_label'].value_counts().sort_index()
smallest_label_count = label_counts.min()
print(f'Smallest label count: {smallest_label_count}')

# Create labels with profession names and gender
labels = [profession_dict_rev[int(label.split('_')[0])] + ' (Male)' if label.split('_')[1] == '0' 
            else profession_dict_rev[int(label.split('_')[0])] + ' (Female)' for label in label_counts.index]

plt.figure(figsize=(12, 8))
plt.bar(labels, label_counts)
plt.xlabel('(Profession x Gender) Label')
plt.ylabel('Number of Samples')
plt.title('Number of Samples per (Profession x Gender) Label')
plt.xticks(rotation=90)
plt.tight_layout()
plt.show()

In [None]:
min_samples_per_group = 1024

# Balance the dataset within each profession
balanced_df_list = []
for profession in df['profession'].unique():
    prof_df = df[df['profession'] == profession]
    min_count = prof_df['gender'].value_counts().min()
    if min_count >= min_samples_per_group:    
        balanced_prof_df = prof_df.groupby('gender').apply(lambda x: x.sample(n=min_samples_per_group)).reset_index(drop=True)
        balanced_df_list.append(balanced_prof_df)

balanced_df = pd.concat(balanced_df_list).reset_index(drop=True)

# Shuffle per profession
grouped = balanced_df.groupby('profession')['hard_text'].apply(list)
bios_gender_balanced = {}

for label, texts in grouped.items():
    shuffled_texts = shuffle(texts)
    bios_gender_balanced[label] = shuffled_texts

In [None]:
def sample_from_classes(data_dict, chosen_class):
    # Step 1: Generate a random list of class indices
    total_samples = len(data_dict[chosen_class])
    all_classes = list(data_dict.keys())
    all_classes.remove(chosen_class)
    
    random_class_indices = random.choices(all_classes, k=total_samples)
    
    # Step 2: Count the number of samples to draw from each class index
    samples_count = defaultdict(int)
    for class_idx in random_class_indices:
        samples_count[class_idx] += 1
    
    # Step 3: Uniformly sample the required amount of samples without replacement
    sampled_data = []
    for class_idx, count in samples_count.items():
        sampled_data.extend(random.sample(data_dict[class_idx], count))
    
    return sampled_data

def create_labeled_dataset(data_dict, chosen_class, batch_size):
    in_class_data = data_dict[chosen_class]
    other_class_data = sample_from_classes(data_dict, chosen_class)

    # Step 1: Label the datasets
    in_class_labeled = [(sample, 0) for sample in in_class_data]
    other_class_labeled = [(sample, 1) for sample in other_class_data]

    # Step 2: Concatenate the datasets
    combined_dataset = in_class_labeled + other_class_labeled

    # Step 3: Shuffle the combined dataset
    random.shuffle(combined_dataset)
    bio_texts, bio_labels = zip(*combined_dataset)
    text_batches = [bio_texts[i:i + batch_size] for i in range(0, len(combined_dataset), batch_size)]
    label_batches = [t.tensor(bio_labels[i:i + batch_size], device=DEVICE) for i in range(0, len(combined_dataset), batch_size)]

    return text_batches, label_batches

## Train Probes

In [None]:
# probe training hyperparameters

layer = 4 # model layer for attaching linear classification head

class Probe(nn.Module):
    def __init__(self, activation_dim):
        super().__init__()
        self.net = nn.Linear(activation_dim, 1, bias=True)

    def forward(self, x):
        logits = self.net(x).squeeze(-1)
        return logits

In [None]:
def train_probe(text_batches, label_batches, get_acts, lr=1e-2, epochs=1, dim=512, seed=SEED):
    t.manual_seed(seed)
    probe = Probe(dim).to(DEVICE)
    optimizer = t.optim.AdamW(probe.parameters(), lr=lr)
    criterion = nn.BCEWithLogitsLoss()

    losses = t.zeros(epochs * len(text_batches))
    batch_idx = 0
    for epoch in range(epochs):
        for text, labels in zip(text_batches, label_batches):
            acts = get_acts(text)
            logits = probe(acts)
            loss = criterion(logits, t.tensor(labels, device=DEVICE, dtype=t.float32))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses[batch_idx] = loss
            batch_idx += 1
    return probe, losses

def get_acts(text):
    with t.no_grad(): 
        with model.trace(text, **tracer_kwargs):
            attn_mask = model.input[1]['attention_mask']
            acts = model.gpt_neox.layers[layer].output[0]
            acts = acts * attn_mask[:, :, None]
            acts = acts.sum(1) / attn_mask.sum(1)[:, None]
            acts = acts.save()
        return acts.value

In [None]:
probes, losses = {}, {}

for profession in bios_gender_balanced.keys():
    t.cuda.empty_cache()
    gc.collect()

    print(f'Training probe for profession: {profession}')
    text_batches, label_batches = create_labeled_dataset(bios_gender_balanced, profession, batch_size)
    probe, loss = train_probe(
        text_batches,
        label_batches,
        get_acts,
        epochs=1
    )
    probes[profession] = probe
    losses[profession] = loss

# make subfolder for saving
os.makedirs('trained_bib_probes', exist_ok=True)

# save probes, losses
t.save(probes, 'trained_bib_probes/probes_0705.pt')
t.save(losses, 'trained_bib_probes/losses_0705.pt')

In [None]:
def test_probe(text_batches, label_batches, probe, get_acts, label_idx=0, seed=SEED):
    with t.no_grad():
        corrects = []

        for text, labels in zip(text_batches, label_batches):
            acts = get_acts(text)
            logits = probe(acts)
            preds = (logits > 0.0).long()
            corrects.append((preds == labels).float())
        return t.cat(corrects).mean().item()