In [None]:
import os
import re
import torch
import numpy as np
import asyncio
import time
import uuid
from datetime import datetime
from enum import Enum
from functools import lru_cache
from typing import List, Dict, Optional, Set, Tuple

from fastapi import FastAPI, HTTPException, BackgroundTasks, status
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from rank_bm25 import BM25Okapi
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
from sentence_transformers import SentenceTransformer
import uvicorn
import logging
import json

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("legal_rag.log"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# Initialize FastAPI
app = FastAPI(
    title="Tunisian Legal RAG API",
    description="A Retrieval-Augmented Generation system for Tunisian legal documents",
    version="2.0.0"
)

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# ================= Configuration =================
class ModelType(str, Enum):
    FLAN_T5_LARGE = "google/flan-t5-large"
    DEEPSEEK_LLM = "deepseek-ai/deepseek-llm-7b"
    MISTRAL = "mistralai/Mistral-7B-v0.1"

MODEL_CONFIG = {
    "model_id": "google/flan-t5-large",
    "cache_dir": "./model_cache",
    "quantize": False,
    "use_gpu": True if torch.cuda.is_available() else False,
    "max_length": 512,
    "temperature": 0.7,
    "top_p": 0.95
}

DOC_CONFIG = {
    "criminal": {
        "path": "./Tunisia_Crim-Fr.pdf",
        "chunk_size": 512,
        "chunk_overlap": 100
    },
    "constitution": {
        "path": "./Constitution_fr.pdf",
        "chunk_size": 512,
        "chunk_overlap": 100
    }
}

LEGAL_TERMS = {
    "imprescriptibilité": ["prescription", "délai", "expiration", "crime", "poursuites"],
    "droit": ["liberté", "garantie", "protection"],
    "justice": ["tribunal", "magistrat", "judiciaire"],
}

# ================= Data Models =================
class FeedbackType(str, Enum):
    POSITIVE = "positive"
    NEGATIVE = "negative"
    CORRECTION = "correction"

class QueryRequest(BaseModel):
    query: str = Field(..., description="The legal question to ask")
    field: str = Field(..., description="Legal domain (criminal/constitution)")
    language: str = Field("fr", description="Response language (fr/en)")
    top_k: int = Field(3, description="Number of documents to retrieve", ge=1, le=5)
    max_tokens: int = Field(150, description="Max tokens to generate", ge=50, le=500)
    temperature: float = Field(0.7, description="Generation temperature", ge=0.1, le=1.0)
    enable_reflection: bool = Field(True, description="Enable self-reflection")

class RetrievedDocument(BaseModel):
    text: str
    score: float
    metadata: Dict

class QueryResponse(BaseModel):
    answer: str
    retrieved_documents: List[RetrievedDocument]
    query_time_ms: float
    query_id: str
    reflection: Optional[str] = None
    cache_hit: bool = False

class FeedbackRequest(BaseModel):
    query_id: str = Field(..., description="ID of the query being rated")
    feedback_type: FeedbackType = Field(..., description="Type of feedback")
    correction_text: Optional[str] = Field(None, description="Corrected answer if type is correction")
    comments: Optional[str] = Field(None, description="Additional feedback comments")

class ReflectionRequest(BaseModel):
    query_id: str = Field(..., description="ID of query to reflect on")
    reflection_prompt: Optional[str] = Field(
        default="Analyze the quality of this response and suggest improvements:",
        description="Optional custom reflection prompt"
    )

# ================= Core Components =================
class DocumentCache:
    def __init__(self):
        self.texts = {}
        self.bm25 = {}
        self.embeddings = {}
        self.articles = {}
        self.metadata = {}
        self.last_loaded = None
        self.responses = {}
        self.feedback = {}
        self.query_metadata = {}

cache = DocumentCache()

class ModelManager:
    def __init__(self, config):
        self.config = config
        self.model = None
        self.tokenizer = None
        self.pipeline = None
        self.embedding_model = None
        self.is_loaded = False

    async def load_models(self):
        logger.info(f"Loading model: {self.config['model_id']}")
        
        try:
            # Load tokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.config["model_id"],
                cache_dir=self.config["cache_dir"]
            )

            # Load model with appropriate class
            if "flan-t5" in self.config["model_id"].lower():
                model_class = AutoModelForSeq2SeqLM
            else:
                model_class = AutoModelForCausalLM

            self.model = model_class.from_pretrained(
                self.config["model_id"],
                cache_dir=self.config["cache_dir"]
            )

            # Quantization
            if self.config["quantize"]:
                logger.info("Applying quantization...")
                self.model = torch.quantization.quantize_dynamic(
                    self.model,
                    {torch.nn.Linear},
                    dtype=torch.float16
                )

            # GPU support
            if torch.cuda.is_available():
                logger.info("Moving model to GPU...")
                self.model.to('cuda')

            # Create pipeline
            device = 0 if torch.cuda.is_available() else -1
            self.pipeline = pipeline(
                task="text2text-generation" if "flan-t5" in self.config["model_id"].lower() else "text-generation",
                model=self.model,
                tokenizer=self.tokenizer,
                device=device
            )

            # Load embedding model
            logger.info("Loading embedding model...")
            self.embedding_model = SentenceTransformer(
                'paraphrase-multilingual-MiniLM-L12-v2',
                device='cuda' if torch.cuda.is_available() else 'cpu'
            )

            self.is_loaded = True
            logger.info("Models loaded successfully")
        except Exception as e:
            logger.error(f"Model loading failed: {str(e)}")
            raise

    @lru_cache(maxsize=1000)
    def generate_response(self, prompt: str, **kwargs) -> str:
        """Generate response with caching"""
        if not self.is_loaded:
            raise ValueError("Model not loaded")

        # Default generation parameters
        generation_params = {
            "max_new_tokens": kwargs.get("max_tokens", self.config["max_length"]),
            "temperature": kwargs.get("temperature", self.config["temperature"]),
            "top_p": kwargs.get("top_p", self.config["top_p"]),
            "do_sample": True,
            "repetition_penalty": 1.2,
            "num_beams": 5,
            "early_stopping": True
        }

        try:
            response = self.pipeline(prompt, **generation_params)
            
            if isinstance(response, list):
                generated_text = response[0]["generated_text"]
            else:
                generated_text = response
            
            # Remove prompt if it's included in response
            if generated_text.startswith(prompt):
                generated_text = generated_text[len(prompt):].strip()
                
            return generated_text
        except Exception as e:
            logger.error(f"Generation failed: {str(e)}")
            raise

    async def get_embeddings(self, texts: List[str]) -> np.ndarray:
        """Generate embeddings for texts"""
        if not self.embedding_model:
            raise ValueError("Embedding model not loaded")

        try:
            embeddings = self.embedding_model.encode(
                texts,
                convert_to_tensor=True,
                show_progress_bar=False
            )
            return embeddings.cpu().numpy()
        except Exception as e:
            logger.error(f"Embedding generation failed: {str(e)}")
            raise

