# GraphRAG Workflow: Knowledge Graph-Enhanced Retrieval Augmented Generation

This notebook implements a production-ready GraphRAG system that:
1. Extracts entities and relationships from documents
2. Builds a knowledge graph representation
3. Performs graph-based retrieval for context
4. Generates responses using graph-augmented context

## Architecture Overview
- **Document Processing**: Chunk and embed documents
- **Entity Extraction**: LLM-based entity and relationship extraction
- **Graph Construction**: NetworkX for graph storage and traversal
- **Hybrid Retrieval**: Combines vector similarity and graph traversal
- **Augmented Generation**: Context-aware response generation

## 1. Environment Setup and Dependencies

In [None]:
# Install required packages
!pip install -q langchain langchain-openai langchain-community 
!pip install -q chromadb networkx pyvis
!pip install -q sentence-transformers tiktoken
!pip install -q python-dotenv tqdm
!pip install -q pandas numpy matplotlib seaborn
!pip install -q spacy
!python -m spacy download en_core_web_sm

In [None]:
import os
import json
import logging
from typing import List, Dict, Any, Optional, Tuple, Set
from dataclasses import dataclass, field
from datetime import datetime
import hashlib
from enum import Enum

import numpy as np
import pandas as pd
import networkx as nx
from pyvis.network import Network
import matplotlib.pyplot as plt
import seaborn as sns

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain.prompts import ChatPromptTemplate
from langchain.schema import Document
from langchain_community.vectorstores import Chroma
from langchain.chains import LLMChain

import spacy
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import tiktoken

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Set up environment
from dotenv import load_dotenv
load_dotenv()

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 8)

## 2. Core Data Models and Configuration

In [None]:
@dataclass
class Entity:
    """Represents an entity in the knowledge graph"""
    id: str
    name: str
    type: str
    description: Optional[str] = None
    properties: Dict[str, Any] = field(default_factory=dict)
    embedding: Optional[np.ndarray] = None
    source_chunks: List[str] = field(default_factory=list)
    
    def __hash__(self):
        return hash(self.id)

@dataclass
class Relationship:
    """Represents a relationship between entities"""
    source_id: str
    target_id: str
    type: str
    description: Optional[str] = None
    properties: Dict[str, Any] = field(default_factory=dict)
    confidence: float = 1.0
    source_chunks: List[str] = field(default_factory=list)

@dataclass
class GraphConfig:
    """Configuration for GraphRAG system"""
    chunk_size: int = 1000
    chunk_overlap: int = 200
    max_entities_per_chunk: int = 10
    max_relationships_per_chunk: int = 15
    embedding_model: str = "text-embedding-3-small"
    llm_model: str = "gpt-4o-mini"
    temperature: float = 0.1
    graph_traversal_depth: int = 2
    similarity_threshold: float = 0.7
    top_k_retrieval: int = 5
    
config = GraphConfig()

## 3. Document Processing and Chunking

In [None]:
class DocumentProcessor:
    """Handles document chunking and preprocessing"""
    
    def __init__(self, config: GraphConfig):
        self.config = config
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=config.chunk_size,
            chunk_overlap=config.chunk_overlap,
            length_function=self._tiktoken_len,
            separators=["\n\n", "\n", ". ", " ", ""]
        )
        self.encoding = tiktoken.encoding_for_model("gpt-4")
        
    def _tiktoken_len(self, text: str) -> int:
        """Calculate token length using tiktoken"""
        return len(self.encoding.encode(text))
    
    def process_documents(self, documents: List[str]) -> List[Document]:
        """Process documents into chunks with metadata"""
        all_chunks = []
        
        for doc_idx, doc_text in enumerate(documents):
            # Create chunks
            chunks = self.text_splitter.split_text(doc_text)
            
            # Add metadata
            for chunk_idx, chunk in enumerate(chunks):
                chunk_id = hashlib.md5(f"{doc_idx}_{chunk_idx}_{chunk[:50]}".encode()).hexdigest()[:12]
                doc = Document(
                    page_content=chunk,
                    metadata={
                        "doc_id": doc_idx,
                        "chunk_id": chunk_id,
                        "chunk_index": chunk_idx,
                        "token_count": self._tiktoken_len(chunk),
                        "timestamp": datetime.now().isoformat()
                    }
                )
                all_chunks.append(doc)
        
        logger.info(f"Processed {len(documents)} documents into {len(all_chunks)} chunks")
        return all_chunks

# Initialize processor
doc_processor = DocumentProcessor(config)

## 4. Entity and Relationship Extraction

