# Production RAG System - Model Deployment

**Architecture:**
- 🔍 **BGE Model**: `databricks-bge-large-en` for high-quality embeddings and retrieval
- 🧠 **Llama 3.3 70B**: `databricks-meta-llama-3-3-70b-instruct` for reasoning and response generation
- 🚀 **MLflow 3.0**: For production deployment and serving
- 📊 **Vector Search**: `jgmt` endpoint with `icc.jugement.main_text_summarized` index

**Deployment Target:** Production serving endpoint for ICC judgment Q&A


In [None]:
# Install required packages
%pip install -U -qqqq mlflow>=3.1.1 langchain langgraph databricks-langchain pydantic databricks-agents unitycatalog-langchain[databricks] uv databricks-feature-engineering==0.12.1
dbutils.library.restartPython()


In [None]:
# Load configuration
import sys
sys.path.append('/Workspace/Users/christophe629@gmail.com/icc_rag_backend/databricks-deployment/config')

# Import unified configuration
from databricks_config import *

# Core imports
import json
import pandas as pd
import numpy as np
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, asdict
import datetime
import logging

import mlflow
from mlflow.models import infer_signature
from mlflow.models.resources import (
    DatabricksVectorSearchIndex,
    DatabricksServingEndpoint
)

# Vector Search and LLM
from databricks.vector_search.client import VectorSearchClient
from langchain_community.chat_models import ChatDatabricks
from langchain.schema import HumanMessage, SystemMessage
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.memory import ConversationBufferWindowMemory

print("✅ Configuration and dependencies loaded")
print_config_summary()


## Core Production RAG System


In [None]:
@dataclass
class RetrievalContext:
    chunk_id: str
    content: str
    summary: str
    section_type: str
    page_range: str
    similarity_score: float
    relevance_score: float

