In [117]:
from langchain_core.runnables import RunnableParallel

import helper_tools.parser as parser
import importlib
import pandas as pd
''
importlib.reload(parser)

relation_df, entity_df, docs = parser.redfm_parser("train")

Fetching 22 files:   0%|          | 0/22 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:00<00:00, 1916.87it/s]


In [118]:
docs.head()

Unnamed: 0,docid,text
0,1755846-1,CBS Corporation comprised the over-the-air tel...
1,1755846-2,The second merger between CBS Corporation and ...
2,1701411-0,Club Sportivo Cienciano is a professional foot...
3,1854133-1,It is the seat of a municipality with 203.30 k...
4,1602703-0,Bad Ischl is a spa town in Austria. It lies in...


In [119]:
from dotenv import load_dotenv
from gen_ai_hub.proxy.langchain.openai import ChatOpenAI
from gen_ai_hub.proxy.core.proxy_clients import get_proxy_client
from gen_ai_hub.proxy.langchain.init_models import init_embedding_model

load_dotenv()

proxy_client = get_proxy_client('gen-ai-hub')
model = ChatOpenAI(proxy_model_name='meta--llama3-70b-instruct', proxy_client=proxy_client)
embeddings = init_embedding_model('text-embedding-3-large')

In [120]:
import faiss
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document

index = faiss.IndexFlatL2(len(embeddings.embed_query("hello world")))

vector_store = FAISS(
    embedding_function=embeddings,
    index=index,
    docstore=InMemoryDocstore(),
    index_to_docstore_id={},
)

documents = []
entity_set = entity_df[['entity', 'entity_uri']].drop_duplicates()
for index, row in entity_set.iterrows():
    documents.append(Document(
        page_content=row["entity"],
        metadata={"uri": row["entity_uri"]},
    ))
    
predicate_set_df = relation_df[["predicate", "predicate_uri"]].drop_duplicates()
for index, row in predicate_set_df.iterrows():
    documents.append(Document(
        page_content=row["predicate"],
        metadata={"uri": row["predicate_uri"]},
    ))
    
faiss_document_ids = vector_store.add_documents(documents=documents)

In [172]:
from langgraph.types import Command
from typing import Literal, TypedDict
from langgraph.graph import StateGraph, MessagesState, START, END
from langchain_core.prompts import PromptTemplate
import re
from typing import TypedDict
from langchain_core.messages import HumanMessage, AIMessage

class cIEState(TypedDict):
    text: str
    messages: list[HumanMessage | AIMessage]
    instruction: str

system_prompt = f"""
You are the Supervisor of a conversation among multiple agents.
The conversation is about extracting information (Closed Information Extraction) from a user-provided text. The final output should only contain wikidata URIs instead of the labels of entities and relations. You can provide additional information to the agents using <instruction> tags. I.e. <instruction>Search additional for entities that are not obvious.</instruction>.

Example Output: <relation>Q950380;P361;Q2576666</relation>

Agent Descriptions:
- entity_extraction_agent: Extracts entities from the text. Instructions can change the prompt of the called agent. The agent has only access to your instruction and the text.
- relation_extraction_agent: Extracts relations from the text. Instructions can change the prompt of the called agent and can be used to input already extracted entity labels (i.e. <instruction>Olaf Scholz, Germany, Berlin</instruction>). DO NOT INPUT WIKIDATA URIs like Q950380 or P361. The agent has only access to your instruction and the text.
- uri_detection_agent: Returns possible wikidata URIs for entities and relations based on similarity search. The instruction should be a simple list of search terms, which the uri detection agent is searching for. I.e. "Olaf Scholz, Germany, Berlin". DO NOT INPUT WIKIDATA URIs like Q950380 or P361. The agent has only access to your instruction and the text.

You have two options:
1. Call an agent using <goto>agent_name</goto>. Replace agent_name with either entity_extraction_agent or relation_extraction_agent. I.e. <goto>entity_extraction_agent</goto>.
2. Finish the conversation using <goto>FINISH</goto>. Please output the final relations in <relation> tags alongside with the <goto> tag.


Note:
- Do not provide any information yourself, instead use the agents for this.
- The first <goto> tag in your response will be executed.
- Therefore, do include exact one agent call in your response.
- If you output nothing, this will result in a NoneType Error.


"""

