In [2]:
from distutils.command.config import config

from langchain_core.runnables import RunnableParallel
from numpy.f2py.cfuncs import callbacks
from sqlalchemy import false

import helper_tools.parser as parser
import importlib
import pandas as pd
import warnings
warnings.filterwarnings("ignore")

''
importlib.reload(parser)

relation_df, entity_df, docs = parser.synthie_parser("train")
entity_set = entity_df[['entity', 'entity_uri']].drop_duplicates()
predicate_set_df = relation_df[["predicate", "predicate_uri"]].drop_duplicates()

Fetching 27 files:   0%|          | 0/27 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:00<00:00, 7904.83it/s]


Uploading Entities to Qdrant.


100%|██████████| 46/46 [00:05<00:00,  8.65it/s]


Uploading Predicates to Qdrant.


100%|██████████| 29/29 [00:03<00:00,  8.66it/s]


In [3]:
docs.head()

Unnamed: 0,docid,text
0,0,Corfe Castle railway station is a station on t...
1,1,Ricardo Lumengo is a Swiss politician. He was ...
2,2,The National Parks Project is a nature documen...
3,3,Lambda Mensae is a star in the constellation M...
4,4,"John Derek was an American actor, director and..."


In [4]:
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from langchain_ollama.embeddings import OllamaEmbeddings
import os

load_dotenv()

model = ChatOpenAI(model_name="Meta-Llama-3.3-70B-Instruct", base_url="https://api.sambanova.ai/v1", api_key=os.getenv("SAMBANOVA_API_KEY"))
embeddings = OllamaEmbeddings(model='nomic-embed-text')

In [5]:
from langchain_qdrant import QdrantVectorStore

from qdrant_client import QdrantClient

client = QdrantClient("localhost", port=6333)
vector_store = QdrantVectorStore(
    client=client,
    collection_name="wikidata_labels",
    embedding=embeddings
)

In [6]:
from langgraph.types import Command
from typing import Literal, TypedDict
from langgraph.graph import StateGraph, MessagesState, START, END
from langchain_core.prompts import PromptTemplate
import re
from typing import TypedDict
from langchain_core.messages import HumanMessage, AIMessage

class cIEState(TypedDict):
    text: str
    messages: list[HumanMessage | AIMessage]
    instruction: str
    verbose: bool

system_prompt = f"""
You are the Supervisor of a conversation among multiple agents.
The conversation is about extracting information (Closed Information Extraction) from a user-provided text. The final output should only contain wikidata URIs instead of the labels of entities and relations. You can provide additional information to the agents using <instruction> tags.

Example Output: <relation>http://www.wikidata.org/entity/Q950380;http://www.wikidata.org/entity/P361;http://www.wikidata.org/entity/Q2576666</relation>

Agent Descriptions:
- entity_extraction_agent: Extracts entities from the text. Instructions can change the extraction behavior and focus of the agent. Do not instruct the agent with URIs. The agent has only access to your instruction and the text.
- relation_extraction_agent: Extracts relations from the text. Instructions can change the prompt of the called agent and can be used to input already extracted entity labels (i.e. <instruction>Please use the already extracted entities: [Olaf Scholz, Germany, Berlin]</instruction>). Do not instruct the agent with URIs. The agent has only access to your instruction and the text.
- uri_detection_agent: Please only use this, after entities and relation were extracted at least once. Returns possible wikidata URIs for entities and predicates based on similarity search. The instruction should be a list of search terms like predicate and entity labels, which the uri detection agent is searching for. For example: "Olaf Scholz, Germany, Berlin, is chancellor of, part of". It is recommended to use the agent at least once to search for the URIs for all possible entities and predicates. Do not instruct the agent with URIs. The agent has only access to your instruction and the text. If this agent return relations or entities that were not extracted before i.e. industry - URI of sport please try to use them as instruction for the relation extraction agent.

You have two options:
1. Call an agent using <goto>agent_name</goto>. Replace agent_name with either entity_extraction_agent or relation_extraction_agent. I.e. <goto>entity_extraction_agent</goto>.
2. Finish the conversation using <goto>FINISH</goto>. Please output the final relations in <relation> tags alongside with the <goto> tag.


Note:
- Do not provide any information yourself, instead use the agents for this.
- The first <goto> tag in your response will be executed.
- Therefore, do include exact one agent call in your response.
- If you output nothing, this will result in a NoneType Error.
- Please do not hallucinate any URI.


"""