class DocumentManager:
    def __init__(self, config):
        self.config = config

    async def load_documents(self, force_reload: bool = False) -> bool:
        """Load and process all documents"""
        current_time = datetime.now()

        # Skip if recently loaded
        if not force_reload and cache.last_loaded and \
           (current_time - cache.last_loaded).total_seconds() < 3600 and \
           cache.texts:
            logger.info("Using cached documents")
            return True

        logger.info("Loading and processing documents...")
        success = False

        for field, field_config in self.config.items():
            pdf_path = field_config["path"]
            
            if not os.path.exists(pdf_path):
                logger.error(f"Document not found: {pdf_path}")
                continue

            try:
                await self._process_document(
                    pdf_path,
                    field,
                    field_config["chunk_size"],
                    field_config["chunk_overlap"]
                )
                success = True
            except Exception as e:
                logger.error(f"Error processing {field}: {str(e)}")

        if success:
            cache.last_loaded = current_time
            logger.info(f"Documents loaded. Available fields: {list(cache.texts.keys())}")
            return True
        return False

    async def _process_document(self, pdf_path: str, field: str, 
                              chunk_size: int, chunk_overlap: int):
        """Process a single document"""
        loader = PyPDFLoader(pdf_path)
        pages = loader.load()

        # Process in parallel
        await asyncio.gather(
            self._process_regular_chunks(pages, field, chunk_size, chunk_overlap),
            self._process_article_chunks(pages, field)
        )

    async def _process_regular_chunks(self, pages: List, field: str,
                                    chunk_size: int, chunk_overlap: int):
        """Process document into searchable chunks"""
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap
        )
        documents = text_splitter.split_documents(pages)

        # Extract content and metadata
        texts = []
        metadata = []
        for doc in documents:
            texts.append(doc.page_content)
            metadata.append({
                "page": doc.metadata.get("page", 0),
                "source": os.path.basename(doc.metadata.get("source", "")),
                "id": f"{field}_{doc.metadata.get('page', 0)}_{len(texts)}"
            })

        # Store in cache
        cache.texts[field] = texts
        cache.metadata[field] = metadata

        # Create BM25 index
        tokenized_texts = [text.split() for text in texts]
        cache.bm25[field] = BM25Okapi(tokenized_texts)

        # Generate embeddings
        cache.embeddings[field] = await model_manager.get_embeddings(texts)

        logger.info(f"Processed {len(texts)} chunks for {field}")

    async def _process_article_chunks(self, pages: List, field: str):
        """Extract articles from document"""
        article_pattern = re.compile(r'Article\s+(\d+)\s*[:.]', re.IGNORECASE)
        cache.articles.setdefault(field, {})

        for page in pages:
            content = page.page_content
            article_matches = list(article_pattern.finditer(content))

            for i, match in enumerate(article_matches):
                article_num = match.group(1)
                start_pos = match.start()
                end_pos = article_matches[i+1].start() if i < len(article_matches)-1 else len(content)
                article_text = content[start_pos:end_pos].strip()

                cache.articles[field][article_num] = {
                    "text": article_text,
                    "metadata": {
                        "page": page.metadata.get("page", 0),
                        "source": os.path.basename(page.metadata.get("source", "")),
                        "id": f"{field}_article_{article_num}",
                        "article": int(article_num)
                    }
                }

        logger.info(f"Extracted {len(cache.articles[field])} articles from {field}")

