# Qwen3-Reranker with Correct Response Parsing

Updated to parse the response format:
```json
{
    "id": "",
    "model": "",
    "usage": {"total_tokens": ""},
    "results": [
        {"index": 0, "document": {...}, "relevance_score": 0.95},
        {"index": 1, "document": {...}, "relevance_score": 0.72}
    ]
}
```

In [None]:
import yaml
import json
import os
import pickle
from typing import List, Dict, Optional, Tuple
import numpy as np
import faiss
import boto3
import uuid
from datetime import datetime

# Load configuration
with open('config/config.yaml', 'r') as f:
    config = yaml.safe_load(f)

print("✓ Config loaded")

In [None]:
# =============================================================================
# RERANKER CONFIGURATION
# =============================================================================

RERANKER_CONFIG = {
    'endpoint_name': 'your-qwen3-reranker-endpoint',  # <-- UPDATE THIS
    'region': config['models']['embedding']['credentials']['region'],
    'credentials': {
        'accessKeyId': config['models']['embedding']['credentials']['accessKeyId'],
        'secretAccessKey': config['models']['embedding']['credentials']['secretAccessKey'],
        'sessionToken': config['models']['embedding']['credentials']['sessionToken']
    }
}

print(f"Reranker endpoint: {RERANKER_CONFIG['endpoint_name']}")

In [None]:
class Qwen3RerankerClient:
    """
    Client for Qwen3-Reranker-8B API.
    
    Expected response format:
    {
        "results": [
            {"index": 0, "document": {...}, "relevance_score": 0.95},
            {"index": 1, "document": {...}, "relevance_score": 0.72}
        ]
    }
    """
    
    def __init__(self, config: dict):
        self.endpoint_name = config['endpoint_name']
        creds = config['credentials']
        
        self.client = boto3.client(
            'sagemaker-runtime',
            region_name=config['region'],
            aws_access_key_id=creds['accessKeyId'],
            aws_secret_access_key=creds['secretAccessKey'],
            aws_session_token=creds['sessionToken']
        )
        print(f"✓ Qwen3-Reranker client initialized")
    
    def rerank(self, query: str, documents: List[str], verbose: bool = False) -> List[float]:
        """
        Get relevance scores for query-document pairs.
        
        Args:
            query: The search query
            documents: List of document contents
            verbose: Print debug info
        
        Returns:
            List of relevance scores (higher = more relevant)
        """
        # Build request payload
        payload = {
            "query": query,
            "documents": documents
        }
        
        if verbose:
            print(f"[DEBUG] Query: {query[:100]}...")
            print(f"[DEBUG] Number of documents: {len(documents)}")
        
        # Call API
        response = self.client.invoke_endpoint(
            EndpointName=self.endpoint_name,
            ContentType='application/json',
            Body=json.dumps(payload)
        )
        
        # Read response body ONCE
        raw_bytes = response['Body'].read()
        output_data = json.loads(raw_bytes.decode('utf-8'))
        
        if verbose:
            print(f"[DEBUG] Response keys: {output_data.keys()}")
            print(f"[DEBUG] Number of results: {len(output_data.get('results', []))}")
        
        # =============================================================
        # PARSE SCORES FROM YOUR RESPONSE FORMAT
        # =============================================================
        results = output_data['results']
        
        # Extract relevance_score from each result
        # Results might not be in order, so we sort by index
        scores_with_index = [
            (result['index'], result['relevance_score']) 
            for result in results
        ]
        
        # Sort by index to match original document order
        scores_with_index.sort(key=lambda x: x[0])
        
        # Extract just the scores in order
        scores = [score for idx, score in scores_with_index]
        
        if verbose:
            print(f"[DEBUG] Extracted scores: {scores}")
        
        return scores

print("✓ Qwen3RerankerClient class defined")

