In [1]:
from dotenv import load_dotenv
load_dotenv()

True

In [2]:
import os

In [3]:
required_env_vars = [
    "NEO4J_URI",
    "NEO4J_USERNAME",
    "NEO4J_PASSWORD",
    "OPENAI_API_KEY",
    "ANTHROPIC_API_KEY"
]

for var in required_env_vars:
    assert os.getenv(var) , f"Environment variable {var} is not set."

#### We will be using anthropic API for graph creation since it gives the best results, and OpenAI for inference since it is faster and cheaper.

In [4]:
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic

In [5]:
llm_anthropic = ChatAnthropic(
    model="claude-3-5-sonnet-20240620",
    temperature=0,
    max_tokens=4096,
    max_retries=2
)

In [6]:
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate

In [7]:
system_prompt = (
    "# Knowledge Graph Instructions\n"
    "## 1. Overview\n"
    "You are a top-tier algorithm designed for extracting information in structured "
    "formats to build a knowledge graph.\n"
    "Try to capture as much information from the text as possible without "
    "sacrificing accuracy. Do not add any information that is not explicitly "
    "mentioned in the text.\n"
    "- **Nodes** represent entities and concepts.\n"
    "- The aim is to achieve simplicity and clarity in the knowledge graph, making it\n"
    "accessible for a vast audience.\n"
    "## 2. Labeling Nodes\n"
    "- **Consistency**: Ensure you use available types for node labels.\n"
    "Ensure you use basic or elementary types for node labels.\n"
    "- For example, when you identify an entity representing a person, "
    "always label it as **'person'**. Avoid using more specific terms "
    "like 'mathematician' or 'scientist'."
    "- **Node IDs**: Never utilize integers as node IDs. Node IDs should be "
    "names or human-readable identifiers found in the text.\n"
    "- **Node Names**: Create a **name** property for each node it should be names, or human-readable identifiers found in the text.\n"
    "- **Relationships** represent connections between entities or concepts.\n"
    "Ensure consistency and generality in relationship types when constructing "
    "knowledge graphs. Instead of using specific and momentary types "
    "such as 'BECAME_PROFESSOR', use more general and timeless relationship types "
    "like 'PROFESSOR'. Make sure to use general and timeless relationship types!\n"
    "## 3. Coreference Resolution\n"
    "- **Maintain Entity Consistency**: When extracting entities, it's vital to "
    "ensure consistency.\n"
    'If an entity, such as "John Doe", is mentioned multiple times in the text '
    'but is referred to by different names or pronouns (e.g., "Joe", "he"),'
    "always use the most complete identifier for that entity throughout the "
    'knowledge graph. In this example, use "John Doe" as the entity ID.\n'
    "Remember, the knowledge graph should be coherent and easily understandable, "
    "so maintaining consistency in entity references is crucial.\n"
    "## 4. Strict Compliance\n"
    "Adhere to the rules strictly. Non-compliance will result in termination."
)
custom_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            system_prompt,
        ),
        (
            "human",
            (
                "Tip: Make sure to answer in the correct format and do "
                "not include any explanations. "
                "Make sure to **strictly** add name property for each node."
                "Remember that some information may be saved as properties of nodes and relationships, think it through add few items as properties as you see fit."
                "Use the given format to extract information from the "
                "following input: {input}"
            ),
        ),
    ]
)

#### Updated the graph creation prompt to include the following:
- Nodes in the graph include a **name** property.
- Name property is made to be added stricty to the nodes.
- Reminded to add properties to the nodes and relationships whenever necessary since properties where not emphasized by the model.

In [8]:
from langchain_experimental.graph_transformers import LLMGraphTransformer

In [9]:
graph_maker = LLMGraphTransformer(llm=llm_anthropic, node_properties=True, relationship_properties=True, strict_mode=True, prompt=custom_prompt)

- Setting node_properties and relationship_properties to True to ensure that the model adds properties to the nodes and relationships whenever necessary.
- Providing properties to be extracted ensures consistency, however, since our usecase it open-ended, we will not be providing any properties to be extracted.

In [10]:
from langchain_community.document_loaders import PyPDFLoader

loader = PyPDFLoader("../sample/mango.pdf")
docs = loader.load()

Read file into documents, and then break it down into chunks considering the context size of LLMs.

In [11]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(chunk_size=4000, chunk_overlap=100)
text_chunks = text_splitter.split_documents(docs)

In [12]:
len(text_chunks)