In [None]:
class EntityExtractor:
    """Extracts entities and relationships using LLM"""
    
    def __init__(self, config: GraphConfig):
        self.config = config
        self.llm = ChatOpenAI(
            model=config.llm_model,
            temperature=config.temperature
        )
        self.nlp = spacy.load("en_core_web_sm")
        
        # Entity extraction prompt
        self.entity_prompt = ChatPromptTemplate.from_template("""
        Extract entities and relationships from the following text.
        Focus on key concepts, people, organizations, locations, and their relationships.
        
        Text: {text}
        
        Return a JSON with the following structure:
        {{
            "entities": [
                {{
                    "name": "entity name",
                    "type": "PERSON/ORGANIZATION/CONCEPT/LOCATION/OTHER",
                    "description": "brief description",
                    "properties": {{"key": "value"}}
                }}
            ],
            "relationships": [
                {{
                    "source": "source entity name",
                    "target": "target entity name",
                    "type": "relationship type",
                    "description": "relationship description"
                }}
            ]
        }}
        
        Limit to {max_entities} entities and {max_relationships} relationships.
        """)
        
    def extract_from_chunk(self, chunk: Document) -> Tuple[List[Entity], List[Relationship]]:
        """Extract entities and relationships from a single chunk"""
        try:
            # LLM extraction
            response = self.llm.invoke(self.entity_prompt.format_messages(
                text=chunk.page_content,
                max_entities=self.config.max_entities_per_chunk,
                max_relationships=self.config.max_relationships_per_chunk
            ))
            
            # Parse JSON response
            extraction = json.loads(response.content)
            
            # Create Entity objects
            entities = []
            for ent_data in extraction.get("entities", []):
                entity = Entity(
                    id=hashlib.md5(ent_data["name"].lower().encode()).hexdigest()[:12],
                    name=ent_data["name"],
                    type=ent_data.get("type", "OTHER"),
                    description=ent_data.get("description"),
                    properties=ent_data.get("properties", {}),
                    source_chunks=[chunk.metadata["chunk_id"]]
                )
                entities.append(entity)
            
            # Create Relationship objects
            relationships = []
            for rel_data in extraction.get("relationships", []):
                source_id = hashlib.md5(rel_data["source"].lower().encode()).hexdigest()[:12]
                target_id = hashlib.md5(rel_data["target"].lower().encode()).hexdigest()[:12]
                
                relationship = Relationship(
                    source_id=source_id,
                    target_id=target_id,
                    type=rel_data["type"],
                    description=rel_data.get("description"),
                    source_chunks=[chunk.metadata["chunk_id"]]
                )
                relationships.append(relationship)
            
            return entities, relationships
            
        except Exception as e:
            logger.error(f"Error extracting from chunk: {e}")
            return [], []
    
    def extract_from_documents(self, chunks: List[Document]) -> Tuple[Dict[str, Entity], List[Relationship]]:
        """Extract entities and relationships from all chunks"""
        all_entities = {}
        all_relationships = []
        
        for chunk in tqdm(chunks, desc="Extracting entities"):
            entities, relationships = self.extract_from_chunk(chunk)
            
            # Merge entities
            for entity in entities:
                if entity.id in all_entities:
                    # Merge source chunks
                    all_entities[entity.id].source_chunks.extend(entity.source_chunks)
                    # Update properties
                    all_entities[entity.id].properties.update(entity.properties)
                else:
                    all_entities[entity.id] = entity
            
            all_relationships.extend(relationships)
        
        logger.info(f"Extracted {len(all_entities)} unique entities and {len(all_relationships)} relationships")
        return all_entities, all_relationships

# Initialize extractor
entity_extractor = EntityExtractor(config)

## 5. Knowledge Graph Construction

