Simple Implementation of Graph Based Retrieval Augmented Generation using Neo4J and Langchain


In [None]:
%pip install --upgrade --quiet langchain langchain-community langchain-openai langchain-experimental neo4j wikipedia tiktoken yfiles_jupyter_graphs


  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.4/50.4 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m21.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m47.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.0/52.0 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m204.3/204.3 kB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m294.6/294.6 kB[0m [31m16.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m37.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m15.6/15.6 MB[0m [31m54.4 MB/s[0m eta [36m0:00:00[

In [None]:
from google.colab import userdata
openaiapi= userdata.get("openai")

In [None]:
import os
os.environ["OPENAI_API_KEY"]=openaiapi
os.environ["NEO4J_URI"]=NEO4J_URI
os.environ["NEO4J_USERNAME"]=NEO4J_USERNAME
os.environ["NEO4J_PASSWORD"]=NEO4J_PASSWORD

In [None]:
from langchain_community.graphs import Neo4jGraph
graph = Neo4jGraph()

In [None]:
from langchain_openai import ChatOpenAI
llm=ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo-0125")

In [None]:
from langchain_experimental.graph_transformers import LLMGraphTransformer
llmt=LLMGraphTransformer(llm=llm)

In [None]:

from langchain.document_loaders import WikipediaLoader
rawdoc= WikipediaLoader("Machine Learning").load()

In [None]:
from langchain.text_splitter import TokenTextSplitter
tokensplitter= TokenTextSplitter(chunk_size=512,chunk_overlap=24)
docs= tokensplitter.split_documents(rawdoc[:3])

In [None]:
graph_docs=llmt.convert_to_graph_documents(docs)

In [None]:
graph.add_graph_documents(
    graph_docs,
    baseEntityLabel=True,
    include_source=True)


In [None]:
default_cypher = "MATCH (s)-[r:!MENTIONS]->(t) RETURN s,r,t LIMIT 50"


In [None]:
from yfiles_jupyter_graphs import GraphWidget
from neo4j import GraphDatabase

In [None]:
from google.colab import output
output.enable_custom_widget_manager()

In [None]:
def showGraph(cypher: str = default_cypher):
    # create a neo4j session to run queries
    driver = GraphDatabase.driver(
        uri = os.environ["NEO4J_URI"],
        auth = (os.environ["NEO4J_USERNAME"],
                os.environ["NEO4J_PASSWORD"]))
    session = driver.session()
    widget = GraphWidget(graph = session.run(cypher).graph())
    widget.node_label_mapping = 'id'
    display(widget)
    return widget

In [None]:
showGraph()

GraphWidget(layout=Layout(height='800px', width='100%'))

GraphWidget(layout=Layout(height='800px', width='100%'))

In [None]:
from typing import Tuple, List, Optional
from langchain_community.vectorstores import Neo4jVector

In [None]:
from langchain_openai import OpenAIEmbeddings
vector_index= Neo4jVector.from_existing_graph(
    OpenAIEmbeddings(),
    search_type="hybrid",
    node_label="Document",
    text_node_properties=["text"],
    embedding_node_property="embedding",
    url=os.environ["NEO4J_URI"],
    username=os.environ["NEO4J_USERNAME"],
    password=os.environ["NEO4J_PASSWORD"]
)

In [None]:
graph.query("CREATE FULLTEXT INDEX entity IF NOT EXISTS FOR (e:__Entity__) ON EACH [e.id]")

[]

In [None]:
from langchain_core.pydantic_v1 import BaseModel,Field

class Entities(BaseModel):
  """Idetifying Information about Entities"""

  names: List[str] = Field(
      ...,
      description ="All the person, organization or business entities that appear in the text"
  )

In [None]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts.prompt import PromptTemplate

In [None]:
prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are extracting organization and person entites from the text.",
        ),
        (
            "human",
            "Use the given format to extact information from the following"
            "input: {question}",
        ),
    ]
)

In [None]:
entity_chain = prompt | llm.with_structured_output(Entities)

In [None]:
entity_chain.invoke({"question": "  Peter Thiel is the founder of PayPal and Founder Fund"}).names

['Peter Thiel', 'PayPal', 'Founder Fund']

In [None]:
from langchain_community.vectorstores.neo4j_vector import remove_lucene_chars

In [None]:
from typing import List

def gen_full_query(input: str, use_fuzzy: bool = False, operator: str = "AND") -> str:
    full_text_query = ""
    cleaned_input = remove_lucene_chars(input).strip()
    if not cleaned_input:
        return full_text_query

    words = cleaned_input.split()

    # If the input has multiple words, use phrase matching
    if len(words) > 1:
        phrase = f'"{cleaned_input}"'
        full_text_query += phrase + " "

    for word in words[:-1]:
        full_text_query += f"{word} " if not use_fuzzy else f"{word}~ {operator} "

    last_word = words[-1]
    full_text_query += f"{last_word}~" if use_fuzzy and len(last_word) > 2 else last_word

    return full_text_query.strip()



In [None]:
def structured_retriever(question: str) -> str:
    result = ""
    entities = entity_chain.invoke({"question": question})
    for entity in entities.names:
        response = graph.query(
            """
            CALL db.index.fulltext.queryNodes('entity', $query, {limit:5})
            YIELD node, score
            WITH node, score
            ORDER BY score DESC  // Ensure the highest scoring results come first
            CALL {
              WITH node
              MATCH (node)-[r:!MENTIONS]->(neighbor)
              RETURN node.id + ' - ' + type(r) + ' -> ' + neighbor.id AS output
              UNION ALL
              WITH node
              MATCH (node)<-[r:!MENTIONS]-(neighbor)
              RETURN neighbor.id + ' - ' + type(r) + ' -> ' +  node.id AS output
            }
            RETURN output LIMIT 50
            """,
            {"query": gen_full_query(entity, use_fuzzy=False)},
        )
        result += "\n".join([el['output'] for el in response])
    return result



In [None]:
print(structured_retriever("Machine Learning"))



Alan Turing - INSPIRED -> Machine Learning
Quantum Machine Learning - INTEGRATION -> Machine Learning Programs
Quantum Machine Learning - APPLIED_TO -> Data From Quantum Experiments
Quantum Machine Learning - RELATED_TO -> Quantum Information
Quantum Machine Learning - RELATED_TO -> Quantum Learning Theory
Quantum Machine Learning - INTEGRATION -> Quantum Algorithms
Quantum Machine Learning - INTEGRATION -> Machine Learning Programs
Quantum Machine Learning - ANALYSIS -> Classical Data
Quantum Machine Learning - REFERS_TO -> Quantum-Enhanced Machine Learning
Quantum Machine Learning - UTILIZES -> Qubits
Quantum Machine Learning - UTILIZES -> Quantum Operations
Quantum Machine Learning - INVOLVES -> Hybrid Methods
Quantum Machine Learning - OUTSOURCES_TO -> Quantum Device
Quantum Machine Learning - ANALYZES -> Quantum States
Quantum Machine Learning - ASSOCIATED_WITH -> Classical Machine Learning Methods
Quantum Machine Learning - LEARNING_OF -> Phase Transitions Of A Quantum System
Qua

In [None]:
def retriever(question: str):
    print(f"Search query: {question}")
    structured_data = structured_retriever(question)
    unstructured_data = [el.page_content for el in vector_index.similarity_search(question)]
    final_data = f"""Structured data:
{structured_data}
Unstructured data:
{"#Document ". join(unstructured_data)}
    """
    return final_data


In [None]:
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question,
in its original language.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""

In [None]:
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)

In [None]:
def _format_chat_history(chat_history: List[Tuple[str, str]]) -> List:
    buffer = []
    for human, ai in chat_history:
        buffer.append(HumanMessage(content=human))
        buffer.append(AIMessage(content=ai))
    return buffer

In [None]:
from langchain_core.runnables import RunnableBranch, RunnableLambda,RunnablePassthrough
from langchain_core.output_parsers.string import StrOutputParser
from langchain.schema import HumanMessage, AIMessage
_search_query = RunnableBranch(
    # If input includes chat_history, we condense it with the follow-up question
    (
        RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
            run_name="HasChatHistoryCheck"
        ),  # Condense follow-up question and chat into a standalone_question
        RunnablePassthrough.assign(
            chat_history=lambda x: _format_chat_history(x["chat_history"])
        )
        | CONDENSE_QUESTION_PROMPT
        | ChatOpenAI(temperature=0)
        | StrOutputParser(),
    ),
    # Else, we have no chat history, so just pass through the question
    RunnableLambda(lambda x : x["question"]),
)

In [None]:
template = """Answer the question based only on the following context:
{context}

Question: {question}
Use natural language and be concise.
Answer:"""

In [None]:
prompt = ChatPromptTemplate.from_template(template)

In [None]:
from langchain_core.runnables import RunnableParallel
chain = (
    RunnableParallel(
        {
            "context": _search_query | retriever,
            "question": RunnablePassthrough(),
        }
    )
    | prompt
    | llm
    | StrOutputParser()
)

In [None]:
chain.invoke({"question": "what is the uses of Machine Learning?"})

Search query: what is the uses of Machine Learning?


ERROR:neo4j.io:Failed to write data to connection ResolvedIPv4Address(('34.126.171.25', 7687)) (ResolvedIPv4Address(('34.126.171.25', 7687)))
ERROR:neo4j.io:Failed to write data to connection IPv4Address(('696dddc5.databases.neo4j.io', 7687)) (ResolvedIPv4Address(('34.126.171.25', 7687)))


'Machine learning is used in various fields such as natural language processing, computer vision, speech recognition, email filtering, agriculture, medicine, and predictive analytics in business.'