In [14]:
from utils import read_jsonl, save_jsonl
import pandas as pd
from pydantic import BaseModel, model_validator, field_validator, Field, ValidationInfo
from typing import List, Dict, Union, Any, Optional, Literal
import instructor
from openai import OpenAI
import os
import json
from langchain.retrievers import ParentDocumentRetriever
from langchain.storage import InMemoryStore
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import Chroma
from langchain_core.documents.base import Document
from langchain_openai import OpenAIEmbeddings

from wikidata_search import WikidataSearch, get_all_properties_with_labels
from data_classes import KnowledgeGraph, ValidatedProperty
from langchain_community.tools.wikidata.tool import WikidataAPIWrapper, WikidataQueryRun
from langchain.tools import DuckDuckGoSearchRun

In [15]:
# q = 'st__louis_cardinals'
# q = " ".join(q.split('_'))
# WikidataSearch()(q)

In [16]:
client = instructor.patch(OpenAI(api_key=os.environ['OPENAI_API_KEY']))
MODEL = "gpt-3.5-turbo-0125"

# 🧠 Load data

In [17]:
def read_nell_sports(file_path) -> List[Dict]:
    with open(file_path) as f:
        triples = []
        for line in f:
            # Strip the trailing period and split the line
            parts = line.rstrip('.').split('"')
            # Extract the subject, relation (assumed), and object
            subject = parts[1]
            object_ = parts[3]
            triples.append({'subject': subject, 'relation': 'teamplayssport', 'object': object_})
    return triples

neg_triples = read_nell_sports('./NELL/neg.txt')
pos_triples = read_nell_sports('./NELL/pos.txt')
neg_triples[:15]

