# Spanner Graph for GraphRAG

> [Spanner](https://cloud.google.com/spanner) is a highly scalable database that combines unlimited scalability with relational semantics, such as secondary indexes, strong consistency, schemas, and SQL providing 99.999% availability in one easy solution.

This notebook goes over how to use `Spanner Graph` for GraphRAG with the custom retriever `SpannerGraphVectorContextRetriever` and compares the response of GraphRAG with conventional RAG.

‚ùóThis notebook is adapted from [Spanner Graph Retrievers Usage](https://github.com/googleapis/langchain-google-spanner-python/blob/main/docs/graph_rag.ipynb) with slight modifications.

## ü¶úüîó Library Installation
The integration lives in its own `langchain-google-spanner` package, so we need to install it.

In [None]:
%%sh

pip freeze | grep -E "langchain|spanner-graph-notebook|pydata-google-auth|pydantic"

In [None]:
%%sh
pip install -r requirements.txt

In [None]:
## Automatically restart kernel after installs so that your environment can access the new packages
import IPython

app = IPython.Application.instance()
app.kernel.do_shutdown(True)

## üîê Authentication
Authenticate to Google Cloud as the IAM user logged into this notebook in order to access your Google Cloud Project.

If you are using Vertex AI Workbench, check out the setup instructions [here](https://github.com/GoogleCloudPlatform/generative-ai/tree/main/setup-env).

## ‚òÅ Set Your Google Cloud Project

In [None]:
PROJECT_ID = !gcloud config get project
PROJECT_ID = PROJECT_ID[0]
LOCATION = "us-central1"

PROJECT_ID, LOCATION

## üí° API Enablement
The `langchain-google-spanner` package requires that you [enable the Spanner API](https://console.cloud.google.com/flows/enableapi?apiid=spanner.googleapis.com) in your Google Cloud Project.

In [None]:
# enable Spanner API
!gcloud services enable spanner.googleapis.com

## Set Spanner database values
Find your database values, in the [Spanner Instances page](https://console.cloud.google.com/spanner?).

NOTE:
 - The database identified by INSTANCE and DATABASE must be created beforehand.
 - The graph does NOT need to be created beforehand.
   
   If such a graph already exists, graphs will be built atop; otherwise a new graph is automatically created.

In [None]:
# @title Set Your Values Here { display-mode: "form" }
INSTANCE = "graph-store-01"  # @param {type: "string"}
DATABASE = "test_graph_db"  # @param {type: "string"}
GRAPH_NAME = "retail_graph"  # @param {type: "string"}
USE_FLEXIBLE_SCHEMA = False  # @param {type: "boolean"}

## SpannerGraphStore

To initialize the `SpannerGraphStore` class you need to provide 3 required arguments and other arguments are optional and only need to pass if it's different from default ones

1.   a Spanner instance id
2.   a Spanner database id belongs to the above instance id
3.   a Spanner graph name used to create a graph in the above database.

In [None]:
from langchain_google_spanner import SpannerGraphStore

graph_store = SpannerGraphStore(
    instance_id=INSTANCE,
    database_id=DATABASE,
    graph_name=GRAPH_NAME
)

## Add Graph Documents
To add graph documents in the graph store.

### Load Documents

In [None]:
!wget https://raw.githubusercontent.com/googleapis/langchain-google-spanner-python/main/samples/retaildata.zip -P content
!unzip "content/retaildata.zip" -d content

In [None]:
import os
from langchain_community.document_loaders import DirectoryLoader
from langchain_community.document_loaders import TextLoader
from langchain_core.documents import Document

path = "content/retaildata/"
directories = [
    item for item in os.listdir(path) if os.path.isdir(os.path.join(path, item))
]

document_lists = []
for directory in directories:
    loader = DirectoryLoader(
        os.path.join(path, directory), glob="**/*.txt", loader_cls=TextLoader
    )
    document_lists.append(loader.load())

In [None]:
len(document_lists)

### Extract Nodes and Edges

In [None]:
from langchain_experimental.graph_transformers import LLMGraphTransformer
from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings

# llm = VertexAI(model_name="gemini-2.5-flash")
# ValueError: The 'node_properties' and 'relationship_properties' parameters cannot be used in combination with a LLM that doesn't support native function calling.
llm = ChatVertexAI(model="gemini-2.5-flash", temperature=0)
llm_transformer = LLMGraphTransformer(
    llm=llm,
    allowed_nodes=["Category", "Segment", "Tag", "Product", "Bundle", "Deal"],
    allowed_relationships=[
        "In_Category",
        "Tagged_With",
        "In_Segment",
        "In_Bundle",
        "Is_Accessory_Of",
        "Is_Upgrade_Of",
        "Has_Deal",
    ],
    node_properties=[
        "name",
        "price",
        "weight",
        "deal_end_date",
        "features",
    ],
)

In [None]:
graph_documents = []
for document_list in document_lists:
    graph_documents.extend(llm_transformer.convert_to_graph_documents(document_list))

In [None]:
len(graph_documents)

In [None]:
# Add embeddings to the graph documents for Product nodes
embedding_service = VertexAIEmbeddings(model_name="text-embedding-005")
for graph_document in graph_documents:
    for node in graph_document.nodes:
        if "features" in node.properties:
            node.properties["embedding"] = embedding_service.embed_query(
                node.properties["features"]
            )

In [None]:
import copy

def print_graph(graph_documents):
    for doc in graph_documents:
        print(f"{doc.source.page_content[:100]} === truncated ===")
        nodes = copy.deepcopy(doc.nodes)
        for node in nodes:
            if "embedding" in node.properties:
                node.properties["embedding"] = "..."
        print(nodes)
        print(doc.relationships)
        print()

In [None]:
print_graph(graph_documents)


 **QuantumLeap Flash Friday Deal - Unleash Blazing Fast Storage!**

**Price:** $39.99

**Manufactur === truncated ===
[Node(id='Quantumleap Flash Friday Deal', type='Deal', properties={'deal_end_date': '2024-03-01'}), Node(id='Quantumleap Flash Drive', type='Product', properties={'price': '39.99'}), Node(id='Quantumstorage Technologies', type='Product', properties={'name': 'QuantumStorage Technologies'}), Node(id='Storage', type='Tag', properties={}), Node(id='Data', type='Tag', properties={})]
[Relationship(source=Node(id='Quantumleap Flash Drive', type='Product', properties={}), target=Node(id='Quantumleap Flash Friday Deal', type='Deal', properties={}), type='HAS_DEAL', properties={}), Relationship(source=Node(id='Quantumleap Flash Drive', type='Product', properties={}), target=Node(id='Storage', type='Tag', properties={}), type='TAGGED_WITH', properties={}), Relationship(source=Node(id='Quantumleap Flash Drive', type='Product', properties={}), target=Node(id='Data', type='Tag', pr

### Post process extracted nodes and edges

Apply your domain knowledge to clean up and make desired fixes to the generated graph in the earlier step.

In [None]:
# set of all valid products
products = set()

def prune_invalid_products():
    for graph_document in graph_documents:
        nodes_to_remove = []
        relationships_to_remove = []
        for node in graph_document.nodes:
            if node.type == "Product" and "features" not in node.properties:
                nodes_to_remove.append(node)
            else:
                products.add(node.id)
        for node in nodes_to_remove:
            graph_document.nodes.remove(node)


def prune_invalid_segments(valid_segments):
    for graph_document in graph_documents:
        nodes_to_remove = []
        for node in graph_document.nodes:
            if node.type == "Segment" and node.id not in valid_segments:
                nodes_to_remove.append(node)
        for node in nodes_to_remove:
            graph_document.nodes.remove(node)


def fix_directions(relation_name, wrong_source_type):
    for graph_document in graph_documents:
        for relationship in graph_document.relationships:
            if relationship.type == relation_name:
                if relationship.source.type == wrong_source_type:
                    source = relationship.source
                    target = relationship.target
                    relationship.source = target
                    relationship.target = source


def is_not_a_listed_product(node):
    if node.type == "Product" and node.id not in products:
        return True
    return False


def prune_dangling_relationships():
    # now remove all dangling relationships
    for graph_document in graph_documents:
        relationships_to_remove = []
        for relationship in graph_document.relationships:
            if is_not_a_listed_product(relationship.source) or is_not_a_listed_product(
                relationship.target
            ):
                relationships_to_remove.append(relationship)
        for relationship in relationships_to_remove:
            graph_document.relationships.remove(relationship)


def prune_unwanted_relationships(relation_name, source, target):
    node_types = set([source, target])
    for graph_document in graph_documents:
        relationships_to_remove = []
        for relationship in graph_document.relationships:
            if (
                relationship.type == relation_name
                and set([relationship.source.type, relationship.target.type])
                == node_types
            ):
                relationships_to_remove.append(relationship)
        for relationship in relationships_to_remove:
            graph_document.relationships.remove(relationship)


prune_invalid_products()
prune_invalid_segments(set(["Home", "Office", "Fitness"]))
prune_unwanted_relationships("IN_CATEGORY", "Bundle", "Category")
prune_unwanted_relationships("IN_CATEGORY", "Deal", "Category")
prune_unwanted_relationships("IN_SEGMENT", "Bundle", "Segment")
prune_unwanted_relationships("IN_SEGMENT", "Deal", "Segment")
prune_dangling_relationships()
fix_directions("HAS_DEAL", "Deal")
fix_directions("IN_BUNDLE", "Bundle")

print_graph(graph_documents)


 **QuantumLeap Flash Friday Deal - Unleash Blazing Fast Storage!**

**Price:** $39.99

**Manufactur === truncated ===
[Node(id='Quantumleap Flash Friday Deal', type='Deal', properties={'deal_end_date': '2024-03-01'}), Node(id='Storage', type='Tag', properties={}), Node(id='Data', type='Tag', properties={})]
[]


 DataSafe Vault

Price: $39.99

Manufacturer: SecureTech Solutions

Description: Keep your sensitiv === truncated ===
[Node(id='Storage', type='Tag', properties={}), Node(id='Data', type='Tag', properties={})]
[]


 QuantumLeap Ultimate Kit

Price: $64.99

Manufacturer: QuantumLeap Technologies

Description: Unle === truncated ===
[Node(id='Quantumleap Ultimate Kit', type='Bundle', properties={'price': '64.99'}), Node(id='Quantumleap Technologies', type='Category', properties={}), Node(id='Storage', type='Tag', properties={}), Node(id='Data', type='Tag', properties={})]
[Relationship(source=Node(id='Universal Adapter Kit', type='Product', properties={}), target=Node(id='Quantu

### Load data to Spanner Graph

In [None]:
# Uncomment the line below, if you want to cleanup from previous iterations.
# BeWARE - THIS COULD REMOVE DATA FROM YOUR DATABASE !!!
graph_store.cleanup()
graph_store.add_graph_documents(graph_documents)

Waiting for DDL operations to complete...
Waiting for DDL operations to complete...
Insert nodes of type `Deal`...
Insert nodes of type `Deal`...
Insert nodes of type `Tag`...
Insert nodes of type `Tag`...
Insert nodes of type `Bundle`...
Insert nodes of type `Category`...
Insert nodes of type `Category`...
Insert nodes of type `Product`...
Insert nodes of type `Product`...
Insert nodes of type `Product`...
Insert nodes of type `Product`...
Insert nodes of type `Segment`...
Insert edges of type `Product_IN_BUNDLE_Bundle`...
Insert edges of type `Bundle_TAGGED_WITH_Tag`...
Insert edges of type `Product_IN_CATEGORY_Category`...
Insert edges of type `Product_IN_SEGMENT_Segment`...
Insert edges of type `Product_TAGGED_WITH_Tag`...
Insert edges of type `Product_IS_ACCESSORY_OF_Product`...
Insert edges of type `Deal_TAGGED_WITH_Tag`...
Insert edges of type `Product_HAS_DEAL_Deal`...
Insert edges of type `Product_IS_UPGRADE_OF_Product`...


## Visualization

In [None]:
%load_ext spanner_graphs

In [None]:
%%spanner_graph --project {PROJECT_ID} --instance {INSTANCE} --database {DATABASE}

GRAPH retail_graph
MATCH p = ()->()
RETURN TO_JSON(p) AS path_json

## GraphRAG flow using Spanner Graph

### Question Answering Prompt

In [None]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts.prompt import PromptTemplate
from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings

from IPython.display import Markdown
import textwrap


# Retrieve and generate using the relevant snippets of the blog.
def format_docs(docs):
    print("Context Retrieved: \n")
    for doc in docs:
        print(textwrap.fill(doc.page_content, width=80))
        print("\n")

    context = "\n\n".join(doc.page_content for doc in docs)
    return context


SPANNERGRAPH_QA_TEMPLATE = """
You are a helpful and friendly AI assistant for question answering tasks for an electronics
retail online store.
Create a human readable answer for the for the question.
You should only use the information provided in the context and not use your internal knowledge.
Don't add any information.
Here is an example:

Question: Which funds own assets over 10M?
Context:[name:ABC Fund, name:Star fund]"
Helpful Answer: ABC Fund and Star fund have assets over 10M.

Follow this example when generating answers.
You are given the following information:
- `Question`: the natural language question from the user
- `Graph Schema`: contains the schema of the graph database
- `Graph Query`: A Spanner Graph GQL query equivalent of the question from the user used to extract context from the graph database
- `Context`: The response from the graph database as context. The context has nodes and edges. Use the relationships.
Information:
Question: {question}
Graph Schema: {graph_schema}
Context: {context}

Format your answer to be human readable.
Use the relationships in the context to answer the question.
Only include information that is relevant to a customer.
Helpful Answer:"""

prompt = PromptTemplate(
    template=SPANNERGRAPH_QA_TEMPLATE,
    input_variables=["question", "graph_schema", "context"],
)

llm = ChatVertexAI(model="gemini-2.5-flash", temperature=0)

chain = prompt | llm | StrOutputParser()

### GraphRAG using Vector Search and Graph Expansion

In [None]:
from langchain_google_spanner import SpannerGraphVectorContextRetriever
from langchain_google_vertexai import VertexAIEmbeddings


def use_node_vector_retriever(
    question, graph_store, embedding_service, label_expr, expand_by_hops
):
    retriever = SpannerGraphVectorContextRetriever.from_params(
        graph_store=graph_store,
        embedding_service=embedding_service,
        label_expr=label_expr,
        expand_by_hops=expand_by_hops,
        top_k=1,
        k=10,
    )
    context = format_docs(retriever.invoke(question))
    return context


embedding_service = VertexAIEmbeddings(model_name="text-embedding-005")

In [None]:
# Enable Debugging on LangChain
from langchain.globals import set_debug
# set_debug(True)

In [None]:
import textwrap

question = "Give me recommendations for a beginner drone"
context = use_node_vector_retriever(
    question, graph_store, embedding_service, label_expr="Product", expand_by_hops=1
)

answer = chain.invoke(
    {"question": question, "graph_schema": graph_store.get_schema, "context": context}
)

print("\n\nAnswer:\n")
print(textwrap.fill(answer, width=80))

Context Retrieved: 

[{"element_definition_name": "Product", "kind": "node", "labels": ["Product"],
"properties": {"deal_end_date": null, "features": "Increased Durability:
Provides an extra layer of protection for your drone.", "id": "Skyhawk Zephyr
Propeller Guards", "name": null, "price": "14.99", "weight": "20g (0.7 oz)"}}]


[{"element_definition_name": "Product", "kind": "node", "labels": ["Product"],
"properties": {"deal_end_date": null, "features": "Increased Durability:
Provides an extra layer of protection for your drone.", "id": "Skyhawk Zephyr
Propeller Guards", "name": null, "price": "14.99", "weight": "20g (0.7 oz)"}},
{"element_definition_name": "Product_HAS_DEAL_Deal", "kind": "edge", "labels":
["Product_HAS_DEAL_Deal"], "properties": {"id": "Skyhawk Zephyr Propeller
Guards", "target_id": "Limited-Time Offer: Skyhawk Zephyr Propeller Guards"}},
{"element_definition_name": "Deal", "kind": "node", "labels": ["Deal"],
"properties": {"deal_end_date": "2025-05-31", "id": "Li

## Compare with Conventional RAG

In [None]:
# Define table name
TABLE_NAME = "retail_table"  # @param {"type":"string"}

### Setup and load data for vector search

In [None]:
from langchain_google_spanner import SpannerVectorStore
from langchain_google_vertexai import VertexAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter

import uuid


def load_data_for_vector_search(splits, embeddings):
    SpannerVectorStore.init_vector_store_table(
        instance_id=INSTANCE,
        database_id=DATABASE,
        table_name=TABLE_NAME,
    )
    db = SpannerVectorStore(
        instance_id=INSTANCE,
        database_id=DATABASE,
        table_name=TABLE_NAME,
        embedding_service=embeddings,
    )
    # Add the chunks to Spanner Vector Store with embeddings
    ids = [str(uuid.uuid4()) for _ in range(len(splits))]
    row_ids = db.add_documents(splits, ids)
    return len(row_ids)


# Create splits for documents
text_splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=200)
splits = text_splitter.split_documents(
    [document for document_list in document_lists for document in document_list]
)

# Initialize table and load data
embeddings = VertexAIEmbeddings(model_name="text-embedding-005")
num_docs = load_data_for_vector_search(splits, embeddings)

### Create a conventional RAG chain

In [None]:
from langchain_core.runnables import RunnablePassthrough
from langchain_google_spanner import SpannerVectorStore
import textwrap


# Retrieve and generate using the relevant snippets of the blog.
def format_docs(docs):
    print("Context Retrieved: \n")
    for doc in docs:
        print(textwrap.fill(doc.page_content, width=80))
        print("\n")

    context = "\n\n".join(doc.page_content for doc in docs)
    return context


prompt = PromptTemplate(
    template="""
    You are a friendly digital shopping assistant.
    Use the following pieces of retrieved context to answer the question.
    If you don't know the answer, just say that you don't know.
    Question: {question}
    Context: {context}
    Answer:
  """,
    input_variables=["context", "question"],
)

# Create a rag chain
embeddings = VertexAIEmbeddings(model_name="text-embedding-005")

db = SpannerVectorStore(
    instance_id=INSTANCE,
    database_id=DATABASE,
    table_name=TABLE_NAME,
    embedding_service=embeddings,
)
vector_retriever = db.as_retriever(search_kwargs={"k": 3})
rag_chain = (
    {
        "context": vector_retriever | format_docs,
        "question": RunnablePassthrough(),
    }
    | prompt
    | llm
    | StrOutputParser()
)

### Run conventional RAG chain

In [None]:
import textwrap

question = "I am looking for an beginner drone. Please give me some recommendations."
resp = rag_chain.invoke(question)

print("\n\nRag Response:\n")
print(textwrap.fill(resp, width=80))

Context Retrieved: 

Product: SkyHawk Zephyr Drone   Price: $129.99   Weight: 220g (7.8 oz)   The
SkyHawk Zephyr is the perfect drone for beginners. It's built for effortless
flying, offering a smooth and enjoyable experience from the moment you unpack
it.   Features:   Simple Controls: Beginner friendly and intuitive controls,
plus automatic takeoff and landing.   Tough Build: Designed to handle rookie
mistakes, thanks to its robust construction.   Capture Memories: Record crisp HD
photos and videos from above.   Extended Fun: Enjoy up to 15 minutes of flight
time per charge.   Worry-Free Flying: Free Fly mode lets you fly without
directional concerns.   Take your first flight with the SkyHawk Zephyr and
discover the joy of aerial views!    Category: Drone   Segment: ['Home']   Tags:
['Photography', 'Videography']


Bundle:SkyHawk Zephyr Starter Package   Price: $129.99   Everything you need to
begin your drone journey:   This package includes the essentials to get you
flying and capt

#### Clean up the graph

> USE IT WITH CAUTION!

Clean up all the nodes/edges in your graph and remove your graph definition.

In [None]:
graph_store.cleanup()