def supervisor(state: cIEState) -> Command[Literal["entity_extraction_agent", "relation_extraction_agent", "uri_detection_agent", END]]:    
    
    response = model.invoke(state["messages"])
    
    print(f"-- START OF OUTPUT (supervisor) --\n\n", response.content, "\n\n-- END OF OUTPUT --\n\n")
        
    goto_match = re.search(r'<goto>(.*?)</goto>', response.content)
    if goto_match:
        goto = goto_match.group(1)
        if goto == "FINISH":
            goto = END
    else:
        goto = "supervisor"
        
    instruction_match = re.search(r'<instruction>(.*?)</instruction>', response.content)
    if instruction_match:
        instruction = instruction_match.group(1)
    else:
        instruction = ""

    return Command(goto=goto, update={"messages": state["messages"] + [response], "instruction": instruction})

def entity_extraction_agent(state: cIEState) -> Command[Literal["supervisor"]]:
    prompt_template = PromptTemplate.from_template("""
    You are an agent tasked with extracting entities from a given text for linking to a knowledge graph. Your job is to capture every entity—both explicit and implicit—and return them as an array. This includes composite entities with modifiers (e.g., "professional football club"). Please output the entities as an array of strings. Do not include any further information or comment in the output.
    
    Example Output: [Olaf Scholz, chancellor, Germany]
    
    Guidelines:
    - An entity is a unique real-world object or concept represented as a node with its properties and relationships.
    - Extract every entity mentioned in the text, including those that are not immediately obvious.
    - For composite entities, include the full descriptive phrase and break it into its core components when appropriate. For example, "chancellor of Germany" should yield [chancellor, Germany] and "professional football club" should capture the descriptive phrase as needed.
    - For composite entities that include a date at the beginning or end, extract the date separately, the entity without the date, and the full composite (e.g., "2022 Winter Olympics" should result in [2022, 2022 Winter Olympics, Winter Olympics]).
    - Also, ensure that dates are extracted as entities.
    
    Instruction: {instruction}
    
    Text: {text}
    """)
    chain = prompt_template | model
    response = chain.invoke({"text": state["text"], "instruction": state["instruction"]})
    
    print(f"-- START OF OUTPUT (entity_extraction_agent) --\n\n", response.content, "\n\n-- END OF OUTPUT --\n\n")
    
    return Command(goto="supervisor", update={"messages": state["messages"] + [response], "instruction": ""})

def relation_extraction_agent(state: cIEState) -> Command[Literal["supervisor"]]:
    prompt_template = PromptTemplate.from_template(
        """
        You are a relation extraction agent. Your task is to read the text of the user message and extract the relations found in the text. Each relation should be written in this exact format: <relation>subject;predicate;object</relation> (e.g.: <relation>Olaf Scholz;is chancellor of;Germany</relation>). Please return only the relations and no other information.
    
        Note: In addition to the explicit relations mentioned in the text, if an entity is described by a characteristic or category (e.g., renowned film director, prestigious university), you must also extract the corresponding attribute relation automatically. For example, if the text states that "Steven Spielberg is a renowned film director", you should extract: <relation>Steven Spielberg;profession;film director</relation>.
    
        Instruction: {instruction}
    
        Text: {text}
        """
    )
    chain = prompt_template | model
    response = chain.invoke({"text": state["text"], "instruction": state["instruction"]})
    
    print(f"-- START OF OUTPUT (relation_extraction_agent) --\n\n", response.content, "\n\n-- END OF OUTPUT --\n\n")
    
    return Command(goto="supervisor", update={"messages": state["messages"] + [response], "instruction": ""})

def uri_detection_agent(state):
    search_terms = state["instruction"].split(",")
    response = ""
    for term in search_terms:
        response += f'Search Results for {term}:\n{[{"label": doc.page_content, "uri": doc.metadata["uri"], "similarity_score": score} for doc, score in vector_store.similarity_search_with_score(term, search_type="similarity", k=3)]}\n\n'
    response = response.replace("},", "},\n")
    return Command(goto="supervisor", update={"messages": state["messages"] + [response], "instruction": ""})
    
