In [2]:
import torch
import torch.nn as nn
import numpy as np
np.random.seed(42)
import json
from tqdm import tqdm
import copy
import os
import re
import pandas as pd

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForMaskedLM, AutoModelForSeq2SeqLM

In [3]:
CENSUS_GROUPS = ["Women", "White", "Black or African American", "Asian", "Hispanic or Latino"]

SAVE_DIR = os.path.join(os.getcwd(), "saved/")
if not os.path.exists(SAVE_DIR):
    os.makedirs(SAVE_DIR)
CACHE_DIR = os.path.join(os.getcwd(), "cache_dir")
DATA_DIR = os.path.join(os.getcwd(), "data")
PROFESSIONS_PATH = os.path.join(DATA_DIR, "professions.json")
CENSUS_PATH = os.path.join(DATA_DIR, "cpsaat11.csv")
PROMPTS_PATH = os.path.join(DATA_DIR, "prompts.txt")
PROMPTS_PATH_CENSUS = os.path.join(DATA_DIR, "census_race_prompts.txt")
CROWSPAIRS_PATH = os.path.join(DATA_DIR, "crows_pairs_anonymized.csv")

SPLIT = 32

DEVICE = "mps"

### Extract Hidden States

In [19]:
def get_encoder_hidden_states(model, tokenizer, input_text, layer):
    """
    Given an encoder model and some text, gets the encoder hidden states (in a given layer, by default the last) 
    on that input text (where the full text is given to the encoder).

    Returns a numpy array of shape (hidden_dim,)
    """
    # tokenize
    encoder_text_ids = tokenizer(input_text, truncation=True, return_tensors="pt").input_ids.to(model.device)

    # forward pass
    with torch.no_grad():
        output = model(encoder_text_ids, output_hidden_states=True)

    # get the appropriate hidden states
    hs_tuple = output["hidden_states"]
    
    # get the last layer, last token hidden states
    hs_tuple = output["hidden_states"]

    if layer:
        hs = hs_tuple[layer][0, -1].detach().cpu().numpy()
    
    # If we do not specify a layer, get them all. Get the CLS token
    else:
        hs = torch.concatenate(hs_tuple, axis=0)[:,0,:]
        hs = hs.detach().cpu().numpy()

    return hs

