In [None]:
import os
import pandas as pd
import pickle
from typing import List, Dict, Any, Optional
from langchain_core.documents import Document
import warnings
from pathlib import Path
from pprint import pprint

def safe_load_pickle(file_path: str) -> Any:
    """Safely load pickle files that may contain pandas objects"""
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        try:
            # First try standard pickle load
            with open(file_path, "rb") as f:
                return pickle.load(f)
        except (TypeError, AttributeError, pickle.UnpicklingError):
            try:
                # Fallback to pandas read_pickle
                return pd.read_pickle(file_path)
            except Exception as e:
                raise ValueError(f"Failed to load {file_path}: {str(e)}")

def inspect_file(file_path: str) -> None:
    """Debug function to examine file structure"""
    print(f"\n=== Inspecting {file_path} ===")
    try:
        if file_path.endswith(('.xlsx', '.xls')):
            df = pd.read_excel(file_path)
            print("Excel file detected")
            print(f"Shape: {df.shape}")
            print("Columns:", df.columns.tolist())
            print("\nFirst row:")
            pprint(df.iloc[0].to_dict())
        elif file_path.endswith('.pkl'):
            data = safe_load_pickle(file_path)
            print(f"Type: {type(data)}")
            if isinstance(data, pd.DataFrame):
                print("Pandas DataFrame detected")
                print(f"Shape: {data.shape}")
                print("Columns:", data.columns.tolist())
                print("\nFirst row:")
                pprint(data.iloc[0].to_dict())
            elif isinstance(data, list):
                print(f"List of {len(data)} items")
                if data:
                    print("\nFirst item type:", type(data[0]))
                    if isinstance(data[0], dict):
                        print("Keys in first item:", data[0].keys())
            elif isinstance(data, dict):
                print("Dictionary with keys:", data.keys())
            else:
                print("Content sample:", str(data)[:200] + "...")
        else:
            raise ValueError("Unsupported file format")
    except Exception as e:
        print(f"Inspection failed: {str(e)}")

def load_buffet_qna_xlsx(file_path: str) -> List[Document]:
    """
    Load Q&A data from Excel with columns: Section, Questions, Answers
    Returns List[Document] where:
    - page_content contains formatted Q&A
    - metadata contains structured fields
    """
    df = pd.read_excel(file_path)
    
    # Validate required columns
    required_cols = {'Section', 'Questions', 'Answers'}
    missing = required_cols - set(df.columns)
    if missing:
        raise ValueError(f"Missing columns in Q&A file: {missing}")

    documents = []
    for _, row in df.iterrows():
        # Create human-readable content
        page_content = (
            f"Question: {row['Questions']}\n"
            f"Answer: {row['Answers']}\n"
            f"Section: {row['Section']}"
        )
        
        # Store structured data in metadata
        metadata = {
            "section": row["Section"],
            "question": row["Questions"],
            "answer": row["Answers"],
            "source": "buffet_qna"
        }
        
        documents.append(Document(
            page_content=page_content,
            metadata=metadata
        ))
    
    return documents

def load_brka_trades_xlsx(file_path: str) -> List[Document]:
    """
    Load trades data from Excel with financial columns
    Returns List[Document] where:
    - page_content contains key trade info
    - metadata contains all raw data
    """
    df = pd.read_excel(file_path)
    
    # Validate required columns
    required_cols = {'RIC', 'Security Name', 'Date', 'Position', 'Position Change'}
    missing = required_cols - set(df.columns)
    if missing:
        raise ValueError(f"Missing columns in trades file: {missing}")

    documents = []
    for _, row in df.iterrows():
        # Create human-readable summary
        page_content = (
            f"Security: {row['Security Name']} ({row['RIC']})\n"
            f"Date: {row['Date']}\n"
            f"Position: {row['Position']:,} shares\n"
            f"Change: {row['Position Change']:+,}"
        )
        
        # Store all raw data in metadata
        metadata = row.to_dict()
        metadata.update({"source": "brka_trades"})
        
        documents.append(Document(
            page_content=page_content,
            metadata=metadata
        ))
    
    return documents

