# Build a RAG application with graph links between documents

This tutorial shows how to create links between documents in a GraphVectorStore and use it to get more relevant responses when querying.

It contains 2 main sections:
* [Preparation: load the DataStaxAstra Documentation into GraphVectorStore](#preparation-load-the-datastax-astra-documentation-into-graphvectorstore)
* [Create and execute the RAG Chains](#create-and-execute-the-rag-chains)

## Preliminaries

In [None]:
%pip install -q langchain-community beautifulsoup4 markdownify python-dotenv

## Preparation: load the DataStax Astra Documentation into GraphVectorStore

First, we'll crawl the DataStax documentation. At the moment, `SiteMapLoader` loads all of the pages into memory simultaneously, which makes it impossible to index larger sites from small environments (such as CoLab). So, we'll scrape the sitemap ourselves and iterate over the URLs, allowing us to process documents in batches and flush them to Astra DB. 

### Scrape the URLs from the Site Maps
First, we use Beautiful Soup to parse the XML content of each sitemap and get the list of URLs.
We also add a few extra URLs for external sites that are useful to include in the index.

In [None]:
import requests
from bs4 import BeautifulSoup

# Use sitemaps to crawl the content
SITEMAPS = [
    "https://docs.datastax.com/en/sitemap-astra-db-vector.xml",
    "https://docs.datastax.com/en/sitemap-cql.xml",
    "https://docs.datastax.com/en/sitemap-dev-app-drivers.xml",
    "https://docs.datastax.com/en/sitemap-glossary.xml",
    "https://docs.datastax.com/en/sitemap-astra-db-serverless.xml",
]

# Additional URLs to crawl for content.
EXTRA_URLS = ["https://github.com/jbellis/jvector"]


def load_pages(sitemap_url):
    r = requests.get(
        sitemap_url,
        headers={
            # Astra docs only return a sitemap with a user agent set.
            "User-Agent": "Mozilla/5.0 (X11; Linux x86_64; rv:58.0) Gecko/20100101 "
            "Firefox/58.0",
        },
        timeout=30,
    )
    xml = r.text

    soup = BeautifulSoup(xml, features="xml")
    url_tags = soup.find_all("url")
    for url in url_tags:
        yield (url.find("loc").text)


# For maintenance purposes, we could check only the new articles since a given time.
URLS = [url for sitemap_url in SITEMAPS for url in load_pages(sitemap_url)] + EXTRA_URLS
len(URLS)

### Load the content from each URL
Next, we create the code to load each page. This performs the following steps:

1. Parses the HTML with BeautifulSoup
2. Locates the "content" of the HTML using an appropriate selector based on the URL
3. Use an HtmlLinkExtractor to find the link (`<a href="...">`) tags in the content and collect the absolute URLs (for creating edges).

Adding the URLs of these references to the metadata allows the graph store to create edges between the documents.

In [None]:
from typing import AsyncIterator, Iterable

from langchain_community.document_loaders import AsyncHtmlLoader
from langchain_community.graph_vectorstores.extractors import HtmlInput, HtmlLinkExtractor
from langchain_core.documents import Document
from langchain_core.graph_vectorstores.links import add_links
from markdownify import MarkdownConverter

markdown_converter = MarkdownConverter(heading_style="ATX")
html_link_extractor = HtmlLinkExtractor()


def select_content(soup: BeautifulSoup, url: str) -> BeautifulSoup:
    if url.startswith("https://docs.datastax.com/en/"):
        return soup.select_one("article.doc")
    if url.startswith("https://github.com"):
        return soup.select_one("article.entry-content")
    return soup


async def load_pages(urls: Iterable[str]) -> AsyncIterator[Document]:
    loader = AsyncHtmlLoader(
        urls,
        requests_per_second=4,
        # Astra docs require a user agent
        header_template={
            "User-Agent": "Mozilla/5.0 (X11; Linux x86_64; rv:58.0) Gecko/20100101 "
            "Firefox/58.0"
        },
    )
    async for html in loader.alazy_load():
        url = html.metadata["source"]

        # Use the URL as the doc ID.
        html.id = url

        # Apply the selectors while loading. This reduces the size of
        # the document as early as possible for reduced memory usage.
        soup = BeautifulSoup(html.page_content, "html.parser")
        content = select_content(soup, url)

        # Extract HTML links from the content.
        add_links(html, html_link_extractor.extract_one(HtmlInput(content, url)))

        # Convert the content to markdown
        html.page_content = markdown_converter.convert_soup(content)

        yield html

### Initialize Environment
Before we initialize the Graph Store and write the documents we need to set some environment variables.

In [None]:
import getpass
import os
from dotenv import load_dotenv

load_dotenv()

if "OPENAI_API_KEY" not in os.environ:
    os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter OpenAI API Key: ")
    
if "ASTRA_DB_DATABASE_ID" not in os.environ:
    os.environ["ASTRA_DB_DATABASE_ID"] = input("Enter Astra DB Database ID: ")
    
if "ASTRA_DB_APPLICATION_TOKEN" not in os.environ:
    os.environ["ASTRA_DB_APPLICATION_TOKEN"] = getpass.getpass(
        "Enter Astra DB Application Token: "
)
    
if "ASTRA_DB_KEYSPACE" not in os.environ:
    keyspace = input("Enter Astra DB Keyspace (Empty for default): ")
    if keyspace:
        os.environ["ASTRA_DB_KEYSPACE"] = keyspace

### Initialize Cassio and GraphVectorStore
With the environment variables set, we initialize the Cassio library for talking to Cassandra / Astra DB.
We also create the `GraphVectorStore`.

In [None]:
import cassio
from langchain_openai import OpenAIEmbeddings
from langchain_community.graph_vectorstores import CassandraGraphVectorStore

cassio.init(auto=True)
embeddings = OpenAIEmbeddings()
graph_vectorstore = CassandraGraphVectorStore(
    embeddings,
    node_table=f"astra_docs_nodes",
)

### Load the Documents
Finally, we fetch pages and write them to the graph store in batches of 50.

In [None]:
not_found = 0
found = 0
BATCH_SIZE = 50

docs = []
async for doc in load_pages(URLS):
    if doc.page_content.startswith("\n# Page Not Found"):
        not_found += 1
        continue

    docs.append(doc)
    found += 1

    if len(docs) >= BATCH_SIZE:
        graph_vectorstore.add_documents(docs)
        docs.clear()

if docs:
    graph_vectorstore.add_documents(docs)
print(f"{not_found} (of {not_found + found}) URLs were not found")

## Create and execute the RAG Chains

In [None]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-4o")

template = """You are a helpful technical support bot. You should provide complete answers explaining the options the user has available to address their problem. Answer the question based only on the following context:
{context}

Question: {question}
"""  # noqa: E501
prompt = ChatPromptTemplate.from_template(template)


def format_docs(docs):
    return "\n\n".join(
        f"From {doc.metadata['content_id']}: {doc.page_content}" for doc in docs
    )

We'll use the following question. This is an interesting question because the ideal answer should be concise and in-depth, based on how the vector indexing is actually implemented.

In [None]:
QUESTION = "What vector indexing algorithms does Astra use?"

In [None]:
from IPython.display import Markdown, display


# Helper method to render markdown in responses to a chain.
def run_and_render(chain, question):
    result = chain.invoke(question)
    display(Markdown(result))

### Vector-Only Retrieval

In [None]:
# Depth 0 doesn't traverses edges and is equivalent to vector similarity only.
vector_retriever = graph_vectorstore.as_retriever(search_kwargs={"depth": 0})

vector_rag_chain = (
    {"context": vector_retriever | format_docs, "question": RunnablePassthrough()}
    | prompt
    | llm
    | StrOutputParser()
)

In [None]:
run_and_render(vector_rag_chain, QUESTION)

### Graph Traversal Retrieval

In [None]:
# Depth 1 does vector similarity and then traverses 1 level of edges.
graph_retriever = graph_vectorstore.as_retriever(search_kwargs={"depth": 1})

graph_rag_chain = (
    {"context": graph_retriever | format_docs, "question": RunnablePassthrough()}
    | prompt
    | llm
    | StrOutputParser()
)

In [None]:
run_and_render(graph_rag_chain, QUESTION)

### MMR Graph Traversal

In [None]:
mmr_graph_retriever = graph_vectorstore.as_retriever(
    search_type="mmr_traversal",
    search_kwargs={
        "k": 4,
        "fetch_k": 10,
        "depth": 2,
        # "score_threshold": 0.2,
    },
)

mmr_graph_rag_chain = (
    {"context": mmr_graph_retriever | format_docs, "question": RunnablePassthrough()}
    | prompt
    | llm
    | StrOutputParser()
)

In [None]:
run_and_render(mmr_graph_rag_chain, QUESTION)

### Check Retrieval Results

In [None]:
# Set the question and see what documents each technique retrieves.
for i, doc in enumerate(vector_retriever.invoke(QUESTION)):
    print(f"Vector [{i}]:    {doc.id}")

for i, doc in enumerate(graph_retriever.invoke(QUESTION)):
    print(f"Graph [{i}]:     {doc.id}")

for i, doc in enumerate(mmr_graph_retriever.invoke(QUESTION)):
    print(f"MMR Graph [{i}]: {doc.id}")

## Conclusion
With vector only we retrieved chunks from the Astra documentation explaining that it used JVector.
Since it didn't follow the link to [JVector on GitHub](https://github.com/jbellis/jvector) it didn't actually answer the question.

The graph retrieval started with the same set of chunks, but it followed the edge to the documents we loaded from GitHub.
This allowed the LLM to read in more depth how JVector is implemented, which allowed it to answer the question more clearly and with more detail.