# Document Retrieval with Qwen3-Reranker-8B API

This notebook uses **Qwen3-Reranker-8B** deployed via inference engine API for document re-ranking.

## Pipeline:
```
Query â†’ Hybrid Search (Qwen Embeddings + BM25) â†’ Candidates â†’ Qwen3-Reranker API â†’ Ranked Documents
```

## Components:
- **Embeddings**: Qwen (SageMaker endpoint) - unchanged
- **Re-Ranker**: Qwen3-Reranker-8B (your deployed API endpoint)

## 1. Imports and Configuration

In [None]:
import yaml
import json
import os
import pickle
from typing import List, Dict, Optional
import numpy as np
import faiss
import boto3
import uuid
from datetime import datetime
import requests

print("âœ“ All libraries imported successfully")

In [None]:
# Load configuration
with open('config/config.yaml', 'r') as f:
    config = yaml.safe_load(f)

print("âœ“ Configuration loaded")

## 2. Qwen3-Reranker API Configuration

Configure your deployed Qwen3-Reranker-8B endpoint details below.

In [None]:
# =============================================================================
# CONFIGURE YOUR QWEN3-RERANKER-8B API ENDPOINT HERE
# =============================================================================

# Option 1: If deployed on SageMaker (similar to your embedding endpoint)
RERANKER_CONFIG = {
    'type': 'sagemaker',  # or 'rest_api'
    'endpoint_name': 'your-qwen3-reranker-endpoint',  # Replace with your endpoint name
    'region': config['models']['embedding']['credentials']['region'],
    'credentials': {
        'accessKeyId': config['models']['embedding']['credentials']['accessKeyId'],
        'secretAccessKey': config['models']['embedding']['credentials']['secretAccessKey'],
        'sessionToken': config['models']['embedding']['credentials']['sessionToken']
    }
}

# Option 2: If deployed as REST API
# RERANKER_CONFIG = {
#     'type': 'rest_api',
#     'url': 'https://your-reranker-api-endpoint.com/rerank',
#     'headers': {
#         'Content-Type': 'application/json',
#         'Authorization': 'Bearer YOUR_API_KEY'  # if needed
#     }
# }

print(f"Re-ranker type: {RERANKER_CONFIG['type']}")

## 3. Qwen3 Re-Ranker API Client

In [None]:
class Qwen3RerankerClient:
    """
    Client for Qwen3-Reranker-8B deployed via inference engine.
    Supports both SageMaker endpoint and REST API.
    """
    
    def __init__(self, config: dict):
        self.config = config
        self.api_type = config['type']
        
        if self.api_type == 'sagemaker':
            self._init_sagemaker_client()
        elif self.api_type == 'rest_api':
            self._init_rest_client()
        else:
            raise ValueError(f"Unknown API type: {self.api_type}")
        
        print(f"âœ“ Qwen3-Reranker client initialized ({self.api_type})")
    
    def _init_sagemaker_client(self):
        """Initialize SageMaker runtime client"""
        self.endpoint_name = self.config['endpoint_name']
        creds = self.config['credentials']
        
        self.client = boto3.client(
            'sagemaker-runtime',
            region_name=self.config['region'],
            aws_access_key_id=creds['accessKeyId'],
            aws_secret_access_key=creds['secretAccessKey'],
            aws_session_token=creds['sessionToken']
        )
    
    def _init_rest_client(self):
        """Initialize REST API client"""
        self.url = self.config['url']
        self.headers = self.config.get('headers', {'Content-Type': 'application/json'})
    
    def rerank(self, query: str, documents: List[str]) -> List[float]:
        """
        Get relevance scores for query-document pairs.
        
        Args:
            query: The search query
            documents: List of document contents (full text)
        
        Returns:
            List of relevance scores (one per document)
        """
        if self.api_type == 'sagemaker':
            return self._rerank_sagemaker(query, documents)
        else:
            return self._rerank_rest_api(query, documents)
    
    def _rerank_sagemaker(self, query: str, documents: List[str]) -> List[float]:
        """
        Call Qwen3-Reranker via SageMaker endpoint.
        
        Adjust the request format based on your deployment configuration.
        """
        # Format request payload
        # Adjust this based on your inference engine's expected format
        payload = {
            "query": query,
            "documents": documents
        }
        
        # Alternative format if your endpoint expects pairs:
        # payload = {
        #     "pairs": [[query, doc] for doc in documents]
        # }
        
        body = json.dumps(payload)
        
        response = self.client.invoke_endpoint(
            EndpointName=self.endpoint_name,
            ContentType='application/json',
            Body=body
        )
        
        output_data = json.loads(response['Body'].read().decode())
        
        # Extract scores from response
        # Adjust based on your endpoint's response format
        if isinstance(output_data, list):
            scores = output_data
        elif 'scores' in output_data:
            scores = output_data['scores']
        elif 'results' in output_data:
            scores = [r['score'] for r in output_data['results']]
        else:
            scores = output_data
        
        return [float(s) for s in scores]
    
    def _rerank_rest_api(self, query: str, documents: List[str]) -> List[float]:
        """
        Call Qwen3-Reranker via REST API.
        """
        payload = {
            "query": query,
            "documents": documents
        }
        
        response = requests.post(
            self.url,
            headers=self.headers,
            json=payload
        )
        response.raise_for_status()
        
        output_data = response.json()
        
        # Extract scores from response
        if isinstance(output_data, list):
            scores = output_data
        elif 'scores' in output_data:
            scores = output_data['scores']
        elif 'results' in output_data:
            scores = [r['score'] for r in output_data['results']]
        else:
            scores = output_data
        
        return [float(s) for s in scores]