7

In [13]:
graph_documents = graph_maker.convert_to_graph_documents(text_chunks)

The above step takes in our text chunks and returns a list GraphDocument object containing the nodes and relationships. Since graph_maker object was given custom prompt graph would be made as per our specification.

In [14]:
graph_documents[0].nodes
graph_documents[0].relationships

[Relationship(source=Node(id='Mango', type='Fruit'), target=Node(id='Mangifera Indica', type='Plant'), type='PRODUCED_BY'),
 Relationship(source=Node(id='Mangifera Indica', type='Plant'), target=Node(id='Mango_Fruit', type='Fruit'), type='PRODUCES'),
 Relationship(source=Node(id='Mangifera Indica', type='Plant'), target=Node(id='Mango_Leaf', type='Plant_part'), type='HAS_PART'),
 Relationship(source=Node(id='Mangifera Indica', type='Plant'), target=Node(id='Mango_Flower', type='Plant_part'), type='HAS_PART'),
 Relationship(source=Node(id='Bangladesh', type='Country'), target=Node(id='Mangifera Indica', type='Plant'), type='NATIONAL_TREE')]

In [15]:
from langchain_community.graphs import Neo4jGraph

In [16]:
graph = Neo4jGraph(driver_config={"max_connection_lifetime": 3600})

This creates a neo4j driver wrapper to interact with the neo4j database, provided by Langchain, we then use it write to our database.

In [17]:
graph.add_graph_documents(graph_documents)

Refreshing the graph schema would show us the relationships and nodes created. We will also be give to the LLMs to create cypher queries to extract information from the graph.

In [18]:
graph.refresh_schema()
graph.structured_schema

