# RAG

### This RAG flow consists of the following steps:

1. **Text Chunking**: The text is split into smaller chunks for better processing.
2. **Text Embedding**: Each chunk is converted into a vector representation.
3. **Vector Database**: The vectors are stored in a database for efficient retrieval.
4. **Best Match 25**: When a query is made, the system retrieves the 25 most relevant chunks based on the vector similarity.
5. **Hybrid Search**: The system performs a hybrid search combining vector similarity and keyword matching to find the most relevant chunks.
6. **Re-ranking**: The retrieved chunks are re-ranked based on their relevance to the query.
7. **Contextual Answer Generation**: The final answer is generated using the most relevant chunks, providing a contextually rich response.

# dependencies

In [1]:
from google import genai
from anthropic import Anthropic
import re
import os
from collections import Counter
from typing import Callable, Any, List, Dict, Tuple, Optional, Protocol
import math
import json
import random
import string

from dotenv import load_dotenv
load_dotenv()

import logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)

try:
    gemini_client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))
    logger.info("Gemini client created successfully.")
except Exception as e:
    logger.error("Failed to create Gemini client: %s", e)
    raise

try:
    anthropic_client = Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
    logger.info("Anthropic client created successfully.")
except Exception as e:
    print(f"Error initializing Anthropic client: {e}")


2025-08-08 20:07:06 - __main__ - INFO - Gemini client created successfully.
2025-08-08 20:07:06 - __main__ - INFO - Anthropic client created successfully.


# global config

In [2]:
EMBEDDING_MODEL = "gemini-embedding-001"
ANTHROPIC_MODEL="claude-3-haiku-20240307"
TEMPERATURE=0.7

# prompt templates

In [3]:
ANSWER_PROMT = """
Based on the following context, please answer the user's question.

Context:
{context}

Question: {query}

Please provide a comprehensive answer based on the context provided.
If the context doesn't contain enough information to answer the question,
please say so clearly.
"""

RERANK_PROMPT = """
You are about to be given a set of documents, along with an id of each.
Your task is to select the {k} most relevant documents to answer the user's question.

Here is the user's question:
<question>
{query_text}
</question>

Here are the documents to select from:
<documents>
{joined_docs}
</documents>

Respond in the following format:
```json
{{
"document_ids": [] 
}}
```
"""

CONTEXT_PROMT = """
Write a short and succinct snippet of text to situate this chunk within the 
overall source document for the purposes of improving search retrieval of the chunk. 

Here is the original source document:
<document> 
{source_text}
</document> 

Here is the chunk we want to situate within the whole document:
<chunk> 
{text_chunk}
</chunk>

Answer only with the succinct context and nothing else. 
"""

# helper functions

In [4]:
def add_user_message(messages, message_content):
    if isinstance(message_content, list):
        user_message = {
            "role": "user",
            "content": message_content
        }
    else:
        user_message = {
            "role": "user",
            "content": [{"type": "text", "text": str(message_content)}]}

    messages.append(user_message)


def add_assistant_message(messages, message_content):
    if isinstance(message_content, list):
        assistant_message = {
            "role": "assistant",
            "content": message_content
        }
    elif hasattr(message_content, "content"):
        content_list = []
        for block in message_content.content:
            if block.type == "text":
                content_list.append({"type": "text", "text": block.text})
            elif block.type == "tool_use":
                content_list.append({
                    "type": "tool_use",
                    "id": block.id,
                    "name": block.name,
                    "input": block.input
                })
        assistant_message = {
            "role": "assistant",
            "content": content_list
        }
    else:
        assistant_message = {
            "role": "assistant",
            "content": [{"type": "text", "text": message_content}]
        }

    messages.append(assistant_message)

