<a href="https://colab.research.google.com/github/jenny005/GraphRAG/blob/main/Neo4jGraphRAG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Cell 1: Install & Import

#!pip install -r requirements.txt

!pip install neo4j openai langchain langchain-openai networkx python-louvain sentence-transformers python-dotenv



In [None]:
from dotenv import load_dotenv  # load environment variables from a .env file into your program’s os.environ dictionary
load_dotenv("/content/sample_data/env", override=True)

True

In [None]:
# ============================================================================
# NEO4J GRAPH RAG - FOR NEO4J BROWSER VISUALIZATION
# ============================================================================
import os
import json
from typing import List, Dict, Any
import networkx as nx
from neo4j import GraphDatabase
import openai
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai import ChatOpenAI
from sentence_transformers import SentenceTransformer
import community as community_louvain
from sklearn.metrics.pairwise import cosine_similarity
from dotenv import load_dotenv

In [None]:
# ============================================================================
# STEP 3: SETUP CREDENTIALS
# ============================================================================

class Config:
    """Configuration class for API keys and database credentials"""

    def __init__(self, use_embedded=False):
        # OpenAI API Key - first try to load from environment, then prompt if not found
        self.OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
        if not self.OPENAI_API_KEY:
            print("⚠️ OPENAI_API_KEY not found in .env file")
            self.OPENAI_API_KEY = input("Enter your OpenAI API Key: ")
        else:
            print("✅ OpenAI API Key loaded from .env file")

        # Database selection
        self.USE_EMBEDDED = use_embedded

        if not use_embedded:
            # Neo4j Cloud Credentials
            self.NEO4J_URI = os.getenv("NEO4J_URI")
            if not self.NEO4J_URI:
                print("\n🌩️ NEO4J CLOUD SETUP REQUIRED:")
                print("Go to https://neo4j.com/cloud/platform/aura-graph-database/")
                print("1. Create free AuraDB instance")
                print("2. Get connection URI (starts with neo4j+s://)")
                print("3. Get username and password")
                print()
                self.NEO4J_URI = input("Enter Neo4j Cloud URI (neo4j+s://...): ")
            else:
                print("✅ Neo4j URI loaded from environment")

            self.NEO4J_USERNAME = os.getenv("NEO4J_USERNAME") or input("Enter Neo4j Username (usually 'neo4j'): ")
            self.NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD") or input("Enter Neo4j Password: ")
        else:
            print("✅ Using embedded NetworkX graph (no Neo4j required)")

        # Set OpenAI API key for the libraries
        openai.api_key = self.OPENAI_API_KEY
        os.environ["OPENAI_API_KEY"] = self.OPENAI_API_KEY

# Initialize configuration
config = Config()
print("✅ Credentials configured successfully!")

✅ OpenAI API Key loaded from .env file
✅ Neo4j URI loaded from environment
✅ Credentials configured successfully!


