# Module 5: Retrievers for Graph-Enhanced RAG

This notebook focuses on implementing retrieval patterns for Retrieval-Augmented Generation (RAG) using graph databases, combining semantic search with graph traversal for enhanced context retrieval.

## Learning Objectives
- Implement vector-based semantic retrieval in Neo4j
- Design graph traversal retrieval patterns
- Combine vector and graph retrieval strategies
- Optimize retrieval performance and relevance
- Build hybrid retrievers for complex queries

## Prerequisites
- Completion of Module 4: Graph Analytics
- Understanding of vector embeddings and RAG concepts

## Setup and Dependencies

In [None]:
# Install required packages
!pip install neo4j pandas numpy openai sentence-transformers langchain tiktoken faiss-cpu

In [None]:
# Import libraries
import os
import pandas as pd
import numpy as np
import json
from neo4j import GraphDatabase
from sentence_transformers import SentenceTransformer
import openai
from typing import List, Dict, Tuple, Optional
import tiktoken
from datetime import datetime
import re
from dataclasses import dataclass
import time
from sklearn.metrics.pairwise import cosine_similarity
import warnings
warnings.filterwarnings('ignore')

# Initialize models
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
tokenizer = tiktoken.get_encoding("cl100k_base")

print("Libraries and models loaded successfully!")

## Neo4j Connection and Data Setup

In [None]:
# Neo4j connection settings
NEO4J_URI = os.getenv('NEO4J_URI', 'bolt://localhost:7687')
NEO4J_USERNAME = os.getenv('NEO4J_USERNAME', 'neo4j')
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD', 'password')

# Create Neo4j driver
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))

def run_query(query, parameters=None):
    """Execute a Cypher query and return results"""
    with driver.session() as session:
        result = session.run(query, parameters or {})
        return [record.data() for record in result]

# Test connection
print("Testing Neo4j connection...")
result = run_query("RETURN 'Connected to Neo4j!' as message")
print(result[0]['message'])

In [None]:
# Clear existing data and create sample knowledge base
run_query("MATCH (n) DETACH DELETE n")

# Sample knowledge base: AI/ML research papers and concepts
knowledge_base = [
    {
        "id": "paper1",
        "title": "Attention Is All You Need",
        "authors": ["Vaswani", "Shazeer", "Parmar", "Uszkoreit"],
        "year": 2017,
        "venue": "NeurIPS",
        "abstract": "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks that include an encoder and a decoder. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely.",
        "concepts": ["Transformer", "Attention Mechanism", "Self-Attention", "Multi-Head Attention", "Neural Machine Translation"],
        "citations": 50000,
        "field": "Natural Language Processing"
    },
    {
        "id": "paper2",
        "title": "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding",
        "authors": ["Devlin", "Chang", "Lee", "Toutanova"],
        "year": 2018,
        "venue": "NAACL",
        "abstract": "We introduce a new language representation model called BERT, which stands for Bidirectional Encoder Representations from Transformers. Unlike recent language representation models, BERT is designed to pre-train deep bidirectional representations from unlabeled text by jointly conditioning on both left and right context in all layers.",
        "concepts": ["BERT", "Bidirectional Transformer", "Pre-training", "Masked Language Modeling", "Fine-tuning"],
        "citations": 75000,
        "field": "Natural Language Processing"
    },
    {
        "id": "paper3",
        "title": "Language Models are Few-Shot Learners",
        "authors": ["Brown", "Mann", "Ryder", "Subbiah"],
        "year": 2020,
        "venue": "NeurIPS",
        "abstract": "Recent work has demonstrated substantial gains on many NLP tasks and benchmarks by pre-training on a large corpus of text followed by fine-tuning on a specific task. While typically task-agnostic in architecture, this method still requires task-specific fine-tuning datasets of thousands or tens of thousands of examples. By contrast, humans can generally perform a new language task from only a few examples or from simple instructions.",
        "concepts": ["GPT-3", "Few-Shot Learning", "In-Context Learning", "Large Language Models", "Emergent Abilities"],
        "citations": 25000,
        "field": "Natural Language Processing"
    },
    {
        "id": "paper4",
        "title": "Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks",
        "authors": ["Lewis", "Perez", "Piktus", "Petroni"],
        "year": 2020,
        "venue": "NeurIPS",
        "abstract": "Large pre-trained language models have been shown to store factual knowledge in their parameters, and achieve state-of-the-art results when fine-tuned on downstream NLP tasks. However, their ability to access and precisely manipulate knowledge is still limited, and hence on knowledge-intensive tasks, their performance lags behind task-specific architectures. Additionally, providing provenance for their decisions is challenging.",
        "concepts": ["RAG", "Retrieval-Augmented Generation", "Knowledge-Intensive Tasks", "Dense Passage Retrieval", "Hybrid Models"],
        "citations": 8000,
        "field": "Natural Language Processing"
    },
    {
        "id": "paper5",
        "title": "Graph Neural Networks: A Review of Methods and Applications",
        "authors": ["Zhou", "Cui", "Hu", "Zhang"],
        "year": 2020,
        "venue": "AI Open",
        "abstract": "Lots of learning tasks require dealing with graph data which contains rich relation information among elements. Modeling physics systems, learning molecular fingerprints, predicting protein interface, and classifying diseases require a model to learn from graph inputs. In other domains such as learning from non-Euclidean data, graph neural networks (GNNs) have been demonstrated to be effective.",
        "concepts": ["Graph Neural Networks", "GNN", "Graph Convolution", "Message Passing", "Graph Attention"],
        "citations": 15000,
        "field": "Machine Learning"
    }
]