builder = StateGraph(cIEState)
builder.add_node(supervisor)
builder.add_node(entity_extraction_agent)
builder.add_node(relation_extraction_agent)
builder.add_node(uri_detection_agent)

builder.add_edge(START, "supervisor")

graph = builder.compile()

In [173]:
from langfuse.callback import CallbackHandler
from dotenv import load_dotenv
import os

load_dotenv()
langfuse_handler = CallbackHandler(
    secret_key=os.getenv("LANGFUSE_SECRET_KEY"),
    public_key=os.getenv("LANGFUSE_PUBLIC_KEY"),
    host=os.getenv("LANGFUSE_HOST"),
)

target_doc = docs.iloc[2]
doc_id = target_doc["docid"]
response = graph.invoke({"text": target_doc["text"], "messages": [system_prompt, target_doc["text"]], "instruction": ""}, config={"callbacks": [langfuse_handler]})

-- START OF OUTPUT (supervisor) --

 <goto>entity_extraction_agent</goto> 

-- END OF OUTPUT --


-- START OF OUTPUT (entity_extraction_agent) --

 [Club Sportivo Cienciano, professional football club, Cusco, Peru, 1901, Ciencias y Artes, science, School, River Plate, Argentina, 2003, Copa Sudamericana, Boca Juniors, 2004, Recopa Sudamericana] 

-- END OF OUTPUT --


-- START OF OUTPUT (supervisor) --

 <instruction>Search additional for entities that are not obvious.</instruction>
<goto>entity_extraction_agent</goto> 

-- END OF OUTPUT --


-- START OF OUTPUT (entity_extraction_agent) --

 [Club Sportivo Cienciano, professional football club, Cusco, Peru, 1901, Ciencias y Artes School, science, Spanish, River Plate, Argentina, 2003, Copa Sudamericana, Boca Juniors, 2004, Recopa Sudamericana] 

-- END OF OUTPUT --


-- START OF OUTPUT (supervisor) --

 <instruction>Club Sportivo Cienciano, Cusco, Peru, Ciencias y Artes School, River Plate, Argentina, Boca Juniors, Copa Sudamericana, Re

In [174]:
def get_uri_labels(df): 
    subjects = []
    predicates = []
    objects = []
    for i, row in df.iterrows():
        try:
            subjects.append(entity_set[entity_set["entity_uri"] == row["subject_uri"]]["entity"].values[0])
        except IndexError:
            subjects.append("Unknown")
        try:
            predicates.append(predicate_set_df[predicate_set_df["predicate_uri"] == row["predicate_uri"]]["predicate"].values[0])
        except IndexError:
            predicates.append("Unknown")
        if "^^" in row["object_uri"]:
            objects.append(row["object_uri"])
        else:
            try:
                objects.append(entity_set[entity_set["entity_uri"] == row["object_uri"]]["entity"].values[0])
            except IndexError:
                objects.append("Unknown")
    return pd.concat([df.reset_index(drop=True), pd.DataFrame({"subject": subjects, "predicate": predicates, "object": objects})], axis=1)

In [175]:
relation_list = [x.split(";") for x in re.findall(r"<relation>(.*?)</relation>", response["messages"][-1].content)]
pred_relation_df = pd.DataFrame(relation_list, columns=["subject_uri", "predicate_uri", "object_uri"])
get_uri_labels(pred_relation_df)

Unnamed: 0,subject_uri,predicate_uri,object_uri,subject,predicate,object
0,Q602482,P159,Q5582862,Cienciano,headquarters location,Cusco
1,Q602482,P31,Q3624078,Cienciano,instance of,Unknown
2,Q602482,P577,1901,Cienciano,Unknown,Unknown
3,Q602482,P31,Q482994,Cienciano,instance of,Unknown
4,Q602482,P361,Q15799,Cienciano,part of,River Plate
5,Q602482,P361,Q170703,Cienciano,part of,Boca Juniors
6,Q602482,P31,Q2066133,Cienciano,instance of,Unknown
7,Q602482,P1344,Q60585,Cienciano,Unknown,Copa Sudamericana
8,Q602482,P1344,Q4603244,Cienciano,Unknown,2004 Recopa Sudamericana


