# Document Retrieval with Conversation History (Embedding-Based)

This notebook handles conversation history **without a re-ranker** by:

1. **Query Expansion** - Append history context to the query before embedding
2. **Multi-Query Retrieval** - Search with current query + previous queries, merge results
3. **Context-Aware Embedding** - Create a combined embedding from history + current query

## Why This Works:
```
User: "How do I cancel a booking?"     → Finds cancellation docs
User: "What about the refund?"         → Finds refund docs (context: cancellation)
User: "What documents do I need?"      → Finds cancellation/refund docs (not generic docs)
```

In [None]:
import yaml
import json
import os
import pickle
from typing import List, Dict, Optional
import numpy as np
import faiss
import boto3
import uuid
from datetime import datetime

with open('config/config.yaml', 'r') as f:
    config = yaml.safe_load(f)

print("✓ Config loaded")

In [None]:
class ConversationAwareRetriever:
    """
    Document retriever that uses conversation history to improve results.
    Uses embeddings only (no re-ranker).
    
    Strategies:
    1. Query Expansion - Expand query with history context
    2. Multi-Query - Search with multiple queries, merge results
    3. Embedding Fusion - Combine embeddings from history + current query
    """
    
    def __init__(self, config: dict):
        self.config = config
        
        # Embedding client
        self.embedding_endpoint_name = config['models']['embedding']['endpoint_name']
        embedding_creds = config['models']['embedding']['credentials']
        self.embedding_client = boto3.client(
            'sagemaker-runtime',
            region_name=embedding_creds['region'],
            aws_access_key_id=embedding_creds['accessKeyId'],
            aws_secret_access_key=embedding_creds['secretAccessKey'],
            aws_session_token=embedding_creds['sessionToken']
        )
        print("✓ Embedding client initialized")
        
        # Session management
        self.sessions = {}
        
        self.load_indexes()
    
    def load_indexes(self):
        faiss_path = os.path.join(self.config['storage']['faiss_index'], 'faiss.index')
        self.faiss_index = faiss.read_index(faiss_path)
        
        bm25_path = os.path.join(self.config['storage']['bm25_index'], 'bm25.pkl')
        with open(bm25_path, 'rb') as f:
            self.bm25_index = pickle.load(f)
        
        metadata_path = os.path.join(self.config['storage']['faiss_index'], 'chunk_metadata.json')
        with open(metadata_path, 'r') as f:
            self.chunks = json.load(f)
        
        print(f"✓ Indexes loaded: {len(self.chunks)} chunks")
    
    def get_embedding(self, text: str) -> np.ndarray:
        """Get embedding for text"""
        params = {"inputs": [text], "encoding_format": "float"}
        response = self.embedding_client.invoke_endpoint(
            EndpointName=self.embedding_endpoint_name,
            ContentType='application/json',
            Body=json.dumps(params)
        )
        raw_bytes = response['Body'].read()
        output_data = json.loads(raw_bytes.decode())
        return np.array(output_data[0], dtype='float32')
    
    # =========================================================================
    # STRATEGY 1: Query Expansion
    # =========================================================================
    def expand_query_with_history(self, query: str, history: List[Dict], 
                                   max_history: int = 3) -> str:
        """
        Expand the query by prepending conversation context.
        
        Example:
            History: ["How do I cancel?", "What's the refund policy?"]
            Query: "What documents do I need?"
            
            Expanded: "cancel booking refund policy. What documents do I need?"
        """
        if not history:
            return query
        
        # Get recent history
        recent = history[-max_history:]
        
        # Extract key terms from previous queries
        context_terms = []
        for turn in recent:
            prev_query = turn.get('query', '')
            # Add the previous query terms
            context_terms.append(prev_query)
        
        # Combine context with current query
        context_str = " ".join(context_terms)
        expanded_query = f"{context_str}. {query}"
        
        return expanded_query
    
    # =========================================================================
    # STRATEGY 2: Multi-Query Retrieval with Fusion
    # =========================================================================
    def multi_query_search(self, query: str, history: List[Dict],
                           entitlement: str, org_id: str = None,
                           tags: List[str] = None, top_k: int = 20,
                           history_weight: float = 0.3) -> List[Dict]:
        """
        Search with multiple queries (current + history) and fuse results.
        
        - Current query gets weight: (1 - history_weight)
        - History queries share weight: history_weight
        """
        all_scores = {}  # chunk_idx -> score
        
        # Search with current query (primary)
        current_results = self._single_query_search(query, entitlement, org_id, tags, top_k * 2)
        current_weight = 1.0 - history_weight
        
        for chunk in current_results:
            idx = self.chunks.index(chunk) if chunk in self.chunks else None
            if idx is not None:
                chunk_id = chunk.get('doc_id', '') + '_' + str(chunk.get('chunk_index', 0))
                all_scores[chunk_id] = {
                    'chunk': chunk,
                    'score': chunk['hybrid_score'] * current_weight
                }
        
        # Search with history queries
        if history:
            recent = history[-3:]  # Last 3 turns
            per_history_weight = history_weight / len(recent)
            
            for turn in recent:
                prev_query = turn.get('query', '')
                if prev_query:
                    hist_results = self._single_query_search(prev_query, entitlement, org_id, tags, top_k)
                    
                    for chunk in hist_results:
                        chunk_id = chunk.get('doc_id', '') + '_' + str(chunk.get('chunk_index', 0))
                        if chunk_id in all_scores:
                            all_scores[chunk_id]['score'] += chunk['hybrid_score'] * per_history_weight
                        else:
                            all_scores[chunk_id] = {
                                'chunk': chunk,
                                'score': chunk['hybrid_score'] * per_history_weight
                            }
        
        # Sort by fused score
        sorted_results = sorted(all_scores.values(), key=lambda x: x['score'], reverse=True)
        
        # Update scores and return
        results = []
        for item in sorted_results[:top_k]:
            chunk = item['chunk'].copy()
            chunk['fused_score'] = item['score']
            results.append(chunk)
        
        return results
    
    # =========================================================================
    # STRATEGY 3: Embedding Fusion
    # =========================================================================
    def get_fused_embedding(self, query: str, history: List[Dict],
                            current_weight: float = 0.7,
                            history_weight: float = 0.3) -> np.ndarray:
        """
        Create a fused embedding from current query + history.
        
        fused = current_weight * embed(query) + history_weight * mean(embed(history))
        """
        # Get current query embedding
        current_embedding = self.get_embedding(query)
        
        if not history:
            return current_embedding
        
        # Get history embeddings
        recent = history[-3:]
        history_embeddings = []
        
        for turn in recent:
            prev_query = turn.get('query', '')
            if prev_query:
                emb = self.get_embedding(prev_query)
                history_embeddings.append(emb)
        
        if not history_embeddings:
            return current_embedding
        
        # Average history embeddings
        history_mean = np.mean(history_embeddings, axis=0)
        
        # Fuse embeddings
        fused = current_weight * current_embedding + history_weight * history_mean
        
        # Normalize
        fused = fused / np.linalg.norm(fused)
        
        return fused.astype('float32')
    
    # =========================================================================
    # Core Search Methods
    # =========================================================================
    def _single_query_search(self, query: str, entitlement: str, 
                              org_id: str = None, tags: List[str] = None,
                              top_k: int = 20) -> List[Dict]:
        """Basic hybrid search for a single query"""
        query_embedding = self.get_embedding(query)
        return self._search_with_embedding(query, query_embedding, entitlement, org_id, tags, top_k)
    
    def _search_with_embedding(self, query: str, query_embedding: np.ndarray,
                                entitlement: str, org_id: str = None,
                                tags: List[str] = None, top_k: int = 20) -> List[Dict]:
        """Search using a pre-computed embedding"""
        query_embedding = query_embedding.reshape(1, -1).astype('float32')
        faiss.normalize_L2(query_embedding)
        
        initial_top_k = min(top_k * 10, len(self.chunks))
        
        # Vector search
        vector_scores, vector_indices = self.faiss_index.search(query_embedding, initial_top_k)
        vector_scores = vector_scores[0]
        vector_indices = vector_indices[0]
        
        # BM25 search
        tokenized_query = query.lower().split()
        bm25_scores = self.bm25_index.get_scores(tokenized_query)
        
        # Normalize
        def normalize(scores):
            min_s, max_s = scores.min(), scores.max()
            if max_s - min_s < 1e-10:
                return np.zeros_like(scores)
            return (scores - min_s) / (max_s - min_s)
        
        vector_scores_norm = normalize(vector_scores)
        bm25_scores_norm = normalize(bm25_scores)
        
        # Hybrid scores
        vector_weight = self.config['retrieval']['hybrid']['vector_weight']
        bm25_weight = self.config['retrieval']['hybrid']['bm25_weight']
        
        hybrid_scores = {}
        for idx, score in zip(vector_indices, vector_scores_norm):
            hybrid_scores[idx] = score * vector_weight
        
        for idx, score in enumerate(bm25_scores_norm):
            if idx in hybrid_scores:
                hybrid_scores[idx] += score * bm25_weight
            else:
                hybrid_scores[idx] = score * bm25_weight
        
        sorted_indices = sorted(hybrid_scores.items(), key=lambda x: x[1], reverse=True)
        
        # Filter by access
        results = []
        for idx, score in sorted_indices:
            chunk = self.chunks[idx].copy()
            
            chunk_entitlements = chunk['entitlement']
            if isinstance(chunk_entitlements, str):
                chunk_entitlements = [chunk_entitlements]
            
            has_access = 'universal' in chunk_entitlements or entitlement in chunk_entitlements
            if not has_access:
                continue
            if org_id and chunk['orgId'] != org_id:
                continue
            if tags and not any(t in chunk['metadata']['tags'] for t in tags):
                continue
            
            chunk['hybrid_score'] = float(score)
            results.append(chunk)
        
        return results[:top_k]
    
    # =========================================================================
    # Main Query Methods
    # =========================================================================
    def query(self, query: str, entitlement: str, org_id: str = None,
              tags: List[str] = None, top_k: int = 5,
              conversation_history: List[Dict] = None,
              strategy: str = 'query_expansion') -> Dict:
        """
        Query with conversation history support.
        
        Args:
            strategy: 'query_expansion', 'multi_query', or 'embedding_fusion'
        """
        history = conversation_history or []
        
        if strategy == 'query_expansion':
            # Strategy 1: Expand query with history
            expanded_query = self.expand_query_with_history(query, history)
            results = self._single_query_search(expanded_query, entitlement, org_id, tags, top_k * 2)
            # Also search with original query and merge
            original_results = self._single_query_search(query, entitlement, org_id, tags, top_k)
            results = self._merge_results(results, original_results, top_k)
            
        elif strategy == 'multi_query':
            # Strategy 2: Multi-query fusion
            results = self.multi_query_search(query, history, entitlement, org_id, tags, top_k * 2)
            
        elif strategy == 'embedding_fusion':
            # Strategy 3: Fused embedding
            fused_embedding = self.get_fused_embedding(query, history)
            results = self._search_with_embedding(query, fused_embedding, entitlement, org_id, tags, top_k * 2)
            
        else:
            raise ValueError(f"Unknown strategy: {strategy}")
        
        # Build response
        documents = []
        seen_docs = set()
        for chunk in results[:top_k]:
            doc_id = chunk['doc_id']
            if doc_id not in seen_docs:
                seen_docs.add(doc_id)
                documents.append({
                    'document_name': chunk['title'],
                    'doc_id': doc_id,
                    'score': chunk.get('fused_score', chunk.get('hybrid_score', 0))
                })
        
        return {
            'query': query,
            'documents': documents,
            'strategy': strategy,
            'history_turns_used': len(history)
        }
    
    def _merge_results(self, results1: List[Dict], results2: List[Dict], top_k: int) -> List[Dict]:
        """Merge two result lists, combining scores for duplicates"""
        merged = {}
        
        for chunk in results1:
            chunk_id = chunk.get('doc_id', '') + '_' + str(chunk.get('chunk_index', 0))
            merged[chunk_id] = chunk.copy()
            merged[chunk_id]['fused_score'] = chunk.get('hybrid_score', 0)
        
        for chunk in results2:
            chunk_id = chunk.get('doc_id', '') + '_' + str(chunk.get('chunk_index', 0))
            if chunk_id in merged:
                # Boost score if found in both
                merged[chunk_id]['fused_score'] += chunk.get('hybrid_score', 0) * 0.5
            else:
                merged[chunk_id] = chunk.copy()
                merged[chunk_id]['fused_score'] = chunk.get('hybrid_score', 0) * 0.5
        
        sorted_results = sorted(merged.values(), key=lambda x: x['fused_score'], reverse=True)
        return sorted_results[:top_k]
    
    # =========================================================================
    # Session Management
    # =========================================================================
    def create_session(self, user_id: str, entitlement: str, org_id: str = None) -> str:
        session_id = str(uuid.uuid4())
        self.sessions[session_id] = {
            'session_id': session_id,
            'user_id': user_id,
            'entitlement': entitlement,
            'org_id': org_id,
            'query_history': []
        }
        print(f"✓ Created session: {session_id}")
        return session_id
    
    def query_with_session(self, session_id: str, query: str,
                           tags: List[str] = None, top_k: int = 5,
                           strategy: str = 'query_expansion',
                           history_limit: int = 3) -> Dict:
        """
        Query using session history.
        """
        session = self.sessions.get(session_id)
        if not session:
            raise ValueError(f"Session not found: {session_id}")
        
        # Get recent history
        history = session['query_history'][-history_limit:] if history_limit else session['query_history']
        
        # Perform query
        result = self.query(
            query=query,
            entitlement=session['entitlement'],
            org_id=session['org_id'],
            tags=tags,
            top_k=top_k,
            conversation_history=history,
            strategy=strategy
        )
        
        # Store in history
        session['query_history'].append({
            'query': query,
            'timestamp': datetime.now().isoformat(),
            'documents_found': [d['document_name'] for d in result['documents']]
        })
        
        result['session_id'] = session_id
        return result
    
    def get_session_history(self, session_id: str) -> List[Dict]:
        session = self.sessions.get(session_id)
        return session['query_history'] if session else []

