# RAG with Knowledge Graphs using Neo4j

This notebook demonstrates how to build a Retrieval-Augmented Generation (RAG) system using:
- **Neo4j** as a knowledge graph database
- **LangChain** for orchestration
- **Groq API** (Llama 3.3 70B) for LLM capabilities
- **Wikipedia** as data source
- **HuggingFace embeddings** for vector search

## Features
- Graph-based knowledge representation
- Hybrid search (vector + graph traversal)
- Entity extraction and relationship mapping
- Conversational interface with memory

## 1. Installation and Setup

In [None]:
# Install required packages
!pip install -q langchain==0.2.0 \
    langchain-community==0.2.1 \
    langchain-openai==0.1.7 \
    langchain-experimental==0.0.59 \
    neo4j==5.20 \
    wikipedia==1.4.0 \
    tiktoken==0.7.0 \
    sentence-transformers==2.7.0 \
    yfiles-jupyter-graphs-for-neo4j==1.0.0

## 2. Import Libraries

In [None]:
import os
from typing import List, Tuple

# LangChain imports
from langchain_openai import ChatOpenAI
from langchain_community.graphs import Neo4jGraph
from langchain_community.document_loaders import WikipediaLoader
from langchain_community.vectorstores import Neo4jVector
from langchain_community.vectorstores.neo4j_vector import remove_lucene_chars
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import TokenTextSplitter
from langchain_experimental.graph_transformers import LLMGraphTransformer

# LangChain core imports
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import (
    RunnableBranch,
    RunnableLambda,
    RunnableParallel,
    RunnablePassthrough,
)

# Visualization
from yfiles_jupyter_graphs import GraphWidget
from neo4j import GraphDatabase

## 3. Configuration

**Important:** Create a `config.py` file with your credentials:
```python
GROQ_API_KEY = "your-groq-api-key"
NEO4J_URI = "neo4j+s://your-instance.databases.neo4j.io"
NEO4J_USERNAME = "neo4j"
NEO4J_PASSWORD = "your-password"
```

In [None]:
# Option 1: Import from config file (recommended)
try:
    from config import GROQ_API_KEY, NEO4J_URI, NEO4J_USERNAME, NEO4J_PASSWORD
except ImportError:
    # Option 2: Set directly (not recommended for production)
    GROQ_API_KEY = "your-groq-api-key"
    NEO4J_URI = "neo4j+s://your-instance.databases.neo4j.io"
    NEO4J_USERNAME = "neo4j"
    NEO4J_PASSWORD = "your-password"

# Set environment variables
os.environ["GROQ_API_KEY"] = GROQ_API_KEY
os.environ["NEO4J_URI"] = NEO4J_URI
os.environ["NEO4J_USERNAME"] = NEO4J_USERNAME
os.environ["NEO4J_PASSWORD"] = NEO4J_PASSWORD

## 4. Initialize LLM and Graph Database

In [None]:
# Initialize Groq LLM (using OpenAI-compatible endpoint)
llm = ChatOpenAI(
    base_url="https://api.groq.com/openai/v1",
    api_key=os.environ["GROQ_API_KEY"],
    model="llama-3.3-70b-versatile",
    temperature=0.7,
    max_tokens=1024,
)

# Test LLM connection
print("Testing LLM connection...")
response = llm.invoke('Hello')
print(f"LLM Response: {response.content}")

In [None]:
# Initialize Neo4j graph
print("Connecting to Neo4j...")
graph = Neo4jGraph()
print("Successfully connected to Neo4j!")

## 5. Load and Process Data

In [None]:
# Load Wikipedia data
print("Loading Wikipedia data...")
raw_documents = WikipediaLoader(query="Elizabeth I").load()
print(f"Loaded {len(raw_documents)} documents")

# Split documents into chunks
text_splitter = TokenTextSplitter(chunk_size=512, chunk_overlap=24)
documents = text_splitter.split_documents(raw_documents[:3])
print(f"Split into {len(documents)} chunks")

## 6. Build Knowledge Graph

