# Objective

We want to create a minified knowledge graph that can be generated quickly and economically from existing data.

We will do so via the following process, an iterative method of constructing knowledge graphs from unstructured, unlabeled text data.

* Define an empty set of unique entities $E$.
* Define an empty set of unique properties $P$.
* Define an empty knowledge graph $G$ as a set of 3-tuples $(e_1, p, e_2)$ s.t. $e_1, e_2\in E$ and $p\in P$.
* Preprocess the input text data. We only applied minimal preprocessing, with the aim of retaining as much information as possible from the original data. In fact, we augment the existing text data by providing simple *coreference resolution*, in the hope of restoring information that would otherwise be lost by sentence-level chunking.
    * For example: "Barack Obama was the 44th President of the United States. *He* was nominated as the Democratic Party's candidate in 2008." $\to$ "... *Barack Obama* was nominated..." 
* Chunk the preprocessed text dataset into a size appropriate for processing within an LLM's context window; we found 1 sentence to be optimal.
* For each data chunk $d$, do the following:
  * Pass $E$, $P$, and $d$ to the LLM. Prompt it to return $C$, a set of 3-tuples $(e_1, p, e_2)$ s.t. $e_1, e_2\in E$ and $p\in P$. The instructions should indicate that, when creating each element $c$ of $C$, if an entity or property referenced in $c$ is not present in $E$ or $P$, it should define a new entity/property and indicate that it is new.
  * For each claim $c\in C$, if either of $e_1, e_2\notin E$ or $p\notin P$, add the relevant element to its set. Then, append $c$ to the locally-constructed knowledge graph $G$.

In this way, we iteratively build a knowledge graph using only the entity and property types that are relevant to the data in our input documents.

In [19]:
%pip install mwxml numpy==1.26.4 torch torchvision torchaudio python-dotenv openai pydantic chromadb datasets evaluate
%pip install -U sentence-transformers

Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Downloading evaluate-0.4.3-py3-none-any.whl (84 kB)
Installing collected packages: evaluate
Successfully installed evaluate-0.4.3

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [9]:
import sys

IN_COLAB = "google.colab" in sys.modules

if IN_COLAB:
  from google.colab import userdata
  openai_token = userdata.get("OPENAI_API_KEY")
else:
  import os
  import dotenv
  dotenv.load_dotenv()
  openai_token = os.environ.get("OPENAI_API_KEY")

assert openai_token is not None, "Must set the OPENAI_API_KEY environment variable"

In [10]:
import re
import json

# Lightly preprocess summaries. In particular, we naively replace all
#  third-person personal pronouns (he/she/they/them/it) with the
#  title of the article.

with open("all_summaries.json", "r") as f:
    summaries = json.load(f)

summaries = {
    title: re.sub(r"\s+([hH]e|[sS]he|[tT]hey|[tT]hem|[iI]t)[\s\.,!?]+", f" {title} ", summary).replace("...", ".")
    for title, summary in summaries.items()
}

In [None]:
import openai
import time
from dataclasses import dataclass
import json
from typing import List
from pydantic import BaseModel

class SemanticTriple(BaseModel):
  entityA: str
  relationship: str
  entityB: str
    
  def __hash__(self):
      return hash((self.entityA, self.relationship, self.entityB))
      
class SemanticTripleList(BaseModel):
    triples: List[SemanticTriple]
    

