# NLP Experiment Using LLMs for Question Answering with Retrieval-Augmented Generation (RAG)

## Overview
This notebook implements an NLP experiment using a Large Language Model (LLM) to answer questions based on different configurations, including retrieval-augmented generation (RAG). It employs both a plain LLM chain and RAG-based methods to enhance the model's ability to answer questions with external data.

### Key Features:
- **Plain LLM Chain:** Implements a basic LLM chain using a Hugging Face model with prompt templates.
- **Retrieval Augmented Generation (RAG):** 
  - Retrieves information from an external file (`data/cats_content.txt`) to support answering questions.
  - Uses the FAISS vector database for document retrieval based on embeddings generated by a sentence-transformer model.
- **Multiple Experiment Methods:**
  - Plain LLM response without external data.
  - RAG chain with source tracking using `RetrievalQA`.
  - Experimental method using `RetrievalQAWithSourcesChain` to provide answers with sources.

### Implemented Components:
- **Hugging Face Hub Integration:** Uses a T5 model (`google/flan-t5-large`) for the LLM tasks.
- **Embedding Model:** Utilizes `sentence-transformers/all-MiniLM-L6-v2` to generate embeddings for document retrieval.
- **Prompt Templates:** Customizable prompt templates for both plain LLM and RAG setups.
- **FAISS Vector Store:** Efficient vector-based retrieval for documents to support the RAG chain.
- **Chain Setup Functions:** Includes functions to initialize and configure the plain LLM, RAG, and RAG with sources.

In [None]:
!pip install -q accelerate==0.25.0 bertopic==0.15.0 datasets==2.14.4
!pip install -q faiss-cpu==1.7.4 langchain==0.0.348 langchainhub==0.1.14
!pip install -q sentence-transformers==2.2.2 sentencepiece==0.1.99 transformers==4.24.0

In [None]:
%env HUGGINGFACEHUB_API_TOKEN=<YOUR_API_TOKEN>

In [None]:
import logging

from langchain import HuggingFaceHub, hub
from langchain.chains import LLMChain, RetrievalQA, RetrievalQAWithSourcesChain
from langchain.document_loaders import TextLoader
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.prompts import PromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS

In [None]:
def get_logger(name: str = __name__) -> logging.Logger:
    logging.basicConfig(
        format="%(asctime)s:%(module)s:%(funcName)s:%(levelname)s: %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    logger = logging.getLogger(name)
    logger.setLevel(logging.INFO)
    return logger


logger = get_logger(__name__)

In [None]:
class QuestionAnsweringExperiment:
    """
    This class defines logic to setup an experiment to test how model can answer questions with:
    - plain llm chain. Just an llm and a prompt using langchain.
    - Retrieval augmeneted generation chain - a chain with the access to the external file content.
    """

    def __init__(
        self,
        repo_id: str,
        external_data_path: str,
        embedding_model_name: str,
        plain_lmm_prompt_template: str,
        rag_llm_prompt_template_name: str,
    ):
        """
        Initializes the experiment setup for testing models in answering questions using both plain LLM chain and Retrieval Augmented Generation (RAG) chain

        :param repo_id: The repository ID for the Hugging Face model
        :param external_data_path: Path to the external data file to be used for RAG
        :param embedding_model_name: Name of the model for generating embeddings
        :param plain_lmm_prompt_template: Prompt template for the plain LLM model
        :param rag_llm_prompt_template_name: The repository ID for the Hugging Face prompt template
        """
        self.llm = HuggingFaceHub(
            repo_id=repo_id, model_kwargs={"temperature": 0.05, "max_length": 200}
        )

        self.embedding_model = None
        self.vectorstore = None

        self.init_embedding_model(embedding_model_name)

        self.create_vector_database(
            self.load_documents_from_file(
                external_data_path, chunk_size=100, chunk_overlap=10
            ),
            self.embedding_model,
        )

        self.setup_plain_llm_chain(plain_lmm_prompt_template)

        self.setup_rag_llm_chain(
            rag_llm_prompt_template_name
        )

        self.setup_rag_with_source_llm_chain(
             rag_llm_prompt_template_name
        )

    def setup_plain_llm_chain(self, prompt_template: str):
        """
        Sets up the plain LLM (Language Model) chain by initializing an LLMChain with the provided LLM and prompt template

        :param prompt_template: The prompt template to use with the plain LLM chain
        """
        logger.info("Setting up plain LLM chain")
        self.plain_llm_chain = LLMChain(
            llm=self.llm, prompt=PromptTemplate.from_template(prompt_template)
        )

    def init_embedding_model(self, model_name: str):
        """
        Initializes the embedding model from the Hugging Face Hub using the specified model name

        :param model_name: Name of the model to be used for generating embeddings
        """
        logger.info(f"Initializing embedding model: {model_name}")
        if self.embedding_model is None:
            self.embedding_model = HuggingFaceEmbeddings(model_name=model_name)
        else:
            logger.info("Embedding model has been already initialized")

    def create_vector_database(self, documents: list, embedding_model):
        """
        Creates a FAISS vector database, adding documents and their corresponding embeddings generated using the provided embedding model

        :param documents: List of documents to be added to the database
        :param embedding_model: The embedding model used to generate embeddings for the documents
        """
        logger.info("Creating FAISS vector database")
        self.vectorstore = FAISS.from_documents(documents, embedding_model)

    def load_documents_from_file(
        self, file_path: str, chunk_size: int = 100, chunk_overlap: int = 10
    ):
        """
        Loads documents from a specified file path, splitting the text into chunks if necessary

        :param file_path: Path to the file containing the documents
        :param chunk_size: The size of each text chunk (default is 100)
        :param chunk_overlap: The overlap between chunks (default is 10)
        :return: A list of documents splitted into chunks
        """
        logger.info(f"Loading documents from file: {file_path}")
        loader = TextLoader(file_path)
        documents = loader.load()

        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size, chunk_overlap=chunk_overlap, add_start_index=True
        )
        docs_splitted = text_splitter.split_documents(documents)
        return docs_splitted

    def download_prompt_from_hub(self, prompt_name):
        """
        Downloads a prompt template by its name from the Langchain Hub.

        :param prompt_name: The name (Hugging Face repo ID) of the prompt template to be downloaded
        :return: The downloaded prompt template
        """
        logger.info(f"Downloading prompt template: {prompt_name}")
        prompt_template = hub.pull(prompt_name)
        return prompt_template

    def setup_rag_llm_chain(
        self, rag_llm_prompt_template_name
    ):
        """
        Sets up the Retrieval Augmented Generation (RAG) LLM chain, including loading documents, initializing the embedding model, creating a FAISS database, and initializing the RAG chain

        :param rag_llm_prompt_template_name: Name (Hugging Face repo ID) of the RAG LLM prompt template
        """
        prompt_template = self.download_prompt_from_hub(rag_llm_prompt_template_name)

        logger.info("Setting up RAG LLM chain")
        self.rag_llm_chain = RetrievalQA.from_llm(
            llm=self.llm,
            prompt=prompt_template,
            retriever=self.vectorstore.as_retriever(),
        )

    def setup_rag_with_source_llm_chain(
        self, rag_llm_prompt_template_name
    ):
        """
        Initializes and sets up the Retrieval Augmented Generation (RAG) LLM chain with source document tracking

        :param rag_llm_prompt_template_name: Name (Hugging Face repo ID) of the RAG LLM prompt template
        """
        prompt_template = self.download_prompt_from_hub(rag_llm_prompt_template_name)

        logger.info("Setting up simple RAG LLM chain with sources")
        self.simple_rag_with_source_llm_chain = RetrievalQA.from_llm(
            llm=self.llm,
            prompt=prompt_template,
            retriever=self.vectorstore.as_retriever(),
            return_source_documents=True,
        )

        logger.info("Setting up experimental RAG LLM chain with sources")
        self.rag_with_source_llm_chain = RetrievalQAWithSourcesChain.from_chain_type(
            llm=self.llm,
            chain_type="stuff",
            retriever=self.vectorstore.as_retriever(),
        )

    def predict_with_llm_chain(self, query):
        """
        Generates a prediction for a given query using the plain LLM chain

        :param query: The query string to be answered
        :return: The generated answer from the plain LLM chain
        :raises AttributeError: If the plain_llm_chain attribute does not exist
        """
        if not hasattr(self, "plain_llm_chain"):
            raise AttributeError("Please set up a chain before calling predict")
        return self.plain_llm_chain.run(query)

    def predict_with_rag_chain(self, query):
        """
        Generates a prediction for a given query using the RAG LLM chain

        :param query: The query string to be answered
        :return: The generated answer from the RAG LLM chain
        :raises AttributeError: If the rag_llm_chain attribute does not exist
        """
        if not hasattr(self, "rag_llm_chain"):
            raise AttributeError(
                "RAG LLM chain is not set up. Please set it up before calling predict"
            )

        response = self.rag_llm_chain.run(query)

        return response

    def predict_rag_with_source_chain(self, query):
        """
        Generates a prediction for a given query using the RAG LLM chain that also returns the source document information

        :param query: The query string to be answered
        :return: The generated answer from the RAG LLM chain with different methods
        :raises AttributeError: If the rag_llm_chain attribute does not exist
        """

        results = {}

        # Manually getting the docs
        retrieved_docs = self.rag_llm_chain.retriever.invoke(query)
        response = self.rag_llm_chain.run(query)
        results["dumb"] = {"answer": response, "sources": retrieved_docs}

        # The same method but from model (using return_source_documents=True)
        response = self.simple_rag_with_source_llm_chain(query)
        results["simple"] = {
            "answer": response["result"],
            "sources": response["source_documents"],
        }

        # Trying RetrievalQAWithSourcesChain
        response = self.rag_with_source_llm_chain(query)
        results["experimental"] = {
            "answer": response["answer"],
            "sources": response["sources"],
        }

        return results