def chat(messages, model=ANTHROPIC_MODEL, temperature=TEMPERATURE, system=None, stop_sequences=None, tools=None, tool_choice=None, betas=[]):
    try:
        params = {
            "model": model,
            "messages": messages,
            "temperature": temperature,
            "max_tokens": 1000,
        }
        if system:
            params["system"] = system

        if tools:
            params["tools"] = tools

        if tool_choice:
            params["tool_choice"] = tool_choice

        if betas:
            params["betas"] = betas

        if stop_sequences: 
            params["stop_sequences"] = stop_sequences

        return anthropic_client.messages.create(**params)
        
    except Exception as e:
        logger.error(f"Chat streaming failed: {e}")
        raise e
    
def text_from_message(message):
    return "\n".join(
        [block.text for block in message.content if block.type == "text"]
    )

# 1. text chunking

In [5]:
def chunk_by_structure(text):
    pattern = r"\n## "

    return re.split(pattern, text)

# 2. text embedding

In [6]:
def generate_embeddings(text, model=EMBEDDING_MODEL):
    try:
        response = gemini_client.models.embed_content(
            model=model,
            contents=text
        )
        embeddings = response.embeddings
        logger.info("Embeddings generated successfully.")
        
        return embeddings
    except Exception as e:
        logger.error("Failed to generate embeddings: %s", e)
        raise

# 3. vector database

## VectorIndex implementation

In [7]:
class VectorIndex:
    def __init__(
        self,
        distance_metric: str = "cosine",
        embedding_fn=None,
    ):
        self.vectors: List[List[float]] = []
        self.documents: List[Dict[str, Any]] = []
        self._vector_dim: Optional[int] = None
        if distance_metric not in ["cosine", "euclidean"]:
            raise ValueError("distance_metric must be 'cosine' or 'euclidean'")
        self._distance_metric = distance_metric
        self._embedding_fn = embedding_fn

    def add_document(self, document: Dict[str, Any]):
        if not self._embedding_fn:
            raise ValueError(
                "Embedding function not provided during initialization."
            )
        if not isinstance(document, dict):
            raise TypeError("Document must be a dictionary.")
        if "content" not in document:
            raise ValueError(
                "Document dictionary must contain a 'content' key."
            )

        content = document["content"]
        if not isinstance(content, str):
            raise TypeError("Document 'content' must be a string.")

        vector = self._embedding_fn(content)
        self.add_vector(vector=vector, document=document)

    def add_documents(self, documents: List[Dict[str, Any]]):
        if not self._embedding_fn:
            raise ValueError(
                "Embedding function not provided during initialization."
            )

        if not isinstance(documents, list):
            raise TypeError("Documents must be a list of dictionaries.")

        if not documents:
            return

        contents = []
        for i, doc in enumerate(documents):
            if not isinstance(doc, dict):
                raise TypeError(f"Document at index {i} must be a dictionary.")
            if "content" not in doc:
                raise ValueError(
                    f"Document at index {i} must contain a 'content' key."
                )
            if not isinstance(doc["content"], str):
                raise TypeError(
                    f"Document 'content' at index {i} must be a string."
                )
            contents.append(doc["content"])

        vectors = self._embedding_fn(contents)

        for vector, document in zip(vectors, documents):
            self.add_vector(vector=vector, document=document)

    def search(
        self, query: Any, k: int = 1
    ) -> List[Tuple[Dict[str, Any], float]]:
        if not self.vectors:
            return []

        if isinstance(query, str):
            if not self._embedding_fn:
                raise ValueError(
                    "Embedding function not provided for string query."
                )
            query_vector = self._embedding_fn(query)
        elif isinstance(query, list) and all(
            isinstance(x, (int, float)) for x in query
        ):
            query_vector = query
        else:
            raise TypeError(
                "Query must be either a string or a list of numbers."
            )

        if self._vector_dim is None:
            return []

        if len(query_vector) != self._vector_dim:
            raise ValueError(
                f"Query vector dimension mismatch. Expected {self._vector_dim}, got {len(query_vector)}"
            )

        if k <= 0:
            raise ValueError("k must be a positive integer.")

        if self._distance_metric == "cosine":
            dist_func = self._cosine_distance
        else:
            dist_func = self._euclidean_distance

        distances = []
        for i, stored_vector in enumerate(self.vectors):
            distance = dist_func(query_vector, stored_vector)
            distances.append((distance, self.documents[i]))

        distances.sort(key=lambda item: item[0])

        return [(doc, dist) for dist, doc in distances[:k]]

    def add_vector(self, vector, document: Dict[str, Any]):
        if not isinstance(vector, list) or not all(
            isinstance(x, (int, float)) for x in vector
        ):
            raise TypeError("Vector must be a list of numbers.")
        if not isinstance(document, dict):
            raise TypeError("Document must be a dictionary.")
        if "content" not in document:
            raise ValueError(
                "Document dictionary must contain a 'content' key."
            )

        if not self.vectors:
            self._vector_dim = len(vector)
        elif len(vector) != self._vector_dim:
            raise ValueError(
                f"Inconsistent vector dimension. Expected {self._vector_dim}, got {len(vector)}"
            )

        self.vectors.append(list(vector))
        self.documents.append(document)

    def _euclidean_distance(
        self, vec1: List[float], vec2: List[float]
    ) -> float:
        if len(vec1) != len(vec2):
            raise ValueError("Vectors must have the same dimension")
        return math.sqrt(sum((p - q) ** 2 for p, q in zip(vec1, vec2)))

    def _dot_product(self, vec1: List[float], vec2: List[float]) -> float:
        if len(vec1) != len(vec2):
            raise ValueError("Vectors must have the same dimension")
        return sum(p * q for p, q in zip(vec1, vec2))

    def _magnitude(self, vec: List[float]) -> float:
        return math.sqrt(sum(x * x for x in vec))

    def _cosine_distance(self, vec1: List[float], vec2: List[float]) -> float:
        if len(vec1) != len(vec2):
            raise ValueError("Vectors must have the same dimension")

        mag1 = self._magnitude(vec1)
        mag2 = self._magnitude(vec2)

        if mag1 == 0 and mag2 == 0:
            return 0.0
        elif mag1 == 0 or mag2 == 0:
            return 1.0

        dot_prod = self._dot_product(vec1, vec2)
        cosine_similarity = dot_prod / (mag1 * mag2)
        cosine_similarity = max(-1.0, min(1.0, cosine_similarity))

        return 1.0 - cosine_similarity

    def __len__(self) -> int:
        return len(self.vectors)

    def __repr__(self) -> str:
        has_embed_fn = "Yes" if self._embedding_fn else "No"
        return f"VectorIndex(count={len(self)}, dim={self._vector_dim}, metric='{self._distance_metric}', has_embedding_fn='{has_embed_fn}')"