In [176]:
doc_relation_df = relation_df[relation_df["docid"] == doc_id][["subject_uri", "predicate_uri", "object_uri"]]
get_uri_labels(doc_relation_df)

Unnamed: 0,subject_uri,predicate_uri,object_uri,subject,predicate,object
0,Q602482,P641,Q2736,Cienciano,sport,football
1,Q602482,P159,Q5582862,Cienciano,headquarters location,Cusco
2,Q602482,P571,1901-01-01T00:00:00Z^^http://www.w3.org/2001/X...,Cienciano,inception,1901-01-01T00:00:00Z^^http://www.w3.org/2001/X...
3,Q15799,P17,Q414,River Plate,country,Argentina


In [177]:
correct_relation_df = pred_relation_df.merge(doc_relation_df[["subject_uri", "predicate_uri", "object_uri"]], on=["subject_uri", "predicate_uri", "object_uri"], how="inner")
correct_relation_df

Unnamed: 0,subject_uri,predicate_uri,object_uri
0,Q602482,P159,Q5582862


In [178]:
def evaluate(pred_relation_df, doc_id, verbose=False):
    doc_relation_df = relation_df[relation_df["docid"] == doc_id][["subject_uri", "predicate_uri", "object_uri"]]
    correct_relation_df = pred_relation_df.merge(doc_relation_df[["subject_uri", "predicate_uri", "object_uri"]], on=["subject_uri", "predicate_uri", "object_uri"], how="inner")
    precision = len(correct_relation_df) / len(pred_relation_df)
    recall = len(correct_relation_df) / len(doc_relation_df)
    f1_score = 2 * (precision * recall) / (precision + recall)
    
    if verbose:
        print(f"Precision: {precision}")
        print(f"Recall: {recall}")
        print(f"F1: {f1_score}")
        
    return precision, recall, f1_score

In [179]:
evaluate(pred_relation_df, doc_id, verbose=True)

Precision: 0.1111111111111111
Recall: 0.25
F1: 0.15384615384615383


(0.1111111111111111, 0.25, 0.15384615384615383)

# Evaluation on Test

In [144]:
evaluation_df = []

for i, target_doc in docs.iterrows():
    doc_id = target_doc["docid"]
    response = graph.invoke({"text": target_doc["text"], "messages": [system_prompt, target_doc["text"]], "instruction": ""}, config={"callbacks": [langfuse_handler]})
    relation_list = [x.split(";") for x in re.findall(r"<relation>(.*?)</relation>", response["messages"][-1].content)]
    pred_relation_df = pd.DataFrame(relation_list, columns=["subject_uri", "predicate_uri", "object_uri"])
    evaluation_df.append([doc_id, *evaluate(pred_relation_df, doc_id, verbose=False)])
    if i >= 3:
        break
    
evaluation_df = pd.DataFrame(evaluation_df, columns=["docid", "precision", "recall", "f1_score"])
evaluation_df

-- START OF OUTPUT (supervisor) --

 <goto>entity_extraction_agent</goto>
<instruction>Extract all entities from the text, including companies, locations, and networks.</instruction> 

-- END OF OUTPUT --


-- START OF OUTPUT (entity_extraction_agent) --

 [CBS Corporation, CBS, The CW, Viacom, CBS Building, Midtown Manhattan, New York City] 

-- END OF OUTPUT --


-- START OF OUTPUT (supervisor) --

 <goto>relation_extraction_agent</goto>
<instruction>Extract relations between the entities, focusing on ownership, location, and composition.</instruction> 

-- END OF OUTPUT --


-- START OF OUTPUT (relation_extraction_agent) --

 <relation>CBS Corporation;comprised;over-the-air television broadcasting</relation>
<relation>CBS Corporation;comprised;television production and distribution</relation>
<relation>CBS Corporation;comprised;publishing</relation>
<relation>CBS Corporation;comprised;pay-cable</relation>
<relation>CBS Corporation;comprised;recording assets</relation>
<relation>over

KeyboardInterrupt: 