In [None]:
!pip -q install \
  langfuse \
  qdrant-client \
  sentence-transformers \
  fastembed \
  groq \
  redis \
  langchain\
  langchain-community \
  langchain-google-genai\
  google-generativeai \
  pymongo \
  langchain-redis

In [None]:
import json
import requests
import time
import hashlib

from datetime import datetime, timezone, timedelta
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Union, Any, Callable, Tuple

from qdrant_client import QdrantClient
from qdrant_client.http.models import SparseVector

from fastembed import TextEmbedding, SparseTextEmbedding, LateInteractionTextEmbedding
from langchain.text_splitter import RecursiveCharacterTextSplitter
from sentence_transformers import SentenceTransformer

import langfuse
from langfuse import Langfuse
from langfuse import observe, get_client
from uuid import uuid4

from groq import Groq
from pymongo import MongoClient

import google.generativeai as genai
from google.genai.types import HttpOptions, GenerateContentConfig
from google.genai import Client as GeminiClient

from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.memory.chat_message_histories import RedisChatMessageHistory
from langchain_community.chat_message_histories import RedisChatMessageHistory
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage
from langchain.schema import messages_from_dict, messages_to_dict

import re
import redis
import gzip
import base64
import numpy as np
from numpy.linalg import norm

# Logging setup
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [None]:
@dataclass
class CacheConfig:
    redis_host: str
    redis_port: int
    redis_password: Optional[str]
    redis_db: int
    cache_ttl: int
    enable_cache: bool

@dataclass
class SearchConfig:
    enable_reranking: bool
    enable_query_expansion: bool
    rerank_top_k: int
    final_top_k: int
    query_expansion_count: int
    context_overlap_threshold: float

@dataclass
class RAGConfig:
    """Configuration for the RAG pipeline, including caching and enhanced search features."""

    # Qdrant / Vector Database config
    qdrant_url: str = ...
    qdrant_api_key: Optional[str] = ...
    collection_name: str = "deeplearning_ai_news_embeddings"
    qdrant_top_k: int = 5
    score_threshold: float = 0.5  # Use more strict default from EnhancedRAGConfig

    # Embedding model
    embedding_model_name: str = "BAAI/bge-m3"

    # Scoring weights
    dense_score: float = 0.5
    colbert_score: float = 0.3
    sparse_score: float = 0.2

    # Context
    max_context_length: int = 4000  # Use shorter max length from EnhancedRAGConfig

    # Gemini LLM config
    gemini_api_key: str = ...
    gemini_model_name: str = "gemini-2.5-flash"
    temperature: float = 0.15
    top_k: int = 3
    top_p: float = 0.85

    # Langfuse observability
    langfuse_public_key: str = ...
    langfuse_secret_key: str = ...
    langfuse_host: str = ...

    # Lakera Guard
    lakera_guard_api_key: Optional[str] = ...

    # Redis caching config
    redis_host: str = ...
    redis_port: int = ...
    redis_password: Optional[str] = ...
    redis_db: int = 0
    cache_ttl: int = 3600
    enable_cache: bool = True

    # MongoDB Config
    mongo_uri: str = ...
    mongo_db_name: str = "deeplearning_ai_news"

    # Search enhancement config
    enable_reranking: bool = True
    enable_query_expansion: bool = True
    rerank_top_k: int = 20
    final_top_k: int = 5
    query_expansion_count: int = 3
    context_overlap_threshold: float = 0.8
    reranker_model_name: str = 'multi-qa-MiniLM-L6-cos-v1'

    # Converational Chatbot
    redis_url: str = f"redis://:{redis_password}@{redis_host}:{redis_port}"
    session_id: str = "default_session"
    max_history_pairs: int = 10
    max_final_contexts: int = 10
    enable_compression: bool = True

    def get_redis_url(self) -> str:
        """Build Redis URL from config."""
        return f"redis://:{self.redis_password}@{self.redis_host}:{self.redis_port}"

    def to_cache_config(self) -> CacheConfig:
        return CacheConfig(
            redis_host=self.redis_host,
            redis_port=self.redis_port,
            redis_password=self.redis_password,
            redis_db=self.redis_db,
            cache_ttl=self.cache_ttl,
            enable_cache=self.enable_cache
        )

    def to_search_config(self) -> SearchConfig:
        return SearchConfig(
            enable_reranking=self.enable_reranking,
            enable_query_expansion=self.enable_query_expansion,
            rerank_top_k=self.rerank_top_k,
            final_top_k=self.final_top_k,
            query_expansion_count=self.query_expansion_count,
            context_overlap_threshold=self.context_overlap_threshold
        )

