# Objective

We want to create a minified knowledge graph that can be generated quickly and economically from existing data.

We will do so via the following process, an iterative method of constructing knowledge graphs from unstructured, unlabeled text data.

* Define an empty set of unique entities $E$.
* Define an empty set of unique properties $P$.
* Define an empty knowledge graph $G$ as a set of 3-tuples $(e_1, p, e_2)$ s.t. $e_1, e_2\in E$ and $p\in P$.
* Preprocess the input text data. We only applied minimal preprocessing, with the aim of retaining as much information as possible from the original data. In fact, we augment the existing text data by providing simple *coreference resolution*, in the hope of restoring information that would otherwise be lost by sentence-level chunking.
    * For example: "Barack Obama was the 44th President of the United States. *He* was nominated as the Democratic Party's candidate in 2008." $\to$ "... *Barack Obama* was nominated..." 
* Chunk the preprocessed text dataset into a size appropriate for processing within an LLM's context window; we found 1 sentence to be optimal.
* For each data chunk $d$, do the following:
  * Pass $E$, $P$, and $d$ to the LLM. Prompt it to return $C$, a set of 3-tuples $(e_1, p, e_2)$ s.t. $e_1, e_2\in E$ and $p\in P$. The instructions should indicate that, when creating each element $c$ of $C$, if an entity or property referenced in $c$ is not present in $E$ or $P$, it should define a new entity/property and indicate that it is new.
  * For each claim $c\in C$, if either of $e_1, e_2\notin E$ or $p\notin P$, add the relevant element to its set. Then, append $c$ to the locally-constructed knowledge graph $G$.

In this way, we iteratively build a knowledge graph using only the entity and property types that are relevant to the data in our input documents.

In [8]:
%pip install mwxml numpy==1.26.4 torch torchvision torchaudio python-dotenv openai pydantic chromadb
%pip install -U sentence-transformers


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [9]:
import sys

IN_COLAB = "google.colab" in sys.modules

if IN_COLAB:
  from google.colab import userdata
  openai_token = userdata.get("OPENAI_API_KEY")
else:
  import os
  import dotenv
  dotenv.load_dotenv()
  openai_token = os.environ.get("OPENAI_API_KEY")

assert openai_token is not None, "Must set the OPENAI_API_KEY environment variable"

In [10]:
with open("proof-of-concept/unique_wiki_urls.txt", "r") as f:
    wiki_urls = f.read().split("\n")

wiki_urls = sorted([
    url.replace("_", " ").replace("-COLON-", ":")
    for url in wiki_urls
    if url
])
len(wiki_urls)

1460

In [11]:
import requests
import json
from tqdm.auto import tqdm

BASE_URL = "https://en.wikipedia.org/w/api.php"

summaries = {}
for slug in tqdm(wiki_urls):
    try:
        data = requests.get(
            BASE_URL,
            params={
                "action": "query",
                "format": "json",
                "titles": slug,
                "prop": "extracts",
                "exintro": True,
                "explaintext": True
            }
        ).json()
        summary = next(iter(data["query"]["pages"].values()))["extract"]
        summaries[slug] = summary
    except Exception as e:
        print(slug, e)
        continue

with open("all_summaries.json", "w") as f:
    json.dump(summaries, f)

  1%|          | 18/1460 [00:02<03:41,  6.51it/s]


KeyboardInterrupt: 

In [12]:
import re
import json

# Lightly preprocess summaries. In particular, we naively replace all
#  third-person personal pronouns (he/she/they/them/it) with the
#  title of the article.

with open("all_summaries.json", "r") as f:
    summaries = json.load(f)

summaries = {
    title: re.sub(r"\s+([hH]e|[sS]he|[tT]hey|[tT]hem|[iI]t)[\s\.,!?]+", f" {title} ", summary).replace("...", ".")
    for title, summary in summaries.items()
}

In [13]:
import openai
import time
from dataclasses import dataclass
import json
from typing import List
from pydantic import BaseModel, ValidationError

class SemanticTriple(BaseModel):
  entityA: str
  relationship: str
  entityB: str
    
  def __hash__(self):
      return hash((self.entityA, self.relationship, self.entityB))
      
class SemanticTripleList(BaseModel):
    triples: List[SemanticTriple]

# def create_extract_triples_fn(entities, relations):
#     return {
#         "name": "extract_triples",
#         "description": f"""
#         Extract all semantic triples (entityA, relationship, entityB) from a sentence.
#         Attempt to do so using only the entities and relationships provided to you, to the
#         best of your ability.
                