def load_news_pickle(file_path: str) -> List[Document]:
    """Load news pickle with flexible format handling"""
    data = safe_load_pickle(file_path)
    
    # Handle DataFrame case
    if isinstance(data, pd.DataFrame):
        data = data.to_dict('records')
    
    # Handle list of dicts
    if isinstance(data, list) and all(isinstance(x, dict) for x in data):
        return [
            Document(
                page_content=item.get('text', item.get('content', str(item))),
                metadata={k: v for k, v in item.items() 
                         if k not in ['text', 'content']}
            )
            for item in data
        ]
    
    # Handle single dictionary
    elif isinstance(data, dict):
        return [Document(
            page_content=data.get('text', data.get('content', str(data))),
            metadata={k: v for k, v in data.items() 
                     if k not in ['text', 'content']}
        )]
    
    # Fallback for other formats
    return [Document(page_content=str(data))]

def load_shareholder_letters_pickle(file_path: str) -> List[Document]:
    """Load shareholder letters with year-based structure"""
    data = safe_load_pickle(file_path)
    
    if isinstance(data, dict) and all(isinstance(k, (str, int)) for k in data):
        return [
            Document(
                page_content=content,
                metadata={"year": year, "source": "shareholder_letter"}
            )
            for year, content in data.items()
        ]
    
    # Fallback for other formats
    return load_news_pickle(file_path)

def load_documents(file_path: str) -> List[Document]:
    """Main document loading interface"""
    try:
        if file_path.endswith('.xlsx'):
            if 'buffet_qna' in file_path.lower():
                return load_buffet_qna_xlsx(file_path)
            elif 'brka_trades' in file_path.lower():
                return load_brka_trades_xlsx(file_path)
        
        elif file_path.endswith('.pkl'):
            if 'news' in file_path.lower():
                return load_news_pickle(file_path)
            elif 'shareholder' in file_path.lower():
                return load_shareholder_letters_pickle(file_path)
        
        raise ValueError(f"Unrecognized file type: {file_path}")
    
    except Exception as e:
        print(f"Error loading {file_path}: {str(e)}")
        return [Document(
            page_content=f"Error loading document: {str(e)}",
            metadata={"source": "error", "file_path": file_path}
        )]

def load_all_documents(
    base_path: str = "/data",
    buffet_qna_path: Optional[str] = None,
    brka_trades_path: Optional[str] = None,
    news_path: Optional[str] = None,
    shareholder_letters_path: Optional[str] = None,
    verbose: bool = True
) -> List[Document]:
    """
    Load all 4 document sources at once.
    
    Args:
        base_path: Base directory for files if individual paths not specified
        *_path: Override individual file paths
        verbose: Print loading progress
    
    Returns:
        Combined list of Documents from all sources
    """
    # Set default paths if not specified
    buffet_qna_path = buffet_qna_path or Path(base_path) / "buffet_qna.xlsx"
    brka_trades_path = brka_trades_path or Path(base_path) / "brka_trades.xlsx"
    news_path = news_path or Path(base_path) / "news.pkl"
    shareholder_letters_path = shareholder_letters_path or Path(base_path) / "shareholder_letters.pkl"
    
    all_docs = []
    
    # Load each file with progress reporting
    for file_path, loader in [
        (buffet_qna_path, load_buffet_qna_xlsx),
        (brka_trades_path, load_brka_trades_xlsx),
        (news_path, load_news_pickle),
        (shareholder_letters_path, load_shareholder_letters_pickle)
    ]:
        try:
            if verbose:
                print(f"Loading {file_path}...")
            docs = loader(file_path)
            all_docs.extend(docs)
            if verbose:
                print(f"Loaded {len(docs)} documents")
        except Exception as e:
            if verbose:
                print(f"Error loading {file_path}: {str(e)}")
            all_docs.append(Document(
                page_content=f"Error loading {file_path}: {str(e)}",
                metadata={"source": "error", "file_path": str(file_path)}
            ))
    
    if verbose:
        print(f"\nTotal documents loaded: {len(all_docs)}")
    
    return all_docs

# print(f"Script running from: {os.getcwd()}")

# print(load_documents("../data/news.pkl"))
# print(load_all_documents("../data"))

