<a href="https://colab.research.google.com/github/neo4j-contrib/ms-graphrag-neo4j/blob/main/examples/drift_search.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install --quiet --upgrade git+https://github.com/neo4j-contrib/ms-graphrag-neo4j.git tiktoken llama-index-core json-repair

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [None]:
import os
from getpass import getpass
from typing import List

import requests
import tiktoken
from llama_index.core.node_parser import TokenTextSplitter
from llama_index.core.workflow import (
    Event,
    Context,
    StartEvent,
    StopEvent,
    Workflow,
    step,
)
from neo4j import GraphDatabase
from openai import AsyncOpenAI

import json_repair

from ms_graphrag_neo4j import MsGraphRAG

In [None]:
# Use Neo4j Sandbox - Blank Project https://sandbox.neo4j.com/

os.environ["NEO4J_URI"]="bolt://44.197.242.70:7687"
os.environ["NEO4J_USERNAME"]="neo4j"
os.environ["NEO4J_PASSWORD"]="oars-condensation-sectors"

In [None]:
os.environ["OPENAI_API_KEY"]= getpass("Openai API Key:")

Openai API Key:··········


In [None]:
client = AsyncOpenAI()

driver = GraphDatabase.driver(
    os.environ["NEO4J_URI"],
    auth=(os.environ["NEO4J_USERNAME"], os.environ["NEO4J_PASSWORD"]),
    notifications_min_severity="OFF",
)

ms_graph = MsGraphRAG(driver=driver, model="gpt-5-mini", max_workers=10)

# Ingestion

In [None]:
book = requests.get('https://www.gutenberg.org/cache/epub/11/pg11.txt').text

In [None]:
splitter = TokenTextSplitter(chunk_size=512, chunk_overlap=20)

In [None]:
class EntitySummarization(Event):
    pass

class CommunitySummarization(Event):
    pass

class CommunityEmbeddings(Event):
    pass

class EntityEmbeddings(Event):
    pass

In [None]:
MIN_COMMUNITY_RATING = 3
TEXT_EMBEDDING_MODEL = "text-embedding-3-small"


class MSGraphRAGIngestion(Workflow):
    @step
    async def entity_extraction(self, ev: StartEvent) -> EntitySummarization:
        chunks = splitter.split_text(ev.text)
        await ms_graph.extract_nodes_and_rels(chunks, ev.allowed_entities)
        return EntitySummarization()

    @step
    async def entity_summarization(
        self, ev: EntitySummarization
    ) -> CommunitySummarization:
        await ms_graph.summarize_nodes_and_rels()
        return CommunitySummarization()

    @step
    async def community_summarization(
        self, ev: CommunitySummarization
    ) -> CommunityEmbeddings:
        await ms_graph.summarize_communities()
        return CommunityEmbeddings()

    @step
    async def community_embeddings(self, ev: CommunityEmbeddings) -> EntityEmbeddings:
        communities = ms_graph.query(
            """
MATCH (c:__Community__)
WHERE c.summary IS NOT NULL AND c.rating > $min_community_rating
RETURN coalesce(c.title, "") + " " + c.summary AS community_description, c.id AS community_id
""",
            params={"min_community_rating": MIN_COMMUNITY_RATING},
        )
        if communities:
            response = await client.embeddings.create(
                input=[c["community_description"] for c in communities],
                model=TEXT_EMBEDDING_MODEL,
            )
            embeds = []
            for community, embedding in zip(communities, response.data):
                embeds.append(
                    {
                        "community_id": community["community_id"],
                        "embedding": embedding.embedding,
                    }
                )
            ms_graph.query(
                """UNWIND $data as row
            MATCH (c:__Community__ {id: row.community_id})
            CALL db.create.setNodeVectorProperty(c, 'embedding', row.embedding)""",
                params={"data": embeds},
            )
            ms_graph.query(
                "CREATE VECTOR INDEX community IF NOT EXISTS FOR (c:__Community__) ON c.embedding"
            )
        else:
            print("No community was summarized")
        return EntityEmbeddings()

    @step
    async def entity_embeddings(self, ev: EntityEmbeddings) -> StopEvent:
        entities = ms_graph.query("""
    MATCH (e:__Entity__)
    WHERE e.summary IS NOT NULL
    RETURN coalesce(e.name, "") + " " + e.summary AS entity_description, e.name AS entity_name
    """)
        if entities:
            response = await client.embeddings.create(
                input=[e["entity_description"] for e in entities],
                model=TEXT_EMBEDDING_MODEL,
            )
            embeds = []
            for entity, embedding in zip(entities, response.data):
                embeds.append(
                    {
                        "entity_name": entity["entity_name"],
                        "embedding": embedding.embedding,
                    }
                )
            ms_graph.query(
                """UNWIND $data as row
            MATCH (e:__Entity__ {name: row.entity_name})
            CALL db.create.setNodeVectorProperty(e, 'embedding', row.embedding)""",
                params={"data": embeds},
            )
            ms_graph.query(
                "CREATE VECTOR INDEX entity IF NOT EXISTS FOR (e:__Entity__) ON e.embedding"
            )
        else:
            print("No entity was summarized")
        return StopEvent()