class ProductionRAGSystem:
    """Production RAG system with BGE retrieval and Llama conversation."""
    
    def __init__(self):
        # Initialize clients with configuration
        self.vsc = VectorSearchClient()
        self.llm = ChatDatabricks(
            target_uri="databricks",
            endpoint=LLAMA_MODEL_ENDPOINT,
            temperature=RAG_CONFIG["temperature"],
            max_tokens=RAG_CONFIG["max_tokens"]
        )
        
        # Use legal expansions from config
        self.legal_expansions = LEGAL_EXPANSIONS
        self.person_entities = KEY_ENTITIES["persons"]
        
        # Conversation memory
        self.conversations = {}
    
    def enhance_query(self, query: str) -> str:
        """Enhance query for better BGE retrieval."""
        enhanced = query.lower()
        
        # Add legal term expansions from config
        for term, expansions in self.legal_expansions.items():
            if term in enhanced:
                enhanced += f" {expansions[0]}"
        
        # Add person context
        for person in self.person_entities:
            if person.lower() in enhanced:
                enhanced += f" defendant {person}"
        
        return enhanced
    
    def retrieve_contexts(self, query: str, num_results: int = None) -> List[RetrievalContext]:
        """Retrieve contexts using BGE-powered vector search."""
        if num_results is None:
            num_results = RAG_CONFIG["default_num_results"]
            
        try:
            # Enhance query for BGE
            enhanced_query = self.enhance_query(query)
            
            # Perform vector search
            results = self.vsc.get_index(VECTOR_SEARCH_ENDPOINT, VECTOR_SEARCH_INDEX).similarity_search(
                query_text=enhanced_query,
                columns=["chunk_id", "content", "section_type", "page_range", "summary"],
                num_results=num_results * 2  # Get extra for filtering
            )
            
            docs = results.get('result', {}).get('data_array', [])
            
            # Convert to RetrievalContext objects
            contexts = []
            for doc in docs:
                if len(doc) >= 6:
                    section_weight = get_section_weight(doc[2])
                    context = RetrievalContext(
                        chunk_id=doc[0],
                        content=doc[1],
                        section_type=doc[2],
                        page_range=doc[3],
                        summary=doc[4],
                        similarity_score=float(doc[5]),
                        relevance_score=float(doc[5]) * section_weight
                    )
                    contexts.append(context)
            
            # Apply relevance boosting
            contexts = self._boost_relevance(contexts, query)
            
            # Sort and return top results
            contexts = sorted(contexts, key=lambda x: x.relevance_score, reverse=True)
            return contexts[:num_results]
            
        except Exception as e:
            print(f"❌ Retrieval error: {e}")
            return []
    
    def _boost_relevance(self, contexts: List[RetrievalContext], query: str) -> List[RetrievalContext]:
        """Apply relevance boosting based on data insights."""
        query_lower = query.lower()
        
        for context in contexts:
            boost = 1.0
            text = f"{context.content} {context.summary}".lower()
            
            # Person mention boost (key insight: 723 Yekatom mentions)
            for person in self.person_entities:
                if person.lower() in text:
                    boost += 0.15
            
            # Legal concept boost
            legal_terms = ["war crime", "murder", "persecution", "command"]
            for term in legal_terms:
                if term in text:
                    boost += 0.10
            
            context.relevance_score *= boost
        
        return contexts
    
    def generate_response(self, query: str, contexts: List[RetrievalContext], 
                         conversation_id: str = None) -> str:
        """Generate response using Llama with legal expertise."""
        
        # Get or create conversation memory
        if conversation_id and RAG_CONFIG["enable_conversation_memory"]:
            if conversation_id not in self.conversations:
                self.conversations[conversation_id] = ConversationBufferWindowMemory(
                    k=RAG_CONFIG["conversation_memory_window"], 
                    return_messages=True
                )
            memory = self.conversations[conversation_id]
            history = memory.chat_memory.messages
        else:
            history = []
        
        # Format contexts for Llama
        context_text = self._format_contexts(contexts)
        
        # Create legal system prompt
        system_prompt = f"""You are an expert legal analyst specializing in International Criminal Court (ICC) proceedings. You are analyzing the judgment in "{DOCUMENT_INFO['case_name']}" ({DOCUMENT_INFO['case_number']}).

Your expertise includes:
- International criminal law (war crimes, crimes against humanity, genocide)
- ICC procedures and legal standards
- Central African Republic conflict analysis
- Legal reasoning and evidence evaluation

Guidelines:
1. Base responses strictly on the provided judgment context
2. Use proper legal terminology and cite specific sections
3. Maintain judicial objectivity
4. Reference page numbers and sections when available
5. Clearly state when information is not available in the context

Structure your response with clear legal reasoning and specific evidence."""

        # Create the prompt
        chat_template = ChatPromptTemplate.from_messages([
            ("system", system_prompt),
            MessagesPlaceholder(variable_name="history"),
            ("human", """
Context from ICC Judgment:
{context}

Legal Question: {query}

Please provide a comprehensive legal analysis based on the judgment context provided.""")
        ])
        
        try:
            # Format messages
            messages = chat_template.format_messages(
                context=context_text,
                query=query,
                history=history
            )
            
            # Generate response
            response = self.llm(messages)
            
            # Update memory if conversation_id provided
            if conversation_id and RAG_CONFIG["enable_conversation_memory"]:
                memory.chat_memory.add_user_message(query)
                memory.chat_memory.add_ai_message(response.content)
            
            return response.content
            
        except Exception as e:
            return f"Error generating response: {str(e)}"
    
    def _format_contexts(self, contexts: List[RetrievalContext]) -> str:
        """Format contexts for Llama processing."""
        if not contexts:
            return "No relevant context found."
        
        formatted = []
        for i, ctx in enumerate(contexts, 1):
            formatted.append(f"""
**Source {i}** - {ctx.section_type} (Pages {ctx.page_range})
Relevance: {ctx.relevance_score:.3f}

Content: {ctx.content}

Summary: {ctx.summary}
---""")
        
        return "\n".join(formatted)
    
    def process_query(self, query: str, num_results: int = None, conversation_id: str = None) -> Dict:
        """Process complete RAG query."""
        start_time = datetime.datetime.now()
        
        # Retrieve contexts
        contexts = self.retrieve_contexts(query, num_results)
        
        # Generate response
        response = self.generate_response(query, contexts, conversation_id)
        
        # Calculate metrics
        processing_time = (datetime.datetime.now() - start_time).total_seconds()
        
        return {
            "response": response,
            "conversation_id": conversation_id,
            "num_contexts": len(contexts),
            "processing_time_seconds": processing_time,
            "sources": [
                {
                    "chunk_id": ctx.chunk_id,
                    "section": ctx.section_type,
                    "pages": ctx.page_range,
                    "relevance": round(ctx.relevance_score, 3)
                }
                for ctx in contexts[:5]  # Top 5 sources
            ] if RAG_CONFIG["include_sources"] else []
        }

# Initialize the system
rag_system = ProductionRAGSystem()
print("✅ Production RAG System initialized")


## MLflow Model and Deployment