# 4. best match 25

In [8]:
class BM25Index:
    def __init__(
        self,
        k1: float = 1.5,
        b: float = 0.75,
        tokenizer: Optional[Callable[[str], List[str]]] = None,
    ):
        self.documents: List[Dict[str, Any]] = []
        self._corpus_tokens: List[List[str]] = []
        self._doc_len: List[int] = []
        self._doc_freqs: Dict[str, int] = {}
        self._avg_doc_len: float = 0.0
        self._idf: Dict[str, float] = {}
        self._index_built: bool = False

        self.k1 = k1
        self.b = b
        self._tokenizer = tokenizer if tokenizer else self._default_tokenizer

    def _default_tokenizer(self, text: str) -> List[str]:
        text = text.lower()
        tokens = re.split(r"\W+", text)
        return [token for token in tokens if token]

    def _update_stats_add(self, doc_tokens: List[str]):
        self._doc_len.append(len(doc_tokens))

        seen_in_doc = set()
        for token in doc_tokens:
            if token not in seen_in_doc:
                self._doc_freqs[token] = self._doc_freqs.get(token, 0) + 1
                seen_in_doc.add(token)

        self._index_built = False

    def _calculate_idf(self):
        N = len(self.documents)
        self._idf = {}
        for term, freq in self._doc_freqs.items():
            idf_score = math.log(((N - freq + 0.5) / (freq + 0.5)) + 1)
            self._idf[term] = idf_score

    def _build_index(self):
        if not self.documents:
            self._avg_doc_len = 0.0
            self._idf = {}
            self._index_built = True
            return

        self._avg_doc_len = sum(self._doc_len) / len(self.documents)
        self._calculate_idf()
        self._index_built = True

    def add_document(self, document: Dict[str, Any]):
        if not isinstance(document, dict):
            raise TypeError("Document must be a dictionary.")
        if "content" not in document:
            raise ValueError(
                "Document dictionary must contain a 'content' key."
            )

        content = document.get("content", "")
        if not isinstance(content, str):
            raise TypeError("Document 'content' must be a string.")

        doc_tokens = self._tokenizer(content)

        self.documents.append(document)
        self._corpus_tokens.append(doc_tokens)
        self._update_stats_add(doc_tokens)

    def add_documents(self, documents: List[Dict[str, Any]]):
        if not isinstance(documents, list):
            raise TypeError("Documents must be a list of dictionaries.")

        if not documents:
            return

        for i, doc in enumerate(documents):
            if not isinstance(doc, dict):
                raise TypeError(f"Document at index {i} must be a dictionary.")
            if "content" not in doc:
                raise ValueError(
                    f"Document at index {i} must contain a 'content' key."
                )
            if not isinstance(doc["content"], str):
                raise TypeError(
                    f"Document 'content' at index {i} must be a string."
                )

            content = doc["content"]
            doc_tokens = self._tokenizer(content)

            self.documents.append(doc)
            self._corpus_tokens.append(doc_tokens)
            self._update_stats_add(doc_tokens)

        self._index_built = False

    def _compute_bm25_score(
        self, query_tokens: List[str], doc_index: int
    ) -> float:
        score = 0.0
        doc_term_counts = Counter(self._corpus_tokens[doc_index])
        doc_length = self._doc_len[doc_index]

        for token in query_tokens:
            if token not in self._idf:
                continue

            idf = self._idf[token]
            term_freq = doc_term_counts.get(token, 0)

            numerator = idf * term_freq * (self.k1 + 1)
            denominator = term_freq + self.k1 * (
                1 - self.b + self.b * (doc_length / self._avg_doc_len)
            )
            score += numerator / (denominator + 1e-9)

        return score

    def search(
        self,
        query: Any,
        k: int = 1,
        score_normalization_factor: float = 0.1,
    ) -> List[Tuple[Dict[str, Any], float]]:
        if not self.documents:
            return []

        if isinstance(query, str):
            query_text = query
        else:
            raise TypeError("Query must be a string for BM25Index.")

        if k <= 0:
            raise ValueError("k must be a positive integer.")

        if not self._index_built:
            self._build_index()

        if self._avg_doc_len == 0:
            return []

        query_tokens = self._tokenizer(query_text)
        if not query_tokens:
            return []

        raw_scores = []
        for i in range(len(self.documents)):
            raw_score = self._compute_bm25_score(query_tokens, i)
            if raw_score > 1e-9:
                raw_scores.append((raw_score, self.documents[i]))

        raw_scores.sort(key=lambda item: item[0], reverse=True)

        normalized_results = []
        for raw_score, doc in raw_scores[:k]:
            normalized_score = math.exp(-score_normalization_factor * raw_score)
            normalized_results.append((doc, normalized_score))

        normalized_results.sort(key=lambda item: item[1])

        return normalized_results

    def __len__(self) -> int:
        return len(self.documents)

    def __repr__(self) -> str:
        return f"BM25VectorStore(count={len(self)}, k1={self.k1}, b={self.b}, index_built={self._index_built})"

