# Interactive RAG Question Answering System

This notebook provides an interactive interface for asking questions about your documents using:
- **Multilingual Bi-encoder** for initial retrieval
- **HyDE** (Hypothetical Document Embeddings) for improved query understanding
- **Cross-encoder reranking** for better relevance
- **LiteLLM** for flexible LLM provider support

## Setup Instructions:
1. Install required packages
2. Set your API key and file path
3. Run the setup cells
4. Start asking questions!

## 1. Installation (run this first if packages not installed)

In [None]:
# Uncomment and run if you need to install packages
# !pip install haystack-ai sentence-transformers litellm torch ipywidgets

## 2. Import Libraries and RAG System

In [None]:
import os
import pickle
from pathlib import Path
from typing import List, Dict, Any, Optional
import logging
from dataclasses import dataclass
import warnings
warnings.filterwarnings('ignore')

import litellm
from haystack import Document, Pipeline
from haystack.components.builders import PromptBuilder
from haystack.components.converters import TextFileToDocument
from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder
from haystack.components.preprocessors import DocumentSplitter
from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever
from haystack.components.rankers import SentenceTransformersSimilarityRanker
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.utils import ComponentDevice

# For interactive widgets
import ipywidgets as widgets
from IPython.display import display, HTML, clear_output
import time

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

print("✅ Libraries imported successfully!")

## 3. RAG System Classes (copy from the main codebase)

In [None]:
@dataclass
class RAGConfig:
    """Configuration class for RAG pipeline"""
    # Model configurations
    biencoder_model: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
    crossencoder_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"
    llm_model: str = "gemini/gemini-1.5-flash"  # LiteLLM format
    
    # Processing configurations
    chunk_size: int = 512
    chunk_overlap: int = 50
    top_k_retrieval: int = 20
    top_k_reranking: int = 5
    
    # Device configuration
    device: ComponentDevice = ComponentDevice.from_str("cuda:0")
    
    # Storage paths
    embeddings_file: str = "document_embeddings.pkl"
    documents_file: str = "documents.pkl"
    
    # LLM parameters
    temperature: float = 0.1
    max_tokens: int = 1000

class LiteLLMGenerator:
    """LiteLLM generator component for unified LLM access"""
    
    def __init__(self, model_name: str = "gemini/gemini-1.5-flash", api_key: Optional[str] = None, **kwargs):
        self.model_name = model_name
        self.generation_kwargs = kwargs
        
        # Set API key if provided
        if api_key:
            if "gemini" in model_name.lower():
                os.environ["GEMINI_API_KEY"] = api_key
            elif "openai" in model_name.lower() or "gpt" in model_name.lower():
                os.environ["OPENAI_API_KEY"] = api_key
            elif "anthropic" in model_name.lower() or "claude" in model_name.lower():
                os.environ["ANTHROPIC_API_KEY"] = api_key
        
        # Test the connection
        self._test_connection()
    
    def _test_connection(self):
        """Test if the model is accessible"""
        try:
            response = litellm.completion(
                model=self.model_name,
                messages=[{"role": "user", "content": "Hi"}],
                max_tokens=10
            )
            logger.info(f"Successfully connected to {self.model_name}")
        except Exception as e:
            logger.warning(f"Could not test connection to {self.model_name}: {e}")
    
    def run(self, prompt: str, **kwargs) -> Dict[str, Any]:
        """Generate response using LiteLLM"""
        try:
            generation_params = {**self.generation_kwargs, **kwargs}
            
            response = litellm.completion(
                model=self.model_name,
                messages=[{"role": "user", "content": prompt}],
                **generation_params
            )
            
            return {"replies": [response.choices[0].message.content]}
            
        except Exception as e:
            logger.error(f"Error generating response with {self.model_name}: {e}")
            return {"replies": ["Sorry, I couldn't generate a response."]}