#         Return a JSON object with a single field `triples`, an array of objects:
#           {{ "new": <true|false>, "entityA": <ENTITY_ID>, "relationship": <RELATION_ID>, "entityB": <ENTITY_ID> }}
        
#         If there are no triples, return `{{"triples":[]}}`.
#         """,
#         "parameters": {
#             "type": "object",
#             "properties": {
#                 "triples": {
#                     "type": "array",
#                     "items": {
#                         "type": "object",
#                         "properties": {
#                             "entityA": {
#                                 "type": "string",
#                                 "enum": entities,
#                                 "description": "ID of the first entity"
#                             },
#                             "relationship": {
#                                 "type": "string",
#                                 "enum": relations,
#                                 "description": "ID of the relationship"
#                             },
#                             "entityB": {
#                                 "type": "string",
#                                 "enum": entities,
#                                 "description": "ID of the second entity"
#                             },
#                         },
#                         "required": ["entityA", "relationship", "entityB"],
#                     },
#                 },
#             },
#             "required": ["triples"],
#         },
#     }


@dataclass
class SemanticTripleExtractor:
    client: openai.OpenAI
    GPT_MODEL = "gpt-4o"
    ERROR_RETRY_SLEEP = 0.001

    def get_semantic_triples(self, text: str):
        system_prompt = """
        You are a semantic role and entity extractor.

        Given an input text (which may contain multiple sentences), identify every (entityA, relationship, entityB) tuple,
        **even if it's factually incorrect**.

        Some sentences may contain multiple triples, and the semantic triples that are explicitly stated in the sentence
        may not be the only implications of the sentence. For example, the sentence "John graduated college" also implies
        the sentence "John holds a degree". Within reason, attempt to capture all explicit and implicit semantic triples.

        Always output exactly valid JSON with a single key "triples" consisting of a list of semantic triples:
        {
        "triples": [
            { "entityA": "<ENTITY_ID>", "relationship": "<REL_ID>", "entityB": "<ENTITY_ID>" },
            …
        ]
        }
        If there are none, return `{ "triples": [] }`.
        All relationships should be formatted using camelCase, and all entities should use PascalCase.
        ---
        Below is an example of proper processing.
        Sentence: "Princess Diana is a British royal."
        Output: {
            ["entityA": "PrincessDiana", "relationship": "countryOfOrigin", "entityB": "GreatBritain"],
            ["entityB": "PrincessDiana", "relationship": "instanceOf", "entityB": "Royal"]
        }
        ---
        Below is another example of proper processing.
        Sentence: "Batman Forever was released on June 16, 1995, to mixed reviews from critics, who praised the visuals, action sequences, and soundtrack, but criticized the screenplay and tonal departure from previous two films."
        Output: {
            ["entityA": "BatmanForever", "relationship": "releaseDate", "entityB": "June16,1995"],
            ["entityB": "BatmanForever", "relationship": "receivedReviews", "entityB": "Mixed"],
            ["entityA": "BatmanForever", "relationship": "praisedFor", "entityB": "Visuals"],
            ["entityA": "BatmanForever", "relationship": "praisedFor", "entityB": "ActionSequences"],
            ["entityA": "BatmanForever", "relationship": "praisedFor", "entityB": "Soundtrack"],
            ["entityA": "BatmanForever", "relationship": "criticizedFor", "entityB": "Screenplay"],
            ["entityA": "BatmanForever", "relationship": "criticizedFor", "entityB": "TonalDepartureFromPreviousFilms"]
        }
        ---
        Think step by step before giving your output.
        """
        return self._request_with_retry(system_prompt, text)

    def _request_with_retry(self, system_prompt: str, text: str):
        n_retries = 0
        while True:
            try:
                response = (
                    self.client.beta.chat.completions.parse(
                        model=self.GPT_MODEL,
                        temperature=0,
                        messages=[
                            {"role": "system", "content": system_prompt},
                            {"role": "user", "content": text},
                        ],
                        response_format=SemanticTripleList,
                    )
                ).choices[0].message
                break

            except openai.RateLimitError as err:
                n_retries += 1
                print(err)
                print("Exceeded rate limit")
                print(f"Sleeping before retry (done {n_retries} time(s))")
                time.sleep(self.ERROR_RETRY_SLEEP)

            except Exception as err:
                n_retries += 1
                print(f"Unexpected error ({err})")
                print(f"Sleeping before retry (done {n_retries} time(s))")
                time.sleep(self.ERROR_RETRY_SLEEP)

        if response is None:
            raise ValueError("Got null response")
        elif response.refusal:
            raise ValueError(response.refusal)
        
        #data = json.loads(response.arguments)["triples"]
        #adapter = TypeAdapter(list[SemanticTriple])

        #return response.validate_python(data)
        return SemanticTripleList.model_validate_json(response.content)