# 5. hybrid search

## Retriever implementation

In [9]:
class SearchIndex(Protocol):
    def add_document(self, document: Dict[str, Any]) -> None: ...

    # Added the 'add_documents' method to avoid rate limiting errors from VoyageAI
    def add_documents(self, documents: List[Dict[str, Any]]) -> None: ...

    def search(
        self, query: Any, k: int = 1
    ) -> List[Tuple[Dict[str, Any], float]]: ...


class Retriever:
    def __init__(
        self,
        *indexes: SearchIndex,
        reranker_fn: Optional[
            Callable[[List[Dict[str, Any]], str, int], List[str]]
        ] = None,
    ):
        if len(indexes) == 0:
            raise ValueError("At least one index must be provided")
        self._indexes = list(indexes)
        self._reranker_fn = reranker_fn

    def add_document(self, document: Dict[str, Any]):
        if "id" not in document:
            document["id"] = "".join(
                random.choices(string.ascii_letters + string.digits, k=4)
            )

        for index in self._indexes:
            index.add_document(document)

    # Added the 'add_documents' method to avoid rate limiting errors from VoyageAI
    def add_documents(self, documents: List[Dict[str, Any]]):
        for index in self._indexes:
            index.add_documents(documents)

    def search(
        self, query_text: str, k: int = 1, k_rrf: int = 60
    ) -> List[Tuple[Dict[str, Any], float]]:
        if not isinstance(query_text, str):
            raise TypeError("Query text must be a string.")
        if k <= 0:
            raise ValueError("k must be a positive integer.")
        if k_rrf < 0:
            raise ValueError("k_rrf must be non-negative.")

        all_results = [
            index.search(query_text, k=k * 5) for index in self._indexes
        ]

        doc_ranks = {}
        for idx, results in enumerate(all_results):
            for rank, (doc, _) in enumerate(results):
                doc_id = id(doc)
                if doc_id not in doc_ranks:
                    doc_ranks[doc_id] = {
                        "doc_obj": doc,
                        "ranks": [float("inf")] * len(self._indexes),
                    }
                doc_ranks[doc_id]["ranks"][idx] = rank + 1

        def calc_rrf_score(ranks: List[float]) -> float:
            return sum(1.0 / (k_rrf + r) for r in ranks if r != float("inf"))

        scored_docs: List[Tuple[Dict[str, Any], float]] = [
            (ranks["doc_obj"], calc_rrf_score(ranks["ranks"]))
            for ranks in doc_ranks.values()
        ]

        filtered_docs = [
            (doc, score) for doc, score in scored_docs if score > 0
        ]
        filtered_docs.sort(key=lambda x: x[1], reverse=True)

        result = filtered_docs[:k]

        if self._reranker_fn is not None:
            docs_only = [doc for doc, _ in result]

            for doc in docs_only:
                if "id" not in doc:
                    doc["id"] = "".join(
                        random.choices(
                            string.ascii_letters + string.digits, k=4
                        )
                    )

            doc_lookup = {doc["id"]: doc for doc in docs_only}
            reranked_ids = self._reranker_fn(docs_only, query_text, k)

            new_result = []
            original_scores = {id(doc): score for doc, score in result}

            for doc_id in reranked_ids:
                if doc_id in doc_lookup:
                    doc = doc_lookup[doc_id]
                    score = original_scores.get(id(doc), 0.0)
                    new_result.append((doc, score))

            result = new_result

        return result

