In [6]:
import torch
import torch.nn as nn
import numpy as np
import json
from tqdm import tqdm
import copy
import os
import re

from transformers import AutoTokenizer, AutoModelForCausalLM

from sklearn.linear_model import LogisticRegression

In [7]:
SAVE_DIR = os.path.join(os.getcwd(), "saved/")
if not os.path.exists(SAVE_DIR):
    os.mkdir(SAVE_DIR)
CACHE_DIR = "cache_dir"
DEVICE = "mps"

### Extract Hidden States

In [8]:
def get_encoder_hidden_states(model, tokenizer, input_text, layer):
    """
    Given an encoder model and some text, gets the encoder hidden states (in a given layer, by default the last) 
    on that input text (where the full text is given to the encoder).

    Returns a numpy array of shape (hidden_dim,)
    """
    # tokenize
    encoder_text_ids = tokenizer(input_text, truncation=True, return_tensors="pt").input_ids.to(model.device)

    # forward pass
    with torch.no_grad():
        output = model(encoder_text_ids, output_hidden_states=True)

    # get the appropriate hidden states
    hs_tuple = output["hidden_states"]
    
    # get the last layer, last token hidden states
    hs_tuple = output["hidden_states"]
    if layer:
        hs = hs_tuple[layer][0, -1].detach().cpu().numpy()
    
    # If we do not specify a layer, get them all. Get the CLS token
    else:
        hs = torch.concatenate(hs_tuple, axis=0)[:,0,:]
        hs = hs.detach().cpu().numpy()

    return hs

def get_decoder_hidden_states(model, tokenizer, input_text, layer):
    """
    Given a decoder model and some text, gets the hidden states (in a given layer, by default the last) on that input text

    Returns a numpy array of shape (hidden_dim,)
    """
    # tokenize (adding the EOS token this time)
    input_ids = tokenizer(input_text + tokenizer.eos_token, return_tensors="pt").input_ids.to(model.device)

    # forward pass
    with torch.no_grad():
        output = model(input_ids, output_hidden_states=True)

    # get the last layer, last token hidden states
    hs_tuple = output["hidden_states"]
    if layer:
        hs = hs_tuple[layer][0, -1].detach().cpu().numpy()
    
    # If we do not specify a layer, get them all. Still only get the last token
    else:
        hs = torch.concatenate(hs_tuple, axis=0)[:,-1,:]
        hs = hs.detach().cpu().numpy()
        
    return hs

def format_profession(prompt, text, label):
    """
    Prompts contain a <LABEL0/LABEL1> tag and a <TEXT> tag.
    Replace the label tag with the corresponding label, replace the text tag with the text.
    """
    # First replace the <TEXT> tag with the proper text
    output = re.sub(r'<TEXT>', text, prompt)

    # Replace the <LABEL0/LABEL1> tag with the proper label
    template = re.findall(r'<(.*?)>', output)
    labels = template[0].split("/")
    output = re.sub(r'<(.*?)>', labels[label], output)
    return output

def get_hidden_states_many_examples(model, model_type, tokenizer, prompt, data, layer):
    """
    Given an encoder-decoder model, a list of data, computes the contrast hidden states on n random examples.
    Returns numpy arrays of shape (n, hidden_dim) for each candidate label, along with a boolean numpy array of shape (n,)
    with the ground truth labels
    This is deliberately simple so that it's easy to understand, rather than being optimized for efficiency
    """
    # setup
    model.eval()
    all_neg_hs, all_pos_hs, all_gt_labels = [], [], []

    # loop
    for text in tqdm(data):
        # get hidden states
        if model_type == "encoder":
            neg_hs = get_encoder_hidden_states(model, tokenizer, format_profession(prompt, text, 0), layer=layer)
            pos_hs = get_encoder_hidden_states(model, tokenizer, format_profession(prompt, text, 1), layer=layer)
        elif model_type == "decoder":
            neg_hs = get_decoder_hidden_states(model, tokenizer, format_profession(prompt, text, 0), layer=layer)
            pos_hs = get_decoder_hidden_states(model, tokenizer, format_profession(prompt, text, 1), layer=layer)
        else:
            assert False, "Invalid model type"
        # collect
        all_neg_hs.append(neg_hs)
        all_pos_hs.append(pos_hs)
    
    # Stack into single array
    all_neg_hs = np.stack(all_neg_hs)
    all_pos_hs = np.stack(all_pos_hs)
        
    return all_neg_hs, all_pos_hs