In [None]:
# Test the RAG system locally
def test_rag_system():
    """Test the RAG system with sample queries."""
    print("🧪 TESTING PRODUCTION RAG SYSTEM")
    print("=" * 40)
    
    test_queries = [
        "What war crimes was Alfred Yekatom found guilty of?",
        "What evidence supported the persecution charges?",
        "What sentence was imposed on Yekatom?"
    ]
    
    for i, query in enumerate(test_queries, 1):
        print(f"\n{i}. Query: {query}")
        print("-" * 30)
        
        try:
            result = rag_system.process_query(
                query=query,
                num_results=8,
                conversation_id=f"test_{i}"
            )
            
            print(f"✅ Response length: {len(result['response'])} chars")
            print(f"📊 Contexts used: {result['num_contexts']}")
            print(f"⏱️  Processing time: {result['processing_time_seconds']:.2f}s")
            print(f"📝 Preview: {result['response'][:150]}...")
            
            if result['sources']:
                print(f"🔍 Top source: {result['sources'][0]['section']} (relevance: {result['sources'][0]['relevance']})")
                
        except Exception as e:
            print(f"❌ Error: {e}")
    
    print("\n✅ RAG system testing complete")

# Run tests
test_rag_system()


In [None]:
# MLflow Model Wrapper
class ProductionRAGModel(mlflow.pyfunc.PythonModel):
    """MLflow 3.0 production model wrapper."""
    
    def load_context(self, context):
        """Initialize the production RAG system."""
        self.rag_system = ProductionRAGSystem()
    
    def predict(self, context, model_input: pd.DataFrame) -> List[Dict]:
        """Handle predictions for serving endpoint."""
        try:
            queries = model_input["query"].tolist()
            
            # Extract optional parameters
            num_results_list = model_input.get("num_results", [RAG_CONFIG["default_num_results"]] * len(queries)).tolist()
            conversation_ids = model_input.get("conversation_id", [None] * len(queries)).tolist()
            
            results = []
            for query, num_results, conv_id in zip(queries, num_results_list, conversation_ids):
                try:
                    # Process query
                    result = self.rag_system.process_query(
                        query=query,
                        num_results=num_results if pd.notna(num_results) else RAG_CONFIG["default_num_results"],
                        conversation_id=conv_id if pd.notna(conv_id) else None
                    )
                    results.append(result)
                    
                except Exception as e:
                    # Handle individual query errors
                    error_result = {
                        "response": f"Error processing query: {str(e)}",
                        "conversation_id": conv_id,
                        "num_contexts": 0,
                        "processing_time_seconds": 0,
                        "sources": []
                    }
                    results.append(error_result)
            
            return results
            
        except Exception as e:
            return [{"error": f"Model error: {str(e)}"}] * len(model_input)

# Register the production model
with mlflow.start_run(run_name="ICC_RAG_Production_BGE_Llama") as run:
    
    # Create model instance
    production_model = ProductionRAGModel()
    
    # Input example for serving endpoint
    input_example = pd.DataFrame({
        "query": [
            "What role did Alfred Yekatom play in the Anti-Balaka forces?",
            "What evidence was presented about persecution of Muslims?"
        ],
        "num_results": [10, 12],
        "conversation_id": ["session_001", "session_001"]
    })
    
    # Expected output format
    output_example = [
        {
            "response": "Based on the ICC judgment...",
            "conversation_id": "session_001",
            "num_contexts": 10,
            "processing_time_seconds": 2.5,
            "sources": [
                {
                    "chunk_id": "chunk_0001",
                    "section": "EVIDENTIARY_CONSIDERATIONS",
                    "pages": "150-151",
                    "relevance": 0.95
                }
            ]
        }
    ]
    
    # Log the model
    mlflow.pyfunc.log_model(
        artifact_path="icc_rag_bge_llama_model",
        python_model=production_model,
        input_example=input_example,
        signature=infer_signature(input_example, output_example),
        resources=[
            DatabricksVectorSearchIndex(index_name=VECTOR_SEARCH_INDEX),
            DatabricksServingEndpoint(endpoint_name=BGE_MODEL_ENDPOINT),
            DatabricksServingEndpoint(endpoint_name=LLAMA_MODEL_ENDPOINT)
        ],
        pip_requirements=[
            "mlflow>=3.1.1",
            "langchain",
            "databricks-langchain",
            "numpy",
            "pandas"
        ]
    )
    
    # Register model
    model_uri = f"runs:/{run.info.run_id}/icc_rag_bge_llama_model"
    registered_model = mlflow.register_model(
        model_uri=model_uri,
        name=get_databricks_path("rag_model")
    )
    
    print(f"✅ Model logged: {run.info.run_id}")
    print(f"📦 Model registered: {registered_model.name} v{registered_model.version}")

print("\n🎉 PRODUCTION RAG DEPLOYMENT COMPLETE!")
print(f"📦 Model: {get_databricks_path('rag_model')}")
print("🚀 Ready for serving endpoint deployment!")