class HyDEGenerator:
    """Hypothetical Document Embeddings (HyDE) generator"""
    
    def __init__(self, llm_generator: LiteLLMGenerator):
        self.generator = llm_generator
        self.hyde_prompt = """Given the following question, write a hypothetical document that would perfectly answer this question.
The document should be detailed, informative, and directly address the question.

Question: {question}

Hypothetical Document:"""
    
    def run(self, query: str) -> Dict[str, str]:
        """Generate hypothetical document for the given query"""
        prompt = self.hyde_prompt.format(question=query)
        result = self.generator.run(prompt)
        hypothetical_doc = result["replies"][0] if result["replies"] else query
        return {"hypothetical_document": hypothetical_doc}

class DocumentProcessor:
    """Handles document processing and embedding generation"""
    
    def __init__(self, config: RAGConfig):
        self.config = config
        self.document_store = InMemoryDocumentStore()
        
        # Initialize embedder
        self.document_embedder = SentenceTransformersDocumentEmbedder(
            model=config.biencoder_model,
            device=config.device
        )

        self.document_embedder.warm_up()
        
        # Initialize splitter
        self.splitter = DocumentSplitter(
            split_by="word",
            split_length=config.chunk_size,
            split_overlap=config.chunk_overlap
        )
    
    def process_text_file(self, file_path: str) -> List[Document]:
        """Process text file into documents"""
        logger.info(f"Processing file: {file_path}")
        
        # Convert file to document
        converter = TextFileToDocument()
        documents = converter.run(sources=[file_path])["documents"]
        
        # Split documents
        split_docs = self.splitter.run(documents=documents)["documents"]
        
        logger.info(f"Created {len(split_docs)} document chunks")
        return split_docs
    
    def generate_embeddings(self, documents: List[Document]) -> List[Document]:
        """Generate embeddings for documents"""
        logger.info("Generating embeddings...")
        
        # Generate embeddings
        embedded_docs = self.document_embedder.run(documents=documents)["documents"]
        
        logger.info(f"Generated embeddings for {len(embedded_docs)} documents")
        return embedded_docs
    
    def save_documents_and_embeddings(self, documents: List[Document]):
        """Save documents and embeddings to files"""
        logger.info("Saving documents and embeddings...")
        
        # Save documents
        with open(self.config.documents_file, 'wb') as f:
            pickle.dump(documents, f)
        
        # Save embeddings separately
        embeddings_data = {
            'embeddings': [doc.embedding for doc in documents],
            'metadata': [{'id': doc.id, 'content': doc.content[:100] + '...'} for doc in documents]
        }
        
        with open(self.config.embeddings_file, 'wb') as f:
            pickle.dump(embeddings_data, f)
        
        logger.info("Documents and embeddings saved successfully")
    
    def load_documents_and_embeddings(self) -> List[Document]:
        """Load documents and embeddings from files"""
        logger.info("Loading documents and embeddings...")
        
        if not os.path.exists(self.config.documents_file):
            raise FileNotFoundError("Documents file not found. Please process documents first.")
        
        with open(self.config.documents_file, 'rb') as f:
            documents = pickle.load(f)
        
        logger.info(f"Loaded {len(documents)} documents")
        return documents