def supervisor(state: cIEState) -> Command[Literal["entity_extraction_agent", "relation_extraction_agent", "uri_detection_agent", END]]:    
    
    response = model.invoke(state["messages"])
    
    if response.content == "":
        print(state["messages"])
        return Command(goto=END)
    
    if state["verbose"]:
        print(f"-- START OF OUTPUT (supervisor) --\n\n", response.content, "\n\n-- END OF OUTPUT --\n\n")
        
    goto_match = re.search(r'<goto>(.*?)</goto>', response.content)
    if goto_match:
        goto = goto_match.group(1)
        if goto == "FINISH":
            goto = END
    else:
        goto = "supervisor"
        
    instruction_match = re.search(r'<instruction>(.*?)</instruction>', response.content)
    if instruction_match:
        instruction = instruction_match.group(1)
    else:
        instruction = ""

    return Command(goto=goto, update={"messages": state["messages"] + [response], "instruction": instruction})

def entity_extraction_agent(state: cIEState) -> Command[Literal["supervisor"]]:
    prompt_template = PromptTemplate.from_template("""
    You are an agent tasked with extracting entities from a given text for linking to a knowledge graph. Your job is to capture every entity—both explicit and implicit—and return them as an array. This includes composite entities with modifiers (e.g., "professional football club"). Please output the entities as an array of strings. Do not include any further information or comment in the output.
    
    Example Output: [Olaf Scholz, chancellor, Germany]
    
    Guidelines:
    - An entity is a unique real-world object or concept represented as a node with its properties and relationships.
    - Extract every entity mentioned in the text, including those that are not immediately obvious.
    - For composite entities, include the full descriptive phrase and break it into its core components when appropriate. For example, "chancellor of Germany" should yield [chancellor, Germany] and "professional football club" should capture the descriptive phrase as needed.
    - For composite entities that include a date at the beginning or end, extract the date separately, the entity without the date, and the full composite (e.g., "2022 Winter Olympics" should result in [2022, 2022 Winter Olympics, Winter Olympics]).
    - Also, ensure that dates are extracted as entities.
    
    Instruction: {instruction}
    
    Text: {text}
    """)
    chain = prompt_template | model
    response = chain.invoke({"text": state["text"], "instruction": state["instruction"]})
    
    if state["verbose"]:
        print(f"-- START OF OUTPUT (entity_extraction_agent) --\n\n", response.content, "\n\n-- END OF OUTPUT --\n\n")
    
    return Command(goto="supervisor", update={"messages": state["messages"] + [response], "instruction": ""})

def relation_extraction_agent(state: cIEState) -> Command[Literal["supervisor"]]:
    prompt_template = PromptTemplate.from_template(
        """
        You are a relation extraction agent. Your task is to analyze the provided text and extract all semantic relations present. Each relation must be output in the exact format:
        
        <relation>subject;predicate;object</relation>
        
        (For example: <relation>Olaf Scholz;is chancellor of;Germany</relation>).
        
        Guidelines:
        - **Extraction Scope:** Extract only the relations explicitly mentioned in the text. Additionally, if the text implies a relation or if a relation can be inferred using the provided entity list, include that relation.
        - **Utilize Provided Entities:** Use the provided list of extracted entities to ensure that all relevant relations are captured. For example, if "technology" is in the list and the text indicates that the subject is a technology company, you must output: <relation>Apple;industry;technology</relation>.
        - **Attribute Relations:** If an entity is described by a characteristic or category (e.g., renowned film director, prestigious university), automatically extract the corresponding attribute relation. For example, if the text states "Steven Spielberg is a renowned film director", extract: <relation>Steven Spielberg;profession;film director</relation>.
        - **Formatting:** Each relation must strictly follow the format <relation>subject;predicate;object</relation> with no additional text or commentary.
        - **Accuracy:** Only include relations that are clearly supported by the text or can be confidently inferred using the provided entity list.
        
        Instruction: {instruction}
        
        Text: {text}
        """
    )
    chain = prompt_template | model
    response = chain.invoke({"text": state["text"], "instruction": state["instruction"]})
    
    if state["verbose"]:
        print(f"-- START OF OUTPUT (relation_extraction_agent) --\n\n", response.content, "\n\n-- END OF OUTPUT --\n\n")
    
    return Command(goto="supervisor", update={"messages": state["messages"] + [response], "instruction": ""})