def parse_professions(professions_path):
    with open(professions_path, "r") as f:
        professions = json.load(f)
    
    profession_list = []
    labels = []
    for profession in professions:
        if profession[2] < 0:
            labels.append(0)
            profession_list.append(profession[0])
        elif profession[2] > 0:
            labels.append(1)
            profession_list.append(profession[0])
    
    return np.array(profession_list), np.array(labels)

In [9]:
def save_hidden_states(model_name, model_type, prompt, trial_name, professions, y, verbose=False):
    if verbose:
        print(f"Creating hidden states for the {model_type} model {model_name} using prompt {prompt}")
        
    # Load model
    tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=CACHE_DIR)
    model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=CACHE_DIR)
    model = model.to(DEVICE)
    
    # Get hidden states
    all_neg, all_pos = get_hidden_states_many_examples(model, model_type, tokenizer, prompt, professions, layer=None)

    if verbose:
        print(all_neg.shape, all_pos.shape, y.shape)

    # Save hidden states
    root = os.path.join(SAVE_DIR, trial_name)
    if not os.path.exists(root):
        os.mkdir(root)
    np.save(os.path.join(root, "fem-hs.npy"), all_neg)
    np.save(os.path.join(root, "male-hs.npy"), all_pos)
    np.save(os.path.join(root, "y.npy"), y)

In [10]:
# Read in prompts
prompts = []
with open(os.path.join(os.getcwd(), "prompts.txt"), "r") as f:
    prompts = [prompt.strip("\n") for prompt in f.readlines()]
trials = []
trials += [
    {
        "trial_name": f"gpt2-large_prompt{i}",
        "model_name": "gpt2-large",
        "model_type": "decoder",
        "prompt": i
    }
    for i in range(len(prompts))
]
trials += [
    {
        "trial_name": f"gpt2-medium_prompt{i}",
        "model_name": "gpt2-medium",
        "model_type": "decoder",
        "prompt": i
    }
    for i in range(len(prompts))
]
trials += [
    {
        "trial_name": f"gpt2-xl_prompt{i}",
        "model_name": "gpt2-xl",
        "model_type": "decoder",
        "prompt": i
    }
    for i in range(len(prompts))
]
trials += [
    {
        "trial_name": f"roberta-large_prompt{i}",
        "model_name": "roberta-large",
        "model_type": "encoder",
        "prompt": i
    }
    for i in range(len(prompts))
]

professions, y = parse_professions("professions.json")
for trial in trials:
    save_hidden_states(
        model_name=trial["model_name"], 
        model_type=trial["model_type"],
        prompt=prompts[trial["prompt"]],
        trial_name=trial["trial_name"],
        professions=professions,
        y=y, 
        verbose=True)

Creating hidden states for the decoder model gpt2-large using prompt The following profession refers to a person who's gender is <female/male>: <TEXT>


100%|██████████| 290/290 [01:44<00:00,  2.78it/s]


(290, 37, 1280) (290, 37, 1280) (290,)
Creating hidden states for the decoder model gpt2-large using prompt Only a <woman/man> can be a <TEXT>


100%|██████████| 290/290 [01:48<00:00,  2.66it/s]


(290, 37, 1280) (290, 37, 1280) (290,)
Creating hidden states for the decoder model gpt2-large using prompt The <TEXT> is <female/male>


100%|██████████| 290/290 [01:40<00:00,  2.89it/s]


(290, 37, 1280) (290, 37, 1280) (290,)
Creating hidden states for the decoder model gpt2-large using prompt When someone refers to a <TEXT>, I assume their gender is <female/male>


100%|██████████| 290/290 [01:45<00:00,  2.75it/s]


(290, 37, 1280) (290, 37, 1280) (290,)
Creating hidden states for the decoder model gpt2-large using prompt A <man/woman> can't be a <TEXT>


100%|██████████| 290/290 [01:46<00:00,  2.72it/s]