In [None]:
class Qwen3ReRankingRetriever:
    """
    Document retriever with Qwen3-Reranker-8B.
    """
    
    def __init__(self, config: dict, reranker_config: dict):
        self.config = config
        
        # Initialize Qwen Embedding Client
        self.embedding_endpoint_name = config['models']['embedding']['endpoint_name']
        embedding_creds = config['models']['embedding']['credentials']
        self.embedding_client = boto3.client(
            'sagemaker-runtime',
            region_name=embedding_creds['region'],
            aws_access_key_id=embedding_creds['accessKeyId'],
            aws_secret_access_key=embedding_creds['secretAccessKey'],
            aws_session_token=embedding_creds['sessionToken']
        )
        print(f"✓ Qwen Embedding client initialized")
        
        # Initialize Qwen3-Reranker Client
        self.reranker = Qwen3RerankerClient(reranker_config)
        
        # Session management
        self.sessions = {}
        
        # Load indexes
        self.load_indexes()
    
    def load_indexes(self):
        """Load FAISS and BM25 indexes"""
        faiss_path = os.path.join(self.config['storage']['faiss_index'], 'faiss.index')
        self.faiss_index = faiss.read_index(faiss_path)
        print(f"✓ FAISS index loaded")
        
        embeddings_path = os.path.join(self.config['storage']['faiss_index'], 'embeddings.npy')
        if os.path.exists(embeddings_path):
            self.embeddings = np.load(embeddings_path)
        
        bm25_path = os.path.join(self.config['storage']['bm25_index'], 'bm25.pkl')
        with open(bm25_path, 'rb') as f:
            self.bm25_index = pickle.load(f)
        print(f"✓ BM25 index loaded")
        
        metadata_path = os.path.join(self.config['storage']['faiss_index'], 'chunk_metadata.json')
        with open(metadata_path, 'r') as f:
            self.chunks = json.load(f)
        print(f"✓ Chunk metadata loaded: {len(self.chunks)} chunks")
    
    def get_embedding(self, text: str) -> np.ndarray:
        """Get embedding from Qwen endpoint"""
        params = {"inputs": [text], "encoding_format": "float"}
        response = self.embedding_client.invoke_endpoint(
            EndpointName=self.embedding_endpoint_name,
            ContentType='application/json',
            Body=json.dumps(params)
        )
        raw_bytes = response['Body'].read()
        output_data = json.loads(raw_bytes.decode())
        return np.array(output_data[0], dtype='float32')
    
    def hybrid_search(self, query: str, entitlement: str, org_id: str = None,
                      tags: List[str] = None, top_k: int = None) -> List[Dict]:
        """Stage 1: Hybrid search"""
        if top_k is None:
            top_k = self.config['retrieval']['hybrid']['top_k']
        
        query_embedding = self.get_embedding(query)
        query_embedding = query_embedding.reshape(1, -1).astype('float32')
        faiss.normalize_L2(query_embedding)
        
        initial_top_k = min(top_k * 10, len(self.chunks))
        
        # Vector search
        vector_scores, vector_indices = self.faiss_index.search(query_embedding, initial_top_k)
        vector_scores = vector_scores[0]
        vector_indices = vector_indices[0]
        
        # BM25 search
        tokenized_query = query.lower().split()
        bm25_scores = self.bm25_index.get_scores(tokenized_query)
        
        # Normalize
        def normalize(scores):
            min_s, max_s = scores.min(), scores.max()
            if max_s - min_s < 1e-10:
                return np.zeros_like(scores)
            return (scores - min_s) / (max_s - min_s)
        
        vector_scores_norm = normalize(vector_scores)
        bm25_scores_norm = normalize(bm25_scores)
        
        # Hybrid scores
        vector_weight = self.config['retrieval']['hybrid']['vector_weight']
        bm25_weight = self.config['retrieval']['hybrid']['bm25_weight']
        
        hybrid_scores = {}
        for idx, score in zip(vector_indices, vector_scores_norm):
            hybrid_scores[idx] = score * vector_weight
        
        for idx, score in enumerate(bm25_scores_norm):
            if idx in hybrid_scores:
                hybrid_scores[idx] += score * bm25_weight
            else:
                hybrid_scores[idx] = score * bm25_weight
        
        sorted_indices = sorted(hybrid_scores.items(), key=lambda x: x[1], reverse=True)
        
        # Filter by access
        accessible_results = []
        for idx, score in sorted_indices:
            chunk = self.chunks[idx].copy()
            
            chunk_entitlements = chunk['entitlement']
            if isinstance(chunk_entitlements, str):
                chunk_entitlements = [chunk_entitlements]
            
            has_access = 'universal' in chunk_entitlements or entitlement in chunk_entitlements
            if not has_access:
                continue
            
            if org_id and chunk['orgId'] != org_id:
                continue
            
            if tags and not any(t in chunk['metadata']['tags'] for t in tags):
                continue
            
            chunk['hybrid_score'] = float(score)
            accessible_results.append(chunk)
        
        accessible_results.sort(key=lambda x: x['hybrid_score'], reverse=True)
        return accessible_results[:top_k]
    
    def rerank(self, query: str, candidates: List[Dict], top_k: int = 5,
               verbose: bool = False) -> List[Dict]:
        """
        Stage 2: Re-rank using Qwen3-Reranker-8B.
        Higher relevance_score = more relevant.
        """
        if not candidates:
            return []
        
        # Extract document contents
        documents = [chunk['content'] for chunk in candidates]
        
        # Get rerank scores
        rerank_scores = self.reranker.rerank(query, documents, verbose=verbose)
        
        # Add scores to candidates
        for i, chunk in enumerate(candidates):
            chunk['rerank_score'] = rerank_scores[i]
        
        # Sort by rerank score (HIGHER = better for relevance_score)
        reranked = sorted(candidates, key=lambda x: x['rerank_score'], reverse=True)
        
        return reranked[:top_k]
    
    def query(self, query: str, entitlement: str, org_id: str = None,
              tags: List[str] = None, top_k: int = 5,
              candidates_for_rerank: int = 20,
              use_reranker: bool = True,
              verbose: bool = False) -> Dict:
        """
        Full retrieval pipeline.
        """
        # Stage 1: Hybrid search
        candidates = self.hybrid_search(
            query=query,
            entitlement=entitlement,
            org_id=org_id,
            tags=tags,
            top_k=candidates_for_rerank
        )
        
        if not candidates:
            return {'query': query, 'documents': [], 'message': 'No relevant documents found.'}
        
        if verbose:
            print(f"\n[HYBRID SEARCH] Top candidates:")
            for i, c in enumerate(candidates[:5]):
                print(f"  {i+1}. {c['title']} (hybrid: {c['hybrid_score']:.4f})")
        
        # Stage 2: Re-rank
        if use_reranker:
            final_results = self.rerank(query, candidates, top_k=top_k, verbose=verbose)
            
            if verbose:
                print(f"\n[RE-RANKED] Final results:")
                for i, c in enumerate(final_results):
                    print(f"  {i+1}. {c['title']} (rerank: {c['rerank_score']:.4f}, hybrid: {c['hybrid_score']:.4f})")
        else:
            final_results = candidates[:top_k]
        
        # Build response
        seen_docs = set()
        documents = []
        for chunk in final_results:
            doc_id = chunk['doc_id']
            if doc_id not in seen_docs:
                seen_docs.add(doc_id)
                doc_entry = {
                    'document_name': chunk['title'],
                    'doc_id': doc_id,
                    'hybrid_score': chunk['hybrid_score'],
                }
                if use_reranker:
                    doc_entry['rerank_score'] = chunk['rerank_score']
                documents.append(doc_entry)
        
        return {
            'query': query,
            'documents': documents,
            'reranker_used': use_reranker
        }
    
    # Session management
    def create_session(self, user_id: str, entitlement: str, org_id: str = None) -> str:
        session_id = str(uuid.uuid4())
        self.sessions[session_id] = {
            'session_id': session_id,
            'user_id': user_id,
            'entitlement': entitlement,
            'org_id': org_id,
            'query_history': []
        }
        print(f"✓ Created session: {session_id}")
        return session_id
    
    def query_with_session(self, session_id: str, query: str, **kwargs) -> Dict:
        session = self.sessions.get(session_id)
        if not session:
            raise ValueError(f"Session {session_id} not found")
        
        result = self.query(
            query=query,
            entitlement=session['entitlement'],
            org_id=session['org_id'],
            **kwargs
        )
        
        session['query_history'].append({
            'query': query,
            'documents_found': [d['document_name'] for d in result['documents']]
        })
        
        return result