def uri_detection_agent(state):
    search_terms = state["instruction"].split(",")
    response = ""
    for term in search_terms:
        response += f'Most Similar Detection Results for {term}:{[{"label": doc.page_content, "uri": doc.metadata["uri"]} for doc in vector_store.similarity_search(term, k=3)]}\n\n'
    response = response.replace("},", "},\n")
    
    if state["verbose"]:
        print(f"-- START OF OUTPUT (uri_detection_tool) --\n\n", response, "\n\n-- END OF OUTPUT --\n\n")
        
    prompt_template = PromptTemplate.from_template(
        """
        You are a formatting agent. Your task is to check and format the output of the URI detection tool. The tool will give a response like this:
        Most Similar Detection Result for Olaf Scholz: ('label': Angela Merkel, 'uri': 'http://www.wikidata.org/entity/Q567)
        
        Your task is to check the response and output an overall mapping of search terms to URIs. If something doesn't match, please response the non mapping search term with the advise, that those might not be present in the knowledge graph.
        
        URI detection tool response:
        
        {response}
        """
    )
    
    chain = prompt_template | model
    response = chain.invoke({"response": response})
    
    if state["verbose"]:
        print(f"-- START OF OUTPUT (uri_detection_agent) --\n\n", response.content, "\n\n-- END OF OUTPUT --\n\n")
        
    return Command(goto="supervisor", update={"messages": state["messages"] + [response], "instruction": ""})
    
builder = StateGraph(cIEState)
builder.add_node(supervisor)
builder.add_node(entity_extraction_agent)
builder.add_node(relation_extraction_agent)
builder.add_node(uri_detection_agent)

builder.add_edge(START, "supervisor")

graph = builder.compile()

In [8]:
from langfuse.callback import CallbackHandler
from dotenv import load_dotenv
import os

load_dotenv()
langfuse_handler = CallbackHandler(
    secret_key=os.getenv("LANGFUSE_SECRET_KEY"),
    public_key=os.getenv("LANGFUSE_PUBLIC_KEY"),
    host=os.getenv("LANGFUSE_HOST"),
)

target_doc = docs.iloc[1]
doc_id = target_doc["docid"]
response = graph.invoke({"text": target_doc["text"], "messages": [system_prompt, target_doc["text"]], "instruction": "", "verbose": True}, config={"callbacks": [langfuse_handler]})

-- START OF OUTPUT (supervisor) --

 <instruction>Please extract entities from the text.</instruction>
<goto>entity_extraction_agent</goto> 

-- END OF OUTPUT --


-- START OF OUTPUT (entity_extraction_agent) --

 [Ricardo Lumengo, Swiss politician, Switzerland, Fribourg, Biel/Bienne, Bern, Kongo language] 

-- END OF OUTPUT --


-- START OF OUTPUT (supervisor) --

 <instruction>Please use the already extracted entities: Ricardo Lumengo, Swiss, Switzerland, Fribourg, Biel/Bienne, Bern, Kongo language</instruction>
<goto>relation_extraction_agent</goto> 

-- END OF OUTPUT --


-- START OF OUTPUT (relation_extraction_agent) --

 <relation>Ricardo Lumengo;nationality;Swiss</relation>
<relation>Ricardo Lumengo;birthplace;Fribourg</relation>
<relation>Ricardo Lumengo;residence;Biel/Bienne</relation>
<relation>Ricardo Lumengo;workplace;Bern</relation>
<relation>Ricardo Lumengo;language;Kongo language</relation>
<relation>Ricardo Lumengo;profession;politician</relation>
<relation>Ricardo Lume

