# Document Retrieval with BERT/DistilBERT

This notebook replaces LLM-based answer generation with BERT/DistilBERT for semantic similarity-based document retrieval. Instead of generating answers, it returns the document names that contain the relevant information.

## Key Changes from Original:
- Uses `sentence-transformers` with DistilBERT/BERT for embeddings
- No LLM endpoint required
- Returns document names instead of generated answers
- Lighter weight and faster inference

## 1. Install Dependencies

In [None]:
# Install required packages
!pip install sentence-transformers faiss-cpu rank_bm25 pyyaml

## 2. Import Libraries

In [None]:
import yaml
import json
import os
import pickle
from typing import List, Dict, Optional
import numpy as np
import faiss
import uuid
from datetime import datetime
from sentence_transformers import SentenceTransformer

print("‚úì All libraries imported successfully")

## 3. Configuration

Choose your BERT model variant. Options include:
- `all-MiniLM-L6-v2` - Fast, good quality (recommended for most cases)
- `all-mpnet-base-v2` - Best quality, slower
- `multi-qa-distilbert-cos-v1` - Optimized for Q&A
- `distilbert-base-nli-stsb-mean-tokens` - Classic DistilBERT
- `bert-base-nli-mean-tokens` - Classic BERT

In [None]:
# Configuration - Modify as needed
CONFIG = {
    'model': {
        # Choose your model - uncomment one:
        'name': 'all-MiniLM-L6-v2',  # Fast and efficient (384 dims)
        # 'name': 'all-mpnet-base-v2',  # Higher quality (768 dims)
        # 'name': 'multi-qa-distilbert-cos-v1',  # Q&A optimized (768 dims)
        # 'name': 'distilbert-base-nli-stsb-mean-tokens',  # DistilBERT (768 dims)
        # 'name': 'bert-base-nli-mean-tokens',  # BERT base (768 dims)
    },
    'retrieval': {
        'hybrid': {
            'top_k': 5,
            'vector_weight': 0.7,
            'bm25_weight': 0.3
        }
    },
    'storage': {
        'faiss_index': './indexes/faiss',
        'bm25_index': './indexes/bm25'
    }
}

print(f"Configuration loaded. Using model: {CONFIG['model']['name']}")

## 4. BERT-based Document Retriever Class

