# Retrieval Beyond Similarity: LazyGraphRAG in LangChain

## Introduction

In [LazyGraphRAG](https://www.microsoft.com/en-us/research/blog/lazygraphrag-setting-a-new-standard-for-quality-and-cost/), Microsoft demonstrates significant cost and performance benefits to delaying the construction of a knowledge graph.
This is largely because not all documents need to be analyzed.
However, it is also benefical that documents by the time documents are analyzed the question is already known, allowing irrelevant information to be ignored. 

We've noticed similar cost benefits to building a document graph linking content based on simple properties such as extracted keywords compared to building a complete knowledge graph.
For the Wikipedia dataset used in this notebook, we estimated it would have taken $70k to build a knowledege graph using the [example from LangChain](https://python.langchain.com/docs/how_to/graph_constructing/#llm-graph-transformer), while the document graph was basically free.

In this notebook we demonstrate how to populate a document graph with Wikipedia articles linked based on mentions in the articles and extracted keywords.
Keyword extraction uses a local [KeyBERT](https://maartengr.github.io/KeyBERT/) model, making it fast and cost-effective to construct these graphs.
We'll then show how to build out a chain which does the steps of Lazy GraphRAG -- retrieving articles, extracting claims from each community, ranking and selecting the top claims, and generating an answer based on those claims.

## Environment Setup

The following block will configure the environment from the Colab Secrets.
To run it, you should have the following Colab Secrets defined and accessible to this notebook:

- `OPENAI_API_KEY`: The OpenAI key.
- `ASTRA_DB_API_ENDPOINT`: The Astra DB API endpoint.
- `ASTRA_DB_APPLICATION_TOKEN`: The Astra DB Application token.
- `LANGCHAIN_API_KEY`: Optional. If defined, will enable LangSmith tracing.
- `ASTRA_DB_KEYSPACE`: Optional. If defined, will specify the Astra DB keyspace. If not defined, will use the default.

In [1]:
#@ Install modules.
%pip install \
    langchain-core \
    langchain-astradb \
    langchain-openai \
    langchain-graph-retriever \
    graph-rag-example-helpers

Note: you may need to restart the kernel to use updated packages.


The last package -- `graph-rag-example-helpers` -- includes some helpers for setting up environment helpers and allowing the loading of wikipedia data to be restarted if it fails.

In [2]:
# Configure import paths.
import os
import sys

sys.path.append("../../")

# Initialize environment variables.
from graph_rag_example_helpers.env import Environment, initialize_environment

initialize_environment(Environment.ASTRAPY)

os.environ["LANGCHAIN_PROJECT"] = "lazy-graph-rag"

## Part 1: Loading Data

First, we'll demonstrate how to load Wikipedia data into an `AstraDBVectorStore`, using the mentioned articles and keywords as metadata fields.
In this section, we're not actually doing anything special for the graph -- we're just populating the metadata with fields that useful describe our content.

## Create Documents from Wikipedia Articles
The first thing we need to do is create the `LangChain` `Document`s we'll import.

To do this, we write some code to convert lines from a JSON file downloaded from [2wikimultihop](https://github.com/Alab-NII/2wikimultihop?tab=readme-ov-file#new-update-april-7-2021) and create a `Document`.
We populate the `id` and `metadata["mentions"]` from information in this file.

Then, we run those documents through the `KeybertKeywordExtractor` to populate `metadata["keywords"]` with the suggested keywords from each article.

In [3]:
import json
from collections.abc import Iterator

from langchain_core.documents import Document
from langchain_graph_retriever.document_transformers.keybert import KeybertKeywordExtractor

def parse_document(line: bytes) -> Document:
    """Reads one JSON line from the wikimultihop dump."""
    para = json.loads(line)

    id = para["id"]

    # Use structured information (mentioned Wikipedia IDs) as metadata.
    mentioned_ids = [id for m in para["mentions"] for m in m["ref_ids"] or []]

    return Document(
        id=id,
        page_content=" ".join(para["sentences"]),
        metadata={
            "mentions": mentioned_ids,
        },
    )


KEYBERT_TRANSFORMER = KeybertKeywordExtractor()


# Load data in batches, using GLiNER to extract entities.
def prepare_batch(lines: Iterator[str]) -> Iterator[Document]:
    # Parse documents from the batch of lines.
    docs = [parse_document(line) for line in lines]

    docs = KEYBERT_TRANSFORMER.transform_documents(docs)

    return docs

  from .autonotebook import tqdm as notebook_tqdm


## Create the AstraDBVectorStore
Next, we create the Vector Store we're going to load these documents into.
In our case, we use DataStax Astra DB with Open AI embeddings.

In [6]:
# @ Create the AstraDBVectorStore

from langchain_astradb import AstraDBVectorStore
from langchain_openai import OpenAIEmbeddings

COLLECTION = "lazy_graph_rag"
store = AstraDBVectorStore(
    embedding=OpenAIEmbeddings(),
    collection_name=COLLECTION,
)

## Loading Data into the Store
Next, we perform the actual loading.
This takes a while, so we use a helper utility to persist which batches have been written so we can resume if there are any failures.

On OS X, it is useful to run `caffeinate -dis` in a shell to prevent the machine from going to sleep and seems to reduce errors.

In [7]:
import os
import os.path

# Path to the file `para_with_hyperlink.zip`.
# See instructions here to download from
# [2wikimultihop](https://github.com/Alab-NII/2wikimultihop?tab=readme-ov-file#new-update-april-7-2021).
PARA_WITH_HYPERLINK_ZIP = os.path.join(
    os.getcwd(), "para_with_hyperlink.zip"
)

from graph_rag_example_helpers.datasets.wikimultihop import aload_2wikimultihop
await aload_2wikimultihop(
    PARA_WITH_HYPERLINK_ZIP,
    store,
    prepare_batch)

Resuming loading with 13 completed, 5977 remaining


  1%|          | 33/5977 [36:11<108:38:24, 65.80s/it]
  + Exception Group Traceback (most recent call last):
  |   File "/Users/benjamin.chambers/code/graph-pancake/.venv/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3575, in run_code
  |     await eval(code_obj, self.user_global_ns, self.user_ns)
  |   File "/var/folders/c4/dcr0mh3d183d5kh9gf89wsc00000gn/T/ipykernel_60249/1917728629.py", line 12, in <module>
  |     await aload_2wikimultihop(
  |   File "/Users/benjamin.chambers/code/graph-pancake/packages/graph-rag-example-helpers/src/graph_rag_example_helpers/datasets/wikimultihop/load.py", line 86, in aload_2wikimultihop
  |     async with asyncio.TaskGroup() as tg:
  |                ^^^^^^^^^^^^^^^^^^^
  |   File "/Users/benjamin.chambers/.local/share/uv/python/cpython-3.12.8-macos-aarch64-none/lib/python3.12/asyncio/taskgroups.py", line 71, in __aexit__
  |     return await self._aexit(et, exc)
  |            ^^^^^^^^^^^^^^^^^^^^^^^^^^
  |   File "/Users/b

At this point, we've created a `VectorStore` with the Wikipedia articles.
Each article is associated with metadata identifying other articles it mentions and keywords from the article.
This could be used for hybrid search -- performing a vector search for articles similar to a specific question *that also mention a specific term*.

The library `langchain-graph-retriever` makes this even more useful, allowing traversing between articles either explicitly mentioned or dealing with the same keywords.

In the next section, we'll go a step further and perform Lazy GraphRAG to extract relevant claims from both the similar and related articles and use the most relevant claims to answer the question.

## Part 2: Lazy Graph RAG via Hierarchical Summarization

As we've noted before, eagerly building a knowledge graph is prohibitively expensive.
Microsoft seems to agree, and recently introduced LazyGraphRAG, which enables GraphRAG to be performed late -- after a query is retrieved.

We implement the LazyGraphRAG technique using the traversing retrievers as follows:

1. Retrieve a good number of nodes using a traversing retrieval.
2. Identify communities in the retrieved sub-graph.
3. Extract claims from each community relevant to the query using an LLM.
4. Rank each of the claims based on the relevance to the question and select the top claims.
5. Generate an answer to the question based on the extracted claims.

### LangChain for Extracting Claims

In [28]:
from collections.abc import Iterable
from operator import itemgetter
from typing import TypedDict

from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableLambda, RunnableParallel, chain
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field


class Claim(BaseModel):
    """Representation of an individual claim from a source document(s)."""

    claim: str = Field(description="The claim from the original document(s).")
    source_id: str = Field(description="Document ID containing the claim.")


class Claims(BaseModel):
    """Claims extracted from a set of source document(s)."""

    claims: list[Claim] = Field(description="The extracted claims.")


MODEL = ChatOpenAI(model="gpt-4o", temperature=0)
CLAIMS_MODEL = MODEL.with_structured_output(Claims)

CLAIMS_PROMPT = ChatPromptTemplate.from_template("""
Extract claims from the following related documents.

Only return claims appearing within the specified documents.
If no documents are provided, do not make up claims or documents.

Claims (and scores) should be relevant to the question.
Don't include claims from the documents if they are not directly or indirectly
relevant to the question.

If none of the documents make any claims relevant to the question, return an
empty list of claims.

If multiple documents make similar claims, include the original text of each as
separate claims. Score the most useful and authoritative claim higher than
similar, lower-quality claims.

Question: {question}

{formatted_documents}
""")

# TODO: Few-shot examples? Possibly with a selector?

def format_documents_with_ids(documents: Iterable[Document]) -> str:
    formatted_docs = "\n\n".join(
        f"Document ID: {doc.id}\nContent: {doc.page_content}" for doc in documents
    )
    return formatted_docs


CLAIM_CHAIN = (
    RunnableParallel(
        {
            "question": itemgetter("question"),
            "formatted_documents": itemgetter("documents")
            | RunnableLambda(format_documents_with_ids),
        }
    )
    | CLAIMS_PROMPT
    | CLAIMS_MODEL
)


class ClaimsChainInput(TypedDict):
    question: str
    communities: Iterable[Iterable[Document]]


@chain
async def claims_chain(input: ClaimsChainInput) -> Iterable[Claim]:
    question = input["question"]
    communities = input["communities"]

    # TODO: Use openai directly so this can use the batch API for performance/cost?
    community_claims = await CLAIM_CHAIN.abatch(
        [{"question": question, "documents": community} for community in communities]
    )
    return [claim for community in community_claims for claim in community.claims]

### LangChain for Ranking Claims

This is based on ideas from [RankRAG](https://arxiv.org/abs/2407.02485).
Specifically, the prompt is constructed so that the next token should be `True` if the content is relevant and `False` if not.
The probability of the token is used to determine the relevance -- `True` with a higher probability is more relevant than `True` with a lesser probability.

In [29]:
import math

from langchain_core.runnables import chain

RANK_PROMPT = ChatPromptTemplate.from_template("""
Rank the relevance of the following claim to the question.
Output "True" if the claim is relevant and "False" if it is not.
Only output True or False.

Question: Where is Seattle?

Claim: Seattle is in Washington State.

Relevant: True

Question: Where is LA?

Claim: New York City is in New York State.

Relevant: False

Question: {question}

Claim: {claim}

Relevant:
""")


def compute_rank(msg):
    logprob = msg.response_metadata["logprobs"]["content"][0]
    prob = math.exp(logprob["logprob"])
    token = logprob["token"]
    if token == "True":
        return prob
    elif token == "False":
        return 1.0 - prob
    else:
        raise ValueError(f"Unexpected logprob: {logprob}")


RANK_CHAIN = RANK_PROMPT | MODEL.bind(logprobs=True) | RunnableLambda(compute_rank)


class RankChainInput(TypedDict):
    question: str
    claims: Iterable[Claim]


@chain
async def rank_chain(input: RankChainInput) -> Iterable[Claim]:
    # TODO: Use openai directly so this can use the batch API for performance/cost?
    claims = input["claims"]
    ranks = await RANK_CHAIN.abatch(
        [{"question": input["question"], "claim": claim} for claim in claims]
    )
    rank_claims = sorted(
        zip(ranks, claims, strict=True), key=lambda rank_claim: rank_claim[0]
    )

    return [claim for _, claim in rank_claims]

## LazyGraphRAG in LangChain

In [30]:
from langchain_core.language_models import BaseLanguageModel
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import chain
from langchain_graph_retriever.document_graphs import create_graph, group_by_community


@chain
async def lazy_graph_rag(
    question: str,
    *,
    retriever: BaseRetriever,
    edges: Iterable[str | tuple[str, str]] | None = None,
    model: BaseLanguageModel,
    max_tokens: int = 1000,
) -> str:
    """Retrieve claims relating to the question using LazyGraphRAG.

    Returns the top claims up to the given `max_tokens` as a markdown list.

    """

    if edges is None:
        try:
            edges = retriever.edges
        except AttributeError as _:
            raise ValueError(
                "Must specify 'edges' or provide a retriever with 'edges' field defined"
            )

    # 1. Retrieve documents using the (traversing) retriever.
    documents = await retriever.ainvoke(question)

    # 2. Create a graph and extract communities.
    documents_by_id, doc_graph = create_graph(
        documents,
        edges=edges,
        directed=False,
    )
    communities = group_by_community(documents_by_id, doc_graph)

    # 3. Extract claims from the communities.
    claims = await claims_chain.ainvoke(
        {"question": question, "communities": communities}
    )

    # 4. Rank the claims and select claims up to the given token limit.
    result_claims = []
    tokens = 0

    for claim in await rank_chain.ainvoke({"question": question, "claims": claims}):
        claim_str = f"- {claim.claim} (Source: {claim.source_id})"

        tokens += model.get_num_tokens(claim_str)
        if tokens > max_tokens:
            break
        result_claims.append(claim_str)

    return "\n".join(result_claims)

## Using LazyGraphRAG in LangChain

In [31]:
from langchain_community.retrievers.graph_traversal import (
    AstraTraversalAdapter,
    GraphTraversalRetriever,
)
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

EDGES = [("mentions", "id"), "entities"]

RETRIEVER = GraphTraversalRetriever(
    store=AstraTraversalAdapter(store),
    edges=EDGES,
    start_k=100,
    depth=3,
)

# RETRIEVER = GraphMMRTraversalRetriever(
#     store = AstraMMRTraversalAdapter(store),
#     edges = EDGES,
#     k = 100,
#     depth = 5,
#     fetch_k = 100,
#     adjacent_k = 25,
#     lambda_mult = 0.8,
#     score_threshold = float("-inf"),
# )

ANSWER_PROMPT = PromptTemplate.from_template("""
Answer the question based on the supporting claims.

Only use information from the claims. Do not guess or make up any information.

Where possible, reference and quote the supporting claims.

Question: {question}

Claims:
{claims}
""")

LAZY_GRAPH_RAG_CHAIN = (
    {
        "question": RunnablePassthrough(),
        "claims": RunnablePassthrough()
        | lazy_graph_rag.bind(
            retriever=RETRIEVER,
            model=MODEL,
            max_tokens=1000,
        ),
    }
    | ANSWER_PROMPT
    | MODEL
)

In [None]:
await LAZY_GRAPH_RAG_CHAIN.ainvoke("Where is Azerbaijan?")

The LazyGraphRAG chain is great when a question needs to consider a large amount of relevant information in order to produce a thorough answer.

## Conclusion

This post introduced _traversing retrievers_ which allow any `VectorStore` to be traversed as a knowledge graph based on properties in the metadata.
This means you can focus on populating and using your `VectorStore` with useful metadata and add GraphRAG when you need it.
We also saw that these traversing retrievers mean that any `VectorStore` can be used with LazyGraphRAG, without needing to change the stored documents.

Knowledge Graphs and GraphRAG shouldn't be hard or scary.
Start simple and easily overlay edges when you need them.

These traversing retrievers and LazyGraphRAG summarization work well with agents.
You can create tools that use different retriever configurations, for instance, searching for articles "near" existing articles or distinguishing between questions that only need a few references and deeper questions which need to retrieve and summarize a larger amount of content.
We'll show how to combine these graph techniques with agents in future posts.
Until then, ...