class QueryProcessor:
    def __init__(self, model_manager: ModelManager):
        self.model_manager = model_manager

    async def preprocess_query(self, query: str) -> Dict:
        """Analyze and extract key information from query"""
        # Extract article number if specified
        article_match = re.search(r'article\s+(\d+)', query, re.IGNORECASE)
        
        # Extract mentioned legal concepts
        query_lower = query.lower()
        mentioned_concepts = []
        for concept, related_terms in LEGAL_TERMS.items():
            if concept.lower() in query_lower:
                mentioned_concepts.append(concept)
            else:
                if any(term.lower() in query_lower for term in related_terms):
                    mentioned_concepts.append(concept)

        # Extract keywords (excluding stopwords)
        stop_words = {"le", "la", "les", "un", "une", "des", "de", "du", "et", "ou", "à", "en"}
        words = [word.lower() for word in re.findall(r'\b\w+\b', query)
                if word.lower() not in stop_words and len(word) > 2]

        return {
            "original_query": query,
            "article_num": article_match.group(1) if article_match else None,
            "concepts": mentioned_concepts,
            "keywords": words
        }

    async def search(self, query: str, field: str, top_k: int = 3) -> List[Dict]:
        """Search for relevant documents using hybrid search"""
        # Validate field exists
        if field not in cache.texts:
            raise ValueError(f"Field '{field}' not found in documents")

        # Preprocess query
        query_info = await self.preprocess_query(query)
        
        # Check for specific article request
        if query_info["article_num"] and field in cache.articles:
            article_result = await self._find_specific_article(
                query_info["article_num"], field
            )
            if article_result:
                general_results = await self._general_search(
                    query, field, top_k-1
                )
                # Combine results ensuring uniqueness
                combined = [article_result]
                seen_texts = {article_result["text"]}
                for res in general_results:
                    if res["text"] not in seen_texts:
                        combined.append(res)
                        seen_texts.add(res["text"])
                return combined[:top_k]

        # Perform general search
        return await self._general_search(query, field, top_k)

    async def _find_specific_article(self, article_num: str, field: str) -> Optional[Dict]:
        """Find article by number if exists"""
        if field in cache.articles and article_num in cache.articles[field]:
            article_data = cache.articles[field][article_num]
            return {
                "text": article_data["text"],
                "score": 1.0,
                "metadata": article_data["metadata"]
            }
        return None

    async def _general_search(self, query: str, field: str, top_k: int) -> List[Dict]:
        """Perform hybrid BM25 + semantic search"""
        # Run both searches in parallel
        bm25_results, semantic_results = await asyncio.gather(
            self._bm25_search(query, field, top_k),
            self._semantic_search(query, field, top_k)
        )

        # Combine results with weighted scores
        combined_scores = {}
        
        # BM25 results (60% weight)
        for idx, score in bm25_results:
            combined_scores[idx] = 0.6 * score
            
        # Semantic results (40% weight)
        for idx, score in semantic_results:
            combined_scores[idx] = combined_scores.get(idx, 0) + 0.4 * score

        # Sort by combined score
        sorted_indices = sorted(combined_scores.keys(), 
                               key=lambda x: combined_scores[x], 
                               reverse=True)[:top_k]

        # Prepare results with metadata
        results = []
        for idx in sorted_indices:
            if idx < len(cache.texts[field]):
                results.append({
                    "text": cache.texts[field][idx],
                    "score": float(combined_scores[idx]),
                    "metadata": cache.metadata[field][idx] if idx < len(cache.metadata[field]) else {}
                })

        return results

    async def _bm25_search(self, query: str, field: str, top_k: int) -> List[Tuple[int, float]]:
        """BM25 keyword search"""
        bm25 = cache.bm25[field]
        scores = bm25.get_scores(query.split())
        top_indices = np.argsort(scores)[-top_k:][::-1]
        return [(idx, scores[idx]) for idx in top_indices]

    async def _semantic_search(self, query: str, field: str, top_k: int) -> List[Tuple[int, float]]:
        """Semantic vector search"""
        # Ensure embeddings exist
        if field not in cache.embeddings:
            cache.embeddings[field] = await model_manager.get_embeddings(cache.texts[field])

        # Get query embedding
        query_embedding = (await model_manager.get_embeddings([query]))[0]
        doc_embeddings = cache.embeddings[field]

        # Calculate cosine similarities
        similarities = np.zeros(len(doc_embeddings))
        for i, emb in enumerate(doc_embeddings):
            norm = np.linalg.norm(query_embedding) * np.linalg.norm(emb)
            if norm > 0:
                similarities[i] = np.dot(query_embedding, emb) / norm

        top_indices = np.argsort(similarities)[-top_k:][::-1]
        return [(idx, similarities[idx]) for idx in top_indices]

    async def create_prompt(self, query: str, retrieved_docs: List[Dict], query_info: Dict, language: str = "fr") -> str:
        """Create optimized prompt for generation"""
        # Format context with document references
        context_parts = []
        for i, doc in enumerate(retrieved_docs):
            # Add article info if available
            article_info = ""
            meta = doc.get("metadata", {})
            if "article" in meta:
                article_info = f" (Article {meta['article']})"
            elif "id" in meta and "article" in meta["id"]:
                article_match = re.search(r'article_(\d+)', meta["id"])
                if article_match:
                    article_info = f" (Article {article_match.group(1)})"

            context_parts.append(f"Document {i+1}{article_info}:\n{doc['text']}")

        context = "\n\n".join(context_parts)

        # Language-specific prompt engineering
        if language == "fr":
            # Article-specific guidance
            article_specific = ""
            if query_info["article_num"]:
                article_specific = (
                    f"La question concerne spécifiquement l'Article {query_info['article_num']}. "
                    "Si cet article ne traite pas du sujet demandé, précisez-le clairement."
                )

            # Legal concept guidance
            concept_guidance = ""
            if query_info["concepts"]:
                concepts_list = ", ".join(query_info["concepts"])
                concept_guidance = (
                    f"Concepts juridiques pertinents: {concepts_list}. "
                    "Analysez leur présence dans les documents."
                )

            prompt = (
                f"Question : {query}\n\n"
                f"Contexte juridique :\n{context}\n\n"
                "Instructions :\n"
                "1. Répondez précisément à la question en vous basant sur le contexte\n"
                "2. Citez les articles pertinents\n"
                "3. Soyez clair et concis\n"
                "4. Si l'information est absente, indiquez-le\n\n"
                f"{article_specific}\n"
                f"{concept_guidance}\n"
                "Réponse :"
            )
        else:
            # English version
            article_specific = ""
            if query_info["article_num"]:
                article_specific = (
                    f"The question specifically asks about Article {query_info['article_num']}. "
                    "If this article doesn't address the topic, state this clearly."
                )

            concept_guidance = ""
            if query_info["concepts"]:
                concepts_list = ", ".join(query_info["concepts"])
                concept_guidance = (
                    f"Relevant legal concepts: {concepts_list}. "
                    "Analyze their presence in the documents."
                )

            prompt = (
                f"Question: {query}\n\n"
                f"Legal Context:\n{context}\n\n"
                "Instructions:\n"
                "1. Answer the question precisely based on the context\n"
                "2. Cite relevant articles\n"
                "3. Be clear and concise\n"
                "4. If information is missing, state so\n\n"
                f"{article_specific}\n"
                f"{concept_guidance}\n"
                "Answer:"
            )

        return prompt

    async def post_process_response(self, response: str, query_info: Dict,
                                 retrieved_docs: List[Dict], language: str = "fr") -> str:
        """Validate and improve the generated response"""
        # Check if specific article was requested but not properly addressed
        if query_info.get("article_num"):
            article_ref = f"Article {query_info['article_num']}"
            article_found = any(
                doc.get("metadata", {}).get("article") == query_info["article_num"]
                or f"article_{query_info['article_num']}" in doc.get("metadata", {}).get("id", "")
                for doc in retrieved_docs
            )

            if article_found and article_ref not in response:
                concepts = " et ".join(query_info["concepts"]) if query_info["concepts"] else "le sujet demandé"
                
                if language == "fr":
                    disclaimer = (
                        f"\n\nNote: L'{article_ref} a été analysé mais ne contient pas "
                        f"de dispositions spécifiques concernant {concepts}."
                    )
                else:
                    disclaimer = (
                        f"\n\nNote: {article_ref} was analyzed but doesn't contain "
                        f"specific provisions regarding {concepts}."
                    )

                if not any(term in response.lower() for term in
                          ["ne contient pas", "ne mentionne pas", "ne traite pas",
                           "does not contain", "does not mention", "does not address"]):
                    response += disclaimer

        return response.strip()

    async def generate_reflection(self, query_id: str, reflection_prompt: str) -> str:
        """Generate self-reflection on a previous response
        """
        if query_id not in cache.query_metadata:
            raise ValueError("Query ID not found")

        metadata = cache.query_metadata[query_id]

        full_prompt = (
            f"{reflection_prompt}\n\n"
            f"Original Question: {metadata['query_text']}\n"
            f"Generated Response: {metadata['final_response']}\n\n"
            "Context Documents:\n"
            + "\n---\n".join([doc["text"] for doc in metadata["retrieved_documents"]])
            + "\n\nAnalysis:"
        )

        return self.model_manager.generate_response(
            prompt=full_prompt,
            max_tokens=300,
            temperature=0.3  # Lower temperature for more focused reflection
        )

# ================= API Endpoints =================
@app.on_event("startup")
async def startup_event():
    """Initialize the application"""
    try:
        # In your FastAPI startup
        logger.info(f"Current query IDs in cache: {list(cache.query_metadata.keys())}")
        await model_manager.load_models()
        await document_manager.load_documents()
        logger.info("API startup completed")
    except Exception as e:
        logger.error(f"Startup failed: {str(e)}")
        raise