In [17]:
def get_uri_labels(df): 
    subjects = []
    predicates = []
    objects = []
    for i, row in df.iterrows():
        try:
            subjects.append(entity_set[entity_set["entity_uri"] == row["subject_uri"]]["entity"].values[0])
        except IndexError:
            subjects.append("Unknown")
        try:
            predicates.append(predicate_set_df[predicate_set_df["predicate_uri"] == row["predicate_uri"]]["predicate"].values[0])
        except IndexError:
            predicates.append("Unknown")
        if row["object_uri"] is not None and "^^" in row["object_uri"]:
            objects.append(row["object_uri"])
        else:
            try:
                objects.append(entity_set[entity_set["entity_uri"] == row["object_uri"]]["entity"].values[0])
            except IndexError:
                objects.append("Unknown")
    return pd.concat([df.reset_index(drop=True), pd.DataFrame({"subject": subjects, "predicate": predicates, "object": objects})], axis=1)

In [18]:
relation_list = [x.split(";") for x in re.findall(r"<relation>(.*?)</relation>", response["messages"][-1].content)]
pred_relation_df = pd.DataFrame(relation_list, columns=["subject_uri", "predicate_uri", "object_uri"]).drop_duplicates()
get_uri_labels(pred_relation_df)

Unnamed: 0,subject_uri,predicate_uri,object_uri,subject,predicate,object
0,http://www.wikidata.org/entity/Q677663,http://www.wikidata.org/entity/P27,http://www.wikidata.org/entity/Q39,Ricardo_Lumengo,country of citizenship,Switzerland
1,http://www.wikidata.org/entity/Q677663,http://www.wikidata.org/entity/P19,http://www.wikidata.org/entity/Q36378,Ricardo_Lumengo,Unknown,Fribourg
2,http://www.wikidata.org/entity/Q677663,http://www.wikidata.org/entity/P551,http://www.wikidata.org/entity/Q1034,Ricardo_Lumengo,residence,Biel/Bienne
3,http://www.wikidata.org/entity/Q677663,http://www.wikidata.org/entity/P937,http://www.wikidata.org/entity/Q70,Ricardo_Lumengo,work location,Bern
4,http://www.wikidata.org/entity/Q677663,http://www.wikidata.org/entity/P1412,http://www.wikidata.org/entity/Q33702,Ricardo_Lumengo,"languages spoken, written or signed",Kongo_language
5,http://www.wikidata.org/entity/Q677663,http://www.wikidata.org/entity/P106,http://www.wikidata.org/entity/Q82955,Ricardo_Lumengo,occupation,Politician
6,http://www.wikidata.org/entity/Q677663,http://www.wikidata.org/entity/P551,http://www.wikidata.org/entity/Q39,Ricardo_Lumengo,residence,Switzerland


In [19]:
doc_relation_df = relation_df[relation_df["docid"] == doc_id][["subject_uri", "predicate_uri", "object_uri"]]
get_uri_labels(doc_relation_df)

Unnamed: 0,subject_uri,predicate_uri,object_uri,subject,predicate,object
0,http://www.wikidata.org/entity/Q677663,http://www.wikidata.org/entity/P1321,http://www.wikidata.org/entity/Q36378,Ricardo_Lumengo,place of origin (Switzerland),Fribourg
1,http://www.wikidata.org/entity/Q677663,http://www.wikidata.org/entity/P27,http://www.wikidata.org/entity/Q39,Ricardo_Lumengo,country of citizenship,Switzerland
2,http://www.wikidata.org/entity/Q677663,http://www.wikidata.org/entity/P551,http://www.wikidata.org/entity/Q1034,Ricardo_Lumengo,residence,Biel/Bienne
3,http://www.wikidata.org/entity/Q677663,http://www.wikidata.org/entity/P937,http://www.wikidata.org/entity/Q70,Ricardo_Lumengo,work location,Bern
4,http://www.wikidata.org/entity/Q677663,http://www.wikidata.org/entity/P1412,http://www.wikidata.org/entity/Q33702,Ricardo_Lumengo,"languages spoken, written or signed",Kongo_language


