# Prepare-Rewrite-Retrieve RAG flow

In this notebook we will explore the data retrieval and generation process of the **prepare-then-rewrite-then-retrieve-then-read** framework proposed by the authors of ["Meta Knowledge for Retrieval Augmented Large Language Models"](https://www.amazon.science/publications/meta-knowledge-for-retrieval-augmented-large-language-models) for creating more accurate and enriched RAG workflows.

## Pre-requisites

To run this notebook your role executing the notebook needs:

* Permissions to invoke Bedrock
* Access to the Amazon Nova Pro model
* Having executed the [DataIndexing.ipynb](./DataIndexing.ipynb) notebook

Additionally, we need the following python packages:

In [None]:
!pip install -U boto3 langchain langchain-aws dotenv faiss-cpu

In [2]:
import os
import re
import logging
import json
import secrets
import time
import boto3
import faiss
import langchain_core

from dotenv import load_dotenv

from pydantic import BaseModel, Field
from typing import Literal

from enum import Enum
from PyPDF2 import PdfReader
from botocore.exceptions import ClientError
from langchain_aws import ChatBedrockConverse

from prompts.dataRetrieval.generate_query_augmentation_prompts import get_query_augmentation_prompt_selector, get_structured_questions_prompt_selector
from prompts.dataRetrieval.generate_qa_kb_prompts import get_kb_qa_prompt_selector
from structured_output.questions import Questions
from structured_output.answers import Answer

from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

from langchain_aws.embeddings import BedrockEmbeddings
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

from langchain_community.vectorstores import FAISS
from langchain_community.docstore.in_memory import InMemoryDocstore

In [None]:
logger = logging.getLogger()
langchain_core.globals.set_debug(False)
load_dotenv()

In [4]:
BEDROCK_MODEL_ID = "us.amazon.nova-pro-v1:0"
EMBEDDINGS_MODEL_ID="amazon.titan-embed-text-v2:0"
EMBEDDING_SIZE = 1024

### Meta-KB-Summaries and VectorStore

This knowledge base is made up of two components:

* A vector store: As in any other knowledge base the main component is a vector store where the embeddings are persisted. For this implementation we use [FAISS](https://faiss.ai/index.html) as the vector store.
* A meta-knowledge base: Consisting of the summaries of each partition of the knowledge base. For this implementation the meta-knowledge base is stored in a simple python dictionary.

In [5]:
# We use a python dictionary as meta-knowledge summary
with open("./data/meta_kb_summaries.json", "r") as f:
    meta_kb_summaries = json.load(f)
          

# We use a FAISS vectorstore as knowledge base
embeddings_model = BedrockEmbeddings(
    model_id=EMBEDDINGS_MODEL_ID,
    model_kwargs={"dimensions": EMBEDDING_SIZE},
    region_name="us-east-1"
)

vector_index = faiss.IndexFlatL2(EMBEDDING_SIZE)
vector_store = FAISS.load_local(
    folder_path="./data/faiss_index", 
    embeddings=embeddings_model, 
    allow_dangerous_deserialization=True
)

### Type definitions

In [6]:
# Customize according to the types of document to be processed by the application
class DocumentTypes(Enum):
    SYSTEM_ARCHITECTURE = "systems architecture"
    SECURITY = "information technology security"
    DATA_GOVERNANCE = "data governance"
    TECH_STRATEGY = "tech strategy"
    MANAGEMENT = "management"

class AnalysisPerspectives(Enum):
    SECURITY = "software security engineer"
    DATA_GOVERNANCE = "data governance"
    RESILIENCY = "systems resiliency"
    SYS_OPS = "systems operations"

# Persona definition for generating and answering QA
AnalysisPersonas = {
    "software security engineer": {
        "description": "It is responsible for ensuring that workloads have the necessary security controls in place",
        "perspectives": [AnalysisPerspectives.SECURITY.value, AnalysisPerspectives.DATA_GOVERNANCE.value]
    },
    "solutions architect": {
        "description": "It is responsible for designing scalable and cost-efficient software solutions",
        "perspectives": [AnalysisPerspectives.RESILIENCY.value, AnalysisPerspectives.DATA_GOVERNANCE.value, AnalysisPerspectives.SECURITY.value]
    },
    "software developer": {
        "description":"Implements the system functionalities",
        "perspectives": [AnalysisPerspectives.SYS_OPS.value, AnalysisPerspectives.RESILIENCY.value, AnalysisPerspectives.SECURITY.value]
    }
}

In [None]:
ANALYSIS_PERSONNA = "solutions architect"
ANALYSIS_PERSPECTIVE = AnalysisPersonas[ANALYSIS_PERSONNA]["perspectives"][0]

print(f"Using persona: {ANALYSIS_PERSONNA}")
print(f"Using perspective: {ANALYSIS_PERSPECTIVE}")

## Helper functions

In [8]:
def qa_chatbot_answer(
    user_query,
    role,
    perspective,
    context
):
    "Answer query given context using LLMs"

    print(user_query)

    rag_llm = ChatBedrockConverse(
        model=BEDROCK_MODEL_ID,
        temperature=0.4,
        max_tokens=1000,
        # other params...
    )

    LLM_KB_QA_PROMPT_SELECTOR = get_kb_qa_prompt_selector(lang="en")
    
    gen_kb_qa_prompt = LLM_KB_QA_PROMPT_SELECTOR.get_prompt(BEDROCK_MODEL_ID)
    
    kb_qa_generate = gen_kb_qa_prompt | rag_llm.with_structured_output(Answer)

    rag_qa = kb_qa_generate.invoke(
        {
            "question": user_query,
            "role": role,
            "perspective": perspective,
            "context": context
        }
    )

    return rag_qa

def augment_user_query(
        role,
        user_query,
        mk_summary,
):
    "Augment the user query with additional queries based on the meta-knowledge summary."
    
    query_augmentation_llm = ChatBedrockConverse(
        model=BEDROCK_MODEL_ID,
        temperature=0.4,
        max_tokens=2000,
        # other params...
    )
    
    LLM_AUGMENT_QUERY_PROMPT_SELECTOR = get_query_augmentation_prompt_selector(lang="en")
    
    gen_queries_prompt = LLM_AUGMENT_QUERY_PROMPT_SELECTOR.get_prompt(BEDROCK_MODEL_ID)
    structured_queries = query_augmentation_llm.with_structured_output(Questions)
    
    structured_queries_generate = gen_queries_prompt | structured_queries

    augmented_queries = structured_queries_generate.invoke(
        {
            "role": role,
            "mk_summary": mk_summary,
            "user_query": user_query
        }
    )

    return augmented_queries


def prepare_rewrite_retrieve_rag_qa(
    query,
    persona,
    perspective
):
    """Answer a user query using the prepare-rewrite-retrieve framework from a persona-perspective point of view"""

    qa_str = ""
    qa_pairs = []

    print(query)

    #Augment user query
    augmented_queries = augment_user_query(
        role=persona,
        user_query=query,
        mk_summary=meta_kb_summaries[f"{persona}-{perspective}"]
    )

    print(augmented_queries)

    # Retrieve context from knowledge base using metadata as partition keys
    for question in augmented_queries.questions:
        
        results = vector_store.similarity_search(
            query=question,
            k=5,
            filter={"persona": ANALYSIS_PERSONNA, "perspective": ANALYSIS_PERSPECTIVE}
        )
        retrieved_qa_pairs = [(result.metadata["question"], result.metadata["answer"]) for result in results]

        qa_pairs.extend(retrieved_qa_pairs)

    qa_str = qa_str.join(f"Question:{qa[0]}\nAnswer:{qa[1]}\n\n" for qa in qa_pairs)

    print(qa_str)

    # Answer query using LLM
    answer = qa_chatbot_answer(
        user_query=query,
        role=ANALYSIS_PERSONNA,
        perspective=ANALYSIS_PERSPECTIVE,
        context=qa_str
    )

    return answer



## Question answering workflow with RAG

In [9]:
QUERY = "What is the purpose of the multi-agent compliance analysis project?"

### Query augmentation

In this step the original query is augmented using the meta-knowledge summary for the persona-perspective combination

In [10]:
meta_kb_summary = meta_kb_summaries[f"{ANALYSIS_PERSONNA}-{ANALYSIS_PERSPECTIVE}"]

In [11]:
augmented_queries = augment_user_query(
    role=ANALYSIS_PERSONNA,
    user_query=QUERY,
    mk_summary=meta_kb_summary
)

In [None]:
print("Augmenting the query:")
print(QUERY)

print("\n\nExisting summary:")
print(meta_kb_summary)

print("\n\nResulting queries:")
for question in augmented_queries.questions:
    print(question)

### Question answering

We can now take the augmented queries and retrieve information from the knowledge base using the augmented queries rather than the original query wich we will pass onto the LLM as context for the question answering

In [13]:
qa_str = ""
qa_pairs = []

# Retrieve context from knowledge base using metadata as partition keys
for question in augmented_queries.questions:
    
    results = vector_store.similarity_search(
        query=question,
        k=5,
        filter={"persona": ANALYSIS_PERSONNA, "perspective": ANALYSIS_PERSPECTIVE}
    )
    retrieved_qa_pairs = [(result.metadata["question"], result.metadata["answer"]) for result in results]

    qa_pairs.extend(retrieved_qa_pairs)

Looking at the results we can observe that we obtain indeed more comprehensive information but more fine-grained thanks to the indexing of Q&A pairs rather than chunks

In [None]:
for qa_pair in qa_pairs:
    print(f"Question: {qa_pair[0]}")
    print(f"Answer: {qa_pair[1]}\n\n")

### Answering the original query with the context from augmented queries

We use the retrieved information from the augmented queries as context to answer the original question

In [None]:
qa_str = qa_str.join(f"Question:{qa[0]}\nAnswer:{qa[1]}\n\n" for qa in qa_pairs)

answer = qa_chatbot_answer(
    user_query=QUERY,
    role=ANALYSIS_PERSONNA,
    perspective=ANALYSIS_PERSPECTIVE,
    context=qa_str
)

In [None]:
answer

In [None]:
print(QUERY)
print(qa_str)
print(answer.answer)

## Answer multiple questions

In [18]:
questions = [
    "What is the purpose of the multi-agent compliance analysis project?",
    "What security measures are in place to protect the data?",
    "How is the multiagent orchestration done?"
]

In [None]:
answer = prepare_rewrite_retrieve_rag_qa(
    questions[0],
    ANALYSIS_PERSONNA,
    ANALYSIS_PERSPECTIVE
)

In [None]:
print(f"Q:{questions[0]}")
print(f"A:{answer.answer}")

In [None]:
answer