class RAGPipeline:
    """Main RAG pipeline with HyDE and cross-encoder reranking"""
    
    def __init__(self, config: RAGConfig, api_key: str):
        self.config = config
        self.document_store = InMemoryDocumentStore()
        
        # Initialize LLM generator
        self.llm_generator = LiteLLMGenerator(
            model_name=config.llm_model,
            api_key=api_key,
            temperature=config.temperature,
            max_tokens=config.max_tokens
        )
        
        # Initialize HyDE generator
        self.hyde_generator = HyDEGenerator(self.llm_generator)
        
        # Initialize retriever
        self.retriever = InMemoryEmbeddingRetriever(
            document_store=self.document_store,
            top_k=config.top_k_retrieval
        )
        
        # Initialize text embedder for queries
        self.text_embedder = SentenceTransformersTextEmbedder(
            model=config.biencoder_model,
            device=config.device
        )
        
        # Initialize cross-encoder ranker
        self.ranker = SentenceTransformersSimilarityRanker(
            model=config.crossencoder_model,
            top_k=config.top_k_reranking,
            device=config.device
        )

        self.text_embedder.warm_up()
        self.ranker.warm_up()
        
        # QA prompt template
        self.qa_prompt_template = """Context information is below:
---------------------
{% for doc in documents %}
{{ doc.content }}
---------------------
{% endfor %}

Given the context information above, please answer the following question.
If the answer cannot be found in the context, please say "I cannot find the answer in the provided context."

Question: {{ query }}

Answer:"""
        
        self.prompt_builder = PromptBuilder(template=self.qa_prompt_template)
    
    def load_documents(self, documents: List[Document]):
        """Load documents into the document store"""
        logger.info("Loading documents into document store...")
        self.document_store.write_documents(documents)
        logger.info(f"Loaded {len(documents)} documents into document store")
    
    def answer_question(self, query: str, use_hyde: bool = True) -> Dict[str, Any]:
        """Answer question using the full RAG pipeline"""
        logger.info(f"Processing query: {query}")
        
        # Step 1: Generate hypothetical document (HyDE)
        if use_hyde:
            logger.info("Generating hypothetical document (HyDE)...")
            hyde_result = self.hyde_generator.run(query)
            search_query = hyde_result["hypothetical_document"]
            logger.info(f"HyDE generated document: {search_query[:100]}...")
        else:
            search_query = query
        
        # Step 2: Embed the search query
        logger.info("Embedding search query...")
        query_embedding = self.text_embedder.run(text=search_query)["embedding"]
        
        # Step 3: Retrieve documents using bi-encoder
        logger.info("Retrieving documents...")
        retrieved_docs = self.retriever.run(
            query_embedding=query_embedding,
            top_k=self.config.top_k_retrieval
        )["documents"]
        
        logger.info(f"Retrieved {len(retrieved_docs)} documents")
        
        # Step 4: Rerank using cross-encoder
        logger.info("Reranking documents...")
        reranked_docs = self.ranker.run(
            query=query,  # Use original query for reranking
            documents=retrieved_docs
        )["documents"]
        
        logger.info(f"Reranked to top {len(reranked_docs)} documents")
        
        # Step 5: Generate final answer
        logger.info("Generating final answer...")
        prompt = self.prompt_builder.run(
            query=query,
            documents=reranked_docs
        )["prompt"]
        
        answer = self.llm_generator.run(prompt)
        
        return {
            "answer": answer["replies"][0],
            "retrieved_documents": retrieved_docs,
            "reranked_documents": reranked_docs,
            "hyde_document": search_query if use_hyde else None
        }

class RAGSystem:
    """Main system orchestrator"""
    
    def __init__(self, api_key: str, config: Optional[RAGConfig] = None):
        self.config = config or RAGConfig()
        self.api_key = api_key
        self.document_processor = DocumentProcessor(self.config)
        self.rag_pipeline = RAGPipeline(self.config, api_key)
    
    def setup_from_text_file(self, file_path: str, force_reprocess: bool = False):
        """Setup the system from a text file"""
        # Check if embeddings already exist
        if (os.path.exists(self.config.embeddings_file) and 
            os.path.exists(self.config.documents_file) and 
            not force_reprocess):
            
            logger.info("Found existing embeddings, loading...")
            documents = self.document_processor.load_documents_and_embeddings()
        else:
            logger.info("Processing documents and generating embeddings...")
            # Process documents
            documents = self.document_processor.process_text_file(file_path)
            
            # Generate embeddings
            documents = self.document_processor.generate_embeddings(documents)
            
            # Save for future use
            self.document_processor.save_documents_and_embeddings(documents)
        
        # Load into RAG pipeline
        self.rag_pipeline.load_documents(documents)
        logger.info("System setup complete!")
    
    def ask(self, question: str, use_hyde: bool = True) -> str:
        """Ask a question and get an answer"""
        result = self.rag_pipeline.answer_question(question, use_hyde=use_hyde)
        return result["answer"]
    
    def ask_detailed(self, question: str, use_hyde: bool = True) -> Dict[str, Any]:
        """Ask a question and get detailed results"""
        return self.rag_pipeline.answer_question(question, use_hyde=use_hyde)