# Initialize config
config = RAGConfig()

In [None]:
class GeminiLLM:
    def __init__(self, model_name: str, api_key: str, llm_config: RAGConfig=None):
        self.model_name = model_name
        self.api_key = api_key
        self.config = llm_config
        self.model = self._setup_model()

    def _setup_model(self):
        return ChatGoogleGenerativeAI(
            model=self.model_name,
            google_api_key=self.api_key,
            temperature=self.config.temperature,
            top_k=self.config.top_k,
            top_p=self.config.top_p,
        )

    def generate(self, query: str, system_prompt: str = None, langfuse_client: Optional[object] = None) -> str:
        messages = []
        if system_prompt:
            messages.append(SystemMessage(content=system_prompt))
        messages.append(HumanMessage(content=query))

        if langfuse_client:
            with langfuse_client.start_as_current_span(
                name="llm_generation",
                metadata={"model": self.model_name}
            ) as root_span:

                with langfuse_client.start_as_current_span(name="llm.generate_content") as llm_span:
                    llm_span.update(input={"prompt": query, "system_prompt": system_prompt})
                    try:
                        logger.info("Gemini is generating answer...")
                        response = self.model.invoke(messages).content
                        llm_span.update(output={"response": response}, status_message="success")
                    except Exception as e:
                        logger.error(f"Error in llm.generate: {e}")
                        llm_span.update(output={"error": str(e)}, status_message="error")
                        raise

                root_span.update(output={"status": "completed"})
                return response
        else:
            try:
                logger.info("Gemini is generating answer...")
                response = self.model.invoke(messages).content
                return response
            except Exception as e:
                logger.error(f"Error in llm.generate: {e}")
                raise

In [None]:
DEFAULT_SYSTEM_PROMPT = """You are an AI assistant specialized in answering questions related to technology in the fields Machine Learning/ Deep Learning/ Artificial Intelligence.
Your task is to generate clear and accurate answers based on the provided **Context** retrieved from relevant documents.

Instructions:
- Read the "Question" carefully and determine user's intent.
- If the "Question" is written in Vietnamese, translate it into English and treat it as the question to be answered.
- Always answer based on the content given after the word **"Context:"** and ignore the "Question" after processing its language.
- If the provided **Context** is too short (for example, fewer than 30 words), politely ask the user to provide more details. Remember use a gentle and respectful Vietnamese tone.
- If conflicting info appears in the context, prioritize the most recent or clearly stated.

When answering, please focus on:
- The most relevant and accurate information from the context that addresses the user's question.
- Clarify core concepts or definitions that are necessary for the answer.
- Explain any relevant approaches, techniques, applications, or research findings mentioned in the context.
- Maintain logical flow and coherence in the answer.

Your answer must:
- Be written in Vietnamese only.
- Be clear, concise, and accessible to a broad Vietnamese audience.
- Be strictly faithful to the original context without adding personal opinions or unsupported conclusions.
- Contain 4–6 sentences, each under 30 words, unless the context is very short.
- Use a professional, friendly, and approachable tone—like how a knowledgeable Vietnamese communicator explains technical topics naturally.

To improve coherence:
- Use natural connectors like “ngoài ra”, “hơn nữa”, “tuy nhiên”, “bên cạnh đó”, “kết lại” when appropriate.
- Present ideas in a logically connected and smooth-flowing paragraph.

Avoid:
- Including minor or irrelevant details.
- Overusing technical jargon unless explained clearly.

Formatting requirements:
- **Write your answer as a single paragraph.**
- **Begin the answer with a polite Vietnamese opening such as**:
  - “Dạ, câu trả lời là…”
  - “Vâng, theo nội dung thì…”
  - “Dạ em xin trả lời như sau…”
  - “Sau đây là phần giải đáp ạ…”
  - “Dạ, nội dung chính là…”
- **Do not translate domain-specific terms like "summarization", "techniques", "use cases", etc.**
- **Only output Vietnamese text.**
"""