In [None]:
class KnowledgeGraph:
    """Manages the knowledge graph structure and operations"""
    
    def __init__(self, config: GraphConfig):
        self.config = config
        self.graph = nx.DiGraph()
        self.embeddings_model = OpenAIEmbeddings(model=config.embedding_model)
        self.entities = {}
        self.entity_embeddings = {}
        
    def build_graph(self, entities: Dict[str, Entity], relationships: List[Relationship]):
        """Build the knowledge graph from entities and relationships"""
        self.entities = entities
        
        # Add nodes
        for entity_id, entity in entities.items():
            self.graph.add_node(
                entity_id,
                name=entity.name,
                type=entity.type,
                description=entity.description,
                properties=entity.properties,
                source_chunks=entity.source_chunks
            )
        
        # Add edges
        for rel in relationships:
            if rel.source_id in entities and rel.target_id in entities:
                self.graph.add_edge(
                    rel.source_id,
                    rel.target_id,
                    type=rel.type,
                    description=rel.description,
                    properties=rel.properties,
                    confidence=rel.confidence,
                    source_chunks=rel.source_chunks
                )
        
        # Generate embeddings for entities
        self._generate_entity_embeddings()
        
        logger.info(f"Graph built with {self.graph.number_of_nodes()} nodes and {self.graph.number_of_edges()} edges")
    
    def _generate_entity_embeddings(self):
        """Generate embeddings for all entities"""
        entity_texts = []
        entity_ids = []
        
        for entity_id, entity in self.entities.items():
            text = f"{entity.name} ({entity.type}): {entity.description or 'No description'}"
            entity_texts.append(text)
            entity_ids.append(entity_id)
        
        if entity_texts:
            embeddings = self.embeddings_model.embed_documents(entity_texts)
            for entity_id, embedding in zip(entity_ids, embeddings):
                self.entity_embeddings[entity_id] = np.array(embedding)
                self.entities[entity_id].embedding = np.array(embedding)
    
    def get_subgraph(self, entity_ids: List[str], depth: int = 2) -> nx.DiGraph:
        """Get subgraph around given entities up to specified depth"""
        nodes_to_include = set(entity_ids)
        
        for _ in range(depth):
            new_nodes = set()
            for node in nodes_to_include:
                if node in self.graph:
                    # Add predecessors and successors
                    new_nodes.update(self.graph.predecessors(node))
                    new_nodes.update(self.graph.successors(node))
            nodes_to_include.update(new_nodes)
        
        return self.graph.subgraph(nodes_to_include)
    
    def find_similar_entities(self, query_embedding: np.ndarray, top_k: int = 5) -> List[Tuple[str, float]]:
        """Find entities similar to query based on embeddings"""
        similarities = []
        
        for entity_id, entity_embedding in self.entity_embeddings.items():
            similarity = np.dot(query_embedding, entity_embedding) / (
                np.linalg.norm(query_embedding) * np.linalg.norm(entity_embedding)
            )
            similarities.append((entity_id, similarity))
        
        similarities.sort(key=lambda x: x[1], reverse=True)
        return similarities[:top_k]
    
    def get_graph_statistics(self) -> Dict[str, Any]:
        """Get statistics about the graph"""
        stats = {
            "num_nodes": self.graph.number_of_nodes(),
            "num_edges": self.graph.number_of_edges(),
            "density": nx.density(self.graph),
            "num_connected_components": nx.number_weakly_connected_components(self.graph),
            "avg_degree": np.mean([d for n, d in self.graph.degree()]),
            "entity_type_distribution": {}
        }
        
        # Entity type distribution
        for entity_id, entity in self.entities.items():
            entity_type = entity.type
            stats["entity_type_distribution"][entity_type] = \
                stats["entity_type_distribution"].get(entity_type, 0) + 1
        
        return stats

# Initialize knowledge graph
knowledge_graph = KnowledgeGraph(config)

## 6. Hybrid Retrieval System