@dataclass
class SemanticTripleExtractor:
    client: openai.OpenAI
    GPT_MODEL = "gpt-4o"
    ERROR_RETRY_SLEEP = 0.001

    def get_semantic_triples(self, text: str):
        system_prompt = """
        You are a semantic role and entity extractor.

        Given an input text (which may contain multiple sentences), identify every (entityA, relationship, entityB) tuple,
        **even if it's factually incorrect**.

        Some sentences may contain multiple triples, and the semantic triples that are explicitly stated in the sentence
        may not be the only implications of the sentence. For example, the sentence "John graduated college" also implies
        the sentence "John holds a degree". Within reason, attempt to capture all explicit and implicit semantic triples.

        Always output exactly valid JSON with a single key "triples" consisting of a list of semantic triples:
        {
        "triples": [
            { "entityA": "<ENTITY_ID>", "relationship": "<REL_ID>", "entityB": "<ENTITY_ID>" },
            …
        ]
        }
        If there are none, return `{ "triples": [] }`.
        All relationships should be formatted using camelCase, and all entities should use PascalCase.
        ---
        Below is an example of proper processing.
        Sentence: "Princess Diana is a British royal."
        Output: {
            ["entityA": "PrincessDiana", "relationship": "countryOfOrigin", "entityB": "GreatBritain"],
            ["entityB": "PrincessDiana", "relationship": "instanceOf", "entityB": "Royal"]
        }
        ---
        Below is another example of proper processing.
        Sentence: "Batman Forever was released on June 16, 1995, to mixed reviews from critics, who praised the visuals, action sequences, and soundtrack, but criticized the screenplay and tonal departure from previous two films."
        Output: {
            ["entityA": "BatmanForever", "relationship": "releaseDate", "entityB": "June16,1995"],
            ["entityB": "BatmanForever", "relationship": "receivedReviews", "entityB": "Mixed"],
            ["entityA": "BatmanForever", "relationship": "praisedFor", "entityB": "Visuals"],
            ["entityA": "BatmanForever", "relationship": "praisedFor", "entityB": "ActionSequences"],
            ["entityA": "BatmanForever", "relationship": "praisedFor", "entityB": "Soundtrack"],
            ["entityA": "BatmanForever", "relationship": "criticizedFor", "entityB": "Screenplay"],
            ["entityA": "BatmanForever", "relationship": "criticizedFor", "entityB": "TonalDepartureFromPreviousFilms"]
        }
        ---
        Think step by step before giving your output.
        """
        return self._request_with_retry(system_prompt, text)

    def _request_with_retry(self, system_prompt: str, text: str):
        n_retries = 0
        while True:
            try:
                response = (
                    self.client.beta.chat.completions.parse(
                        model=self.GPT_MODEL,
                        temperature=0,
                        messages=[
                            {"role": "system", "content": system_prompt},
                            {"role": "user", "content": text},
                        ],
                        response_format=SemanticTripleList,
                    )
                ).choices[0].message
                break

            except openai.RateLimitError as err:
                n_retries += 1
                print(err)
                print("Exceeded rate limit")
                print(f"Sleeping before retry (done {n_retries} time(s))")
                time.sleep(self.ERROR_RETRY_SLEEP)

            except Exception as err:
                n_retries += 1
                print(f"Unexpected error ({err})")
                print(f"Sleeping before retry (done {n_retries} time(s))")
                time.sleep(self.ERROR_RETRY_SLEEP)

        if response is None:
            raise ValueError("Got null response")
        elif response.refusal:
            raise ValueError(response.refusal)
        
        return SemanticTripleList.model_validate_json(response.content)

In [13]:
import chromadb

chroma_client = chromadb.PersistentClient(
    path="vectordb"
)
ontology_db = chroma_client.get_collection("ontology_embeddings")
kg_db = chroma_client.get_collection("fact_embeddings")

In [14]:
def split_label(label: str) -> str:
    label_str = label[0].upper() + label[1:]
    return re.sub(r"[A-Z]", lambda m: " " + m.group(0).lower(), label_str).strip()

In [28]:
from typing import Literal, Optional

def get_best_match(collection: chromadb.Collection, label: str, item_type: Literal["entity", "relationship"], min_similarity=0.9) -> Optional[str]:
    lexical_match = collection.get(ids=[label], where={"type": item_type})
    if lexical_match["ids"]:
        return label
    
    semantic_match = collection.query(query_texts=[label], n_results=1, where={"type": item_type})
    semantic_match_label = str(semantic_match["ids"][0][0])
    return semantic_match_label


def ground_triple(collection: chromadb.Collection, triple: SemanticTriple) -> SemanticTriple:
    return SemanticTriple(
        entityA=get_best_match(collection, triple.entityA, "entity"),
        relationship=get_best_match(collection, triple.relationship, "relationship"),
        entityB=get_best_match(collection, triple.entityB, "entity")
    )

In [40]:
from pydantic import BaseModel
from enum import Enum

# Unique labels
print("Output labels:", set(fever_test["label"]))

class FeverResponse(str, Enum):
    REFUTES = "REFUTES"
    SUPPORTS = "SUPPORTS"
    NOT_ENOUGH_INFO = "NOT ENOUGH INFO"

class FeverEvaluation(BaseModel):
    label: FeverResponse
    reason: str

Output labels: {'REFUTES', 'SUPPORTS', 'NOT ENOUGH INFO'}