In [None]:
class BERTDocumentRetriever:
    """
    Document retriever using BERT/DistilBERT for embeddings.
    Returns document names instead of generating answers.
    """
    
    def __init__(self, config: dict):
        self.config = config
        
        # Initialize BERT/DistilBERT model
        model_name = config['model']['name']
        print(f"Loading model: {model_name}...")
        self.model = SentenceTransformer(model_name)
        print(f"‚úì Model loaded: {model_name}")
        print(f"  Embedding dimension: {self.model.get_sentence_embedding_dimension()}")
        
        # Session management
        self.sessions = {}
        
        # Load indexes
        self.load_indexes()
    
    def load_indexes(self):
        """Load FAISS and BM25 indexes"""
        # Load FAISS index
        faiss_path = os.path.join(self.config['storage']['faiss_index'], 'faiss.index')
        if not os.path.exists(faiss_path):
            raise FileNotFoundError(f"FAISS index not found at {faiss_path}. Run indexing notebook first.")
        
        self.faiss_index = faiss.read_index(faiss_path)
        print(f"‚úì FAISS index loaded: {self.faiss_index.ntotal} vectors")
        
        # Load embeddings (optional)
        embeddings_path = os.path.join(self.config['storage']['faiss_index'], 'embeddings.npy')
        if os.path.exists(embeddings_path):
            self.embeddings = np.load(embeddings_path)
            print(f"‚úì Embeddings loaded: shape {self.embeddings.shape}")
        
        # Load BM25 index
        bm25_path = os.path.join(self.config['storage']['bm25_index'], 'bm25.pkl')
        if not os.path.exists(bm25_path):
            raise FileNotFoundError(f"BM25 index not found at {bm25_path}. Run indexing notebook first.")
        
        with open(bm25_path, 'rb') as f:
            self.bm25_index = pickle.load(f)
        print(f"‚úì BM25 index loaded")
        
        # Load chunk metadata
        metadata_path = os.path.join(self.config['storage']['faiss_index'], 'chunk_metadata.json')
        if not os.path.exists(metadata_path):
            raise FileNotFoundError(f"Chunk metadata not found at {metadata_path}. Run indexing notebook first.")
        
        with open(metadata_path, 'r') as f:
            self.chunks = json.load(f)
        print(f"‚úì Chunk metadata loaded: {len(self.chunks)} chunks")
        
        print("\n‚úì All indexes loaded successfully")
    
    def get_embedding(self, text: str) -> np.ndarray:
        """Get embedding using BERT/DistilBERT"""
        embedding = self.model.encode(text, convert_to_numpy=True)
        return embedding.astype('float32')
    
    def hybrid_search(self, query: str, entitlement: str, org_id: str = None,
                      tags: List[str] = None, top_k: int = None) -> List[Dict]:
        """Perform hybrid search with filtering"""
        if top_k is None:
            top_k = self.config['retrieval']['hybrid']['top_k']
        
        # Get query embedding
        query_embedding = self.get_embedding(query)
        query_embedding = query_embedding.reshape(1, -1).astype('float32')
        faiss.normalize_L2(query_embedding)
        
        # Retrieve more results initially for filtering
        retrieval_multiplier = 10
        initial_top_k = min(top_k * retrieval_multiplier, len(self.chunks))
        
        # Vector search (FAISS)
        vector_scores, vector_indices = self.faiss_index.search(query_embedding, initial_top_k)
        vector_scores = vector_scores[0]
        vector_indices = vector_indices[0]
        
        # Keyword search (BM25)
        tokenized_query = query.lower().split()
        bm25_scores = self.bm25_index.get_scores(tokenized_query)
        
        # Normalize scores
        def normalize(scores):
            min_s, max_s = scores.min(), scores.max()
            if max_s - min_s < 1e-10:
                return np.zeros_like(scores)
            return (scores - min_s) / (max_s - min_s)
        
        vector_scores_norm = normalize(vector_scores)
        bm25_scores_norm = normalize(bm25_scores)
        
        # Compute hybrid scores
        vector_weight = self.config['retrieval']['hybrid']['vector_weight']
        bm25_weight = self.config['retrieval']['hybrid']['bm25_weight']
        
        hybrid_scores = {}
        for idx, score in zip(vector_indices, vector_scores_norm):
            hybrid_scores[idx] = score * vector_weight
        
        for idx, score in enumerate(bm25_scores_norm):
            if idx in hybrid_scores:
                hybrid_scores[idx] += score * bm25_weight
            else:
                hybrid_scores[idx] = score * bm25_weight
        
        # Sort by score
        sorted_indices = sorted(hybrid_scores.items(), key=lambda x: x[1], reverse=True)
        
        # Filter and collect results
        accessible_results = []
        
        for idx, score in sorted_indices:
            chunk = self.chunks[idx].copy()
            
            # Apply entitlement filter
            chunk_entitlements = chunk.get('entitlement', ['universal'])
            if isinstance(chunk_entitlements, str):
                chunk_entitlements = [chunk_entitlements]
            
            has_access = (
                'universal' in chunk_entitlements or
                entitlement in chunk_entitlements
            )
            
            if not has_access:
                continue
            
            # Apply org filter
            if org_id and chunk.get('orgId') != org_id:
                continue
            
            # Apply tag filter
            if tags and not any(t in chunk.get('metadata', {}).get('tags', []) for t in tags):
                continue
            
            chunk['score'] = float(score)
            accessible_results.append(chunk)
        
        # Sort and return top K
        accessible_results.sort(key=lambda x: x['score'], reverse=True)
        return accessible_results[:top_k]
    
    def get_relevant_documents(self, query: str, entitlement: str, 
                               org_id: str = None, tags: List[str] = None,
                               top_k: int = 5) -> Dict:
        """
        Main retrieval function - returns document names instead of generated answers.
        
        Returns:
            Dict with 'documents' list containing document info and 'query' string
        """
        chunks = self.hybrid_search(query, entitlement, org_id=org_id, tags=tags, top_k=top_k)
        
        if not chunks:
            return {
                'query': query,
                'documents': [],
                'message': 'No relevant documents found.'
            }
        
        # Extract unique documents (deduplicate by doc_id)
        seen_docs = set()
        documents = []
        
        for chunk in chunks:
            doc_id = chunk.get('doc_id', chunk.get('title', 'Unknown'))
            
            if doc_id not in seen_docs:
                seen_docs.add(doc_id)
                documents.append({
                    'document_name': chunk.get('title', 'Unknown'),
                    'doc_id': doc_id,
                    'score': chunk['score'],
                    'chunk_preview': chunk.get('content', '')[:200] + '...',
                    'metadata': chunk.get('metadata', {})
                })
        
        return {
            'query': query,
            'documents': documents,
            'total_chunks_found': len(chunks)
        }
    
    # Session management methods
    def create_session(self, user_id: str, entitlement: str, org_id: str = None) -> str:
        """Create a new session"""
        session_id = str(uuid.uuid4())
        self.sessions[session_id] = {
            'session_id': session_id,
            'user_id': user_id,
            'entitlement': entitlement,
            'org_id': org_id,
            'created_at': datetime.now().isoformat(),
            'query_history': [],
            'last_activity': datetime.now().isoformat()
        }
        print(f"‚úì Created session: {session_id}")
        return session_id
    
    def query_with_session(self, session_id: str, query: str,
                           tags: List[str] = None, top_k: int = 5) -> Dict:
        """Query with session context"""
        session = self.sessions.get(session_id)
        if not session:
            raise ValueError(f"Session {session_id} not found")
        
        session['last_activity'] = datetime.now().isoformat()
        
        result = self.get_relevant_documents(
            query,
            session['entitlement'],
            org_id=session['org_id'],
            tags=tags,
            top_k=top_k
        )
        
        result['session_id'] = session_id
        
        # Store in history
        session['query_history'].append({
            'timestamp': datetime.now().isoformat(),
            'query': query,
            'documents_found': [d['document_name'] for d in result['documents']]
        })
        
        return result
    
    def get_session_history(self, session_id: str) -> List[Dict]:
        """Get query history for a session"""
        session = self.sessions.get(session_id)
        return session['query_history'] if session else []

