In [None]:
import os
import asyncio
import time
from typing import List, Optional
from redis import Redis
from redisvl.utils.vectorize import HFTextVectorizer
from redisvl.index import AsyncSearchIndex
from redisvl.query import VectorQuery
from redisvl.extensions.cache.llm import SemanticCache
from redisvl.extensions.session import SemanticMessageHistory
import openai

# ============================================
# CONFIGURATION
# ============================================

# Redis connection
REDIS_HOST = os.getenv("REDIS_HOST", "localhost")
REDIS_PORT = os.getenv("REDIS_PORT", "6379")
REDIS_PASSWORD = os.getenv("REDIS_PASSWORD", "")
REDIS_URL = f"redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}"

# OpenAI
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
openai.api_key = OPENAI_API_KEY

# Cache settings
CACHE_TTL = 300  # 5 minutes
CACHE_DISTANCE_THRESHOLD = 0.2

# Session settings
SESSION_TTL = 3600  # 1 hour
NUM_RESULTS = 5

# ============================================
# SEMANTIC CACHE SETUP
# ============================================

def setup_semantic_cache(vectorizer: HFTextVectorizer) -> SemanticCache:
    """
    Initialize semantic cache for LLM responses
    
    Args:
        vectorizer: HFTextVectorizer instance
        
    Returns:
        SemanticCache instance
    """
    print("Setting up semantic cache...")
    
    cache = SemanticCache(
        name="llmcache",
        vectorizer=vectorizer,
        redis_url=REDIS_URL,
        ttl=CACHE_TTL,
        distance_threshold=CACHE_DISTANCE_THRESHOLD,
        overwrite=True
    )
    
    print(f"Cache created with TTL={CACHE_TTL}s, threshold={CACHE_DISTANCE_THRESHOLD}")
    return cache


# ============================================
# CHATBOT CLASS WITH CACHE AND MEMORY
# ============================================

class ChatBot:
    """
    Production-ready RAG chatbot with semantic caching and session memory
    """
    
    def __init__(
        self, 
        index: AsyncSearchIndex, 
        vectorizer: HFTextVectorizer,
        cache: SemanticCache,
        user: str
    ):
        """
        Initialize ChatBot
        
        Args:
            index: AsyncSearchIndex for document search
            vectorizer: HFTextVectorizer for embeddings
            cache: SemanticCache for response caching
            user: User identifier for session management
        """
        self.index = index
        self.vectorizer = vectorizer
        self.cache = cache
        self.user = user
        
        # Initialize session manager
        self.session_manager = SemanticMessageHistory(
            name=f"chat_session_{user}",
            redis_url=REDIS_URL,
            ttl=SESSION_TTL
        )
        
        # Performance metrics
        self.metrics = {
            "cache_hits": 0,
            "cache_misses": 0,
            "total_queries": 0,
            "avg_response_time": 0
        }
        
        print(f"ChatBot initialized for user: {user}")
    
    def embed_query(self, query: str):
        """Convert query to vector"""
        return self.vectorizer.embed(query)
    
    async def retrieve_context(self, query_vector, num_results: int = NUM_RESULTS) -> str:
        """
        Retrieve relevant context from Redis
        
        Args:
            query_vector: Embedded query vector
            num_results: Number of chunks to retrieve
            
        Returns:
            Combined context string
        """
        results = await self.index.query(
            VectorQuery(
                vector=query_vector,
                vector_field_name="text_embedding",
                num_results=num_results,
                return_fields=["content"],
                return_score=True
            )
        )
        
        context = "\n\n".join([result['content'] for result in results])
        return context
    
    @staticmethod
    async def generate_llm_response(
        query: str, 
        context: str, 
        session: Optional[List] = None
    ) -> str:
        """
        Generate LLM response with context and conversation history
        
        Args:
            query: User question
            context: Retrieved context from documents
            session: Conversation history
            
        Returns:
            LLM-generated answer
        """
        # Build conversation history
        conversation = ""
        if session:
            for msg in session[-5:]:  # Last 5 turns only
                role = msg.get("role", "unknown")
                content = msg.get("content", "")
                conversation += f"{role.capitalize()}: {content}\n"
        
        # Construct prompt
        prompt = f'''Use the provided context and conversation history to answer 
        the user's question. If the question refers to something from earlier in 
        the conversation, use that context. If you cannot answer based on the 
        provided information, say so.
        
        Context from documents:
        {context}
        
        Conversation history:
        {conversation if conversation else "No previous conversation"}
        
        Current question: {query}
        
        Answer:'''
        
        response = await openai.AsyncClient().chat.completions.create(
            model="gpt-4o-mini",
            messages=[{"role": "user", "content": prompt}],
            temperature=0.2,
            seed=42
        )
        
        return response.choices[0].message.content
    
    async def clear_history(self):
        """Clear conversation history"""
        self.session_manager.clear()
        print(f"Session history cleared for user: {self.user}")
    
    async def answer_question(self, query: str) -> dict:
        """
        Answer question with caching and session memory
        
        Args:
            query: User question
            
        Returns:
            Dictionary with answer and metadata
        """
        start_time = time.time()
        self.metrics["total_queries"] += 1
        
        # Embed the query
        query_vector = self.embed_query(query)
        
        # Check cache first
        cached_result = self.cache.check(vector=query_vector)
        
        if cached_result:
            # Cache hit
            answer = cached_result[0]['response']
            self.metrics["cache_hits"] += 1
            used_cache = True
            
            # Still add to session history
            self.session_manager.add_messages([
                {"role": "user", "content": query},
                {"role": "assistant", "content": answer}
            ])
            
            elapsed = time.time() - start_time
            
            return {
                "answer": answer,
                "used_cache": True,
                "used_memory": False,
                "response_time": elapsed,
                "source": "cache"
            }
        
        # Cache miss - run full RAG
        self.metrics["cache_misses"] += 1
        
        # Get conversation history
        session = self.session_manager.messages
        used_memory = bool(session)
        
        # Retrieve context
        context = await self.retrieve_context(query_vector)
        
        # Generate response
        answer = await self.generate_llm_response(query, context, session)
        
        # Store in cache
        self.cache.store(
            prompt=query,
            response=answer,
            vector=query_vector
        )
        
        # Add to session history
        self.session_manager.add_messages([
            {"role": "user", "content": query},
            {"role": "assistant", "content": answer}
        ])
        
        elapsed = time.time() - start_time
        
        return {
            "answer": answer,
            "used_cache": False,
            "used_memory": used_memory,
            "response_time": elapsed,
            "source": "llm"
        }
    
    def get_metrics(self) -> dict:
        """Get performance metrics"""
        cache_hit_rate = (
            self.metrics["cache_hits"] / self.metrics["total_queries"] * 100
            if self.metrics["total_queries"] > 0 else 0
        )
        
        return {
            **self.metrics,
            "cache_hit_rate": f"{cache_hit_rate:.1f}%"
        }