@app.post("/query", response_model=QueryResponse)
async def query_legal_assistant(request: QueryRequest, 
                              background_tasks: BackgroundTasks):
    """Main query endpoint"""
    start_time = time.time()
    query_id = str(uuid.uuid4())

    try:
        # Validate field
        if request.field not in DOC_CONFIG:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=f"Invalid field. Choose from: {list(DOC_CONFIG.keys())}"
            )

        # Ensure documents are loaded
        if not await document_manager.load_documents():
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                detail="Failed to load documents"
            )

        # Initialize processor
        processor = QueryProcessor(model_manager)

        # Process query
        query_info = await processor.preprocess_query(request.query)
        retrieved_docs = await processor.search(
            query=request.query,
            field=request.field,
            top_k=request.top_k
        )

        # Handle empty results
        if not retrieved_docs:
            empty_response = {
                "answer": "Aucune information pertinente trouvée dans les documents." 
                          if request.language == "fr" else 
                          "No relevant information found in documents.",
                "retrieved_documents": [],
                "query_time_ms": (time.time() - start_time) * 1000,
                "query_id": query_id,
                "cache_hit": False
            }
            return empty_response

        # Generate response
        prompt = await processor.create_prompt(
            query=request.query,
            retrieved_docs=retrieved_docs,
            query_info=query_info,
            language=request.language
        )

        raw_response = model_manager.generate_response(
            prompt=prompt,
            max_tokens=request.max_tokens,
            temperature=request.temperature
        )

        processed_response = await processor.post_process_response(
            response=raw_response,
            query_info=query_info,
            retrieved_docs=retrieved_docs,
            language=request.language
        )

        # Store query metadata BEFORE generating reflection
        query_metadata = {
            "query_id": query_id,
            "timestamp": datetime.now().isoformat(),
            "query_text": request.query,
            "field": request.field,
            "language": request.language,
            "prompt": prompt,
            "retrieved_documents": [doc.copy() for doc in retrieved_docs],
            "raw_response": raw_response,
            "final_response": processed_response
        }
        
        cache.query_metadata[query_id] = query_metadata

        # Generate reflection if enabled
        reflection = None
        if request.enable_reflection:
            reflection = await processor.generate_reflection(
                query_id=query_id,
                reflection_prompt="Analyze this legal response for accuracy and completeness:"
            )
            query_metadata["reflection"] = reflection

        # Prepare response
        response_data = {
            "answer": processed_response,
            "retrieved_documents": retrieved_docs,
            "query_time_ms": (time.time() - start_time) * 1000,
            "query_id": query_id,
            "reflection": reflection,
            "cache_hit": False
        }

        return response_data

    except Exception as e:
        logger.error(f"Query failed: {str(e)}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=str(e)
        )

@app.post("/feedback", status_code=status.HTTP_201_CREATED)
async def submit_feedback(feedback: FeedbackRequest):
    """Endpoint for submitting feedback on responses"""
    try:
        # Validate query exists
        if feedback.query_id not in cache.query_metadata:
            logger.error(f"Query ID {feedback.query_id} not found in metadata")
            raise HTTPException(
                status_code=status.HTTP_404_NOT_FOUND,
                detail="Query ID not found in system records"
            )

        # Store feedback
        feedback_data = {
            "type": feedback.feedback_type,
            "timestamp": datetime.now().isoformat(),
            "correction": feedback.correction_text if feedback.feedback_type == FeedbackType.CORRECTION else None,
            "comments": feedback.comments
        }

        cache.feedback[feedback.query_id] = feedback_data
        cache.query_metadata[feedback.query_id]["feedback"] = feedback_data

        # In production: Trigger fine-tuning pipeline here

        return {"status": "success", "message": "Feedback recorded"}

    except Exception as e:
        logger.error(f"Feedback failed: {str(e)}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=str(e)
        )

@app.post("/reflect", response_model=Dict)
async def reflect_on_response(request: ReflectionRequest):
    """Endpoint for generating self-reflection on a response"""
    try:
        if request.query_id not in cache.query_metadata:
            raise HTTPException(
                status_code=status.HTTP_404_NOT_FOUND,
                detail="Query ID not found"
            )

        processor = QueryProcessor(model_manager)
        reflection = await processor.generate_reflection(
            query_id=request.query_id,
            reflection_prompt=request.reflection_prompt
        )

        return {
            "reflection": reflection,
            "query_id": request.query_id
        }

    except Exception as e:
        logger.error(f"Reflection failed: {str(e)}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=str(e)
        )

@app.get("/stats", response_model=Dict)
async def get_system_stats():
    """Endpoint for system statistics"""
    return {
        "documents_loaded": {
            field: len(texts) for field, texts in cache.texts.items()
        },
        "queries_processed": len(cache.query_metadata),
        "feedback_received": len(cache.feedback),
        "feedback_stats": {
            "positive": sum(1 for f in cache.feedback.values() if f["type"] == "positive"),
            "negative": sum(1 for f in cache.feedback.values() if f["type"] == "negative"),
            "corrections": sum(1 for f in cache.feedback.values() if f["type"] == "correction")
        }
    }

@app.post("/reload", status_code=status.HTTP_200_OK)
async def reload_documents():
    """Reload documents from source"""
    try:
        success = await document_manager.load_documents(force_reload=True)
        if success:
            return {"status": "success", "message": "Documents reloaded"}
        else:
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                detail="Failed to reload documents"
            )
    except Exception as e:
        logger.error(f"Document reload failed: {str(e)}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=str(e)
        )

# Initialize components
model_manager = ModelManager(MODEL_CONFIG)
document_manager = DocumentManager(DOC_CONFIG)

