# Implementation, Testing and Evaluation of RAG with Knowledge Graph (GraphRAG)

#### Notebook Outline
1. Imports and Configurations
2. Creation of Graph Database
3. Querying the Graph Database
4. Output of Knowledge Graph RAG Model
5. Evaluations

This code is adapted and based on the provided implementation of Tomaz Bratanic [https://medium.com/neo4j/enhancing-the-accuracy-of-rag-applications-with-knowledge-graphs-ad5e2ffab663].

### 1. Imports and Configurations

Imports

In [None]:
# === Standard Library Imports ===
import os
import sys
import re
import json
import time
import hashlib
import logging
import threading
from uuid import uuid4
from typing import List, Dict
from collections import Counter
from concurrent.futures import ThreadPoolExecutor, as_completed
import tiktoken

# === Third-Party Library Imports ===
import openai
from openai import OpenAI
from dotenv import load_dotenv
from pydantic import BaseModel, Field
from tqdm import tqdm

# === LangChain Core Modules ===
from langchain.text_splitter import TokenTextSplitter
from langchain.prompts import ChatPromptTemplate, PromptTemplate
from langchain.output_parsers import PydanticOutputParser
from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
from langchain_core.runnables import (
    RunnableParallel,
    RunnablePassthrough,
)

# === LangChain Models and Tools ===
from langchain.chat_models import ChatOpenAI
from langchain.tools import Tool
from langchain_openai import OpenAI

# === LangChain Integrations ===
from langchain.graphs import Neo4jGraph

# === Project-Specific Modules ===
project_root = os.path.abspath(os.path.join(os.getcwd(), '..', '..'))
if project_root not in sys.path:
    sys.path.append(project_root)

from ipynb_notebooks.baseline.rag_utils.baseline_rag import load_documents
from ipynb_notebooks.evaluation_datasets.retrieval_eval.eval_graph_dataset_generator import generate_graph_evalset
from ipynb_notebooks.evaluation_datasets.retrieval_eval.retrieval_metrics import run_retrieval_evaluation
from ipynb_notebooks.evaluation_datasets.generation_eval.generation_metrics import run_generation_evaluation


Configurations

In [None]:
# Move up one level from the Jupyter Notebook directory
BASE_DIR = os.path.abspath(os.path.join(os.getcwd(), "../.."))

# Construct the path to .env.neo4j in the base directory
env_path = os.path.join(BASE_DIR, ".env.neo4j")

# Load environment variables from .env and .env.neo4j files
load_dotenv()
load_dotenv(env_path, override=True)

# Set environment variables for OpenAI and Neo4j
openai.api_key = os.environ['OPENAI_API_KEY']
client = OpenAI(api_key=openai.api_key)

NEO4J_AUTH = os.getenv("NEO4J_AUTH")
NEO4J_URI = os.getenv("NEO4J_URI")

# Split NEO4J_AUTH into user name and password
NEO4J_USERNAME, NEO4J_PASSWORD = NEO4J_AUTH.split("/")

# Define constants for paths
DATA_PATH = "../../data/laws_and_ordinances.json"  # Directory containing the url to the law and ordinance documents
DATA_PATH_SHORT_VERSION = "../../data/laws_and_ordinances_short_version.json" # Directory containing a subset of all urls for testing purposes
CHROMA_PATH = "chroma"  # Directory to save the Chroma vector store

Helper Functions

### 2. Creation of Graph Database

In [None]:
def generate_chunk_id(text: str) -> str:
    """Generates a stable ID based on the text content."""
    return hashlib.sha256(text.encode('utf-8')).hexdigest()[:16] 

In [None]:
def extract_relations_from_chunk(text: str) -> list[dict]:
    system_prompt = (
        "Du bist ein KI-System für juristische Wissensmodellierung. "
        "Extrahiere alle relevanten Entitäten und ihre Beziehungen aus folgendem Gesetzestext. "
        "Gib das Ergebnis als reine JSON-Liste zurück:\n"
        "[{\"head\": \"...\", \"relation\": \"...\", \"tail\": \"...\"}]\n"
        "Keine weiteren Erklärungen oder Einleitungen."
    )
    user_prompt = f"Text:\n\"\"\"\n{text}\n\"\"\"\n\nExtrahiere jetzt."

    response = openai.chat.completions.create(
        model="gpt-4o-mini",
        temperature=0,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ]
    )

    raw = response.choices[0].message.content

    try:
        json_block = re.search(r"\[\s*{.*?}\s*\]", raw, re.DOTALL)
        if not json_block:
            raise ValueError("No JSON-Block found.")
        return json.loads(json_block.group())
    except Exception as e:
        print(f"Parsing Errpr: {e}")
        print("GPT-Antwort:", raw)
        return []