print("âœ“ Qwen3RerankerClient class defined")

## 4. Document Retriever with Qwen3 Re-Ranking

In [None]:
class Qwen3ReRankingRetriever:
    """
    Two-stage document retriever:
    
    Stage 1: Hybrid search (Qwen Embeddings + BM25) for candidate retrieval
    Stage 2: Qwen3-Reranker-8B API scores query against FULL document content
    
    Returns ranked document names.
    """
    
    def __init__(self, config: dict, reranker_config: dict):
        self.config = config
        
        # ============================================================
        # Initialize Qwen Embedding Client (UNCHANGED from original)
        # ============================================================
        self.embedding_endpoint_name = config['models']['embedding']['endpoint_name']
        embedding_creds = config['models']['embedding']['credentials']
        self.embedding_client = boto3.client(
            'sagemaker-runtime',
            region_name=embedding_creds['region'],
            aws_access_key_id=embedding_creds['accessKeyId'],
            aws_secret_access_key=embedding_creds['secretAccessKey'],
            aws_session_token=embedding_creds['sessionToken']
        )
        print(f"âœ“ Qwen Embedding client initialized")
        
        # ============================================================
        # Initialize Qwen3-Reranker-8B API Client (REPLACES Llama LLM)
        # ============================================================
        self.reranker = Qwen3RerankerClient(reranker_config)
        
        # Session management
        self.sessions = {}
        
        # Load indexes
        self.load_indexes()
    
    def load_indexes(self):
        """Load FAISS and BM25 indexes"""
        
        # Load FAISS index
        faiss_path = os.path.join(self.config['storage']['faiss_index'], 'faiss.index')
        if not os.path.exists(faiss_path):
            raise FileNotFoundError(f"FAISS index not found at {faiss_path}")
        self.faiss_index = faiss.read_index(faiss_path)
        print(f"âœ“ FAISS index loaded")
        
        # Load embeddings
        embeddings_path = os.path.join(self.config['storage']['faiss_index'], 'embeddings.npy')
        if os.path.exists(embeddings_path):
            self.embeddings = np.load(embeddings_path)
            print(f"âœ“ Embeddings loaded: shape {self.embeddings.shape}")
        
        # Load BM25 index
        bm25_path = os.path.join(self.config['storage']['bm25_index'], 'bm25.pkl')
        if not os.path.exists(bm25_path):
            raise FileNotFoundError(f"BM25 index not found at {bm25_path}")
        with open(bm25_path, 'rb') as f:
            self.bm25_index = pickle.load(f)
        print(f"âœ“ BM25 index loaded")
        
        # Load chunk metadata (contains full content)
        metadata_path = os.path.join(self.config['storage']['faiss_index'], 'chunk_metadata.json')
        if not os.path.exists(metadata_path):
            raise FileNotFoundError(f"Chunk metadata not found at {metadata_path}")
        with open(metadata_path, 'r') as f:
            self.chunks = json.load(f)
        print(f"âœ“ Chunk metadata loaded: {len(self.chunks)} chunks")
        
        print("\nâœ“ All indexes loaded successfully")
    
    def get_embedding(self, text: str) -> np.ndarray:
        """Get embedding from Qwen SageMaker endpoint (UNCHANGED)"""
        params = {
            "inputs": [text],
            "encoding_format": "float"
        }
        body = json.dumps(params)
        
        response = self.embedding_client.invoke_endpoint(
            EndpointName=self.embedding_endpoint_name,
            ContentType='application/json',
            Body=body
        )
        output_data = json.loads(response['Body'].read().decode())
        embedding = np.array(output_data[0], dtype='float32')
        return embedding
    
    def hybrid_search(self, query: str, entitlement: str, org_id: str = None,
                      tags: List[str] = None, top_k: int = None) -> List[Dict]:
        """
        Stage 1: Hybrid search to retrieve candidate chunks.
        Returns chunks with FULL CONTENT for re-ranking.
        """
        if top_k is None:
            top_k = self.config['retrieval']['hybrid']['top_k']
        
        # Get query embedding from Qwen
        query_embedding = self.get_embedding(query)
        query_embedding = query_embedding.reshape(1, -1).astype('float32')
        faiss.normalize_L2(query_embedding)
        
        # Retrieve more candidates for re-ranking
        retrieval_multiplier = 10
        initial_top_k = min(top_k * retrieval_multiplier, len(self.chunks))
        
        # Vector search (FAISS)
        vector_scores, vector_indices = self.faiss_index.search(query_embedding, initial_top_k)
        vector_scores = vector_scores[0]
        vector_indices = vector_indices[0]
        
        # Keyword search (BM25)
        tokenized_query = query.lower().split()
        bm25_scores = self.bm25_index.get_scores(tokenized_query)
        
        # Normalize scores
        def normalize(scores):
            min_s, max_s = scores.min(), scores.max()
            if max_s - min_s < 1e-10:
                return np.zeros_like(scores)
            return (scores - min_s) / (max_s - min_s)
        
        vector_scores_norm = normalize(vector_scores)
        bm25_scores_norm = normalize(bm25_scores)
        
        # Compute hybrid scores
        vector_weight = self.config['retrieval']['hybrid']['vector_weight']
        bm25_weight = self.config['retrieval']['hybrid']['bm25_weight']
        
        hybrid_scores = {}
        for idx, score in zip(vector_indices, vector_scores_norm):
            hybrid_scores[idx] = score * vector_weight
        
        for idx, score in enumerate(bm25_scores_norm):
            if idx in hybrid_scores:
                hybrid_scores[idx] += score * bm25_weight
            else:
                hybrid_scores[idx] = score * bm25_weight
        
        # Sort by hybrid score
        sorted_indices = sorted(hybrid_scores.items(), key=lambda x: x[1], reverse=True)
        
        # Filter by entitlements and collect results WITH FULL CONTENT
        accessible_results = []
        
        for idx, score in sorted_indices:
            chunk = self.chunks[idx].copy()
            
            # Entitlement filter
            chunk_entitlements = chunk['entitlement']
            if isinstance(chunk_entitlements, str):
                chunk_entitlements = [chunk_entitlements]
            
            has_access = 'universal' in chunk_entitlements or entitlement in chunk_entitlements
            if not has_access:
                continue
            
            # Org filter
            if org_id and chunk['orgId'] != org_id:
                continue
            
            # Tag filter
            if tags and not any(t in chunk['metadata']['tags'] for t in tags):
                continue
            
            chunk['hybrid_score'] = float(score)
            accessible_results.append(chunk)
        
        accessible_results.sort(key=lambda x: x['hybrid_score'], reverse=True)
        return accessible_results[:top_k]
    
    def rerank(self, query: str, candidates: List[Dict], top_k: int = 5) -> List[Dict]:
        """
        Stage 2: Re-rank candidates using Qwen3-Reranker-8B API.
        
        The re-ranker receives FULL DOCUMENT CONTENT, just like the LLM did.
        """
        if not candidates:
            return []
        
        # Extract full document content for re-ranking
        documents = [chunk['content'] for chunk in candidates]
        
        # Call Qwen3-Reranker-8B API with query and full document contents
        rerank_scores = self.reranker.rerank(query, documents)
        
        # Add rerank scores to candidates
        for i, chunk in enumerate(candidates):
            chunk['rerank_score'] = rerank_scores[i]
        
        # Sort by rerank score (highest first)
        reranked = sorted(candidates, key=lambda x: x['rerank_score'], reverse=True)
        
        return reranked[:top_k]
    
    def query(self, query: str, entitlement: str, org_id: str = None,
              tags: List[str] = None, top_k: int = 5,
              candidates_for_rerank: int = 20) -> Dict:
        """
        Full retrieval pipeline with Qwen3-Reranker-8B.
        
        Args:
            query: User's search query
            entitlement: User's access level
            org_id: Organization filter
            tags: Tag filters
            top_k: Number of final results
            candidates_for_rerank: Number of candidates for re-ranking
        
        Returns:
            Dict with ranked documents
        """
        # Stage 1: Get candidates via hybrid search
        candidates = self.hybrid_search(
            query=query,
            entitlement=entitlement,
            org_id=org_id,
            tags=tags,
            top_k=candidates_for_rerank
        )
        
        if not candidates:
            return {
                'query': query,
                'documents': [],
                'message': 'No relevant documents found.'
            }
        
        # Stage 2: Re-rank using Qwen3-Reranker-8B with full document content
        reranked = self.rerank(query, candidates, top_k=top_k)
        
        # Build response with unique documents
        seen_docs = set()
        documents = []
        
        for chunk in reranked:
            doc_id = chunk['doc_id']
            if doc_id not in seen_docs:
                seen_docs.add(doc_id)
                documents.append({
                    'document_name': chunk['title'],
                    'doc_id': doc_id,
                    'rerank_score': chunk['rerank_score'],
                    'hybrid_score': chunk['hybrid_score'],
                    'content_preview': chunk['content'][:200] + '...' if len(chunk['content']) > 200 else chunk['content']
                })
        
        return {
            'query': query,
            'documents': documents,
            'candidates_considered': len(candidates)
        }
    
    # ==================== Session Management ====================
    
    def create_session(self, user_id: str, entitlement: str, org_id: str = None) -> str:
        """Create a new session"""
        session_id = str(uuid.uuid4())
        self.sessions[session_id] = {
            'session_id': session_id,
            'user_id': user_id,
            'entitlement': entitlement,
            'org_id': org_id,
            'created_at': datetime.now().isoformat(),
            'query_history': [],
            'last_activity': datetime.now().isoformat()
        }
        print(f"âœ“ Created session: {session_id}")
        return session_id
    
    def get_session(self, session_id: str) -> Optional[Dict]:
        """Get session data"""
        return self.sessions.get(session_id)
    
    def query_with_session(self, session_id: str, query: str,
                           tags: List[str] = None, top_k: int = 5,
                           candidates_for_rerank: int = 20) -> Dict:
        """Query with session tracking"""
        session = self.get_session(session_id)
        if not session:
            raise ValueError(f"Session {session_id} not found")
        
        session['last_activity'] = datetime.now().isoformat()
        
        result = self.query(
            query=query,
            entitlement=session['entitlement'],
            org_id=session['org_id'],
            tags=tags,
            top_k=top_k,
            candidates_for_rerank=candidates_for_rerank
        )
        
        result['session_id'] = session_id
        
        # Store in history
        session['query_history'].append({
            'timestamp': datetime.now().isoformat(),
            'query': query,
            'documents_found': [d['document_name'] for d in result['documents']]
        })
        
        return result
    
    def get_query_history(self, session_id: str, limit: int = None) -> List[Dict]:
        """Get query history"""
        session = self.get_session(session_id)
        if not session:
            return []
        history = session['query_history']
        return history[-limit:] if limit else history
    
    def clear_session(self, session_id: str):
        """Clear session"""
        if session_id in self.sessions:
            del self.sessions[session_id]
            print(f"âœ“ Session {session_id} cleared")
    
    def export_session(self, session_id: str, filepath: str):
        """Export session to JSON"""
        session = self.get_session(session_id)
        if session:
            with open(filepath, 'w') as f:
                json.dump(session, f, indent=2)
            print(f"âœ“ Session exported to: {filepath}")