print(f"Loaded {len(knowledge_base)} research papers for the knowledge base")

## Lesson 1: Vector Similarity Search Implementation

In [None]:
# Create knowledge graph with embeddings
def create_knowledge_graph(papers):
    """Create a knowledge graph from research papers with embeddings"""
    
    for paper in papers:
        # Create paper node
        paper_query = """
        CREATE (p:Paper {
            id: $id,
            title: $title,
            year: $year,
            venue: $venue,
            abstract: $abstract,
            citations: $citations,
            field: $field
        })
        """
        
        run_query(paper_query, {
            'id': paper['id'],
            'title': paper['title'],
            'year': paper['year'],
            'venue': paper['venue'],
            'abstract': paper['abstract'],
            'citations': paper['citations'],
            'field': paper['field']
        })
        
        # Create author nodes and relationships
        for author in paper['authors']:
            author_query = """
            MERGE (a:Author {name: $author_name})
            WITH a
            MATCH (p:Paper {id: $paper_id})
            CREATE (a)-[:AUTHORED]->(p)
            """
            
            run_query(author_query, {
                'author_name': author,
                'paper_id': paper['id']
            })
        
        # Create concept nodes and relationships
        for concept in paper['concepts']:
            concept_query = """
            MERGE (c:Concept {name: $concept_name})
            WITH c
            MATCH (p:Paper {id: $paper_id})
            CREATE (p)-[:DISCUSSES]->(c)
            """
            
            run_query(concept_query, {
                'concept_name': concept,
                'paper_id': paper['id']
            })
        
        # Create field node and relationship
        field_query = """
        MERGE (f:Field {name: $field_name})
        WITH f
        MATCH (p:Paper {id: $paper_id})
        CREATE (p)-[:BELONGS_TO]->(f)
        """
        
        run_query(field_query, {
            'field_name': paper['field'],
            'paper_id': paper['id']
        })

# Create the knowledge graph
create_knowledge_graph(knowledge_base)
print("Knowledge graph created with papers, authors, concepts, and fields")

In [None]:
# Generate and store embeddings for papers
def add_embeddings_to_papers():
    """Generate embeddings for paper titles and abstracts"""
    
    # Get all papers
    papers = run_query("MATCH (p:Paper) RETURN p.id as id, p.title as title, p.abstract as abstract")
    
    for paper in papers:
        # Combine title and abstract for embedding
        text_content = f"{paper['title']} {paper['abstract']}"
        
        # Generate embedding
        embedding = embedding_model.encode([text_content])[0]
        
        # Store embedding in Neo4j
        embedding_query = """
        MATCH (p:Paper {id: $paper_id})
        SET p.embedding = $embedding
        """
        
        run_query(embedding_query, {
            'paper_id': paper['id'],
            'embedding': embedding.tolist()
        })
        
        print(f"Generated embedding for: {paper['title'][:50]}...")

add_embeddings_to_papers()
print("\nEmbeddings added to all papers in the knowledge graph")

In [None]:
# Create vector search index
def create_vector_index():
    """Create vector search index for papers"""
    
    index_query = """
    CREATE VECTOR INDEX paper_embeddings IF NOT EXISTS
    FOR (p:Paper) ON (p.embedding)
    OPTIONS {
      indexConfig: {
        `vector.dimensions`: 384,
        `vector.similarity_function`: 'cosine'
      }
    }
    """
    
    try:
        run_query(index_query)
        print("Vector index created successfully")
        
        # Wait for index to be ready
        time.sleep(2)
        
        # Check index status
        status = run_query("SHOW INDEXES YIELD name, state WHERE name = 'paper_embeddings'")
        if status:
            print(f"Index status: {status[0]['state']}")
        
    except Exception as e:
        print(f"Vector index creation failed: {e}")
        print("Proceeding with fallback similarity search")

create_vector_index()

In [None]:
# Implement vector similarity search
@dataclass
class RetrievalResult:
    """Container for retrieval results"""
    paper_id: str
    title: str
    abstract: str
    score: float
    retrieval_method: str
    additional_info: Dict = None