# 6. re-ranking

In [10]:
def reranker_fn(docs, query_text, k):
    joined_docs = "\n".join(
        [
            f"""
        <document>
        <document_id>{doc["id"]}</document_id>
        <document_content>{doc["content"]}</document_content>
        </document>
        """
            for doc in docs
        ]
    )

    prompt = RERANK_PROMPT.format(
        k=k,
        query_text=query_text,
        joined_docs=joined_docs,
    )

    messages = []
    add_user_message(messages, prompt)
    add_assistant_message(messages, "```json")

    result = chat(messages, stop_sequences=["```"])

    # Note: updated to use 'text_from_message' helper fn
    try:
        return json.loads(text_from_message(result))["document_ids"]
    except:
        # Fallback if JSON parsing fails
        return [doc["id"] for doc in docs[:k]]

# 7. contextual answer generation

In [11]:
# Add context to a single chunk
def add_context(text_chunk, source_text):
    prompt = CONTEXT_PROMT.format(
        source_text=source_text,
        text_chunk=text_chunk,
    )

    messages = []

    add_user_message(messages, prompt)
    result = chat(messages)

    return text_from_message(result) + "\n" + text_chunk

# test function

In [12]:
DOCUMENT_FILE = "./assets/report.md"

In [13]:
def test_main():
    """Test function with a sample query"""
    # Run the setup part of main
    with open(DOCUMENT_FILE, "r") as f:
        source_text = f.read()
    
    chunks = chunk_by_structure(source_text)
    logger.info(f"Document split into {len(chunks)} chunks.")

    # Add context to chunks
    documents = []
    for i, chunk in enumerate(chunks[:3]):  # Test with first 3 chunks only
        contextualized_chunk = add_context(chunk, source_text)
        documents.append({
            "id": f"chunk_{i}",
            "content": contextualized_chunk
        })

    # Create embedding function wrapper
    def embedding_fn(text_or_texts):
        if isinstance(text_or_texts, str):
            embeddings = generate_embeddings(text_or_texts)
            return embeddings[0].values
        else:
            all_embeddings = []
            for text in text_or_texts:
                embeddings = generate_embeddings(text)
                all_embeddings.append(embeddings[0].values)
            return all_embeddings

    # Create retriever
    vector_index = VectorIndex(embedding_fn=embedding_fn)
    bm25_index = BM25Index()
    retriever = Retriever(vector_index, bm25_index, reranker_fn=reranker_fn)
    retriever.add_documents(documents)

    # Test with a sample query
    test_query = "What are the main findings of the report?"
    print(f"Testing with query: {test_query}")
    
    results = retriever.search(test_query, k=2)
    context = "\n\n".join([doc["content"] for doc, _ in results])
    
    answer_prompt = ANSWER_PROMT.format(
        context=context,
        query=test_query
    )
    
    messages = []
    add_user_message(messages, answer_prompt)
    response = chat(messages)
    answer = text_from_message(response)
    
    print("="*50)
    print("TEST RESULT:")
    print("="*50)
    print(answer)

