In [4]:
import os
import re
import time
import uuid
import shutil
from pathlib import Path
from typing import List, Dict, Tuple, Optional
from datetime import datetime, timedelta

import chromadb
from dotenv import load_dotenv
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import ChatPromptTemplate
from langchain_community.document_loaders import TextLoader
from langchain_ollama import OllamaEmbeddings, ChatOllama
from langchain_chroma import Chroma


class CarrierRAGSystem:
    """
    Complete RAG system for Carrier earnings analysis using LangChain and Ollama
    """
    
    def __init__(self, config: Optional[Dict] = None):
        """Initialize the RAG system with configuration"""
        load_dotenv()
        
        # Default configuration
        default_config = {
            "chroma_path": "chroma_20250806",
            "collection_name": "earnings_markdown",
            "data_path": "/Users/haritha/hari/carrier_earnings_release/markdown_output_20250806_094639",
            "ollama_host": "http://localhost:11434",
            "embed_model": "nomic-embed-text",
            "chat_model": "llama3",
            "chunk_size": 1000,
            "chunk_overlap": 200,
            "batch_size": 32,
            "clear_collection": True
        }
        
        # Merge with provided config
        self.config = {**default_config, **(config or {})}
        
        # Apply environment variable overrides
        for key, env_key in [
            ("chroma_path", "CHROMA_PATH"),
            ("collection_name", "COLLECTION_NAME"),
            ("data_path", "DATA_PATH"),
            ("ollama_host", "OLLAMA_HOST"),
            ("embed_model", "EMBED_MODEL"),
            ("chat_model", "CHAT_MODEL")
        ]:
            if env_key in os.environ:
                self.config[key] = os.environ[env_key]
        
        # Set up paths
        self.base_dir = Path.cwd()
        self.chroma_dir = str(self.base_dir / self.config["chroma_path"])
        
        # Initialize components
        self.embedding_function = None
        self.vectorstore = None
        self.llm = None
        
        # Regex patterns for cleaning
        self._base64_line = re.compile(r"^[A-Za-z0-9+/=]{80,}$")
        self._data_uri = re.compile(r"^data:image/[^;]+;base64,", re.IGNORECASE)
        
        # Prompt template for financial analysis
        self.prompt_template = """You are an expert financial analyst specializing in Carrier Corporation earnings.

Answer from the provided context snippets. For forecasting questions, analyze historical patterns and trends from the available data.

Rules:
- Quote figures exactly as written in the source material
- After each figure or claim, include a bracketed citation like [S1] pointing to the snippet ID
- Be concise but accurate
- Do NOT guess or invent numbers not found in the context
- For forecasting questions: analyze historical trends, seasonal patterns, and provide a data-driven estimate with clear reasoning
- If insufficient data for forecasting, explain what data is missing
- Focus on financial metrics, trends, and business performance

Context Snippets:
{context}

---
Question: {question}

Answer:"""

    def log(self, msg: str) -> None:
        """Logging utility"""
        print(f"[CarrierRAG] {msg}")

    def clean_markdown(self, text: str) -> str:
        """Clean markdown text by removing base64 images and overly long lines"""
        cleaned_lines = []
        for line in text.splitlines():
            # Skip base64 images, data URIs, and very long lines
            if (self._data_uri.search(line) or 
                self._base64_line.match(line) or 
                len(line) > 1200):
                continue
            cleaned_lines.append(line)
        return "\n".join(cleaned_lines).strip()

    def ensure_fresh_chroma_dir(self) -> None:
        """Prepare a clean Chroma directory"""
        if os.path.exists(self.chroma_dir):
            shutil.rmtree(self.chroma_dir, ignore_errors=True)
        os.makedirs(self.chroma_dir, exist_ok=True)
        
        # Test write access
        test_file = Path(self.chroma_dir) / ".write_test"
        with open(test_file, "w", encoding="utf-8") as f:
            f.write("ok")
        test_file.unlink(missing_ok=True)
        self.log(f"Persistence directory ready: {self.chroma_dir}")

    def load_documents(self) -> List[Document]:
        """Load and clean markdown documents"""
        data_path = Path(self.config["data_path"])
        if not data_path.exists():
            raise FileNotFoundError(f"Data path not found: {data_path}")
        
        md_files = list(data_path.glob("*.md"))
        if not md_files:
            raise RuntimeError(f"No .md files found in: {data_path}")
        
        documents = []
        for file_path in sorted(md_files):
            try:
                loader = TextLoader(str(file_path), encoding="utf-8")
                docs = loader.load()
                for doc in docs:
                    doc.page_content = self.clean_markdown(doc.page_content)
                    documents.append(doc)
            except Exception as e:
                self.log(f"Warning: Failed to load {file_path}: {e}")
                continue
        
        self.log(f"Loaded {len(documents)} markdown files from {data_path}")
        return documents

    def split_documents(self, documents: List[Document]) -> List[Document]:
        """Split documents into chunks"""
        splitter = RecursiveCharacterTextSplitter(
            chunk_size=self.config["chunk_size"],
            chunk_overlap=self.config["chunk_overlap"],
            length_function=len,
            add_start_index=True,
        )
        
        chunks = splitter.split_documents(documents)
        self.log(f"Split {len(documents)} documents into {len(chunks)} chunks")
        
        # Show preview
        if chunks:
            preview_chunk = chunks[min(5, len(chunks) - 1)]
            print(f"Preview chunk: {preview_chunk.page_content[:300]}...")
            print(f"Metadata: {preview_chunk.metadata}")
        
        return chunks

    def get_embedding_function(self):
        """Get the embedding function"""
        if self.embedding_function is None:
            self.embedding_function = OllamaEmbeddings(
                model=self.config["embed_model"],
                base_url=self.config["ollama_host"]
            )
        return self.embedding_function

    def build_vectorstore(self, force_rebuild: bool = False) -> None:
        """Build or load the vector store"""
        if not force_rebuild and os.path.exists(self.chroma_dir):
            self.log("Loading existing vector store...")
            try:
                self.vectorstore = self._load_existing_vectorstore()
                count = self.vectorstore._collection.count()
                self.log(f"Loaded existing vector store with {count} chunks")
                return
            except Exception as e:
                self.log(f"Failed to load existing store: {e}. Rebuilding...")
        
        # Build new vector store
        self.log("Building new vector store...")
        self.ensure_fresh_chroma_dir()
        
        # Load and process documents
        documents = self.load_documents()
        chunks = self.split_documents(documents)
        
        # Create vector store with batched embedding
        self._create_vectorstore_with_batching(chunks)

    def _load_existing_vectorstore(self) -> Chroma:
        """Load existing vector store"""
        client = chromadb.PersistentClient(path=self.chroma_dir)
        return Chroma(
            client=client,
            collection_name=self.config["collection_name"],
            embedding_function=self.get_embedding_function()
        )

    def _create_vectorstore_with_batching(self, chunks: List[Document]) -> None:
        """Create vector store with batched embedding for efficiency"""
        texts = [chunk.page_content for chunk in chunks]
        metadatas = [chunk.metadata for chunk in chunks]
        ids = [str(uuid.uuid4()) for _ in texts]
        
        total = len(texts)
        batch_size = self.config["batch_size"]
        total_batches = (total + batch_size - 1) // batch_size
        
        # Get embedding function
        embedding_fn = self.get_embedding_function()
        
        # Batch embedding with progress tracking
        self.log(f"Embedding {total} chunks in {total_batches} batches of {batch_size}...")
        embeddings = []
        start_time = datetime.now()
        
        for batch_idx, i in enumerate(range(0, total, batch_size), start=1):
            batch_texts = texts[i:i + batch_size]
            t0 = time.time()
            
            try:
                batch_embeddings = embedding_fn.embed_documents(batch_texts)
                embeddings.extend(batch_embeddings)
            except Exception as e:
                self.log(f"Error in batch {batch_idx}: {e}")
                raise
            
            # Progress tracking
            dt = time.time() - t0
            elapsed = datetime.now() - start_time
            avg_time = elapsed / batch_idx
            eta = avg_time * (total_batches - batch_idx)
            
            print(f"[{batch_idx:>3}/{total_batches}] +{len(batch_texts):>3} in {dt:.2f}s | "
                  f"done {i + len(batch_texts):>5}/{total} | "
                  f"ETA {str(timedelta(seconds=int(eta.total_seconds())))}")
        
        # Create Chroma collection
        client = chromadb.PersistentClient(path=self.chroma_dir)
        collection = client.get_or_create_collection(
            name=self.config["collection_name"],
            metadata={"hnsw:space": "cosine"}
        )
        
        # Clear if requested
        if self.config["clear_collection"]:
            try:
                collection.delete(where={})
                self.log("Cleared existing collection contents")
            except Exception as e:
                self.log(f"Could not clear collection: {e}")
        
        # Batch upsert
        for i in range(0, total, batch_size):
            end_idx = min(i + batch_size, total)
            collection.upsert(
                ids=ids[i:end_idx],
                documents=texts[i:end_idx],
                metadatas=metadatas[i:end_idx],
                embeddings=embeddings[i:end_idx]
            )
        
        # Create LangChain wrapper
        self.vectorstore = Chroma(
            client=client,
            collection_name=self.config["collection_name"],
            embedding_function=embedding_fn
        )
        
        final_count = collection.count()
        self.log(f"✅ Created vector store with {final_count} chunks")

    def get_llm(self):
        """Get the chat LLM"""
        if self.llm is None:
            self.llm = ChatOllama(
                model=self.config["chat_model"],
                base_url=self.config["ollama_host"],
                temperature=0,  # Deterministic for financial analysis
                num_ctx=4096    # Context window
            )
        return self.llm

    def _doc_key(self, doc) -> Tuple[str, int]:
        """Create unique key for document deduplication"""
        source = doc.metadata.get("source", "")
        start_index = doc.metadata.get("start_index", -1)
        return (source, start_index)

    def _filename(self, doc) -> str:
        """Extract filename from document metadata"""
        source = doc.metadata.get("source", "")
        return Path(source).name if source else "unknown.md"

    def _build_context_and_sources(self, docs: List) -> Tuple[str, Dict[str, str]]:
        """Build numbered context with source mapping"""
        context_lines = []
        sources_map = {}
        
        for i, doc in enumerate(docs, start=1):
            source_id = f"S{i}"
            snippet = doc.page_content.strip()
            # Trim very long snippets to reduce noise
            if len(snippet) > 1200:
                snippet = snippet[:1200] + "..."
            
            filename = self._filename(doc)
            sources_map[source_id] = filename
            context_lines.append(f"[{source_id}] {filename}\n{snippet}\n")
        
        return "\n".join(context_lines), sources_map

    def expand_query_for_forecasting(self, query: str) -> List[str]:
        """Generate additional search queries for forecasting questions"""
        base_queries = [query]
        
        # Check if this is a forecasting question
        forecast_indicators = ['forecast', 'predict', 'estimate', 'expect', 'outlook', 'guidance', 'future', 'next', 'upcoming']
        is_forecast = any(indicator in query.lower() for indicator in forecast_indicators)
        
        if is_forecast:
            # Extract key terms
            quarters = ['Q1', 'Q2', 'Q3', 'Q4', 'first quarter', 'second quarter', 'third quarter', 'fourth quarter']
            metrics = ['EPS', 'earnings', 'revenue', 'income', 'profit', 'margin', 'growth']
            
            # Add historical queries
            for quarter in quarters:
                if quarter.lower() in query.lower():
                    base_queries.extend([
                        f"{quarter} earnings historical performance",
                        f"{quarter} EPS previous years",
                        f"{quarter} financial results trends"
                    ])
                    break
            
            for metric in metrics:
                if metric.lower() in query.lower():
                    base_queries.extend([
                        f"{metric} quarterly trends",
                        f"historical {metric} performance",
                        f"{metric} year over year"
                    ])
                    break
        
        return base_queries

    def retrieve_context(
        self, 
        query: str, 
        k_primary: int = 12, 
        k_mmr: int = 8, 
        fetch_k: int = 30, 
        min_score: float = 0.4
    ) -> Tuple[str, Dict[str, str], List]:
        """Enhanced retrieval with query expansion for forecasting"""
        if not self.vectorstore:
            raise RuntimeError("Vector store not initialized. Call build_vectorstore() first.")
        
        # Expand query for better forecasting retrieval
        search_queries = self.expand_query_for_forecasting(query)
        all_docs = []
        seen_keys = set()
        
        # Search with multiple queries
        for search_query in search_queries:
            # Primary retrieval with relevance scores
            try:
                primary_results = self.vectorstore.similarity_search_with_relevance_scores(
                    search_query, k=k_primary
                )
                
                # Add primary results (with score filtering)
                for doc, score in primary_results:
                    if score is None or score >= min_score:
                        key = self._doc_key(doc)
                        if key not in seen_keys:
                            seen_keys.add(key)
                            all_docs.append(doc)
                
                # Diversity retrieval with MMR
                mmr_docs = self.vectorstore.max_marginal_relevance_search(
                    search_query, k=k_mmr, fetch_k=fetch_k
                )
                
                # Add MMR results (for diversity)
                for doc in mmr_docs:
                    key = self._doc_key(doc)
                    if key not in seen_keys:
                        seen_keys.add(key)
                        all_docs.append(doc)
                        
            except Exception as e:
                self.log(f"Warning: Error in retrieval for query '{search_query}': {e}")
                continue
        
        if not all_docs:
            return "", {}, []
        
        # Limit total documents to avoid token overflow
        max_docs = 15
        if len(all_docs) > max_docs:
            all_docs = all_docs[:max_docs]
        
        context_text, sources_map = self._build_context_and_sources(all_docs)
        return context_text, sources_map, all_docs

    def answer_query(self, question: str, **retrieval_kwargs) -> Dict[str, any]:
        """Answer a query using the RAG system"""
        # Retrieve context
        context, sources_map, docs = self.retrieve_context(question, **retrieval_kwargs)
        
        if not context:
            return {
                "answer": "No relevant information found. Try rephrasing your question or adjusting search parameters.",
                "sources": {},
                "context_docs": 0
            }
        
        # Create prompt and get LLM response
        prompt = ChatPromptTemplate.from_template(self.prompt_template)
        messages = prompt.format_messages(context=context, question=question)
        
        llm = self.get_llm()
        response = llm.invoke(messages)
        answer = getattr(response, "content", str(response))
        
        return {
            "answer": answer,
            "sources": sources_map,
            "context_docs": len(docs),
            "question": question
        }

    def interactive_query(self):
        """Interactive query loop"""
        print("\n🔍 Carrier RAG System - Interactive Mode")
        print("Type 'quit' to exit, 'rebuild' to rebuild the index")
        print("-" * 50)
        
        while True:
            try:
                question = input("\n📊 Ask about Carrier's earnings: ").strip()
                
                if question.lower() in ['quit', 'exit', 'q']:
                    print("Goodbye! 👋")
                    break
                
                if question.lower() == 'rebuild':
                    print("Rebuilding vector store...")
                    self.build_vectorstore(force_rebuild=True)
                    continue
                
                if not question:
                    continue
                
                # Get answer
                result = self.answer_query(question)
                
                # Display results
                print(f"\n--- Answer ---")
                print(result["answer"])
                
                print(f"\n--- Sources ({result['context_docs']} documents) ---")
                for source_id, filename in result["sources"].items():
                    print(f"{source_id}: {filename}")
                
            except KeyboardInterrupt:
                print("\n\nGoodbye! 👋")
                break
            except Exception as e:
                print(f"Error: {e}")