In [None]:
w = MSGraphRAGIngestion(timeout=3600, verbose=False)
result = await w.run(text=book, allowed_entities = ["PERSON", "ORGANIZATION", "LOCATION", "EVENT", "ARTIFACT"])

Extracting nodes & relationships: 100%|██████████| 86/86 [07:03<00:00,  4.93s/it]
Summarizing nodes: 100%|██████████| 101/101 [01:21<00:00,  1.23it/s]
Summarizing relationships: 100%|██████████| 120/120 [01:33<00:00,  1.28it/s]


Leiden algorithm identified 3 community levels with 22 communities on the last level.


Summarizing communities: 100%|██████████| 22/22 [01:46<00:00,  4.85s/it]


# Retrieval

In [None]:
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""DRIFT Search prompts."""

DRIFT_LOCAL_SYSTEM_PROMPT = """
---Role---

You are a helpful assistant responding to questions about data in the tables provided.


---Goal---

Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.

If you don't know the answer, just say so. Do not make anything up.

Points supported by data should list their data references as follows:

"This is an example sentence supported by multiple data references [Data: <dataset name> (record ids); <dataset name> (record ids)]."

Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more.

For example:

"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (15, 16)]."

where 15, 16, 1, 5, 7, 23, 2, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record.

Pay close attention specifically to the Sources tables as it contains the most relevant information for the user query. You will be rewarded for preserving the context of the sources in your response.

---Target response length and format---

{response_type}


---Data tables---

{context_data}


---Goal---

Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.

If you don't know the answer, just say so. Do not make anything up.

Points supported by data should list their data references as follows:

"This is an example sentence supported by multiple data references [Data: <dataset name> (record ids); <dataset name> (record ids)]."

Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more.

For example:

"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (15, 16)]."

where 15, 16, 1, 5, 7, 23, 2, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record.

Pay close attention specifically to the Sources tables as it contains the most relevant information for the user query. You will be rewarded for preserving the context of the sources in your response.

---Target response length and format---

{response_type}

Add sections and commentary to the response as appropriate for the length and format.

Additionally provide a score between 0 and 100 representing how well the response addresses the overall research question: {global_query}. Based on your response, suggest up to five follow-up questions that could be asked to further explore the topic as it relates to the overall research question. Do not include scores or follow up questions in the 'response' field of the JSON, add them to the respective 'score' and 'follow_up_queries' keys of the JSON output. Format your response in JSON with the following keys and values:

{{'response': str, Put your answer, formatted in markdown, here. Do not answer the global query in this section.
'score': int,
'follow_up_queries': List[str]}}
"""


DRIFT_REDUCE_PROMPT = """
---Role---

You are a helpful assistant responding to questions about data in the reports provided.

---Goal---

Generate a response of the target length and format that responds to the user's question, summarizing all information in the input reports appropriate for the response length and format, and incorporating any relevant general knowledge while being as specific, accurate and concise as possible.

If you don't know the answer, just say so. Do not make anything up.