print("✓ ConversationAwareRetriever class defined")

## Initialize Retriever

In [None]:
retriever = ConversationAwareRetriever(config=config)

## Test: Multi-Turn Conversation

Watch how the context affects results for ambiguous queries.

In [None]:
print("="*70)
print("TEST: Multi-Turn Conversation with History Context")
print("="*70)

# Create session
session_id = retriever.create_session(
    user_id='agent_001',
    entitlement='agent_support',
    org_id='org_123'
)

# Turn 1: Establish context (cancellation)
print("\n" + "-"*50)
print("TURN 1: Establishing context")
print("-"*50)
r1 = retriever.query_with_session(
    session_id=session_id,
    query="How do I cancel a booking?",
    strategy='query_expansion'
)
print(f"Query: {r1['query']}")
print(f"Strategy: {r1['strategy']}")
print(f"History turns: {r1['history_turns_used']}")
print(f"Results: {[d['document_name'] for d in r1['documents'][:3]]}")

# Turn 2: Related question
print("\n" + "-"*50)
print("TURN 2: Related question")
print("-"*50)
r2 = retriever.query_with_session(
    session_id=session_id,
    query="What about the refund?",
    strategy='query_expansion'
)
print(f"Query: {r2['query']}")
print(f"History turns: {r2['history_turns_used']}")
print(f"Results: {[d['document_name'] for d in r2['documents'][:3]]}")

