# Implementation, Testing and Evaluation of RAG with Knowledge Graph in combination with Vector Database (Hybrid RAG)

#### Notebook Outline
1. Imports and Configurations
2. Creation of Graph & Vector Database
3. Querying the Hybrid Database Combination
4. Output of Hybrid  RAG Model
5. Evaluations

This code is adapted and based on the provided implementation of Tomaz Bratanic [https://medium.com/neo4j/enhancing-the-accuracy-of-rag-applications-with-knowledge-graphs-ad5e2ffab663].

### 1. Imports and Configurations

Imports

In [None]:
# === Standard Library Imports ===
import hashlib
import json
import logging
import os
import re
import sys
import threading
import time
from collections import Counter, defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, List
from uuid import uuid4
import uuid
import shutil

# === Third-Party Library Imports ===
import openai
from openai import OpenAI
import tiktoken
from dotenv import load_dotenv
from pydantic import BaseModel, Field
from tqdm import tqdm

# === LangChain Core Modules ===
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, PromptTemplate
from langchain.output_parsers import PydanticOutputParser
from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
from langchain.tools import Tool
from langchain.text_splitter import TokenTextSplitter
from langchain_core.documents import Document
from langchain_core.runnables import RunnableParallel, RunnablePassthrough

# === LangChain Integrations ===
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings
from langchain.graphs import Neo4jGraph

# === Neo4j Imports ===
from neo4j import GraphDatabase

# === Project-Specific Module Imports ===
project_root = os.path.abspath(os.path.join(os.getcwd(), '..', '..'))
if project_root not in sys.path:
    sys.path.append(project_root)

from ipynb_notebooks.baseline.rag_utils.baseline_rag import (
    load_documents,
    save_documents_for_sparse_retrieval,
    load_vector_database,
    retrieve_documents,
    generate_answer
)
from ipynb_notebooks.evaluation_datasets.retrieval_eval.eval_vector_dataset_generator import (
    generate_evalset
)
from ipynb_notebooks.evaluation_datasets.retrieval_eval.retrieval_metrics import (
    run_retrieval_evaluation
)
from ipynb_notebooks.evaluation_datasets.generation_eval.generation_metrics import (
    run_generation_evaluation
)

Configurations

In [None]:
# Move up one level from the Jupyter Notebook directory
BASE_DIR = os.path.abspath(os.path.join(os.getcwd(), "../.."))

# Construct the path to .env.neo4j in the base directory
env_path = os.path.join(BASE_DIR, ".env.neo4j")

# Load environment variables from .env and .env.neo4j files
load_dotenv()
load_dotenv(env_path, override=True)

# Set environment variables for OpenAI and Neo4j
openai.api_key = os.environ['OPENAI_API_KEY']
client = OpenAI(api_key=openai.api_key)

NEO4J_AUTH = os.getenv("NEO4J_AUTH")
NEO4J_URI = os.getenv("NEO4J_URI")


# Split NEO4J_AUTH into user name and password
NEO4J_USERNAME, NEO4J_PASSWORD = NEO4J_AUTH.split("/")

# Set up neo4j driver
neo4j_driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))

# Define constants for paths
DATA_PATH = "../../data/laws_and_ordinances.json"  # Directory containing the url to the law and ordinance documents
DATA_PATH_SHORT_VERSION = "../../data/laws_and_ordinances_short_version.json" # Directory containing a subset of all urls for testing purposes
CHROMA_PATH = "chroma"  # Directory to save the Chroma vector store

Helper Functions

### 2. Creation of Vector and Graph Database with Identical Chunk Index

**Why Creating a Chroma and Neo4j Database with Identical Chunk Index Is Necessary**

Creating a Chroma and Neo4j database with identical chunk_id or chunk_index is necessary to enable hybrid retrieval in a RAG pipeline. This shared identifier ensures that semantically retrieved chunks from the vector database (Chroma) can be directly mapped to corresponding nodes and relationships in the graph database (Neo4j). It allows for seamless integration between semantic similarity and structured knowledge, supports graph-based context expansion, and makes the overall system traceable, explainable, and consistent.

In [None]:
MAX_PARALLEL = 10
SLEEP_BETWEEN_CALLS = 1

encoding = tiktoken.encoding_for_model("gpt-4o-mini")


def token_length(text):
    return len(encoding.encode(text))


def split_text(documents: list[Document], chunk_size, chunk_overlap):
    text_splitter = TokenTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap,
        model_name="gpt-4o-mini"
    )
    chunks = text_splitter.split_documents(documents)

    chunk_index = 1
    for chunk in chunks:
        chunk.metadata["chunk_id"] = str(uuid.uuid4())
        chunk.metadata["chunk_index"] = chunk_index
        chunk_index += 1

    return chunks

