In [3]:
import os
import torch
import pandas as pd
import numpy as np
from tqdm import tqdm

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForMaskedLM, AutoModelForSeq2SeqLM

In [4]:
DEVICE = "mps"

CACHE_DIR = os.path.join(os.getcwd(), "cache_dir")
DATA_DIR = os.path.join(os.getcwd(), "data")
CROWSPAIRS_PATH = os.path.join(DATA_DIR, "crows_pairs_anonymized.csv")
RESULTS_DIR = os.path.join(os.getcwd(), "results/msp")

In [5]:
def mask_differences(A, B, mask):
    A_masked = A.detach().clone().squeeze()
    B_masked = B.detach().clone().squeeze()
    a_start, a_end = 0, len(A) - 1
    b_start, b_end = 0, len(B) - 1
    while A_masked[a_start] == B_masked[b_start]:
        a_start += 1
        b_start += 1
    while A_masked[a_end] == B_masked[b_end]:
        a_end -= 1
        b_end -= 1

    A_masked[a_start:a_end+1] = mask
    B_masked[b_start:b_end+1] = mask
    return A_masked.unsqueeze(0), B_masked.unsqueeze(0)


def get_perplexity(input_ids, model, tokenizer, model_type):
    if model_type == "decoder":
        out = model(input_ids, labels=input_ids)
    elif model_type == "encoder-decoder":
        decoder_input_ids = tokenizer("", return_tensors="pt").input_ids.to(model.device)
        out = model(input_ids, decoder_input_ids=decoder_input_ids, labels=decoder_input_ids)
    else:
        assert False, "Give a correct model type"
    loss = out.loss
    return torch.exp(loss).item()


def compute_all_perplexities(text, counter, model, tokenizer, model_type):
    
    # Add the EOS token if we're decoding
    if model_type == "decoder":
        text += tokenizer.eos_token
        counter += tokenizer.eos_token

    in_tokens = tokenizer(text, return_tensors="pt").input_ids.to(model.device)
    counter_tokens = tokenizer(counter, return_tensors="pt").input_ids.to(model.device)
    mask_tokens = mask_differences(in_tokens, counter_tokens, tokenizer.unk_token_id)[0]

    perp = get_perplexity(in_tokens, model, tokenizer, model_type)
    counter_perp = get_perplexity(counter_tokens, model, tokenizer, model_type)
    masked_perp = get_perplexity(mask_tokens, model, tokenizer, model_type)
    return perp, counter_perp, masked_perp

In [6]:
def parse_crowspairs(filter=None):
    df = pd.read_csv(CROWSPAIRS_PATH)
    df = df[df["stereo_antistereo"] == "stereo"]

    # Filter by filter
    if filter:
        df = df[df["bias_type"] == filter]

    return df["sent_more"].tolist(), df["sent_less"].tolist()

In [7]:
def msp(trials, filters, verbose=True):
    for trial in trials:
        # Set up model for trial
        if trial["model_type"] == "decoder":
            model = AutoModelForCausalLM.from_pretrained(trial["model_name"], cache_dir=CACHE_DIR)
        elif trial["model_type"] == "encoder-decoder":
            model = AutoModelForSeq2SeqLM.from_pretrained(trial["model_name"], cache_dir=CACHE_DIR)
        else:
            assert False, "Input a valid model type"
        model = model.to(DEVICE)
        tokenizer = AutoTokenizer.from_pretrained(trial["model_name"], cache_dir=CACHE_DIR)

        # Go through each filter and get results
        for filter in filters:
            
            path = os.path.join(RESULTS_DIR, f"{filter}/{trial['trial_name']}")
            if os.path.exists(path):
                print(f"Already exists results for model {trial['model_name']}, bias type {filter}")
                continue
            
            print(f"Creating MSP results for {trial['model_name']}, bias type {filter}")

            # Load data
            if filter == "all":
                filter = None
            texts, counters = parse_crowspairs(filter)

            base_perps = []
            counter_perps = []
            masked_perps = []
            for text, counter in tqdm(list(zip(texts, counters))):
                base_perp, counter_perp, masked_perp = compute_all_perplexities(text, counter, model, tokenizer, trial["model_type"])
                base_perps.append(base_perp)
                counter_perps.append(counter_perp)
                masked_perps.append(masked_perp)

            base_perps = np.array(base_perps)
            counter_perps = np.array(counter_perps)
            masked_perps = np.array(masked_perps)
            
            # Save hidden states
            print(f"Creating directory {path}")
            os.makedirs(path)
            np.save(os.path.join(path, "base_perps.npy"), base_perps)
            np.save(os.path.join(path, "counter_perps.npy"), counter_perps)
            np.save(os.path.join(path, "masked_perps.npy"), masked_perps)

            if verbose:
                print(base_perps.mean(), counter_perps.mean(), masked_perps.mean())

