# Adaptive RAG System
## Overview
This system implements an advanced RAG approach that adapts its retrieval strategy based on the type of query. By leveraging LLM at various stages, it aims to provide more accurate, relevant, and context-aware responses to user queries.
## Motivation
Traditional RAG systems often use a one-size-fits-all appraoch to retrieval, which can be suboptimal for different types of queries. Our adaptive system is motivated by the understanding that different typers of questions require different retrieval strategies. For example, a factual query might benefit from precise, focused retrieval, while an analytical query might require a broader, more diverse set of information.
## Key Components
1. Query Classifier: Determines the type of query (Factual, Analytical, Opinon, or Contextual)
2. Adaptive Retrieval Strategies: 4 distinct strategies tailored to different query types:
    - Factual Strategy
    - Analytical Strategy
    - Opinion Strategy
    - Contextual Strategy
3. LLM Integration: LLMs are used throughtout the process to enhance retrieval and ranking
4. LLM: Generates the final response using the retrieved documents as context
## Benefits of This Approach
1. Improved Accuracy: By tailoring the retrieval strategy to the query type, the system can provide more accurate and relevant information
2. Flexibility: The system adapts to different types of queries, handling a wide range of user needs
3. Context-Awareness: Especially for contextual queries, the system can incorporate user-specific information for more personalized responses
4. Diverse Persepectives: For opinion-based queries, the system actively seeks out and presents multiple viewpoints
5. Comprehensive Analysis: The analytical strategy ensures a thorough exploration of complex topics
## Conclusion
This adaptive RAG system represents a significant advancement over traditional RAG approaches. By dynamically adjusting its retrieval strategy and leveraging LLMs throughout the process, it aims to provide more accurate, relevant, and nuanced responses to a wide variety of user quries.

In [2]:
import os
from dotenv import load_dotenv

from langchain_openai.chat_models.azure import AzureChatOpenAI
load_dotenv()
openai_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
openai_api_key = os.environ.get("AZURE_OPENAI_API_KEY")
openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT_ID")
openai_api_version = os.getenv("AZURE_API_VERSION")

llm = AzureChatOpenAI(
    azure_deployment=openai_deployment,
    api_version="2024-10-01-preview",
    azure_endpoint=f"{openai_endpoint}openai/deployments/{openai_deployment}/chat/completions?api-version=2024-10-01-preview",
    temperature=0,
    logprobs=True,
)

In [3]:
from pydantic import BaseModel, Field
from langchain.prompts import PromptTemplate
from langchain.vectorstores import FAISS
openai_embedding = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT_ID")

from langchain_openai.embeddings.azure import AzureOpenAIEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain_core.retrievers import BaseRetriever
from typing import Dict, Any
from langchain.docstore.document import Document


In [31]:
class categories_options(BaseModel):
    category: str = Field(
        description="The category of the query, the options are: Factual, Analytical, Opinion, or Contextual",
        examples=["Factual", "Analytical", "Opinion", "Contextual"],
    )

class QueryClassifier:
    def __init__(self) -> None:
        self.llm = llm
        self.prompt = PromptTemplate(
            input_variables=["query"],
            template="Classify the following query into one of these categories: Factual, Analytical, Opinion, or Contextual.\nQuery: {query}\nCategory:"
        )
        print("classifier prompt", self.prompt)
        self.chain = self.prompt | self.llm.with_structured_output(categories_options)

    def classify(self, query):
        print("classifying query")
        return self.chain.invoke(query).category

In [6]:
embeddings = AzureOpenAIEmbeddings(
    deployment=openai_embedding,
    model="text-embedding-ada-002",
    chunk_size=16
)

class BaseRetrievalStrategy:
    def __init__(self, texts) -> None:
        self.embeddings = embeddings
        text_splitter = CharacterTextSplitter(chunk_size=800, chunk_overlap=0)
        self.documents = text_splitter.create_documents(texts)
        self.db = FAISS.from_documents(self.documents, self.embeddings)
        self.llm = llm
    def retrieve(self, query, k=4):
        return self.db.similarity_search(query, k)

In [7]:
class relevant_score(BaseModel):
    score: float = Field(description="The relevance score of the document to the query", examples=8.0)