In [None]:
class RAGSingleVectorSearch:
    """Main RAG Pipeline with Single Vector Search"""
    def __init__(self, llm: GeminiLLM, embedding_model: object,
                 qdrant_client: QdrantClient, config: RAGConfig):

        self.config = config
        self.embedding_model = embedding_model
        self.llm = llm
        self.system_prompt = None
        self.qdrant_client = qdrant_client
        self.gemini_client = GeminiClient(
            api_key=self.config.gemini_api_key,
            http_options=HttpOptions(api_version='v1')
        )

        # Config for search
        self.top_k = self.config.qdrant_top_k
        self.max_context_length = self.config.max_context_length
        self.score_threshold = self.config.score_threshold

        logger.info("RAG Pipeline with Single Vector Search initialized successfully")

    def _count_tokens(self, text: str) -> int:
        """Count tokens from text"""
        try:
            response = self.gemini_client.models.count_tokens(
                model=self.config.gemini_model_name,
                contents=text
            )
            return response.total_tokens
        except Exception as e:
            logger.warning(f"Failed to count tokens: {e}")
            # Fallback: rough estimation (1 token ≈ 4 characters)
            return len(text) // 4

    def _load_default_system_prompt(self):
        return DEFAULT_SYSTEM_PROMPT

    def _encode_query(self, query: str, langfuse_client: Optional[object] = None) -> object:
        """Encode query with single embedding model"""
        if langfuse_client:
            with langfuse_client.start_as_current_span(name="rag.encode_query") as encode_span:
                encode_span.update(input={"query": query})

                try:
                    vector = self._do_encode_query(query)
                    encode_span.update(
                        output={
                            "vector_dim": len(vector) if hasattr(vector, '__len__') else "unknown",
                            "vector_preview": str(vector[:10]) + "..." if hasattr(vector, '__len__') else "unknown"
                        },
                        status_message="success"
                    )
                    return vector

                except Exception as e:
                    encode_span.update(status_message="error", output={"error": str(e)})
                    raise
        else:
            return self._do_encode_query(query)

    def _do_encode_query(self, query: str) -> object:
        """Actually encode query with embedding model"""
        return self.embedding_model.encode(query)

    def _search_vector(self, vector, limit: int = None, langfuse_client: Optional[object] = None) -> object:
        """Search with single vector embedding"""
        if limit is None:
            limit = self.top_k

        if langfuse_client:
            with langfuse_client.start_as_current_span(name="rag.search_vector") as search_span:
                try:
                    results = self._do_search_vector(vector, limit)
                    results_review = []
                    for pt in results.points[:limit]:
                        results_review.append({
                            "id": pt.id,
                            "score": round(pt.score, 4),
                            "text_preview": pt.payload.get("page_content", "")[:100]  # short preview
                        })
                    search_span.update(
                        input={
                            "limit": limit,
                            "score_threshold": self.score_threshold
                        },
                        output={
                            "result_count": len(results.points),
                            "results_preview": results_review
                        },
                        status_message="success"
                    )
                    return results

                except Exception as e:
                    search_span.update(status_message="error", output={"error": str(e)})
                    raise
        else:
            return self._do_search_vector(vector, limit)

    def _do_search_vector(self, vector, limit: int) -> object:
        """Actually perform vector search"""
        return self.qdrant_client.query_points(
            collection_name=self.config.collection_name,
            query=vector,
            limit=limit,
            with_payload=True,
            score_threshold=self.score_threshold
        )

    def _search_context(self, query: str, langfuse_client: Optional[object] = None) -> List[str]:
        """Single vector search context from Qdrant"""
        if langfuse_client:
            with langfuse_client.start_as_current_span(name="rag.search_context") as search_span:
                search_span.update(input={"query": query})

                try:
                    context_texts = self._do_search_context(query, langfuse_client)
                    search_span.update(
                        output={
                            "final_context_count": len(context_texts),
                            "score_threshold": self.score_threshold
                        },
                        status_message="success"
                    )
                    return context_texts

                except Exception as e:
                    search_span.update(status_message="error", output={"error": str(e)})
                    raise
        else:
            return self._do_search_context(query, langfuse_client)

    def _do_search_context(self, query: str, langfuse_client: Optional[object] = None) -> List[str]:
        """Actually perform single vector search"""
        # Encode query with single embedding model
        query_vector = self._encode_query(query, langfuse_client)

        # Search with single vector
        search_results = self._search_vector(query_vector, limit=self.top_k, langfuse_client=langfuse_client)

        # Extract text from filtered results
        texts = []
        for result in search_results.points:
            page_content = result.payload.get("page_content")
            if page_content:  # Check if content exists
                texts.append(page_content.strip())
        return texts

    def _build_context(self, search_results: List[str]) -> str:
        """Build context from search results"""
        if not search_results:
            return ""

        full_context = ""
        for idx, result in enumerate(search_results):
            full_context += f"Context {idx+1}: {result}\n"
        return full_context[:self.max_context_length]

    def generate_response(self, query: str, context: str, system_prompt: str, langfuse_client: Optional[object] = None) -> str:
        """Generate response using LLM with context and system prompt"""
        # Create Prompt
        final_system_prompt = f"""{system_prompt.strip()}
--- Here are the contexts ---:
{context}
"""
        response = self.llm.generate(query, final_system_prompt)

        # Count tokens
        query_tokens = self._count_tokens(query)
        prompt_tokens = self._count_tokens(final_system_prompt)
        completion_tokens = self._count_tokens(response)
        total_tokens = prompt_tokens + completion_tokens

        # Langfuse trace
        if langfuse_client:
            with langfuse_client.start_as_current_generation(
                name="rag.llm_generate",
                model=getattr(self.llm, "model_name", "unknown-model")
            ) as generation:
                generation.update(
                    input={"query": query, "system_prompt": final_system_prompt},
                    output={"response": response},
                    usage_details={
                        "query_tokens": query_tokens,
                        "prompt_tokens": prompt_tokens,
                        "completion_tokens": completion_tokens,
                        "total_tokens": total_tokens,
                    },
                )

        return response