In [None]:
# Convert documents to graph
print("Converting documents to knowledge graph...")
llm_transformer = LLMGraphTransformer(llm=llm)
graph_documents = llm_transformer.convert_to_graph_documents(documents)
print(f"Generated {len(graph_documents)} graph documents")

# Add to Neo4j
print("Adding graph documents to Neo4j...")
graph.add_graph_documents(
    graph_documents,
    baseEntityLabel=True,
    include_source=True
)
print("Knowledge graph created successfully!")

## 7. Setup Vector Search

In [None]:
# Initialize embeddings
print("Loading embedding model...")
embeddings = HuggingFaceEmbeddings(
    model_name="sentence-transformers/all-MiniLM-L6-v2"
)

# Create vector index
print("Creating vector index...")
vector_index = Neo4jVector.from_existing_graph(
    embeddings,
    search_type="hybrid",
    node_label="Document",
    text_node_properties=["text"],
    embedding_node_property="embedding"
)
print("Vector index created!")

In [None]:
# Create full-text index for entity search
graph.query(
    "CREATE FULLTEXT INDEX entity IF NOT EXISTS FOR (e:__Entity__) ON EACH [e.id]"
)
print("Full-text index created!")

## 8. Entity Extraction Pipeline

In [None]:
# Define entity schema
class Entities(BaseModel):
    """Identifying information about entities."""
    names: List[str] = Field(
        ...,
        description="All the person, organization, or business entities that appear in the text",
    )

# Create entity extraction chain
entity_prompt = ChatPromptTemplate.from_messages([
    ("system", "You are extracting organization and person entities from the text."),
    ("human", "Use the given format to extract information from the following input: {question}"),
])

entity_chain = entity_prompt | llm.with_structured_output(Entities)

# Test entity extraction
test_entities = entity_chain.invoke({"question": "Who is Elizabeth I?"})
print(f"Extracted entities: {test_entities.names}")

## 9. Structured Retrieval Functions

In [None]:
def generate_full_text_query(input: str) -> str:
    """Generate fuzzy full-text search query for Neo4j."""
    full_text_query = ""
    words = [el for el in remove_lucene_chars(input).split() if el]
    for word in words[:-1]:
        full_text_query += f" {word}~2 AND"
    full_text_query += f" {words[-1]}~2"
    return full_text_query.strip()


def structured_retriever(question: str) -> str:
    """Retrieve structured data from knowledge graph based on entities."""
    result = ""
    entities = entity_chain.invoke({"question": question})
    
    for entity in entities.names:
        response = graph.query(
            """CALL db.index.fulltext.queryNodes('entity', $query, {limit:2})
            YIELD node,score
            CALL {
              WITH node
              MATCH (node)-[r:!MENTIONS]->(neighbor)
              RETURN node.id + ' - ' + type(r) + ' -> ' + neighbor.id AS output
              UNION ALL
              WITH node
              MATCH (node)<-[r:!MENTIONS]-(neighbor)
              RETURN neighbor.id + ' - ' + type(r) + ' -> ' + node.id AS output
            }
            RETURN output LIMIT 50
            """,
            {"query": generate_full_text_query(entity)},
        )
        result += "\n".join([el['output'] for el in response])
    
    return result


def retriever(question: str) -> str:
    """Hybrid retrieval combining structured graph data and vector search."""
    print(f"Search query: {question}")
    
    # Get structured data from graph
    structured_data = structured_retriever(question)
    
    # Get unstructured data from vector search
    unstructured_data = [
        el.page_content for el in vector_index.similarity_search(question)
    ]
    
    # Combine both data sources
    final_data = f"""Structured data:
{structured_data}

Unstructured data:
{"#Document ".join(unstructured_data)}
    """
    return final_data

In [None]:
# Test structured retrieval
test_result = structured_retriever("Who is Elizabeth I?")
print("Sample structured data:")
print(test_result[:500])

## 10. Conversational RAG Chain

In [None]:
# Chat history formatting
def _format_chat_history(chat_history: List[Tuple[str, str]]) -> List:
    """Format chat history into message objects."""
    buffer = []
    for human, ai in chat_history:
        buffer.append(HumanMessage(content=human))
        buffer.append(AIMessage(content=ai))
    return buffer


