In [1]:
from utils import read_jsonl, save_jsonl
import pandas as pd
from pydantic import BaseModel, model_validator, field_validator, Field, ValidationInfo
from typing import List, Dict, Union, Any, Optional, Literal
import instructor
from openai import OpenAI
import os
import json
from langchain.retrievers import ParentDocumentRetriever
from langchain.storage import InMemoryStore
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import Chroma
from langchain_core.documents.base import Document
from langchain_openai import OpenAIEmbeddings

from wikidata_search import WikidataSearch, get_all_properties_with_labels
from data_classes import KnowledgeGraph, ValidatedProperty

In [2]:
client = instructor.patch(OpenAI(api_key=os.environ['OPENAI_API_KEY']))
MODEL = "gpt-3.5-turbo-0125"

# 🧠 Load data

In [3]:
pred_kbs = read_jsonl('../../data/prediction.jsonl')
print(f"Number of predicted KBs: {len(pred_kbs)}")

Number of predicted KBs: 4


In [4]:
pred_kbs[0].keys()

dict_keys(['entity_label', 'properties'])

# 🪬 Define Evaluation Model

In [5]:

class WikidataKGValidator(KnowledgeGraph):

    reference_knowledge_graph: KnowledgeGraph = None
    validated_properties: List[ValidatedProperty] = []


    @staticmethod
    def get_wikidata(entity_label):
        qid: List = WikidataSearch.search_wikidata(entity_label)
        if len(qid) == 0:
            raise ValueError(f"No results found for {entity_label} :(")

        wikidata_kg = get_all_properties_with_labels(qid[0]['id'])

        reference_knowledge_graph = KnowledgeGraph(**wikidata_kg)
        return reference_knowledge_graph
    
    @staticmethod
    def get_wikidata_neighbors(wikidata_kg: KnowledgeGraph) -> List[KnowledgeGraph]:
        one_hop_kg = {}
        for property in wikidata_kg['properties'].keys():
            qid: List = WikidataSearch.search_wikidata(property)
            if len(qid) > 0:
                print(f"Found wikidata id for {property}")
                one_hop_kg[property] = get_all_properties_with_labels(qid[0]['id'])

        return one_hop_kg



    @staticmethod
    def create_parent_document_retriever(docs: List[Document]):
        # https://python.langchain.com/docs/modules/data_connection/retrievers/parent_document_retriever

        # This text splitter is used to create the child documents
        child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
        # The vectorstore to use to index the child chunks
        vectorstore = Chroma(
            collection_name="full_documents", embedding_function=OpenAIEmbeddings()
        )
        # The storage layer for the parent documents
        store = InMemoryStore()
        retriever = ParentDocumentRetriever(
            vectorstore=vectorstore,
            docstore=store,
            child_splitter=child_splitter,
            # parent_splitter=parent_splitter,
        )
        retriever.add_documents(docs, ids=None) # add entity doc(s)

        # list(store.yield_keys())   # see how many chunks it's created

        return retriever, store, vectorstore

    @staticmethod
    def retrieve_relevant_property(entity_name, property_name, vectorstore, retriever):
        '''Fetch the most similar chunk to predicted property name'''

        query = f"{property_name}"

        sub_docs = vectorstore.similarity_search(query)

        relevant_property = sub_docs[0].page_content
        return relevant_property

    @staticmethod
    def validate_statement_with_context(entity_label, predicted_property_name, predicted_property_value, context):
        '''Validate a statement about an entity

        a statement is a triple: entity_label --> predicted_property_name --> predicted_property_value
                             e.g Donald Trump --> wife --> Ivanka Trump
        
        '''
        resp: ValidatedProperty = client.chat.completions.create(
                        response_model=ValidatedProperty,
                        messages=[
                            {
                                "role": "user",
                                "content": f"Using the given context as a reference, " +
                                "is the following predicted property valid for the given entity? " +
                                f"\nEntity Label: {entity_label}" +
                                f"\nPredicted Property Name: {predicted_property_name}" +
                                f"\nPredicted Property Value: {predicted_property_value}" +
                                f"\n\nContext {context}"
                            }
                        ],
                        max_retries=2,
                        model=MODEL,
                    )
        return resp

    @model_validator(mode='before')
    def validate(self, context) -> "WikidataKGValidator":

        self['validated_properties'] = []

        self['reference_knowledge_graph'] = WikidataKGValidator.get_wikidata(self['entity_label'])

        ref_property_names = [Document(p) for p in self['reference_knowledge_graph'].properties.keys()]   # 🚨 embedding the reference properties
        retriever, store, vectorstore = WikidataKGValidator.create_parent_document_retriever(
            ref_property_names
        )

        for predicted_property_name, predicted_property_value in self['properties'].items():

            relevant_property_name = WikidataKGValidator.retrieve_relevant_property(
                entity_name=self['entity_label'],
                property_name=predicted_property_name, 
                vectorstore=vectorstore,
                retriever=retriever
            )
            # get reference property-value pair
            relevant_property = {
                relevant_property_name: self['reference_knowledge_graph'].properties[relevant_property_name]
            }

            # EVALUATE ONE PROPERTY
            resp = WikidataKGValidator.validate_statement_with_context(
                self['entity_label'], 
                predicted_property_name, 
                predicted_property_value, 
                context=relevant_property
            )

            self['validated_properties'].append(resp)
        return self


    @model_validator(mode='after')
    def assert_all_properties_validated(self, info: ValidationInfo):
        if len(self.validated_properties) != len(self.properties):
            raise ValueError(
                "Number of properties validated does not match number of properties in the prediction knowledge base. " +
                f"Number of properties validated: {len(self.validated_properties)}, " +
                f"Number of properties in the text: {len(self.properties)}"
                )
        return self