In [None]:
@dataclass
class ClaimEvaluator:
    client: openai.OpenAI
    GPT_MODEL = "gpt-4o"
    ERROR_RETRY_SLEEP = 0.001

    def evaluate(self, claim: str, facts: list[str]):
        system_prompt = """
        You are a fact-checking assistant.  You will be given:

        • A single claim.
        • A list of facts.

        Your job is to decide, using ONLY the provided facts and no external or background knowledge, whether the facts:

        • SUPPORT the claim,
        • REFUTE the claim, or
        • do NOT provide enough information to determine (NOT ENOUGH INFO).

        Output exactly one of these three labels (in ALL CAPS), as well as your reasoning:
        SUPPORTS
        REFUTES
        NOT ENOUGH INFO
        """
        return self._request_with_retry(system_prompt, claim, facts)

    def _request_with_retry(self, system_prompt: str, claim: str, facts: list[str]):
        facts_block = "\n".join(f"- {fact}" for fact in facts)
        user_content = f"Claim:\n{claim}\n\nFacts:\n{facts_block}"
        
        n_retries = 0
        while True:
            try:
                response = (
                    self.client.beta.chat.completions.parse(
                        model=self.GPT_MODEL,
                        temperature=0,
                        messages=[
                            {"role": "system", "content": system_prompt},
                            {"role": "user", "content": user_content},
                        ],
                        response_format=FeverEvaluation,
                    )
                ).choices[0].message
                break

            except openai.RateLimitError as err:
                n_retries += 1
                print(err)
                print("Exceeded rate limit")
                print(f"Sleeping before retry (done {n_retries} time(s))")
                time.sleep(self.ERROR_RETRY_SLEEP)

            except Exception as err:
                n_retries += 1
                print(f"Unexpected error ({err})")
                print(f"Sleeping before retry (done {n_retries} time(s))")
                time.sleep(self.ERROR_RETRY_SLEEP)

        if response is None:
            raise ValueError("Got null response")
        elif response.refusal:
            raise ValueError(response.refusal)
        
        return FeverEvaluation.model_validate_json(response.content)

In [49]:
# Define our pipeline for claim evaluation

import numpy as np

client = openai.OpenAI(api_key=openai_token)
semantic_extractor = SemanticTripleExtractor(client)
evaluator = ClaimEvaluator(client)


def evaluate_claim(claim: str) -> FeverEvaluation:
    embeddings = []
    triples = semantic_extractor.get_semantic_triples(claim).triples
    for triple in triples:
        grounded_triple = ground_triple(ontology_db, triple)
        item_vecs = [
            ontology_db.get(
            ids=[label],
            include=["embeddings"]
            )["embeddings"][0]
            for label in [grounded_triple.entityA, grounded_triple.relationship, grounded_triple.entityB]
        ]
        claim_embedding = np.concatenate(item_vecs)
        embeddings.append(claim_embedding)

    relevant_facts = kg_db.query(
        query_embeddings=embeddings,
        include=["metadatas"],
        n_results=2  # top 2 most relevant results per claim
    )["metadatas"]
    # Flatten metadata into a single list
    relevant_facts = np.concatenate(relevant_facts).tolist()
    # Lexicalize each stored fact
    relevant_facts = [
        " ".join([
            f["entityA"],
            f["relationship"],
            f["entityB"],
        ])
        for f in relevant_facts
    ]
    
    # Ask arbiter to determine validity of claim
    response = evaluator.evaluate(claim, relevant_facts)

    return response

In [None]:
from datasets import load_dataset

fever = load_dataset('fever', 'v1.0')
fever_test = fever["paper_test"]
fever_test.features

{'id': Value(dtype='int32', id=None),
 'label': Value(dtype='string', id=None),
 'claim': Value(dtype='string', id=None),
 'evidence_annotation_id': Value(dtype='int32', id=None),
 'evidence_id': Value(dtype='int32', id=None),
 'evidence_wiki_url': Value(dtype='string', id=None),
 'evidence_sentence_id': Value(dtype='int32', id=None)}

In [66]:
import random
from tqdm.auto import tqdm

true_labels = []
pred_labels = []

with open("eval.log", "w") as f:
    f.write(
        "\t".join([
            "claim", "pred_label", "true_label", "reason"
        ])
    )

fever_selection = random.sample(fever_test.to_list(), 2000)

for entry in tqdm(fever_selection):
    true_labels.append(entry["label"])
    eval = evaluate_claim(entry["claim"])
    pred_labels.append(eval.label.value)
    with open("eval.log", "a+") as f:
        f.write(
            "\n" + 
            "\t".join([
                entry["claim"], eval.label.value, entry["label"], eval.reason
            ])
        )

 39%|███▉      | 788/2000 [5:41:39<8:45:30, 26.02s/it]     


KeyboardInterrupt: 

In [69]:
from sklearn.metrics import classification_report

report = classification_report(true_labels, pred_labels)
print(report)

                 precision    recall  f1-score   support

NOT ENOUGH INFO       0.25      0.97      0.40       192
        REFUTES       0.84      0.11      0.19       294
       SUPPORTS       1.00      0.01      0.02       302

       accuracy                           0.28       788
      macro avg       0.70      0.36      0.20       788
   weighted avg       0.76      0.28      0.17       788