In [None]:
def save_to_chroma(chunks: list[Document], chunk_size, chunk_overlap, optimization="hybrid_storage_approach", batch_size=100):
    if os.path.exists(CHROMA_PATH):
        print(f"Removing existing directory: {CHROMA_PATH}")
        shutil.rmtree(CHROMA_PATH)

    chroma_path = f"../chroma_dbs/chroma_chunksize{chunk_size}_overlap{chunk_overlap}_{str(uuid.uuid4())[:8]}_{optimization}"

    embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
    db = Chroma(embedding_function=embeddings, persist_directory=chroma_path)


    for i in tqdm(range(0, len(chunks), batch_size), desc="🔢 Store Chunks with Embeddings"):
        batch = chunks[i:i + batch_size]
        db.add_documents(batch)

    db.persist()
    print(f"✅ Saved {len(chunks)} chunks to {chroma_path}")
    return chroma_path

In [None]:
def sanitize_relation(relation: str) -> str:
    umlaut_map = {"Ä": "AE", "Ö": "OE", "Ü": "UE", "ä": "ae", "ö": "oe", "ü": "ue", "ß": "ss"}
    for umlaut, replacement in umlaut_map.items():
        relation = relation.replace(umlaut, replacement)
    relation = relation.strip()
    relation = relation.replace("§", "PARAGRAPH_").replace(" ", "_").replace("-", "_").upper()
    relation = re.sub(r"[^A-Z0-9_]", "", relation)
    return relation or "UNDEFINED_RELATION"


def extract_relations_from_chunk(text: str) -> list[dict]:
    import openai  # wichtig: OpenAI-Client installiert & konfiguriert
    system_prompt = (
        "Du bist ein KI-System für juristische Wissensmodellierung. "
        "Extrahiere alle relevanten Entitäten und ihre Beziehungen aus folgendem Gesetzestext. "
        "Gib das Ergebnis als reine JSON-Liste zurück:\n"
        "[{\"head\": \"...\", \"relation\": \"...\", \"tail\": \"...\"}]"
    )
    user_prompt = f"Text:\n\"\"\"\n{text}\n\"\"\""

    response = openai.chat.completions.create(
        model="gpt-4o-mini",
        temperature=0,
        messages=[{"role": "system", "content": system_prompt},
                  {"role": "user", "content": user_prompt}]
    )

    raw = response.choices[0].message.content
    try:
        json_block = re.search(r"\[\s*{.*?}\s*\]", raw, re.DOTALL)
        return json.loads(json_block.group()) if json_block else []
    except Exception as e:
        print(f"Parsing Error: {e}\nAnswer: {raw}")
        return []


def process_single_chunk(i, doc, graph):
    chunk_id = doc.metadata["chunk_id"]
    chunk_index = doc.metadata["chunk_index"]
    title = doc.metadata.get("title", "UnknownLaw")
    source = doc.metadata.get("source", "unknown")
    text = doc.page_content

    # Knoten erzeugen
    graph.query("MERGE (l:Law {title: $title})", {"title": title})
    graph.query("""
        MERGE (c:Chunk {chunk_id: $chunk_id})
        SET c.text = $text, c.chunk_index = $chunk_index, c.title = $title, c.source = $source
    """, {
        "chunk_id": chunk_id, "text": text, "chunk_index": chunk_index,
        "title": title, "source": source
    })
    graph.query("""
        MATCH (l:Law {title: $title}), (c:Chunk {chunk_id: $chunk_id})
        MERGE (l)-[:HAS_CHUNK]->(c)
    """, {
        "title": title, "chunk_id": chunk_id
    })

    # Relationen extrahieren
    try:
        relations = extract_relations_from_chunk(text)
        for rel in relations:
            head, tail = rel["head"], rel["tail"]
            rel_type = sanitize_relation(rel["relation"])
            cypher = f"""
                MERGE (h:Entity {{id: $head}})
                MERGE (t:Entity {{id: $tail}})
                MERGE (h)-[:{rel_type}]->(t)
                WITH h, t
                MATCH (c:Chunk {{chunk_id: $chunk_id}})
                MERGE (c)-[:HAS_ENTITY]->(h)
                MERGE (c)-[:HAS_ENTITY]->(t)
            """
            graph.query(cypher, {"head": head, "tail": tail, "chunk_id": chunk_id})
    except Exception as e:
        print(f"Error in chunk {chunk_id[:6]}: {e}")

    time.sleep(SLEEP_BETWEEN_CALLS)