def vector_similarity_search(query: str, top_k: int = 3) -> List[RetrievalResult]:
    """Perform vector similarity search"""
    
    # Generate query embedding
    query_embedding = embedding_model.encode([query])[0]
    
    try:
        # Try using vector index first
        search_query = """
        CALL db.index.vector.queryNodes('paper_embeddings', $top_k, $query_embedding)
        YIELD node AS paper, score
        RETURN paper.id as paper_id,
               paper.title as title,
               paper.abstract as abstract,
               score
        ORDER BY score DESC
        """
        
        results = run_query(search_query, {
            'query_embedding': query_embedding.tolist(),
            'top_k': top_k
        })
        
    except Exception as e:
        print(f"Vector index search failed: {e}")
        print("Falling back to manual similarity calculation")
        
        # Fallback: manual similarity calculation
        papers = run_query("MATCH (p:Paper) RETURN p.id as paper_id, p.title as title, p.abstract as abstract, p.embedding as embedding")
        
        similarities = []
        for paper in papers:
            if paper['embedding']:
                paper_embedding = np.array(paper['embedding'])
                similarity = cosine_similarity([query_embedding], [paper_embedding])[0][0]
                similarities.append({
                    'paper_id': paper['paper_id'],
                    'title': paper['title'],
                    'abstract': paper['abstract'],
                    'score': float(similarity)
                })
        
        # Sort by similarity and take top k
        results = sorted(similarities, key=lambda x: x['score'], reverse=True)[:top_k]
    
    # Convert to RetrievalResult objects
    return [
        RetrievalResult(
            paper_id=r['paper_id'],
            title=r['title'],
            abstract=r['abstract'],
            score=r['score'],
            retrieval_method='vector_similarity'
        )
        for r in results
    ]

# Test vector similarity search
test_query = "transformer models for language understanding"
vector_results = vector_similarity_search(test_query, top_k=3)

print(f"Vector Similarity Search Results for: '{test_query}'\n")
for i, result in enumerate(vector_results, 1):
    print(f"{i}. {result.title} (Score: {result.score:.4f})")
    print(f"   Abstract: {result.abstract[:100]}...\n")

## Lesson 2: Graph Traversal Retrieval Patterns

In [None]:
# Implement graph traversal retrieval
def graph_traversal_search(query: str, max_depth: int = 2) -> List[RetrievalResult]:
    """Retrieve papers using graph traversal patterns"""
    
    # Extract key terms from query
    query_terms = [term.strip().lower() for term in query.split() if len(term.strip()) > 3]
    
    results = []
    
    # Pattern 1: Find papers by concept similarity
    for term in query_terms:
        concept_query = f"""
        MATCH (c:Concept)
        WHERE toLower(c.name) CONTAINS '{term}'
        MATCH (p:Paper)-[:DISCUSSES]->(c)
        RETURN DISTINCT p.id as paper_id,
               p.title as title,
               p.abstract as abstract,
               c.name as matched_concept,
               'concept_match' as match_type
        """
        
        concept_results = run_query(concept_query)
        
        for result in concept_results:
            results.append(RetrievalResult(
                paper_id=result['paper_id'],
                title=result['title'],
                abstract=result['abstract'],
                score=0.8,  # High score for direct concept match
                retrieval_method='graph_traversal',
                additional_info={'matched_concept': result['matched_concept'], 'match_type': result['match_type']}
            ))
    
    # Pattern 2: Find papers by author collaboration
    author_query = f"""
    MATCH (a:Author)-[:AUTHORED]->(p1:Paper)
    WHERE toLower(p1.title) CONTAINS '{query_terms[0]}' OR toLower(p1.abstract) CONTAINS '{query_terms[0]}'
    MATCH (a)-[:AUTHORED]->(p2:Paper)
    WHERE p1 <> p2
    RETURN DISTINCT p2.id as paper_id,
           p2.title as title,
           p2.abstract as abstract,
           a.name as connecting_author,
           'author_connection' as match_type
    LIMIT 3
    """
    
    author_results = run_query(author_query)
    
    for result in author_results:
        results.append(RetrievalResult(
            paper_id=result['paper_id'],
            title=result['title'],
            abstract=result['abstract'],
            score=0.6,  # Medium score for author connection
            retrieval_method='graph_traversal',
            additional_info={'connecting_author': result['connecting_author'], 'match_type': result['match_type']}
        ))
    
    # Pattern 3: Find papers in the same field with shared concepts
    field_query = f"""
    MATCH (p1:Paper)-[:DISCUSSES]->(c:Concept)
    WHERE toLower(p1.title) CONTAINS '{query_terms[0]}' OR toLower(c.name) CONTAINS '{query_terms[0]}'
    MATCH (p1)-[:BELONGS_TO]->(f:Field)<-[:BELONGS_TO]-(p2:Paper)
    MATCH (p2)-[:DISCUSSES]->(c2:Concept)
    WHERE p1 <> p2
    RETURN DISTINCT p2.id as paper_id,
           p2.title as title,
           p2.abstract as abstract,
           f.name as field,
           collect(DISTINCT c2.name)[0..3] as shared_concepts,
           'field_connection' as match_type
    LIMIT 3
    """
    
    field_results = run_query(field_query)
    
    for result in field_results:
        results.append(RetrievalResult(
            paper_id=result['paper_id'],
            title=result['title'],
            abstract=result['abstract'],
            score=0.7,  # Good score for field connection
            retrieval_method='graph_traversal',
            additional_info={'field': result['field'], 'shared_concepts': result['shared_concepts'], 'match_type': result['match_type']}
        ))
    
    # Remove duplicates and sort by score
    unique_results = {}
    for result in results:
        if result.paper_id not in unique_results or result.score > unique_results[result.paper_id].score:
            unique_results[result.paper_id] = result
    
    return sorted(unique_results.values(), key=lambda x: x.score, reverse=True)

