In [3]:
from langchain_core.runnables import RunnableParallel

import helper_tools.parser as parser
import importlib
import pandas as pd
''
importlib.reload(parser)

relation_df, entity_df, docs = parser.redfm_parser("train")

Fetching 22 files:   0%|          | 0/22 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:00<00:00, 15827.56it/s]


In [4]:
docs.head()

Unnamed: 0,docid,text
0,1755846-1,CBS Corporation comprised the over-the-air tel...
1,1755846-2,The second merger between CBS Corporation and ...
2,1701411-0,Club Sportivo Cienciano is a professional foot...
3,1854133-1,It is the seat of a municipality with 203.30 k...
4,1602703-0,Bad Ischl is a spa town in Austria. It lies in...


In [5]:
predicate_set_df = relation_df[["predicate", "predicate_uri"]].drop_duplicates()
predicate_set_df

Unnamed: 0,predicate,predicate_uri
0,owned by,P127
1,follows,P155
2,inception,P571
3,sport,P641
4,headquarters location,P159
6,country,P17
7,shares border with,P47
8,located in or next to body of water,P206
10,location,P276
13,spouse,P26


In [6]:
from dotenv import load_dotenv
from gen_ai_hub.proxy.langchain.openai import ChatOpenAI
from gen_ai_hub.proxy.core.proxy_clients import get_proxy_client
from gen_ai_hub.proxy.langchain.init_models import init_embedding_model

load_dotenv()

proxy_client = get_proxy_client('gen-ai-hub')
model = ChatOpenAI(proxy_model_name='meta--llama3-70b-instruct', proxy_client=proxy_client)
embeddings = init_embedding_model('text-embedding-3-large')

In [7]:
model.invoke("Hello, how are you?")

AIMessage(content="I'm doing well, thank you for asking! I'm a large language model, so I don't have emotions or feelings like humans do, but I'm always happy to chat with you and help with any questions or topics you'd like to discuss. How about you? How's your day going so far?", additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 64, 'prompt_tokens': 16, 'total_tokens': 80, 'completion_tokens_details': None, 'prompt_tokens_details': None}, 'model_name': 'meta--llama3-70b-instruct', 'system_fingerprint': None, 'finish_reason': 'stop', 'logprobs': None}, id='run-0494a28e-e7a2-40ca-bb04-bb0134ce0e69-0', usage_metadata={'input_tokens': 16, 'output_tokens': 64, 'total_tokens': 80, 'input_token_details': {}, 'output_token_details': {}})

In [8]:
import faiss
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document

index = faiss.IndexFlatL2(len(embeddings.embed_query("hello world")))

vector_store = FAISS(
    embedding_function=embeddings,
    index=index,
    docstore=InMemoryDocstore(),
    index_to_docstore_id={},
)

documents = []
entity_set = entity_df[['entity', 'entity_uri']].drop_duplicates()
for index, row in entity_set.iterrows():
    documents.append(Document(
        page_content=row["entity"],
        metadata={"uri": row["entity_uri"]},
    ))
    
vector_store.add_documents(documents=documents)