In [None]:
class HistoryGuidedReranker:
    def __init__(self, embedding_model, similarity_weight=0.7, history_weight=0.3, max_final_contexts=10):
        self.embedding_model = embedding_model
        self.similarity_weight = similarity_weight
        self.history_weight = history_weight
        self.max_final_contexts = max_final_contexts

    def _normalize(self, vector):
        norm = np.linalg.norm(vector)
        return vector / norm if norm > 0 else vector

    def _compute_cosine_similarity(self, vec1, vec2):
        return float(np.dot(self._normalize(vec1), self._normalize(vec2)))

    def _compute_history_relevance(self, context_embedding, history_embeddings):
        if history_embeddings is None or len(history_embeddings) == 0:
            return 0.0
        similarities = [
            self._compute_cosine_similarity(context_embedding, hist_emb)
            for hist_emb in history_embeddings
        ]
        weights = [1.0 / (i + 1) for i in range(len(similarities))]  # recency weighting
        weighted_sum = sum(sim * w for sim, w in zip(similarities, weights))
        return weighted_sum / sum(weights)

    def rerank(self, query, contexts, history_texts=None):
        query_embedding = self.embedding_model.encode(query)
        history_embeddings = self.embedding_model.encode(history_texts) if history_texts else []

        results = []
        for context in contexts:
            ctx_emb = self.embedding_model.encode(context)
            sim = self._compute_cosine_similarity(query_embedding, ctx_emb)
            hist_rel = self._compute_history_relevance(ctx_emb, history_embeddings)
            final_score = self.similarity_weight * sim + self.history_weight * hist_rel
            results.append({
                "context": context,
                "query_similarity": round(sim, 4),
                "history_relevance": round(hist_rel, 4),
                "final_score": round(final_score, 4)
            })
        return sorted(results, key=lambda x: x["final_score"], reverse=True)

In [None]:
class SemanticPromptCache:
    def __init__(self, redis_url: str, embedding_model: SentenceTransformer):
        self.redis = redis.from_url(redis_url, decode_responses=True)
        self.embedding_model = embedding_model

    def _cosine_similarity(self, a, b):
        return np.dot(a, b) / (norm(a) * norm(b))

    def get_cached_response(self, prompt: str, similarity_threshold: float = 0.9) -> Optional[str]:
        new_embedding = self.embedding_model.encode(prompt)

        # Get prompt
        all_cache = self.redis.lrange("semantic_cache", 0, -1)
        for item in all_cache:
            data = json.loads(item)
            cached_embedding = np.array(data["embedding"])
            sim = self._cosine_similarity(new_embedding, cached_embedding)

            if sim >= similarity_threshold:
                return data["response"]
        return None

    def cache_response(self, prompt: str, response: str):
        embedding = self.embedding_model.encode(prompt)
        cache_entry = {
            "prompt": prompt,
            "response": response,
            "embedding": embedding.tolist()
        }
        self.redis.rpush("semantic_cache", json.dumps(cache_entry))