In [None]:
def sanitize_relation(relation: str) -> str:
    umlaut_map = {
        "Ä": "AE", "Ö": "OE", "Ü": "UE",
        "ä": "ae", "ö": "oe", "ü": "ue",
        "ß": "ss"
    }

    for umlaut, replacement in umlaut_map.items():
        relation = relation.replace(umlaut, replacement)

    relation = relation.strip()
    relation = relation.replace("§", "PARAGRAPH_")
    relation = relation.replace(" ", "_").replace("-", "_")
    relation = relation.upper()

    # Remove everything excepts A-Z, 0-9 and _
    relation = re.sub(r"[^A-Z0-9_]", "", relation)

    if not relation:
        relation = "UNDEFINED_RELATION"

    return relation


In [None]:
def create_interlaw_references(graph: Neo4jGraph):
    """Scans all Entity nodes for legal cross-references and creates :CITES_LAW relationships between Law nodes."""

    print("Searching for inter-law references ...")
    
    # Get all Entity nodes and their associated Law titles
    result = graph.query("""
    MATCH (e:Entity)<-[:HAS_ENTITY]-(c:Chunk)<-[:HAS_CHUNK]-(l:Law)
    RETURN e.id AS entity_id, l.title AS source_law
    """)

    references = []

    for record in result:
        entity_id = record["entity_id"]
        source_law = record["source_law"]

        # Look for references to other laws (e.g., "§ 14 EnWG", "according to GEG")
        match = re.search(r"(§\s*\d+[a-zA-Z]*(?: Abs\.? \d+)?(?: Satz \d+)? )?\b([A-ZÄÖÜ]{2,})\b", entity_id)
        if match:
            target_law = match.group(2)
            # Avoid self-references
            if target_law != source_law:
                references.append((source_law, target_law))

    # Deduplicate references
    unique_refs = set(references)

    # Create citation edges in the graph
    for src, tgt in unique_refs:
        graph.query("""
        MATCH (a:Law {title: $src}), (b:Law {title: $tgt})
        MERGE (a)-[:CITES_LAW]->(b)
        """, {"src": src, "tgt": tgt})
    
    print(f"Created {len(unique_refs)} inter-law citation links.")

In [None]:
MAX_PARALLEL = 10  # maximal gleichzeitig
SLEEP_BETWEEN_CALLS = 2  # Sekunden, falls du auf Nummer sicher gehen willst
lock = threading.Lock()

def process_single_chunk(i, doc, graph):
    chunk_text = doc.page_content
    chunk_id = generate_chunk_id(chunk_text)
    title = doc.metadata.get("title", "UnknownLaw")
    source = doc.metadata.get("source", "unknown")

    doc.metadata["chunk_id"] = chunk_id
    doc.metadata["chunk_index"] = i
    doc.metadata["title"] = title

    # create law node and connect it with chunk nodes
    with lock:
        # create law node
        graph.query("""
        MERGE (l:Law {title: $title})
        """, {"title": title})

        # create chunk node
        graph.query("""
        MERGE (c:Chunk {chunk_id: $chunk_id})
        SET c.text = $text, c.chunk_index = $chunk_index, c.title = $title, c.source = $source
        """, {
            "chunk_id": chunk_id,
            "text": chunk_text,
            "chunk_index": i,
            "title": title,
            "source": source
        })

        # Connection between law and chunk
        graph.query("""
        MATCH (l:Law {title: $title}), (c:Chunk {chunk_id: $chunk_id})
        MERGE (l)-[:HAS_CHUNK]->(c)
        """, {
            "title": title,
            "chunk_id": chunk_id
        })

    try:
        relations = extract_relations_from_chunk(chunk_text)
        for rel in relations:
            head = rel["head"]
            tail = rel["tail"]
            relation = sanitize_relation(rel["relation"])

            cypher = f"""
            MERGE (h:Entity {{id: $head}})
            MERGE (t:Entity {{id: $tail}})
            MERGE (h)-[:{relation}]->(t)
            WITH h, t
            MATCH (c:Chunk {{chunk_id: $chunk_id}})
            MERGE (c)-[:HAS_ENTITY]->(h)
            MERGE (c)-[:HAS_ENTITY]->(t)
            """
            with lock:
                graph.query(cypher, {
                    "head": head,
                    "tail": tail,
                    "chunk_id": chunk_id
                })

        time.sleep(SLEEP_BETWEEN_CALLS)

    except Exception as e:
        print(f"Error at chunk {chunk_id[:6]}...: {e}")

