In [None]:
import os
import uuid
from datetime import datetime
from typing import List, Dict, Optional
import re

import chromadb
from chromadb.utils import embedding_functions
from langchain_core.prompts import ChatPromptTemplate
from langchain_ollama import OllamaLLM
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.documents import Document
from langchain.vectorstores import Chroma
from pymongo import MongoClient
from dotenv import load_dotenv
from sentence_transformers import CrossEncoder

load_dotenv()

# New cross-encoder for re-ranking
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

class MongoConnectionHandler:
    def __init__(self):
        self.client = MongoClient(os.getenv('MONGO_URL'))
        self.db = self.client.get_database(os.getenv('MONGO_DB'))
    
    def get_related_documents(self) -> List[Dict]:
        # Projection to reduce data transfer
        orders = list(self.db.orders.find({}, {
            'status': 1,
            'createdAt': 1,
            'orderContractType': 1,
            'orderCreatedBy': 1,
            'orderItems': 1
        }))
        
        client_ids = [order['orderCreatedBy'] for order in orders]
        item_ids = [item_id for order in orders for item_id in order['orderItems']]
        
        # Batch processing with projection
        clients = {str(c['_id']): c for c in self.db.clients.find(
            {'_id': {'$in': client_ids}},
            {'legalNameOfCompany': 1, 'contactFirstName': 1, 
             'contactLastName': 1, 'contactEmail': 1,
             'physicalAddressOfCompany': 1, 'preferredCoffeeTypes': 1}
        )}
        
        items = {str(i['_id']): i for i in self.db.order_items.find(
            {'_id': {'$in': item_ids}},
            {'r_id': 1, 'totalAmount': 1, 'price': 1, 'status': 1, 'updatedAt': 1}
        )}

        return [{
            "order": order,
            "client": clients.get(str(order['orderCreatedBy']), {}),
            "items": [items[str(item_id)] for item_id in order['orderItems'] if str(item_id) in items]
        } for order in orders]

class DocumentProcessor:
    def create_chunks(self, data: List[Dict]) -> List[Dict]:
        chunks = []
        for rec in data:
            if rec["client"] and rec["items"]:
                chunks.extend(self._create_hierarchical_chunks(rec["order"], rec["client"], rec["items"]))
        return chunks

    def _create_hierarchical_chunks(self, order: Dict, client: Dict, items: List[Dict]) -> List[Dict]:
        # Parent chunk for order overview
        order_chunk = {
            "text": f"""\
            ## Order Overview
            - ID: {order['_id']}
            - Status: {order.get('status', 'N/A')}
            - Created: {self._format_date(order.get('createdAt'))}
            - Contract: {order.get('orderContractType', 'N/A')}
            - Client ID: {client.get('_id', 'N/A')}
            - Total Value: ${sum(item.get('price', 0) * item.get('totalAmount', 0) for item in items):.2f}
            """,
            "metadata": {
                "chunk_type": "order_summary",
                "order_id": str(order["_id"]),
                "client_id": str(client["_id"]),
                "status": order.get("status", "UNKNOWN"),
                "total_value": sum(item.get("price", 0) * item.get("totalAmount", 0)),
                "creation_date": order.get("createdAt", ""),
                "coffee_types": client.get("preferredCoffeeTypes", [])
            }
        }

        # Client chunk
        client_chunk = {
            "text": f"""\
            ## Client Profile
            - Company: {client.get('legalNameOfCompany', 'N/A')}
            - Contact: {client.get('contactFirstName', '')} {client.get('contactLastName', '')}
            - Email: {client.get('contactEmail', 'N/A')}
            - Address: {client.get('physicalAddressOfCompany', {}).get("address", "")}, {client.get('physicalAddressOfCompany', {}).get("city", "")}
            - Preferred Coffee: {', '.join(client.get('preferredCoffeeTypes', [])) or 'None'}
            """,
            "metadata": {
                "chunk_type": "client_profile",
                "client_id": str(client["_id"]),
                "company": client.get('legalNameOfCompany', ''),
                "contact_name": f"{client.get('contactFirstName', '')} {client.get('contactLastName', '')}",
                "coffee_preferences": client.get('preferredCoffeeTypes', [])
            }
        }

        # Item chunks
        item_chunks = []
        for item in items:
            item_chunk = {
                "text": f"""\
                ## Order Item
                - ID: {item.get('r_id', 'N/A')}
                - Amount: {item.get('totalAmount', 0)} lbs
                - Price: ${item.get('price', 0):.2f}
                - Status: {item.get('status', 'N/A')}
                - Updated: {self._format_date(item.get('updatedAt'))}
                """,
                "metadata": {
                    "chunk_type": "order_item",
                    "item_id": str(item.get('_id', '')),
                    "order_id": str(order["_id"]),
                    "status": item.get("status", "UNKNOWN"),
                    "price": item.get("price", 0),
                    "amount": item.get("totalAmount", 0)
                }
            }
            item_chunks.append(item_chunk)

        return [order_chunk, client_chunk] + item_chunks

    def _format_date(self, date_str: str) -> str:
        if not date_str:
            return "N/A"
        try:
            return datetime.fromisoformat(date_str).strftime("%Y-%m-%d %H:%M")
        except:
            return "Invalid Date"