In [8]:
gpt2_trials = [
    {"trial_name": "gpt2",
     "model_name": "gpt2",
     "model_type": "decoder"},
     {"trial_name": "gpt2-large",
     "model_name": "gpt2-large",
     "model_type": "decoder"},
     {"trial_name": "gpt2-xl",
     "model_name": "gpt2-xl",
     "model_type": "decoder"},
     {"trial_name": "gpt2-medium",
     "model_name": "gpt2-medium",
     "model_type": "decoder"},
]
flan_t5_trials = [
    {"trial_name": "flan-t5-small",
     "model_name": "google/flan-t5-small",
     "model_type": "encoder-decoder"},
     {"trial_name": "flan-t5-base",
     "model_name": "google/flan-t5-base",
     "model_type": "encoder-decoder"},
     {"trial_name": "flan-t5-large",
     "model_name": "google/flan-t5-large",
     "model_type": "encoder-decoder"}
]
df = pd.read_csv(CROWSPAIRS_PATH)
filters = ["all"] + list(df["bias_type"].unique())

In [None]:
msp(gpt2_trials, filters)
msp(flan_t5_trials, filters)

In [9]:
gpt2_models = ["gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"]
roberta_models = ["roberta-base", "roberta-large"]
flan_t5_models = ["flan-t5-small", "flan-t5-base", "flan-t5-large"]


In [28]:
def msp_results(model_names, filters):
    rows = []
    rows_mask = []
    for filter in filters:
        row = {}
        row_mask = {}
        s = filter + " "
        for model_name in model_names:
            path = os.path.join(os.getcwd(), f"results/msp/{filter}/{model_name}")
            base_perps = np.load(os.path.join(path, "base_perps.npy"))
            counter_perps = np.load(os.path.join(path, "counter_perps.npy"))
            masked_perps = np.load(os.path.join(path, "masked_perps.npy"))
            diff_counter = np.mean(counter_perps) - np.mean(base_perps)
            diff_mask = np.mean(masked_perps) - np.mean(base_perps)
            s += f"& {round(diff_counter, 2)} "
            row[model_name] = diff_counter
            row_mask[model_name] = diff_mask
        s += "\\\\"
        print(s)

        rows.append(row)
        rows_mask.append(row_mask)

    df_counter = pd.DataFrame(rows, index=filters)
    df_mask = pd.DataFrame(rows_mask, index=filters)

    return df_counter, df_mask

print(msp_results(gpt2_models, filters)) 

all & 3.82 & 8.86 & 8.0 & 7.91 \\
race-color & 5.05 & 12.22 & 12.42 & 5.54 \\
socioeconomic & 10.23 & 11.55 & 9.02 & 10.82 \\
gender & -22.8 & -0.01 & -11.22 & 8.81 \\
disability & 24.94 & 20.49 & 22.46 & 13.65 \\
nationality & -3.53 & 0.4 & 1.83 & 3.73 \\
sexual-orientation & 31.61 & 23.59 & 26.9 & 20.24 \\
physical-appearance & 22.53 & 22.96 & 17.78 & 20.11 \\
religion & 18.69 & 18.63 & 19.08 & 20.64 \\
age & -22.34 & -29.11 & -20.3 & -19.05 \\
(                          gpt2  gpt2-medium  gpt2-large    gpt2-xl
all                   3.824816     8.862363    8.004009   7.909845
race-color            5.047342    12.223942   12.418284   5.535150
socioeconomic        10.232997    11.548631    9.018133  10.821410
gender              -22.801524    -0.009595  -11.223390   8.809298
disability           24.941194    20.488231   22.459889  13.652316
nationality          -3.532796     0.401660    1.829124   3.726966
sexual-orientation   31.608135    23.591473   26.903229  20.239925
physical-app