In [54]:
class Neo4jConnection:
    def __init__(self, uri: str, username: str, password: str):
        self.driver = GraphDatabase.driver(uri, auth=(username, password))

    def close(self):
        self.driver.close()

    def test_connection(self):
        try:
            with self.driver.session() as session:
                result = session.run("RETURN 'Connection successful' as message")
                message = result.single()["message"]
                print(f"✅ Neo4j Connection: {message}")
                return True
        except Exception as e:
            print(f"❌ Neo4j Connection Failed: {e}")
            return False

    def clear_database(self):
        with self.driver.session() as session:
            session.run("MATCH (n) DETACH DELETE n")
            print("🧹 Neo4j database cleared!")

    def get_graph_stats(self):
        """Get basic graph statistics"""
        with self.driver.session() as session:
            # Count nodes and relationships
            node_count = session.run("MATCH (n) RETURN count(n) as count").single()["count"]
            rel_count = session.run("MATCH ()-[r]->() RETURN count(r) as count").single()["count"]

            # Get entity types
            entity_types = session.run(
                "MATCH (e:Entity) RETURN DISTINCT e.type as type"
            ).values("type")

            # Get relationship types
            rel_types = session.run(
                "MATCH ()-[r:RELATES_TO]->() RETURN DISTINCT r.type as type"
            ).values("type")

            return {
                'total_nodes': node_count,
                'total_relationships': rel_count,
                'entity_types': [t for t in entity_types if t],
                'relationship_types': [t for t in rel_types if t]
            }

    def show_sample_data(self, limit=10):
        """Show sample entities and relationships"""
        print(f"\n🏷️ SAMPLE ENTITIES (limit {limit}):")
        print("-" * 50)

        with self.driver.session() as session:
            # Show entities
            result = session.run(f"""
            MATCH (e:Entity)
            RETURN e.name as name, e.type as type, e.description as description
            LIMIT {limit}
            """)

            for record in result:
                print(f"• {record['name']} ({record['type']})")
                if record['description']:
                    print(f"  └─ {record['description'][:80]}...")

        print(f"\n🔗 SAMPLE RELATIONSHIPS (limit {limit}):")
        print("-" * 50)

        with self.driver.session() as session:
            # Show relationships
            result = session.run(f"""
            MATCH (source:Entity)-[r:RELATES_TO]->(target:Entity)
            RETURN source.name as source, target.name as target,
                   r.type as rel_type, r.description as description
            LIMIT {limit}
            """)

            for record in result:
                print(f"• {record['source']} --[{record['rel_type']}]--> {record['target']}")
                if record['description']:
                    print(f"  └─ {record['description'][:80]}...")

    def create_neo4j_browser_queries(self):
        """Generate Cypher queries for Neo4j Browser visualization"""
        print("\n🎨 NEO4J BROWSER VISUALIZATION QUERIES:")
        print("="*60)

        print("\n1️⃣ BASIC GRAPH VIEW:")
        print("Copy this query into Neo4j Browser:")
        print("-" * 40)
        print("MATCH (n)-[r]->(m) RETURN n, r, m LIMIT 50")

        print("\n2️⃣ ENTITIES ONLY:")
        print("Copy this query into Neo4j Browser:")
        print("-" * 40)
        print("MATCH (e:Entity) RETURN e LIMIT 25")

        print("\n3️⃣ PEOPLE AND THEIR CONNECTIONS:")
        print("Copy this query into Neo4j Browser:")
        print("-" * 40)
        print("MATCH (p:Entity {type: 'PERSON'})-[r]->(n) RETURN p, r, n")

        print("\n4️⃣ ORGANIZATIONS AND GROUPS:")
        print("Copy this query into Neo4j Browser:")
        print("-" * 40)
        print("MATCH (o:Entity {type: 'ORGANIZATION'})-[r]->(n) RETURN o, r, n")

        print("\n5️⃣ FULL NETWORK (use carefully with large graphs):")
        print("Copy this query into Neo4j Browser:")
        print("-" * 40)
        print("MATCH (n)-[r]->(m) RETURN n, r, m")

        print(f"\n💡 HOW TO VISUALIZE:")
        print("1. Open Neo4j Browser (usually at localhost:7474 or your cloud URL)")
        print("2. Copy any query above")
        print("3. Paste in the query box and press Enter")
        print("4. Click on the graph visualization")
        print("5. Drag nodes to arrange the layout")

print("✅ Neo4j classes defined!")

# ============================================================================
# Cell 3: Graph RAG Components
# ============================================================================

class DocumentProcessor:
    def __init__(self, chunk_size: int = 1000, chunk_overlap: int = 200):
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size, chunk_overlap=chunk_overlap,
            length_function=len, separators=["\n\n", "\n", " ", ""]
        )

    def extract_and_chunk_text(self, documents: List[str]) -> List[Dict]:
        chunks = []
        for doc_id, document in enumerate(documents):
            doc_chunks = self.text_splitter.split_text(document)
            for chunk_id, chunk in enumerate(doc_chunks):
                chunks.append({
                    'doc_id': doc_id, 'chunk_id': chunk_id,
                    'text': chunk, 'chunk_key': f"doc_{doc_id}_chunk_{chunk_id}"
                })
        return chunks