Points supported by data should list their data references as follows:

"This is an example sentence supported by multiple data references [Data: <dataset name> (record ids); <dataset name> (record ids)]."

Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more.

For example:

"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (1, 5, 15)]."

Do not include information where the supporting evidence for it is not provided.

If you decide to use general knowledge, you should add a delimiter stating that the information is not supported by the data tables. For example:

"Person X is the owner of Company Y and subject to many allegations of wrongdoing. [Data: General Knowledge (href)]"

---Data Reports---

{context_data}

---Target response length and format---

{response_type}


---Goal---

Generate a response of the target length and format that responds to the user's question, summarizing all information in the input reports appropriate for the response length and format, and incorporating any relevant general knowledge while being as specific, accurate and concise as possible.

If you don't know the answer, just say so. Do not make anything up.

Points supported by data should list their data references as follows:

"This is an example sentence supported by multiple data references [Data: <dataset name> (record ids); <dataset name> (record ids)]."

Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more.

For example:

"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (1, 5, 15)]."

Do not include information where the supporting evidence for it is not provided.

If you decide to use general knowledge, you should add a delimiter stating that the information is not supported by the data tables. For example:

"Person X is the owner of Company Y and subject to many allegations of wrongdoing. [Data: General Knowledge (href)]".

Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. Now answer the following query using the data above:

"""


DRIFT_PRIMER_PROMPT = """You are a helpful agent designed to reason over a knowledge graph in response to a user query.
This is a unique knowledge graph where edges are freeform text rather than verb operators. You will begin your reasoning looking at a summary of the content of the most relevant communites and will provide:

1. score: How well the intermediate answer addresses the query. A score of 0 indicates a poor, unfocused answer, while a score of 100 indicates a highly focused, relevant answer that addresses the query in its entirety.

2. intermediate_answer: This answer should match the level of detail and length found in the community summaries. The intermediate answer should be exactly 2000 characters long. This must be formatted in markdown and must begin with a header that explains how the following text is related to the query.

3. follow_up_queries: A list of follow-up queries that could be asked to further explore the topic. These should be formatted as a list of strings. Generate at least five good follow-up queries.

Use this information to help you decide whether or not you need more information about the entities mentioned in the report. You may also use your general knowledge to think of entities which may help enrich your answer.

You will also provide a full answer from the content you have available. Use the data provided to generate follow-up queries to help refine your search. Do not ask compound questions, for example: "What is the market cap of Apple and Microsoft?". Use your knowledge of the entity distribution to focus on entity types that will be useful for searching a broad area of the knowledge graph.

For the query:

{query}

The top-ranked community summaries:

{community_reports}

Provide the intermediate answer, and all scores in JSON format following:

{{'intermediate_answer': str,
'score': int,
'follow_up_queries': List[str]}}