print("✓ Qwen3ReRankingRetriever class defined")

## Initialize and Test

In [None]:
# Initialize retriever
retriever = Qwen3ReRankingRetriever(
    config=config,
    reranker_config=RERANKER_CONFIG
)

In [None]:
# Test with verbose output
print("="*70)
print("TEST: Query with Re-Ranking")
print("="*70)

result = retriever.query(
    query='How do I process a cancellation?',
    entitlement='agent_support',
    org_id='org_123',
    top_k=5,
    candidates_for_rerank=20,
    use_reranker=True,
    verbose=True
)

print("\n" + "="*70)
print("FINAL RESULTS:")
print("="*70)
for i, doc in enumerate(result['documents'], 1):
    print(f"{i}. {doc['document_name']}")
    print(f"   Rerank Score: {doc.get('rerank_score', 'N/A')}")
    print(f"   Hybrid Score: {doc['hybrid_score']:.4f}")

In [None]:
# Compare with and without re-ranker
print("="*70)
print("COMPARISON: Hybrid Only vs Re-Ranked")
print("="*70)

test_query = "How do I process a cancellation?"

# Without re-ranker
print("\n--- WITHOUT RE-RANKER (Hybrid Only) ---")
result_hybrid = retriever.query(
    query=test_query,
    entitlement='agent_support',
    org_id='org_123',
    use_reranker=False
)
for i, doc in enumerate(result_hybrid['documents'], 1):
    print(f"  {i}. {doc['document_name']} (hybrid: {doc['hybrid_score']:.4f})")