# Test graph traversal search
graph_results = graph_traversal_search("attention transformer neural networks", max_depth=2)

print(f"Graph Traversal Search Results for: 'attention transformer neural networks'\n")
for i, result in enumerate(graph_results[:5], 1):
    print(f"{i}. {result.title} (Score: {result.score:.2f})")
    if result.additional_info:
        info = result.additional_info
        if 'matched_concept' in info:
            print(f"   Matched Concept: {info['matched_concept']}")
        if 'connecting_author' in info:
            print(f"   Connected via Author: {info['connecting_author']}")
        if 'field' in info:
            print(f"   Field: {info['field']}, Concepts: {', '.join(info['shared_concepts'])}")
    print(f"   Abstract: {result.abstract[:80]}...\n")

## Lesson 3: Hybrid Retrieval Strategies

In [None]:
# Implement hybrid retrieval combining vector and graph methods
def hybrid_retrieval(query: str, top_k: int = 5, weights: Dict[str, float] = None) -> List[RetrievalResult]:
    """Combine vector similarity and graph traversal for optimal retrieval"""
    
    if weights is None:
        weights = {'vector': 0.6, 'graph': 0.4}
    
    # Get results from both methods
    vector_results = vector_similarity_search(query, top_k=top_k*2)
    graph_results = graph_traversal_search(query, max_depth=2)
    
    # Combine and re-score results
    combined_results = {}
    
    # Add vector results with weighted scores
    for result in vector_results:
        combined_results[result.paper_id] = RetrievalResult(
            paper_id=result.paper_id,
            title=result.title,
            abstract=result.abstract,
            score=result.score * weights['vector'],
            retrieval_method='hybrid',
            additional_info={'vector_score': result.score, 'graph_score': 0.0, 'methods': ['vector']}
        )
    
    # Add or update with graph results
    for result in graph_results:
        if result.paper_id in combined_results:
            # Update existing result
            existing = combined_results[result.paper_id]
            existing.score += result.score * weights['graph']
            existing.additional_info['graph_score'] = result.score
            existing.additional_info['methods'].append('graph')
            if result.additional_info:
                existing.additional_info.update(result.additional_info)
        else:
            # Add new result
            combined_results[result.paper_id] = RetrievalResult(
                paper_id=result.paper_id,
                title=result.title,
                abstract=result.abstract,
                score=result.score * weights['graph'],
                retrieval_method='hybrid',
                additional_info={
                    'vector_score': 0.0, 
                    'graph_score': result.score, 
                    'methods': ['graph'],
                    **(result.additional_info or {})
                }
            )
    
    # Sort by combined score and return top k
    return sorted(combined_results.values(), key=lambda x: x.score, reverse=True)[:top_k]

# Test hybrid retrieval
hybrid_results = hybrid_retrieval("pre-training language models with attention", top_k=4)

print(f"Hybrid Retrieval Results for: 'pre-training language models with attention'\n")
for i, result in enumerate(hybrid_results, 1):
    print(f"{i}. {result.title}")
    print(f"   Combined Score: {result.score:.4f}")
    info = result.additional_info
    print(f"   Vector Score: {info['vector_score']:.4f}, Graph Score: {info['graph_score']:.4f}")
    print(f"   Methods Used: {', '.join(info['methods'])}")
    if 'matched_concept' in info:
        print(f"   Matched Concept: {info['matched_concept']}")
    print(f"   Abstract: {result.abstract[:100]}...\n")

In [None]:
# Advanced hybrid retrieval with query expansion
def advanced_hybrid_retrieval(query: str, top_k: int = 5) -> List[RetrievalResult]:
    """Advanced hybrid retrieval with query expansion and contextual scoring"""
    
    # Step 1: Query expansion using graph structure
    expanded_terms = expand_query_with_graph(query)
    print(f"Query expanded to include: {', '.join(expanded_terms)}")
    
    # Step 2: Multi-hop graph traversal for contextual retrieval
    contextual_results = contextual_graph_retrieval(query, expanded_terms)
    
    # Step 3: Vector similarity with expanded query
    expanded_query = f"{query} {' '.join(expanded_terms)}"
    vector_results = vector_similarity_search(expanded_query, top_k=top_k*2)
    
    # Step 4: Intelligent score fusion
    return fuse_retrieval_results(vector_results, contextual_results, top_k)