Begin:
"""

HYDE_PROMPT = """Create a hypothetical answer to the following query: {query}\n\n
          Format it to follow the structure of the template below:\n\n
          {template}\n"
          Ensure that the hypothetical answer does not reference new named entities that are not present in the original query."""

In [None]:
class CommunitySearch(Event):
    query: str
    hyde_query: str

class LocalSearch(Event):
    query: str
    local_query: str

class LocalSearchResults(Event):
    results: dict

class FinalAnswer(Event):
    query: str

In [None]:
DEFAULT_RESPONSE_TYPE = "multiple paragraphs"
LOCAL_TOP_K = 3
MAX_LOCAL_SEARCH_DEPTH = 2


class DriftSearch(Workflow):
    @step
    async def hyde_generation(self, ev: StartEvent) -> CommunitySearch:
        random_community_report = driver.execute_query(
            """
        MATCH (c:__Community__)
        WHERE c.summary IS NOT NULL
        RETURN coalesce(c.title, "") + " " + c.summary AS community_description""",
            result_transformer_=lambda r: r.data(),
        )
        hyde = HYDE_PROMPT.format(
            query=ev.query, template=random_community_report[0]["community_description"]
        )
        hyde_response = await client.responses.create(
            model="gpt-5-mini",
            input=[{"role": "user", "content": hyde}],
            reasoning={"effort": "low"},
        )
        return CommunitySearch(query=ev.query, hyde_query=hyde_response.output_text)

    @step
    async def community_search(self, ctx: Context, ev: CommunitySearch) -> LocalSearch:
        embedding_response = await client.embeddings.create(
            input=ev.hyde_query, model=TEXT_EMBEDDING_MODEL
        )
        embedding = embedding_response.data[0].embedding
        community_reports = driver.execute_query(
            """
        CALL db.index.vector.queryNodes('community', 5, $embedding) YIELD node, score
        RETURN 'community-' + node.id AS source_id, node.summary AS community_summary
        """,
            result_transformer_=lambda r: r.data(),
            embedding=embedding,
        )
        initial_prompt = DRIFT_PRIMER_PROMPT.format(
            query=ev.query, community_reports=community_reports
        )
        initial_response = await client.responses.create(
            model="gpt-5-mini",
            input=[{"role": "user", "content": initial_prompt}],
            reasoning={"effort": "low"},
        )
        response_json = json_repair.loads(initial_response.output_text)
        print(f"Initial intermediate response: {response_json['intermediate_answer']}")
        # Set global states
        async with ctx.store.edit_state() as ctx_state:
            ctx_state["intermediate_answers"] = [
                {
                    "intermediate_answer": response_json["intermediate_answer"],
                    "score": response_json["score"],
                }
            ]
            ctx_state["local_search_num"] = len(response_json["follow_up_queries"])
        # Run follow-up queries in parallel
        for local_query in response_json["follow_up_queries"]:
            ctx.send_event(LocalSearch(query=ev.query, local_query=local_query))
        return None

    @step(num_workers=5)
    async def local_search(self, ev: LocalSearch) -> LocalSearchResults:
        print(f"Running local query: {ev.local_query}")
        response = await client.embeddings.create(
            input=ev.local_query, model=TEXT_EMBEDDING_MODEL
        )
        embedding = response.data[0].embedding
        local_reports = driver.execute_query(
            """
    CALL db.index.vector.queryNodes('entity', 5, $embedding) YIELD node, score
    WITH collect(node) AS nodes
    WITH
  collect {
    UNWIND nodes as n
    MATCH (n)<-[:MENTIONS]->(c:__Chunk__)
    WITH c, count(distinct n) as freq
    RETURN {chunkText: c.text, source_id: 'chunk-' + c.id}
    ORDER BY freq DESC
    LIMIT 3
} AS text_mapping,
collect {
    UNWIND nodes as n
    MATCH (n)-[:IN_COMMUNITY*]->(c:__Community__)
    WHERE c.summary IS NOT NULL
    WITH c, c.rating as rank
    RETURN {summary: c.summary, source_id: 'community-' + c.id}
    ORDER BY rank DESC
    LIMIT 3
} AS report_mapping,
collect {
    UNWIND nodes as n
    MATCH (n)-[r:SUMMARIZED_RELATIONSHIP]-(m)
    WHERE m IN nodes
    RETURN {descriptionText: r.summary, source_id: 'relationship-' + n.name + '-' + m.name}
    LIMIT 3
} as insideRels,
collect {
    UNWIND nodes as n
    RETURN {descriptionText: n.summary, source_id: 'node-' + n.name}
} as entities
RETURN {Chunks: text_mapping, Reports: report_mapping,
       Relationships: insideRels,
       Entities: entities} AS output
    """,
            result_transformer_=lambda r: r.data(),
            embedding=embedding,
        )
        local_prompt = DRIFT_LOCAL_SYSTEM_PROMPT.format(
            response_type=DEFAULT_RESPONSE_TYPE,
            context_data=local_reports,
            global_query=ev.query,
        )
        local_response = await client.responses.create(
            model="gpt-5-mini",
            input=[{"role": "user", "content": local_prompt}],
            reasoning={"effort": "low"},
        )
        response_json = json_repair.loads(local_response.output_text)
        # Trim to topK
        response_json["follow_up_queries"] = response_json["follow_up_queries"][
            :LOCAL_TOP_K
        ]
        return LocalSearchResults(results=response_json, query=ev.query)

    @step
    async def local_search_results(
        self, ctx: Context, ev: LocalSearchResults
    ) -> LocalSearch | FinalAnswer:
        local_search_num = await ctx.store.get("local_search_num")
        results = ctx.collect_events(ev, [LocalSearchResults] * local_search_num)
        if results is None:
            return None
        intermediate_results = [
            {
                "intermediate_answer": event.results["response"],
                "score": event.results["score"],
            }
            for event in results
        ]
        current_depth = await ctx.store.get("local_search_depth", default=1)
        # Parse out original query
        query = [ev.query for ev in results][0]

        if current_depth < MAX_LOCAL_SEARCH_DEPTH:
            await ctx.store.set("local_search_depth", current_depth + 1)
            follow_up_queries = [
                query
                for event in results
                for query in event.results["follow_up_queries"]
            ]
            # Set global states
            async with ctx.store.edit_state() as ctx_state:
                ctx_state["intermediate_answers"].extend(intermediate_results)
                ctx_state["local_search_num"] = len(follow_up_queries)

            for local_query in follow_up_queries:
                ctx.send_event(LocalSearch(query=query, local_query=local_query))
            return None
        else:
            return FinalAnswer(query=query)

    @step
    async def final_answer_generation(self, ctx: Context, ev: FinalAnswer) -> StopEvent:
        intermediate_answers = await ctx.store.get("intermediate_answers")
        answer_prompt = DRIFT_REDUCE_PROMPT.format(
            response_type=DEFAULT_RESPONSE_TYPE,
            context_data=intermediate_answers,
            global_query=ev.query,
        )
        answer_response = await client.responses.create(
            model="gpt-5-mini",
            input=[
                {"role": "developer", "content": answer_prompt},
                {"role": "user", "content": ev.query},
            ],
            reasoning={"effort": "low"},
        )

        return StopEvent(result=answer_response.output_text)

In [None]:
global_question = """How do the various characters Alice encounters in Wonderland, such as the Cheshire Cat, the Caterpillar,
and the Queen of Hearts, challenge her understanding of logic, rules, and authority, and what does this reveal about the nature of the adult world she is entering?"""

w = DriftSearch(timeout=3600, verbose=False)
result = await w.run(query=global_question)

Initial intermediate response: # How this relates to the query

Alice's encounters in Wonderland stage a series of vignettes that unsettle her (and the reader's) assumptions about logic, rules, and authority. The Cheshire Cat undermines consistent logic by appearing and disappearing at will and by speaking in riddling paradoxes; its grin and detachment suggest that identity and reason can be fragmentary. The Caterpillar challenges Alice's sense of stable self and taxonomy — his questions about who she is and his languid, cryptic counsel force Alice to confront the malleability of identity and the limits of straightforward instruction. The Queen of Hearts institutionalizes arbitrary authority: her frequent cries of "Off with their heads!" and the mock-legal procedures of the trial expose how power can be performative, capricious, and divorced from justice.

Together these figures reveal an adult world where rules exist but lack coherent grounding. Logic becomes local and conversational 

In [None]:
print(result)

Summary

Alice’s encounters in Wonderland repeatedly unsettle her expectations about consistent logic, stable rules, and legitimate authority. Key figures—the Cheshire Cat, the Caterpillar, and the Queen of Hearts—each expose different ways that meaning, measurement, and power in the adult world can be arbitrary, performative, or detached from responsibility. These episodes push Alice toward a skeptical, adaptive stance rather than simple acceptance of adult rules [Data: Entities (node-ALICE); Reports (community-2-9)].

How specific characters challenge Alice

- Cheshire Cat — instability of identity and guidance: The Cat appears and disappears at will (sometimes leaving only a grin), offering ambiguous directional or philosophical remarks rather than clear guidance. Its vanishing/returning undermines Alice’s trust that an authoritative speaker or guide will be consistently present or accountable [Data: Entities (node-CHESHIRE CAT, node-CAT); Reports (community-2-4)].

- Caterpillar — 