In [14]:
client = openai.OpenAI(api_key=openai_token)
semantic_extractor = SemanticTripleExtractor(client)

In [15]:
test_triples = semantic_extractor.get_semantic_triples("Val Kilmer played Batman in Batman Forever.")

In [16]:
for i in test_triples.triples:
    print(i)

entityA='ValKilmer' relationship='playedCharacter' entityB='Batman'
entityA='ValKilmer' relationship='actedIn' entityB='BatmanForever'
entityA='Batman' relationship='characterIn' entityB='BatmanForever'


In [17]:
all_chunks = []
for title, summary in summaries.items():
    all_chunks += [
        f"(Topic: {title}) {sentence}"
        for sentence in re.split(r"[\.!?]", summary)
    ]
len(all_chunks)

17186

In [74]:
import chromadb

chroma_client = chromadb.PersistentClient(
    path="vectordb",
    settings=chromadb.Settings(allow_reset=True)
)
chroma_client.reset()
ontology_db = chroma_client.create_collection("ontology_embeddings", metadata={"hnsw:space": "cosine"})
kg_db = chroma_client.create_collection("fact_embeddings", metadata={"hnsw:space": "cosine"})

In [40]:
def split_label(label: str) -> str:
    label_str = label[0].upper() + label[1:]
    return re.sub(r"[A-Z]", lambda m: " " + m.group(0).lower(), label_str).strip()

In [59]:
from typing import Literal, Optional

def get_best_match(collection: chromadb.Collection, label: str, item_type: Literal["entity", "relationship"], min_similarity=0.9) -> Optional[str]:
    lexical_match = collection.get(ids=[label], where={"type": item_type})
    if lexical_match["ids"]:
        # print("Got lexical match:", label)
        return label
    
    semantic_match = collection.query(query_texts=[label], n_results=1, where={"type": item_type})
    # print("Semantic match for", label, semantic_match)
    if not semantic_match["ids"][0] or semantic_match["distances"][0][0] < min_similarity:
        # print("No adequate semantic match")
        return None

    semantic_match_label = str(semantic_match["ids"][0][0])
    # print("Found semantic match w/ similarity", semantic_match["distances"][0][0], semantic_match_label)
    return semantic_match_label

def ground_item(collection: chromadb.Collection, label: str, item_type: Literal["entity", "relationship"]) -> str:
    item_match = get_best_match(collection, label, item_type)
    if item_match is not None:
        return item_match
    
    ontology_db.add(
        documents=[split_label(label)],
        metadatas=[{"label": label, "type": item_type}],
        ids=[label]
    )
    return label
    

def ground_triple(collection: chromadb.Collection, triple: SemanticTriple) -> SemanticTriple:
    return SemanticTriple(
        entityA=ground_item(collection, triple.entityA, "entity"),
        relationship=ground_item(collection, triple.relationship, "relationship"),
        entityB=ground_item(collection, triple.entityB, "entity")
    )

In [77]:
import random
import numpy as np
from uuid import uuid4

knowledge_graph = []

chunk_sample = random.sample(all_chunks, 2000)

for chunk in tqdm(chunk_sample):
    triples = semantic_extractor.get_semantic_triples(chunk).triples
    for triple in triples:
        grounded_triple = ground_triple(ontology_db, triple)
        item_vecs = [
            ontology_db.get(
            ids=[label],
            include=["embeddings"]
            )["embeddings"][0]
            for label in [grounded_triple.entityA, grounded_triple.relationship, grounded_triple.entityB]
        ]
        fact_embedding = np.concatenate(item_vecs)

        kg_db.add(
            ids=[str(uuid4())],
            embeddings=[fact_embedding],
            metadatas=[{
                "entityA": grounded_triple.entityA,
                "relationship": grounded_triple.relationship,
                "entityB": grounded_triple.entityB
            }]
        )
    
with open("extracted-kg.json", "w") as f:
    json.dump(knowledge_graph, f)

100%|██████████| 2000/2000 [4:28:27<00:00,  8.05s/it]  
