In [17]:
import sys

IN_COLAB = "google.colab" in sys.modules

if IN_COLAB:
  from google.colab import userdata
  openai_token = userdata.get("OPENAI_API_KEY")
else:
  import os
  import dotenv
  dotenv.load_dotenv()
  openai_token = os.environ.get("OPENAI_API_KEY")

assert openai_token is not None, "Must set the OPENAI_API_KEY environment variable"

In [18]:
import openai
import time
from dataclasses import dataclass
import json
from typing import List
from pydantic import BaseModel

class SemanticTriple(BaseModel):
  entityA: str
  relationship: str
  entityB: str
    
  def __hash__(self):
      return hash((self.entityA, self.relationship, self.entityB))
      
class SemanticTripleList(BaseModel):
    triples: List[SemanticTriple]
    

@dataclass
class SemanticTripleExtractor:
    client: openai.OpenAI
    GPT_MODEL = "gpt-4o-mini"
    ERROR_RETRY_SLEEP = 0.001

    def get_semantic_triples(self, text: str):
        system_prompt = """
        You are a semantic role and entity extractor.

        Given an input text (which may contain multiple sentences), identify every (entityA, relationship, entityB) tuple,
        **even if it's factually incorrect**.

        Some sentences may contain multiple triples, and the semantic triples that are explicitly stated in the sentence
        may not be the only implications of the sentence. For example, the sentence "John graduated college" also implies
        the sentence "John holds a degree". Within reason, attempt to capture all explicit and implicit semantic triples.

        Always output exactly valid JSON with a single key "triples" consisting of a list of semantic triples:
        {
        "triples": [
            { "entityA": "<ENTITY_ID>", "relationship": "<REL_ID>", "entityB": "<ENTITY_ID>" },
            …
        ]
        }
        If there are none, return `{ "triples": [] }`.
        All relationships should be formatted using camelCase, and all entities should use PascalCase.
        ---
        Below is an example of proper processing.
        Sentence: "Princess Diana is a British royal."
        Output: {
            ["entityA": "PrincessDiana", "relationship": "countryOfOrigin", "entityB": "GreatBritain"],
            ["entityB": "PrincessDiana", "relationship": "instanceOf", "entityB": "Royal"]
        }
        ---
        Below is another example of proper processing.
        Sentence: "Batman Forever was released on June 16, 1995, to mixed reviews from critics, who praised the visuals, action sequences, and soundtrack, but criticized the screenplay and tonal departure from previous two films."
        Output: {
            ["entityA": "BatmanForever", "relationship": "releaseDate", "entityB": "June16,1995"],
            ["entityB": "BatmanForever", "relationship": "receivedReviews", "entityB": "Mixed"],
            ["entityA": "BatmanForever", "relationship": "praisedFor", "entityB": "Visuals"],
            ["entityA": "BatmanForever", "relationship": "praisedFor", "entityB": "ActionSequences"],
            ["entityA": "BatmanForever", "relationship": "praisedFor", "entityB": "Soundtrack"],
            ["entityA": "BatmanForever", "relationship": "criticizedFor", "entityB": "Screenplay"],
            ["entityA": "BatmanForever", "relationship": "criticizedFor", "entityB": "TonalDepartureFromPreviousFilms"]
        }
        ---
        Think step by step before giving your output.
        """
        return self._request_with_retry(system_prompt, text)

    def _request_with_retry(self, system_prompt: str, text: str):
        n_retries = 0
        while True:
            try:
                response = (
                    self.client.beta.chat.completions.parse(
                        model=self.GPT_MODEL,
                        temperature=0,
                        messages=[
                            {"role": "system", "content": system_prompt},
                            {"role": "user", "content": text},
                        ],
                        response_format=SemanticTripleList,
                    )
                ).choices[0].message
                break

            except openai.RateLimitError as err:
                n_retries += 1
                print(err)
                print("Exceeded rate limit")
                print(f"Sleeping before retry (done {n_retries} time(s))")
                time.sleep(self.ERROR_RETRY_SLEEP)

            except Exception as err:
                n_retries += 1
                print(f"Unexpected error ({err})")
                print(f"Sleeping before retry (done {n_retries} time(s))")
                time.sleep(self.ERROR_RETRY_SLEEP)

        if response is None:
            raise ValueError("Got null response")
        elif response.refusal:
            raise ValueError(response.refusal)
        
        return SemanticTripleList.model_validate_json(response.content)

In [19]:
from pydantic import BaseModel

class BoolQResponse(BaseModel):
    answer: bool
    reason: str