class EnhancedVectorStore:
    def __init__(self):
        # Better embedding model
        self.embedder = embedding_functions.SentenceTransformerEmbeddingFunction(
            model_name="BAAI/bge-base-en-v1.5",
            device="cuda" if os.getenv('USE_CUDA', 'false').lower() == 'true' else "cpu"
        )
        self.client = chromadb.PersistentClient(path="./chroma_dbs")
        self.collection = self.client.get_or_create_collection(
            name="coffee_orders_processing",
            embedding_function=self.embedder,
            metadata={"hnsw:space": "cosine"}  # Better similarity metric
        )
    
    def index_documents(self, chunks: List[Dict]):
        ids = [str(uuid.uuid4()) for _ in chunks]
        documents = [self._clean_text(chunk['text']) for chunk in chunks]
        metadatas = [chunk['metadata'] for chunk in chunks]
        
        self.collection.upsert(
            ids=ids,
            documents=documents,
            metadatas=metadatas
        )
    
    def query(self, query: str, filters: Optional[Dict] = None, n_results: int = 20) -> List[Dict]:
        # First-stage retrieval
        results = self.collection.query(
            query_texts=[query],
            where=filters,
            n_results=min(n_results*3, 100),  # Cast wider net
            include=["metadatas", "documents", "distances"]
        )
        
        # Re-rank with cross-encoder
        pairs = [(query, doc) for doc in results["documents"][0]]
        scores = cross_encoder.predict(pairs)
        
        # Combine and sort
        combined = sorted(zip(results["documents"][0], results["metadatas"][0], scores),
                         key=lambda x: x[2], reverse=True)
        
        return [{
            "text": doc,
            "metadata": meta,
            "score": score
        } for doc, meta, score in combined[:n_results]]

    def _clean_text(self, text: str) -> str:
        text = re.sub(r'\s+', ' ', text).strip()
        return re.sub(r'(?<!\n)\n(?!\n)', ' ', text)

class CoffeeRAG:
    def __init__(self):
        self.mongo = MongoConnectionHandler()
        self.processor = DocumentProcessor()
        self.vector_store = EnhancedVectorStore()
        self.llm = OllamaLLM(model="llama3")  # More capable model
        
        self.prompt = ChatPromptTemplate.from_template(
            """You are an expert coffee order analyst. Use the following context to answer the question.
            Today's Date: {current_date}
            
            Context:
            {context}
            
            Question: {question}
            
            Guidelines:
            1. Be precise with numbers and dates
            2. If asking about client info, verify across multiple orders
            3. For status inquiries, check both order and item statuses
            4. When unsure, ask for clarification
            5. Use markdown for structured responses
            6. Always reference source metadata IDs
            
            Format your answer as:
            ### Analysis Result
            [Your detailed answer here]
            
            ### Sources
            - Metadata IDs: {metadata_ids}
            
            If no relevant information is found, respond:
            "I couldn't find sufficient information to answer this question. Please contact support@coffee.example.com for assistance." """
        )
    
    def initialize(self):
        """Load and index data"""
        data = self.mongo.get_related_documents()
        chunks = self.processor.create_chunks(data)
        self.vector_store.index_documents(chunks)
    
    def retrieve(self, query: str) -> List[Dict]:
        # Add query analysis for metadata filtering
        filters = self._analyze_query_for_filters(query)
        return self.vector_store.query(query, filters=filters, n_results=8)
    
    def generate(self, query: str) -> str:
        context = self.retrieve(query)
        if not context:
            return "No relevant information found. Please contact support."
        
        context_str = "\n\n---\n\n".join([f"{r['text']}" for r in context])
        metadata_ids = ", ".join(set(str(r['metadata'].get('order_id', r['metadata'].get('client_id', ''))) for r in context))
        
        chain = (
            {
                "context": lambda _: context_str,
                "question": RunnablePassthrough(),
                "current_date": lambda _: datetime.now().strftime("%Y-%m-%d"),
                "metadata_ids": lambda _: metadata_ids
            }
            | self.prompt
            | self.llm
            | StrOutputParser()
        )
        return chain.invoke(query)
    
    def _analyze_query_for_filters(self, query: str) -> Optional[Dict]:
        # Simple pattern matching for status queries
        status_match = re.search(r'\b(status)\b.*?\b(completed|pending|cancelled)\b', query, re.I)
        if status_match:
            return {"status": {"$eq": status_match.group(2).upper()}}
        
        # Detect client-related queries
        if any(word in query.lower() for word in ["client", "company", "contact"]):
            return {"chunk_type": {"$in": ["client_profile", "order_summary"]}}
        
        return None

if __name__ == "__main__":
    rag = CoffeeRAG()
    rag.initialize()  # Initial indexing
    
    # Example usage
    while True:
        query = input("\nEnter your query (or 'exit' to quit): ")
        if query.lower() == 'exit':
            break
        print("\nProcessing...\n")
        print(rag.generate(query))