print("‚úì BERTDocumentRetriever class defined")

## 5. Alternative: Create Index with BERT Embeddings

If you need to create new indexes using BERT embeddings (instead of using existing indexes from the original pipeline), use this section.

In [None]:
from rank_bm25 import BM25Okapi

def create_bert_indexes(documents: List[Dict], model_name: str = 'all-MiniLM-L6-v2',
                        output_dir: str = './indexes') -> None:
    """
    Create FAISS and BM25 indexes from documents using BERT embeddings.
    
    Args:
        documents: List of dicts with 'content', 'title', 'doc_id', etc.
        model_name: Sentence transformer model name
        output_dir: Directory to save indexes
    """
    # Load model
    print(f"Loading model: {model_name}...")
    model = SentenceTransformer(model_name)
    
    # Create output directories
    faiss_dir = os.path.join(output_dir, 'faiss')
    bm25_dir = os.path.join(output_dir, 'bm25')
    os.makedirs(faiss_dir, exist_ok=True)
    os.makedirs(bm25_dir, exist_ok=True)
    
    # Generate embeddings
    print(f"Generating embeddings for {len(documents)} documents...")
    texts = [doc['content'] for doc in documents]
    embeddings = model.encode(texts, show_progress_bar=True, convert_to_numpy=True)
    embeddings = embeddings.astype('float32')
    
    # Normalize for cosine similarity
    faiss.normalize_L2(embeddings)
    
    # Create FAISS index
    dimension = embeddings.shape[1]
    index = faiss.IndexFlatIP(dimension)  # Inner product for normalized vectors = cosine similarity
    index.add(embeddings)
    
    # Save FAISS index and embeddings
    faiss.write_index(index, os.path.join(faiss_dir, 'faiss.index'))
    np.save(os.path.join(faiss_dir, 'embeddings.npy'), embeddings)
    print(f"‚úì FAISS index saved: {index.ntotal} vectors")
    
    # Save chunk metadata
    with open(os.path.join(faiss_dir, 'chunk_metadata.json'), 'w') as f:
        json.dump(documents, f, indent=2)
    print(f"‚úì Chunk metadata saved")
    
    # Create BM25 index
    tokenized_docs = [doc['content'].lower().split() for doc in documents]
    bm25 = BM25Okapi(tokenized_docs)
    
    with open(os.path.join(bm25_dir, 'bm25.pkl'), 'wb') as f:
        pickle.dump(bm25, f)
    print(f"‚úì BM25 index saved")
    
    print(f"\n‚úì All indexes created successfully in {output_dir}")