In [None]:
class HybridRetriever:
    """Combines vector search and graph traversal for retrieval"""
    
    def __init__(self, config: GraphConfig, knowledge_graph: KnowledgeGraph):
        self.config = config
        self.knowledge_graph = knowledge_graph
        self.embeddings_model = OpenAIEmbeddings(model=config.embedding_model)
        self.vector_store = None
        self.chunks_map = {}
        
    def index_chunks(self, chunks: List[Document]):
        """Index document chunks for vector search"""
        # Create vector store
        self.vector_store = Chroma.from_documents(
            documents=chunks,
            embedding=self.embeddings_model,
            collection_name="graphrag_chunks"
        )
        
        # Store chunks map
        for chunk in chunks:
            self.chunks_map[chunk.metadata["chunk_id"]] = chunk
        
        logger.info(f"Indexed {len(chunks)} chunks in vector store")
    
    def retrieve(self, query: str, top_k: int = 5) -> Dict[str, Any]:
        """Perform hybrid retrieval combining vector search and graph traversal"""
        # Generate query embedding
        query_embedding = np.array(self.embeddings_model.embed_query(query))
        
        # 1. Vector search for relevant chunks
        vector_results = self.vector_store.similarity_search_with_score(query, k=top_k)
        
        # 2. Find similar entities in graph
        similar_entities = self.knowledge_graph.find_similar_entities(
            query_embedding, top_k=top_k
        )
        
        # 3. Get subgraph around similar entities
        entity_ids = [eid for eid, _ in similar_entities if _ > self.config.similarity_threshold]
        subgraph = self.knowledge_graph.get_subgraph(
            entity_ids, 
            depth=self.config.graph_traversal_depth
        )
        
        # 4. Collect all relevant chunks
        relevant_chunks = set()
        
        # From vector search
        for doc, score in vector_results:
            relevant_chunks.add(doc.metadata["chunk_id"])
        
        # From graph entities
        for node_id in subgraph.nodes():
            entity = self.knowledge_graph.entities.get(node_id)
            if entity:
                relevant_chunks.update(entity.source_chunks)
        
        # 5. Prepare context
        context = {
            "query": query,
            "vector_results": [
                {
                    "content": doc.page_content,
                    "metadata": doc.metadata,
                    "score": float(score)
                }
                for doc, score in vector_results
            ],
            "graph_entities": [
                {
                    "id": eid,
                    "name": self.knowledge_graph.entities[eid].name,
                    "type": self.knowledge_graph.entities[eid].type,
                    "description": self.knowledge_graph.entities[eid].description,
                    "similarity": float(sim)
                }
                for eid, sim in similar_entities
            ],
            "graph_relationships": [
                {
                    "source": self.knowledge_graph.entities[u].name,
                    "target": self.knowledge_graph.entities[v].name,
                    "type": data.get("type"),
                    "description": data.get("description")
                }
                for u, v, data in subgraph.edges(data=True)
            ],
            "relevant_chunk_ids": list(relevant_chunks),
            "subgraph_stats": {
                "num_nodes": subgraph.number_of_nodes(),
                "num_edges": subgraph.number_of_edges()
            }
        }
        
        return context

# Initialize retriever
retriever = HybridRetriever(config, knowledge_graph)

## 7. Graph-Augmented Generation

In [None]:
class GraphRAGGenerator:
    """Generates responses using graph-augmented context"""
    
    def __init__(self, config: GraphConfig):
        self.config = config
        self.llm = ChatOpenAI(
            model=config.llm_model,
            temperature=config.temperature
        )
        
        self.generation_prompt = ChatPromptTemplate.from_template("""
        You are an AI assistant with access to a knowledge graph and document chunks.
        Use the provided context to answer the question comprehensively.
        
        Question: {query}
        
        Document Context:
        {document_context}
        
        Relevant Entities:
        {entities_context}
        
        Entity Relationships:
        {relationships_context}
        
        Instructions:
        1. Synthesize information from both documents and the knowledge graph
        2. Cite specific entities and relationships when relevant
        3. Provide a comprehensive answer that leverages the graph structure
        4. If information is missing or unclear, state that explicitly
        
        Answer:
        """)
    
    def generate(self, retrieval_context: Dict[str, Any]) -> Dict[str, Any]:
        """Generate response using retrieved context"""
        
        # Prepare context strings
        document_context = "\n\n".join([
            f"[Chunk {i+1} - Score: {r['score']:.3f}]\n{r['content'][:500]}..."
            for i, r in enumerate(retrieval_context["vector_results"][:3])
        ])
        
        entities_context = "\n".join([
            f"- {e['name']} ({e['type']}): {e['description'] or 'No description'} [Similarity: {e['similarity']:.3f}]"
            for e in retrieval_context["graph_entities"][:5]
        ])
        
        relationships_context = "\n".join([
            f"- {r['source']} --[{r['type']}]--> {r['target']}: {r['description'] or ''}"
            for r in retrieval_context["graph_relationships"][:10]
        ])
        
        # Generate response
        response = self.llm.invoke(self.generation_prompt.format_messages(
            query=retrieval_context["query"],
            document_context=document_context,
            entities_context=entities_context,
            relationships_context=relationships_context
        ))
        
        return {
            "query": retrieval_context["query"],
            "response": response.content,
            "metadata": {
                "num_chunks_used": len(retrieval_context["vector_results"]),
                "num_entities_used": len(retrieval_context["graph_entities"]),
                "num_relationships_used": len(retrieval_context["graph_relationships"]),
                "subgraph_size": retrieval_context["subgraph_stats"],
                "timestamp": datetime.now().isoformat()
            }
        }

# Initialize generator
generator = GraphRAGGenerator(config)

## 8. Complete GraphRAG Pipeline

