In [None]:
import os
import torch
import pandas as pd
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM

from masking import get_overlap_indices, get_masked_and_label
from constants import CACHE_DIR, DATA_DIR, DEVICE, CROWSPAIRS_PATH, MSP_RESULTS_DIR, PERP_RESULTS_DIR, FILTERS, GPT2_MODELS, FLAN_T5_MODELS

## Decoder Perplexity Analysis

In [None]:
def get_perplexity(input, model):
    out = model(**input, labels=input.input_ids)
    loss = out.loss
    return torch.exp(loss).item()


def compute_all_perplexities(text, counter, model, tokenizer):
    # Tokenizer doesn't do this for us
    text += tokenizer.eos_token
    counter += tokenizer.eos_token

    in_tokens = tokenizer(text, padding=True, return_tensors="pt").to(model.device)
    counter_tokens = tokenizer(counter, padding=True, return_tensors="pt").to(model.device)

    perp = get_perplexity(in_tokens, model)
    counter_perp = get_perplexity(counter_tokens, model)
    return perp, counter_perp

In [None]:
def parse_crowspairs(filter):
    df = pd.read_csv(CROWSPAIRS_PATH)
    df = df[df["stereo_antistereo"] == "stereo"]

    # Filter by filter
    if filter != "all":
        df = df[df["bias_type"] == filter]

    return df["sent_more"].tolist(), df["sent_less"].tolist()

In [None]:
def perplexity_run(model_list, filters, verbose=True, force=False, save=True):

    for model_name in model_list:
        # Set up model for trial
        model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=CACHE_DIR)
        model = model.to(DEVICE)
        tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=CACHE_DIR)
        tokenizer.pad_token = tokenizer.eos_token

        # Go through each filter and get results
        for filter in filters:
            path = os.path.join(PERP_RESULTS_DIR, f"{filter}/{model_name}")
            if os.path.exists(path) and not force:
                print(f"Already exists results for model {model_name}, bias type {filter}")
                continue
            
            print(f"Creating perp results for {model_name}, bias type {filter}")

            # Load data
            texts, counters = parse_crowspairs(filter)

            base_perps = []
            counter_perps = []
            for text, counter in tqdm(list(zip(texts, counters))):
                base_perp, counter_perp = compute_all_perplexities(text, counter, model, tokenizer)
                base_perps.append(base_perp)
                counter_perps.append(counter_perp)

            base_perps = np.array(base_perps)
            counter_perps = np.array(counter_perps)
            
            if save:
                # Save hidden states
                if not os.path.exists(path):
                    print(f"Creating directory {path}")
                    os.makedirs(path)
                np.save(os.path.join(path, "pos-perps.npy"), base_perps)
                np.save(os.path.join(path, "neg-perps.npy"), counter_perps)

            if verbose:
                print(f"median perp ratio: {sorted(base_perps / counter_perps)[len(base_perps) // 2]}")
                print(np.median(base_perps), np.median(counter_perps))
        
        # I don't know if this does anything. It didn't before
        del model
        del tokenizer
        torch.mps.empty_cache()

In [None]:
perplexity_run(GPT2_MODELS, FILTERS, verbose=True, force=True, save=True)

## Encoder-Decoder Perplexity Analysis

In [None]:
def get_model_loss(string, label, model, tokenizer):
    """
    Feeds text and input into tokenizer then model and outputs the loss.
    """
    tokens = tokenizer(string, return_tensors="pt").to(model.device)
    label = tokenizer(label, return_tensors="pt").input_ids.to(model.device)
    out = model(**tokens, labels=label)
    return torch.exp(out["loss"]).item()