In [None]:
from langchain.text_splitter import RecursiveCharacterTextSplitter


CHUNK_SIZE = 1024
CHUNK_OVERLAP = 20

def split_documents(documents):
    """
    Splits a list of document objects into manageable text chunks.
    Assumes each document has a 'content' field.
    """
    splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
    # For compatibility with langchain's splitter, wrap your dicts in a simple object:
    # Here we assume each document is a dict with a 'content' key.
    # You might need to convert these dicts to the Document type expected by langchain.
    return splitter.split_documents(documents)


In [None]:
import numpy as np
from rank_bm25 import BM25Okapi
from langchain_community.vectorstores import FAISS
from langchain_community.retrievers import BM25Retriever
from langchain_openai import OpenAIEmbeddings
from langchain.retrievers import EnsembleRetriever

# Compare 2 different retrieval methods and evaluate

def build_bm25(corpus):
    """
    Builds a BM25 model from a list of texts (corpus).
    """
    return BM25Okapi(corpus)

def build_faiss_index(documents):
    """
    Uses OpenAI embeddings to build a FAISS vectorstore from documents.
    """
    embeddings = OpenAIEmbeddings()
    # FAISS vectorstore will build the index based on document 'content'
    vectorstore = FAISS.from_documents(documents, embeddings)
    return vectorstore

def create_ensemble_retriever(bm25_retriever, faiss_retriever):
    """
    Combines BM25 and FAISS retrievers into an ensemble retriever.
    """
    return EnsembleRetriever(retrievers=[bm25_retriever, faiss_retriever])


In [None]:
import numpy as np
from sentence_transformers import CrossEncoder

# Initialize the pre-trained cross-encoder (adjust model name as needed)
cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")

def rerank_documents(query, documents, top_k=5):
    """
    Reranks a list of document objects based on relevance to the query.
    Each document is assumed to have a 'content' field.
    """
    pairs = [[query, doc["content"]] for doc in documents]
    scores = cross_encoder.predict(pairs)
    ranked_indices = np.argsort(scores)[::-1][:top_k]
    return [documents[i] for i in ranked_indices]


In [None]:
def create_prompt(company_name: str, documents):
    """
    Creates a prompt for the LLM using the retrieved documents.
    Assumes each document is a dict with a 'content' field.
    """
    context = "\n".join([doc["content"] for doc in documents])
    prompt = f"""
You are a financial analyst tasked with generating an investment report. Use the following context:
{context}

Focus on revenue growth, profitability, and initiatives.
Provide specific numbers and facts where available.
Use a professional, concise tone.
Structure the report with these sections:
- Overview
- Revenue Growth
- Profitability
- ESG Initiatives
- Conclusion

Investment Report for {company_name}:
"""
    return prompt

def create_evaluation_prompt(query: str, generated_report: str, retrieved_documents):
    """
    Creates an evaluation prompt for the LLM to judge the generated report.
    """
    retrieved_context = "\n".join([doc["content"] for doc in retrieved_documents])
    prompt = f"""
You are an expert financial analyst evaluating an investment report.
    
Query:
{query}

Generated Report:
{generated_report}

Retrieved Context:
{retrieved_context}

Evaluation Criteria:
1. Relevance (0-10): Does the report address the query?
2. Accuracy (0-10): Are the facts correct and supported by the retrieved context?
3. Coherence (0-10): Is the report well-structured and easy to understand?

Provide a score for each criterion along with a brief explanation.
"""
    return prompt


In [None]:
import os
from dotenv import load_dotenv
import openai 
from rag.text_splitter import split_documents
from rag.retrieval import build_bm25, build_faiss_index, create_ensemble_retriever
from rag.reranker import rerank_documents
from rag.prompt_engineering import create_prompt, create_evaluation_prompt

from langchain_community.retrievers import BM25Retriever

# load env variables from .env
load_dotenv()

# set OpenAI API key
openai.api_key = os.getenv("OPENAI_API_KEY")
if not openai.api_key:
    raise ValueError("API key is not set")

