<a href="https://colab.research.google.com/github/nolanwelch/ai-factcheck/blob/main/proof-of-concept/code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%pip install python-dotenv openai pydantic


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [2]:
import sys

IN_COLAB = "google.colab" in sys.modules

if IN_COLAB:
  from google.colab import userdata
  openai_token = userdata.get("OPENAI_API_KEY")
else:
  import os
  import dotenv
  dotenv.load_dotenv()
  openai_token = os.environ.get("OPENAI_API_KEY")

assert openai_token is not None, "Must set the OPENAI_API_KEY environment variable"

In [None]:
from typing import Dict, List
import json
from pydantic import BaseModel, model_validator, RootModel

class KGMapping(BaseModel):
    entities: Dict[str, str]
    relations: Dict[str, str]

class KG(RootModel[Dict[str, Dict[str, List[str]]]]):
    def __getitem__(self, ent_id: str) -> Dict[str, List[str]]:
        return self.root[ent_id]

class KnowledgeGraph(BaseModel):
    mapping: KGMapping
    graph: KG

    @model_validator(mode='after')
    def _check_references(self) -> "KnowledgeGraph":
        mapping = self.mapping
        graph_dict = self.graph.root

        # 1) every source‑entity in graph must exist
        for ent_id, rels in graph_dict.items():
            if ent_id not in mapping.entities:
                raise ValueError(f"Unknown entity in graph: {ent_id}")

            # 2) every relation must exist
            for rel_id, targets in rels.items():
                if rel_id not in mapping.relations:
                    raise ValueError(f"Unknown relation in graph: {rel_id}")

                # 3) every target entity must exist
                for tgt in targets:
                    if tgt not in mapping.entities:
                        raise ValueError(f"Unknown target entity in graph: {tgt}")

        return self

def load_knowledge_graph(filepath: str):
    with open(filepath) as f:
        data = json.load(f)
        return KnowledgeGraph.model_validate(data)

kg = load_knowledge_graph("kg.json")

mapping=KGMapping(entities={'ent_01': 'Barack Obama', 'ent_02': 'Michelle Obama', 'ent_03': 'Presidency of the United States'}, relations={'rel_01': 'spouseOf', 'rel_02': 'formerOfficeHolder'}) graph=KG(root={'ent_01': {'rel_01': ['ent_02'], 'rel_02': ['ent_03']}, 'ent_02': {'rel_01': ['ent_02']}})


In [35]:
import openai
import time
from dataclasses import dataclass
import json

# Semantically-grounded IDs
entity_ids   = list(kg.mapping.entities.keys())
relation_ids = list(kg.mapping.relations.keys())

# Human-readable mappings
entity_list = "\n".join(f"- {eid}: {lbl}" for eid, lbl in kg.mapping.entities.items())
relation_list = "\n".join(f"- {rid}: {lbl}" for rid, lbl in kg.mapping.relations.items())

extract_triples_fn = {
    "name": "extract_triples",
    "description": f"""
Extract all semantic triples (entityA, relationship, entityB) from a sentence.
Use **only** the following IDs:

Entities:
{entity_list}

Relations:
{relation_list}

Return a JSON object with a single field `triples`, an array of objects:
  {{ "entityA": <ENTITY_ID>, "relationship": <RELATION_ID>, "entityB": <ENTITY_ID> }}

If there are no triples, return `{{"triples":[]}}`.
""",
    "parameters": {
        "type": "object",
        "properties": {
            "triples": {
                "type": "array",
                "items": {
                    "type": "object",
                    "properties": {
                        "entityA": {
                            "type": "string",
                            "enum": entity_ids,
                            "description": "ID of the first entity"
                        },
                        "relationship": {
                            "type": "string",
                            "enum": relation_ids,
                            "description": "ID of the relationship"
                        },
                        "entityB": {
                            "type": "string",
                            "enum": entity_ids,
                            "description": "ID of the second entity"
                        },
                    },
                    "required": ["entityA", "relationship", "entityB"],
                },
            },
        },
        "required": ["triples"],
    },
}


@dataclass
class SemanticTripleExtractor:
    client: openai.OpenAI
    GPT_MODEL = "gpt-4o-mini"
    ERROR_RETRY_SLEEP = 0.001

    def get_semantic_triples(self, text: str):
        system_prompt = """
            You are a semantic role and entity extractor.  

            Given an input sentence, identify every tuple (entityA, relationship, entityB) expressed in that sentence, where:

            • entityA and entityB must be one of the entities given in the schema (use the ENTITY IDs)
            • relationship must be one of the relations given in the schema (use the RELATION IDs)

            Do not invent any new entities or relations.  
            Always output **only** valid JSON that matches the provided schema.
        """
        return self._request_with_retry(system_prompt, text)

    def _request_with_retry(self, system_prompt: str, text: str):
        n_retries = 0
        while True:
            try:
                response = (
                    self.client.beta.chat.completions.parse(
                        model=self.GPT_MODEL,
                        temperature=0,
                        messages=[
                            {"role": "system", "content": system_prompt},
                            {"role": "user", "content": text},
                        ],
                        functions=[extract_triples_fn],
                        function_call={"name": "extract_triples"},
                    )
                    .choices[0]
                    .message.function_call
                )
                break

            except openai.RateLimitError as err:
                n_retries += 1
                print(err)
                print("Exceeded rate limit")
                print(f"Sleeping before retry (done {n_retries} time(s))")
                time.sleep(self.ERROR_RETRY_SLEEP)

            except Exception as err:
                n_retries += 1
                print(f"Unexpected error ({err})")
                print(f"Sleeping before retry (done {n_retries} time(s))")
                time.sleep(self.ERROR_RETRY_SLEEP)

        if response is None:
            raise ValueError("Got null response")

        return json.loads(response.arguments)["triples"]

In [38]:
client = openai.OpenAI(api_key=openai_token)

semantic_extractor = SemanticTripleExtractor(client)

In [52]:
sentence = "Barack Obama used to be President. "
sentence += "Thomas Edison invented the lightbulb. "
sentence += "Barack Obama is married to Michelle Obama, and Michelle Obama is married to Barack Obama."

triples = semantic_extractor.get_semantic_triples(sentence)

# We expect one triple for the first sentence, nothing for the second,
#   and two triples for the compound third sentence.
for triple in triples:
    name_A = kg.mapping.entities[triple["entityA"]]
    name_rel = kg.mapping.relations[triple["relationship"]]
    name_B = kg.mapping.entities[triple["entityB"]]
    print(" ".join([name_A, name_rel, name_B]))

Barack Obama formerOfficeHolder Presidency of the United States
Barack Obama spouseOf Michelle Obama
Michelle Obama spouseOf Barack Obama