def expand_query_with_graph(query: str) -> List[str]:
    """Expand query using related concepts from the knowledge graph"""
    
    # Find concepts related to query terms
    query_terms = [term.strip().lower() for term in query.split() if len(term.strip()) > 3]
    expanded_terms = set()
    
    for term in query_terms:
        expansion_query = f"""
        MATCH (c1:Concept)
        WHERE toLower(c1.name) CONTAINS '{term}'
        MATCH (p:Paper)-[:DISCUSSES]->(c1)
        MATCH (p)-[:DISCUSSES]->(c2:Concept)
        WHERE c1 <> c2
        RETURN DISTINCT c2.name as related_concept
        LIMIT 3
        """
        
        results = run_query(expansion_query)
        for result in results:
            expanded_terms.add(result['related_concept'].lower())
    
    return list(expanded_terms)[:5]  # Limit expansion

def contextual_graph_retrieval(query: str, expanded_terms: List[str]) -> List[RetrievalResult]:
    """Perform contextual retrieval using multi-hop graph patterns"""
    
    # Multi-hop pattern: Paper -> Concept -> Paper -> Author -> Paper
    contextual_query = """
    MATCH (p1:Paper)-[:DISCUSSES]->(c:Concept)<-[:DISCUSSES]-(p2:Paper)
    MATCH (p2)<-[:AUTHORED]-(a:Author)-[:AUTHORED]->(p3:Paper)
    WHERE p1 <> p2 AND p2 <> p3 AND p1 <> p3
      AND (toLower(p1.title) CONTAINS $query_term OR toLower(p1.abstract) CONTAINS $query_term)
    WITH p3, c, a, 
         p3.citations as citations,
         p3.year as year,
         count(*) as connection_strength
    RETURN DISTINCT p3.id as paper_id,
           p3.title as title,
           p3.abstract as abstract,
           connection_strength,
           citations,
           year,
           c.name as connecting_concept,
           a.name as connecting_author
    ORDER BY connection_strength DESC, citations DESC
    LIMIT 5
    """
    
    results = []
    query_terms = [term.strip().lower() for term in query.split() if len(term.strip()) > 3]
    
    for term in query_terms[:2]:  # Use first 2 terms
        contextual_results = run_query(contextual_query, {'query_term': term})
        
        for result in contextual_results:
            # Calculate contextual score based on connection strength and citations
            base_score = min(result['connection_strength'] * 0.2, 1.0)
            citation_bonus = min(result['citations'] / 10000, 0.3)
            recency_bonus = max(0, (result['year'] - 2015) * 0.05)
            
            contextual_score = base_score + citation_bonus + recency_bonus
            
            results.append(RetrievalResult(
                paper_id=result['paper_id'],
                title=result['title'],
                abstract=result['abstract'],
                score=contextual_score,
                retrieval_method='contextual_graph',
                additional_info={
                    'connection_strength': result['connection_strength'],
                    'connecting_concept': result['connecting_concept'],
                    'connecting_author': result['connecting_author'],
                    'citations': result['citations']
                }
            ))
    
    return results

def fuse_retrieval_results(vector_results: List[RetrievalResult], 
                          contextual_results: List[RetrievalResult], 
                          top_k: int) -> List[RetrievalResult]:
    """Intelligently fuse results from different retrieval methods"""
    
    fused_results = {}
    
    # Add vector results
    for result in vector_results:
        fused_results[result.paper_id] = RetrievalResult(
            paper_id=result.paper_id,
            title=result.title,
            abstract=result.abstract,
            score=result.score * 0.7,  # Weight vector similarity
            retrieval_method='advanced_hybrid',
            additional_info={
                'vector_score': result.score,
                'contextual_score': 0.0,
                'methods': ['vector']
            }
        )
    
    # Add contextual results
    for result in contextual_results:
        if result.paper_id in fused_results:
            # Boost existing result
            existing = fused_results[result.paper_id]
            existing.score += result.score * 0.5  # Weight contextual score
            existing.additional_info['contextual_score'] = result.score
            existing.additional_info['methods'].append('contextual')
            existing.additional_info.update(result.additional_info)
        else:
            # Add new result
            fused_results[result.paper_id] = RetrievalResult(
                paper_id=result.paper_id,
                title=result.title,
                abstract=result.abstract,
                score=result.score * 0.5,
                retrieval_method='advanced_hybrid',
                additional_info={
                    'vector_score': 0.0,
                    'contextual_score': result.score,
                    'methods': ['contextual'],
                    **result.additional_info
                }
            )
    
    return sorted(fused_results.values(), key=lambda x: x.score, reverse=True)[:top_k]

# Test advanced hybrid retrieval
print("Testing Advanced Hybrid Retrieval...\n")
advanced_results = advanced_hybrid_retrieval("bidirectional encoder representations", top_k=3)

print(f"\nAdvanced Hybrid Retrieval Results for: 'bidirectional encoder representations'\n")
for i, result in enumerate(advanced_results, 1):
    print(f"{i}. {result.title}")
    print(f"   Final Score: {result.score:.4f}")
    info = result.additional_info
    print(f"   Component Scores - Vector: {info['vector_score']:.4f}, Contextual: {info['contextual_score']:.4f}")
    print(f"   Methods: {', '.join(info['methods'])}")
    if 'connecting_concept' in info:
        print(f"   Connected via: {info['connecting_concept']} (Author: {info['connecting_author']})")
    print(f"   Abstract: {result.abstract[:100]}...\n")

