"To analyze this, we: 1) probe the model with the original question from our dataset, 2) probe the model with the question and a demographic group as context for all groups within an attribute, and 3) compare the answers produced by the model for each group against the model’s answer to the question with out context (original answer)."

notes:

- they don't say what kind of sampling strategy they use, i assume they are on the default parameters.
- they use a temperature of 0.

In [None]:
from langchain_together import ChatTogether
from langchain_core.prompt_values import ChatPromptValue
from langchain_core.messages import ChatMessage
import os
import pandas as pd

In [None]:
chat = ChatTogether(
    together_api_key=os.environ["TG_KEY"],
    model="meta-llama/Llama-3-70b-chat-hf",
    temperature=0.0,
)

In [None]:
data = pd.read_csv("../labeled_contexual_questions_submit.csv")

In [None]:
data.head()

In [None]:
len(data)

base prompt: question

extra prompts: question + {age[10,15,18,21,25,30,40,50,60,70], female | male, location[hawaii, idaho, massachusetts, south dakota, vermont, wyoming]}

In [None]:
import enum

class LocationBias(enum.Enum):
    HI = "Hawaii"
    ID = "Idaho"
    MA = "Massachusetts"
    SD = "South Dakota"
    VT = "Vermont"
    WY = "Wyoming"
    
    
class SexBias(enum.Enum):
    M = "Male"
    F = "Female"
    
    
AgeBias = int
    

class PromptDataset:
    def __init__(self, data, add_age: AgeBias | bool = False, add_location: str | LocationBias = False, add_sex: str | SexBias = False):
        self.data = data.fillna("False")
        self.questions = data["Question"].tolist()
        self.is_sex_bias = self.fillna_col(data["Sex"]).tolist()
        self.is_location_bias = self.fillna_col(data["Location"]).tolist()
        self.is_age_bias = self.fillna_col(data["Age"]).tolist()
        self.add_age = add_age
        self.add_location = add_location
        self.add_sex = add_sex
        
    @property
    def age_prompt_suffix(self):
        if self.add_age is False:
            return ""
        
        return f"I am {self.add_age} years old."
    
    @property
    def location_prompt_suffix(self):
        if self.add_location is False:
            return ""
        
        return f"I live in {self.add_location}."
    
    @property
    def sex_prompt_suffix(self):
        if self.add_sex is False:
            return ""
        
        return f"I am a {self.add_sex}."

    @staticmethod
    def fillna_col(col: pd.Series):
        with pd.option_context("future.no_silent_downcasting", True):
            return col.fillna(False).replace({"x": True})

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        langchain_input = ChatPromptValue(
            messages=[
                ChatMessage(
                    content=self.questions[idx],
                    role="user",
                )
            ]
        )

        return langchain_input

    def __iter__(self):
        for idx in range(len(self.data)):
            yield self[idx]

In [None]:
import json
import tqdm

# generate for all of the base prompts
outputs = []

for i in tqdm.tqdm(PromptDataset(data)):
    result = chat.generate_prompt([i])
    generation = result.generations[0][0]
    text = generation.text
    metadata = result.generations[0][0].message.response_metadata
    outputs.append({
        "text": text,
        "metadata": metadata
    })
    
    with open("outputs_base.json", "w") as f:
        json.dump(outputs, f)