print("‚úì Index creation function defined")

## 6. Example: Create Sample Documents and Index

If you don't have existing indexes, run this cell to create sample data and indexes.

In [None]:
# Sample documents for demonstration
SAMPLE_DOCUMENTS = [
    {
        'doc_id': 'doc_001',
        'title': 'Cancellation Policy Guide',
        'content': 'To process a cancellation, first verify the customer identity. Then check the booking status in the system. If eligible for refund, initiate the cancellation workflow and process the refund within 5-7 business days.',
        'entitlement': ['agent_support', 'universal'],
        'orgId': 'org_123',
        'metadata': {'tags': ['cancellation', 'refund', 'policy']}
    },
    {
        'doc_id': 'doc_002',
        'title': 'Booking Creation Process',
        'content': 'To create a new booking, collect customer details including name, contact, and travel dates. Search for availability in the system. Select the appropriate option and confirm with payment.',
        'entitlement': ['agent_sales', 'universal'],
        'orgId': 'org_123',
        'metadata': {'tags': ['booking', 'sales']}
    },
    {
        'doc_id': 'doc_003',
        'title': 'Refund Processing Guidelines',
        'content': 'Refunds must be processed within 48 hours of approval. Required documents include original receipt, cancellation confirmation, and customer ID. Refunds are issued to the original payment method.',
        'entitlement': ['agent_support'],
        'orgId': 'org_123',
        'metadata': {'tags': ['refund', 'cancellation']}
    },
    {
        'doc_id': 'doc_004',
        'title': 'Customer Verification Procedures',
        'content': 'Always verify customer identity using two-factor authentication. Check government ID and booking reference number. For sensitive operations, additional security questions may be required.',
        'entitlement': ['universal'],
        'orgId': 'org_123',
        'metadata': {'tags': ['security', 'verification']}
    },
    {
        'doc_id': 'doc_005',
        'title': 'Sales Commission Structure',
        'content': 'Sales agents earn 5% commission on standard bookings and 7% on premium packages. Commissions are calculated monthly and paid on the 15th of each month.',
        'entitlement': ['agent_sales', 'agent_manager'],
        'orgId': 'org_123',
        'metadata': {'tags': ['sales', 'commission']}
    }
]

# Create indexes
create_bert_indexes(SAMPLE_DOCUMENTS, model_name=CONFIG['model']['name'])

print("\n‚úì Sample indexes created successfully")

## 7. Initialize the Retriever

In [None]:
# Initialize the BERT-based retriever
retriever = BERTDocumentRetriever(CONFIG)

## 8. Test Document Retrieval

In [None]:
# Test 1: Simple query
print("="*70)
print("TEST 1: Simple Document Retrieval")
print("="*70)

query = "How do I process a cancellation?"
result = retriever.get_relevant_documents(
    query=query,
    entitlement='agent_support',
    org_id='org_123'
)

print(f"\nQuery: {result['query']}")
print(f"\nRelevant Documents Found:")
print("-"*50)

for i, doc in enumerate(result['documents'], 1):
    print(f"\n{i}. {doc['document_name']}")
    print(f"   Doc ID: {doc['doc_id']}")
    print(f"   Score: {doc['score']:.4f}")
    print(f"   Preview: {doc['chunk_preview'][:100]}...")

In [None]:
# Test 2: Query with session
print("\n" + "="*70)
print("TEST 2: Document Retrieval with Session")
print("="*70)

# Create session
session_id = retriever.create_session(
    user_id='agent_001',
    entitlement='agent_support',
    org_id='org_123'
)

# Query 1
print("\n--- Query 1 ---")
result1 = retriever.query_with_session(
    session_id=session_id,
    query="What documents do I need for a refund?"
)

print(f"Query: {result1['query']}")
print(f"Documents found: {[d['document_name'] for d in result1['documents']]}")

# Query 2
print("\n--- Query 2 ---")
result2 = retriever.query_with_session(
    session_id=session_id,
    query="How do I verify customer identity?"
)

print(f"Query: {result2['query']}")
print(f"Documents found: {[d['document_name'] for d in result2['documents']]}")

# Show session history
print("\n--- Session History ---")
history = retriever.get_session_history(session_id)
for item in history:
    print(f"  [{item['timestamp'][:19]}] {item['query']}")
    print(f"    ‚Üí Found: {item['documents_found']}")