# Turn 3: Ambiguous query (should use context)
print("\n" + "-"*50)
print("TURN 3: Ambiguous query (history provides context)")
print("-"*50)
r3 = retriever.query_with_session(
    session_id=session_id,
    query="What documents do I need?",
    strategy='query_expansion'
)
print(f"Query: {r3['query']}")
print(f"History turns: {r3['history_turns_used']}")
print(f"Results: {[d['document_name'] for d in r3['documents'][:3]]}")
print("\n→ With context, 'documents' means cancellation/refund documents, not generic docs")

## Compare: With vs Without History Context

In [None]:
print("="*70)
print("COMPARISON: Same Query With vs Without History")
print("="*70)

ambiguous_query = "What documents do I need?"

# Simulated history about cancellations
cancellation_history = [
    {'query': 'How do I cancel a booking?', 'documents_found': ['Cancellation Policy']},
    {'query': 'What is the refund timeline?', 'documents_found': ['Refund Guide']}
]

# WITHOUT history
print("\n--- WITHOUT HISTORY ---")
result_no_history = retriever.query(
    query=ambiguous_query,
    entitlement='agent_support',
    org_id='org_123',
    conversation_history=None,
    strategy='query_expansion'
)
print(f"Query used: '{ambiguous_query}'")
for i, doc in enumerate(result_no_history['documents'][:3], 1):
    print(f"  {i}. {doc['document_name']}")