In [None]:
def ingest_graph_under_chunks_parallel(datapath, chunk_size=1024, chunk_overlap=128):
    raw_documents = load_documents(datapath)
    splitter = TokenTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
    chunks = splitter.split_documents(raw_documents)

    graph = Neo4jGraph(url=NEO4J_URI, username=NEO4J_USERNAME, password=NEO4J_PASSWORD)

    with ThreadPoolExecutor(max_workers=MAX_PARALLEL) as executor:
        futures = [
            executor.submit(process_single_chunk, i, doc, graph)
            for i, doc in enumerate(chunks)
        ]
        for f in tqdm(as_completed(futures), total=len(futures), desc="Parallel-Ingestion"):
            pass

    graph.query("""
    MATCH (a:Chunk), (b:Chunk)
    WHERE a.title = b.title AND a.chunk_id <> b.chunk_id
    MERGE (a)-[:SAME_LAW]->(b)
    """)
    
    create_interlaw_references(graph)
    
    print("Parallel-Ingestion succesful.")

In [None]:
# Run the function to test with the sample text
ingest_graph_under_chunks_parallel(datapath="../../data/laws_and_ordinances.json", chunk_size=1024, chunk_overlap=128)

### 3. Querying of Vector Database

In [None]:
# Initialize Neo4j connection
graph = Neo4jGraph(url=NEO4J_URI, username=NEO4J_USERNAME, password=NEO4J_PASSWORD)

# Initialize LLM
llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini")

# Entity extraction schema
class Entities(BaseModel):
    names: List[str] = Field(..., description="Extracted legal entities such as laws, abbreviations, paragraph references, or authorities.")

# Define an output parser for structured extraction
parser = PydanticOutputParser(pydantic_object=Entities)

# Define a better prompt for structured entity extraction
entity_prompt = PromptTemplate(
    template="Extract legal entities (e.g. laws, abbreviations, paragraph references, institutions) from the following question:\n\n{question}\n\n{format_instructions}",
    input_variables=["question"],
    partial_variables={"format_instructions": parser.get_format_instructions()}
)

In [None]:
# Function for extracting entities
def extract_entities(question: str) -> Entities:
    prompt_value = entity_prompt.format_prompt(question=question)
    response = llm.invoke(prompt_value.to_string())  # Ensure correct format
    return parser.parse(response.content)  # Ensure parsing of response

# Create a LangChain tool
entity_extraction_tool = Tool(
    name="EntityExtractor",
    func=extract_entities,
    description="Extracts named entities (people, organizations, locations) from a question."
)

# Chain with prompt
entity_chain = entity_prompt | llm | parser

# Full-text search query generation
def generate_full_text_query(input: str) -> str:
    input = input.lower().replace("§", "paragraph").replace("-", " ").replace(",", "").strip()
    return f'"{input}"~'


# Create a full-text index if it doesn't exist
graph.query("CREATE FULLTEXT INDEX entity IF NOT EXISTS FOR (e:__Entity__) ON EACH [e.id, e.title, e.text]")