In [None]:
# Test 3: Entitlement-based filtering
print("\n" + "="*70)
print("TEST 3: Entitlement-Based Access Control")
print("="*70)

same_query = "What are the sales commission rates?"

# Support agent (should not see sales commission doc)
print("\n--- Support Agent ---")
support_result = retriever.get_relevant_documents(
    query=same_query,
    entitlement='agent_support',
    org_id='org_123'
)
print(f"Query: {same_query}")
print(f"Documents found: {[d['document_name'] for d in support_result['documents']]}")

# Sales agent (should see sales commission doc)
print("\n--- Sales Agent ---")
sales_result = retriever.get_relevant_documents(
    query=same_query,
    entitlement='agent_sales',
    org_id='org_123'
)
print(f"Query: {same_query}")
print(f"Documents found: {[d['document_name'] for d in sales_result['documents']]}")

## 9. Utility Function: Pretty Print Results

In [None]:
def display_results(result: Dict, show_preview: bool = True):
    """
    Pretty print retrieval results.
    """
    print("\n" + "="*60)
    print(f"Query: {result['query']}")
    print("="*60)
    
    if not result['documents']:
        print("\n‚ùå No relevant documents found.")
        return
    
    print(f"\n‚úì Found {len(result['documents'])} relevant document(s):")
    print("-"*60)
    
    for i, doc in enumerate(result['documents'], 1):
        print(f"\nüìÑ {i}. {doc['document_name']}")
        print(f"   Relevance Score: {doc['score']:.4f}")
        print(f"   Document ID: {doc['doc_id']}")
        
        if show_preview:
            print(f"   Preview: {doc['chunk_preview'][:150]}...")
        
        if doc.get('metadata', {}).get('tags'):
            print(f"   Tags: {', '.join(doc['metadata']['tags'])}")

# Example usage
result = retriever.get_relevant_documents(
    query="How do I create a new booking?",
    entitlement='agent_sales',
    org_id='org_123'
)

display_results(result)

## 10. Interactive Query Mode

In [None]:
def interactive_query_mode(retriever, entitlement: str = 'agent_support', org_id: str = 'org_123'):
    """
    Interactive mode to query documents.
    Type 'quit' to exit.
    """
    print("\n" + "="*60)
    print("üìö Interactive Document Retrieval")
    print("="*60)
    print(f"Entitlement: {entitlement}")
    print(f"Organization: {org_id}")
    print("Type 'quit' to exit.\n")
    
    while True:
        query = input("\nüîç Enter your query: ").strip()
        
        if query.lower() == 'quit':
            print("\nExiting interactive mode. Goodbye!")
            break
        
        if not query:
            print("Please enter a valid query.")
            continue
        
        result = retriever.get_relevant_documents(
            query=query,
            entitlement=entitlement,
            org_id=org_id
        )
        
        display_results(result)

# Uncomment to run interactive mode
# interactive_query_mode(retriever, entitlement='agent_support', org_id='org_123')

## 11. Summary

### Key Differences from LLM Approach:

| Feature | LLM Approach | BERT Approach |
|---------|--------------|---------------|
| Output | Generated answer text | Document names/IDs |
| Model | Large LLM (Llama, etc.) | BERT/DistilBERT |
| Compute | High (LLM inference) | Low (embedding only) |
| Latency | Higher | Lower |
| Cost | Higher (LLM endpoint) | Lower (local/smaller model) |
| Use Case | Q&A, chatbots | Document search, retrieval |

### Available Models:

- `all-MiniLM-L6-v2` - Fast, 384 dimensions
- `all-mpnet-base-v2` - Best quality, 768 dimensions
- `multi-qa-distilbert-cos-v1` - Q&A optimized, 768 dimensions
- `distilbert-base-nli-stsb-mean-tokens` - DistilBERT, 768 dimensions
- `bert-base-nli-mean-tokens` - BERT, 768 dimensions

In [None]:
print("\n" + "="*60)
print("‚úì NOTEBOOK COMPLETE")
print("="*60)
print("\nThe BERTDocumentRetriever is ready to use!")
print("\nKey methods:")
print("  - get_relevant_documents(query, entitlement, ...) ‚Üí Returns document names")
print("  - create_session(user_id, entitlement, ...) ‚Üí Creates a session")
print("  - query_with_session(session_id, query, ...) ‚Üí Query with session tracking")