In [None]:
class GraphRAGPipeline:
    """End-to-end GraphRAG pipeline"""
    
    def __init__(self, config: GraphConfig):
        self.config = config
        self.doc_processor = DocumentProcessor(config)
        self.entity_extractor = EntityExtractor(config)
        self.knowledge_graph = KnowledgeGraph(config)
        self.retriever = HybridRetriever(config, self.knowledge_graph)
        self.generator = GraphRAGGenerator(config)
        self.chunks = []
        self.is_initialized = False
    
    def index_documents(self, documents: List[str]):
        """Index documents and build knowledge graph"""
        logger.info("Starting document indexing...")
        
        # Process documents into chunks
        self.chunks = self.doc_processor.process_documents(documents)
        
        # Extract entities and relationships
        entities, relationships = self.entity_extractor.extract_from_documents(self.chunks)
        
        # Build knowledge graph
        self.knowledge_graph.build_graph(entities, relationships)
        
        # Index chunks for vector search
        self.retriever.index_chunks(self.chunks)
        
        self.is_initialized = True
        logger.info("Document indexing complete")
        
        return self.knowledge_graph.get_graph_statistics()
    
    def query(self, question: str) -> Dict[str, Any]:
        """Query the GraphRAG system"""
        if not self.is_initialized:
            raise ValueError("Pipeline not initialized. Please index documents first.")
        
        logger.info(f"Processing query: {question}")
        
        # Retrieve context
        retrieval_context = self.retriever.retrieve(question)
        
        # Generate response
        result = self.generator.generate(retrieval_context)
        
        return result
    
    def visualize_graph(self, output_file: str = "knowledge_graph.html"):
        """Visualize the knowledge graph"""
        if not self.is_initialized:
            raise ValueError("Pipeline not initialized. Please index documents first.")
        
        net = Network(height="750px", width="100%", bgcolor="#222222", font_color="white")
        
        # Add nodes
        for node_id, node_data in self.knowledge_graph.graph.nodes(data=True):
            entity = self.knowledge_graph.entities[node_id]
            color = {
                "PERSON": "#ff6b6b",
                "ORGANIZATION": "#4ecdc4",
                "LOCATION": "#45b7d1",
                "CONCEPT": "#96ceb4",
                "OTHER": "#dfe6e9"
            }.get(entity.type, "#dfe6e9")
            
            net.add_node(
                node_id,
                label=entity.name,
                title=f"{entity.name}\n{entity.description or 'No description'}",
                color=color,
                size=20 + len(self.knowledge_graph.graph.edges(node_id)) * 3
            )
        
        # Add edges
        for source, target, edge_data in self.knowledge_graph.graph.edges(data=True):
            net.add_edge(
                source, 
                target,
                title=f"{edge_data.get('type', 'RELATED')}\n{edge_data.get('description', '')}",
                label=edge_data.get('type', 'RELATED')
            )
        
        net.save_graph(output_file)
        logger.info(f"Graph visualization saved to {output_file}")

# Initialize pipeline
pipeline = GraphRAGPipeline(config)

## 9. Example Usage and Demonstration

In [None]:
# Sample documents for demonstration
sample_documents = [
    """
    Artificial Intelligence (AI) has revolutionized many industries, particularly healthcare and finance.
    In healthcare, companies like DeepMind have developed AI systems for protein folding prediction,
    while IBM Watson has been used for cancer diagnosis and treatment recommendations.
    Stanford University's AI lab has been pioneering research in computer vision and natural language processing.
    Dr. Andrew Ng, a prominent AI researcher from Stanford, founded Coursera and has been instrumental
    in democratizing AI education. The field has seen rapid advancement with transformer architectures
    developed by Google Research, which led to models like BERT and GPT.
    """,
    
    """
    Machine learning applications in finance include fraud detection, algorithmic trading, and credit scoring.
    JPMorgan Chase has invested heavily in AI for risk assessment and portfolio optimization.
    Their COiN platform uses natural language processing to analyze legal documents.
    Goldman Sachs employs machine learning for market prediction and automated trading strategies.
    The collaboration between financial institutions and tech companies like Google and Microsoft
    has accelerated the adoption of AI in finance. Regulatory bodies like the SEC are developing
    frameworks for AI governance in financial markets.
    """,
    
    """
    The ethical implications of AI have become a central concern for researchers and policymakers.
    The Partnership on AI, founded by major tech companies including Google, Facebook, Amazon, and Microsoft,
    works on establishing best practices for AI development. Dr. Timnit Gebru's research on bias in AI
    has highlighted the importance of diversity in AI development teams. The European Union has proposed
    comprehensive AI regulations focusing on transparency and accountability. Universities like MIT and
    Oxford have established AI ethics research centers to study the societal impact of AI technologies.
    """
]

# Index documents
print("Indexing documents...")
stats = pipeline.index_documents(sample_documents)
print("\nGraph Statistics:")
for key, value in stats.items():
    print(f"  {key}: {value}")

