#Build a RAG agent that can run on llama

%md
![Llama RAG implementation.png](./Llama RAG implementation.png "Llama RAG implementation.png")

In [0]:
from langchain_core.messages import HumanMessage, SystemMessage
from databricks_langchain import ChatDatabricks

In [0]:
from  dotenv import load_dotenv
_ = load_dotenv()

In [0]:
chat_model = ChatDatabricks(
    endpoint='otc-lama-poc',
    temperature=0,
    max_tokens=250
)
chat_model_json = ChatDatabricks(
    endpoint='otc-lama-poc',
    temperature=0,
    max_tokens=250,
    return_json=True,
    # json_format=True
)

In [0]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import SKLearnVectorStore
from langchain_nomic.embeddings import NomicEmbeddings

In [0]:
urls = [
    "https://lilianweng.github.io/posts/2023-06-23-agent/",
    "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
    "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]

In [0]:
#load documents
documents = [WebBaseLoader(url).load() for url in urls]
doc_list = [item for sublist in documents for item in sublist] 

In [0]:
#split documents
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
doc_splits = text_splitter.split_documents(doc_list)

In [0]:
#load vector store
vectorstore = SKLearnVectorStore.from_documents(documents=doc_splits, 
                                                embedding=NomicEmbeddings(model="nomic-embed-text-v1.5", inference_mode="local"))

In [0]:
# ROUTER 
import json
from langchain_core.messages import HumanMessage, SystemMessage

#Prompt
router_instructions = '''You are an expert at routing a user question to a vectorstore or web search.
The vectorstore contains documents related to agents, prompt engineering, and adversarial attacks.
Use the vectorstore for questions on these topics. For all else, and especially for current events, use web-search.
Return JSON format with single key, datasource, that is 'websearch' or 'vectorstore' depending on the question. 
No formatting or comments required. Pure json format.'''

#Test
test_websearch = chat_model_json.invoke([SystemMessage(content=router_instructions), 
                                         HumanMessage(content="who won champions tropy 2025?")])
test_vectorstore = chat_model_json.invoke([SystemMessage(content=router_instructions), HumanMessage(content="What is prompt engineering?")])
print (json.loads(test_websearch.content))
print (json.loads(test_vectorstore.content))                                    

In [0]:
#RETRIEVE DOCUMENTS
retriever = vectorstore.as_retriever(k=3)
retriever.invoke("Agent memory")

In [0]:
#GRADE DOCUMENTS
#Doc Grade Instructions
doc_grade_instructions = """You are a grader assessing relevance of a retrieved document to a user question.
If the document contains keyword(s) or semantic meaning related to the question, grade it as relevant."""

#Doc Grade Prompt
doc_grade_prompt = """Here is the retrieved document: \n\n {document} \n\n Here is the user question: \n\n {question}. 

This carefully and objectively assess whether the document contains at least some information that is relevant to the question.

Return only JSON with single key, binary_score, that is 'yes' or 'no' score to indicate whether the document contains at least some information that is relevant to the question. No formatting or comments required. Pure json format."""

question = "What is chain of thought prompting?"
docs = retriever.invoke(question)
doc_txt = docs[1].page_content
doc_grade_prompt_formatted = doc_grade_prompt.format(document=doc_txt, question=question)

results = chat_model_json.invoke([SystemMessage(content=doc_grade_instructions), HumanMessage(content=doc_grade_prompt_formatted)])
print(json.loads(results.content))






In [0]:
#GENERATE ANSWER
rag_prompt = """You are an assistant for question-answering tasks. 
Here is the context to use to answer the question:

{context} 

Think carefully about the above context. 
Now, review the user question:

{question}

Provide an answer to this questions using only the above context. 
Use three sentences maximum and keep the answer concise.

Answer:"""
#post processing
def format_docs (docs):
    return  "\n\n".join(doc.page_content for doc in docs)

#test
docs = retriever.invoke(question)
doc_txt = format_docs(docs)
rag_prompt_formated = rag_prompt.format(context=doc_txt, question=question)
generation = chat_model.invoke([HumanMessage(content=rag_prompt_formated)])
print (generation.content)

In [0]:
# HALLUCINATION CHECKER
hallucination_grader_instructions = """

You are a teacher grading a quiz. 

You will be given FACTS and a STUDENT ANSWER. 

Here is the grade criteria to follow:

(1) Ensure the STUDENT ANSWER is grounded in the FACTS. 

(2) Ensure the STUDENT ANSWER does not contain "hallucinated" information outside the scope of the FACTS.

Score:

A score of yes means that the student's answer meets all of the criteria. This is the highest (best) score. 

A score of no means that the student's answer does not meet all of the criteria. This is the lowest possible score you can give.

Explain your reasoning in a step-by-step manner to ensure your reasoning and conclusion are correct. 

Avoid simply stating the correct answer at the outset."""

#Grader prompt
hallucination_grader_prompt = """ FACTS: \n\n {docs} \n\n STUDENT ANSWER: \n\n {generation} \n\n Return JSON with two two keys, binary_score is 'yes' or 'no' score to indicate whether the STUDENT ANSWER is grounded in the FACTS. And a key, explanation, that contains an explanation of the score.
Only JSON formated. No formatting or comments required. no ``` in the start or end"""

hallucination_grader_prompt_formatted = hallucination_grader_prompt.format(docs=doc_txt, generation=generation.content)
result = chat_model_json.invoke([SystemMessage(content=hallucination_grader_instructions), HumanMessage(content=hallucination_grader_prompt_formatted)])
json.loads(result.content)

In [0]:
# GRADE THE ANSWER
answer_grader_instructions = """You are a teacher grading a quiz. 

You will be given a QUESTION and a STUDENT ANSWER. 

Here is the grade criteria to follow:

(1) The STUDENT ANSWER helps to answer the QUESTION

Score:

A score of yes means that the student's answer meets all of the criteria. This is the highest (best) score. 

The student can receive a score of yes if the answer contains extra information that is not explicitly asked for in the question.

A score of no means that the student's answer does not meet all of the criteria. This is the lowest possible score you can give.

Explain your reasoning in a step-by-step manner to ensure your reasoning and conclusion are correct. 

Avoid simply stating the correct answer at the outset."""

answer_grader_prompt = """QUESTION:\n\n {question} \n\n STUDENT ANSWER: {generation}.
Return JSON with two two keys, binary_score is 'yes' or 'no' score to indicate whether the STUDENT ANSWER meets the criteria. And a key, explanation, that contains an explanation of the score.
only json formated. No formatting or comments required."""

#Test
question = "What are the vision models released today as part of Llama 3.2?"
answer = "The Llama 3.2 models released today include two vision models: Llama 3.2 11B Vision Instruct and Llama 3.2 90B Vision Instruct, which are available on Azure AI Model Catalog via managed compute. These models are part of Meta's first foray into multimodal AI and rival closed models like Anthropic's Claude 3 Haiku and OpenAI's GPT-4o mini in visual reasoning. They replace the older text-only Llama 3.1 models."

answer_grader_prompt_formatted = answer_grader_prompt.format(question=question, generation=answer)
result = chat_model_json.invoke([SystemMessage(content=answer_grader_instructions), HumanMessage(content=answer_grader_prompt_formatted)])

json.loads(result.content)



# Websearch tool

In [0]:
#Search
from langchain_community.tools.tavily_search import TavilySearchResults
web_search_tool = TavilySearchResults(k=3)

# Langgraph

In [0]:
import operator
from typing_extensions import TypedDict
from typing import List, Annotated

#Langgragh agent state class
class AgentState(TypedDict):
    '''
    Graph state is a dictionary that contains information we want to propagate to, and modify in, each graph node.
    '''
    question: str  #User Question
    generation: str #LLM Generation
    web_search: str # Binary decision to run websearch
    max_retries: int # Max number of retries for answer generation
    answers: int # number of answers generated
    loop_step: [int, operator.add] # 
    documents: List[str] # list of retrieved documents



In [0]:
from langchain.schema import Document
from langgraph.graph import END


In [0]:
# Retrieve Node
def retrieve(state: AgentState):
    '''
    Retrieves documents from vector store

    Args:
    state (dict): The current graph state
    Returns:
    state(dict): New key added to state, documents that contain retrieved documents
    '''
    print ("---RETRIEVE---")
    question = state['question']
    # Write retrieved documents to documents key in state
    documents = retriever.invoke(question)
    return {"documents": documents}


In [0]:
    def generate(state: AgentState):
        '''
        Generate answer using RAG on retrieved documents
        Args:
            state (dict): The current graph state
        Returns:
            state(dict): New key added to state, generation from LLM
        '''
        print ("entering generate function")
        print ("---GENERATE---")
        question = state['question']
        # print (f'question: {question}')
        documents = state['documents']
        # print (f'documents: {documents}')

        loop_step = state.get('loop_step', 0)
        # print (f"loop_step: {loop_step}")

        #RAG Generation:
        docs_txt = format_docs(documents)
        # print (f"docs_text: {docs_txt}")
        rag_prompt_formated = rag_prompt.format(context=docs_txt, question=question)
        generation = chat_model.invoke([HumanMessage(content=rag_prompt_formated)])
        return {"generation": generation, "loop_step": loop_step+1}

In [0]:
def grade_documents (state: AgentState):
    '''
    Determines where the retrieved documents are relevant to the question.
    If any document is not relevant, we will set a flag to run websearch.
    Args:
        state (dict): The current graph state
    Returns:
        state(dict): Filtered out irrelevant documents and updated web_search state if there is no relevant documents found
    '''
    print ("----CHECK DOCUMENTS RELEVANCE TO QUESTION----")
    question = state['question']
    documents = state['documents']
    filtered_docs = []
    web_search = "No"
    for d in documents:
        doc_grade_prompt_formatted = doc_grade_prompt.format(document=d.page_content, question=question)
        result = chat_model.invoke([SystemMessage(content= doc_grade_instructions),
                                    HumanMessage(content=doc_grade_prompt_formatted)])
        print (result.content)
        grade = json.loads(result.content)["binary_score"]
        #Document Relevant
        if grade.lower() == "yes":
            print ("---- GRADER: DOCUMENT RELEVANT----")
            filtered_docs.append(d)
        #Documents not relevant
        else:
            print ("---- GRADER: DOCUMENT NOT RELEVANT----") 
            #we do not include documents in the filtered_docs
            # web_search="Yes"
            continue
        #If there are no relevant documents then direct to websearch
    if len(filtered_docs) == 0:
        web_search="Yes"
    return {"documents": filtered_docs, "web_search": web_search}
    

In [0]:
def web_search(state: AgentState):
    """
    Websearch is based on the question

    Args:
    state(dict): The current graph state

    Returns:
    state(dict): Append web results to documents
    """
    print ("---WEB SEARCH---")
    question = state['question']
    documents = state.get('documents', [])
    print (f'question: {question}')
    print (f'documents: {documents}')
    #web search
    docs = web_search_tool.invoke({"query": question})
    print (f'web search results docs: {docs}')
    web_results = "\n".join(d['content'] for d in docs)
    web_results = Document(page_content=web_results)
    documents.append(web_results)
    return {"documents": documents}
    

In [0]:
#EDGES
def route_question(state: AgentState):
    """
    Route question to web serch or RAG
    Args:
        state(dict): The current graph state
    returns:
        str: Next node to call
    """
    print ("---ROUTE QUESTION---")
    route_question = chat_model_json.invoke([SystemMessage(content=router_instructions),HumanMessage(content=state['question'])])
    source = json.loads(route_question.content)["datasource"]  
    if source == "websearch":
        print ("----ROUTE QUESTION: WEB SEARCH----")
        return "websearch"
    elif source == "vectorstore":
        print ("----ROUTE QUESTION: RAG----")
        return "vectorstore"
    else:
        print ("Option invalid: ERROR in flow")
        return "Error"

In [0]:
def decide_to_generate(state: AgentState):
    """
    Determines whether to generate an answer, or add web search
    Args:
        state(dict): The current graph state
    Returns:
        str: Binary decision for next node to call   
    """
    print ("----ASSES GRADED DOCUMENTS----")
    question = state['question']
    web_search=state['web_search']
    filtered_documents = state['documents']

    if "web_search" == "Yes":
        ## All documents have been filtered check relevance
        #We will re-generate a new query
        print("----DECISION: NOT ALL DOCUMENTS ARE RELEVANT TO QUESTION, INCLUDE WEBSEARCH ---")
        return "websearch"
    else:
        #We have relevant documents, so generate answer
        print ("----DECISION: GENERATE----")
        return "generate"
    

In [0]:
def grade_generation_v_documents_and_question(state: AgentState):
    """
    Determines whether the generation is grounded in the document and answers question

    Args:
        state(dict): The current graph state
    Returns:
        str: Binary decision for next node to call   
    """
    print ("----CHECK HALLUCINATION----")
    question = state['question']
    documents = state['documents']
    generation = state['generation']
    max_retries = state.get("max_retries", 3)
    
    hallucination_grader_prompt_formatted = hallucination_grader_prompt.format( docs=format_docs(documents), generation=generation.content)
    result = chat_model_json.invoke([SystemMessage(content=hallucination_grader_instructions), HumanMessage(content=hallucination_grader_prompt_formatted)])
    # print (f'hallucination grader instructions {hallucination_grader_instructions}')
    # print (f'hallucination grader prompt: {hallucination_grader_prompt_formatted}')
    # print (f'{result.content}')
    # print (f'hallucination results: {json.loads(result.content)}')
    grade = json.loads(result.content)["binary_score"]
    #Check hallucination
    if grade.lower() == "yes":
        print ("----DECISION: GENERATION GROUNDED IN DOCUMENTS ----")
        #check question answering
        print ("----CHECK GENERATION vs QUESTION----")
        #Test using question and generation from above
        answer_grader_prompt_formatted  = answer_grader_prompt.format(question=question, generation=generation.content)
        result = chat_model_json.invoke([SystemMessage(content=answer_grader_instructions), HumanMessage(content=answer_grader_prompt_formatted)])
        grade = json.loads(result.content)["binary_score"]
        if grade.lower() == "yes":
            print ("----DECISION: GENERATION GROUNDED IN DOCUMENTS AND QUESTION ----")  
            return "useful"
        elif state['loop_step'] <= max_retries:
            print ("----DECISION: GENERATION GROUNDED IN DOCUMENTS BUT NOT QUESTION, RETRYING ----")
            return "not useful"
        else:
            print ("----DECISION: GENERATION GROUNDED IN DOCUMENTS BUT NOT QUESTION, RETRY LIMIT REACHED ----")
            return "max retries"
    elif state['loop_step'] <= max_retries:
        print ("----DECISION: GENERATION NOT GROUNDED IN DOCUMENTS, RETRYING ----")
        return "not supported"
    else:
        print ("----DECISION: GENERATION NOT GROUNDED IN DOCUMENTS, RETRY LIMIT REACHED ----")
        return "max retries"
        

#Control flow

In [0]:
from langgraph.graph import StateGraph
from IPython.display import Image, display

In [0]:
graph = StateGraph(AgentState)

#Define the nodes
graph.add_node("websearch", web_search) #websearch
graph.add_node("generate", generate) # generate
graph.add_node("retrieve", retrieve) #retrieve documents
graph.add_node("grade_documents", grade_documents) #grade documents

In [0]:
#build graph
graph.set_conditional_entry_point(
  route_question,
  {
    "websearch": "websearch",
    "vectorstore": "retrieve"
  }
)
graph.add_edge("websearch", "generate")
graph.add_edge("retrieve", "grade_documents")
graph.add_conditional_edges("grade_documents",
                            decide_to_generate,
                            {
                              "websearch": "websearch",
                              "generate": "generate"
                            })
graph.add_conditional_edges("generate",
                            grade_generation_v_documents_and_question,
                            {
                              "not supported": "generate",
                              "useful": END,
                              "not useful": "websearch",
                              "max retries": END
                            })
compiled_graph = graph.compile()



In [0]:
display(Image(compiled_graph.get_graph().draw_mermaid_png()))


In [0]:
inputs = {"question": "Adversarial Attacks on LLMs","max_retries":3}
for event in compiled_graph.stream(inputs, stream_mode="values"):
    if "question" in event:
        print (f'Question: {event["question"]}')