class FactualRetrievalStrategy(BaseRetrievalStrategy):
    def retrieve(self, query, k=4):
        print("retrieving factual")
        enhanced_query_prompt = PromptTemplate(
            input_variables=["query"],
            template="Enhance this factual query for better information retrieval: {query}"
        )

        query_chain = enhanced_query_prompt | self.llm
        enhanced_query = query_chain.invoke(query).content
        print(f"enhanced query: {enhanced_query}")

        docs = self.db.similarity_search(enhanced_query, k*2)

        ranking_prompt = PromptTemplate(
            input_variables=["query", "documents"],
            template="On a scale of 1-10, how relevant is this document to the query: '{query}'?\nDocument: {doc}\nRelevance score:"
        )
        ranking_chain = ranking_prompt | self.llm.with_structured_output(relevant_score)

        ranked_docs = []
        print("ranking docs")

        for doc in docs:
            input_data = {
                "query": enhanced_query,
                "doc": doc.page_content
            }
            score = float(ranking_chain.invoke(input_data).score)
            ranked_docs.append((doc, score))
        ranked_docs.sort(key=lambda x: x[1], reverse=True)
        return [doc for doc, _ in ranked_docs[:k]]

In [8]:
from typing import List
class SelectedIndices(BaseModel):
    indices: List[int] = Field(
        description="Indices of selected documents", 
        example=[0, 1, 2, 3]
    )

class SubQueries(BaseModel):
    sub_queries: List[str] = Field(
        description="List of sub-queries for comprehensive analysis", 
        example=["What is the population of New York?", "What is the GDP of New York?"]
    )

class AnalyticalRetrievalStrategy(BaseRetrievalStrategy):
    def retrieve(self, query, k=4):
        print("retrieving analytical")
        enhanced_query_prompt = PromptTemplate(
            input_variables=["query", "k"],
            template="Generate {k} sub-questions for: {query}"
        )

        subquery_chain = enhanced_query_prompt | self.llm.with_structured_output(SubQueries)
        input_data = {
            "query": query,
            "k": k
        }
        subqueries = subquery_chain.invoke(input_data).sub_queries
        print(f"sub queries for comprehensive analysis: {subqueries}")

        all_docs = []
        for sub_query in subqueries:
            all_docs.extend(self.db.similarity_search(sub_query, 2))

        deversity_prompt = PromptTemplate(
            input_variables=["query", "documents", "k"],
            template="""Select the most diverse and relevant set of {k} documents for the query: '{query}'\nDocuments: {documents}\n
            Return only the indices of selected documents as a list of integers."""
        )

        diversity_chain = deversity_prompt | self.llm.with_structured_output(SelectedIndices)
        docs_text = "\n".join([f"{i}: {doc.page_content[:50]}..." for i, doc in enumerate(all_docs)])
        input_data = {"query": query, "docs": docs_text, "k": k}
        selected_indices_result = diversity_chain.invoke(input_data).indices
        print(f'selected diverse and relevant documents')
        
        return [all_docs[i] for i in selected_indices_result if i < len(all_docs)]

In [9]:
class OpinionRetrievalStrategy(BaseRetrievalStrategy):
    def retrieve(self, query, k=3):
        print("retrieving opinion")
        # Use LLM to identify potential viewpoints
        viewpoints_prompt = PromptTemplate(
            input_variables=["query", "k"],
            template="Identify {k} distinct viewpoints or perspectives on the topic: {query}"
        )
        viewpoints_chain = viewpoints_prompt | self.llm
        input_data = {"query": query, "k": k}
        viewpoints = viewpoints_chain.invoke(input_data).content.split('\n')
        print(f'viewpoints: {viewpoints}')

        all_docs = []
        for viewpoint in viewpoints:
            all_docs.extend(self.db.similarity_search(f"{query} {viewpoint}", k=2))

        opinion_prompt = PromptTemplate(
            input_variables=["query", "documents", "k"],
            template="Classify these documents into distinct opinions on '{query}' and select the {k} most representative and diverse viewpoints:\nDocuments: {docs}\nSelected indices:"
        )

        opinion_chain = opinion_prompt | self.llm.with_structured_output(SelectedIndices)

        docs_text = "\n".join([f"{i}: {doc.page_content[:100]}..." for i, doc in enumerate(all_docs)])
        input_data = {"query": query, "docs": docs_text, "k": k}
        selected_indices = opinion_chain.invoke(input_data).indices
        print(f'selected diverse and relevant documents')
        
        return [all_docs[int(i)] for i in selected_indices.split() if i.isdigit() and int(i) < len(all_docs)]