if __name__ == "__main__":
    import nest_asyncio
    from uvicorn import Config, Server

    # Apply nest_asyncio to allow nested event loops
    nest_asyncio.apply()

    # Configure and run the Uvicorn server
    config = Config(app, host="0.0.0.0", port=8000, log_level="info")
    server = Server(config)
    server.run()

        on_event is deprecated, use lifespan event handlers instead.

        Read more about it in the
        [FastAPI docs for Lifespan Events](https://fastapi.tiangolo.com/advanced/events/).
        
  @app.on_event("startup")
INFO:     Started server process [11072]
INFO:     Waiting for application startup.
2025-03-29 21:59:03,202 - __main__ - INFO - Current query IDs in cache: []
2025-03-29 21:59:03,203 - __main__ - INFO - Loading model: google/flan-t5-large
Device set to use cpu
2025-03-29 21:59:04,937 - __main__ - INFO - Loading embedding model...
2025-03-29 21:59:04,949 - sentence_transformers.SentenceTransformer - INFO - Load pretrained SentenceTransformer: paraphrase-multilingual-MiniLM-L12-v2
2025-03-29 21:59:07,688 - asyncio - ERROR - Task exception was never retrieved
future: <Task finished name='Task-20' coro=<Server.serve() done, defined at c:\Users\mouni\anaconda3\envs\tun_law_env\lib\site-packages\uvicorn\server.py:68> exception=KeyboardInterrupt()>
Traceback (most r

INFO:     127.0.0.1:52207 - "GET /stats HTTP/1.1" 200 OK
INFO:     127.0.0.1:52211 - "GET /stats HTTP/1.1" 200 OK
INFO:     127.0.0.1:52230 - "GET /stats HTTP/1.1" 200 OK


2025-03-29 22:00:33,369 - __main__ - INFO - Using cached documents
Token indices sequence length is longer than the specified maximum sequence length for this model (535 > 512). Running this sequence through the model will result in indexing errors


INFO:     127.0.0.1:52234 - "POST /query HTTP/1.1" 200 OK
INFO:     127.0.0.1:52376 - "GET /stats HTTP/1.1" 200 OK
INFO:     127.0.0.1:52378 - "GET /stats HTTP/1.1" 200 OK
INFO:     127.0.0.1:52382 - "GET /stats HTTP/1.1" 200 OK


2025-03-29 22:04:39,135 - __main__ - INFO - Using cached documents


INFO:     127.0.0.1:52386 - "POST /query HTTP/1.1" 200 OK
INFO:     127.0.0.1:52404 - "GET /stats HTTP/1.1" 200 OK
INFO:     127.0.0.1:52426 - "GET /stats HTTP/1.1" 200 OK


2025-03-29 22:07:53,571 - __main__ - INFO - Using cached documents


INFO:     127.0.0.1:52428 - "POST /query HTTP/1.1" 200 OK
INFO:     127.0.0.1:52484 - "GET /stats HTTP/1.1" 200 OK


In [None]:
import os
import re
import torch
import numpy as np
import asyncio
import time
from functools import lru_cache
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from rank_bm25 import BM25Okapi
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Set
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from sentence_transformers import SentenceTransformer
import uvicorn
import json
import logging
from datetime import datetime

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("api.log"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# Initialize FastAPI
app = FastAPI(
    title="Tunisian Legal Assistant API",
    description="An API for querying Tunisian legal documents using AI",
    version="1.1.0"
)

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Model configuration
MODEL_CONFIG = {
    "model_id": "google/flan-t5-large",  # Consider upgrading to a better model
    "cache_dir": "./model_cache",
    "quantize": False,
    "use_gpu": True if torch.cuda.is_available() else False,
    "max_length": 512
}

# Document configuration
DOC_CONFIG = {
    "criminal": {
        "path": r"C:\Users\mouni\tun_law_project\Tunisia_Crim-Fr.pdf",
        "chunk_size": 512,
        "chunk_overlap": 100
    },
    "constitution": {
        "path": r"C:\Users\mouni\tun_law_project\Constitution_fr.pdf",
        "chunk_size": 512,
        "chunk_overlap": 100
    }
}

# Legal terminology and synonyms for improved search
LEGAL_TERMS = {
    "imprescriptibilité": ["prescription", "délai", "expiration", "crime", "poursuites"],
    "droit": ["liberté", "garantie", "protection"],
    "justice": ["tribunal", "magistrat", "judiciaire"],
    # Add more legal concepts and related terms
}

# Cache for document data and responses
cache = {
    "texts": {},
    "bm25": {},
    "embeddings": {},
    "articles": {},
    "last_loaded": None,
    "responses": {}
}

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline

class ModelManager:
    def __init__(self, config):
        self.config = config
        self.model = None
        self.tokenizer = None
        self.pipeline = None
        self.embedding_model = None

    async def load_models(self):
        logger.info(f"Loading models from {self.config['model_id']}...")

        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.config["model_id"],
            cache_dir=self.config["cache_dir"]
        )

        # Load LLM
        self.model = AutoModelForSeq2SeqLM.from_pretrained(
            self.config["model_id"],
            cache_dir=self.config["cache_dir"]
        )

        # Apply quantization if enabled
        if self.config["quantize"]:
            logger.info("Applying quantization...")
            self.model = torch.quantization.quantize_dynamic(
                self.model,
                {torch.nn.Linear},
                dtype=torch.float16
            )

        # Move model to GPU if available
        if torch.cuda.is_available():
            logger.info("Moving model to GPU...")
            self.model.to('cuda')

        # Create generation pipeline
        device = 0 if torch.cuda.is_available() else -1
        self.pipeline = pipeline(
            task="text2text-generation",
            model=self.model,
            tokenizer=self.tokenizer,
            device=device
        )

        # Load sentence transformer for semantic search
        logger.info("Loading sentence transformer model...")
        self.embedding_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
        if torch.cuda.is_available():
            self.embedding_model = self.embedding_model.to(torch.device("cuda"))

        logger.info("Models loaded successfully")

    @lru_cache(maxsize=100)
    def generate_response(self, prompt, max_tokens=150, temperature=0.7):
        """Generate response with caching for identical prompts"""
        if not self.pipeline:
            raise ValueError("Model pipeline not initialized")

        response = self.pipeline(
            prompt,
            max_new_tokens=max_tokens,
            num_return_sequences=1,
            num_beams=5,
            early_stopping=True,
            temperature=temperature,
            do_sample=True,
            top_p=0.95,
            repetition_penalty=1.2
        )

        # Extract only the newly generated text
        generated_text = response[0]["generated_text"]
        if generated_text.startswith(prompt):
            generated_text = generated_text[len(prompt):].strip()

        return generated_text

    async def get_embeddings(self, texts):
        """Generate embeddings for a list of texts using the sentence transformer model"""
        if not self.embedding_model:
            raise ValueError("Embedding model not initialized")

        embeddings = self.embedding_model.encode(texts, convert_to_tensor=True)
        return embeddings.cpu().numpy()


class DocumentManager:
    def __init__(self, config):
        self.config = config

    async def load_documents(self, force_reload=False):
        global cache
        success = False

        current_time = datetime.now()

        # Skip loading if documents were loaded within the last hour
        if not force_reload and cache.get("last_loaded") and \
           (current_time - cache["last_loaded"]).total_seconds() < 3600 and \
           cache.get("texts") and len(cache["texts"]) > 0:
            logger.info("Using cached document data")
            return True

        logger.info("Loading and processing documents...")

        # Initialize caches
        cache.setdefault("texts", {})
        cache.setdefault("bm25", {})
        cache.setdefault("embeddings", {})
        cache.setdefault("articles", {})

        for field, field_config in self.config.items():
            pdf_path = field_config["path"]
            chunk_size = field_config["chunk_size"]
            chunk_overlap = field_config["chunk_overlap"]

            # Check if file exists
            if not os.path.exists(pdf_path):
                logger.error(f"Document file not found: {pdf_path}")
                continue

            # Load document
            try:
                logger.info(f"Loading document for field: {field}")
                # Load and process documents with article extraction
                await self._process_document(pdf_path, field, chunk_size, chunk_overlap)
                success = True
            except Exception as e:
                logger.error(f"Error loading document {pdf_path}: {str(e)}")
                # Continue with other documents instead of failing completely

        if success:
            cache["last_loaded"] = current_time
            logger.info(f"Documents loaded successfully. Available fields: {list(cache.get('texts', {}).keys())}")
            return True
        else:
            logger.error("Failed to load any documents")
            return False

    async def _process_document(self, pdf_path, field, chunk_size, chunk_overlap):
        # Load PDF
        loader = PyPDFLoader(pdf_path)
        pages = loader.load()

        # Process for different chunk types
        await asyncio.gather(
            self._process_regular_chunks(pages, field, chunk_size, chunk_overlap),
            self._process_article_chunks(pages, field)
        )

    async def _process_regular_chunks(self, pages, field, chunk_size, chunk_overlap):
        """Process document into regular chunks for general search"""
        # Split into chunks
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap
        )
        documents = text_splitter.split_documents(pages)

        # Extract page content and metadata
        texts = []
        metadata = []

        for doc in documents:
            texts.append(doc.page_content)
            metadata.append({
                "page": doc.metadata.get("page", 0),
                "source": os.path.basename(doc.metadata.get("source", "")),
                "id": f"{field}_{doc.metadata.get('page', 0)}_{len(texts)}"
            })

        # Store in cache
        cache["texts"][field] = texts
        cache.setdefault("metadata", {})
        cache["metadata"][field] = metadata

        # Create BM25 index
        tokenized_texts = [text.split() for text in texts]
        cache["bm25"][field] = BM25Okapi(tokenized_texts)

        # Get embeddings in background
        cache.setdefault("embeddings", {})
        cache["embeddings"][field] = await model_manager.get_embeddings(texts)

        logger.info(f"Loaded {len(texts)} regular chunks from {field}")

    async def _process_article_chunks(self, pages, field):
        """Process document to extract article-specific chunks"""
        # Compile regex for article detection
        article_pattern = re.compile(r'Article\s+(\d+)\s*:', re.IGNORECASE)

        # Initialize article cache for this field
        cache.setdefault("articles", {})
        cache["articles"].setdefault(field, {})

        for page in pages:
            content = page.page_content
            article_matches = list(article_pattern.finditer(content))

            for i, match in enumerate(article_matches):
                article_num = match.group(1)
                start_pos = match.start()

                # Find the next article or end of page
                end_pos = article_matches[i+1].start() if i < len(article_matches) - 1 else len(content)

                article_text = content[start_pos:end_pos].strip()

                # Store with article number as key
                cache["articles"][field][article_num] = {
                    "text": article_text,
                    "metadata": {
                        "page": page.metadata.get("page", 0),
                        "source": os.path.basename(page.metadata.get("source", "")),
                        "id": f"{field}_article_{article_num}",
                        "article": int(article_num)
                    }
                }

        logger.info(f"Extracted {len(cache['articles'][field])} articles from {field}")


class QueryProcessor:
    def __init__(self, model_manager):
        self.model_manager = model_manager

    async def preprocess_query(self, query):
        """Extract key information from the query"""
        # Check for article references
        article_match = re.search(r'article\s+(\d+)', query, re.IGNORECASE)
        article_num = article_match.group(1) if article_match else None

        # Extract legal concepts mentioned
        query_lower = query.lower()
        mentioned_concepts = []
        for concept, related_terms in LEGAL_TERMS.items():
            if concept.lower() in query_lower:
                mentioned_concepts.append(concept)
            else:
                # Check for related terms
                for term in related_terms:
                    if term.lower() in query_lower:
                        mentioned_concepts.append(concept)
                        break

        # Extract query keywords (simple approach)
        stop_words = {"le", "la", "les", "un", "une", "des", "de", "du", "et", "ou", "à", "en"}
        words = [word.lower() for word in re.findall(r'\b\w+\b', query)
                if word.lower() not in stop_words and len(word) > 2]

        return {
            "original_query": query,
            "article_num": article_num,
            "concepts": mentioned_concepts,
            "keywords": words
        }

    async def search(self, query, field, top_k=2):
        """Search for relevant documents in the specified field"""
        logger.info(f"Searching in field: {field}")
        
        # Check if documents are loaded
        if not cache.get("texts"):
            logger.error("No documents loaded in cache")
            raise HTTPException(status_code=500, detail="Document cache is empty. Try reloading documents.")
        
        # Handle case where field doesn't exist
        if field not in cache.get("texts", {}):
            logger.error(f"Field '{field}' not found in loaded documents. Available fields: {list(cache.get('texts', {}).keys())}")
            raise HTTPException(
                status_code=400, 
                detail=f"Field '{field}' not found in loaded documents. Available fields: {list(cache.get('texts', {}).keys())}"
            )
        
        # Preprocess query
        query_info = await self.preprocess_query(query)
        
        # Log field info for debugging
        logger.info(f"Field '{field}' contains {len(cache['texts'][field])} text chunks and {len(cache.get('articles', {}).get(field, {}))} articles")
        
        # Specific article search if requested
        if query_info["article_num"] and field in cache.get("articles", {}):
            article_results = await self._find_specific_article(
                query_info["article_num"], field
            )

            if article_results:
                # Combine with general search results if needed
                general_results = await self._general_search(
                    query, field, top_k-1
                )

                # Remove duplicates and limit to top_k
                combined = [article_results]
                for res in general_results:
                    if res["text"] != article_results["text"]:
                        combined.append(res)

                return combined[:top_k]

        # General search if no article specified or article not found
        results = await self._general_search(query, field, top_k)
        return results


    async def _find_specific_article(self, article_num, field):
        """Find a specific article by number"""
        if field in cache.get("articles", {}) and article_num in cache["articles"].get(field, {}):
            article_data = cache["articles"][field][article_num]
            return {
                "text": article_data["text"],
                "score": 1.0,  # Highest score for exact match
                "metadata": article_data["metadata"]
            }
        return None

    async def _general_search(self, query, field, top_k=3):
        if field not in cache.get("texts", {}):
            logger.error(f"Field '{field}' not found in texts cache. Available fields: {list(cache.get('texts', {}).keys())}")
            raise ValueError(f"Field '{field}' not found in loaded documents")
            
        texts = cache["texts"][field]

        # Run BM25 and embedding search in parallel
        bm25_results, semantic_results = await asyncio.gather(
            self._bm25_search(query, field, top_k),
            self._semantic_search(query, field, top_k)
        )

        # Combine results with weights
        combined_indices = set()
        combined_scores = {}

        # BM25 results (60% weight)
        for idx, score in bm25_results:
            combined_indices.add(idx)
            combined_scores[idx] = 0.6 * score

        # Semantic results (40% weight)
        for idx, score in semantic_results:
            combined_indices.add(idx)
            combined_scores[idx] = combined_scores.get(idx, 0) + 0.4 * score

        # Create the combined results
        retrieved_texts = []
        for idx in combined_indices:
            if idx < len(texts):
                retrieved_texts.append({
                    "text": texts[idx],
                    "score": float(combined_scores[idx]),
                    "metadata": cache.get("metadata", {}).get(field, [])[idx] if idx < len(cache.get("metadata", {}).get(field, [])) else {}
                })

        # Sort by combined score
        retrieved_texts.sort(key=lambda x: x["score"], reverse=True)

        return retrieved_texts[:top_k]

    async def _bm25_search(self, query, field, top_k):
        """BM25 search function"""
        if field not in cache.get("bm25", {}):
            logger.error(f"Field '{field}' not found in BM25 cache. Available fields: {list(cache.get('bm25', {}).keys())}")
            raise ValueError(f"BM25 index not found for field '{field}'")
            
        bm25 = cache["bm25"][field]
        bm25_scores = bm25.get_scores(query.split())
        top_indices = np.argsort(bm25_scores)[-top_k:][::-1]

        return [(idx, bm25_scores[idx]) for idx in top_indices]

    async def _semantic_search(self, query, field, top_k):
        """Semantic search using embeddings"""
        # Ensure embeddings exist
        if field not in cache.get("embeddings", {}):
            logger.warning(f"No embeddings for {field}, calculating now")
            if field not in cache.get("texts", {}):
                logger.error(f"Field '{field}' not found in texts cache. Available fields: {list(cache.get('texts', {}).keys())}")
                raise ValueError(f"Cannot create embeddings for field '{field}' - field not found in text cache")
                
            cache.setdefault("embeddings", {})
            cache["embeddings"][field] = await self.model_manager.get_embeddings(cache["texts"][field])

        # Get query embedding
        query_embedding = (await self.model_manager.get_embeddings([query]))[0]
        doc_embeddings = cache["embeddings"][field]

        # Calculate cosine similarity
        similarities = np.zeros(len(doc_embeddings))
        for i, embedding in enumerate(doc_embeddings):
            similarities[i] = np.dot(query_embedding, embedding) / (
                np.linalg.norm(query_embedding) * np.linalg.norm(embedding)
            )

        top_indices = np.argsort(similarities)[-top_k:][::-1]
        return [(idx, similarities[idx]) for idx in top_indices]

    async def create_prompt(self, query, retrieved_documents, query_info, language="fr"):
        """Create an optimized prompt based on query analysis and retrieved documents"""
        # Format context with clear article references
        context_parts = []
        for i, doc in enumerate(retrieved_documents):
            # Check if this is an article and extract the number
            article_info = ""
            if "metadata" in doc:
                if "article" in doc["metadata"]:
                    article_info = f" (Article {doc['metadata']['article']})"
                elif "id" in doc["metadata"] and "article" in doc["metadata"]["id"]:
                    article_match = re.search(r'article_(\d+)', doc["metadata"]["id"])
                    if article_match:
                        article_info = f" (Article {article_match.group(1)})"

            context_parts.append(f"Document {i+1}{article_info}:\n{doc['text']}")

        context = "\n\n".join(context_parts)

        if language == "fr":
            # Check if query asked for a specific article
            article_specific = ""
            if query_info["article_num"]:
                article_specific = (
                    f"La question porte spécifiquement sur l'Article {query_info['article_num']}. "
                    "Si cet article ne traite pas du sujet demandé, précisez-le clairement."
                )

            # Check if query mentioned specific legal concepts
            concept_guidance = ""
            if query_info["concepts"]:
                concepts_list = ", ".join(query_info["concepts"])
                concept_guidance = (
                    f"La question porte sur les concepts juridiques suivants: {concepts_list}. "
                    "Analysez si ces concepts sont présents dans les documents fournis."
                )

            prompt = (
                f"Question : {query}\n\n"
                f"Contexte juridique :\n{context}\n\n"
                "Instructions : répondre a la question d'aprés le contexte.\n\n"
                f"{article_specific}\n"
                f"{concept_guidance}\n"
                # "Votre réponse doit :\n"
                # "1. Vérifier si l'information est présente.\n"
                # "2. Ne pas inventer d'informations.\n"
                # "3. Citer les articles pertinents.\n"
                # "4. Être claire et structurée.\n"
                # "5. Indiquer si l'information manque.\n\n"
                "Réponse :"
            )
        else:
            # English version
            article_specific = ""
            if query_info["article_num"]:
                article_specific = (
                    f"The question specifically asks about Article {query_info['article_num']}. "
                    "If this article does not address the topic, clearly state this."
                )

            concept_guidance = ""
            if query_info["concepts"]:
                concepts_list = ", ".join(query_info["concepts"])
                concept_guidance = (
                    f"The question concerns the following legal concepts: {concepts_list}. "
                    "Analyze whether these concepts are present in the provided documents."
                )

            prompt = (
                f"Question : {query}\n\n"
                f"Contexte juridique :\n{context}\n\n"
                "Instructions : répondre a la question d'aprés le contexte.\n\n"
                f"{article_specific}\n"
                f"{concept_guidance}\n"
                # "Votre réponse doit :\n"
                # "1. Vérifier si l'information est présente.\n"
                # "2. Ne pas inventer d'informations.\n"
                # "3. Citer les articles pertinents.\n"
                # "4. Être claire et structurée.\n"
                # "5. Indiquer si l'information manque.\n\n"
                "Réponse :"
            )
            
        return prompt

    async def post_process_response(self, response, query_info, retrieved_docs, language="fr"):
        """Clean up and validate the response"""
        # Check if a specific article was requested but not found in response
        if query_info.get("article_num"):
            article_reference = f"Article {query_info['article_num']}"
            article_not_found = True

            # Check if any retrieved document contains the requested article
            for doc in retrieved_docs:
                if "metadata" in doc:
                    if "article" in doc["metadata"] and str(doc["metadata"]["article"]) == query_info["article_num"]:
                        article_not_found = False
                        break
                    if "id" in doc["metadata"] and f"article_{query_info['article_num']}" in doc["metadata"]["id"]:
                        article_not_found = False
                        break

            # If article was found in documents but response doesn't clarify it doesn't contain requested info
            if not article_not_found and article_reference not in response:
                concepts = " et ".join(query_info["concepts"]) if query_info["concepts"] else "le sujet demandé"

                if language == "fr":
                    disclaimer = (
                        f"\n\nNote: Bien que l'Article {query_info['article_num']} ait été analysé, "
                        f"il ne contient pas de dispositions spécifiques concernant {concepts}."
                    )
                else:
                    disclaimer = (
                        f"\n\nNote: Although Article {query_info['article_num']} was analyzed, "
                        f"it does not contain specific provisions regarding {concepts}."
                    )

                # Only add disclaimer if response doesn't already indicate this
                if not any(term in response.lower() for term in
                          ["ne contient pas", "ne mentionne pas", "ne traite pas",
                           "does not contain", "does not mention", "does not address"]):
                    response += disclaimer

        return response

# Define API models
class QueryRequest(BaseModel):
    query: str = Field(..., description="The legal question to ask")
    field: str = Field(..., description="The legal domain to query (criminal or constitution)")
    language: str = Field("fr", description="Response language (fr or en)")
    top_k: int = Field(3, description="Number of documents to retrieve", ge=1, le=5)
    max_tokens: int = Field(150, description="Maximum number of tokens to generate", ge=50, le=500)
    temperature: float = Field(0.7, description="Temperature for text generation", ge=0.1, le=1.0)

class RetrievedDocument(BaseModel):
    text: str
    score: float
    metadata: Dict

class QueryResponse(BaseModel):
    answer: str
    retrieved_documents: List[RetrievedDocument]
    query_time_ms: float
    cache_hit: bool = False

# Initialize components
model_manager = ModelManager(MODEL_CONFIG)
document_manager = DocumentManager(DOC_CONFIG)

@app.on_event("startup")
async def startup_event():
    try:
        # Load models
        await model_manager.load_models()
        
        # Try to load documents
        docs_loaded = await document_manager.load_documents()
        
        if not docs_loaded:
            logger.warning("Failed to load documents during startup. API will attempt to load on first request.")
        
        logger.info("API started successfully")
    except Exception as e:
        logger.error(f"Startup error: {str(e)}")
        # Don't raise the exception - let the API start anyway
        # We'll handle document loading errors on a per-request basis

@app.post("/query", response_model=QueryResponse)
async def query_legal_assistant(request: QueryRequest, background_tasks: BackgroundTasks):
    start_time = time.time()

    try:
        # Create cache key
        cache_key = f"{request.query}_{request.field}_{request.language}_{request.top_k}_{request.max_tokens}_{request.temperature}"

        # Check cache for identical query
        if cache_key in cache.get("responses", {}):
            result = cache["responses"][cache_key].copy()
            result["cache_hit"] = True
            result["query_time_ms"] = 0.0
            logger.info(f"Cache hit for query: {request.query[:30]}...")
            return result

        # Validate field
        if request.field not in DOC_CONFIG:
            raise HTTPException(
                status_code=400,
                detail=f"Invalid field. Choose from: {', '.join(DOC_CONFIG.keys())}"
            )

        # Ensure documents are loaded
        if not cache.get("texts") or request.field not in cache.get("texts", {}):
            logger.info("Documents not loaded, attempting to load now")
            await document_manager.load_documents(force_reload=True)
            
            # Check again after loading attempt
            if not cache.get("texts") or request.field not in cache.get("texts", {}):
                raise HTTPException(
                    status_code=500,
                    detail=f"Could not load documents for field: {request.field}. Please check file paths and try again."
                )

        # Initialize query processor
        query_processor = QueryProcessor(model_manager)

        # Process query to extract key information
        query_info = await query_processor.preprocess_query(request.query)

        # Retrieve relevant documents
        retrieved_docs = await query_processor.search(
            query=request.query,
            field=request.field,
            top_k=request.top_k
        )

        if not retrieved_docs:
            no_info_message = (
                "Je n'ai pas trouvé d'informations pertinentes dans les documents juridiques disponibles."
                if request.language == "fr" else
                "I could not find relevant information in the available legal documents."
            )

            result = {
                "answer": no_info_message,
                "retrieved_documents": [],
                "query_time_ms": (time.time() - start_time) * 1000,
                "cache_hit": False
            }

            # Cache the response
            cache.setdefault("responses", {})
            cache["responses"][cache_key] = result

            return result

        # Create prompt
        prompt = await query_processor.create_prompt(
            query=request.query,
            retrieved_documents=retrieved_docs,
            query_info=query_info,
            language=request.language
        )

        # Generate response
        raw_response = model_manager.generate_response(
            prompt=prompt,
            max_tokens=request.max_tokens,
            temperature=request.temperature
        )

        # Post-process response
        processed_response = await query_processor.post_process_response(
            response=raw_response,
            query_info=query_info,
            retrieved_docs=retrieved_docs,
            language=request.language
        )

        # Calculate query time
        query_time_ms = (time.time() - start_time) * 1000

        result = {
            "answer": processed_response,
            "retrieved_documents": retrieved_docs,
            "query_time_ms": query_time_ms,
            "cache_hit": False
        }

        # Cache the response
        cache.setdefault("responses", {})
        cache["responses"][cache_key] = result

        # Schedule cache cleanup in background if it's getting too large
        if len(cache["responses"]) > 1000:
            background_tasks.add_task(clean_response_cache)

        return result

    except Exception as e:
        logger.error(f"Query error: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

async def clean_response_cache():
    """Clean up the response cache if it gets too large"""
    global cache
    if len(cache.get("responses", {})) > 1000:
        # Keep only the most recent 500 responses
        sorted_keys = sorted(
            cache["responses"].keys(),
            key=lambda k: cache["responses"][k].get("timestamp", 0),
            reverse=True
        )

        new_cache = {}
        for key in sorted_keys[:500]:
            new_cache[key] = cache["responses"][key]

        cache["responses"] = new_cache
        logger.info(f"Cleaned response cache, now contains {len(cache['responses'])} items")



@app.post("/reload")
async def reload_documents():
    try:
        await document_manager.load_documents(force_reload=True)
        # Clear the response cache when documents are reloaded
        cache["responses"] = {}
        return {"status": "success", "message": "Documents reloaded successfully"}
    except Exception as e:
        logger.error(f"Reload error: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/stats")
async def get_stats():
    """Get API usage statistics"""
    return {
        "total_queries": len(cache.get("responses", {})),
        "documents_loaded": {
            field: len(texts) for field, texts in cache.get("texts", {}).items()
        },
        "articles_extracted": {
            field: len(articles) for field, articles in cache.get("articles", {}).items()
        },
        "cache_hit_ratio": sum(1 for r in cache.get("responses", {}).values() if r.get("cache_hit", False)) /
                          max(1, len(cache.get("responses", {})))
    }

# Example query for the specific article mentioned in your code
example_query = {
  "query": "Quels sont les principes relatifs à l'imprescriptibilité selon l'article 21 de la Constitution tunisienne ?",
  "field": "constitution",
  "language": "fr"
}

import nest_asyncio
nest_asyncio.apply()

# Now your uvicorn.run() should work
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8001)


INFO:     Started server process [16996]
INFO:     Waiting for application startup.
2025-03-29 20:17:53,663 - __main__ - INFO - Loading models from google/flan-t5-large...
2025-03-29 20:17:56,332 - __main__ - INFO - Loading sentence transformer model...
2025-03-29 20:17:56,337 - sentence_transformers.SentenceTransformer - INFO - Use pytorch device_name: cpu
2025-03-29 20:17:56,340 - sentence_transformers.SentenceTransformer - INFO - Load pretrained SentenceTransformer: paraphrase-multilingual-MiniLM-L12-v2
2025-03-29 20:18:00,597 - asyncio - ERROR - Task exception was never retrieved
future: <Task finished name='Task-163' coro=<Server.serve() done, defined at c:\Users\mouni\anaconda3\envs\tun_law_env\lib\site-packages\uvicorn\server.py:68> exception=KeyboardInterrupt()>
Traceback (most recent call last):
  File "c:\Users\mouni\anaconda3\envs\tun_law_env\lib\site-packages\uvicorn\main.py", line 579, in run
    server.run()
  File "c:\Users\mouni\anaconda3\envs\tun_law_env\lib\site-packa

INFO:     127.0.0.1:50147 - "POST /query HTTP/1.1" 200 OK


INFO:     Shutting down
INFO:     Waiting for application shutdown.
INFO:     Application shutdown complete.
INFO:     Finished server process [16996]
