<a href="https://colab.research.google.com/github/duper203/RAG_Techniques_with_upstage/upstage/21__adaptive_retrieval.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Adaptive Retrieval-Augmented Generation (RAG) System


## Key Components

1. Query Classifier: Determines the type of query (Factual, Analytical, Opinion, or Contextual).

2. Adaptive Retrieval Strategies: Four distinct strategies tailored to different query types:

  * Factual Strategy
  * Analytical Strategy
  * Opinion Strategy
  * Contextual Strategy
3. LLM Integration: LLMs are used throughout the process to enhance retrieval and ranking.

4. OpenAI GPT Model: Generates the final response using the retrieved documents as context.

## Method Details

1. Query Classification
2. Adaptive Retrieval Strategies
3. LLM-Enhanced Ranking
4. Response Generation

In [1]:
! pip3 install -qU langchain-upstage langchain langchain-community faiss-cpu PyMuPDF

In [2]:
import os
import asyncio
from google.colab import userdata

from langchain_upstage import ChatUpstage, UpstageEmbeddings
from langchain.prompts import PromptTemplate
from langchain.vectorstores import FAISS
from langchain.text_splitter import CharacterTextSplitter
from langchain.chains import RetrievalQA

from langchain_core.retrievers import BaseRetriever
from typing import Dict, Any
from langchain.docstore.document import Document
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field
from typing import List

os.environ["UPSTAGE_API_KEY"] = userdata.get("UPSTAGE_API_KEY")

In [3]:
class categories_options(BaseModel):
        category: str = Field(description="The category of the query, the options are: Factual, Analytical, Opinion, or Contextual", example="Factual")

class QueryClassifier:
    def __init__(self):
        self.llm = ChatUpstage()
        self.prompt = PromptTemplate(
            input_variables=["query"],
            template="Classify the following query into one of these categories: Factual, Analytical, Opinion, or Contextual.\nQuery: {query}\nCategory:"
        )
        self.chain = self.prompt | self.llm.with_structured_output(categories_options)

    def classify(self, query):
        print("classifying query")
        return self.chain.invoke(query).category

In [4]:
class BaseRetrievalStrategy:
    def __init__(self, texts):
        self.embeddings = UpstageEmbeddings(model="solar-embedding-1-large")
        text_splitter = CharacterTextSplitter(chunk_size=800, chunk_overlap=0)
        self.documents = text_splitter.create_documents(texts)
        self.db = FAISS.from_documents(self.documents, self.embeddings)
        self.llm = ChatUpstage()

    def retrieve(self, query, k=4):
        return self.db.similarity_search(query, k=k)

In [5]:
class relevant_score(BaseModel):
        score: float = Field(description="The relevance score of the document to the query", example=8.0)

class FactualRetrievalStrategy(BaseRetrievalStrategy):
    def retrieve(self, query, k=4):
        print("retrieving factual")
        enhanced_query_prompt = PromptTemplate(
            input_variables=["query"],
            template="Enhance this factual query for better information retrieval: {query}"
        )
        query_chain = enhanced_query_prompt | self.llm
        enhanced_query = query_chain.invoke(query).content
        docs = self.db.similarity_search(enhanced_query, k=k*2)
        ranked_docs = []
        for doc in docs:
            input_data = {"query": enhanced_query, "doc": doc.page_content}
            score = float(ranking_chain.invoke(input_data).score)
            ranked_docs.append((doc, score))

        ranked_docs.sort(key=lambda x: x[1], reverse=True)
        return [doc for doc, _ in ranked_docs[:k]]