In [20]:
@dataclass
class ClaimEvaluator:
    client: openai.OpenAI
    GPT_MODEL = "gpt-4o-mini"
    ERROR_RETRY_SLEEP = 0.001

    def evaluate(self, claim: str, facts: list[str]):
        system_prompt = """
        You are a fact-checking assistant.  You will be given:

        • A single claim.
        • A yes-or-no question.

        Your job is to decide, using ONLY the provided facts and no external or background knowledge,
        whether the answer to the question is true or false. Return only your true/false answer and a short explanation.
        """
        return self._request_with_retry(system_prompt, claim, facts)

    def _request_with_retry(self, system_prompt: str, claim: str, facts: list[str]):
        facts_block = "\n".join(f"- {fact}" for fact in facts)
        user_content = f"Claim:\n{claim}\n\nFacts:\n{facts_block}"
        
        n_retries = 0
        while True:
            try:
                response = (
                    self.client.beta.chat.completions.parse(
                        model=self.GPT_MODEL,
                        temperature=0,
                        messages=[
                            {"role": "system", "content": system_prompt},
                            {"role": "user", "content": user_content},
                        ],
                        response_format=BoolQResponse,
                    )
                ).choices[0].message
                break

            except openai.RateLimitError as err:
                n_retries += 1
                print(err)
                print("Exceeded rate limit")
                print(f"Sleeping before retry (done {n_retries} time(s))")
                time.sleep(self.ERROR_RETRY_SLEEP)

            except Exception as err:
                n_retries += 1
                print(f"Unexpected error ({err})")
                print(f"Sleeping before retry (done {n_retries} time(s))")
                time.sleep(self.ERROR_RETRY_SLEEP)

        if response is None:
            raise ValueError("Got null response")
        elif response.refusal:
            raise ValueError(response.refusal)
        
        return BoolQResponse.model_validate_json(response.content)

In [21]:
# Define our pipeline for question evaluation

import numpy as np

client = openai.OpenAI(api_key=openai_token)
semantic_extractor = SemanticTripleExtractor(client)
evaluator = ClaimEvaluator(client)


def evaluate_claim(passage: str, claim: str) -> BoolQResponse:
    facts = np.concatenate([
        semantic_extractor.get_semantic_triples(sentence).triples
        for sentence in passage.split(".")
    ]).tolist()
    # Lexicalize each extracted fact
    fact_strings = [
        " ".join([
            f.entityA,
            f.relationship,
            f.entityB,
        ])
        for f in facts
    ]
    
    # Ask arbiter to determine answer to question based on facts
    response = evaluator.evaluate(claim, fact_strings)

    return response

In [22]:
from datasets import load_dataset

boolq = load_dataset("google/boolq")
boolq

DatasetDict({
    train: Dataset({
        features: ['question', 'answer', 'passage'],
        num_rows: 9427
    })
    validation: Dataset({
        features: ['question', 'answer', 'passage'],
        num_rows: 3270
    })
})

In [23]:
# Boolean values
set(boolq["train"]["answer"])

{False, True}

In [35]:
import random
import pandas as pd
from tqdm.auto import tqdm

true_labels = []
pred_labels = []

answer_df = pd.read_csv("model_output.csv")

boolq_test = boolq["train"].to_list() + boolq["validation"].to_list()
boolq_test = pd.DataFrame(boolq_test)
# Remove already-processed queries
boolq_test = boolq_test[~boolq_test["question"].isin(answer_df["question"])]
boolq_test = boolq_test.sample(250)

answers = []

for idx, entry in tqdm(boolq_test.iterrows(), total=250):
    true_labels.append(entry["answer"])
    eval = evaluate_claim(entry["passage"], entry["question"])
    pred_labels.append(eval.answer)
    answers.append({
        "true_label": entry["answer"],
        "pred_label": eval.answer,
        "passage": entry["passage"],
        "question": entry["question"],
        "model_reason": eval.reason,
    })

 50%|█████     | 126/250 [25:48<23:04, 11.16s/it]

Unexpected error (Error code: 500 - {'error': {'message': 'The server had an error while processing your request. Sorry about that!', 'type': 'server_error', 'param': None, 'code': None}})
Sleeping before retry (done 1 time(s))


100%|██████████| 250/250 [55:04<00:00, 13.22s/it]


In [36]:
output_df = pd.concat([answer_df, pd.DataFrame(answers)])
output_df.to_csv("model_output.csv", index=False)

In [37]:
from sklearn.metrics import classification_report

scores_df = pd.read_csv("model_output.csv")
scores_df = scores_df.drop_duplicates(subset=["question"], keep="first")

report = classification_report(
    scores_df["true_label"], scores_df["pred_label"]
)
print(report)

              precision    recall  f1-score   support

       False       0.79      0.89      0.84       404
        True       0.92      0.84      0.88       596

    accuracy                           0.86      1000
   macro avg       0.86      0.87      0.86      1000
weighted avg       0.87      0.86      0.86      1000