In [None]:
def save_msp_losses(model_names, filters, save=True, force=False):
    """
    Generates losses for 2 groups for positive and negative for a total of 4:
        unmasked group: does msp with the bias word guaranteed to not be masked
        control group: guarantees the bias word is masked
    We want the difference between the unmasked and control 
    """

    for model_name in model_names:
        tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=CACHE_DIR, model_max_length=512)
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir=CACHE_DIR)
        model.to("mps")
        model.eval()

        for filter in filters:
            # If we already have something saved here, skip it
            path = os.path.join(MSP_RESULTS_DIR, f"{filter}/{model_name}")
            if os.path.exists(path) and not force:
                print(f"Already exists msp results for {model_name}:{filter}")
                continue
            
            print(f"Creating msp results for {model_name}, bias type {filter}")

            df = pd.read_csv(CROWSPAIRS_PATH)
            df = df[df["stereo_antistereo"] == "stereo"]
            if filter != "all":
                df = df[df["bias_type"] == filter]
            all_pos = df["sent_more"].to_list()
            all_neg = df["sent_less"].to_list()

            #all_pos, all_neg = filter_samples(all_pos, all_neg, max_len=3)
            print(f"pos len: {len(all_pos)}, neg len: {len(all_neg)}")

            pos_losses = []
            neg_losses = []
            for pos, neg in tqdm(zip(all_pos, all_neg), total=len(all_pos)):

                # Create masks, get masked strings and labels
                pos_mask, neg_mask = get_overlap_indices(pos, neg)
                pos_masked, pos_label = get_masked_and_label(pos, pos_mask)
                neg_masked, neg_label = get_masked_and_label(neg, neg_mask)
                # Run masked strings and labels through model
                pos_perp = get_model_loss(pos_masked, pos_label, model, tokenizer)
                neg_perp = get_model_loss(neg_masked, neg_label, model, tokenizer)
                pos_losses.append(pos_perp)
                neg_losses.append(neg_perp)
            
            pos_losses = np.array(pos_losses)
            neg_losses = np.array(neg_losses)
            print(f"median diff: {np.median(pos_losses/neg_losses)}")
            print(f"median pos perp: {np.median(pos_losses)}, median neg perp: {np.median(neg_losses)}")
            
            if save:
                if not os.path.exists(path):
                    os.makedirs(path)

                np.save(os.path.join(MSP_RESULTS_DIR, f"{filter}/{model_name}/pos-perps"), pos_losses)
                np.save(os.path.join(MSP_RESULTS_DIR, f"{filter}/{model_name}/neg-perps"), neg_losses)

        del model
        del tokenizer
        torch.mps.empty_cache()

In [None]:
#msp_model_list = ["google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large"]
msp_model_list = ["t5-small", "t5-base", "t5-large"]
filters = ["all"] + list(pd.read_csv(CROWSPAIRS_PATH)["bias_type"].unique())
save_msp_losses(msp_model_list, filters, force=False, save=True)

## Misc. Results Plotting

In [None]:
def load_results(model_name, filter, msp=True):
    if msp:
        path = os.path.join(MSP_RESULTS_DIR, f"{filter}/{model_name}")
    else:
        path = os.path.join(PERP_RESULTS_DIR, f"{filter}/{model_name}")
    all_pos_losses = np.load(os.path.join(path, "pos-perps.npy"))
    all_neg_losses = np.load(os.path.join(path, "neg-perps.npy"))

    return all_pos_losses, all_neg_losses

want pos loss - neg loss to be negative value.

In [None]:
def plot_diffs_filters(filters, model_name):
    mean_diffs = []
    for filter in filters:
        pos_losses, neg_losses = load_results(model_name, filter)
        mean_diff = np.mean(pos_losses - neg_losses)
        mean_diffs.append(mean_diff)
    plt.bar(filters, mean_diffs)
    plt.xticks(rotation=90)
    plt.gca().yaxis.grid(True)
    plt.show()

def boxplots(filters, model_name, msp=True):
    mean_diffs = {}
    for filter in filters:
        pos_losses, neg_losses = load_results(model_name, filter, msp)
        mean_diffs[filter] = pos_losses - neg_losses
    plt.axhline(y=0)
    plt.boxplot(mean_diffs.values(), labels=mean_diffs.keys(), showfliers=False)
    plt.xticks(rotation=90)
    plt.gca().yaxis.grid(True)
    plt.show()

In [None]:
filters = ["all"] + list(pd.read_csv(CROWSPAIRS_PATH)["bias_type"].unique())
boxplots(filters, "t5-large")
boxplots(filters, "gpt2-xl", msp=False)

In [None]:
filters = list(pd.read_csv(CROWSPAIRS_PATH)["bias_type"].unique())
filters = [filters] + [[filt] for filt in filters]
model_name = "flan-t5-small"
pos = []
neg = []
for filter in filters:
    pos_losses, neg_losses = load_results(model_name, filter)
    pos.append(np.mean(pos_losses))
    neg.append(np.mean(neg_losses))
filters = ["all" if len(filt) != 1 else filt[0] for filt in filters]
width = 0.4
x = np.arange(len(filters))
plt.bar(x-width/2, pos, width=width, label="pos")
plt.bar(x+width/2, neg, width=width, label="neg")
plt.xticks(x, filters, rotation=90)
plt.gca().yaxis.grid(True)
plt.legend()
plt.show()