def ingest_chunks_to_neo4j(chunks: list[Document]):
    from langchain_community.graphs import Neo4jGraph
    graph = Neo4jGraph(url=NEO4J_URI, username=NEO4J_USERNAME, password=NEO4J_PASSWORD)

    with ThreadPoolExecutor(max_workers=MAX_PARALLEL) as executor:
        futures = [
            executor.submit(process_single_chunk, i, doc, graph)
            for i, doc in enumerate(chunks)
        ]
        for _ in tqdm(as_completed(futures), total=len(futures), desc="Neo4j Ingest"):
            pass

In [None]:
def generate_synchronized_databases(datapath, chunk_size=512, chunk_overlap=64, optimization="hybrid_storage_approach", baseline=False):

    documents = load_documents(datapath)
    chunks = split_text(documents, chunk_size, chunk_overlap)
    save_documents_for_sparse_retrieval(chunks, chunk_size, chunk_overlap, optimization, baseline)

    print("Storing in Chroma ...")
    chroma_path = save_to_chroma(chunks, chunk_size, chunk_overlap, optimization)

    print("Ingest in Neo4j ...")
    ingest_chunks_to_neo4j(chunks)

    print("Both databases were successfully synchronized.")
    return chroma_path

In [None]:
chroma_path_hybrid_graph_rag = "../chroma_dbs/chroma_chunksize1024_overlap128_a5e9b634_hybrid_graph_rag"

In [None]:
datapath = "../../data/laws_and_ordinances.json"
chunk_size = 1024
chunk_overlap = 128
optimization = "hybrid_graph_rag"

chroma_path_hybrid_graph_rag = generate_synchronized_databases(datapath, chunk_size, chunk_overlap, optimization=optimization)

### 3. Querying of Vector and Graph Databases

In [None]:
def run_graph_query(chunk_id: str):
    query = """
    MATCH (c:Chunk {chunk_id: $chunk_id})-[:HAS_ENTITY]->(e:Entity)
    CALL apoc.path.expand(e, '>', 'Entity', 1, 2) YIELD path
    UNWIND relationships(path) AS r
    WITH c, startNode(r) AS s, endNode(r) AS t, type(r) AS rel
    RETURN
        c.chunk_id AS chunk_id,
        s.id AS head,
        rel AS relation,
        t.id AS tail,
        c.chunk_index AS chunk_index,
        c.title AS law_title
    """
    try:
        with neo4j_driver.session() as session:
            result = session.run(query, {"chunk_id": chunk_id})
            return [
                {
                    "chunk_id": record["chunk_id"],
                    "chunk_index": record["chunk_index"],
                    "law_title": record["law_title"],
                    "context": f'{record["head"]} - {record["relation"]} -> {record["tail"]}'
                }
                for record in result
            ]
    except Exception as e:
        print(f"Error for chunk {chunk_id[:6]}: {e}")
        return []


In [None]:
def graph_retriever_from_chunks(chunk_ids: List[str], top_k: int = 20) -> dict:
    retrieved_nodes = []
    with ThreadPoolExecutor(max_workers=2) as executor:
        futures = [executor.submit(run_graph_query, cid) for cid in chunk_ids]
        for f in as_completed(futures):
            retrieved_nodes.extend(f.result())

    # Der Rest bleibt gleich wie vorher:
    index_counts = Counter(entry["chunk_index"] for entry in retrieved_nodes)
    top_indices = set(idx for idx, _ in index_counts.most_common(top_k))
    filtered_nodes = [entry for entry in retrieved_nodes if entry["chunk_index"] in top_indices]

    contexts = list(dict.fromkeys(entry["context"] for entry in filtered_nodes))
    enc = tiktoken.encoding_for_model("gpt-4o-mini")
    MAX_TOKENS = 3000
    final_contexts = []
    token_count = 0

    for ctx in contexts:
        tokens = len(enc.encode(ctx))
        if token_count + tokens > MAX_TOKENS:
            break
        final_contexts.append(ctx)
        token_count += tokens

    chunk_indices = [entry["chunk_index"] for entry in filtered_nodes]
    law_titles = list(set(entry["law_title"] for entry in filtered_nodes))

    return {
        "prompt_context": final_contexts,
        "retrieved_chunk_indices": list(set(chunk_indices)),
        "retrieved_law_titles": law_titles
    }