# Structured retrieval from Neo4j graph
def structured_retriever(question: str, top_k: int = 50) -> dict:
    """Performs graph-based retrieval using legal entity and relation context."""
    extracted = entity_chain.invoke({"question": question})
    entities = extracted.names

    retrieved_nodes = []
    
    for entity in entities:
        # Cypher query mit Ego-Network approach from Hu et al. (2024)
        response = graph.query("""
        MATCH (e:Entity)
        WHERE toLower(e.id) CONTAINS toLower($query)
        CALL apoc.path.expand(e, '>', 'Entity', 1, 2) YIELD path
        UNWIND relationships(path) AS r
        WITH startNode(r) AS s, endNode(r) AS t, type(r) AS rel
        OPTIONAL MATCH (s)<-[:HAS_ENTITY]-(c:Chunk)
        RETURN DISTINCT
            s.id AS head,
            rel AS relation,
            t.id AS tail,
            c.chunk_id AS chunk_id,
            c.chunk_index AS chunk_index,
            c.title AS law_title
        """, {"query": entity})

        for row in response:
            if not row["chunk_id"]:
                continue 

            retrieved_nodes.append({
                "chunk_id": row["chunk_id"],
                "chunk_index": row["chunk_index"],
                "law_title": row["law_title"],
                "context": f'{row["head"]} - {row["relation"]} -> {row["tail"]}'
            })
            
            
    # count context data
    index_counts = Counter(entry["chunk_index"] for entry in retrieved_nodes)
        
    # Get top-k most frequent chunk indices
    top_indices = set(idx for idx, _ in index_counts.most_common(top_k))

    # Filter to keep only nodes from top-k chunk indices
    filtered_nodes = [entry for entry in retrieved_nodes if entry["chunk_index"] in top_indices]
        
    # Filter: Only keep relations that contain named entities from the question --> Soft pruning by Hu et al. (2024)
    contexts = list({
        entry["context"]
        for entry in filtered_nodes
        if any(ent.lower() in entry["context"].lower() for ent in entities)
    })
    #contexts = list(set(entry["context"] for entry in filtered_nodes))


    
    chunk_ids = [entry["chunk_id"] for entry in filtered_nodes]
    chunk_indices = [entry["chunk_index"] for entry in filtered_nodes]
    index_counts = dict(Counter(chunk_indices))
    law_titles = list(set(entry["law_title"] for entry in filtered_nodes))
    
    # Optional: truncate context to avoid token overflow (OpenAI token budget)
    enc = tiktoken.encoding_for_model("gpt-4o-mini")
    MAX_TOKENS = 3000

    final_contexts = []
    token_count = 0

    for ctx in contexts:
        tokens = len(enc.encode(ctx))
        if token_count + tokens > MAX_TOKENS:
            break
        final_contexts.append(ctx)
        token_count += tokens

    return {
        "prompt_context": final_contexts,
        "retrieved_chunk_ids": list(dict.fromkeys(chunk_ids)),
        "retrieved_chunk_indices": list(set(chunk_indices)),
        "retrieved_chunk_index_counts": index_counts,
        "law_titles": law_titles

    }


def retriever(question: str):
    """
    Retrieve structured context from the knowledge graph based on a legal question.
    Processes entity-level graph relationships and returns relevant context information.
    """

    # Run the structured retrieval based on question → returns a dictionary
    structured_data = structured_retriever(question)

    # Extract fields from structured result dictionary
    contexts = structured_data.get("prompt_context", [])
    chunk_ids = structured_data.get("retrieved_chunk_ids", [])
    chunk_indices = structured_data.get("retrieved_chunk_indices", [])
    chunk_index_counts = structured_data.get("retrieved_chunk_index_counts", [])
    law_titles = structured_data.get("law_titles", [])

    # Return final formatted output
    return {
        "prompt_context": contexts,
        "retrieved_chunk_ids": chunk_ids,
        "retrieved_chunk_indices": chunk_indices,
        "retrieved_chunk_index_counts": chunk_index_counts,
        "retrieved_law_titles": law_titles
    }


In [None]:
def format_hard_prompt(contexts: List[str], law_titles: List[str]) -> str:
    title_section = "\n".join(f"Gesetz: {t}" for t in set(law_titles))
    relation_section = "\n".join(f"- {c}" for c in contexts)
    return f"{title_section}\n\nRelationen:\n{relation_section}"

### 4. Output of Baseline RAG Model

In [None]:
# Define the prompt for the main task (answering the question based on context)
template = """
Du bist ein hilfreicher, juristischer KI-Assistent für Gesetzestexte im deutschen Energie- und Versorgungsbereich. 
Beantworte folgende Frage in der Sprache, in der sie gestellt wurde, und generiere eine kurze, präzise, konsistente und vollständige Antwort von max. 200 Tokens basierend auf folgenden kontextbasierten Relationen: 

Frage: {question}

Kontext:
{context}
"""

prompt = ChatPromptTemplate.from_template(template)