{'node_props': {'Fruit': [{'property': 'id', 'type': 'STRING'},
   {'property': 'name', 'type': 'STRING'},
   {'property': 'length_range', 'type': 'STRING'},
   {'property': 'weight_range', 'type': 'STRING'},
   {'property': 'ripening_time', 'type': 'STRING'},
   {'property': 'national_fruit_of', 'type': 'STRING'},
   {'property': 'scientific_name', 'type': 'STRING'},
   {'property': 'origin', 'type': 'STRING'}],
  'Plant': [{'property': 'id', 'type': 'STRING'},
   {'property': 'name', 'type': 'STRING'},
   {'property': 'scientific_name', 'type': 'STRING'},
   {'property': 'crown_radius', 'type': 'STRING'},
   {'property': 'common_name', 'type': 'STRING'},
   {'property': 'lifespan', 'type': 'STRING'},
   {'property': 'height', 'type': 'STRING'}],
  'Plant_part': [{'property': 'id', 'type': 'STRING'},
   {'property': 'name', 'type': 'STRING'},
   {'property': 'arrangement', 'type': 'STRING'},
   {'property': 'shape', 'type': 'STRING'},
   {'property': 'length', 'type': 'STRING'},
   {'

We now will use neo4jupyter package to visualize the graph. It is made on top of vis.js and is good for visualizing nep4j graphs in jupyter notebooks.

In [19]:
import neo4jupyter
neo4jupyter.init_notebook_mode()

<IPython.core.display.Javascript object>

In [20]:
def show_graph():
    load_dotenv()
    NEO4J_URI = os.getenv("NEO4J_URI")
    NEW_NEO4J_URI = NEO4J_URI.replace("neo4j+s://", "bolt+s://")
    os.environ.pop('NEO4J_URI')

    from py2neo import Graph
    vis_graph = Graph(NEW_NEO4J_URI, auth=(os.getenv("NEO4J_USERNAME"), os.getenv("NEO4J_PASSWORD")))

    query = "MATCH (n) RETURN COUNT(n) AS total_nodes"
    result = vis_graph.run(query).data()

    total_nodes = result[0]['total_nodes']

    os.environ["NEO4J_URI"] = NEO4J_URI
    return neo4jupyter.draw(vis_graph,{}, limit=total_nodes)
    

This function will show us the graph, there is a workaround done in the function. 
It requires p2neo package to read from the neo4j database, not the Graph in py2neo reads neo4j credentials and uri directly from the environment if it exist and cannot be overridden. However, the issue is that it only support bolt protocol and the neo4j uri in environment is in neo4j+s protocol. So, we poped the uri from the environment and rewrote after visualization. 

In [21]:
show_graph()

After visualization, we see that there has been a lot of nodes that are similar with slight variations, this leads to disjointed nodes or nodes that cannot be reasoned by the LLMs. We will now merge the nodes that are similar to each other. We will do disambiguation with a custom prompt.

In [22]:

system_prompt = """
Act as a entity disambiugation tool and tell me which values reference the same entity. 
For example if I give you

Birds
Bird
Ant

You return to me

Birds, 1
Bird, 1
Ant, 2

As the Bird and Birds values have the same integer assigned to them, it means that they reference the same entity.

"""

disambiguate_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            system_prompt,
        ),
        (
            "human",
            (                
                "Perform disambiguation on the following values: \n{input}"
            ),
        ),
    ]
)

We now define pydantic classes to define the return types of the function calling feature in Claude/ OpenAI LLMs.

In [23]:
from langchain_core.pydantic_v1 import BaseModel, Field
from typing import List

class DisambiguatedNode(BaseModel):
    """Get the node name and id from the collection of nodes"""
    name: str = Field(..., title="Disambiquated name for the entity")
    id: int = Field(..., title="Disambiquated id for the entity")

class DisambiguatedNodeList(BaseModel):
    """Get the node name and id from the collection of nodes"""
    nodes: List[DisambiguatedNode] = Field(..., title="List of disambiguated nodes")

In [24]:
llm_openai = ChatOpenAI(model="gpt-4-turbo", temperature=0)

In [25]:
disambiguation_chain = disambiguate_prompt | llm_openai.with_structured_output(DisambiguatedNodeList)

We created a disambiguation chain that takes the prompt and runs it through LLM, and returns returns the disambiguated nodes after receiving the structured output as per out definition for function calling.

In [26]:
node_names = graph.query("MATCH (n) WHERE NOT n:Chunk RETURN n.name as name")



Note: We ignore nodes with label Chunk, as those are used to store embedding chunks, we haven't used them in this case, but may exist in the db from a different operation.

In [27]:
node_names[:5]

[{'name': 'Northeastern India'},
 {'name': 'Indian Group'},
 {'name': 'Southeast Asian Group'},
 {'name': 'Mango'},
 {'name': 'Southeast Asia'}]

In [28]:
node_names_list = [name["name"] for name in node_names]

In [29]:
node_names_list[:5]

['Northeastern India',
 'Indian Group',
 'Southeast Asian Group',
 'Mango',
 'Southeast Asia']

In [30]:
response = disambiguation_chain.invoke({"input": "\n".join(node_names_list)})

We invoked the disambiguation chain on the node names and assigned same id to the nodes that are similar to each other.

In [31]:
response.dict()["nodes"][:10]

[{'name': 'Northeastern India', 'id': 1},
 {'name': 'Indian Group', 'id': 2},
 {'name': 'Southeast Asian Group', 'id': 3},
 {'name': 'Mango', 'id': 4},
 {'name': 'Southeast Asia', 'id': 5},
 {'name': 'South Asia', 'id': 6},
 {'name': 'Alphonso', 'id': 7},
 {'name': 'Julie', 'id': 8},
 {'name': 'Tommy Atkins', 'id': 9},
 {'name': 'East Africa', 'id': 10}]

In [32]:
node_clusters = [[node["name"], node["id"]] for node in response.dict()["nodes"]]

In [33]:
node_clusters[:10]

[['Northeastern India', 1],
 ['Indian Group', 2],
 ['Southeast Asian Group', 3],
 ['Mango', 4],
 ['Southeast Asia', 5],
 ['South Asia', 6],
 ['Alphonso', 7],
 ['Julie', 8],
 ['Tommy Atkins', 9],
 ['East Africa', 10]]

In [34]:
cluster_map = {}
for cluster_name, cluster_id in node_clusters:
    if cluster_id not in cluster_map:
        cluster_map[cluster_id] = cluster_name.capitalize()

In [35]:
mapped_nodes = []
for cluster_name, cluster_id in node_clusters:
    mapped_nodes.append([cluster_name,cluster_map[cluster_id]])

We created the name to which the nodes of the same clusters are to be renamed to.

In [36]:
mapped_nodes[:10]

[['Northeastern India', 'Northeastern india'],
 ['Indian Group', 'Indian group'],
 ['Southeast Asian Group', 'Southeast asian group'],
 ['Mango', 'Mango'],
 ['Southeast Asia', 'Southeast asia'],
 ['South Asia', 'South asia'],
 ['Alphonso', 'Alphonso'],
 ['Julie', 'Julie'],
 ['Tommy Atkins', 'Tommy atkins'],
 ['East Africa', 'East africa']]

In [37]:
node_rename_cypher = """
UNWIND $mapping as pair
MATCH (n {name: pair[0]})
SET n.name = pair[1]
"""


graph.query(node_rename_cypher,{
    "mapping": mapped_nodes
})

[]

We renamed the nodes in the db

In [38]:
node_merge_cypher = """
MATCH (n)
WHERE NOT 'Chunk' IN labels(n)
WITH n.name as nodeId, collect(n) as nodes
CALL apoc.refactor.mergeNodes(nodes, {properties: "combine", mergeRels: true})
YIELD node
RETURN node;
"""

graph.query(node_merge_cypher)

[{'node': {'name': 'Northeastern india', 'id': 'Northeastern_India'}},
 {'node': {'name': 'Indian group',
   'embryony': 'monoembryonic',
   'id': 'Indian_Group'}},
 {'node': {'name': 'Southeast asian group',
   'embryony': 'polyembryonic',
   'id': 'Southeast_Asian_Group'}},
 {'node': {'name': 'Mango',
   'national_fruit_of': 'India, Pakistan, Philippines',
   'scientific_name': 'Mangifera indica',
   'id': 'Mango',
   'origin': 'Region between northwestern Myanmar, Bangladesh, and northeastern India'}},
 {'node': {'name': 'Southeast asia', 'id': 'Southeast Asia'}},
 {'node': {'name': 'South asia', 'id': 'South Asia'}},
 {'node': {'name': 'Alphonso',
   'description': "important export product, considered 'the king of mangoes'",
   'id': 'Alphonso',
   'origin': 'India'}},
 {'node': {'name': 'Julie',
   'description': 'prolific cultivar in Jamaica',
   'id': 'Julie'}},
 {'node': {'name': 'Tommy atkins',
   'location': 'southern Florida',
   'id': 'Tommy Atkins',
   'first_fruited': '1

We now merge the nodes with same name property, this is why we forced the LLM earlier to add name property, any arbitrary property name could be chose, name happened to be easier for the LLM to extract.

In [39]:
graph.refresh_schema()
graph.structured_schema

{'node_props': {'Fruit': [{'property': 'id', 'type': 'STRING'},
   {'property': 'name', 'type': 'STRING'},
   {'property': 'national_fruit_of', 'type': 'STRING'},
   {'property': 'scientific_name', 'type': 'STRING'},
   {'property': 'origin', 'type': 'STRING'},
   {'property': 'length_range', 'type': 'STRING'},
   {'property': 'weight_range', 'type': 'STRING'},
   {'property': 'ripening_time', 'type': 'STRING'}],
  'Plant': [{'property': 'id', 'type': 'STRING'},
   {'property': 'name', 'type': 'STRING'},
   {'property': 'national_fruit_of', 'type': 'STRING'},
   {'property': 'scientific_name', 'type': 'STRING'},
   {'property': 'origin', 'type': 'STRING'},
   {'property': 'crown_radius', 'type': 'STRING'},
   {'property': 'common_name', 'type': 'STRING'},
   {'property': 'lifespan', 'type': 'STRING'},
   {'property': 'height', 'type': 'STRING'}],
  'Plant_part': [{'property': 'id', 'type': 'STRING'},
   {'property': 'name', 'type': 'STRING'},
   {'property': 'arrangement', 'type': 'STR

In [40]:
show_graph()

We note provide a custom prompt to the LLM to create a cypher query to extract information from the graph, we provide key insights to the LLM here:

- When the node names where rewritter we used capitalization to make the nodes more readable and to make consistent query on the graph.
- Remind the LLm to use properties to extract information from the graph.
- We ask it use synonyms to extract information from the graph.

In [41]:
CYPHER_GENERATION_TEMPLATE = """Task: Generate a Cypher statement to query a graph database.
Instructions:
- Use only the provided relationship types and properties in the schema.
- Do not use any other relationship types or properties that are not listed.
- Examine the properties of nodes and relationships closely, as the answer might be found there.
- STRICTLY use **capitalize** when using **name** property of nodes, for example instead of "mango" use "Mango" or "Mango Tree" use "Mango tree".
- When fetching from properties use synonyms of values to match, for example instead of just Leaf do an or on Mango Leaf".

Schema:
{schema}
Note: Do not include any explanations or apologies in your responses.
Do not respond to any questions that might ask anything else than for you to construct a Cypher statement.
Do not include any text except the generated Cypher statement.

The question is:
{question}"""
CYPHER_GENERATION_PROMPT = PromptTemplate(
    input_variables=["schema", "question"], template=CYPHER_GENERATION_TEMPLATE
)

In [42]:
from langchain.chains import GraphCypherQAChain

We create a graph query chain that interprets the user queries, create cypher queries and runs them on the graph, and uses LLM to creat human readable responses.

In [43]:
graph_qa_chain = GraphCypherQAChain.from_llm(
    llm = ChatOpenAI(model="gpt-4-turbo", temperature=0),
    graph=graph,
    verbose=True,
    cypher_prompt=CYPHER_GENERATION_PROMPT,
    validate=True,    
)

In [44]:
graph_qa_chain.invoke(
    "what is the scientific name of mango?"
)



[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mcypher
MATCH (f:Fruit {name: "Mango"})
RETURN f.scientific_name AS ScientificName;
[0m
Full Context:
[32;1m[1;3m[{'ScientificName': 'Mangifera indica'}][0m

[1m> Finished chain.[0m


{'query': 'what is the scientific name of mango?',
 'result': 'The scientific name of mango is Mangifera indica.'}

In [47]:
graph_qa_chain.invoke(
    "where are mangoes consumed?"
)



[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mcypher
MATCH (m:Fruit)-[:CONSUMED_IN]->(l:Location)
WHERE m.name = "Mango"
RETURN l.name AS Location
[0m
Full Context:
[32;1m[1;3m[{'Location': 'Southeast asia'}, {'Location': 'Central america'}][0m

[1m> Finished chain.[0m


{'query': 'where are mangoes consumed?',
 'result': 'Mangoes are consumed in Southeast Asia and Central America.'}

In [48]:
graph_qa_chain.invoke(
    "mango is the national fruit of which all countries?"
)



[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mcypher
MATCH (f:Fruit {name: "Mango"})-[:NATIONAL_FRUIT_OF]->(c:Country)
RETURN c.name AS Country
[0m
Full Context:
[32;1m[1;3m[{'Country': 'India'}][0m

[1m> Finished chain.[0m


{'query': 'mango is the national fruit of which all countries?',
 'result': 'Mango is the national fruit of India.'}

In [49]:
graph_qa_chain.invoke(
    "mango is the national fruit/tree of which all countries?"
)



[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mcypher
MATCH (m:Fruit)-[:NATIONAL_FRUIT_OF]->(c:Country)
WHERE m.name = "Mango"
RETURN c.name AS Country
UNION
MATCH (p:Plant)-[:NATIONAL_TREE_OF]->(c:Country)
WHERE p.name = "Mango tree"
RETURN c.name AS Country
[0m
Full Context:
[32;1m[1;3m[{'Country': 'India'}, {'Country': 'Bangladesh'}][0m

[1m> Finished chain.[0m


{'query': 'mango is the national fruit/tree of which all countries?',
 'result': 'Mango is the national fruit of India and Bangladesh.'}

In [50]:
graph_qa_chain.invoke(
    "What all does mango contain?"
)



[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mcypher
MATCH (m:Fruit {name: "Mango"})-[:CONTAINS]->(s)
RETURN s.name AS SubstanceOrNutrient
[0m
Full Context:
[32;1m[1;3m[{'SubstanceOrNutrient': 'Vitamin c'}, {'SubstanceOrNutrient': 'Folate'}, {'SubstanceOrNutrient': 'Mango allergens'}][0m

[1m> Finished chain.[0m


{'query': 'What all does mango contain?',
 'result': 'Mango contains Vitamin C, Folate, and allergens specific to mangoes.'}

In [60]:
graph_qa_chain.invoke(
    "mangoes are used as ingredients in?"
)



[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mcypher
MATCH (m:Fruit)-[:INGREDIENT_OF]->(f:Food)
WHERE m.name = "Mango"
RETURN f.name AS FoodName
[0m
Full Context:
[32;1m[1;3m[{'FoodName': 'Aam panna'}, {'FoodName': 'Mango lassi'}, {'FoodName': 'Aamras'}, {'FoodName': 'Mangada'}, {'FoodName': 'Mango sticky rice'}][0m

[1m> Finished chain.[0m


{'query': 'mangoes are used as ingredients in?',
 'result': 'Mangoes are used as ingredients in Aam panna, Mango lassi, Aamras, Mangada, and Mango sticky rice.'}