# ============================================
# DEMO FUNCTIONS
# ============================================

async def demo_semantic_caching(chatbot: ChatBot):
    """Demonstrate semantic caching with similar queries"""
    print("\n" + "=" * 60)
    print("DEMO: Semantic Caching")
    print("=" * 60)
    
    queries = [
        "What are Nike's main revenue drivers?",
        "What are Nike's primary sources of income?",  # Semantically similar
        "Tell me about Nike's biggest revenue streams?"  # Also similar
    ]
    
    for i, query in enumerate(queries, 1):
        print(f"\n--- Query {i} ---")
        print(f"Question: {query}")
        
        result = await chatbot.answer_question(query)
        
        print(f"Answer: {result['answer'][:200]}...")
        print(f"Source: {result['source']}")
        print(f"Cache hit: {result['used_cache']}")
        print(f"Response time: {result['response_time']:.3f}s")


async def demo_session_memory(chatbot: ChatBot):
    """Demonstrate session memory with follow-up questions"""
    print("\n" + "=" * 60)
    print("DEMO: Session Memory")
    print("=" * 60)
    
    # Clear history for clean demo
    await chatbot.clear_history()
    
    conversation = [
        "What was Nike's marketing strategy?",
        "What role do athlete partnerships play in that?",  # Follow-up
        "How much do they spend on it?"  # Another follow-up
    ]
    
    for i, query in enumerate(conversation, 1):
        print(f"\n--- Turn {i} ---")
        print(f"User: {query}")
        
        result = await chatbot.answer_question(query)
        
        print(f"Assistant: {result['answer'][:200]}...")
        print(f"Used memory: {result['used_memory']}")
        print(f"Response time: {result['response_time']:.3f}s")


async def demo_performance_comparison(chatbot: ChatBot):
    """Compare performance with and without optimizations"""
    print("\n" + "=" * 60)
    print("DEMO: Performance Comparison")
    print("=" * 60)
    
    # Same query asked twice
    query = "What are Nike's operating segments?"
    
    print(f"\nQuery: {query}")
    
    # First call (cache miss)
    print("\nFirst call (cache miss):")
    result1 = await chatbot.answer_question(query)
    print(f"Response time: {result1['response_time']:.3f}s")
    print(f"Source: {result1['source']}")
    
    # Second call (cache hit)
    print("\nSecond call (cache hit):")
    result2 = await chatbot.answer_question(query)
    print(f"Response time: {result2['response_time']:.3f}s")
    print(f"Source: {result2['source']}")
    
    # Calculate improvement
    speedup = result1['response_time'] / result2['response_time']
    print(f"\nSpeedup: {speedup:.1f}x faster")


# ============================================
# MAIN EXECUTION
# ============================================

async def main():
    """
    Main demo function
    
    Note: This assumes you have already:
    1. Processed and chunked your documents
    2. Generated embeddings
    3. Created and loaded a Redis index
    
    See the basic RAG example for those steps.
    """
    print("=" * 60)
    print("Production RAG with Caching and Memory")
    print("=" * 60)
    
    # Initialize components (assumes index already exists)
    vectorizer = HFTextVectorizer(
        model="sentence-transformers/all-MiniLM-L6-v2",
        dims=384
    )
    
    # Setup cache
    cache = setup_semantic_cache(vectorizer)
    
    # Setup async index (assumes schema already created)
    schema = {
        "index": {"name": "redisvl_rag", "prefix": "doc"},
        "fields": [
            {"name": "chunk_id", "type": "numeric"},
            {"name": "content", "type": "text"},
            {
                "name": "text_embedding",
                "type": "vector",
                "attrs": {
                    "dims": 384,
                    "algorithm": "flat",
                    "distance_metric": "cosine"
                }
            }
        ]
    }
    
    from redisvl.index import AsyncSearchIndex
    async_index = AsyncSearchIndex.from_dict(schema, redis_url=REDIS_URL)
    
    # Initialize chatbot
    chatbot = ChatBot(
        index=async_index,
        vectorizer=vectorizer,
        cache=cache,
        user="demo_user"
    )
    
    # Run demos
    await demo_semantic_caching(chatbot)
    await demo_session_memory(chatbot)
    await demo_performance_comparison(chatbot)
    
    # Print final metrics
    print("\n" + "=" * 60)
    print("Final Metrics")
    print("=" * 60)
    metrics = chatbot.get_metrics()
    for key, value in metrics.items():
        print(f"{key}: {value}")


if __name__ == "__main__":
    asyncio.run(main())