In [20]:
correct_relation_df = pred_relation_df.merge(doc_relation_df[["subject_uri", "predicate_uri", "object_uri"]], on=["subject_uri", "predicate_uri", "object_uri"], how="inner")
get_uri_labels(correct_relation_df)

Unnamed: 0,subject_uri,predicate_uri,object_uri,subject,predicate,object
0,http://www.wikidata.org/entity/Q677663,http://www.wikidata.org/entity/P27,http://www.wikidata.org/entity/Q39,Ricardo_Lumengo,country of citizenship,Switzerland
1,http://www.wikidata.org/entity/Q677663,http://www.wikidata.org/entity/P551,http://www.wikidata.org/entity/Q1034,Ricardo_Lumengo,residence,Biel/Bienne
2,http://www.wikidata.org/entity/Q677663,http://www.wikidata.org/entity/P937,http://www.wikidata.org/entity/Q70,Ricardo_Lumengo,work location,Bern
3,http://www.wikidata.org/entity/Q677663,http://www.wikidata.org/entity/P1412,http://www.wikidata.org/entity/Q33702,Ricardo_Lumengo,"languages spoken, written or signed",Kongo_language


In [11]:
def evaluate(pred_relation_df, doc_id, verbose=False):
    doc_relation_df = relation_df[relation_df["docid"] == doc_id][["subject_uri", "predicate_uri", "object_uri"]]
    correct_relation_df = pred_relation_df.merge(doc_relation_df[["subject_uri", "predicate_uri", "object_uri"]], on=["subject_uri", "predicate_uri", "object_uri"], how="inner")
    precision = len(correct_relation_df) / len(pred_relation_df)
    recall = len(correct_relation_df) / len(doc_relation_df)
    if precision + recall == 0:
        f1_score = 0
    else:
        f1_score = 2 * (precision * recall) / (precision + recall)
    
    if verbose:
        print(f"Precision: {precision}")
        print(f"Recall: {recall}")
        print(f"F1: {f1_score}")
        
    return precision, recall, f1_score

In [22]:
evaluate(pred_relation_df, doc_id, verbose=True)

Precision: 0.5714285714285714
Recall: 0.8
F1: 0.6666666666666666


(0.5714285714285714, 0.8, 0.6666666666666666)

# Evaluation on Test

In [12]:
from tqdm import tqdm

evaluation_df = []

for i, target_doc in tqdm(docs.iterrows()):
    doc_id = target_doc["docid"]
    response = graph.invoke({"text": target_doc["text"], "messages": [system_prompt, target_doc["text"]], "instruction": "", "verbose": False}, config={"callbacks": [langfuse_handler]})
    relation_list = [x.split(";") for x in re.findall(r"<relation>(.*?)</relation>", response["messages"][-1].content)]
    pred_relation_df = pd.DataFrame(relation_list, columns=["subject_uri", "predicate_uri", "object_uri"]).drop_duplicates()
    evaluation_df.append([doc_id, *evaluate(pred_relation_df, doc_id, verbose=False)])
    if i >= 10:
        break
    
evaluation_df = pd.DataFrame(evaluation_df, columns=["docid", "precision", "recall", "f1_score"])
evaluation_df

10it [00:58,  5.87s/it]


Unnamed: 0,docid,precision,recall,f1_score
0,0,0.0,0.0,0.0
1,1,0.571429,0.8,0.666667
2,2,0.0,0.0,0.0
3,3,0.375,0.5,0.428571
4,4,0.333333,1.0,0.5
5,5,0.166667,0.333333,0.222222
6,6,0.0,0.0,0.0
7,7,0.25,0.25,0.25
8,8,0.1,0.5,0.166667
9,9,0.5,0.2,0.285714


In [16]:
print(f'F1 (Macro Avg.): {evaluation_df["f1_score"].mean()}')
print(f'Precision (Macro Avg.): {evaluation_df["precision"].mean()}')
print(f'Recall (Macro Avg.): {evaluation_df["recall"].mean()}')

F1 (Macro Avg.): 0.2519841269841269
Precision (Macro Avg.): 0.22964285714285718
Recall (Macro Avg.): 0.35833333333333334
