In [1]:
from utils import read_jsonl, save_jsonl
import pandas as pd
from pydantic import BaseModel, model_validator, field_validator, Field, ValidationInfo
from typing import List, Dict, Union, Any, Optional
import instructor
from openai import OpenAI
import os
import json
from langchain.retrievers import ParentDocumentRetriever
from langchain.storage import InMemoryStore
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import Chroma
from langchain_core.documents.base import Document
from langchain_openai import OpenAIEmbeddings

In [2]:
client = instructor.patch(OpenAI(api_key=os.environ['OPENAI_API_KEY']))
MODEL = "gpt-3.5-turbo-0125"

# 🧠 Load data

In [3]:
pred_kbs = read_jsonl('../../data/prediction.jsonl')
ref_kbs = read_jsonl('../../data/wikidata_entities.jsonl')
print(f"Number of predicted KBs: {len(pred_kbs)}")

Number of predicted KBs: 4


In [4]:
ref_kbs[0].keys()

dict_keys(['entity_label', 'properties', 'chunked_content', 'QID'])

In [5]:
pred_kbs[0].keys()

dict_keys(['entity_label', 'properties'])

In [6]:
# only take the predictions and references of entities that exist in both
union_entities = set([p['entity_label'] for p in pred_kbs]).intersection(set([r['entity_label'] for r in ref_kbs]))

pred_kbs = [kb for kb in pred_kbs if kb['entity_label'] in union_entities]
pred_kbs = sorted(pred_kbs, key=lambda x: x['entity_label'])

ref_kbs = [kb for kb in ref_kbs if kb['entity_label'] in union_entities]
ref_kbs = sorted(ref_kbs, key=lambda x: x['entity_label'])

print(f"Number of predicted KBs: {len(pred_kbs)}")
print(f"Number of reference KBs: {len(ref_kbs)}")
print("ref: ", [kb['entity_label'] for kb in pred_kbs])
print("pred: ", [kb['entity_label'] for kb in ref_kbs])

Number of predicted KBs: 4
Number of reference KBs: 4
ref:  ['Barack Obama', 'Douglas Adams', 'George Washington', 'Tim Berners-Lee']
pred:  ['Barack Obama', 'Douglas Adams', 'George Washington', 'Tim Berners-Lee']


# 🪬 Define Evaluation Model

In [7]:
from wikidata_search import WikidataSearch, get_all_properties_with_labels
from typing import List
entity_label = 'Obama'
qid: List = WikidataSearch.search_wikidata(entity_label)
if len(qid) == 0:
    print(f"No results found for {entity_label} :(")
print(qid)

wikidata_kg = get_all_properties_with_labels(qid[0]['id'])
one_hop_kg = {}
for property in wikidata_kg['properties'].keys():
    qid: List = WikidataSearch.search_wikidata(property)
    if len(qid) > 0:
        print(f"Found wikidata id for {property}")
        one_hop_kg[property] = get_all_properties_with_labels(qid[0]['id'])