In [None]:
# Example queries
queries = [
    "What companies are working on AI in healthcare?",
    "How is AI being used in finance?",
    "What are the connections between universities and AI research?",
    "Who are the key researchers mentioned and their contributions?"
]

results = []
for query in queries:
    print(f"\nQuery: {query}")
    print("-" * 80)
    
    result = pipeline.query(query)
    results.append(result)
    
    print(f"Response: {result['response']}")
    print(f"\nMetadata:")
    for key, value in result['metadata'].items():
        print(f"  {key}: {value}")

In [None]:
# Visualize the knowledge graph
pipeline.visualize_graph("knowledge_graph.html")
print("Knowledge graph visualization saved to knowledge_graph.html")

# Display graph statistics
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Entity type distribution
stats = pipeline.knowledge_graph.get_graph_statistics()
entity_types = list(stats['entity_type_distribution'].keys())
entity_counts = list(stats['entity_type_distribution'].values())

axes[0, 0].bar(entity_types, entity_counts, color='skyblue')
axes[0, 0].set_title('Entity Type Distribution')
axes[0, 0].set_xlabel('Entity Type')
axes[0, 0].set_ylabel('Count')
axes[0, 0].tick_params(axis='x', rotation=45)

# Degree distribution
degrees = [d for n, d in pipeline.knowledge_graph.graph.degree()]
axes[0, 1].hist(degrees, bins=20, color='lightgreen', edgecolor='black')
axes[0, 1].set_title('Node Degree Distribution')
axes[0, 1].set_xlabel('Degree')
axes[0, 1].set_ylabel('Frequency')

# Query performance metrics
query_lengths = [len(q) for q in queries]
response_lengths = [len(r['response']) for r in results]

axes[1, 0].scatter(query_lengths, response_lengths, s=100, alpha=0.6, color='coral')
axes[1, 0].set_title('Query vs Response Length')
axes[1, 0].set_xlabel('Query Length (characters)')
axes[1, 0].set_ylabel('Response Length (characters)')

# Retrieval metrics
metrics = ['Chunks', 'Entities', 'Relationships']
avg_metrics = [
    np.mean([r['metadata']['num_chunks_used'] for r in results]),
    np.mean([r['metadata']['num_entities_used'] for r in results]),
    np.mean([r['metadata']['num_relationships_used'] for r in results])
]

axes[1, 1].bar(metrics, avg_metrics, color=['#ff9999', '#66b3ff', '#99ff99'])
axes[1, 1].set_title('Average Retrieval Metrics per Query')
axes[1, 1].set_ylabel('Average Count')

plt.tight_layout()
plt.show()

## 10. Advanced Features and Production Considerations

In [None]:
class GraphRAGMonitor:
    """Monitoring and observability for GraphRAG system"""
    
    def __init__(self):
        self.metrics = {
            "queries_processed": 0,
            "avg_response_time": [],
            "avg_entities_per_query": [],
            "avg_chunks_per_query": [],
            "error_count": 0,
            "query_history": []
        }
    
    def log_query(self, query: str, result: Dict[str, Any], response_time: float):
        """Log query metrics"""
        self.metrics["queries_processed"] += 1
        self.metrics["avg_response_time"].append(response_time)
        self.metrics["avg_entities_per_query"].append(result['metadata']['num_entities_used'])
        self.metrics["avg_chunks_per_query"].append(result['metadata']['num_chunks_used'])
        
        self.metrics["query_history"].append({
            "query": query,
            "timestamp": datetime.now().isoformat(),
            "response_time": response_time,
            "metadata": result['metadata']
        })
    
    def log_error(self, error: Exception, context: Dict[str, Any]):
        """Log errors"""
        self.metrics["error_count"] += 1
        logger.error(f"GraphRAG Error: {error}, Context: {context}")
    
    def get_summary(self) -> Dict[str, Any]:
        """Get monitoring summary"""
        return {
            "total_queries": self.metrics["queries_processed"],
            "avg_response_time": np.mean(self.metrics["avg_response_time"]) if self.metrics["avg_response_time"] else 0,
            "avg_entities_per_query": np.mean(self.metrics["avg_entities_per_query"]) if self.metrics["avg_entities_per_query"] else 0,
            "avg_chunks_per_query": np.mean(self.metrics["avg_chunks_per_query"]) if self.metrics["avg_chunks_per_query"] else 0,
            "error_rate": self.metrics["error_count"] / max(self.metrics["queries_processed"], 1),
            "recent_queries": self.metrics["query_history"][-5:]
        }