In [14]:
test_main()

2025-08-08 20:08:06 - __main__ - INFO - Document split into 15 chunks.
2025-08-08 20:08:08 - httpx - INFO - HTTP Request: POST https://api.anthropic.com/v1/messages "HTTP/1.1 200 OK"
2025-08-08 20:08:09 - httpx - INFO - HTTP Request: POST https://api.anthropic.com/v1/messages "HTTP/1.1 200 OK"
2025-08-08 20:08:10 - httpx - INFO - HTTP Request: POST https://api.anthropic.com/v1/messages "HTTP/1.1 200 OK"
2025-08-08 20:08:12 - httpx - INFO - HTTP Request: POST https://generativelanguage.googleapis.com/v1beta/models/gemini-embedding-001:batchEmbedContents "HTTP/1.1 200 OK"
2025-08-08 20:08:13 - __main__ - INFO - Embeddings generated successfully.
2025-08-08 20:08:13 - httpx - INFO - HTTP Request: POST https://generativelanguage.googleapis.com/v1beta/models/gemini-embedding-001:batchEmbedContents "HTTP/1.1 200 OK"
2025-08-08 20:08:13 - __main__ - INFO - Embeddings generated successfully.
2025-08-08 20:08:14 - httpx - INFO - HTTP Request: POST https://generativelanguage.googleapis.com/v1bet

Testing with query: What are the main findings of the report?


2025-08-08 20:08:15 - httpx - INFO - HTTP Request: POST https://generativelanguage.googleapis.com/v1beta/models/gemini-embedding-001:batchEmbedContents "HTTP/1.1 200 OK"
2025-08-08 20:08:15 - __main__ - INFO - Embeddings generated successfully.
2025-08-08 20:08:16 - httpx - INFO - HTTP Request: POST https://api.anthropic.com/v1/messages "HTTP/1.1 200 OK"
2025-08-08 20:08:19 - httpx - INFO - HTTP Request: POST https://api.anthropic.com/v1/messages "HTTP/1.1 200 OK"