if __name__ == "__main__":
    """
    Main entrypoint.
    1) Initialises experiment as an instance of QuestionAnsweringExperiment with passed params.
    2) initialises a set of test queries.
    3) iterate queries and runs "predict_with_llm_chain" and "predict_with_rag_chain" methods.
    4) prints results of generated answers for both setups as well as expected result.
    """

    qa_prompt_template = """Answer the question. If you don't know the answer, just say that you don't know, don't try to make up an answer.
    Question: {question}
    Helpful Answer:"""

    experiment = QuestionAnsweringExperiment(
        repo_id="google/flan-t5-large",
        external_data_path="data/cats_content.txt",
        embedding_model_name="sentence-transformers/all-MiniLM-L6-v2",
        plain_lmm_prompt_template=qa_prompt_template,
        rag_llm_prompt_template_name="rlm/rag-prompt",
    )

    queries = [
        ("Are cats good jumpers?", "yes"),
        ("How far can cat jump in compare with its own length?", "up to 6 times"),
        ("How many bones does cat have?", "230"),
        ("How many toes does each front paw of a cat has?", "five"),
    ]

    print("\nStarting answering questions! \n")
    for query, answer in queries:
        print(f"\n'{query}' - expected answer: {answer} \n")
        plain_llm_answer = experiment.predict_with_llm_chain(query)
        print(f"\tPlain llm answer: {plain_llm_answer} \n")
        rag_llm_answer = experiment.predict_with_rag_chain(query)
        print(f"\tRAG llm answer: {rag_llm_answer} \n")
        print("\tRAG with sources llm answers:\n")

        results = experiment.predict_rag_with_source_chain(query)
        for key in results.keys():
            print(f"\t\t{key.capitalize()} method:")
            print(f"\t\t\tAnswer: {results[key]['answer']}")
            print(f"\t\t\tSources: {results[key]['sources']}\n")


INFO:__main__:Initializing embedding model: sentence-transformers/all-MiniLM-L6-v2
INFO:__main__:Loading documents from file: data/cats_content.txt
INFO:__main__:Creating FAISS vector database
INFO:__main__:Setting up plain LLM chain
INFO:__main__:Downloading prompt template: rlm/rag-prompt
INFO:__main__:Setting up RAG LLM chain
INFO:__main__:Downloading prompt template: rlm/rag-prompt
INFO:__main__:Setting up simple RAG LLM chain with sources
INFO:__main__:Setting up experimental RAG LLM chain with sources



Starting answering questions! 


'Are cats good jumpers?' - expected answer: yes 

	Plain llm answer: no 

	RAG llm answer: yes 

	RAG with sources llm answers:

		Dumb method:
			Answer: yes
			Sources: [Document(page_content='Cats can jump up to six times their length.', metadata={'source': 'data/cats_content.txt', 'start_index': 568}), Document(page_content='Cats’ claws all curve downward, which means that they can’t climb down trees head-first. Instead,', metadata={'source': 'data/cats_content.txt', 'start_index': 612}), Document(page_content='Cats are nearsighted, but their peripheral vision and night vision are much better than that of', metadata={'source': 'data/cats_content.txt', 'start_index': 371}), Document(page_content='Cats’ collarbones don’t connect to their other bones, as these bones are buried in their shoulder', metadata={'source': 'data/cats_content.txt', 'start_index': 744})]

		Simple method:
			Answer: yes
			Sources: [Document(page_content='Cats can jump up to s

Unfortunately, I ran into an unknown problem of lack of answer from the `RetrievalQAWithSourcesChain` model :(  
But in returns correct source