In [1]:
from langchain_huggingface import HuggingFaceEndpoint
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain.chains import LLMChain
from langchain.output_parsers import RegexParser
from langchain_core.prompts import PromptTemplate
from datetime import datetime, timezone
from tqdm import tqdm
from bilateral_factuality_evaluator import BilateralFactualityEvaluator
import pandas, re, json, os

In [2]:
df = pandas.read_csv(
    "https://openaipublic.blob.core.windows.net/simple-evals/simple_qa_test_set.csv"
)

In [3]:
SIMPLE_QA_DATASET = df.to_dict(orient="records")

In [4]:
VERIFICATION_PROMPT = """Here is a factual question and a potential answer. 
Your task is to assess whether the answer is true. 
Please first explain your reasoning step by step, 
then conclude with either "TRUE" if you have determined that the answer is true, 
otherwise "CANNOT DETERMINE TRUE", 
and only those conclusions, 
at the very end of your reasoning.

Reasoning steps: 
1. First verify the essential information is present 
2. Check for any supporting facts 
3. Consider question context for implied terms 
4. Note if any missing information is essential vs optional 

Question: {problem}
Answer: {answer}
"""

In [5]:
FALSIFICATION_PROMPT = """Here is a factual question and a potential answer. 
Your task is to assess whether the answer is false. 
Please first explain your reasoning step by step, 
then conclude with either "FALSE" if you have determined that the answer is false, 
otherwise "CANNOT DETERMINE FALSE", 
and only those conclusions, 
at the very end of your reasoning.

Reasoning steps: 
1. First verify the essential information is present 
2. Check for any contradictory facts
3. Consider question context for implied terms 
4. Note if any missing information is essential vs optional 

Question: {problem}
Answer: {answer}
"""

In [6]:
OUTPUT_PARSER = RegexParser(
    regex=r"(TRUE|FALSE|CANNOT DETERMINE FALSE|CANNOT DETERMINE TRUE)",
    output_keys=["metadata", "problem", "answer", "text"],
    default_output_key="text"
)

In [19]:
MODELS = [ 
    { "model_name": "gpt-4o-mini", "batch_size": 100 },
    # { "model_name": "gpt-4o-2024-05-13", "batch_size": 100 },
    # { "model_name": "gpt-4-0125-preview", "batch_size": 50 },
    # { "model_name": "mistralai/Mistral-7B-Instruct-v0.3", "batch_size": 50 },
    # { "model_name": "claude-3-5-sonnet-20240620", "batch_size": 1 },
    { "model_name": "mistralai/Mixtral-8x7B-Instruct-v0.1", "batch_size": 50 },
    # { "model_name": "claude-3-opus-20240229", "batch_size": 1 },
    # { "model_name": "meta-llama/Meta-Llama-3-70B-Instruct", "batch_size": 50 },
    # { "model_name": "claude-3-haiku-20240307", "batch_size": 1 },
]

In [8]:
def get_matched_string(text):
    pattern = r'\b(TRUE|CANNOT DETERMINE TRUE|FALSE|CANNOT DETERMINE FALSE)\b'
    match = re.search(pattern, text)
    return match.group(1) if match else None

In [9]:
def v4(verification_result, falsification_result):
    if verification_result == 'TRUE':
        if falsification_result == 'FALSE':
            return 'b'
        elif falsification_result == 'CANNOT DETERMINE FALSE':
            return 't'
        else:
            return None
    elif verification_result == 'CANNOT DETERMINE TRUE':
        if falsification_result == 'FALSE':
            return 'f'
        elif falsification_result == 'CANNOT DETERMINE FALSE':
            return 'n'
        else:
            return None
    else:
        return None

In [10]:
def get_llm(model, temperature=0.1):
    if model in [ "gpt-3.5-turbo", "gpt-4-1106-preview", "gpt-4-0125-preview", "gpt-4o-2024-05-13", "gpt-4o-mini" ]:
        return ChatOpenAI(model_name=model, temperature=temperature)
    elif model in [ "claude-3-opus-20240229", "claude-3-5-sonnet-20240620", "claude-3-haiku-20240307" ]:
        return ChatAnthropic(
            temperature=temperature, 
            anthropic_api_key=os.environ["ANTHROPIC_API_KEY"], 
            model_name=model
        )
    elif model in [ "gemini-1.0-pro" ]:
        return ChatGooglePalm(
            temperature=temperature, 
            google_api_key=os.environ["GOOGLE_API_KEY"], 
            model=model
        )
    elif model in [
        "meta-llama/Llama-2-70b-chat-hf", 
        "mistralai/Mixtral-8x7B-Instruct-v0.1", 
        "mistralai/Mistral-7B-Instruct-v0.3", 
        "google/gemma-2-9b-it",
        "google/gemma-7b-it", 
        "google/gemma-2b-it",
        "meta-llama/Meta-Llama-3-70B-Instruct", 
        "microsoft/Phi-3-mini-128k-instruct",
        ]:
        return HuggingFaceEndpoint(
            repo_id=model, 
            temperature=temperature, 
            timeout=300,
            huggingfacehub_api_token=os.environ["HUGGINGFACEHUB_API_TOKEN"]
        )
    else:
        raise Exception(f'Model {model} not supported')


In [24]:
for model in MODELS:
    filename = f'experiments/{model["model_name"].split("/")[-1]}-simpleqa.json'
    if os.path.isfile(filename):
        print(f'{model["model_name"]:36}: EXISTS')
    else:
        results = []
        batches = [ SIMPLE_QA_DATASET[i:i+model["batch_size"]] for i in range(0, len(SIMPLE_QA_DATASET), model["batch_size"]) ] 
        llm = get_llm(model["model_name"])
        verify_prompt = PromptTemplate(input_variables=["problem", "answer"], template=VERIFICATION_PROMPT)
        falsify_prompt = PromptTemplate(input_variables=["problem", "answer"], template=FALSIFICATION_PROMPT)
        verify_chain = verify_prompt | llm
        falsify_chain = falsify_prompt | llm
        for batch in tqdm(batches, desc=f'{model["model_name"]:36}', total=len(batches)):
            falsifications = falsify_chain.batch(batch)
            verifications = verify_chain.batch(batch)
            for i in range(len(verifications)):
                results.append({
                    "metadata": batch[i]["metadata"],
                    "problem": batch[i]["problem"],
                    "answer": batch[i]["answer"],
                    "model_name": model["model_name"],
                    "timestamp": datetime.now(timezone.utc).isoformat(),
                    # "total_tokens": verifications[i].response_metadata["token_usage"]["total_tokens"] + falsifications[i].response_metadata["token_usage"]["total_tokens"],
                    # "verification": verifications[i].content,
                    # "falsification": falsifications[i].content,
                    "verification": verifications[i],
                    "falsification": falsifications[i],
                    # "evaluation": v4(get_matched_string(verifications[i].content), get_matched_string(falsifications[i].content))
                    "evaluation": v4(get_matched_string(verifications[i]), get_matched_string(falsifications[i]))
                })
            json.dump(results, open(filename, "w+"))

gpt-4o-mini                         : EXISTS


mistralai/Mixtral-8x7B-Instruct-v0.1: 100%|██████████| 87/87 [9:34:06<00:00, 395.94s/it]  