In [None]:
print("âœ“ Qwen3ReRankingRetriever class defined")

## 5. Initialize Retriever

In [None]:
# Initialize the retriever with Qwen3-Reranker-8B API
retriever = Qwen3ReRankingRetriever(
    config=config,
    reranker_config=RERANKER_CONFIG
)

## 6. Test: Basic Query

In [None]:
print("="*70)
print("TEST 1: Basic Query with Qwen3-Reranker-8B")
print("="*70)

result = retriever.query(
    query='How do I process a cancellation?',
    entitlement='agent_support',
    org_id='org_123',
    tags=['cancellation'],
    top_k=5,
    candidates_for_rerank=20
)

print(f"\nQuery: {result['query']}")
print(f"Candidates considered: {result.get('candidates_considered', 'N/A')}")
print(f"\nRanked Documents:")
print("-"*50)

for i, doc in enumerate(result['documents'], 1):
    print(f"\n{i}. {doc['document_name']}")
    print(f"   Re-rank Score: {doc['rerank_score']:.4f}")
    print(f"   Hybrid Score:  {doc['hybrid_score']:.4f}")

## 7. Test: Compare Hybrid vs Re-Ranked

In [None]:
print("\n" + "="*70)
print("TEST 2: Compare Hybrid Search vs Qwen3 Re-Ranked Results")
print("="*70)