print("✅ RAG System classes loaded successfully!")

## 4. Configuration and Setup

In [None]:
# 🔧 CONFIGURATION - MODIFY THESE VALUES
from my_secrets import API_KEY
TEXT_FILE_PATH = "/home/fiok/work/rag-test/data/tarczynski.txt"  # Replace with your text file path

# Model configuration
config = RAGConfig(
    llm_model="gemini/gemini-2.0-flash-exp",  # Change to your preferred model
    # Alternative models:
    # llm_model="openai/gpt-4o",
    # llm_model="anthropic/claude-3-sonnet-20240229",
    
    chunk_size=512,
    chunk_overlap=50,
    top_k_retrieval=20,
    top_k_reranking=5,
    device=ComponentDevice.from_str("cuda:0"),  # Change to "cpu" if no GPU
    temperature=0.1,
    max_tokens=1000
)

print("✅ Configuration loaded!")
print(f"📋 Model: {config.llm_model}")
print(f"🔧 Device: {config.device}")
print(f"📄 Document file: {TEXT_FILE_PATH}")

## 5. Initialize and Setup RAG System

In [None]:
# Initialize the RAG system
print("🚀 Initializing RAG System...")
rag_system = RAGSystem(API_KEY, config)

# Setup from text file (this will process and embed documents if not done before)
print("📚 Setting up documents...")
try:
    rag_system.setup_from_text_file(TEXT_FILE_PATH, force_reprocess=False)
    print("✅ RAG System ready for questions!")
    system_ready = True
except Exception as e:
    print(f"❌ Error setting up system: {e}")
    print("Please check your file path and API key")
    system_ready = False

## 6. Interactive Question-Answer Interface

In [None]:
def ask_simple_question(question, use_hyde=True):
    """Simple function to ask a question and print results"""
    if not system_ready:
        print("❌ System not ready. Please check setup above.")
        return
    
    print(f"🤔 Question: {question}")
    print(f"🔬 Using HyDE: {'Yes' if use_hyde else 'No'}")
    print("⏳ Processing...")
    
    start_time = time.time()
    
    try:
        # Get detailed results
        result = rag_system.ask_detailed(question, use_hyde=use_hyde)
        
        end_time = time.time()
        processing_time = end_time - start_time
        
        print("\n" + "="*80)
        print("📝 ANSWER:")
        print("="*80)
        print(result["answer"])
        
        print("\n" + "="*80)
        print("📊 DETAILS:")
        print("="*80)
        print(f"⏱️  Processing time: {processing_time:.2f} seconds")
        print(f"📄 Documents retrieved: {len(result['retrieved_documents'])}")
        print(f"🎯 Documents after reranking: {len(result['reranked_documents'])}")
        
        if result['hyde_document'] and use_hyde:
            print(f"\n🔮 HyDE Generated Document (first 200 chars):")
            print(f"{result['hyde_document']}...")
        
        print(f"\n📋 Top Retrieved Document Snippets:")
        for i, doc in enumerate(result['reranked_documents'][:3], 1):
            print(f"\n[{i}] Score: {doc.score:.4f}")
            print(f"Content: {doc.content}...")

        return result
            
    except Exception as e:
        print(f"❌ Error processing question: {e}")
        return ""

# Example usage:
print("✅ Simple Q&A function ready!")
print("📝 Use: ask_simple_question('Your question here')")
print("🔬 With HyDE: ask_simple_question('Your question', use_hyde=True)")
print("🚀 Without HyDE: ask_simple_question('Your question', use_hyde=False)")

# %%
# Try asking a question:
result = ask_simple_question("Jaka liczba pracowników jest zatrudniona przez firmę?")

# %%
# Ask another question without HyDE:
result_no_hyde = ask_simple_question("Jaka liczba pracowników jest zatrudniona przez firmę?", use_hyde=False)

In [None]:
result['reranked_documents']

In [None]:
result_no_hyde['reranked_documents']