In [38]:
import os
import torch
import pandas as pd
import numpy as np
from tqdm import tqdm

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForMaskedLM, AutoModelForSeq2SeqLM

In [39]:
DEVICE = "mps"

CACHE_DIR = os.path.join(os.getcwd(), "cache_dir")
DATA_DIR = os.path.join(os.getcwd(), "data")
CROWSPAIRS_PATH = os.path.join(DATA_DIR, "crows_pairs_anonymized.csv")
RESULTS_DIR = os.path.join(os.getcwd(), "results/msp")

In [45]:
def mask_differences(A, B, mask):
    A_masked = A.detach().clone().squeeze()
    B_masked = B.detach().clone().squeeze()
    a_start, a_end = 0, len(A) - 1
    b_start, b_end = 0, len(B) - 1
    while A_masked[a_start] == B_masked[b_start]:
        a_start += 1
        b_start += 1
    while A_masked[a_end] == B_masked[b_end]:
        a_end -= 1
        b_end -= 1

    A_masked[a_start:a_end+1] = mask
    B_masked[b_start:b_end+1] = mask
    return A_masked.unsqueeze(0), B_masked.unsqueeze(0)


def get_perplexity(input_ids, model, tokenizer, model_type):
    if model_type == "decoder":
        out = model(input_ids, labels=input_ids)
    elif model_type == "encoder-decoder":
        decoder_input_ids = tokenizer("", return_tensors="pt").input_ids.to(model.device)
        out = model(input_ids, decoder_input_ids=decoder_input_ids, labels=decoder_input_ids)
    else:
        assert False, "Give a correct model type"
    loss = out.loss
    return torch.exp(loss).item()


def compute_all_perplexities(text, counter, model, tokenizer, model_type):
    
    # Add the EOS token if we're decoding
    if model_type == "decoder":
        text += tokenizer.eos_token
        counter += tokenizer.eos_token

    in_tokens = tokenizer(text, return_tensors="pt").input_ids.to(model.device)
    counter_tokens = tokenizer(counter, return_tensors="pt").input_ids.to(model.device)
    mask_tokens = mask_differences(in_tokens, counter_tokens, tokenizer.unk_token_id)[0]

    perp = get_perplexity(in_tokens, model, tokenizer, model_type)
    counter_perp = get_perplexity(counter_tokens, model, tokenizer, model_type)
    masked_perp = get_perplexity(mask_tokens, model, tokenizer, model_type)
    return perp, counter_perp, masked_perp

In [41]:
def parse_crowspairs(crowspairs_path, filter=None):
    df = pd.read_csv(crowspairs_path)
    df = df[df["stereo_antistereo"] == "stereo"]

    # Filter by filter
    if filter:
        df = df[df["bias_type"].isin(filter)]

    return df["sent_more"].tolist(), df["sent_less"].tolist()

In [42]:
def msp(trials, verbose=True):
    for trial in trials:
        # Load data
        texts, counters = parse_crowspairs(CROWSPAIRS_PATH, filter=trial["filter"])

        if trial["model_type"] == "decoder":
            model = AutoModelForCausalLM.from_pretrained(trial["model_name"], cache_dir=CACHE_DIR)
        elif trial["model_type"] == "encoder-decoder":
            model = AutoModelForSeq2SeqLM.from_pretrained(trial["model_name"], cache_dir=CACHE_DIR)
        else:
            assert False, "Input a valid model type"
        model = model.to(DEVICE)
        tokenizer = AutoTokenizer.from_pretrained(trial["model_name"], cache_dir=CACHE_DIR)

        base_perps = []
        counter_perps = []
        masked_perps = []
        for text, counter in tqdm(list(zip(texts, counters))):
            base_perp, counter_perp, masked_perp = compute_all_perplexities(text, counter, model, tokenizer, trial["model_type"])
            base_perps.append(base_perp)
            counter_perps.append(counter_perp)
            masked_perps.append(masked_perp)

        base_perps = np.array(base_perps)
        counter_perps = np.array(counter_perps)
        masked_perps = np.array(masked_perps)
        
        # Save hidden states
        root = os.path.join(RESULTS_DIR, trial["trial_name"])
        if not os.path.exists(root):
            print(f"Creating directory {root}")
            os.makedirs(root)
        np.save(os.path.join(root, "base_perps.npy"), base_perps)
        np.save(os.path.join(root, "counter_perps.npy"), counter_perps)
        np.save(os.path.join(root, "masked_perps.npy"), masked_perps)

        if verbose:
            print(base_perps.mean(), counter_perps.mean(), masked_perps.mean())

In [43]:
gpt2_trials = [
    {"trial_name": "gpt2",
     "model_name": "gpt2",
     "model_type": "decoder",
     "filter": ["race-color"]}
]

In [44]:
msp(gpt2_trials)

100%|██████████| 473/473 [01:22<00:00,  5.71it/s]

128.95556921686733 134.00291123904336 7376.902352246371