test_query = "What documents are needed for a refund?"

# Get hybrid search results (Stage 1 only)
hybrid_only = retriever.hybrid_search(
    query=test_query,
    entitlement='agent_support',
    org_id='org_123',
    top_k=5
)

# Get full re-ranked results
reranked = retriever.query(
    query=test_query,
    entitlement='agent_support',
    org_id='org_123',
    top_k=5,
    candidates_for_rerank=20
)

print(f"\nQuery: {test_query}")

print(f"\n{'BEFORE RE-RANKING (Hybrid Only)':^50}")
print("-"*50)
for i, chunk in enumerate(hybrid_only, 1):
    print(f"  {i}. {chunk['title']} (hybrid: {chunk['hybrid_score']:.4f})")

print(f"\n{'AFTER QWEN3 RE-RANKING':^50}")
print("-"*50)
for i, doc in enumerate(reranked['documents'], 1):
    print(f"  {i}. {doc['document_name']} (rerank: {doc['rerank_score']:.4f})")

## 8. Test: Session-Based Queries

In [None]:
print("\n" + "="*70)
print("TEST 3: Session-Based Queries")
print("="*70)

# Create session
session_id = retriever.create_session(
    user_id='agent_001',
    entitlement='agent_support',
    org_id='org_123'
)

# Query 1
print("\n--- Query 1 ---")
r1 = retriever.query_with_session(session_id=session_id, query="How do I cancel a booking?")
print(f"Query: {r1['query']}")
print(f"Top Document: {r1['documents'][0]['document_name'] if r1['documents'] else 'None'}")