## Lesson 4: Retrieval Optimization and Evaluation

In [None]:
# Retrieval evaluation and optimization
def evaluate_retrieval_methods(test_queries: List[Dict[str, any]]) -> Dict[str, Dict[str, float]]:
    """Evaluate different retrieval methods on test queries"""
    
    methods = {
        'vector_only': lambda q: vector_similarity_search(q, top_k=3),
        'graph_only': lambda q: graph_traversal_search(q, max_depth=2)[:3],
        'hybrid': lambda q: hybrid_retrieval(q, top_k=3),
        'advanced_hybrid': lambda q: advanced_hybrid_retrieval(q, top_k=3)
    }
    
    results = {}
    
    for method_name, method_func in methods.items():
        print(f"Evaluating {method_name}...")
        
        method_results = {
            'precision_at_1': 0.0,
            'precision_at_3': 0.0,
            'avg_score': 0.0,
            'coverage': 0.0
        }
        
        for query_data in test_queries:
            query = query_data['query']
            relevant_papers = set(query_data['relevant_papers'])
            
            try:
                retrieved = method_func(query)
                retrieved_papers = {r.paper_id for r in retrieved}
                
                # Calculate metrics
                if retrieved:
                    # Precision@1
                    if retrieved[0].paper_id in relevant_papers:
                        method_results['precision_at_1'] += 1.0
                    
                    # Precision@3
                    relevant_in_top3 = len(retrieved_papers.intersection(relevant_papers))
                    method_results['precision_at_3'] += relevant_in_top3 / min(3, len(retrieved))
                    
                    # Average score
                    method_results['avg_score'] += sum(r.score for r in retrieved) / len(retrieved)
                    
                    # Coverage
                    method_results['coverage'] += len(retrieved_papers.intersection(relevant_papers)) / len(relevant_papers)
                    
            except Exception as e:
                print(f"Error with {method_name} on query '{query}': {e}")
        
        # Average the metrics
        num_queries = len(test_queries)
        for metric in method_results:
            method_results[metric] /= num_queries
        
        results[method_name] = method_results
    
    return results

# Define test queries with relevant papers
test_queries = [
    {
        'query': 'transformer attention mechanism',
        'relevant_papers': ['paper1', 'paper2']  # Attention paper and BERT
    },
    {
        'query': 'large language models few shot learning',
        'relevant_papers': ['paper3']  # GPT-3 paper
    },
    {
        'query': 'retrieval augmented generation knowledge',
        'relevant_papers': ['paper4']  # RAG paper
    },
    {
        'query': 'graph neural networks message passing',
        'relevant_papers': ['paper5']  # GNN paper
    }
]

# Run evaluation
evaluation_results = evaluate_retrieval_methods(test_queries)

# Display results
print("\n=== RETRIEVAL METHOD EVALUATION RESULTS ===")
print(f"{'Method':<20} {'P@1':<8} {'P@3':<8} {'Avg Score':<10} {'Coverage':<10}")
print("-" * 60)

for method, metrics in evaluation_results.items():
    print(f"{method:<20} {metrics['precision_at_1']:.3f}    {metrics['precision_at_3']:.3f}    {metrics['avg_score']:.3f}      {metrics['coverage']:.3f}")

In [None]:
# Token-aware retrieval optimization
def optimize_context_window(query: str, max_tokens: int = 2000) -> List[RetrievalResult]:
    """Optimize retrieval results to fit within token constraints"""
    
    # Get comprehensive results
    results = advanced_hybrid_retrieval(query, top_k=10)
    
    # Calculate token usage for each result
    optimized_results = []
    current_tokens = 0
    
    # Reserve tokens for query and system prompt
    query_tokens = len(tokenizer.encode(query))
    system_tokens = 200  # Estimate for system prompt
    available_tokens = max_tokens - query_tokens - system_tokens
    
    for result in results:
        # Calculate tokens for this result
        content = f"Title: {result.title}\nAbstract: {result.abstract}"
        result_tokens = len(tokenizer.encode(content))
        
        if current_tokens + result_tokens <= available_tokens:
            optimized_results.append(result)
            current_tokens += result_tokens
        else:
            # Try to truncate abstract to fit
            truncated_abstract = result.abstract
            while len(tokenizer.encode(f"Title: {result.title}\nAbstract: {truncated_abstract}")) > available_tokens - current_tokens and len(truncated_abstract) > 50:
                truncated_abstract = truncated_abstract[:-50]
            
            if len(truncated_abstract) > 50:
                truncated_result = RetrievalResult(
                    paper_id=result.paper_id,
                    title=result.title,
                    abstract=truncated_abstract + "...",
                    score=result.score,
                    retrieval_method=result.retrieval_method,
                    additional_info=result.additional_info
                )
                optimized_results.append(truncated_result)
                current_tokens += len(tokenizer.encode(f"Title: {result.title}\nAbstract: {truncated_abstract}..."))
            
            break
    
    return optimized_results, current_tokens