graph_rag_chain = (
    RunnableParallel({
        "retrieval": RunnablePassthrough() | retriever,
        "question": RunnablePassthrough()
    })
    | (lambda inputs: {
        **inputs,
        "context": inputs["retrieval"]["prompt_context"],
        "retrieved_chunk_ids": inputs["retrieval"]["retrieved_chunk_ids"],
        "retrieved_chunk_indices": inputs["retrieval"]["retrieved_chunk_indices"],
        "retrieved_chunk_index_counts": inputs["retrieval"]["retrieved_chunk_index_counts"],
        "law_titles": list(set(inputs["retrieval"].get("retrieved_law_titles", []))) 
    })
    | (lambda inputs: {
        **inputs,
        "formatted_context": format_hard_prompt(inputs["context"], inputs.get("law_titles", ["Unbekanntes Gesetz"]))
    })
    | (lambda inputs: {
        **inputs,
        "prompt_text": prompt.format(context=inputs["formatted_context"], question=inputs["question"])
    })
    | (lambda inputs: {
        **inputs,
        "generated_response": llm.invoke(inputs["prompt_text"]).content
    })
)


In [None]:
results = graph_rag_chain.invoke({"question": "Welchen Anwendungsbereich umfasst §1 des Energiewirtschaftsgesetzes - EnWG?"})

print("Antwort:")
print(results["generated_response"])
print("Context:")
print(results["context"])
print("Chunk IDs:")
print(results["retrieved_chunk_ids"])
print("Chunk Indices:")
print(results["retrieved_chunk_indices"])
print("Chunk Counts:")
print(results["retrieved_chunk_index_counts"])
print("Law Titles:")
print(results["law_titles"])


### 5. Evaluations

#### 5.1 Generate Evaluation Dataset

In [None]:
generated_eval_dataset_graph_rag = generate_graph_evalset(test_set_size=50, query_distribution={"single": 0.6, "multi_specific": 0.4})

#### 5.2 Enrich Evaluation Dataset

In [None]:
def enrich_eval_dataset_with_graph_rag_chain(eval_dataset):
    """
    Enriches a multi-hop evaluation dataset using a GraphRAG LangChain `Runnable`.
    """
    with open(eval_dataset, "r", encoding="utf-8") as f:
        eval_dataset_json = json.load(f)

    enriched_dataset = []

    for entry in tqdm(eval_dataset_json, desc="Processing GraphRAG responses"):
        query = entry["query"]

        try:
            results = graph_rag_chain.invoke({"question": query})

            entry["generated_response"] = results.get("generated_response", "")
            entry["retrieved_chunk_contexts"] = results.get("context", [])
            entry["retrieved_chunk_ids"] = results.get("retrieved_chunk_ids", [])
            entry["retrieved_chunk_indices"] = results.get("retrieved_chunk_indices", [])

            enriched_dataset.append(entry)

        except Exception as e:
            logging.warning(f"GraphRAG Error for query '{query}': {e}")

    output_path = f"{eval_dataset.replace('json', '')}_enriched.json"
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(enriched_dataset, f, indent=2, ensure_ascii=False)
        
    return output_path


In [None]:
eval_dataset_graph_rag_enriched = enrich_eval_dataset_with_graph_rag_chain(eval_dataset=generated_eval_dataset_graph_rag)

#### 5.3 Evaluation of Graph RAG Retrieval

In [None]:
json_filename = f"2_graph_database/{eval_dataset_graph_rag_enriched.split('/')[-1]}"
model_name="graph_rag"

retrieval_result = run_retrieval_evaluation(json_filename=json_filename, model_name=model_name)
display(retrieval_result)

#### Evaluation of Graph RAG Generation

In [None]:
generation_results = run_generation_evaluation(json_filename=json_filename, model_name=model_name) 
display(generation_results)

In [None]:
def read_last_line_from_json(filepath):
    with open(filepath, "r", encoding="utf-8") as f:
        lines = f.readlines()
        if lines:
            return lines[-1].strip()
        else:
            return None

# Beispielaufruf
file_path = "../../neo4j_db/import/graph_export.json"
last_line = read_last_line_from_json(file_path)

if last_line:
    print("📄 Letzte Zeile in der Datei:")
    print(last_line)
else:
    print("⚠️ Die Datei ist leer.")