class EntityRelationshipExtractor:
    def __init__(self, model_name: str = "gpt-3.5-turbo"):
        self.llm = ChatOpenAI(model_name=model_name, temperature=0)

    def extract_entities_relationships(self, text: str) -> Dict:
        prompt = f"""You are an expert knowledge graph constructor. Analyze the text and extract:
1. Important entities (people, organizations, concepts, locations, events)
2. Relationships between entities

Format as JSON:
{{"entities": [{{"name": "Entity Name", "type": "PERSON|ORGANIZATION|CONCEPT|LOCATION|EVENT", "description": "Brief description"}}], "relationships": [{{"source": "Entity1", "target": "Entity2", "relationship": "RELATIONSHIP_TYPE", "description": "Brief description"}}]}}

Text: {text}

JSON Response:"""

        try:
            response = self.llm.invoke(prompt)
            result = json.loads(response.content)
            return result
        except Exception as e:
            print(f"Error extracting entities/relationships: {e}")
            return {"entities": [], "relationships": []}

class KnowledgeGraphBuilder:
    def __init__(self, neo4j_conn: Neo4jConnection):
        self.neo4j_conn = neo4j_conn

    def create_entity_node(self, entity: Dict, chunk_key: str):
        with self.neo4j_conn.driver.session() as session:
            query = """
            MERGE (e:Entity {name: $name})
            SET e.type = $type, e.description = $description
            WITH e
            MERGE (c:Chunk {key: $chunk_key})
            MERGE (e)-[:MENTIONED_IN]->(c)
            """
            session.run(query, name=entity['name'], type=entity['type'],
                       description=entity.get('description', ''), chunk_key=chunk_key)

    def create_relationship(self, relationship: Dict, chunk_key: str):
        with self.neo4j_conn.driver.session() as session:
            query = """
            MATCH (source:Entity {name: $source_name})
            MATCH (target:Entity {name: $target_name})
            MERGE (source)-[r:RELATES_TO {type: $rel_type}]->(target)
            SET r.description = $description,
                r.found_in_chunk = $chunk_key
            """
            session.run(query, source_name=relationship['source'],
                       target_name=relationship['target'], rel_type=relationship['relationship'],
                       description=relationship.get('description', ''), chunk_key=chunk_key)

    def build_graph_from_chunks(self, chunks: List[Dict], extractor: EntityRelationshipExtractor):
        print("🔨 Building Knowledge Graph in Neo4j...")
        for i, chunk in enumerate(chunks):
            print(f"Processing chunk {i+1}/{len(chunks)}: {chunk['chunk_key']}")
            extracted = extractor.extract_entities_relationships(chunk['text'])

            for entity in extracted['entities']:
                self.create_entity_node(entity, chunk['chunk_key'])

            for relationship in extracted['relationships']:
                self.create_relationship(relationship, chunk['chunk_key'])

        print("✅ Knowledge Graph built in Neo4j!")

class GraphRAGPipeline:
    def __init__(self, neo4j_conn: Neo4jConnection):
        self.neo4j_conn = neo4j_conn
        self.doc_processor = DocumentProcessor()
        self.extractor = EntityRelationshipExtractor()
        self.graph_builder = KnowledgeGraphBuilder(neo4j_conn)

    def build_knowledge_graph(self, documents: List[str]):
        print("🚀 Starting Graph RAG Pipeline...")

        # Process documents
        chunks = self.doc_processor.extract_and_chunk_text(documents)
        print(f"Created {len(chunks)} chunks from {len(documents)} documents")

        # Build graph
        self.graph_builder.build_graph_from_chunks(chunks, self.extractor)

        # Show results
        stats = self.neo4j_conn.get_graph_stats()
        print(f"\n📊 GRAPH STATISTICS:")
        print(f"  • Total nodes: {stats['total_nodes']}")
        print(f"  • Total relationships: {stats['total_relationships']}")
        print(f"  • Entity types: {stats['entity_types']}")
        print(f"  • Relationship types: {stats['relationship_types']}")

        # Show sample data
        self.neo4j_conn.show_sample_data()

        # Generate visualization queries
        self.neo4j_conn.create_neo4j_browser_queries()

        return stats

