In [None]:
import torch
import torch.nn as nn
import numpy as np
import json
from tqdm import tqdm
import copy
import os
import re

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForMaskedLM, AutoModelForSeq2SeqLM

In [None]:
SAVE_DIR = os.path.join(os.getcwd(), "saved/")
if not os.path.exists(SAVE_DIR):
    os.mkdir(SAVE_DIR)
CACHE_DIR = os.path.join(os.getcwd(), "cache_dir")
DEVICE = "mps"

### Extract Hidden States

In [None]:
def get_encoder_hidden_states(model, tokenizer, input_text, layer):
    """
    Given an encoder model and some text, gets the encoder hidden states (in a given layer, by default the last) 
    on that input text (where the full text is given to the encoder).

    Returns a numpy array of shape (hidden_dim,)
    """
    # tokenize
    encoder_text_ids = tokenizer(input_text, truncation=True, return_tensors="pt").input_ids.to(model.device)

    # forward pass
    with torch.no_grad():
        output = model(encoder_text_ids, output_hidden_states=True)

    # get the appropriate hidden states
    hs_tuple = output["hidden_states"]
    
    # get the last layer, last token hidden states
    hs_tuple = output["hidden_states"]
    if layer:
        hs = hs_tuple[layer][0, -1].detach().cpu().numpy()
    
    # If we do not specify a layer, get them all. Get the CLS token
    else:
        hs = torch.concatenate(hs_tuple, axis=0)[:,0,:]
        hs = hs.detach().cpu().numpy()

    return hs

def get_decoder_hidden_states(model, tokenizer, input_text, layer):
    """
    Given a decoder model and some text, gets the hidden states (in a given layer, by default the last) on that input text

    Returns a numpy array of shape (hidden_dim,)
    """
    # tokenize (adding the EOS token this time)
    input_ids = tokenizer(input_text + tokenizer.eos_token, return_tensors="pt").input_ids.to(model.device)

    # forward pass
    with torch.no_grad():
        output = model(input_ids, output_hidden_states=True)

    # get the last layer, last token hidden states
    hs_tuple = output["hidden_states"]
    if layer:
        hs = hs_tuple[layer][0, -1].detach().cpu().numpy()
    
    # If we do not specify a layer, get them all. Still only get the last token
    else:
        hs = torch.concatenate(hs_tuple, axis=0)[:,-1,:]
        hs = hs.detach().cpu().numpy()
        
    return hs

def get_encoder_decoder_hidden_states(model, tokenizer, input_text, layer):
    """
    Given an encoder-decoder model and some text, gets the encoder hidden states (in a given layer, by default the last) 
    on that input text (where the full text is given to the encoder).

    Returns a numpy array of shape (hidden_dim,)
    """
    # tokenize
    encoder_text_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(model.device)
    decoder_text_ids = tokenizer("", return_tensors="pt").input_ids.to(model.device)

    # forward pass
    with torch.no_grad():
        output = model(encoder_text_ids, decoder_input_ids=decoder_text_ids, output_hidden_states=True)

    # get the appropriate hidden states
    hs_tuple = output["encoder_hidden_states"]
    hs = hs_tuple[layer][0, -1].detach().cpu().numpy()

    return hs

def format_profession(prompt, text, label):
    """
    Prompts contain a <LABEL0/LABEL1> tag and a <TEXT> tag.
    Replace the label tag with the corresponding label, replace the text tag with the text.
    """
    # First replace the <TEXT> tag with the proper text
    output = re.sub(r'<TEXT>', text, prompt)

    # Replace the <LABEL0/LABEL1> tag with the proper label
    template = re.findall(r'<(.*?)>', output)
    labels = template[0].split("/")
    output = re.sub(r'<(.*?)>', labels[label], output)
    return output