# Test token optimization
optimized_results, token_count = optimize_context_window("transformer models with attention mechanisms", max_tokens=1500)

print(f"Token-Optimized Retrieval Results (Total tokens: {token_count})\n")
for i, result in enumerate(optimized_results, 1):
    content = f"Title: {result.title}\nAbstract: {result.abstract}"
    result_tokens = len(tokenizer.encode(content))
    print(f"{i}. {result.title} ({result_tokens} tokens)")
    print(f"   Score: {result.score:.4f}")
    print(f"   Abstract: {result.abstract[:100]}...\n")

## Hands-on Exercise: Complete GraphRAG Retrieval System

In [None]:
# Complete GraphRAG retrieval system
class GraphRAGRetriever:
    """Complete GraphRAG retrieval system with multiple strategies"""
    
    def __init__(self, neo4j_driver, embedding_model):
        self.driver = neo4j_driver
        self.embedding_model = embedding_model
        self.tokenizer = tiktoken.get_encoding("cl100k_base")
    
    def retrieve(self, query: str, method: str = 'advanced_hybrid', 
                top_k: int = 5, max_tokens: int = 2000) -> Dict[str, any]:
        """Main retrieval method with multiple strategies"""
        
        start_time = time.time()
        
        # Choose retrieval method
        if method == 'vector':
            results = vector_similarity_search(query, top_k)
        elif method == 'graph':
            results = graph_traversal_search(query, max_depth=2)[:top_k]
        elif method == 'hybrid':
            results = hybrid_retrieval(query, top_k)
        elif method == 'advanced_hybrid':
            results = advanced_hybrid_retrieval(query, top_k)
        else:
            raise ValueError(f"Unknown method: {method}")
        
        # Optimize for token constraints
        optimized_results, token_count = optimize_context_window(query, max_tokens)
        
        # Prepare context for RAG
        context = self._prepare_context(optimized_results)
        
        retrieval_time = time.time() - start_time
        
        return {
            'query': query,
            'method': method,
            'results': optimized_results,
            'context': context,
            'token_count': token_count,
            'retrieval_time': retrieval_time,
            'metadata': {
                'num_results': len(optimized_results),
                'avg_score': sum(r.score for r in optimized_results) / len(optimized_results) if optimized_results else 0
            }
        }
    
    def _prepare_context(self, results: List[RetrievalResult]) -> str:
        """Prepare formatted context for RAG"""
        
        context_parts = []
        
        for i, result in enumerate(results, 1):
            context_part = f"""[Document {i}]
Title: {result.title}
Abstract: {result.abstract}
Retrieval Score: {result.score:.4f}
Method: {result.retrieval_method}
"""
            
            if result.additional_info:
                info = result.additional_info
                if 'matched_concept' in info:
                    context_part += f"Matched Concept: {info['matched_concept']}\n"
                if 'connecting_author' in info:
                    context_part += f"Connected via Author: {info['connecting_author']}\n"
                if 'field' in info:
                    context_part += f"Field: {info['field']}\n"
            
            context_parts.append(context_part)
        
        return "\n".join(context_parts)
    
    def batch_retrieve(self, queries: List[str], method: str = 'advanced_hybrid') -> List[Dict[str, any]]:
        """Batch retrieval for multiple queries"""
        return [self.retrieve(query, method) for query in queries]
    
    def explain_retrieval(self, query: str, method: str = 'advanced_hybrid') -> Dict[str, any]:
        """Explain how retrieval worked for debugging"""
        
        result = self.retrieve(query, method)
        
        explanation = {
            'query_analysis': {
                'query': query,
                'query_terms': [term.strip().lower() for term in query.split() if len(term.strip()) > 3],
                'query_tokens': len(self.tokenizer.encode(query))
            },
            'retrieval_details': {
                'method_used': method,
                'num_results': len(result['results']),
                'total_tokens': result['token_count'],
                'retrieval_time': result['retrieval_time']
            },
            'result_breakdown': []
        }
        
        for result_obj in result['results']:
            breakdown = {
                'paper_id': result_obj.paper_id,
                'title': result_obj.title,
                'final_score': result_obj.score,
                'retrieval_method': result_obj.retrieval_method
            }
            
            if result_obj.additional_info:
                breakdown['score_components'] = result_obj.additional_info
            
            explanation['result_breakdown'].append(breakdown)
        
        return explanation

# Initialize the complete GraphRAG retriever
graphrag_retriever = GraphRAGRetriever(driver, embedding_model)

# Test the complete system
test_query = "How do attention mechanisms work in transformer models?"
retrieval_result = graphrag_retriever.retrieve(test_query, method='advanced_hybrid', top_k=3)