TEST RESULT:
Based on the context provided in the Executive Summary and the Table of Contents, the main findings of the Annual Interdisciplinary Research Review report include:

1. Medical Research: Advances in understanding the rare XDR-471 syndrome, leading to new diagnostic insights.

2. Software Engineering: Tackled persistent stability issues in Project Phoenix, implementing key fixes identified through error code analysis.

3. Financial Analysis: Revealed mixed quarterly performance, prompting strategic reviews, particularly concerning resource allocation impacting R&D pipelines.

4. Scientific Experimentation: Characterized novel material properties of Composite XT-5, potentially impacting future product lines.

5. Legal Developments: Navigated complex precedents, particularly in intellectual property related to the Synergy Dynamics case, ensuring compliance and mitigating risk.

6. Product Engineering: Finalized specifications for the next-generation Model Zircon-5, incorporati

# main function

In [15]:
DOCUMENT_FILE = "./assets/report.md"

In [16]:
def main():
    # step 1: text chunking
    with open(DOCUMENT_FILE, "r") as f:
        source_text = f.read()
    
    chunks = chunk_by_structure(source_text)
    logger.info(f"Document split into {len(chunks)} chunks.")

    # step 2: add context to chunks
    logger.info("Adding context to chunks...")
    documents = []
    for i, chunk in enumerate(chunks):
        contextualized_chunk = add_context(chunk, source_text)
        documents.append({
            "id": f"chunk_{i}",
            "content": contextualized_chunk
        })
    logger.info(f"Context added to {len(documents)} chunks.")

    # step 3: create embedding function wrapper
    def embedding_fn(text_or_texts):
        if isinstance(text_or_texts, str):
            embeddings = generate_embeddings(text_or_texts)
            return embeddings[0].values
        else:
            all_embeddings = []
            for text in text_or_texts:
                embeddings = generate_embeddings(text)
                all_embeddings.append(embeddings[0].values)
            return all_embeddings
        
    # step 4: create indexes
    logger.info("Creating vector and BM25 indexes...")
    vector_index = VectorIndex(embedding_fn=embedding_fn)
    bm25_index = BM25Index()

    # step 5: create retriever
    logger.info("Creating retriever...")
    retriever = Retriever(vector_index, bm25_index, reranker_fn=reranker_fn)
    
    # step 6: add documents to retriever
    logger.info("Adding documents to retriever...")
    retriever.add_documents(documents)
    logger.info(f"Added {len(documents)} documents to retriever.")

    # step 7: interactive query loop
    print("\n" + "="*50)
    print("RAG SYSTEM READY")
    print("="*50)
    print("You can now ask questions about the document.")
    print("Type 'quit' to exit.\n")
    
    while True:
        query = input("Enter your question: ").strip()
        
        if query.lower() in ['quit', 'exit', 'q']:
            print("Goodbye!")
            break
        
        if not query:
            continue
            
        print("\n" + "="*50)
        print("USER QUERY:")
        print("="*50)
        print(query)
        print()
        
        # Search for relevant documents
        print("Searching for relevant information...")
        results = retriever.search(query, k=3)  # Get top 3 most relevant chunks
        
        if not results:
            print("No relevant information found.")
            continue
            
        # Generate answer using retrieved context
        context = "\n\n".join([doc["content"] for doc, _ in results])
        
        answer_prompt = ANSWER_PROMT.format(
            context=context,
            query=query
        )
        
        messages = []
        add_user_message(messages, answer_prompt)
        
        print("="*50)
        print("ASSISTANT RESPONSE:")
        print("="*50)
        
        response = chat(messages)
        answer = text_from_message(response)
        print(answer)
        print("\n")

In [None]:
main()