In [None]:
class ConversationalRAGChatbot:
    """
    Enhanced Conversational RAG Chatbot with History-Guided Reranking
    """

    def __init__(
        self,
        rag_pipeline,
        config: RAGConfig,
        session_id: str = "default_session",
        reranker: Optional[object] = None,
        langfuse_client: Optional[object] = None,
        semantic_cache: Optional[object] = None
    ):
        self.rag_pipeline = rag_pipeline
        self.config = config
        self.session_id = session_id
        self.max_history_pairs = config.max_history_pairs
        self.max_final_contexts = config.max_final_contexts
        self.enable_compression = config.enable_compression
        self.langfuse_client = langfuse_client

        # Chat history
        self.history = RedisChatMessageHistory(
            url=self.config.redis_url,
            session_id=self.session_id,
            ttl=1800
        )

        # Semantic caching
        self.semantic_cache = semantic_cache

        # Reranker
        self.reranker = reranker

        # Security check
        self.lakera_session = requests.Session()
        self.lakera_guard_api_key = config.lakera_guard_api_key

        # Dev
        self.logger = logging.getLogger(__name__)

    # def _generate_session_hash(self, session_id: str) -> str:
    #     return hashlib.md5(session_id.encode()).hexdigest()

    def _compress_data(self, data: str) -> str:
        if not self.enable_compression:
            return data
        try:
            compressed = gzip.compress(data.encode('utf-8'))
            return base64.b64encode(compressed).decode('utf-8')
        except Exception as e:
            self.logger.warning(f"Compression failed: {e}")
            return data

    def _decompress_data(self, data: str) -> str:
        if not self.enable_compression or not self._is_base64_gzip(data):
            return data
        try:
            decoded = base64.b64decode(data.encode('utf-8'))
            return gzip.decompress(decoded).decode('utf-8')
        except Exception as e:
            self.logger.warning(f"Decompression failed: {e}")
            return data

    def _check_prompt_injection(self, prompt: str, root_span: Optional[object] = None) -> Optional[str]:
        """Check prompt for injection using Lakera Guard API and log via Langfuse"""
        try:
            response = self.lakera_session.post(
                "https://api.lakera.ai/v2/guard",
                json={"messages": [{"role": "user", "content": prompt}]},
                headers={"Authorization": f"Bearer {self.config.lakera_guard_api_key}"},
                timeout=10  # Add timeout to prevent hanging
            )
            response.raise_for_status()
            response_json = response.json()
            results = response_json.get("results", [])
            flagged = any(result.get("flagged", False) for result in results)
        except Exception as e:
            logger.warning(f"Error in LakeraGuard: {e}")
            # Don't block on Lakera Guard failure, just log and continue
            return None

        output_msg = "[PROMPT INJECTION] Lakera Guard identified a prompt injection. No user was harmed by this LLM."

        if self.langfuse_client:
            with self.langfuse_client.start_as_current_span(name="prompt_injection_check") as prompt_injection_span:
                prompt_injection_span.update(input={"prompt": prompt})
                if flagged:
                    prompt_injection_span.update(
                        output={
                            "lakera_output": output_msg,
                            "lakera_response": response_json
                        },
                        status_message="warning"
                    )
                    if root_span:
                        root_span.update(
                            output={"final_response": output_msg},
                            status_message="warning"
                        )
                    return output_msg
                else:
                    prompt_injection_span.update(
                        output={
                            "lakera_output": "No prompt injection detected.",
                            "lakera_response": response_json
                        },
                        status_message="success"
                    )
        else:
            if flagged:
                return output_msg
        return None

    def _load_default_system_prompt(self):
        return self.rag_pipeline._load_default_system_prompt()

    def _validate_inputs(self, query: str) -> Tuple[bool, str]:
        if not query or not query.strip():
            return False, "Query cannot be empty"
        if len(query.strip()) > 10000:
            return False, "Query is too long (max 10000 characters)"
        return True, ""

    def _is_base64_gzip(self, data: str) -> bool:
        try:
            decoded = base64.b64decode(data.encode('utf-8'))
            return decoded[:2] == b'\x1f\x8b'  # Magic bytes of gzip
        except Exception:
            return False

    def _trim_history(self) -> None:
        try:
            messages = self.history.messages
            if len(messages) <= 2 * self.max_history_pairs:
                return

            messages_to_keep = messages[-2 * self.max_history_pairs:]
            self.history.clear()

            for message in messages_to_keep:
                content = message.content

                if self.enable_compression and not self._is_base64_gzip(content):
                    content = self._compress_data(content)

                if isinstance(message, HumanMessage):
                    self.history.add_user_message(content)
                elif isinstance(message, AIMessage):
                    self.history.add_ai_message(content)

        except Exception as e:
            self.logger.error(f"Error trimming history: {e}")

    def _get_history_texts(self) -> List[str]:
        """
        Extract and format history texts as a list of strings like:
        ["User: ...", "Assistant: ..."]
        """
        try:
            messages = self.history.messages
            formatted_texts = []

            for message in messages:
                content = message.content
                if self.enable_compression:
                    content = self._decompress_data(content)

                if isinstance(message, HumanMessage):
                    formatted_texts.append(f"User: {content}")
                elif isinstance(message, AIMessage):
                    formatted_texts.append(f"Assistant: {content}")

            return formatted_texts
        except Exception as e:
            self.logger.error(f"Error formatting history texts: {e}")
            return []

    def _enhanced_search_with_reranking(self, query: str, root_span: Optional[object] = None) -> List[str]:
        """
        Enhanced search with reranking using history context

        Args:
            query: User query
            root_span: Optional Langfuse span for tracing

        Returns:
            List of reranked context strings
        """
        try:
            if root_span and self.langfuse_client:
                with self.langfuse_client.start_as_current_span(
                    name="rag.enhanced_search_with_reranking"
                ) as search_span:
                    search_span.update(input={"query": query})

                    # Get initial search results from RAG pipeline
                    initial_contexts = self.rag_pipeline._do_search_context(query, self.langfuse_client)

                    # Apply reranking if reranker is available
                    if self.reranker and initial_contexts is not None and len(initial_contexts) > 0:
                        # Get history texts for reranking
                        history_texts = self._get_history_texts()

                        with self.langfuse_client.start_as_current_span(
                            name="conversational_rag.rerank_contexts"
                        ) as rerank_span:
                            # Rerank contexts
                            reranked_results = self.reranker.rerank(
                                query=query,
                                contexts=initial_contexts,
                                history_texts=history_texts
                            )

                            # Extract contexts from reranked results
                            final_contexts = [result["context"] for result in reranked_results]
                            rerank_span.update(
                                input={
                                    'initial_context': initial_contexts,
                                    'history_texts': history_texts
                                },
                                output={
                                    'final_context': final_contexts,
                                    'full_rerank_results': reranked_results

                                },
                                status_message='success'
                            )

                        search_span.update(
                            output={
                                "initial_context_count": len(initial_contexts),
                                "final_context_count": len(final_contexts),
                                "reranker_used": True,
                                "history_texts_count": len(history_texts)
                            },
                            status_message="success"
                        )

                        self.logger.debug(f"Reranked {len(initial_contexts)} contexts to {len(final_contexts)}")
                        return final_contexts
                    else:
                        # Return original contexts if no reranker
                        search_span.update(
                            output={
                                "context_count": len(initial_contexts),
                                "reranker_used": False
                            },
                            status_message="success"
                        )
                        return initial_contexts

            else:
                # No tracing - direct execution
                # Get initial search results from RAG pipeline
                initial_contexts = self.rag_pipeline._do_search_context(query)

                # Apply reranking if reranker is available
                if self.reranker and initial_contexts:
                    # Get history texts for reranking
                    history_texts = self._get_history_texts()

                    # Rerank contexts
                    reranked_results = self.reranker.rerank(
                        query=query,
                        contexts=initial_contexts,
                        history_texts=history_texts
                    )

                    # Extract contexts from reranked results
                    final_contexts = [result["context"] for result in reranked_results]

                    self.logger.debug(f"Reranked {len(initial_contexts)} contexts to {len(final_contexts)}")
                    return final_contexts
                else:
                    # Return original contexts if no reranker
                    return initial_contexts

        except Exception as e:
            self.logger.error(f"Error in enhanced search with reranking: {e}")
            # Fallback to basic search
            try:
                return self.rag_pipeline._do_search_context(query)
            except Exception as fallback_e:
                self.logger.error(f"Fallback search also failed: {fallback_e}")
                return []

    def chat(self, query: str, system_prompt: Optional[str] = None) -> str:
        is_valid, error_msg = self._validate_inputs(query)
        if not is_valid:
            self.logger.warning(f"Invalid input: {error_msg}")
            return f"Error: {error_msg}"

        if self.langfuse_client:
            with self.langfuse_client.start_as_current_span(name=f"{self.session_id}") as root_span:
                return self._do_chat(query, system_prompt, root_span)
        else:
            return self._do_chat(query, system_prompt)

    def _build_query(self, query: str, history_lines: List[str]) -> str:
        """
        Builds final query prompt including history and user question.
        `history_lines` is a list of formatted strings: ["User: ...", "Assistant: ..."]
        """
        prompt_parts = []
        if history_lines:
            trimmed_history = history_lines[-2 * self.max_history_pairs:]
            history_section = "\n".join(trimmed_history)
            prompt_parts.append("--- Conversation History ---\n" + history_section)

        prompt_parts.append("--- Current User Question ---\n" + query.strip())
        return "\n\n".join(prompt_parts)

    def _do_chat(self, query: str, system_prompt: Optional[str] = None,
                 root_span: Optional[object] = None) -> str:
        try:
            # Prompt injection check
            flagged = self._check_prompt_injection(query, root_span)
            if flagged:
                if root_span:
                    with self.langfuse_client.start_as_current_span(
                        name="prompt_injection_check.flagged"
                    ) as flagged_span:
                        flagged_span.update(
                            input={"prompt": query},
                            output={"flagged_response": flagged},
                            status_message='warning'
                        )

                        root_span.update(
                            input={"query": query},
                            output={"final_response": flagged},
                            status_message="warning",
                            tags={"prompt_injection": "flagged"}
                        )
                return flagged

            # Search cache prompts
            cached_response = self.semantic_cache.get_cached_response(query)
            if cached_response:
                if root_span:
                    with self.langfuse_client.start_as_current_span(
                        name="semantic_cache.hit"
                    ) as cache_hit_span:
                        cache_hit_span.update(
                            input={"query": query},
                            output={"cached_response": cached_response},
                            status_message="success"
                        )

                    root_span.update(
                        input={"query": query},
                        output={"final_response": cached_response},
                        status_message="success",
                        tags={"semantic_cache": "hit"}
                    )
                return cached_response

            # Compress query
            compressed_query = self._compress_data(query) if self.enable_compression else query
            self.history.add_user_message(compressed_query)

            # System promt + Rerank contexts
            final_contexts = self.rag_pipeline._build_context(self._enhanced_search_with_reranking(query, root_span))
            self.system_prompt = system_prompt or self._load_default_system_prompt()

            # Query + History chat
            final_query = self._build_query(query, self._get_history_texts())

            # Get response
            response = self.rag_pipeline.generate_response(final_query, final_contexts, self.system_prompt, self.langfuse_client)

            compressed_response = self._compress_data(response) if self.enable_compression else response
            self.history.add_ai_message(compressed_response)
            self._trim_history()

            if root_span:
                root_span.update(
                    input={"query": query, "context": final_contexts},
                    output={"final_response": response},
                    status_message="success",
                )

            return response
        except Exception as e:
            error_msg = f"An error occurred while processing your request: {str(e)}"
            self.logger.error(error_msg, exc_info=True)
            try:
                compressed_error = self._compress_data(error_msg) if self.enable_compression else error_msg
                self.history.add_ai_message(compressed_error)
            except:
                pass
            if root_span:
                root_span.update(status_message="error", output={"error": str(e)})
            return error_msg

    def get_conversation_history(self) -> List[Dict[str, Any]]:
        try:
            messages = self.history.messages
            return [{
                "type": "human" if isinstance(msg, HumanMessage) else "ai",
                "content": self._decompress_data(msg.content) if self.enable_compression else msg.content,
                "timestamp": getattr(msg, 'timestamp', None)
            } for msg in messages]
        except Exception as e:
            self.logger.error(f"Error getting conversation history: {e}")
            return []

    def clear_history(self) -> bool:
        try:
            self.history.clear()
            return True
        except Exception as e:
            self.logger.error(f"Error clearing history: {e}")
            return False

    def get_session_stats(self) -> Dict[str, Any]:
        try:
            return {
                "session_id": self.session_id,
                "total_messages": len(self.history.messages),
                "compression_enabled": self.enable_compression,
                "reranker_available": self.reranker is not None,
                "max_history_pairs": self.max_history_pairs
            }
        except Exception as e:
            self.logger.error(f"Error getting session stats: {e}")
            return {}