# Question condensation for follow-up queries
condense_template = """Given the following conversation and a follow up question, 
rephrase the follow up question to be a standalone question, in its original language.

Chat History:
{chat_history}

Follow Up Input: {question}

Standalone question:"""

CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(condense_template)


_search_query = RunnableBranch(
    # If input includes chat_history, condense it with the follow-up question
    (
        RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
            run_name="HasChatHistoryCheck"
        ),
        RunnablePassthrough.assign(
            chat_history=lambda x: _format_chat_history(x["chat_history"])
        )
        | CONDENSE_QUESTION_PROMPT
        | llm
        | StrOutputParser(),
    ),
    # Else, just pass through the question
    RunnableLambda(lambda x: x["question"]),
)

In [None]:
# Answer generation prompt
answer_template = """Answer the question based only on the following context:
{context}

Question: {question}

Use natural language and be concise.
Answer:"""

answer_prompt = ChatPromptTemplate.from_template(answer_template)

# Build the complete RAG chain
chain = (
    RunnableParallel(
        {
            "context": _search_query | retriever,
            "question": RunnablePassthrough(),
        }
    )
    | answer_prompt
    | llm
    | StrOutputParser()
)

print("RAG chain created successfully!")

## 11. Query Examples

In [None]:
# Example 1: Simple query
question1 = "Which house did Elizabeth I belong to?"
answer1 = chain.invoke({"question": question1})
print(f"Q: {question1}")
print(f"A: {answer1}\n")

In [None]:
# Example 2: Follow-up query with context
question2 = "When was she born?"
chat_history = [(question1, answer1)]
answer2 = chain.invoke({
    "question": question2,
    "chat_history": chat_history
})
print(f"Q: {question2}")
print(f"A: {answer2}\n")

In [None]:
# Example 3: Complex relationship query
question3 = "Who were Elizabeth I's parents and siblings?"
answer3 = chain.invoke({"question": question3})
print(f"Q: {question3}")
print(f"A: {answer3}")

## 12. Graph Visualization (Optional)

In [None]:
# Enable custom widgets in Colab
try:
    import google.colab
    from google.colab import output
    output.enable_custom_widget_manager()
except:
    pass


def show_graph(cypher: str = "MATCH (s)-[r:!MENTIONS]->(t) RETURN s,r,t LIMIT 50"):
    """Visualize Neo4j graph using yFiles."""
    driver = GraphDatabase.driver(
        uri=os.environ["NEO4J_URI"],
        auth=(os.environ["NEO4J_USERNAME"], os.environ["NEO4J_PASSWORD"])
    )
    session = driver.session()
    widget = GraphWidget(graph=session.run(cypher).graph())
    widget.node_label_mapping = 'id'
    display(widget)
    return widget


# Uncomment to visualize the graph
# show_graph()

## 13. Interactive Query Interface

In [None]:
def chat_interface():
    """Simple chat interface with memory."""
    chat_history = []
    print("RAG Knowledge Graph Chatbot (type 'quit' to exit)\n")
    
    while True:
        question = input("You: ").strip()
        
        if question.lower() in ['quit', 'exit', 'q']:
            print("Goodbye!")
            break
        
        if not question:
            continue
        
        try:
            # Get answer with chat history
            answer = chain.invoke({
                "question": question,
                "chat_history": chat_history
            })
            
            print(f"\nBot: {answer}\n")
            
            # Update chat history
            chat_history.append((question, answer))
            
        except Exception as e:
            print(f"Error: {str(e)}\n")


# Uncomment to start interactive chat
# chat_interface()

## Summary

This notebook demonstrates a complete RAG pipeline with:
- ✅ Knowledge graph construction from Wikipedia
- ✅ Hybrid retrieval (graph + vector search)
- ✅ Entity extraction and relationship mapping
- ✅ Conversational interface with memory
- ✅ Integration with Groq's Llama 3.3 70B model

### Next Steps
1. Expand the knowledge base with more documents
2. Fine-tune entity extraction for your domain
3. Add custom relationship types
4. Implement caching for better performance
5. Add evaluation metrics