print("✅ Graph RAG components defined!")

# ============================================================================
# Cell 4: Setup Neo4j Connection
# ============================================================================

config = Config()
neo4j_conn = Neo4jConnection(config.NEO4J_URI, config.NEO4J_USERNAME, config.NEO4J_PASSWORD)

if neo4j_conn.test_connection():
    print("✅ Neo4j ready for graph visualization!")
else:
    print("❌ Fix Neo4j connection before proceeding")

# ============================================================================
# Cell 5: Build Graph from Alice Text
# ============================================================================

graph_text = """
Sophia's title is Principal Data Architect in the Cloud Infrastructure Group.
Sophia works with Ethan whose title is Senior DevOps Engineer.
Ethan works in the Platform Engineering Group.
Ethan works with Mia whose title is Lead Security Analyst.
Mia works in the Cybersecurity Group.
Mia collaborates with Oliver whose title is Vice President of Risk Management.
Oliver works in the Executive Group.
Oliver also works with Sophia on compliance initiatives.

Meanwhile, Sophia works with Isabella whose title is Director of Product Innovation.
Isabella works in the Research & Development Group.
Isabella works with Noah whose title is Senior Product Manager.
Noah works in the Product Strategy Group.
Noah collaborates with Liam whose title is Chief Operating Officer.
Liam works in the Executive Group.

Additionally, Sophia and Isabella both work with Ava whose title is Head of Data Governance.
Ava works in the Enterprise Analytics Group.
Ava collaborates with David whose title is Chief Technology Officer.
David works in the Executive Group and is responsible for both the Technology Group and the Innovation Council.
"""

# Clear previous data and build new graph
neo4j_conn.clear_database()
pipeline = GraphRAGPipeline(neo4j_conn)
stats = pipeline.build_knowledge_graph([graph_text])

print(f"\n🎯 GRAPH CREATED SUCCESSFULLY!")
print(f"📊 Total entities: {stats['total_nodes']}")
print(f"🔗 Total relationships: {stats['total_relationships']}")
print(f"\n💡 Now open Neo4j Browser and use the Cypher queries above to visualize your graph!")

✅ Neo4j classes defined!
✅ Graph RAG components defined!
✅ OpenAI API Key loaded from .env file
✅ Neo4j URI loaded from environment
✅ Neo4j Connection: Connection successful
✅ Neo4j ready for graph visualization!
🧹 Neo4j database cleared!
🚀 Starting Graph RAG Pipeline...
Created 2 chunks from 1 documents
🔨 Building Knowledge Graph in Neo4j...
Processing chunk 1/2: doc_0_chunk_0
Processing chunk 2/2: doc_0_chunk_1
✅ Knowledge Graph built in Neo4j!

📊 GRAPH STATISTICS:
  • Total nodes: 11
  • Total relationships: 21
  • Entity types: [['PERSON']]
  • Relationship types: [['WORKS_WITH'], ['COLLABORATES_WITH']]

🏷️ SAMPLE ENTITIES (limit 10):
--------------------------------------------------
• Sophia (PERSON)
  └─ Works with Ava...
• Ethan (PERSON)
  └─ Senior DevOps Engineer in the Platform Engineering Group...
• Mia (PERSON)
  └─ Lead Security Analyst in the Cybersecurity Group...
• Oliver (PERSON)
  └─ Vice President of Risk Management in the Executive Group...
• Isabella (PERSON)
  └─