# Query 2
print("\n--- Query 2 ---")
r2 = retriever.query_with_session(session_id=session_id, query="What is the refund timeline?")
print(f"Query: {r2['query']}")
print(f"Top Document: {r2['documents'][0]['document_name'] if r2['documents'] else 'None'}")

# Show history
print("\n--- Session History ---")
history = retriever.get_query_history(session_id)
for i, h in enumerate(history, 1):
    print(f"  {i}. {h['query']} â†’ {h['documents_found']}")

## 9. Test: Entitlement Filtering

In [None]:
print("\n" + "="*70)
print("TEST 4: Entitlement-Based Access Control")
print("="*70)

query = "What are the booking procedures?"

# Support agent
support_result = retriever.query(query=query, entitlement='agent_support', org_id='org_123')

# Sales agent
sales_result = retriever.query(query=query, entitlement='agent_sales', org_id='org_123')

print(f"\nQuery: {query}")

print(f"\nSupport Agent Results ({len(support_result['documents'])} docs):")
for doc in support_result['documents']:
    print(f"  ðŸ“„ {doc['document_name']} (score: {doc['rerank_score']:.4f})")

print(f"\nSales Agent Results ({len(sales_result['documents'])} docs):")
for doc in sales_result['documents']:
    print(f"  ðŸ“„ {doc['document_name']} (score: {doc['rerank_score']:.4f})")