# Initialize monitor
monitor = GraphRAGMonitor()

# Example monitoring
import time

for query in queries[:2]:
    start_time = time.time()
    try:
        result = pipeline.query(query)
        response_time = time.time() - start_time
        monitor.log_query(query, result, response_time)
    except Exception as e:
        monitor.log_error(e, {"query": query})

print("\nMonitoring Summary:")
for key, value in monitor.get_summary().items():
    if key != "recent_queries":
        print(f"  {key}: {value}")

In [None]:
class GraphRAGCache:
    """Caching layer for GraphRAG queries"""
    
    def __init__(self, max_size: int = 100):
        self.cache = {}
        self.max_size = max_size
        self.hit_count = 0
        self.miss_count = 0
    
    def get_cache_key(self, query: str) -> str:
        """Generate cache key for query"""
        return hashlib.md5(query.lower().strip().encode()).hexdigest()
    
    def get(self, query: str) -> Optional[Dict[str, Any]]:
        """Get cached result"""
        key = self.get_cache_key(query)
        if key in self.cache:
            self.hit_count += 1
            logger.info(f"Cache hit for query: {query[:50]}...")
            return self.cache[key]
        self.miss_count += 1
        return None
    
    def set(self, query: str, result: Dict[str, Any]):
        """Cache result"""
        if len(self.cache) >= self.max_size:
            # Remove oldest entry (simple FIFO)
            oldest_key = next(iter(self.cache))
            del self.cache[oldest_key]
        
        key = self.get_cache_key(query)
        self.cache[key] = result
    
    def get_stats(self) -> Dict[str, Any]:
        """Get cache statistics"""
        total_requests = self.hit_count + self.miss_count
        hit_rate = self.hit_count / max(total_requests, 1)
        
        return {
            "cache_size": len(self.cache),
            "max_size": self.max_size,
            "hit_count": self.hit_count,
            "miss_count": self.miss_count,
            "hit_rate": hit_rate,
            "total_requests": total_requests
        }

# Initialize cache
cache = GraphRAGCache(max_size=50)

# Test caching
test_query = "What companies are working on AI in healthcare?"

# First query (miss)
cached_result = cache.get(test_query)
if not cached_result:
    result = pipeline.query(test_query)
    cache.set(test_query, result)
    print("Query executed and cached")

# Second query (hit)
cached_result = cache.get(test_query)
if cached_result:
    print("Retrieved from cache")

print("\nCache Statistics:")
for key, value in cache.get_stats().items():
    print(f"  {key}: {value}")

## 11. Evaluation Metrics

In [None]:
class GraphRAGEvaluator:
    """Evaluation metrics for GraphRAG system"""
    
    def __init__(self, pipeline: GraphRAGPipeline):
        self.pipeline = pipeline
        self.embeddings_model = OpenAIEmbeddings(model=config.embedding_model)
    
    def evaluate_retrieval_relevance(self, query: str, retrieved_chunks: List[str]) -> float:
        """Evaluate relevance of retrieved chunks"""
        query_embedding = np.array(self.embeddings_model.embed_query(query))
        chunk_embeddings = [np.array(e) for e in self.embeddings_model.embed_documents(retrieved_chunks)]
        
        similarities = [
            np.dot(query_embedding, chunk_emb) / (np.linalg.norm(query_embedding) * np.linalg.norm(chunk_emb))
            for chunk_emb in chunk_embeddings
        ]
        
        return np.mean(similarities)
    
    def evaluate_graph_coverage(self, query_result: Dict[str, Any]) -> Dict[str, float]:
        """Evaluate how well the graph was utilized"""
        total_entities = len(self.pipeline.knowledge_graph.entities)
        total_relationships = self.pipeline.knowledge_graph.graph.number_of_edges()
        
        used_entities = query_result['metadata']['num_entities_used']
        used_relationships = query_result['metadata']['num_relationships_used']
        
        return {
            "entity_coverage": used_entities / max(total_entities, 1),
            "relationship_coverage": used_relationships / max(total_relationships, 1),
            "graph_utilization_score": (used_entities + used_relationships) / max(total_entities + total_relationships, 1)
        }
    
    def evaluate_response_quality(self, query: str, response: str) -> Dict[str, float]:
        """Evaluate response quality metrics"""
        # Simple heuristics - in production, use more sophisticated metrics
        
        return {
            "response_length": len(response),
            "word_count": len(response.split()),
            "sentence_count": len([s for s in response.split('.') if s.strip()]),
            "query_terms_in_response": sum(1 for term in query.lower().split() if term in response.lower()) / len(query.split())
        }
    
    def run_evaluation(self, test_queries: List[str]) -> pd.DataFrame:
        """Run comprehensive evaluation"""
        evaluation_results = []
        
        for query in test_queries:
            result = self.pipeline.query(query)
            retrieval_context = self.pipeline.retriever.retrieve(query)
            
            # Get retrieved chunks
            retrieved_chunks = [r['content'] for r in retrieval_context['vector_results']]
            
            # Calculate metrics
            relevance_score = self.evaluate_retrieval_relevance(query, retrieved_chunks)
            coverage_metrics = self.evaluate_graph_coverage(result)
            quality_metrics = self.evaluate_response_quality(query, result['response'])
            
            evaluation_results.append({
                "query": query[:50] + "...",
                "retrieval_relevance": relevance_score,
                **coverage_metrics,
                **quality_metrics
            })
        
        return pd.DataFrame(evaluation_results)