# call custom LLM
def call_custom_llm(prompt: str, context: str = "") -> str:
    """
    Placeholder function to call your custom LLM (@ethan, @yucai)
    The function should take a prompt and optional context, then return a generated text.
    In production, replace this with the actual API call or function to your LLM.
    """
    # For demonstration, we simply return the prompt combined with context.
    # Replace this with your actual generation call.
    combined_input = context + "\n" + prompt if context else prompt
    return "Custom LLM Response based on:\n" + combined_input

class RAGPipeline:
    def __init__(self, raw_documents):
        """
        Initializes the pipeline:
         - Splits raw documents into text chunks.
         - Builds BM25 and FAISS retrievers.
         - Combines them into an ensemble retriever.
         - Initializes a context memory for feedback loop.
        Expects raw_documents as a list of dicts with at least a 'content' field.
        """
        # Split documents into chunks.
        self.documents = split_documents(raw_documents)
        
        # Build BM25 retriever.
        corpus = [doc.page_content for doc in self.documents]  

        bm25_model = build_bm25(corpus)
        
        self.bm25_retriever = BM25Retriever.from_documents(self.documents)
        self.bm25_retriever.k = 5
        
        # Build FAISS retriever.
        faiss_vectorstore = build_faiss_index(self.documents)
        self.faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={"k": 5})
        
        # Create an ensemble retriever.
        self.ensemble_retriever = create_ensemble_retriever(self.bm25_retriever, self.faiss_retriever)
        
        # Initialize context memory for feedback loop (stores conversation history).
        self.context_memory = []

    def retrieve_and_rerank(self, query: str, retriever, top_k: int = 5):
        """
        Retrieves documents using the provided retriever and then reranks them.
        """
        docs = retriever.get_relevant_documents(query)
        return rerank_documents(query, docs, top_k=top_k)
    
    def generate_report(self, query: str, company_name: str, retriever):
        """
        Retrieves documents, creates a prompt, adds context from previous interactions,
        and uses the custom LLM to generate a report.
        """
        # Retrieve and rerank documents.
        retrieved_docs = self.retrieve_and_rerank(query, retriever)
        
        # Create the base prompt using the retrieved documents.
        prompt = create_prompt(company_name, retrieved_docs)
        
        # Incorporate feedback context (if any) into the prompt.
        # Here we simply concatenate the context memory into a single string.
        context_text = "\n".join(self.context_memory)
        
        # Call your custom LLM with the prompt and context.
        report = call_custom_llm(prompt, context=context_text)
        
        # Update context memory: you can choose to store just the query/response pair
        # or additional details. Here, we store the generated report.
        self.context_memory.append(f"User Query: {query}")
        self.context_memory.append(f"LLM Report: {report}")
        
        return report, retrieved_docs

    def evaluate_report(self, query: str, report: str, retrieved_docs):
        """
        Uses the custom LLM (or another LLM) to evaluate the generated report.
        This function calls the evaluation prompt and returns the evaluation.
        """
        eval_prompt = create_evaluation_prompt(query, report, retrieved_docs)
        # You could also use your custom LLM here, but for now, we call the placeholder.
        evaluation = call_custom_llm(eval_prompt)
        return evaluation

    def process_query(self, query: str, company_name: str, method: str = "ensemble"):
        """
        Processes the query using one of the retrieval methods:
          - "bm25": BM25 only.
          - "faiss": FAISS only.
          - "ensemble": Ensemble (BM25 + FAISS).
        Returns the generated report and its evaluation.
        """
        if method == "bm25":
            retriever = self.bm25_retriever
        elif method == "faiss":
            retriever = self.faiss_retriever
        else:
            retriever = self.ensemble_retriever
        
        report, retrieved_docs = self.generate_report(query, company_name, retriever)
        evaluation = self.evaluate_report(query, report, retrieved_docs)
        return report, evaluation


In [None]:
# Document loading
if args.load_all:
    print("Loading all standard documents from data/ directory...")
    raw_docs = load_all_documents(base_path="../data")
else:
    raw_docs = load_documents(args.data_path)

# Initialize and run pipeline
pipeline = RAGPipeline(raw_docs)
report, evaluation = pipeline.process_query(
    args.query, args.company, method=args.method
)