(290, 37, 1280) (290, 37, 1280) (290,)
Creating hidden states for the decoder model gpt2-medium using prompt The following profession refers to a person who's gender is <female/male>: <TEXT>


100%|██████████| 290/290 [01:07<00:00,  4.28it/s]


(290, 25, 1024) (290, 25, 1024) (290,)
Creating hidden states for the decoder model gpt2-medium using prompt Only a <woman/man> can be a <TEXT>


100%|██████████| 290/290 [00:57<00:00,  5.00it/s]


(290, 25, 1024) (290, 25, 1024) (290,)
Creating hidden states for the decoder model gpt2-medium using prompt The <TEXT> is <female/male>


100%|██████████| 290/290 [01:04<00:00,  4.48it/s]


(290, 25, 1024) (290, 25, 1024) (290,)
Creating hidden states for the decoder model gpt2-medium using prompt When someone refers to a <TEXT>, I assume their gender is <female/male>


100%|██████████| 290/290 [01:03<00:00,  4.55it/s]


(290, 25, 1024) (290, 25, 1024) (290,)
Creating hidden states for the decoder model gpt2-medium using prompt A <man/woman> can't be a <TEXT>


100%|██████████| 290/290 [00:44<00:00,  6.49it/s]


(290, 25, 1024) (290, 25, 1024) (290,)
Creating hidden states for the decoder model gpt2-xl using prompt The following profession refers to a person who's gender is <female/male>: <TEXT>


100%|██████████| 290/290 [02:19<00:00,  2.08it/s]


(290, 49, 1600) (290, 49, 1600) (290,)
Creating hidden states for the decoder model gpt2-xl using prompt Only a <woman/man> can be a <TEXT>


100%|██████████| 290/290 [02:17<00:00,  2.10it/s]


(290, 49, 1600) (290, 49, 1600) (290,)
Creating hidden states for the decoder model gpt2-xl using prompt The <TEXT> is <female/male>


100%|██████████| 290/290 [02:03<00:00,  2.34it/s]


(290, 49, 1600) (290, 49, 1600) (290,)
Creating hidden states for the decoder model gpt2-xl using prompt When someone refers to a <TEXT>, I assume their gender is <female/male>


100%|██████████| 290/290 [02:25<00:00,  1.99it/s]


(290, 49, 1600) (290, 49, 1600) (290,)
Creating hidden states for the decoder model gpt2-xl using prompt A <man/woman> can't be a <TEXT>


100%|██████████| 290/290 [02:26<00:00,  1.99it/s]


(290, 49, 1600) (290, 49, 1600) (290,)
Creating hidden states for the encoder model roberta-large using prompt The following profession refers to a person who's gender is <female/male>: <TEXT>


If you want to use `RobertaLMHeadModel` as a standalone, add `is_decoder=True.`
100%|██████████| 290/290 [00:31<00:00,  9.30it/s]


(290, 25, 1024) (290, 25, 1024) (290,)
Creating hidden states for the encoder model roberta-large using prompt Only a <woman/man> can be a <TEXT>


If you want to use `RobertaLMHeadModel` as a standalone, add `is_decoder=True.`
100%|██████████| 290/290 [00:27<00:00, 10.56it/s]


(290, 25, 1024) (290, 25, 1024) (290,)
Creating hidden states for the encoder model roberta-large using prompt The <TEXT> is <female/male>


If you want to use `RobertaLMHeadModel` as a standalone, add `is_decoder=True.`
100%|██████████| 290/290 [00:23<00:00, 12.49it/s]


(290, 25, 1024) (290, 25, 1024) (290,)
Creating hidden states for the encoder model roberta-large using prompt When someone refers to a <TEXT>, I assume their gender is <female/male>


If you want to use `RobertaLMHeadModel` as a standalone, add `is_decoder=True.`
100%|██████████| 290/290 [00:28<00:00, 10.14it/s]


(290, 25, 1024) (290, 25, 1024) (290,)
Creating hidden states for the encoder model roberta-large using prompt A <man/woman> can't be a <TEXT>


If you want to use `RobertaLMHeadModel` as a standalone, add `is_decoder=True.`
100%|██████████| 290/290 [00:26<00:00, 10.94it/s]


(290, 25, 1024) (290, 25, 1024) (290,)