# WITH history
print("\n--- WITH CANCELLATION HISTORY ---")
result_with_history = retriever.query(
    query=ambiguous_query,
    entitlement='agent_support',
    org_id='org_123',
    conversation_history=cancellation_history,
    strategy='query_expansion'
)
expanded = retriever.expand_query_with_history(ambiguous_query, cancellation_history)
print(f"Expanded query: '{expanded}'")
for i, doc in enumerate(result_with_history['documents'][:3], 1):
    print(f"  {i}. {doc['document_name']}")

## Compare All Three Strategies

In [None]:
print("="*70)
print("COMPARE: All Three History Strategies")
print("="*70)

test_query = "What documents do I need?"
test_history = [
    {'query': 'How do I cancel a booking?'},
    {'query': 'What is the refund policy?'}
]

for strategy in ['query_expansion', 'multi_query', 'embedding_fusion']:
    print(f"\n--- Strategy: {strategy} ---")
    result = retriever.query(
        query=test_query,
        entitlement='agent_support',
        org_id='org_123',
        conversation_history=test_history,
        strategy=strategy,
        top_k=3
    )
    for i, doc in enumerate(result['documents'], 1):
        print(f"  {i}. {doc['document_name']} (score: {doc['score']:.4f})")

## Summary: Three Strategies

| Strategy | How It Works | Best For |
|----------|--------------|----------|
| `query_expansion` | Prepends history queries to current query | Simple, fast |
| `multi_query` | Searches each query separately, merges results | Diverse results |
| `embedding_fusion` | Averages embeddings from history + current | Smooth context blending |

### Recommended: `query_expansion`
- Simplest approach
- Only 1 embedding API call
- Works well for follow-up questions

In [None]:
print("\n" + "="*70)
print("NOTEBOOK COMPLETE")
print("="*70)
print("\nYou now have conversation history support WITHOUT a re-ranker!")
print("\nRecommended strategy: 'query_expansion'")