def get_hidden_states_many_examples(model, model_type, tokenizer, prompt, data, layer):
    """
    Given an encoder-decoder model, a list of data, computes the contrast hidden states on n random examples.
    Returns numpy arrays of shape (n, hidden_dim) for each candidate label, along with a boolean numpy array of shape (n,)
    with the ground truth labels
    This is deliberately simple so that it's easy to understand, rather than being optimized for efficiency
    """
    # setup
    model.eval()
    all_neg_hs, all_pos_hs, all_gt_labels = [], [], []

    # loop
    for text in tqdm(data):
        # get hidden states
        if model_type == "encoder":
            neg_hs = get_encoder_hidden_states(model, tokenizer, format_profession(prompt, text, 0), layer=layer)
            pos_hs = get_encoder_hidden_states(model, tokenizer, format_profession(prompt, text, 1), layer=layer)
        elif model_type == "decoder":
            neg_hs = get_decoder_hidden_states(model, tokenizer, format_profession(prompt, text, 0), layer=layer)
            pos_hs = get_decoder_hidden_states(model, tokenizer, format_profession(prompt, text, 1), layer=layer)
        elif model_type == "encoder-decoder":
            neg_hs = get_encoder_decoder_hidden_states(model, tokenizer, format_profession(prompt, text, 0), layer=layer)
            pos_hs = get_encoder_decoder_hidden_states(model, tokenizer, format_profession(prompt, text, 1), layer=layer)
        else:
            assert False, "Invalid model type"
        # collect
        all_neg_hs.append(neg_hs)
        all_pos_hs.append(pos_hs)
    
    # Stack into single array
    all_neg_hs = np.stack(all_neg_hs)
    all_pos_hs = np.stack(all_pos_hs)
        
    return all_neg_hs, all_pos_hs

def parse_professions(professions_path, undersample=False):
    with open(professions_path, "r") as f:
        professions = json.load(f)
    
    profs = np.array([prof[0].replace("_", " ") for prof in professions])
    # reals = np.array([prof[1] for prof in professions])
    biases = np.array([prof[2] for prof in professions])

    fem_idx = np.where(biases < 0)[0]
    male_idx = np.where(biases > 0)[0]
    if undersample:
        male_idx = np.flip(np.argsort(biases))[:fem_idx.shape[0]]
    idx = np.concatenate([fem_idx, male_idx])
    labels = np.array([0 for _ in range(len(fem_idx))] + [1 for _ in range(len(male_idx))])

    return profs[idx], labels

In [None]:
def save_hidden_states(model_name, model_type, prompt, trial_name, professions, y, verbose=False):
    if verbose:
        print(f"Creating hidden states for the {model_type} model {model_name} using prompt {prompt}")
        
    # Load model
    tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=CACHE_DIR)
    if model_type == "encoder":
        model = AutoModelForMaskedLM.from_pretrained(model_name, cache_dir=CACHE_DIR)
    if model_type == "decoder":
        model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=CACHE_DIR)
    elif model_type == "encoder-decoder":
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir=CACHE_DIR)
    model = model.to(DEVICE)
    
    # Get hidden states
    all_neg, all_pos = get_hidden_states_many_examples(model, model_type, tokenizer, prompt, professions, layer=None)

    if verbose:
        print(all_neg.shape, all_pos.shape, y.shape)

    # Save hidden states
    root = os.path.join(SAVE_DIR, trial_name)
    if not os.path.exists(root):
        os.mkdir(root)
    np.save(os.path.join(root, "fem-hs.npy"), all_neg)
    np.save(os.path.join(root, "male-hs.npy"), all_pos)
    np.save(os.path.join(root, "y.npy"), y)

In [None]:
# Read in prompts
prompts = []
with open(os.path.join(os.getcwd(), "prompts.txt"), "r") as f:
    prompts = [prompt.strip("\n") for prompt in f.readlines()]
trials = []
trials += [
    {
        "trial_name": f"flan-t5-large_undersample_prompt{i}",
        "model_name": "google/flan-t5-large",
        "model_type": "encoder-decoder",
        "prompt": i
    }
    for i in range(len(prompts))
]
professions, y = parse_professions("professions.json", undersample=False)
for trial in trials:
    save_hidden_states(
        model_name=trial["model_name"], 
        model_type=trial["model_type"],
        prompt=prompts[trial["prompt"]],
        trial_name=trial["trial_name"],
        professions=professions,
        y=y, 
        verbose=True)