In [1]:
import torch
import torch.nn as nn
import numpy as np
import json
from tqdm import tqdm
import copy
import os

from transformers import AutoTokenizer, AutoModelForCausalLM

from sklearn.linear_model import LogisticRegression

### Extract Hidden States

In [2]:
def get_decoder_hidden_states(model, tokenizer, input_text, layer=-1):
    """
    Given a decoder model and some text, gets the hidden states (in a given layer, by default the last) on that input text

    Returns a numpy array of shape (hidden_dim,)
    """
    # tokenize (adding the EOS token this time)
    input_ids = tokenizer(input_text + tokenizer.eos_token, return_tensors="pt").input_ids.to(model.device)

    # forward pass
    with torch.no_grad():
        output = model(input_ids, output_hidden_states=True)

    # get the last layer, last token hidden states
    hs_tuple = output["hidden_states"]
    if layer:
        hs = hs_tuple[layer][0, -1].detach().cpu().numpy()
    
    # If we do not specify a layer, get them all. Still only get the last token
    else:
        hs = torch.concatenate(hs_tuple, axis=0)[:,-1,:]
        hs = hs.detach().cpu().numpy()
        
    return hs

def format_imdb(text, label):
    """
    Given an imdb example ("text") and corresponding label (0 for negative, or 1 for positive), 
    returns a zero-shot prompt for that example (which includes that label as the answer).
    
    (This is just one example of a simple, manually created prompt.)
    """
    return "The following movie review expresses a " + ["negative", "positive"][label] + " sentiment:\n" + text

def format_profession(text, label):
    return f"The following profession refers to a person who's gender is f{['female', 'male'][label]}:\n{text}"

def get_hidden_states_many_examples(model, tokenizer, data, layer=-1):
    """
    Given an encoder-decoder model, a list of data, computes the contrast hidden states on n random examples.
    Returns numpy arrays of shape (n, hidden_dim) for each candidate label, along with a boolean numpy array of shape (n,)
    with the ground truth labels
    This is deliberately simple so that it's easy to understand, rather than being optimized for efficiency
    """
    # setup
    model.eval()
    all_neg_hs, all_pos_hs, all_gt_labels = [], [], []

    # loop
    for text in tqdm(data):
        # get hidden states
        neg_hs = get_decoder_hidden_states(model, tokenizer, format_profession(text, 0), layer=layer)
        pos_hs = get_decoder_hidden_states(model, tokenizer, format_profession(text, 1), layer=layer)
        # collect
        all_neg_hs.append(neg_hs)
        all_pos_hs.append(pos_hs)
    
    # Stack into single array
    all_neg_hs = np.stack(all_neg_hs)
    all_pos_hs = np.stack(all_pos_hs)
        
    return all_neg_hs, all_pos_hs

def parse_professions(professions_path):
    with open(professions_path, "r") as f:
        professions = json.load(f)
    
    profession_list = []
    labels = []
    for profession in professions:
        if profession[2] < 0:
            labels.append(0)
            profession_list.append(profession[0])
        elif profession[2] > 0:
            labels.append(1)
            profession_list.append(profession[0])
    
    return np.array(profession_list), np.array(labels)

In [3]:
MODEL_NAME = "gpt2-xl"

cache_dir = "cache_dir"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=cache_dir)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, cache_dir=cache_dir)
model.cuda()
professions, y = parse_professions("professions.json")

In [4]:
SAVE_DIR = os.path.join(os.getcwd(), "saved/")
TRIAL_NAME = "first_try"
root = os.path.join(SAVE_DIR, TRIAL_NAME)
if not os.path.exists(root):
    os.mkdir(root)

all_neg, all_pos = get_hidden_states_many_examples(model, tokenizer, professions, layer=None)
print(all_neg.shape, all_pos.shape, y.shape)
np.save(os.path.join(root, "fem-hs.npy"), all_neg)
np.save(os.path.join(root, "male-hs.npy"), all_pos)
np.save(os.path.join(root, "y.npy"), y)

100%|██████████| 290/290 [00:29<00:00,  9.68it/s]


(290, 49, 1600) (290, 49, 1600) (290,)


### Playground

In [35]:
# tokenize (adding the EOS token this time)
input_ids = tokenizer("0 1 2 3 4 5" + tokenizer.eos_token, return_tensors="pt").input_ids.to(model.device)

# forward pass
with torch.no_grad():
    output = model(input_ids, output_hidden_states=True)

In [39]:
torch.concatenate(output["hidden_states"], dim=0)[:,-1,:].shape

torch.Size([49, 1600])