# Convenience functions for backward compatibility
def build_database(config: Optional[Dict] = None):
    """Build the vector database"""
    rag = CarrierRAGSystem(config)
    rag.build_vectorstore(force_rebuild=True)
    return rag

def query_database(question: str, config: Optional[Dict] = None, **kwargs):
    """Query the existing database"""
    rag = CarrierRAGSystem(config)
    rag.build_vectorstore(force_rebuild=False)
    return rag.answer_query(question, **kwargs)


# Example usage and testing
if __name__ == "__main__":
    # Example 1: Direct usage
    print("🚀 Initializing Carrier RAG System...")
    
    # Custom configuration (optional)
    custom_config = {
        "chunk_size": 1000,
        "chunk_overlap": 200,
        "chat_model": "llama3",  # or "mistral", "codellama", etc.
        "embed_model": "nomic-embed-text"  # or "mxbai-embed-large"
    }
    
    # Initialize system
    rag_system = CarrierRAGSystem(custom_config)
    
    # Build or load vector store
    rag_system.build_vectorstore()
    
    # Example queries
    
    test_questions = [
        "What was Carrier's GAAP EPS in Q1 2024?",
        "Using historical data, forecast Carrier's operating margin for Q4 2024",
        """Using historical performance data from all available past quarters, forecast Carrier's GAAP EPS for Q3 2025. Provide:
        1. A point estimate and 80% confidence interval
        2. Your analytical methodology and reasoning
        3. Specific past quarters that most influence your projection
        4. Key assumptions and risk factors that could affect accuracy
        5. Seasonal patterns or trends identified in the historical data
        6. A comparison to any available analyst consensus, if found""",
        "What is a reasonable forecast for Carrier's GAAP EPS in Q2 2025 based on Q2 performance in previous years?",
        "Using historical performance data from all available past quarters, forecast Carrier's GAAP EPS for Q3 2025. Provide your reasoning and reference specific past quarters that influence the projection.",
        "What are the main business segments of Carrier?",
        "Show me Carrier's Q2 EPS results for the last 3-4 years",
        "What was Carrier's Q2 2024 GAAP EPS and adjusted EPS?"
    ]
    
    print("\n📊 Testing with sample questions...")
    for i, question in enumerate(test_questions[:3], 1):  # Test first 3 questions
        print(f"\n[{i}] {question}")
        result = rag_system.answer_query(question)
        print(f"Answer: {result['answer']}")
        print(f"Sources: {list(result['sources'].values())}")
    
    # Uncomment the next line for interactive mode
    # rag_system.interactive_query()

🚀 Initializing Carrier RAG System...
[CarrierRAG] Loading existing vector store...
[CarrierRAG] Loaded existing vector store with 1370 chunks

📊 Testing with sample questions...

[1] What was Carrier's GAAP EPS in Q1 2024?
Answer: According to [S1] Q1_2024.md, Carrier's GAAP EPS in Q1 2024 was $0.29.
Sources: ['Q1_2024.md', 'Q4_2021.md', 'Q1_2024.md', 'Q4_2022.md', 'Q2_2024.md', 'Q1_2025.md', 'Q2_2024.md', 'Q4_2023.md', 'Q2_2025.md', 'Q2_2020.md', 'Q3_2023.md', 'Q3_2023.md', 'Q1_2020.md', 'Q4_2022.md', 'Q2_2024.md']

[2] Using historical data, forecast Carrier's operating margin for Q4 2024
Answer: Based on the provided context snippets, I will analyze the historical trends and patterns to forecast Carrier's operating margin for Q4 2024.

From [S1] Q4_2023.md, we can see that Carrier's full-year 2023 sales were $22.1B, with organic sales growth of 3% and a 5% impact from acquisitions and divestitures. Gross margins increased 210 basis points compared to the prior year.

From [S11] Q4_2