querying wikidata with params: {'action': 'wbsearchentities', 'format': 'json', 'errorformat': 'plaintext', 'language': 'en', 'uselang': 'en', 'type': 'item', 'limit': 1, 'search': 'Obama'}
[{'id': 'Q76', 'title': 'Q76', 'pageid': 205, 'display': {'label': {'value': 'Barack Obama', 'language': 'en'}, 'description': {'value': 'President of the United States from 2009 to 2017', 'language': 'en'}}, 'repository': 'wikidata', 'url': '//www.wikidata.org/wiki/Q76', 'concepturi': 'http://www.wikidata.org/entity/Q76', 'label': 'Barack Obama', 'description': 'President of the United States from 2009 to 2017', 'match': {'type': 'alias', 'language': 'en', 'text': 'Obama'}, 'aliases': ['Obama']}]
querying wikidata with params: {'action': 'wbsearchentities', 'format': 'json', 'errorformat': 'plaintext', 'language': 'en', 'uselang': 'en', 'type': 'item', 'limit': 1, 'search': 'instance of'}
Found wikidata id for instance of
querying wikidata with params: {'action': 'wbsearchentities', 'format': 'json

In [8]:
print(f"found kgs for {len(one_hop_kg)} properties out of {len(wikidata_kg['properties'])}")

found kgs for 81 properties out of 348


In [56]:

class ValidatedProperty(BaseModel):
    property_name: str
    property_value: Any

    property_is_valid: bool = Field(
      ...,
        description="Whether the property is generally valid, judged against " +
                    "the given context.",
    )
    is_valid_reason: Optional[str] = Field(
        None, description="The reason why the property is valid if it is indeed valid."
    )
    error_message: Optional[str] = Field(
        None, description="The error message if either property_name and/or property_value is not valid."
    )


class KnowledgeGraph(BaseModel):
    entity_label: str
    properties: Dict[str, Any]


class WikidataKGValidator(KnowledgeGraph):

    reference_knowledge_graph: KnowledgeGraph = None
    validated_properties: List[ValidatedProperty] = []


    @staticmethod
    def get_wikidata(entity_label):
        qid: List = WikidataSearch.search_wikidata(entity_label)
        if len(qid) == 0:
            raise ValueError(f"No results found for {entity_label} :(")

        wikidata_kg = get_all_properties_with_labels(qid[0]['id'])

        reference_knowledge_graph = KnowledgeGraph(**wikidata_kg)
        return reference_knowledge_graph
    
    @staticmethod
    def get_wikidata_neighbors(wikidata_kg: KnowledgeGraph) -> List[KnowledgeGraph]:
        one_hop_kg = {}
        for property in wikidata_kg['properties'].keys():
            qid: List = WikidataSearch.search_wikidata(property)
            if len(qid) > 0:
                print(f"Found wikidata id for {property}")
                one_hop_kg[property] = get_all_properties_with_labels(qid[0]['id'])

        return one_hop_kg



    @staticmethod
    def create_parent_document_retriever(docs: List[Document]):
        # https://python.langchain.com/docs/modules/data_connection/retrievers/parent_document_retriever

        # This text splitter is used to create the child documents
        child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
        # The vectorstore to use to index the child chunks
        vectorstore = Chroma(
            collection_name="full_documents", embedding_function=OpenAIEmbeddings()
        )
        # The storage layer for the parent documents
        store = InMemoryStore()
        retriever = ParentDocumentRetriever(
            vectorstore=vectorstore,
            docstore=store,
            child_splitter=child_splitter,
            # parent_splitter=parent_splitter,
        )
        retriever.add_documents(docs, ids=None) # add entity doc(s)

        # list(store.yield_keys())   # see how many chunks it's created

        return retriever, store, vectorstore

    @staticmethod
    def retrieve_relevant_property(entity_name, property_name, vectorstore, retriever):
        '''Fetch the most similar chunk to predicted property name'''

        query = f"{property_name}"

        sub_docs = vectorstore.similarity_search(query)

        relevant_property = sub_docs[0].page_content
        return relevant_property

    @staticmethod
    def validate_statement_with_context(entity_label, predicted_property_name, predicted_property_value, context):
        '''Validate a statement about an entity

        a statement is a triple: entity_label --> predicted_property_name --> predicted_property_value
                             e.g Donald Trump --> wife --> Ivanka Trump
        
        '''
        resp: ValidatedProperty = client.chat.completions.create(
                        response_model=ValidatedProperty,
                        messages=[
                            {
                                "role": "user",
                                "content": f"Using your knowledge of the world " +
                                "and the given context as a reference, " +
                                "is the following predicted property valid for the given entity? " +
                                f"\nEntity Label: {entity_label}" +
                                f"\nPredicted Property Name: {predicted_property_name}" +
                                f"\nPredicted Property Value: {predicted_property_value}" +
                                f"\n\nContext {context}"
                            }
                        ],
                        max_retries=2,
                        model=MODEL,
                    )
        return resp

    @model_validator(mode='before')
    def validate(self, context) -> "WikidataKGValidator":

        self['validated_properties'] = []

        self['reference_knowledge_graph'] = WikidataKGValidator.get_wikidata(self['entity_label'])

        ref_property_names = [Document(p) for p in self['reference_knowledge_graph'].properties.keys()]   # 🚨 embedding the reference properties
        retriever, store, vectorstore = WikidataKGValidator.create_parent_document_retriever(
            ref_property_names
        )

        for predicted_property_name, predicted_property_value in self['properties'].items():

            relevant_property_name = WikidataKGValidator.retrieve_relevant_property(
                entity_name=self['entity_label'],
                property_name=predicted_property_name, 
                vectorstore=vectorstore,
                retriever=retriever
            )
            # get reference property-value pair
            relevant_property = {
                relevant_property_name: self['reference_knowledge_graph'].properties[relevant_property_name]
            }

            # EVALUATE ONE PROPERTY
            resp = WikidataKGValidator.validate_statement_with_context(
                entity_label, 
                predicted_property_name, 
                predicted_property_value, 
                context=relevant_property
            )

            self['validated_properties'].append(resp)
        return self


    @model_validator(mode='after')
    def assert_all_properties_validated(self, info: ValidationInfo):
        if len(self.validated_properties) != len(self.properties):
            raise ValueError(
                "Number of properties validated does not match number of properties in the prediction knowledge base. " +
                f"Number of properties validated: {len(self.validated_properties)}, " +
                f"Number of properties in the text: {len(self.properties)}"
                )
        return self



# 🐕 Relevant Chunk Retrieval

In [38]:
idx = 0
results = []

pred_kbs[idx]['reference_knowledge_graph'] = WikidataKGValidator.get_wikidata(
    entity_label=pred_kbs[idx]['entity_label']
)

ref_property_names = [Document(p) for p in pred_kbs[idx]['reference_knowledge_graph'].properties.keys()]   # 🚨 embedding the reference properties
retriever, store, vectorstore = WikidataKGValidator.create_parent_document_retriever(ref_property_names)

retrieved_properties = []
for predicted_property_name, predicted_property_value in pred_kbs[idx]['properties'].items():

    relevant_property_name = WikidataKGValidator.retrieve_relevant_property(
        entity_name=pred_kbs[idx]['entity_label'],
        property_name=predicted_property_name, 
        vectorstore=vectorstore,
        retriever=retriever
    )

    # # get reference property-value
    relevant_property = {
        relevant_property_name: pred_kbs[idx]['reference_knowledge_graph'].properties[relevant_property_name]
    }
    retrieved_properties.append((predicted_property_name, predicted_property_value, relevant_property))
    print((predicted_property_name, predicted_property_value, relevant_property))


querying wikidata with params: {'action': 'wbsearchentities', 'format': 'json', 'errorformat': 'plaintext', 'language': 'en', 'uselang': 'en', 'type': 'item', 'limit': 1, 'search': 'Barack Obama'}


('Birth Place', 'Honolulu, Hawaii', {'place of birth': ['Kapiolani Medical Center for Women and Children', 'Honolulu']})
('Birthday', 'August 4, 1961', {'date of birth': ['+1961-08-04T00:00:00Z']})
('Party', 'Democratic Party', {'occupation': ['politician', 'lawyer', 'political writer', 'community organizer', 'statesperson', 'jurist', 'podcaster', 'academic', 'memoirist', 'international forum participant']})
('Education', ['Columbia University', 'Harvard Law School', 'Occidental College'], {'educated at': ['State Elementary School Menteng 01', 'Punahou School', 'Occidental College', 'Columbia University', 'Harvard Law School', 'Noelani Elementary School', 'Centaurus High School', 'University of Chicago Law School', 'Harvard University', 'Nelson High School', 'King College Prep High School']})
('Net Worth in 2007', '$1.3 million (equivalent to $1.8 million in 2022)', {'award received': ['Nobel Peace Prize', 'Grammy Award for Best Audio Book, Narration & Storytelling Recording', 'Preside

In [39]:
relevant_property = WikidataKGValidator.retrieve_relevant_property(
    entity_name="Barack Obama",
    property_name='Job', 
    vectorstore=vectorstore,
    retriever=retriever
)
relevant_property

'occupation'

In [40]:
query = f"Barack Obama party"
sub_docs = vectorstore.similarity_search(query)
sub_docs

[Document(page_content='member of political party', metadata={'doc_id': 'f1aa4b8c-0efd-4308-802c-b4b574d473e9'}),
 Document(page_content='member of political party', metadata={'doc_id': '2d20de7e-9c53-4fcb-bbf2-bdeef9c46008'}),
 Document(page_content='member of political party', metadata={'doc_id': '7b3614f8-8cfa-4749-92f3-f4cfc3166922'}),
 Document(page_content='member of political party', metadata={'doc_id': '638c0f84-a6c1-4800-919e-e45133a0f0f7'})]

In [41]:
for predicted_property_name, predicted_property_value, relevant_property in retrieved_properties:
    print(f"\n----------\n❓Pred Property name: {predicted_property_name}\n🙋Predicted value: {predicted_property_value}\n⭐️ Relevant property -->  {relevant_property}\n")


----------
❓Pred Property name: Birth Place
🙋Predicted value: Honolulu, Hawaii
⭐️ Relevant property -->  {'place of birth': ['Kapiolani Medical Center for Women and Children', 'Honolulu']}


----------
❓Pred Property name: Birthday
🙋Predicted value: August 4, 1961
⭐️ Relevant property -->  {'date of birth': ['+1961-08-04T00:00:00Z']}


----------
❓Pred Property name: Party
🙋Predicted value: Democratic Party
⭐️ Relevant property -->  {'occupation': ['politician', 'lawyer', 'political writer', 'community organizer', 'statesperson', 'jurist', 'podcaster', 'academic', 'memoirist', 'international forum participant']}


----------
❓Pred Property name: Education
🙋Predicted value: ['Columbia University', 'Harvard Law School', 'Occidental College']
⭐️ Relevant property -->  {'educated at': ['State Elementary School Menteng 01', 'Punahou School', 'Occidental College', 'Columbia University', 'Harvard Law School', 'Noelani Elementary School', 'Centaurus High School', 'University of Chicago Law Sc

# Evaluate!

In [57]:
idx = 0
results = []

# Add some (presumably) wrong properties
pred_kbs[idx]['properties']['Bought Stocks in'] = ['Tesla', 'Nvidia', 'Hertz']
pred_kbs[idx]['properties']['Favourite Fast Food Chain'] = 'McDonalds'

results.append(WikidataKGValidator(**pred_kbs[idx]))

querying wikidata with params: {'action': 'wbsearchentities', 'format': 'json', 'errorformat': 'plaintext', 'language': 'en', 'uselang': 'en', 'type': 'item', 'limit': 1, 'search': 'Barack Obama'}


In [58]:
results[0].model_dump()['validated_properties']

[{'property_name': 'Birth Place',
  'property_value': 'Honolulu, Hawaii',
  'property_is_valid': True,
  'is_valid_reason': 'Honolulu, Hawaii is a valid birthplace for Obama based on the context provided.',
  'error_message': None},
 {'property_name': 'Birthday',
  'property_value': 'August 4, 1961',
  'property_is_valid': True,
  'is_valid_reason': 'The predicted property value matches the known date of birth of August 4, 1961 for the entity Obama.',
  'error_message': None},
 {'property_name': 'Party',
  'property_value': 'Democratic Party',
  'property_is_valid': True,
  'is_valid_reason': "The predicted party 'Democratic Party' matches the known political party affiliations of the entity 'Obama' based on the context provided.",
  'error_message': None},
 {'property_name': 'Education',
  'property_value': ['Columbia University',
   'Harvard Law School',
   'Occidental College'],
  'property_is_valid': True,
  'is_valid_reason': 'All education institutions are present in the context'

In [59]:
results_json = [r.model_dump() for r in results]
save_jsonl(results_json, '../../data/wikidata_kg_context_evaluation_results.jsonl')

Saved to f'../../data/wikidata_kg_context_evaluation_results.jsonl


# Look at our Evaluations

In [62]:
results = read_jsonl('../../data/wikidata_kg_context_evaluation_results.jsonl')
len(results)

1