In [None]:
print(f'Loading reranker model {config.reranker_model_name}...')
reranker_model = SentenceTransformer(config.reranker_model_name)
print(reranker_model)
print(f'Loading embedding {config.embedding_model_name}...')
embedding_model = SentenceTransformer(config.embedding_model_name)
print(embedding_model)

In [None]:
# Config Langfuse
langfuse_client = Langfuse(
    public_key=config.langfuse_public_key,
    secret_key=config.langfuse_secret_key,
    host=config.langfuse_host
)

# Initialize LLM
gemini = GeminiLLM(
    model_name=config.gemini_model_name,
    api_key=config.gemini_api_key,
    llm_config=config
)

# Config Qdrant database
qdrant_client = QdrantClient(
    url=config.qdrant_url,
    api_key=config.qdrant_api_key
)

# Config main RAG pipeline
rag_pipeline = RAGSingleVectorSearch(
    llm=gemini,
    embedding_model=embedding_model,
    qdrant_client=qdrant_client,
    config=config
)

# Config reranker
reranker = HistoryGuidedReranker(
    embedding_model=reranker_model,
    similarity_weight=0.7,
    history_weight=0.3,
    max_final_contexts=5,
)

# Semantic cache
semantic_cache = SemanticPromptCache(
    redis_url=config.redis_url,
    embedding_model=embedding_model
)