['bf7eae19-f3e7-4be7-9ad2-efd678c1960d',
 '28fc0ff2-a82c-4df6-9b1f-74925dd51f55',
 '6e70880e-144f-42bd-b482-adf07d69a800',
 'eb01fdcb-6870-4c1d-86ec-3d42006f80c5',
 '094495e5-f8b4-4586-aafc-8334d1ccef72',
 '6c77c893-0454-4c29-991a-bee2a7b27ed8',
 'b1f9a4bd-a5f1-46da-bb6c-0cef1cc39ebc',
 '18e1176f-6404-4553-a065-2397ae461834',
 '383007a1-d500-404e-b640-9a3e0cea81e8',
 '646997dc-e698-4171-87bc-b385389c9b0a',
 '0af92bc8-fb4b-4cea-8364-f3fa3a84cf7f',
 'ae164a42-181a-4e61-b67b-5571ac8c4dd3',
 '0b16311e-b248-4397-af09-7911996d90c7',
 'ab303510-19b3-4394-b61e-134cc4b2a62b',
 '069acd63-ee9b-42f0-aacc-3ae9ef97c58f',
 'f91c7f4b-1d60-4e7a-b7f1-651c2ea9067b',
 'aa057a7d-9e3a-4ded-96a1-0ec85d0fec97',
 '245d68a5-4410-46ce-a93a-bee64a6a42c2',
 'e71c0589-7ca4-4944-ae80-7369a8b31718',
 '1ed428f5-f4d1-43b1-ab8f-4bb532a3e492',
 '56874bc6-462c-4223-afca-2daa8499ae26',
 '794c20ff-2701-42c3-8bcd-8c8fa2191e78',
 '625e9c5e-78de-44ce-b4a1-60bafe81f2db',
 '4dd3e3d8-f80c-4f20-99b1-853e8fbb9f94',
 '57b69637-1ceb-

In [14]:
from langgraph.types import Command
from typing import Literal, TypedDict
from langgraph.graph import StateGraph, MessagesState, START, END
from langchain_core.prompts import PromptTemplate
import re
from typing import TypedDict
from langchain_core.messages import HumanMessage, AIMessage

class cIEState(TypedDict):
    text: str
    messages: list[HumanMessage | AIMessage]
    instruction: str

members = ["entity_extraction_agent", "relation_extraction_agent"]
# Our team supervisor is an LLM node. It just picks the next agent to process
# and decides when the work is completed
options = members + ["FINISH"]

system_prompt = f"""
You are the Supervisor of a conversation among multiple agents.
The conversation is about extracting information (Closed Information Extraction) from a user-provided text. The final output should only contain URIs isntead of labels or descriptions of entities or relations.

Agent Descriptions:
- entity_extraction_agent: Extracts entities from the text. Can not take any instructions.
- relation_extraction_agent: Extracts relations from the text. Can not take any instructions.
- uri_detection_agent: Detects URIs for entities based on similarity search. The instruction is the search term, which the uri detection agent is searching for.

You have two options:
1. Call an agent using <goto>agent_name</goto>. Replace agent_name with either entity_extraction_agent or relation_extraction_agent. I.e. <goto>entity_extraction_agent</goto>.
2. Finish the conversation using <goto>FINISH</goto>. Please output the final relations in <relation> tags alongside with the <goto> tag.

In addition to the options you can provide additional information to the agents using the <instruction> tag. I.e. <instruction>Search additional for entities that are not obvious.</instruction>.

Note:
- Do not provide any information yourself, instead use the agents for this.
- The first <goto> tag in your response will be executed.
- Therefore, do include exact one agent call in your response.
- If you output nothing, this will result in a NoneType Error.


"""

def supervisor(state: cIEState) -> Command[Literal["entity_extraction_agent", "relation_extraction_agent", END]]:    
    
    response = model.invoke(state["messages"])
    
    print(f"-- START OF OUTPUT (supervisor) --\n\n", response.content, "\n\n-- END OF OUTPUT --\n\n")
        
    goto_match = re.search(r'<goto>(.*?)</goto>', response.content)
    if goto_match:
        goto = goto_match.group(1)
        if goto == "FINISH":
            goto = END
    else:
        goto = "supervisor"
        
    instruction_match = re.search(r'<instruction>(.*?)</instruction>', response.content)
    if instruction_match:
        instruction = instruction_match.group(1)
    else:
        instruction = ""

    return Command(goto=goto, update={"messages": state["messages"] + [response], "instruction": instruction})

def entity_extraction_agent(state: cIEState) -> Command[Literal["supervisor"]]:
    prompt_template = PromptTemplate.from_template("""
        You are an agent to extract entities out of a given text. The entities will be linked to a knowledge graph later. You should return a list of any explicit and implicit entities found in the text. Please output the entities in an array like this: [Olaf Scholz, Germany, Berlin]. Please return only the array.

        Text: {text}
    """)
    chain = prompt_template | model
    response = chain.invoke({"text": state["text"]})
    
    print(f"-- START OF OUTPUT (entity_extraction_agent) --\n\n", response.content, "\n\n-- END OF OUTPUT --\n\n")
    
    return Command(goto="supervisor", update={"messages": state["messages"] + [response], "instruction": ""})

def relation_extraction_agent(state: cIEState) -> Command[Literal["supervisor"]]:
    prompt_template = PromptTemplate.from_template(
    """
    You are a relation extraction agent. Your task is to read the text of the user message and extract the relations found in the text. Each relation should be written in this exact format: <relation>subject;predicate;object</relation> (e.g.: <relation>Olaf Scholz;is chancellor of;Germany</relation>). Please return only the relations and no other information.
    
    Text: {text}
        """)
    chain = prompt_template | model
    response = chain.invoke({"text": state["text"]})
    
    print(f"-- START OF OUTPUT (relation_extraction_agent) --\n\n", response.content, "\n\n-- END OF OUTPUT --\n\n")
    
    return Command(goto="supervisor", update={"messages": state["messages"] + [response], "instruction": ""})

def uri_detection_agent(state: cIEState) -> Command[Literal["supervisor"]]:
    response = vector_store.similarity_search_with_score(state["instruction"], search_type="similarity", k=5)
    response = f'Output of uri_detection_agent: {[{"entity_label": doc.page_content, "uri": doc.metadata["uri"], "similarity_score": score} for doc, score in response]}'
    return Command(goto="supervisor", update={"messages": state["messages"] + [response], "instruction": ""})
    
builder = StateGraph(cIEState)
builder.add_node(supervisor)
builder.add_node(entity_extraction_agent)
builder.add_node(relation_extraction_agent)
builder.add_node(uri_detection_agent)

builder.add_edge(START, "supervisor")

graph = builder.compile()

In [15]:
from langfuse.callback import CallbackHandler
from dotenv import load_dotenv
import os

load_dotenv()
langfuse_handler = CallbackHandler(
    secret_key=os.getenv("LANGFUSE_SECRET_KEY"),
    public_key=os.getenv("LANGFUSE_PUBLIC_KEY"),
    host=os.getenv("LANGFUSE_HOST"),
)

target_doc = docs.iloc[0]
doc_id = target_doc["docid"]
response = graph.invoke({"text": target_doc["text"], "messages": [system_prompt, target_doc["text"]], "instruction": ""}, config={"callbacks": [langfuse_handler]})

-- START OF OUTPUT (supervisor) --

 <goto>entity_extraction_agent</goto> 

-- END OF OUTPUT --


-- START OF OUTPUT (entity_extraction_agent) --

 [CBS Corporation, The CW, Viacom, CBS Building, Midtown Manhattan, New York City] 

-- END OF OUTPUT --


-- START OF OUTPUT (supervisor) --

 <instruction>CBS Corporation</instruction><goto>uri_detection_agent</goto> 

-- END OF OUTPUT --


-- START OF OUTPUT (supervisor) --

 <instruction>The CW</instruction><goto>uri_detection_agent</goto> 

-- END OF OUTPUT --


-- START OF OUTPUT (supervisor) --

 <instruction>Midtown Manhattan</instruction><goto>uri_detection_agent</goto> 

-- END OF OUTPUT --


-- START OF OUTPUT (supervisor) --

 <instruction>New York City</instruction><goto>uri_detection_agent</goto> 

-- END OF OUTPUT --


-- START OF OUTPUT (supervisor) --

 <goto>relation_extraction_agent</goto> 

-- END OF OUTPUT --


-- START OF OUTPUT (relation_extraction_agent) --

 <relation>CBS Corporation;comprised;over-the-air television

In [16]:
relation_list = [x.split(";") for x in re.findall(r"<relation>(.*?)</relation>", response["messages"][-1].content)]
relation_list

[['Q950380', 'comprised', 'Q212252'],
 ['Q950380', 'comprised', 'television production and distribution'],
 ['Q950380', 'comprised', 'publishing'],
 ['Q950380', 'comprised', 'pay-cable'],
 ['Q950380', 'comprised', 'recording assets'],
 ['Q950380', 'was', "world's eighth largest entertainment company"],
 ['Q950380', 'had', 'headquarters'],
 ['Q950380', 'headquartered at', 'Q2640560'],
 ['Q2640560', 'located in', 'Q11249'],
 ['Q11249', 'part of', 'Q60']]

In [28]:
pred_relation_df = pd.DataFrame(relation_list, columns=["subject_uri", "predicate_uri", "object_uri"])
pred_relation_df

Unnamed: 0,subject_uri,predicate_uri,object_uri
0,CBS Corporation,comprised,over-the-air television broadcasting
1,CBS Corporation,comprised,television production and distribution
2,CBS Corporation,comprised,publishing
3,CBS Corporation,comprised,pay-cable
4,CBS Corporation,comprised,recording assets
5,CBS Corporation,was,world's eighth largest entertainment company
6,CBS Corporation,had,headquarters
7,CBS Corporation,headquartered at,CBS Building
8,CBS Building,is located in,Midtown Manhattan
9,Midtown Manhattan,is part of,New York City


In [31]:
doc_relation_df = relation_df[relation_df["docid"] == doc_id][["subject_uri", "predicate_uri", "object_uri"]]
doc_relation_df

Unnamed: 0,subject_uri,predicate_uri,object_uri
0,Q212252,P127,Q950380


In [30]:
correct_relation_df = pred_relation_df.merge(doc_relation_df[["subject_uri", "predicate_uri", "object_uri"]], on=["subject_uri", "predicate_uri", "object_uri"], how="inner")
correct_relation_df

Unnamed: 0,subject_uri,predicate_uri,object_uri


In [25]:
print(f"Accuracy: {len(correct_relation_df) / len(pred_relation_df)}")
print(f"Precision: {len(correct_relation_df) / len(doc_relation_df)}")

NameError: name 'correct_relation_df' is not defined