[{'subject': 'chicago_cubs',
  'relation': 'teamplayssport',
  'object': 'basketball'},
 {'subject': 'chicago_white_sox',
  'relation': 'teamplayssport',
  'object': 'soccer'},
 {'subject': 'derby_county',
  'relation': 'teamplayssport',
  'object': 'football'},
 {'subject': 'broncos', 'relation': 'teamplayssport', 'object': 'football'},
 {'subject': 'charlotte_bobcats',
  'relation': 'teamplayssport',
  'object': 'softball'},
 {'subject': 'chicago_fire', 'relation': 'teamplayssport', 'object': 'golf'},
 {'subject': 'cubbies', 'relation': 'teamplayssport', 'object': 'hockey'},
 {'subject': 'arkansas_razorbacks',
  'relation': 'teamplayssport',
  'object': 'baseball'},
 {'subject': 'boston_college',
  'relation': 'teamplayssport',
  'object': 'baseball'},
 {'subject': 'anaheim_ducks',
  'relation': 'teamplayssport',
  'object': 'football'},
 {'subject': 'alabama_crimson_tide',
  'relation': 'teamplayssport',
  'object': 'golf'},
 {'subject': 'dallas_stars',
  'relation': 'teamplayssport

In [18]:

# wrapper = WikidataAPIWrapper()
# wrapper.top_k_results = 1
# wikidata = WikidataQueryRun(api_wrapper=wrapper)

# q = 'st__louis_cardinals'
# # q = " ".join(q.split('_'))
# print(wikidata.run(q))

# 🪬 Define Evaluation Model

In [19]:

# import pandas as pd

# df = pd.read_csv('NELL.08m.100.SSFeedback.csv.gz', compression='gzip', on_bad_lines='skip', sep='\t')
# print(len(df))
# df.head()

In [20]:

class ValidatedProperty(BaseModel):
    entity_label: str
    property_name: str
    property_value: Any

    property_is_valid: Literal[True, False, "Not enough information to say"] = Field(
      ...,
        description="Whether the property is generally valid, judged against " +
                    "the given context.",
    )
    is_valid_reason: Optional[str] = Field(
        None, description="The reason why the property is valid if it is indeed valid."
    )
    error_message: Optional[str] = Field(
        None, description="The error message if either property_name and/or property_value is not valid."
    )


class WikidataKGValidator(BaseModel):

    triples: List
    validated_properties: List[ValidatedProperty] = []


    @staticmethod
    def get_wikidata(entity_label, wikidata_wrapper):

        q = entity_label
        q = " ".join(q.split('_'))
        return wikidata_wrapper.run(q)  # a string of the wikidata page

    
    @staticmethod
    def get_wikidata_neighbors(wikidata_kg: KnowledgeGraph) -> List[KnowledgeGraph]:
        one_hop_kg = {}
        for property in wikidata_kg['properties'].keys():
            qid: List = WikidataSearch.search_wikidata(property)
            if len(qid) > 0:
                print(f"Found wikidata id for {property}")
                one_hop_kg[property] = get_all_properties_with_labels(qid[0]['id'])

        return one_hop_kg



    @staticmethod
    def create_parent_document_retriever(docs: List[Document]):
        # https://python.langchain.com/docs/modules/data_connection/retrievers/parent_document_retriever

        # This text splitter is used to create the child documents
        child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
        # The vectorstore to use to index the child chunks
        vectorstore = Chroma(
            collection_name="full_documents", embedding_function=OpenAIEmbeddings()
        )
        # The storage layer for the parent documents
        store = InMemoryStore()
        retriever = ParentDocumentRetriever(
            vectorstore=vectorstore,
            docstore=store,
            child_splitter=child_splitter,
            # parent_splitter=parent_splitter,
        )
        retriever.add_documents(docs, ids=None) # add entity doc(s)

        # list(store.yield_keys())   # see how many chunks it's created

        return retriever, store, vectorstore

    @staticmethod
    def retrieve_relevant_property(entity_name, property_name, vectorstore, retriever):
        '''Fetch the most similar chunk to predicted property name'''

        query = f"{property_name}"

        sub_docs = vectorstore.similarity_search(query)

        relevant_property = sub_docs[0].page_content
        return relevant_property

    @staticmethod
    def validate_statement_with_context(entity_label, predicted_property_name, predicted_property_value, context):
        '''Validate a statement about an entity

        a statement is a triple: entity_label --> predicted_property_name --> predicted_property_value
                             e.g Donald Trump --> wife --> Ivanka Trump
        
        '''
        resp: ValidatedProperty = client.chat.completions.create(
                        response_model=ValidatedProperty,
                        messages=[
                            {
                                "role": "user",
                                "content": f"Using the given context as a reference, " +
                                "is the following predicted property valid for the given entity? " +
                                f"\nEntity Label: {entity_label}" +
                                f"\nPredicted Property Name: {predicted_property_name}" +
                                f"\nPredicted Property Value: {predicted_property_value}" +
                                f"\n\nContext {context}"
                            }
                        ],
                        max_retries=2,
                        model=MODEL,
                    )
        return resp

    @model_validator(mode='before')
    def validate(self, context) -> "WikidataKGValidator":

        self['validated_properties'] = []

        wrapper = WikidataAPIWrapper()
        wrapper.top_k_results = 1
        wikidata_wrapper = WikidataQueryRun(api_wrapper=wrapper)

        for triple in self['triples']:

            subject, relation, object = triple['subject'], triple['relation'], triple['object']

            wikidata_reference = WikidataKGValidator.get_wikidata(subject, wikidata_wrapper)

            # EVALUATE ONE PROPERTY
            resp = WikidataKGValidator.validate_statement_with_context(
                entity_label=subject, 
                predicted_property_name=relation, 
                predicted_property_value=object, 
                context=wikidata_reference
            )

            self['validated_properties'].append(resp)
        return self


    @model_validator(mode='after')
    def assert_all_properties_validated(self, info: ValidationInfo):
        if len(self.validated_properties) != len(self.triples):
            raise ValueError(
                "Number of properties validated does not match number of properties in the prediction knowledge base. " +
                f"Number of properties validated: {len(self.validated_properties)}, " +
                f"Number of properties in the text: {len(self.triples)}"
                )
        return self



In [32]:
wrapper = WikidataAPIWrapper()
wrapper.top_k_results = 1
wikidata_wrapper = WikidataQueryRun(api_wrapper=wrapper)

wikidata_reference = WikidataKGValidator.get_wikidata('Derby FC', wikidata_wrapper)
print(wikidata_reference)

Result Q578856:
Label: FC Derby
Description: association football club
instance of: association football club
country: Cape Verde
inception: 1929-08-05
sport: association football
headquarters location: Mindelo


# Evaluate!

In [21]:
neg_results = []

neg_results.append(WikidataKGValidator(**{'triples': neg_triples[:7]}))

neg_results[0].model_dump()['validated_properties']

[{'entity_label': 'chicago_cubs',
  'property_name': 'teamplayssport',
  'property_value': 'basketball',
  'property_is_valid': False,
  'is_valid_reason': 'The Chicago Cubs are specifically known as a baseball team and they play the sport of baseball as mentioned in the context.',
  'error_message': None},
 {'entity_label': 'chicago_white_sox',
  'property_name': 'teamplayssport',
  'property_value': 'soccer',
  'property_is_valid': False,
  'is_valid_reason': "The context information describes the Chicago White Sox as a baseball team playing baseball, not soccer. Therefore, the predicted property 'teamplayssport' with a value of 'soccer' is not valid for the entity.",
  'error_message': None},
 {'entity_label': 'derby_county',
  'property_name': 'teamplayssport',
  'property_value': 'football',
  'property_is_valid': True,
  'is_valid_reason': "The context result Q19470 indicates that Derby County F.C. is an association football club, which means the predicted property 'teamplaysspor

In [25]:
pos_results = []

pos_results.append(WikidataKGValidator(**{'triples': pos_triples[:7]}))

pos_results[0].model_dump()['validated_properties']

[{'entity_label': 'texans',
  'property_name': 'teamplayssport',
  'property_value': 'hockey',
  'property_is_valid': False,
  'is_valid_reason': None,
  'error_message': "The predicted property 'teamplayssport' with a value of 'hockey' is not valid for the entity 'texans'. The context provided indicates that the Texans are an American football team, not a hockey team."},
 {'entity_label': 'gonzaga_bulldogs',
  'property_name': 'teamplayssport',
  'property_value': 'basketball',
  'property_is_valid': True,
  'is_valid_reason': "The context shows that the Gonzaga Bulldogs men's basketball team plays the sport of basketball.",
  'error_message': None},
 {'entity_label': 'esu_hornets',
  'property_name': 'teamplayssport',
  'property_value': 'basketball',
  'property_is_valid': 'Not enough information to say',
  'is_valid_reason': None,
  'error_message': None},
 {'entity_label': 'columbus_blue_jackets',
  'property_name': 'teamplayssport',
  'property_value': 'hockey',
  'property_is_va

In [30]:
tp = 0
fp = 0
tn = 0
fn = 0 
for val in pos_results[0].model_dump()['validated_properties']:
    if val['property_is_valid']:    # property is correctly marked as valid
        tp += 1
    else:                           # property is incorrectly marked as invalid
        fn += 1
for val in neg_results[0].model_dump()['validated_properties']:
    if val['property_is_valid']:    # property is incorrectly marked as valid
        fp += 1
    else:                           # property is correctly marked as invalid
        tn += 1

precision = (tp / (tp + fp)) if tp + fp > 0 else 0
recall = (tp / (tp + fn)) if tp + fn > 0 else 0
f1_score = ((2 * (precision * recall)) / (precision + recall)) if precision + recall > 0 else 0
print(f"Precision: {precision}")
print(f"Recall: {recall}")
print(f"F1 Score: {f1_score}")
print("----------")

Precision: 0.8
Recall: 0.5714285714285714
F1 Score: 0.6666666666666666
----------


In [26]:
results_json = [r.model_dump() for r in neg_results] + [r.model_dump() for r in pos_results]
save_jsonl(results_json, '../../data/wikidata_kg_context_evaluation_results.jsonl')

Saved to f'../../data/wikidata_kg_context_evaluation_results.jsonl


# Look at our Evaluations

In [27]:
results = read_jsonl('../../data/wikidata_kg_context_evaluation_results.jsonl')
len(results)

2