## 10. Interactive Mode

In [None]:
USER_PROFILES = {
    '1': {'user_id': 'agent_001', 'name': 'Alice (Support)', 'entitlement': 'agent_support', 'org_id': 'org_123'},
    '2': {'user_id': 'agent_002', 'name': 'Bob (Sales)', 'entitlement': 'agent_sales', 'org_id': 'org_123'},
    '3': {'user_id': 'manager_001', 'name': 'Carol (Manager)', 'entitlement': 'agent_manager', 'org_id': 'org_123'}
}

def interactive_mode():
    """Interactive query mode"""
    print("\n" + "="*50)
    print("Interactive Document Retrieval (Qwen3-Reranker-8B)")
    print("="*50)
    
    print("\nSelect User Profile:")
    for key, profile in USER_PROFILES.items():
        print(f"  {key}. {profile['name']}")
    
    choice = input("\nChoice (1-3): ").strip()
    if choice not in USER_PROFILES:
        print("Invalid choice")
        return
    
    profile = USER_PROFILES[choice]
    session_id = retriever.create_session(
        user_id=profile['user_id'],
        entitlement=profile['entitlement'],
        org_id=profile['org_id']
    )
    
    print(f"\nLogged in as: {profile['name']}")
    print("Commands: 'quit' to exit, 'history' to see past queries\n")
    
    while True:
        query = input("You: ").strip()
        
        if query.lower() == 'quit':
            print("Goodbye!")
            break
        
        if query.lower() == 'history':
            history = retriever.get_query_history(session_id)
            for h in history:
                print(f"  [{h['timestamp'][:19]}] {h['query']} â†’ {h['documents_found']}")
            print()
            continue
        
        if not query:
            continue
        
        result = retriever.query_with_session(session_id=session_id, query=query)
        
        print("\nRelevant Documents:")
        if result['documents']:
            for i, doc in enumerate(result['documents'], 1):
                print(f"  {i}. {doc['document_name']} (score: {doc['rerank_score']:.4f})")
        else:
            print("  No relevant documents found.")
        print()

# Uncomment to run:
# interactive_mode()

## 11. Summary

### Pipeline:
```
Query
  â†“
Hybrid Search (Qwen Embeddings + BM25)
  â†“
20 Candidate Chunks (with full content)
  â†“
Qwen3-Reranker-8B API (scores query vs full content)
  â†“
Top 5 Ranked Documents
```

### API Request Format:
```python
# Request to Qwen3-Reranker-8B API:
{
    "query": "How do I process a cancellation?",
    "documents": [
        "Full content of document 1...",
        "Full content of document 2...",
        ...
    ]
}

# Response:
{
    "scores": [0.92, 0.85, 0.71, ...]
}
```

### Note:
Adjust the request/response format in `Qwen3RerankerClient` based on your inference engine's API specification.

In [None]:
print("\n" + "="*70)
print("NOTEBOOK COMPLETE")
print("="*70)
print("\nQwen3-Reranker-8B API configured.")
print("The re-ranker receives FULL document content (like the LLM did).")
print("It outputs relevance scores to rank documents.")