In [10]:
class ContextualRetrievalStrategy(BaseRetrievalStrategy):
    def retrieve(self, query, k=4, user_context=None):
        print("retrieving contextual")

        context_prompt = PromptTemplate(
            input_variables=["query", "context"],
            template="Given the user context: {context}\nReformulate the query to best address the user's needs: {query}"
        )

        context_chain = context_prompt | self.llm
        input_data = {"query": query, "context": user_context or "No specific context provided"}
        contextualized_query = context_chain.invoke(input_data).content
        print(f'contextualized query: {contextualized_query}')

        docs = self.db.similarity_search(contextualized_query, k*2)

        ranking_prompt = PromptTemplate(
            input_variables=["query", "context", "doc"],
            template="Given the query: '{query}' and user context: '{context}', rate the relevance of this document on a scale of 1-10:\nDocument: {doc}\nRelevance score:"

        )

        ranking_chain = ranking_prompt | self.llm.with_structured_output(relevant_score)
        print("ranking docs")

        ranked_docs = []
        for doc in docs:
            input_data = {"query": contextualized_query, "context": user_context or "No specific context provided", "doc": doc.page_content}
            score = float(ranking_chain.invoke(input_data).score)
            ranked_docs.append((doc, score))


        ranked_docs.sort(key=lambda x: x[1], reverse=True)
        return [doc for doc, _ in ranked_docs[:k]]

In [26]:
class AdaptiveRetriever:
    def __init__(self, texts: List[str]) -> None:
        self.classifier = QueryClassifier()
        self.strategies = {
            "Factual": FactualRetrievalStrategy(texts),
            "Analytical": AnalyticalRetrievalStrategy(texts),
            "Opinion": OpinionRetrievalStrategy(texts),
            "Contextual": ContextualRetrievalStrategy(texts)
        }

    def get_relevant_documents(self, query:str) -> List[Document]:
        category = self.classifier.classify(query)
        print("Category: ", category)
        strategy = self.strategies[category]
        return strategy.retrieve(query)

In [15]:
class PydanticAdaptiveRetriever(BaseRetriever):
    adaptive_retriever: AdaptiveRetriever = Field(exclude=True)

    class Config:
        arbitrary_types_allowed = True
    def _get_relevant_documents(self, query: str) -> List[Document]:
        return self.adaptive_retriever.get_relevant_documents(query)
    
    async def _aget_relevant_documents(self, query:str) -> List[Document]:
        return self._get_relevant_documents(query)

In [24]:
class AdaptiveRAG:
    def __init__(self, texts: List[str]):
        adaptive_retriever = AdaptiveRetriever(texts)
        self.retriever = PydanticAdaptiveRetriever(adaptive_retriever=adaptive_retriever)
        self.llm = llm
        
        # Create a custom prompt
        prompt_template = """Use the following pieces of context to answer the question at the end. 
        If you don't know the answer, just say that you don't know, don't try to make up an answer.

        {context}

        Question: {question}
        Answer:"""
        prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
        
        # Create the LLM chain
        self.llm_chain = prompt | self.llm
        
      

    def answer(self, query: str) -> str:
        docs = self.retriever.invoke(query)
        print("docs", docs)
        input_data = {"context": "\n".join([doc.page_content for doc in docs]), "question": query}
        return self.llm_chain.invoke(input_data)

In [20]:
texts = [
    "The Earth is the third planet from the Sun and the only astronomical object known to harbor life."
    ]
rag_system = AdaptiveRAG(texts)

In [32]:
factual_result = rag_system.answer("What is the distance between the Earth and the Sun?").content
print(f"Answer: {factual_result}")


classifying query


BadRequestError: Error code: 400 - {'error': {'message': "Invalid schema for function 'categories_options': 'Factual' is not of type 'array'.", 'type': 'invalid_request_error', 'param': 'tools[0].function.parameters', 'code': 'invalid_function_parameters'}}

In [None]:

analytical_result = rag_system.answer("How does the Earth's distance from the Sun affect its climate?").content
print(f"Answer: {analytical_result}")

opinion_result = rag_system.answer("What are the different theories about the origin of life on Earth?").content
print(f"Answer: {opinion_result}")

contextual_result = rag_system.answer("How does the Earth's position in the Solar System influence its habitability?").content
print(f"Answer: {contextual_result}")