def get_encoder_hidden_states_tokens(model, tokenizer, input_text_list):
    input_text_list = input_text_list.tolist()
    max_len = max([len(tokenizer.encode(text)) for text in input_text_list])
    n = len(input_text_list)
    all_hs = []
    i = 0
    while i < n:
        split_text_list = input_text_list[i : min(i + (n // SPLIT), n)]
        i += (n // SPLIT)
        input = tokenizer(split_text_list, padding="max_length", max_length=max_len, truncation=True, return_tensors="pt").to(model.device)
        with torch.no_grad():
            outputs = model(**input, output_hidden_states=True)
        hs_layers = outputs["hidden_states"]
        hs = hs_layers[-1]
        all_hs.append(hs)

    return torch.concatenate(all_hs, dim=0)

def get_decoder_hidden_states(model, tokenizer, input_text, layer):
    """
    Given a decoder model and some text, gets the hidden states (in a given layer, by default the last) on that input text

    Returns a numpy array of shape (hidden_dim,)
    """
    # tokenize (adding the EOS token this time)
    input_ids = tokenizer(input_text + tokenizer.eos_token, return_tensors="pt").input_ids.to(model.device)

    # forward pass
    with torch.no_grad():
        output = model(input_ids, output_hidden_states=True)

    # get the last layer, last token hidden states
    hs_tuple = output["hidden_states"]

    if layer:
        hs = hs_tuple[layer][0, -1].detach().cpu().numpy()
    
    # If we do not specify a layer, get them all. Still only get the last token
    else:
        hs = torch.concatenate(hs_tuple, axis=0)[:,-1,:]
        hs = hs.detach().cpu().numpy()
        
    return hs

def get_decoder_hidden_states_tokens(model, tokenizer, input_text_list):
    input_text_list = [text + tokenizer.eos_token for text in input_text_list]
    max_len = max([len(tokenizer.encode(text)) for text in input_text_list])
    n = len(input_text_list)
    all_hs = []
    i = 0
    while i < n:
        split_text_list = input_text_list[i : min(i + (n // SPLIT), n)]
        i += (n // SPLIT)
        input = tokenizer(split_text_list, padding="max_length", max_length=max_len, truncation=True, return_tensors="pt").to(model.device)
        with torch.no_grad():
            outputs = model(**input, output_hidden_states=True)
        hs_layers = outputs["hidden_states"]
        hs = hs_layers[-1]
        all_hs.append(hs)
    
    return torch.concatenate(all_hs, dim=0)


def get_encoder_decoder_hidden_states(model, tokenizer, input_text, layer):
    """
    Given an encoder-decoder model and some text, gets the encoder hidden states (in a given layer, by default the last) 
    on that input text (where the full text is given to the encoder).

    Returns a numpy array of shape (hidden_dim,)
    """
    # tokenize
    encoder_text_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(model.device)
    decoder_text_ids = tokenizer(tokenizer.pad_token, return_tensors="pt").input_ids.to(model.device)

    # forward pass
    with torch.no_grad():
        output = model(encoder_text_ids, decoder_input_ids=decoder_text_ids, output_hidden_states=True)

    # get the appropriate hidden states
    hs_tuple = output["decoder_hidden_states"]
    if layer:
        hs = hs_tuple[layer][0, -1].detach().cpu().numpy()
     # If we do not specify a layer, get them all. Get the last hidden state in the decoder
    else:
        hs = torch.concatenate(hs_tuple, axis=0)[:,-1,:]
        hs = hs.detach().cpu().numpy()

    return hs

def get_encoder_decoder_hidden_states_tokens(model, tokenizer, input_text_list):
    input_text_list = input_text_list.tolist()
    max_len = max([len(tokenizer.encode(text)) for text in input_text_list])
    n = len(input_text_list)
    all_hs = []
    i = 0
    while i < n:
        split_text_list = input_text_list[i : min(i + (n // SPLIT), n)]
        decoder_text_list = ["" for _ in range(len(split_text_list))]
        i += (n // SPLIT)
        input = tokenizer(split_text_list, padding="max_length", max_length=max_len, truncation=True, return_tensors="pt").to(model.device)
        decoder_input = tokenizer(decoder_text_list, return_tensors="pt")["input_ids"].to(model.device)
        with torch.no_grad():
            outputs = model(**input, decoder_input_ids=decoder_input, output_hidden_states=True)
        hs_enc = outputs["encoder_hidden_states"][-1]
        hs_dec = outputs["decoder_hidden_states"][-1]
        hs = torch.concatenate((hs_enc, hs_dec), dim=1)
        all_hs.append(hs)
    
    return torch.concatenate(all_hs, dim=0)


def get_hidden_states_many_examples(model, model_type, tokenizer, neg_prompts, pos_prompts, layer):
    """
    Returns N x L X D arrays of hidden states.
    """
    # setup
    model.eval()

    if layer:
        if model_type == "encoder":
                get_hidden_states = get_encoder_hidden_states
        elif model_type == "decoder":
            get_hidden_states = get_decoder_hidden_states
        elif model_type == "encoder-decoder":
            get_hidden_states = get_encoder_decoder_hidden_states
        else:
            assert False, "Invalid model type"
    
    else:
        if model_type == "encoder":
                get_hidden_states = get_encoder_hidden_states_tokens
        elif model_type == "decoder":
            get_hidden_states = get_decoder_hidden_states_tokens
        elif model_type == "encoder-decoder":
            get_hidden_states = get_encoder_decoder_hidden_states_tokens
        else:
            assert False, "Invalid model type"

    neg_hs = get_hidden_states(model, tokenizer, neg_prompts)
    pos_hs = get_hidden_states(model, tokenizer, pos_prompts)

    return neg_hs, pos_hs

In [20]:
def format_profession(prompt, text, label):
    """
    Prompts contain a <LABEL0/LABEL1> tag and a <TEXT> tag.
    Replace the label tag with the corresponding label, replace the text tag with the text.
    """
    # First replace the <TEXT> tag with the proper text
    output = re.sub(r'<TEXT>', text, prompt)

    # Replace the <LABEL0/LABEL1> tag with the proper label
    template = re.findall(r'<(.*?)>', output)
    labels = template[0].split("/")
    output = re.sub(r'<(.*?)>', labels[label], output)
    return output


def parse_professions(professions_path, prompt, undersample=False):
    """
    Reads professions.json and loads professions that have nonzero stereotypical male/female
    bias scores. Parses professions into given prompt. Creates parallel label array as well.
    :param undersample: Whether or not to take the top k most biased male professions to match
    the k female professions.
    """
    with open(professions_path, "r") as f:
        professions = json.load(f)
    
    profs = np.array([prof[0].replace("_", " ") for prof in professions])
    # reals = np.array([prof[1] for prof in professions])
    biases = np.array([prof[2] for prof in professions])

    fem_idx = np.where(biases < 0)[0]
    male_idx = np.where(biases > 0)[0]
    if undersample:
        male_idx = np.flip(np.argsort(biases))[:fem_idx.shape[0]]
    idx = np.concatenate([fem_idx, male_idx])
    prof_sample = profs[idx]
    labels = [0 for _ in range(len(fem_idx))] + [1 for _ in range(len(male_idx))]

    # Parse into prompt
    neg_prompts = [format_profession(prompt, prof, 0) for prof in prof_sample]
    pos_prompts = [format_profession(prompt, prof, 1) for prof in prof_sample]

    return neg_prompts, pos_prompts, labels


def parse_census_data(census_path, prompts, by="Women"):
    census_df = pd.read_csv(census_path)

    # Drop NaNs
    census_df = census_df.replace('–', pd.NA) # This is a weird '-' character present in the data
    census_df = census_df.dropna()

    census_df[CENSUS_GROUPS] = census_df[CENSUS_GROUPS].apply(pd.to_numeric, errors='coerce')

    sorted_df = None
    if by == "Black or African American":
        census_df["whitenorm"] = (census_df["White"] - census_df["White"].mean()) / census_df["White"].std()
        census_df["blacknorm"] = (census_df["Black or African American"] - \
                                  census_df["Black or African American"].mean()) / \
                                    census_df["Black or African American"].std()

        census_df["racediff"] = census_df["whitenorm"] - census_df["blacknorm"]
        sorted_df = census_df.sort_values(by="racediff", ascending=False)
        sorted_df = sorted_df[abs(sorted_df["racediff"]) >= 1]
        sorted_df["label"] = sorted_df["racediff"] > 1

    elif by == "Women":
        # Get most and least female dominated jobs
        sorted_df = census_df.sort_values(by=by, ascending=False)
        sorted_df = sorted_df[abs(sorted_df[by] - 0.5) > 0.1]
        sorted_df["label"] = sorted_df[by] < 0.5

    else:
        assert False, "Give a valid group to sort by"
    
    return np.array(sorted_df["Occupation"].tolist()), np.array(sorted_df["label"].tolist(), dtype=int)


def parse_crowspairs(crowspairs_path, filter=None):
    df = pd.read_csv(crowspairs_path)
    df = df[df["stereo_antistereo"] == "stereo"]

    # Filter by filter
    if filter:
        df = df[df["bias_type"].isin(filter)]

    # Randomly shuffle around which are positive and which are negative so
    # we have even label distribution.
    true = df["sent_more"].to_numpy()
    false = df["sent_less"].to_numpy()
    labels = np.random.randint(2, size=len(true), )
    neg_prompts = np.where(labels == 0, true, false)
    pos_prompts = np.where(labels == 1, true, false)

    return neg_prompts, pos_prompts, labels
    

def parse_prompts():
    with open(PROMPTS_PATH, "r") as f:
        prompts = [prompt.strip("\n") for prompt in f.readlines()]

    return prompts

In [21]:
def save_hidden_states(model_name, model_type, trial_name, neg_prompts, pos_prompts, y, layer, verbose=False):
    """
    Takes in 3 parallel lists of pos and neg prompts as well as their label and passes them through
    the given model, outputting 2 sets of hidden states for each layer.
    """
        
    # Load model
    tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=CACHE_DIR)
    if model_type == "encoder":
        model = AutoModelForMaskedLM.from_pretrained(model_name, cache_dir=CACHE_DIR)
    if model_type == "decoder":
        model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=CACHE_DIR)
        tokenizer.pad_token = tokenizer.eos_token
    elif model_type == "encoder-decoder":
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir=CACHE_DIR)
    model = model.to(DEVICE)
    
    # Get hidden states
    neg_hs, pos_hs = get_hidden_states_many_examples(model, model_type, tokenizer, neg_prompts, pos_prompts, layer)
    y = torch.tensor(y).long()

    if verbose:
        print(neg_hs.shape, pos_hs.shape, y.shape)

    # # Save hidden states
    root = os.path.join(SAVE_DIR, trial_name)
    if not os.path.exists(root):
        print(f"Creating directory {root}")
        os.makedirs(root)
    
    torch.save(neg_hs, os.path.join(root, "neg-hs.pt"))
    torch.save(pos_hs, os.path.join(root, "pos-hs.pt"))
    torch.save(y, os.path.join(root, "y.pt"))

    if not layer and not os.path.exists(os.path.join(root, "neg-lens.pt")):
        neg_lens = torch.tensor([len(tokenizer.encode(text)) for text in neg_prompts])
        pos_lens = torch.tensor([len(tokenizer.encode(text)) for text in pos_prompts])
        # Because we added an EOS token
        if model_type == "decoder":
            neg_lens += 1
            pos_lens += 1
        torch.save(neg_lens, os.path.join(root, "neg-lens.pt"))
        torch.save(pos_lens, os.path.join(root, "pos-lens.pt"))

In [22]:
def save_professions_trials(trials):
    """
    Saves hidden states for each trial in saved/professions/trialname_promptX
    """
    # Read in data
    prompts = parse_prompts()

    # Pass data throug hidden states
    for trial in trials:
        for i, prompt in enumerate(prompts):
            print(f"Creating hs for {trial['model_type']} model {trial['model_name']} with prompt {prompt}")
            # Create prompts from professions
            neg_prompts, pos_prompts, y = parse_professions(PROFESSIONS_PATH, prompt, undersample=False)

            save_hidden_states(
                model_name=trial["model_name"], 
                model_type=trial["model_type"],
                trial_name=f"professions/{trial['trial_name']}_prompt{i}",
                neg_prompts=neg_prompts,
                pos_prompts=pos_prompts,
                y=y, 
                layer=True,
                verbose=True)
            

def save_crowspairs_trials(trials, layer, filter=None, force=False):
    # Pass data through hidden states
    for trial in trials:
        if filter[0] == "all" and len(filter) == 1:
            filter = None
        # For save path
        prefix = "crowspairs/" if layer else "crowspairs-token/"
        if filter:
            filter = sorted(filter)
            prefix += "_".join(filter) + "/"
        full_trial_name = f"{prefix}{trial['trial_name']}"

        # If we already have something saved here, skip it
        if os.path.exists(os.path.join(SAVE_DIR, full_trial_name)) and not force:
            print(f"Already exists hs for {trial['model_type']} model {trial['model_name']} with crowspairs {filter} ")
        
        else:
            print(f"Creating hs for {trial['model_type']} model {trial['model_name']} with crowspairs {filter} across {'layers' if layer else 'tokens'}")
            # Create prompts from professions
            neg_prompts, pos_prompts, y = parse_crowspairs(CROWSPAIRS_PATH, filter=filter)
            save_hidden_states(
                model_name=trial["model_name"], 
                model_type=trial["model_type"],
                trial_name=full_trial_name,
                neg_prompts=neg_prompts,
                pos_prompts=pos_prompts,
                y=y, 
                layer=layer,
                verbose=True)

In [23]:
gpt2_trials = [
    {"trial_name": "gpt2",
     "model_name": "gpt2",
     "model_type": "decoder"},
     {"trial_name": "gpt2-large",
     "model_name": "gpt2-large",
     "model_type": "decoder"},
     {"trial_name": "gpt2-xl",
     "model_name": "gpt2-xl",
     "model_type": "decoder"},
     {"trial_name": "gpt2-medium",
     "model_name": "gpt2-medium",
     "model_type": "decoder"},
]

roberta_trials = [
    {"trial_name": "roberta-base",
     "model_name": "roberta-base",
     "model_type": "encoder"},
     {"trial_name": "roberta-large",
     "model_name": "roberta-large",
     "model_type": "encoder"},
]

flan_t5_trials = [
    {"trial_name": "flan-t5-small",
     "model_name": "google/flan-t5-small",
     "model_type": "encoder-decoder"},
     {"trial_name": "flan-t5-base",
     "model_name": "google/flan-t5-base",
     "model_type": "encoder-decoder"},
     {"trial_name": "flan-t5-large",
     "model_name": "google/flan-t5-large",
     "model_type": "encoder-decoder"},
]

In [24]:
df = pd.read_csv(CROWSPAIRS_PATH)
filters = list(df["bias_type"].unique())
print(f"bias types: {filters}")
for filter in filters:
    #save_crowspairs_trials(roberta_trials, False, [filter])
    #save_crowspairs_trials(gpt2_trials, False, [filter])
    save_crowspairs_trials(flan_t5_trials, False, [filter], force=True)

bias types: ['race-color', 'socioeconomic', 'gender', 'disability', 'nationality', 'sexual-orientation', 'physical-appearance', 'religion', 'age']
Creating hs for encoder-decoder model google/flan-t5-small with crowspairs ['race-color'] across tokens
torch.Size([473, 53, 512]) torch.Size([473, 53, 512]) torch.Size([473])
Creating hs for encoder-decoder model google/flan-t5-base with crowspairs ['race-color'] across tokens
torch.Size([473, 53, 768]) torch.Size([473, 53, 768]) torch.Size([473])
Creating hs for encoder-decoder model google/flan-t5-large with crowspairs ['race-color'] across tokens
torch.Size([473, 53, 1024]) torch.Size([473, 53, 1024]) torch.Size([473])
Creating hs for encoder-decoder model google/flan-t5-small with crowspairs ['socioeconomic'] across tokens
torch.Size([157, 56, 512]) torch.Size([157, 56, 512]) torch.Size([157])
Creating hs for encoder-decoder model google/flan-t5-base with crowspairs ['socioeconomic'] across tokens
torch.Size([157, 56, 768]) torch.Size([