# ICC Enhanced RAG System - Production Deployment

**Architecture:**
- 🔍 **Enhanced Vector Search**: Dual-index retrieval with intelligent routing using `databricks-gte-large-en`
- 🧠 **Advanced LLM**: `databricks-meta-llama-3-3-70b-instruct` for legal analysis
- 🚀 **MLflow 3.0**: Production deployment and model management
- ⚖️ **Legal Expertise**: Specialized for ICC defense team research

**Data Sources:**
- **Past Judgments Index**: `past_judgement` (ICTY/ICC case law)
- **Geneva Documentation Index**: `geneva_documentation` (IHL framework)
- **Vector Search Endpoint**: `jgmt` (with databricks-gte-large-en embedding model)

**Key Features:**
- Intelligent routing based on legal topics
- Enhanced retrieval with relevance boosting
- Comprehensive legal analysis generation
- Production-ready MLflow 3.0 deployment


In [0]:
%pip install -U -qqqq mlflow>=3.1.1 langchain databricks-langchain pydantic databricks-agents unitycatalog-langchain[databricks] uv databricks-feature-engineering==0.12.1
dbutils.library.restartPython()


[43mNote: you may need to restart the kernel using %restart_python or dbutils.library.restartPython() to use updated packages.[0m


In [0]:
import json
import pandas as pd
import numpy as np
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, asdict
import datetime
import logging
import re

import mlflow
from mlflow.models import infer_signature
from mlflow.models.resources import (
    DatabricksVectorSearchIndex,
    DatabricksServingEndpoint
)

# Vector Search and LLM
from databricks.vector_search.client import VectorSearchClient
from databricks.sdk import WorkspaceClient
from langchain_community.chat_models import ChatDatabricks
from langchain.schema import HumanMessage, SystemMessage
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.memory import ConversationBufferWindowMemory

print("✅ Enhanced RAG dependencies loaded")


✅ Enhanced RAG dependencies loaded


## Enhanced Configuration & Legal Topics


In [0]:
# Enhanced Configuration
VECTOR_SEARCH_ENDPOINT = "jgmt"
PAST_JUDGMENTS_INDEX = "icc_chatbot.search_model.past_judgement"
GENEVA_DOCUMENTATION_INDEX = "icc_chatbot.search_model.geneva_documentation"
LLM_MODEL_ENDPOINT = "databricks-meta-llama-3-3-70b-instruct"

# Search parameters
DEFAULT_TOP_K = 10
MAX_CONTEXT_LENGTH = 4000
SIMILARITY_THRESHOLD = 0.7
MAX_TOKENS = 2048
TEMPERATURE = 0.1

# Legal topics for intelligent routing
LEGAL_TOPICS = {
    "judgment_priority": [
        "overall control", "state", "protected persons", "active participation", "direct participation",
        "combatant status", "combatant privilege", "civilian status", "duty to protect",
        "organisation of armed groups", "principle of distinction", "indiscriminate attack",
        "civilian population", "military objectives", "military objects", "rule of proportionality",
        "principle of proportionality", "collateral damage", "military necessity",
        "military imperative", "security of civilians", "imperative military reasons",
        "conduct of hostilities", "means of warfare", "methods of warfare",
        "attacks against protected objects", "religious buildings", "displacement",
        "deportation", "coercion", "cruel treatment", "torture", "outrages against dignity",
        "murder", "self-defense", "causal link", "checkpoints", "roadblocks",
        "icty", "trial chamber", "appeals chamber", "judgment", "applied", "practice"
    ],
    "geneva_priority": [
        "geneva convention", "international humanitarian law", "ihl", "protected persons",
        "wounded and sick", "prisoners of war", "civilians", "medical personnel",
        "religious personnel", "cultural property", "distinctive emblems", "red cross",
        "red crescent", "additional protocol", "grave breaches", "serious violations",
        "customary international law", "treaty law", "convention", "protocol"
    ]
}

print("✅ Enhanced configuration loaded with legal topics")


✅ Enhanced configuration loaded with legal topics


In [0]:
@dataclass
class SearchResult:
    """Enhanced search result with comprehensive metadata"""
    content: str
    summary: str
    source: str
    metadata: Dict[str, Any]
    score: float
    source_type: str  # 'judgment' or 'geneva'
    page_number: Optional[int] = None
    article: Optional[str] = None
    section: Optional[str] = None
    document_type: Optional[str] = None

@dataclass
class RetrievalContext:
    """Enhanced retrieval context with routing information"""
    question: str
    routing_decision: str
    judgment_results: "List[SearchResult]"
    geneva_results: "List[SearchResult]"
    all_results: "List[SearchResult]"
    total_sources: int
    processing_time: float
    
@dataclass
class LegalAnalysis:
    """Structured legal analysis result"""
    question: str
    analysis: str
    sources_used: "List[SearchResult]"
    key_findings: List[str]
    citations: List[str]
    confidence_score: float
    processing_time: float

print("✅ Enhanced data structures defined")


✅ Enhanced data structures defined


## Enhanced Data Structures


In [0]:
class EnhancedICCRAGSystem:
    """Enhanced ICC RAG system with intelligent routing and legal expertise."""
    
    def __init__(self):
        # Initialize clients
        self.vsc = VectorSearchClient()
        self.w = WorkspaceClient()
        self.llm = ChatDatabricks(
            target_uri="databricks",
            endpoint=LLM_MODEL_ENDPOINT,
            temperature=TEMPERATURE,
            max_tokens=MAX_TOKENS
        )
        
        # Conversation memory
        self.conversations = {}
        
        # Legal terminology for query enhancement
        self.legal_expansions = {
            "war crimes": ["war crime", "violations of laws of war", "grave breaches"],
            "crimes against humanity": ["crime against humanity", "systematic attack", "persecution"],
            "persecution": ["persecute", "persecuted", "discriminatory acts", "discriminatory intent"],
            "murder": ["kill", "killing", "unlawful killing", "wilful killing"],
            "active participation": ["direct participation", "hostilities", "combatant status"],
            "civilian status": ["protected person", "civilian population", "non-combatant"],
            "combatant status": ["combatant privilege", "armed forces", "military objective"]
        }
    
    def determine_routing_priority(self, question: str) -> str:
        """Determine which index to prioritize based on question content."""
        question_lower = question.lower()
        
        # Count matches for each topic category
        judgment_matches = sum(1 for topic in LEGAL_TOPICS["judgment_priority"] 
                              if topic in question_lower)
        geneva_matches = sum(1 for topic in LEGAL_TOPICS["geneva_priority"] 
                            if topic in question_lower)
        
        # Determine routing based on matches and question patterns
        if "icty" in question_lower or "trial" in question_lower or "appeal" in question_lower:
            return "judgment"
        elif "geneva" in question_lower or "convention" in question_lower:
            return "geneva"
        elif judgment_matches > geneva_matches and judgment_matches > 0:
            return "judgment"
        elif geneva_matches > judgment_matches and geneva_matches > 0:
            return "geneva"
        elif judgment_matches > 0 and geneva_matches > 0:
            return "both"
        else:
            return "both"  # Default to both if no clear indicators
    
    def enhance_query(self, query: str) -> str:
        """Enhance query for better retrieval using legal terminology."""
        enhanced = query.lower()
        
        # Add legal term expansions
        for term, expansions in self.legal_expansions.items():
            if term in enhanced:
                enhanced += f" {' '.join(expansions[:2])}"
        
        return enhanced

print("✅ Enhanced ICC RAG System core defined")


✅ Enhanced ICC RAG System core defined


In [0]:
# Add missing core methods to the EnhancedICCRAGSystem class
def add_core_methods_to_rag_system():
    """Add the missing retrieve_context and generate_legal_analysis methods."""
    
    def retrieve_context(self, query: str, top_k: int = DEFAULT_TOP_K) -> "RetrievalContext":
        """Retrieve context from both indices with intelligent routing."""
        import time
        start_time = time.time()
        
        # Determine routing priority
        routing_decision = self.determine_routing_priority(query)
        
        # Enhance query for better retrieval
        enhanced_query = self.enhance_query(query)
        
        # Initialize results
        judgment_results = []
        geneva_results = []
        
        # Search based on routing decision
        if routing_decision in ["judgment", "both"]:
            judgment_results = self.search_past_judgments(enhanced_query, top_k)
        
        if routing_decision in ["geneva", "both"]:
            geneva_results = self.search_geneva_documentation(enhanced_query, top_k)
        
        # Combine and sort all results by score
        all_results = judgment_results + geneva_results
        all_results.sort(key=lambda x: x.score, reverse=True)
        
        # Limit to top results
        all_results = all_results[:top_k]
        
        processing_time = time.time() - start_time
        
        return RetrievalContext(
            question=query,
            routing_decision=routing_decision,
            judgment_results=judgment_results,
            geneva_results=geneva_results,
            all_results=all_results,
            total_sources=len(all_results),
            processing_time=processing_time
        )
    
    def generate_legal_analysis(self, question: str, context: "RetrievalContext", conversation_id: str = None) -> "LegalAnalysis":
        """Generate comprehensive legal analysis using the retrieved context."""
        import time
        start_time = time.time()
        
        # Prepare context for LLM
        context_text = self._prepare_context_for_llm(context.all_results)
        
        # Create conversation memory if needed
        if conversation_id and conversation_id not in self.conversations:
            self.conversations[conversation_id] = ConversationBufferWindowMemory(
                k=5,  # Keep last 5 exchanges
                return_messages=True
            )
        
        # Build prompt
        system_prompt = """You are an expert legal researcher specializing in International Criminal Law and International Humanitarian Law. 
        You have access to comprehensive databases of ICTY/ICC judgments and Geneva Convention documentation.
        
        Your task is to provide thorough, accurate legal analysis based on the retrieved context. Always:
        1. Cite specific sources and page numbers when available
        2. Identify key legal principles and precedents
        3. Highlight relevant case law and treaty provisions
        4. Provide clear, structured analysis
        5. Note any limitations or gaps in the available information
        
        Be precise, professional, and comprehensive in your analysis."""
        
        human_prompt = f"""Legal Research Question: {question}

Retrieved Context:
{context_text}

Please provide a comprehensive legal analysis addressing the question above. Include:
1. Key findings from the retrieved sources
2. Relevant legal principles and precedents
3. Specific citations to judgments, articles, or sections
4. Analysis of the legal framework
5. Any limitations or areas requiring further research

Format your response with clear headings and bullet points for readability."""
        
        # Generate analysis
        try:
            if conversation_id and conversation_id in self.conversations:
                # Use conversation memory
                memory = self.conversations[conversation_id]
                messages = memory.chat_memory.messages
                messages.extend([
                    SystemMessage(content=system_prompt),
                    HumanMessage(content=human_prompt)
                ])
                response = self.llm(messages)
                memory.chat_memory.add_message(HumanMessage(content=question))
                memory.chat_memory.add_message(response)
            else:
                # Direct generation
                messages = [
                    SystemMessage(content=system_prompt),
                    HumanMessage(content=human_prompt)
                ]
                response = self.llm(messages)
            
            analysis_text = response.content
            
            # Extract key findings and citations
            key_findings = self._extract_key_findings(analysis_text)
            citations = self._extract_citations(analysis_text, context.all_results)
            
            # Calculate confidence score based on source quality and quantity
            confidence_score = self._calculate_confidence_score(context.all_results, len(key_findings))
            
        except Exception as e:
            analysis_text = f"Error generating analysis: {str(e)}"
            key_findings = []
            citations = []
            confidence_score = 0.0
        
        processing_time = time.time() - start_time
        
        return LegalAnalysis(
            question=question,
            analysis=analysis_text,
            sources_used=context.all_results,
            key_findings=key_findings,
            citations=citations,
            confidence_score=confidence_score,
            processing_time=processing_time
        )
    
    def _prepare_context_for_llm(self, results: "List[SearchResult]") -> str:
        """Prepare retrieved results for LLM consumption."""
        context_parts = []
        
        for i, result in enumerate(results, 1):
            context_part = f"Source {i} ({result.source_type.upper()}):\n"
            context_part += f"Document: {result.source}\n"
            if result.section:
                context_part += f"Section: {result.section}\n"
            if result.page_number:
                context_part += f"Page: {result.page_number}\n"
            context_part += f"Relevance Score: {result.score:.3f}\n"
            context_part += f"Content: {result.content[:1000]}...\n"  # Limit content length
            context_parts.append(context_part)
        
        return "\n\n".join(context_parts)
    
    def _extract_key_findings(self, analysis_text: str) -> List[str]:
        """Extract key findings from the analysis text."""
        # Simple extraction based on common patterns
        findings = []
        lines = analysis_text.split('\n')
        
        for line in lines:
            line = line.strip()
            if (line.startswith('•') or line.startswith('-') or 
                line.startswith('1.') or line.startswith('2.') or
                'finding' in line.lower() or 'principle' in line.lower()):
                findings.append(line)
        
        return findings[:10]  # Limit to top 10 findings
    
    def _extract_citations(self, analysis_text: str, sources: "List[SearchResult]") -> List[str]:
        """Extract citations from the analysis text."""
        citations = []
        
        # Extract source references
        for source in sources:
            if source.page_number:
                citations.append(f"{source.source}, Page {source.page_number}")
            elif source.article:
                citations.append(f"{source.source}, Article {source.article}")
            else:
                citations.append(source.source)
        
        return citations[:15]  # Limit to top 15 citations
    
    def _calculate_confidence_score(self, sources: "List[SearchResult]", findings_count: int) -> float:
        """Calculate confidence score based on source quality and analysis depth."""
        if not sources:
            return 0.0
        
        # Base score from source quality
        avg_score = sum(s.score for s in sources) / len(sources)
        
        # Bonus for number of sources
        source_bonus = min(len(sources) / 10.0, 0.2)
        
        # Bonus for findings
        findings_bonus = min(findings_count / 5.0, 0.2)
        
        # Combine scores
        confidence = min(avg_score + source_bonus + findings_bonus, 1.0)
        
        return round(confidence, 3)
    
    # Add methods to the class
    EnhancedICCRAGSystem.retrieve_context = retrieve_context
    EnhancedICCRAGSystem.generate_legal_analysis = generate_legal_analysis
    EnhancedICCRAGSystem._prepare_context_for_llm = _prepare_context_for_llm
    EnhancedICCRAGSystem._extract_key_findings = _extract_key_findings
    EnhancedICCRAGSystem._extract_citations = _extract_citations
    EnhancedICCRAGSystem._calculate_confidence_score = _calculate_confidence_score
    
    print("✅ Core methods added to Enhanced ICC RAG System")

# Execute the function to add methods
add_core_methods_to_rag_system()


✅ Core methods added to Enhanced ICC RAG System


## Enhanced RAG System Core


In [0]:
# Add search methods to the EnhancedICCRAGSystem class
def add_search_methods_to_rag_system():
    """Add search methods to the RAG system class."""
    
    def search_past_judgments(self, query: str, top_k: int = DEFAULT_TOP_K) -> "List[SearchResult]":
        """Search past judgments using vector search with enhanced metadata."""
        try:
            # Use columns parameter as it's required by the API
            results = self.vsc.get_index(VECTOR_SEARCH_ENDPOINT, PAST_JUDGMENTS_INDEX).similarity_search(
                query_text=query,
                columns=["text", "summary", "doc_id", "section_type", "pages"],
                num_results=top_k
            )
            

            search_results = []
            for i, result in enumerate(results):
                try:
                    # Handle different result formats
                    if isinstance(result, str):
                        # If result is a string, create a basic SearchResult
                        search_results.append(SearchResult(
                            content=result,
                            summary="",
                            source=f"Document_{i+1}",
                            metadata={"score": 0.5},
                            score=0.5,
                            source_type="judgment",
                            page_number=None,
                            section=""
                        ))
                    elif isinstance(result, dict):
                        # If result is a dictionary, extract fields safely
                        pages = result.get("pages", [])
                        page_number = pages[0] if pages and len(pages) > 0 else None
                        
                        search_results.append(SearchResult(
                            content=result.get("text", ""),
                            summary=result.get("summary", ""),
                            source=result.get("doc_id", f"Document_{i+1}"),
                            metadata={
                                "section_type": result.get("section_type", ""),
                                "score": result.get("score", 0.0)
                            },
                            score=result.get("score", 0.0),
                            source_type="judgment",
                            page_number=page_number,
                            section=result.get("section_type", "")
                        ))
                    else:
                        # Handle other types (e.g., custom objects)
                        search_results.append(SearchResult(
                            content=str(result),
                            summary="",
                            source=f"Document_{i+1}",
                            metadata={"score": 0.5},
                            score=0.5,
                            source_type="judgment",
                            page_number=None,
                            section=""
                        ))
                except Exception as item_error:
                    print(f"Error processing result {i}: {item_error}")
                    # Create a fallback result
                    search_results.append(SearchResult(
                        content=str(result) if result else "",
                        summary="",
                        source=f"Document_{i+1}",
                        metadata={"score": 0.0},
                        score=0.0,
                        source_type="judgment",
                        page_number=None,
                        section=""
                    ))
            
            return search_results
        except Exception as e:
            print(f"Error searching past judgments: {e}")
            return []
    
    def search_geneva_documentation(self, query: str, top_k: int = DEFAULT_TOP_K) -> "List[SearchResult]":
        """Search Geneva Convention documentation using vector search."""
        try:
            # Use columns parameter as it's required by the API
            results = self.vsc.get_index(VECTOR_SEARCH_ENDPOINT, GENEVA_DOCUMENTATION_INDEX).similarity_search(
                query_text=query,
                columns=["text", "summary", "doc_name", "section_type", "pages"],
                num_results=top_k
            )
            

            search_results = []
            for i, result in enumerate(results):
                try:
                    # Handle different result formats
                    if isinstance(result, str):
                        # If result is a string, create a basic SearchResult
                        search_results.append(SearchResult(
                            content=result,
                            summary="",
                            source=f"Geneva_Document_{i+1}",
                            metadata={"score": 0.5},
                            score=0.5,
                            source_type="geneva",
                            page_number=None,
                            section=""
                        ))
                    elif isinstance(result, dict):
                        # If result is a dictionary, extract fields safely
                        pages = result.get("pages", [])
                        page_number = pages[0] if pages and len(pages) > 0 else None
                        
                        search_results.append(SearchResult(
                            content=result.get("text", ""),
                            summary=result.get("summary", ""),
                            source=result.get("doc_name", f"Geneva_Document_{i+1}"),
                            metadata={
                                "section_type": result.get("section_type", ""),
                                "score": result.get("score", 0.0)
                            },
                            score=result.get("score", 0.0),
                            source_type="geneva",
                            page_number=page_number,
                            section=result.get("section_type", "")
                        ))
                    else:
                        # Handle other types (e.g., custom objects)
                        search_results.append(SearchResult(
                            content=str(result),
                            summary="",
                            source=f"Geneva_Document_{i+1}",
                            metadata={"score": 0.5},
                            score=0.5,
                            source_type="geneva",
                            page_number=None,
                            section=""
                        ))
                except Exception as item_error:
                    print(f"Error processing Geneva result {i}: {item_error}")
                    # Create a fallback result
                    search_results.append(SearchResult(
                        content=str(result) if result else "",
                        summary="",
                        source=f"Geneva_Document_{i+1}",
                        metadata={"score": 0.0},
                        score=0.0,
                        source_type="geneva",
                        page_number=None,
                        section=""
                    ))
            
            return search_results
        except Exception as e:
            print(f"Error searching Geneva documentation: {e}")
            return []
    
    # Add methods to the class
    EnhancedICCRAGSystem.search_past_judgments = search_past_judgments
    EnhancedICCRAGSystem.search_geneva_documentation = search_geneva_documentation
    
    print("✅ Search methods added to Enhanced ICC RAG System")

# Execute the function to add methods
add_search_methods_to_rag_system()


✅ Search methods added to Enhanced ICC RAG System


## Test Legal Research Questions


In [None]:
# Test the fixed search methods
def test_fixed_search_methods():
    """Test the fixed search methods to ensure they work correctly."""
    
    print("🔧 TESTING FIXED SEARCH METHODS")
    print("=" * 50)
    
    # Initialize the system
    rag_system = EnhancedICCRAGSystem()
    
    # Test simple queries
    test_queries = [
        "What is active participation in hostilities?",
        "Geneva Convention protected persons",
        "ICTY trial judgment civilian status"
    ]
    
    for i, query in enumerate(test_queries, 1):
        print(f"\n🔍 Test Query {i}: {query}")
        print("-" * 40)
        
        try:
            # Test past judgments search
            print("Testing past judgments search...")
            judgment_results = rag_system.search_past_judgments(query, top_k=3)
            print(f"✅ Past judgments: {len(judgment_results)} results")
            
            # Test Geneva documentation search
            print("Testing Geneva documentation search...")
            geneva_results = rag_system.search_geneva_documentation(query, top_k=3)
            print(f"✅ Geneva documentation: {len(geneva_results)} results")
            
            # Test full context retrieval
            print("Testing full context retrieval...")
            context = rag_system.retrieve_context(query, top_k=5)
            print(f"✅ Context retrieval: {context.total_sources} total sources")
            print(f"   Routing decision: {context.routing_decision}")
            print(f"   Processing time: {context.processing_time:.2f}s")
            
        except Exception as e:
            print(f"❌ Error in test {i}: {e}")
    
    print(f"\n🎉 Search method testing completed!")
    return True

# Run the test
test_fixed_search_methods()


In [None]:
# Improved search methods with proper column handling
def add_improved_search_methods():
    """Add improved search methods that can handle columns parameter properly."""
    
    def search_past_judgments_improved(self, query: str, top_k: int = DEFAULT_TOP_K) -> "List[SearchResult]":
        """Improved search past judgments with better column handling."""
        try:
            # Use columns parameter as it's required by the API
            results = self.vsc.get_index(VECTOR_SEARCH_ENDPOINT, PAST_JUDGMENTS_INDEX).similarity_search(
                query_text=query,
                columns=["text", "summary", "doc_id", "section_type", "pages"],
                num_results=top_k
            )
            
            search_results = []
            for i, result in enumerate(results):
                try:
                    if isinstance(result, dict):
                        # Extract fields safely
                        pages = result.get("pages", [])
                        page_number = pages[0] if pages and len(pages) > 0 else None
                        
                        search_results.append(SearchResult(
                            content=result.get("text", ""),
                            summary=result.get("summary", ""),
                            source=result.get("doc_id", f"Document_{i+1}"),
                            metadata={
                                "section_type": result.get("section_type", ""),
                                "score": result.get("score", 0.0)
                            },
                            score=result.get("score", 0.0),
                            source_type="judgment",
                            page_number=page_number,
                            section=result.get("section_type", "")
                        ))
                    else:
                        # Handle non-dict results
                        search_results.append(SearchResult(
                            content=str(result),
                            summary="",
                            source=f"Document_{i+1}",
                            metadata={"score": 0.5},
                            score=0.5,
                            source_type="judgment",
                            page_number=None,
                            section=""
                        ))
                except Exception as item_error:
                    print(f"Error processing result {i}: {item_error}")
                    search_results.append(SearchResult(
                        content=str(result) if result else "",
                        summary="",
                        source=f"Document_{i+1}",
                        metadata={"score": 0.0},
                        score=0.0,
                        source_type="judgment",
                        page_number=None,
                        section=""
                    ))
            
            return search_results
        except Exception as e:
            print(f"Error searching past judgments: {e}")
            return []
    
    def search_geneva_documentation_improved(self, query: str, top_k: int = DEFAULT_TOP_K) -> "List[SearchResult]":
        """Improved search Geneva documentation with better column handling."""
        try:
            # Use columns parameter as it's required by the API
            results = self.vsc.get_index(VECTOR_SEARCH_ENDPOINT, GENEVA_DOCUMENTATION_INDEX).similarity_search(
                query_text=query,
                columns=["text", "summary", "doc_name", "section_type", "pages"],
                num_results=top_k
            )
            
            search_results = []
            for i, result in enumerate(results):
                try:
                    if isinstance(result, dict):
                        # Extract fields safely
                        pages = result.get("pages", [])
                        page_number = pages[0] if pages and len(pages) > 0 else None
                        
                        search_results.append(SearchResult(
                            content=result.get("text", ""),
                            summary=result.get("summary", ""),
                            source=result.get("doc_name", f"Geneva_Document_{i+1}"),
                            metadata={
                                "section_type": result.get("section_type", ""),
                                "score": result.get("score", 0.0)
                            },
                            score=result.get("score", 0.0),
                            source_type="geneva",
                            page_number=page_number,
                            section=result.get("section_type", "")
                        ))
                    else:
                        # Handle non-dict results
                        search_results.append(SearchResult(
                            content=str(result),
                            summary="",
                            source=f"Geneva_Document_{i+1}",
                            metadata={"score": 0.5},
                            score=0.5,
                            source_type="geneva",
                            page_number=None,
                            section=""
                        ))
                except Exception as item_error:
                    print(f"Error processing Geneva result {i}: {item_error}")
                    search_results.append(SearchResult(
                        content=str(result) if result else "",
                        summary="",
                        source=f"Geneva_Document_{i+1}",
                        metadata={"score": 0.0},
                        score=0.0,
                        source_type="geneva",
                        page_number=None,
                        section=""
                    ))
            
            return search_results
        except Exception as e:
            print(f"Error searching Geneva documentation: {e}")
            return []
    
    # Add improved methods to the class
    EnhancedICCRAGSystem.search_past_judgments_improved = search_past_judgments_improved
    EnhancedICCRAGSystem.search_geneva_documentation_improved = search_geneva_documentation_improved
    
    print("✅ Improved search methods added to Enhanced ICC RAG System")

# Execute the function to add improved methods
add_improved_search_methods()


In [None]:
# Quick test to verify the columns parameter fix
def quick_test_columns_fix():
    """Quick test to verify the columns parameter fix works."""
    
    print("🔧 QUICK TEST - COLUMNS PARAMETER FIX")
    print("=" * 50)
    
    # Initialize the system
    rag_system = EnhancedICCRAGSystem()
    
    # Test a simple query
    test_query = "active participation in hostilities"
    
    try:
        print(f"Testing query: '{test_query}'")
        print("-" * 30)
        
        # Test past judgments search
        print("Testing past judgments search...")
        judgment_results = rag_system.search_past_judgments(test_query, top_k=2)
        print(f"✅ Past judgments: {len(judgment_results)} results")
        
        if judgment_results:
            print(f"   First result source: {judgment_results[0].source}")
            print(f"   First result score: {judgment_results[0].score}")
        
        # Test Geneva documentation search
        print("Testing Geneva documentation search...")
        geneva_results = rag_system.search_geneva_documentation(test_query, top_k=2)
        print(f"✅ Geneva documentation: {len(geneva_results)} results")
        
        if geneva_results:
            print(f"   First result source: {geneva_results[0].source}")
            print(f"   First result score: {geneva_results[0].score}")
        
        print(f"\n🎉 All tests passed! The columns parameter fix is working.")
        
    except Exception as e:
        print(f"❌ Error during test: {e}")
        import traceback
        traceback.print_exc()

# Run the quick test
quick_test_columns_fix()


In [0]:
# Test the Enhanced RAG System with complex legal research questions
def test_enhanced_rag_system():
    """Test the enhanced RAG system with the provided legal research questions."""
    
    # Initialize the system
    rag_system = EnhancedICCRAGSystem()
    
    # Complex legal research queries
    test_questions = [
        {
            "question": "Can you please go through all the ICTY trial judgments and appeal judgments and identify where the chamber discusses the status of an individual during the conflict. In particular, please identify all relevant paragraphs where the chamber refers to the active or direct participation of the individual or where the chamber discusses the civilian status or combatant status of an individual. Please provide the direct paragraph in full.",
            "expected_routing": "judgment",
            "key_topics": ["active participation", "direct participation", "civilian status", "combatant status", "ICTY", "trial judgments", "appeal judgments"]
        },
        {
            "question": "Can you please go through all the ICTY trial judgments and appeal judgments and identify which factors the Trial or Appeals Chamber relied on in order to assess whether an individual is actively or directly participating in hostilities at a particular point? Please provide the full paragraph and citations",
            "expected_routing": "judgment", 
            "key_topics": ["factors", "assessment", "actively participating", "directly participating", "hostilities", "Trial Chamber", "Appeals Chamber", "citations"]
        },
        {
            "question": "Can you please search through all the ICTY trial judgments and appeal judgments and identify relevant paragraphs which would support the proposition that an individual who has previously joined enemy forces and is armed at the relevant point is considered to have lost their protected status at a particular point? Please determine whether the chamber undertakes a subjective or objective assessment?",
            "expected_routing": "judgment",
            "key_topics": ["enemy forces", "armed", "protected status", "subjective assessment", "objective assessment", "lost status"]
        }
    ]
    
    print("🧪 TESTING ENHANCED ICC RAG SYSTEM")
    print("=" * 80)
    
    results = []
    
    for i, query_info in enumerate(test_questions, 1):
        print(f"\n{'#'*80}")
        print(f"LEGAL RESEARCH QUESTION {i}")
        print(f"{'#'*80}")
        print(f"Question: {query_info['question'][:150]}...")
        print(f"Expected routing: {query_info['expected_routing']}")
        print(f"Key topics: {', '.join(query_info['key_topics'])}")
        
        # Retrieve context
        context = rag_system.retrieve_context(query_info["question"], top_k=8)
        
        # Generate legal analysis
        analysis = rag_system.generate_legal_analysis(
            query_info["question"], 
            context, 
            conversation_id=f"test_session_{i}"
        )
        
        # Display results
        print(f"\n📊 ROUTING ANALYSIS:")
        print(f"Expected: {query_info['expected_routing']}")
        print(f"Actual: {context.routing_decision}")
        print(f"Sources found: {context.total_sources}")
        print(f"Processing time: {context.processing_time:.2f}s")
        
        print(f"\n⚖️ LEGAL ANALYSIS:")
        print(f"Confidence score: {analysis.confidence_score:.3f}")
        print(f"Key findings: {len(analysis.key_findings)}")
        print(f"Citations: {len(analysis.citations)}")
        print(f"Analysis length: {len(analysis.analysis)} characters")
        
        print(f"\n📝 ANALYSIS PREVIEW:")
        print(analysis.analysis[:500] + "..." if len(analysis.analysis) > 500 else analysis.analysis)
        
        print(f"\n🔍 KEY FINDINGS:")
        for j, finding in enumerate(analysis.key_findings[:3], 1):
            print(f"{j}. {finding}")
        
        print(f"\n📚 CITATIONS:")
        for j, citation in enumerate(analysis.citations[:5], 1):
            print(f"{j}. {citation}")
        
        results.append({
            "question_id": i,
            "question": query_info["question"],
            "routing_decision": context.routing_decision,
            "sources_found": context.total_sources,
            "confidence_score": analysis.confidence_score,
            "analysis_length": len(analysis.analysis),
            "key_findings_count": len(analysis.key_findings),
            "citations_count": len(analysis.citations),
            "processing_time": context.processing_time + analysis.processing_time
        })
        
        print(f"\n{'#'*80}\n")
    
    # Summary
    print("📊 TEST SUMMARY")
    print("=" * 50)
    for result in results:
        print(f"Question {result['question_id']}: {result['routing_decision']} routing, "
              f"{result['sources_found']} sources, {result['confidence_score']:.3f} confidence, "
              f"{result['processing_time']:.2f}s")
    
    return results

# Run the test
test_results = test_enhanced_rag_system()


[NOTICE] Using a notebook authentication token. Recommended for development only. For improved performance, please use Service Principal based authentication. To disable this message, pass disable_notice=True.
🧪 TESTING ENHANCED ICC RAG SYSTEM

################################################################################
LEGAL RESEARCH QUESTION 1
################################################################################
Question: Can you please go through all the ICTY trial judgments and appeal judgments and identify where the chamber discusses the status of an individual durin...
Expected routing: judgment
Key topics: active participation, direct participation, civilian status, combatant status, ICTY, trial judgments, appeal judgments
[NOTICE] Using a notebook authentication token. Recommended for development only. For improved performance, please use Service Principal based authentication. To disable this message, pass disable_notice=True.
Error searching past judgments: 'st

## MLflow 3.0 Production Model


In [0]:
class EnhancedICCRAGModel(mlflow.pyfunc.PythonModel):
    """MLflow 3.0 production model wrapper for Enhanced ICC RAG System."""
    
    def load_context(self, context):
        """Initialize the enhanced RAG system."""
        self.rag_system = EnhancedICCRAGSystem()
    
    def predict(self, context, model_input: pd.DataFrame) -> List[Dict]:
        """Handle predictions for serving endpoint."""
        try:
            queries = model_input["query"].tolist()
            
            # Extract optional parameters
            num_results_list = model_input.get("num_results", [8] * len(queries)).tolist()
            conversation_ids = model_input.get("conversation_id", [None] * len(queries)).tolist()
            
            results = []
            for query, num_results, conv_id in zip(queries, num_results_list, conversation_ids):
                try:
                    # Retrieve context
                    context = self.rag_system.retrieve_context(
                        query=query,
                        top_k=num_results if pd.notna(num_results) else 8
                    )
                    
                    # Generate legal analysis
                    analysis = self.rag_system.generate_legal_analysis(
                        question=query,
                        context=context,
                        conversation_id=conv_id if pd.notna(conv_id) else None
                    )
                    
                    # Format response
                    result = {
                        "question": query,
                        "analysis": analysis.analysis,
                        "routing_decision": context.routing_decision,
                        "sources_used": len(analysis.sources_used),
                        "confidence_score": analysis.confidence_score,
                        "key_findings": analysis.key_findings,
                        "citations": analysis.citations,
                        "processing_time_seconds": context.processing_time + analysis.processing_time,
                        "conversation_id": conv_id,
                        "sources": [
                            {
                                "source": s.source,
                                "source_type": s.source_type,
                                "section": s.section,
                                "page_number": s.page_number,
                                "article": s.article,
                                "relevance_score": round(s.score, 3)
                            }
                            for s in analysis.sources_used[:10]  # Top 10 sources
                        ]
                    }
                    results.append(result)
                    
                except Exception as e:
                    # Handle individual query errors
                    error_result = {
                        "question": query,
                        "analysis": f"Error processing query: {str(e)}",
                        "routing_decision": "error",
                        "sources_used": 0,
                        "confidence_score": 0.0,
                        "key_findings": [],
                        "citations": [],
                        "processing_time_seconds": 0,
                        "conversation_id": conv_id,
                        "sources": []
                    }
                    results.append(error_result)
            
            return results
            
        except Exception as e:
            return [{"error": f"Model error: {str(e)}"}] * len(model_input)

print("✅ Enhanced ICC RAG Model for MLflow 3.0 defined")




In [0]:
# Register the Enhanced ICC RAG Model in MLflow 3.0
with mlflow.start_run(run_name="Enhanced_ICC_RAG_Production") as run:
    
    # Create model instance
    production_model = EnhancedICCRAGModel()
    
    # Input example for serving endpoint
    input_example = pd.DataFrame({
        "query": [
            "Can you please go through all the ICTY trial judgments and identify where the chamber discusses the status of an individual during the conflict?",
            "What factors did the Trial Chamber rely on to assess active participation in hostilities?"
        ],
        "num_results": [10, 12],
        "conversation_id": ["legal_research_001", "legal_research_001"]
    })
    
    # Expected output format
    output_example = [
        {
            "question": "Sample legal question",
            "analysis": "Comprehensive legal analysis based on retrieved context...",
            "routing_decision": "judgment",
            "sources_used": 8,
            "confidence_score": 0.85,
            "key_findings": ["Key legal finding 1", "Key legal finding 2"],
            "citations": ["Article 8", "Page 123", "Section A"],
            "processing_time_seconds": 5.2,
            "conversation_id": "legal_research_001",
            "sources": [
                {
                    "source": "ICTY_Judgment_001.pdf",
                    "source_type": "judgment",
                    "section": "FINDINGS_OF_FACT",
                    "page_number": 123,
                    "article": None,
                    "relevance_score": 0.95
                }
            ]
        }
    ]
    
    # Log the model using MLflow 3.0 syntax
    mlflow.pyfunc.log_model(
        name="icc_chatbot.search_model.enhanced_icc_rag_model",
        python_model=production_model,
        input_example=input_example,
        signature=infer_signature(input_example, output_example),
        resources=[
            DatabricksVectorSearchIndex(index_name=PAST_JUDGMENTS_INDEX),
            DatabricksVectorSearchIndex(index_name=GENEVA_DOCUMENTATION_INDEX),
            DatabricksServingEndpoint(endpoint_name=LLM_MODEL_ENDPOINT)
        ],
        pip_requirements=[
            "mlflow>=3.1.1",
            "langchain",
            "databricks-langchain",
            "numpy",
            "pandas",
            "pydantic"
        ]
    )
    
    # Register model in Unity Catalog
    model_uri = f"runs:/{run.info.run_id}/enhanced_icc_rag_model"
    registered_model = mlflow.register_model(
        model_uri=model_uri,
        name="icc_chatbot.search_model.enhanced_icc_rag_legal_research"
    )
    
    print(f"✅ Model logged: {run.info.run_id}")
    print(f"🔗 Model URI: {model_uri}")
    print(f"📦 Model registered: {registered_model.name} v{registered_model.version}")
    print(f"🌐 View in Unity Catalog: https://dbc-0619d7f5-0bda.cloud.databricks.com/explore/data/models/{registered_model.name}/version/{registered_model.version}")




## Usage Examples & Deployment Instructions


In [0]:
def show_usage_examples():
    """Show comprehensive usage examples for the Enhanced ICC RAG System."""
    
    print("🚀 ENHANCED ICC RAG SYSTEM - USAGE EXAMPLES")
    print("=" * 60)
    
    print("\n📋 1. LOCAL USAGE:")
    print("""
# Initialize the system
rag_system = EnhancedICCRAGSystem()

# Simple legal research query
question = "What are the elements of crimes against humanity?"
context = rag_system.retrieve_context(question, top_k=8)
analysis = rag_system.generate_legal_analysis(question, context)

print(f"Analysis: {analysis.analysis}")
print(f"Confidence: {analysis.confidence_score}")
print(f"Sources: {len(analysis.sources_used)}")
""")
    
    print("\n📋 2. CONVERSATIONAL USAGE:")
    print("""
# Multi-turn conversation
conversation_id = "legal_research_session_001"

# First question
question1 = "How has the principle of proportionality been applied in ICTY judgments?"
context1 = rag_system.retrieve_context(question1, top_k=10)
analysis1 = rag_system.generate_legal_analysis(question1, context1, conversation_id)

# Follow-up question (with memory)
question2 = "What factors did the chamber consider in those cases?"
context2 = rag_system.retrieve_context(question2, top_k=8)
analysis2 = rag_system.generate_legal_analysis(question2, context2, conversation_id)
""")
    
    print("\n📋 3. SERVING ENDPOINT USAGE:")
    print("""
# Deploy to serving endpoint
import requests

endpoint_url = "https://your-workspace.cloud.databricks.com/serving-endpoints/enhanced-icc-rag/invocations"
headers = {"Authorization": "Bearer YOUR_TOKEN"}

# Single query
payload = {
    "dataframe_split": {
        "columns": ["query", "num_results", "conversation_id"],
        "data": [["What are the requirements for combatant status?", 10, "session_001"]]
    }
}

response = requests.post(endpoint_url, headers=headers, json=payload)
result = response.json()["predictions"][0]

print(f"Analysis: {result['analysis']}")
print(f"Routing: {result['routing_decision']}")
print(f"Sources: {result['sources_used']}")
print(f"Confidence: {result['confidence_score']}")
""")
    
    print("\n📋 4. BATCH PROCESSING:")
    print("""
# Multiple legal research questions
batch_payload = {
    "dataframe_split": {
        "columns": ["query", "num_results", "conversation_id"],
        "data": [
            ["Can you identify all ICTY judgments discussing civilian status?", 12, "batch_001"],
            ["What factors determine active participation in hostilities?", 10, "batch_001"],
            ["How do chambers assess subjective vs objective criteria?", 8, "batch_001"]
        ]
    }
}

response = requests.post(endpoint_url, headers=headers, json=batch_payload)
results = response.json()["predictions"]

for i, result in enumerate(results, 1):
    print(f"Question {i}: {result['routing_decision']} routing, {result['sources_used']} sources")
""")
    
    print("\n📋 5. DEPLOYMENT INSTRUCTIONS:")
    print("""
# Step 1: Create serving endpoint
# Go to Databricks UI > Serving > Create Endpoint
# Select the registered model: enhanced_icc_rag_legal_research
# Configure compute and scaling

# Step 2: Test endpoint
# Use the test queries provided above
# Monitor performance and adjust scaling as needed

# Step 3: Integration
# Integrate with your legal research workflow
# Use conversation_id for multi-turn research sessions
# Monitor confidence scores for quality assurance
""")
    
    print("\n📋 6. OPTIMAL CONFIGURATION:")
    print("""
# Query types and recommended num_results:
# - Complex legal research: 10-15
# - Specific case law queries: 8-12  
# - Geneva Convention queries: 6-10
# - Factual questions: 4-8

# Routing decisions:
# - "judgment": ICTY/ICC case law queries
# - "geneva": International humanitarian law queries  
# - "both": Comparative legal analysis

# Confidence scores:
# - >0.8: High confidence, reliable analysis
# - 0.6-0.8: Good confidence, review recommended
# - <0.6: Low confidence, additional research needed
""")

# Show usage examples
show_usage_examples()