In [None]:
def hybrid_graph_rag_pipeline(query, database, model_name="gpt-4o-mini"):
    """
    Hybrid RAG pipeline: First, a vector-based retrieval phase for chunk selection,     
    followed by graph-based context enrichment via the corresponding chunk IDs.
    """

    # Vektor-Retrieval
    vector_results = retrieve_documents(query, db=database)
    
    if vector_results and isinstance(vector_results[0], tuple):
        vector_results = [doc for doc, _ in vector_results]

    # extract relevant Chunk-IDs
    sources = [doc.metadata.get("source") for doc in vector_results]
    retrieved_chunk_ids = [doc.metadata.get("chunk_id") for doc in vector_results]
    retrieved_chunk_indices = [doc.metadata.get("chunk_index") for doc in vector_results]

    # extract graph results
    graph_results = graph_retriever_from_chunks(retrieved_chunk_ids)

    # extract vector and graph contexts 
    vector_contexts = [doc.page_content for doc in vector_results]
    graph_contexts = graph_results.get("prompt_context", [])

    # combine and merge vector and graph contexts
    merged_context = list(dict.fromkeys(graph_contexts + vector_contexts))

    enc = tiktoken.encoding_for_model(model_name)
    MAX_TOKENS = 10000
    final_contexts = []
    token_count = 0
    for ctx in merged_context:
        tokens = len(enc.encode(ctx))
        if token_count + tokens > MAX_TOKENS:
            break
        final_contexts.append(ctx)
        token_count += tokens
        
    # generate answer
    response = generate_answer(final_contexts, query, model_name)

    return response, sources, final_contexts, retrieved_chunk_ids, retrieved_chunk_indices

### 4. Output of Hybrid RAG Model

In [None]:
query = "Welchen Anwendungsbereich umfasst §1 des Elektromobilitätsgesetz - EmoG?"
database = load_vector_database(chroma_path=chroma_path_hybrid_graph_rag)
model_name = "gpt-4o-mini"  # or any other supported model

response, sources, retrieved_chunk_contexts, retrieved_chunk_ids, retrieved_chunk_indices = hybrid_graph_rag_pipeline(query=query, database=database, model_name=model_name)

In [None]:
# Display the results
print(f"Query: {query} \n")
print(f"Response: {response} \n")
print(f"Sources: {sources} \n")
print(f"Retrieved Chunk Contexts: {retrieved_chunk_contexts} \n")
print(f"Retrieved Chunk Indices: {retrieved_chunk_indices} \n")

### 5. Evaluations

#### 5.1 Generate Evaluation Dataset

In [None]:
eval_dataset_hybrid_graph_rag = generate_evalset(chroma_db=chroma_path_hybrid_graph_rag, test_set_size=50, 
                 query_distribution={"single": 0.6, "multi_specific": 0.2, "multi_intra_document": 0.2})

#### 5.2 Enrich Evaluation Dataset

In [None]:
def enrich_eval_dataset_with_hybrid_graph_rag_responses(eval_dataset, chroma_path, model_name="gpt-4o-mini"):
    
    db = load_vector_database(chroma_path)

    with open(eval_dataset, "r", encoding="utf-8") as f:
        eval_dataset_json = json.load(f)

    enriched_dataset = []
    
    for entry in tqdm(eval_dataset_json, desc="Processing Hybrid RAG responses"):
        query = entry["query"]

        # Run RAG pipeline
        response, _, retrieved_chunk_contexts, retrieved_chunk_ids, retrieved_chunk_indices = hybrid_graph_rag_pipeline(query, db, model_name=model_name)

        # Add new fields to file
        entry["generated_response"] = response
        entry["retrieved_chunk_contexts"] = retrieved_chunk_contexts
        entry["retrieved_chunk_ids"] = retrieved_chunk_ids
        entry["retrieved_chunk_indices"] = retrieved_chunk_indices

        enriched_dataset.append(entry)

    output_path = f"{eval_dataset.replace('.json', '')}_rag_enriched.json"
    # Store results as new json file
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(enriched_dataset, f, indent=2, ensure_ascii=False)
        
    return output_path

In [None]:
eval_dataset_graph_rag_enriched = enrich_eval_dataset_with_hybrid_graph_rag_responses(eval_dataset=eval_dataset_hybrid_graph_rag,
                                                                                      chroma_path=chroma_path_hybrid_graph_rag,
                                                                                      model_name="gpt-4o-mini")

#### 5.3 Evaluation of Graph RAG Retrieval

In [None]:
json_filename = f"2_graph_database/{eval_dataset_graph_rag_enriched.split('/')[-1]}"
model_name="graph_rag"

retrieval_result = run_retrieval_evaluation(json_filename=json_filename, model_name=model_name)
display(retrieval_result)

#### Evaluation of Graph RAG Generation

In [None]:
generation_results = run_generation_evaluation(json_filename=json_filename, model_name=model_name) 
display(generation_results)