# L7: Conversational RAG

<p style="background-color:#fff6e4; padding:15px; border-width:3px; border-color:#f5ecda; border-style:solid; border-radius:6px"> ⏳ <b>Note <code>(Kernel Starting)</code>:</b> This notebook takes about 30 seconds to be ready to use. You may start and watch the video while you wait.</p>

In [None]:
import warnings
warnings.filterwarnings('ignore')

## Import libraries

In [None]:
from ai21 import AI21Client
from ai21.models.chat import ChatMessage
import uuid
import time

<div style="background-color:#fff6ff; padding:13px; border-width:3px; border-color:#efe6ef; border-style:solid; border-radius:6px">
<p> 💻 &nbsp; <b>Access <code>requirements.txt</code> and <code>utils.py</code> files:</b> 1) click on the <em>"File"</em> option on the top menu of the notebook and then 2) click on <em>"Open"</em>.

<p> ⬇ &nbsp; <b>Download Notebooks:</b> 1) click on the <em>"File"</em> option on the top menu of the notebook and then 2) click on <em>"Download as"</em> and select <em>"Notebook (.ipynb)"</em>.</p>

<p> 📒 &nbsp; For more help, please see the <em>"Appendix – Tips, Help, and Download"</em> Lesson.</p>
</div>

## Load API key and create AI21Client

In [None]:
from utils import get_ai21_api_key
ai21_api_key = get_ai21_api_key()
client = AI21Client(api_key=ai21_api_key)

In [None]:
from utils import call_convrag

conversation_history = []
def convrag_response(message):
  conversation_history.append(ChatMessage(content=message, role="user"))
  chat_response = call_convrag(client, conversation_history)
  # the LLM response to user query
  response = chat_response.choices[0].content
  # most relevant retrieved text segment
  text_retrieval = chat_response.sources[0].text
  # the file contains the retrieved text segment
  file_retrieval = chat_response.sources[0].file_name
  conversation_history.append(ChatMessage(content=response, role="assistant"))
  return response

## Prompt the Conversational RAG

<p style="background-color:#f7fff8; padding:15px; border-width:3px; border-color:#e0f0e0; border-style:solid; border-radius:6px"> 🚨
&nbsp; <b>Different Run Results:</b> The output generated by AI chat models can vary with each execution due to their probabilistic nature. Don't be surprised if your results differ from those shown in the video.</p>

In [None]:
message = "You are a financial analyst and what is the summary with Nvidia annual earnings report?"

response = convrag_response(message)

print(response)

In [None]:
message = "How much did the Nvidia's revenue increase in the period?"

response = convrag_response(message)

print(response)

In [None]:
message = "Should I buy Nvidia stock now?"

response = convrag_response(message)

print(response)

## Create a gradio chat app

In [None]:
import gradio as gr

demo = gr.Interface(
    fn=convrag_response,
    inputs=[gr.Textbox(label="Your questions:", lines=2)],
    outputs=[gr.Textbox(label="AI21 Conversational RAG answer:", lines=2)],
    examples=[
    "How have revenue, gross margin, and net income trended over the past year?",
    "What are the actions taken by the company about sustainability?",
    "What are the main risks of the company?",
    "How is the company allocating its capital (e.g., dividends, share repurchases, acquisitions)?",
    "Are there any concerning trends in operating cash flow?",
    ],
    title="Nvidia 10-K Q&A",
    description="Use AI21 Conversational RAG to retrieval insights from SEC filings",
    allow_flagging="never"
)


demo.launch(server_name="0.0.0.0")

## RAG with AI21 Jamba model in Langchain

In [None]:
from langchain_ai21 import ChatAI21
from langchain_community.document_loaders import TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
from langchain_chroma import Chroma
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser

In [None]:
llm = ChatAI21(model="jamba-1.5-large",
               max_tokens = 4096,
               temperature = 0.4,
               top_p = 1)

In [None]:
loader = TextLoader("./Nvidia_10K_20240128.txt")
doc = loader.load()

In [None]:
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=2000,
    chunk_overlap=400)
documents = text_splitter.split_documents(doc)

In [None]:
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
vectorstore = Chroma.from_documents(documents, embedding=embeddings)

### Prompt template

In [None]:
prompt = PromptTemplate.from_template(
    """You are an expert in answering questions based on provided context.
    Answer the question based on the provided context below to the best of your ability.
    The response must be complete, coherent and concise.
    If the answer is not contained in the context, please respond with "answer not in the document"\n
    Here is the context you should use to answer the question: \n
    <context>
    {context}
    </context> \n
    Based on the provided context, answer the following question: {question} \n
    Answer:"""
)

In [None]:
retriever = vectorstore.as_retriever(
    search_type="mmr", 
    search_kwargs={"k": 10})

In [None]:
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

# RAG chain
rag_chain = (
    {"context": retriever | format_docs, "question": RunnablePassthrough()}
    | prompt
    | llm
    | StrOutputParser()
)

### Query

In [None]:
q = "How has the company revenue and profit changed from last year?"

response = rag_chain.invoke(q)
print(f"Answer: {response}")

In [None]:
docs = retriever.invoke(q)
docs

In [None]:
questions = ["What are the main business risks for the company?",
             "What are the key financial metrics of the company?",
             "What is the profit growth of the company in the reporting period?",
             "Did the company have a cybersecurity incident based on the following SEC filing document?"
]

for q in questions:
    response = rag_chain.invoke(q)
    print("="*80)
    print(f"Question: {q}")
    print(f"Answer: {response}")

In [None]:
vectorstore.delete_collection()