# Run evaluation
evaluator = GraphRAGEvaluator(pipeline)
eval_df = evaluator.run_evaluation(queries)

print("\nEvaluation Results:")
print(eval_df.to_string())

print("\nAverage Metrics:")
numeric_columns = eval_df.select_dtypes(include=[np.number]).columns
for col in numeric_columns:
    print(f"  {col}: {eval_df[col].mean():.3f}")

## 12. Production Deployment Considerations

In [None]:
# Production configuration template
production_config = {
    "deployment": {
        "environment": "production",
        "api_endpoint": "https://api.graphrag.example.com",
        "max_concurrent_requests": 100,
        "request_timeout": 30,
        "retry_policy": {
            "max_retries": 3,
            "backoff_factor": 2,
            "retry_on_status": [500, 502, 503, 504]
        }
    },
    
    "scaling": {
        "auto_scaling": True,
        "min_instances": 2,
        "max_instances": 10,
        "target_cpu_utilization": 70,
        "scale_up_threshold": 80,
        "scale_down_threshold": 30
    },
    
    "storage": {
        "vector_store": {
            "type": "chroma",
            "persistence_path": "/data/chroma",
            "collection_name": "graphrag_production"
        },
        "graph_store": {
            "type": "neo4j",
            "uri": "bolt://neo4j:7687",
            "database": "graphrag"
        },
        "cache": {
            "type": "redis",
            "host": "redis.example.com",
            "port": 6379,
            "ttl": 3600
        }
    },
    
    "monitoring": {
        "metrics_backend": "prometheus",
        "logging_level": "INFO",
        "log_aggregation": "elasticsearch",
        "alerting": {
            "enabled": True,
            "channels": ["email", "slack"],
            "thresholds": {
                "error_rate": 0.05,
                "response_time_p95": 5000,
                "cpu_utilization": 85
            }
        }
    },
    
    "security": {
        "authentication": "oauth2",
        "rate_limiting": {
            "enabled": True,
            "requests_per_minute": 60,
            "requests_per_hour": 1000
        },
        "encryption": {
            "at_rest": True,
            "in_transit": True,
            "algorithm": "AES-256-GCM"
        }
    },
    
    "optimization": {
        "batch_processing": True,
        "batch_size": 32,
        "async_processing": True,
        "connection_pooling": True,
        "query_optimization": {
            "cache_embeddings": True,
            "precompute_graph_traversals": True,
            "index_hot_paths": True
        }
    }
}

print("Production Configuration Template:")
print(json.dumps(production_config, indent=2))

## Summary

This notebook demonstrates a complete GraphRAG implementation with:

### Core Features
- **Document Processing**: Intelligent chunking with token counting
- **Entity Extraction**: LLM-based entity and relationship extraction
- **Knowledge Graph**: NetworkX-based graph construction and management
- **Hybrid Retrieval**: Combined vector search and graph traversal
- **Augmented Generation**: Context-aware response generation

### Production Features
- **Monitoring**: Comprehensive metrics and observability
- **Caching**: Query result caching for performance
- **Evaluation**: Multiple metrics for system assessment
- **Error Handling**: Robust error handling throughout
- **Configuration**: Flexible configuration management

### Next Steps
1. **Scale Testing**: Test with larger document collections
2. **Graph Database**: Migrate to Neo4j for production scale
3. **Advanced Extraction**: Implement more sophisticated entity extraction
4. **Query Understanding**: Add query intent classification
5. **Multi-hop Reasoning**: Implement complex graph traversal strategies
6. **API Development**: Build REST/GraphQL API layer
7. **UI Development**: Create visualization and query interface

The system is designed to be modular and extensible, allowing for easy customization and enhancement based on specific use cases.