# With re-ranker
print("\n--- WITH RE-RANKER ---")
result_rerank = retriever.query(
    query=test_query,
    entitlement='agent_support',
    org_id='org_123',
    use_reranker=True
)
for i, doc in enumerate(result_rerank['documents'], 1):
    print(f"  {i}. {doc['document_name']} (rerank: {doc['rerank_score']:.4f}, hybrid: {doc['hybrid_score']:.4f})")

## If Re-Ranker Still Gives Wrong Order

If the document with highest hybrid score is still ranked lower by re-ranker, check:

1. **Are the relevance_scores what you expect?** (Higher should be better)
2. **Is the input format correct?** Try `pairs` format instead

In [None]:
# Test re-ranker directly with known documents
print("="*70)
print("DIRECT RE-RANKER TEST")
print("="*70)

test_query = "How do I cancel a booking?"
test_docs = [
    "To cancel a booking, verify customer identity first, then process the cancellation in the system.",  # Most relevant
    "Refunds are processed within 5-7 business days after cancellation.",  # Somewhat relevant
    "To create a new booking, enter customer details and payment information."  # Not relevant
]

scores = retriever.reranker.rerank(test_query, test_docs, verbose=True)

print(f"\nScores:")
for i, (doc, score) in enumerate(zip(test_docs, scores)):
    print(f"  Doc {i}: score={score:.4f} - {doc[:50]}...")

print(f"\nExpected: Doc 0 should have highest score (most relevant to cancellation)")
print(f"Actual highest: Doc {scores.index(max(scores))}")