# Test chatbot
chatbot = ConversationalRAGChatbot(
    rag_pipeline=rag_pipeline,
    reranker=reranker,
    langfuse_client=langfuse_client,
    config=config,
    session_id="thuan phat",
    semantic_cache=semantic_cache
)

In [None]:
# Test chat
while True:
  try:
      query = input("User: ")
      if query.lower() in ["quit", "exit", "bye"]:
          print("Assistant: Goodbye!")
          break
      else:
          print("Assistant: ", chatbot.chat(query))
          print('\n\n')
  except Exception as e:
      raise e

User: Hello
Assistant:  Xin chào bạn! Tôi là chatbot assistant của trang The Doms, tôi có thể giúp gì cho bạn?



User: Hãy nêu một số điểm mới của LLMs DeepSeekR1 so với các LLMs khác hiện tại như GPT hay Gemini
Assistant:  Dạ, nội dung chính là: DeepSeek đã nâng cấp LLM DeepSeek-R1, đạt hiệu suất cạnh tranh với OpenAI o3 và Google Gemini 2.5 Pro. Bản cập nhật DeepSeek-R1-0528 cải thiện đáng kể ở các nhiệm vụ toán học, lập trình và logic. DeepSeek-R1-0528-Qwen3-8B có kích thước nhỏ hơn, có thể chạy trên một GPU duy nhất với VRAM chỉ 40GB. DeepSeek tuyên bố cải thiện khả năng suy luận, quản lý các tác vụ phức tạp, viết và chỉnh sửa văn bản dài, đồng thời giảm 50% ảo giác khi viết lại và tóm tắt. Mã nguồn và trọng số của DeepSeek-R1 được cấp phép tự do cho mục đích thương mại và cá nhân.



User: Đạt hiệu suất cạnh tranh như thế nào? Có số liệu thực tế không?
Assistant:  Dạ, theo nội dung thì GPT-3 chỉ đạt độ chính xác 68% khi trả lời các câu hỏi đố vui trong one-shot TriviaQA, trong kh