print("=== COMPLETE GRAPHRAG RETRIEVAL SYSTEM TEST ===")
print(f"Query: {test_query}")
print(f"Method: {retrieval_result['method']}")
print(f"Results: {retrieval_result['metadata']['num_results']}")
print(f"Tokens: {retrieval_result['token_count']}")
print(f"Time: {retrieval_result['retrieval_time']:.3f}s")
print(f"Avg Score: {retrieval_result['metadata']['avg_score']:.4f}")

print("\nRetrieved Context:")
print("=" * 50)
print(retrieval_result['context'])

In [None]:
# Test explanation functionality
explanation = graphrag_retriever.explain_retrieval("bidirectional language model pre-training", method='advanced_hybrid')

print("\n=== RETRIEVAL EXPLANATION ===")
print(f"Query: {explanation['query_analysis']['query']}")
print(f"Query Terms: {', '.join(explanation['query_analysis']['query_terms'])}")
print(f"Query Tokens: {explanation['query_analysis']['query_tokens']}")

print(f"\nRetrieval Method: {explanation['retrieval_details']['method_used']}")
print(f"Results Retrieved: {explanation['retrieval_details']['num_results']}")
print(f"Total Context Tokens: {explanation['retrieval_details']['total_tokens']}")
print(f"Retrieval Time: {explanation['retrieval_details']['retrieval_time']:.3f}s")

print("\nResult Breakdown:")
for i, result in enumerate(explanation['result_breakdown'], 1):
    print(f"\n{i}. {result['title']}")
    print(f"   Paper ID: {result['paper_id']}")
    print(f"   Final Score: {result['final_score']:.4f}")
    print(f"   Method: {result['retrieval_method']}")
    
    if 'score_components' in result:
        components = result['score_components']
        print(f"   Score Components:")
        if 'vector_score' in components:
            print(f"     - Vector Score: {components['vector_score']:.4f}")
        if 'contextual_score' in components:
            print(f"     - Contextual Score: {components['contextual_score']:.4f}")
        if 'methods' in components:
            print(f"     - Methods Used: {', '.join(components['methods'])}")

## Module Summary and Next Steps

In this module, you learned to:
- Implement vector-based semantic retrieval with Neo4j
- Design sophisticated graph traversal retrieval patterns
- Combine multiple retrieval strategies for optimal results
- Optimize retrieval for token constraints and performance
- Build complete GraphRAG systems with explanation capabilities

### Key Takeaways
- **Vector Similarity**: Provides semantic understanding and content-based matching
- **Graph Traversal**: Reveals structural relationships and contextual connections
- **Hybrid Approaches**: Combine strengths of multiple methods for better results
- **Token Optimization**: Essential for real-world LLM applications with context limits
- **Explainability**: Critical for debugging and improving retrieval systems

### Business Applications
- **Research Discovery**: Find related papers and concepts through graph relationships
- **Document Retrieval**: Combine semantic and structural search for enterprise knowledge
- **Recommendation Systems**: Use graph connections to suggest relevant content
- **Question Answering**: Provide rich context for LLM-powered Q&A systems

### Performance Insights
- Vector similarity excels at semantic matching but may miss structural relationships
- Graph traversal finds unexpected connections but can be computationally expensive
- Hybrid methods consistently outperform single approaches
- Token optimization is crucial for production deployments

### Next Module
Module 6: Agents - Learn how to build intelligent agents that can reason over graph knowledge and take actions based on retrieved information.

In [None]:
# Final performance benchmark
print("\n=== FINAL PERFORMANCE BENCHMARK ===")

benchmark_queries = [
    "transformer attention mechanisms neural networks",
    "large language models few shot learning",
    "graph neural networks message passing",
    "retrieval augmented generation knowledge",
    "bidirectional encoder representations"
]

methods_to_test = ['vector', 'graph', 'hybrid', 'advanced_hybrid']
performance_data = []

for method in methods_to_test:
    total_time = 0
    total_results = 0
    total_score = 0
    
    for query in benchmark_queries:
        try:
            result = graphrag_retriever.retrieve(query, method=method, top_k=3)
            total_time += result['retrieval_time']
            total_results += result['metadata']['num_results']
            total_score += result['metadata']['avg_score']
        except Exception as e:
            print(f"Error with {method}: {e}")
    
    avg_time = total_time / len(benchmark_queries)
    avg_results = total_results / len(benchmark_queries)
    avg_score = total_score / len(benchmark_queries)
    
    performance_data.append({
        'method': method,
        'avg_time': avg_time,
        'avg_results': avg_results,
        'avg_score': avg_score
    })

print(f"{'Method':<18} {'Avg Time (s)':<12} {'Avg Results':<12} {'Avg Score':<10}")
print("-" * 55)
for perf in performance_data:
    print(f"{perf['method']:<18} {perf['avg_time']:<12.3f} {perf['avg_results']:<12.1f} {perf['avg_score']:<10.4f}")

print("\n🎉 Module 5: Retrievers completed successfully!")
print("You're now ready to move on to Module 6: Agents")

# Cleanup (optional)
# driver.close()