# 🐕 Relevant Chunk Retrieval

In [6]:
# idx = 0
# results = []

# pred_kbs[idx]['reference_knowledge_graph'] = WikidataKGValidator.get_wikidata(
#     entity_label=pred_kbs[idx]['entity_label']
# )

# ref_property_names = [Document(p) for p in pred_kbs[idx]['reference_knowledge_graph'].properties.keys()]   # 🚨 embedding the reference properties
# retriever, store, vectorstore = WikidataKGValidator.create_parent_document_retriever(ref_property_names)

# retrieved_properties = []
# for predicted_property_name, predicted_property_value in pred_kbs[idx]['properties'].items():

#     relevant_property_name = WikidataKGValidator.retrieve_relevant_property(
#         entity_name=pred_kbs[idx]['entity_label'],
#         property_name=predicted_property_name, 
#         vectorstore=vectorstore,
#         retriever=retriever
#     )

#     # # get reference property-value
#     relevant_property = {
#         relevant_property_name: pred_kbs[idx]['reference_knowledge_graph'].properties[relevant_property_name]
#     }
#     retrieved_properties.append((predicted_property_name, predicted_property_value, relevant_property))
#     print((predicted_property_name, predicted_property_value, relevant_property))


In [7]:
# relevant_property = WikidataKGValidator.retrieve_relevant_property(
#     entity_name="Barack Obama",
#     property_name='Job', 
#     vectorstore=vectorstore,
#     retriever=retriever
# )
# relevant_property

In [8]:
# query = f"Barack Obama party"
# sub_docs = vectorstore.similarity_search(query)
# sub_docs

In [9]:
# for predicted_property_name, predicted_property_value, relevant_property in retrieved_properties:
#     print(f"\n----------\n❓Pred Property name: {predicted_property_name}\n🙋Predicted value: {predicted_property_value}\n⭐️ Relevant property -->  {relevant_property}\n")

# Evaluate!

In [10]:
idx = 0
results = []

# Add some (presumably) wrong properties
pred_kbs[idx]['properties']['Bought Stocks in'] = ['Tesla', 'Nvidia', 'Hertz']
pred_kbs[idx]['properties']['Favourite Fast Food Chain'] = 'McDonalds'

results.append(WikidataKGValidator(**pred_kbs[idx]))

querying wikidata with params: {'action': 'wbsearchentities', 'format': 'json', 'errorformat': 'plaintext', 'language': 'en', 'uselang': 'en', 'type': 'item', 'limit': 1, 'search': 'George Washington'}


In [11]:
results[0].model_dump()['validated_properties']

[{'property_name': 'Name',
  'property_value': 'George Washington',
  'property_is_valid': 'Not enough information to say',
  'is_valid_reason': None,
  'error_message': 'Insufficient context information to validate the predicted property.'},
 {'property_name': 'Birth date',
  'property_value': 'February 22, 1732',
  'property_is_valid': 'Not enough information to say',
  'is_valid_reason': None,
  'error_message': 'Insufficient context to validate the predicted birth date property for the entity George Washington.'},
 {'property_name': 'Death date',
  'property_value': 'December 14, 1799',
  'property_is_valid': 'Not enough information to say',
  'is_valid_reason': None,
  'error_message': None},
 {'property_name': 'Occupation',
  'property_value': ['Founding Father',
   'Military Officer',
   'Politician',
   'First President of the United States'],
  'property_is_valid': 'Not enough information to say',
  'is_valid_reason': None,
  'error_message': None},
 {'property_name': 'Place o

In [12]:
results_json = [r.model_dump() for r in results]
save_jsonl(results_json, '../../data/wikidata_kg_context_evaluation_results.jsonl')

Saved to f'../../data/wikidata_kg_context_evaluation_results.jsonl


# Look at our Evaluations

In [13]:
results = read_jsonl('../../data/wikidata_kg_context_evaluation_results.jsonl')
len(results)

1