### Corrective RAG implementation with LangChain & LangGraph

In [None]:
# imports

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_ollama import OllamaEmbeddings
from langchain.prompts import ChatPromptTemplate
from langchain_ollama import ChatOllama
from langchain_community.vectorstores import Chroma
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.documents import Document

import os

# set enviroment variables for tracing and web-search
os.environ["LANGSMITH_TRACING"] = "true"
os.environ['TAVILY_API_KEY'] = "***Insert your API key***"
os.environ["LANGSMITH_API_KEY"] = "***Insert your API key***"

os.environ['TAVILY_API_KEY'] = "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"
os.environ["LANGSMITH_API_KEY"] = "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"

In [None]:
# define local embedding-model
embedding_model = OllamaEmbeddings(model= 'nomic-embed-text:latest')

# define local LLM
llm = ChatOllama(model = "mistral:latest", temperature= 0.001)

In [None]:
# data sources
urls = [
    "https://lilianweng.github.io/posts/2023-06-23-agent/",
    "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
    "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]

docs = [WebBaseLoader(url).load() for url in urls]
docs = [item for doc in docs for item in doc]

In [None]:
# splittting & chunking
text_splitter = RecursiveCharacterTextSplitter(chunk_size= 250, chunk_overlap= 20)
doc_chunks = text_splitter.split_documents(docs)

In [None]:
# setting up the vectorstore and indexing the chunks
vectorstore = Chroma.from_documents(documents= doc_chunks, embedding= embedding_model, 
                                    collection_name= "crag")

retriever = vectorstore.as_retriever(k = 3)

In [None]:
# let's build document grader chain

from pydantic import BaseModel, Field

class GraderSchema(BaseModel):
    response : str = Field(description= "Assess if retrieved documents are relevant to the question: 'yes' or 'no'")

system_template = "You are an expert at assessing if the retrieved documents are semantically relevant to the question being asked. Answer your question with a simple 'yes' or 'no'"

prompt = ChatPromptTemplate.from_messages([
                            ("system", system_template),
                            ("human", "Retrieved documents:{context}\n\nQuestion:{question}")])

In [9]:
llm_ = llm.with_structured_output(GraderSchema)
relevance_chain = (prompt | llm_ | StrOutputParser())

In [10]:
relevant_docs = retriever.get_relevant_documents("What is an AI agent?")

  relevant_docs = retriever.get_relevant_documents("What is an AI agent?")


In [11]:
relevant_docs

[Document(metadata={'title': "LLM Powered Autonomous Agents | Lil'Log", 'description': 'Building agents with LLM (large language model) as its core controller is a cool concept. Several proof-of-concepts demos, such as AutoGPT, GPT-Engineer and BabyAGI, serve as inspiring examples. The potentiality of LLM extends beyond generating well-written copies, stories, essays and programs; it can be framed as a powerful general problem solver.\nAgent System Overview\nIn a LLM-powered autonomous agent system, LLM functions as the agent’s brain, complemented by several key components:\n\nPlanning\n\nSubgoal and decomposition: The agent breaks down large tasks into smaller, manageable subgoals, enabling efficient handling of complex tasks.\nReflection and refinement: The agent can do self-criticism and self-reflection over past actions, learn from mistakes and refine them for future steps, thereby improving the quality of final results.\n\n\nMemory\n\nShort-term memory: I would consider all the in

In [13]:
relevance_chain.invoke({"context": relevant_docs[0].page_content, "question": "what is theory of relativity?"})

GraderSchema(response='No')

In [14]:
# question rewriter
template = "You are an expert at rewriting questions to make web search easier.Simply output the rephrased question.\nQuestion: {question}"
prompt = ChatPromptTemplate.from_template(template)

rewriting_chain = (prompt | llm | StrOutputParser())
rewriting_chain.invoke({'question': "Explain model distillation in llm"})


' "What is the concept of model distillation in the context of Language Model Learning (LLM)?"'

In [16]:
# define web-search tool
from langchain_community.tools.tavily_search import TavilySearchResults
web_search = TavilySearchResults(max_results= 3)

In [30]:
results = web_search.invoke({'query': "What does AWS do?"})

In [31]:
results

[{'title': 'Amazon Web Services - Wikipedia',
  'url': 'https://en.wikipedia.org/wiki/Amazon_Web_Services',
  'content': '**Amazon Web Services, Inc.** (**AWS**) is a subsidiary of [Amazon](/wiki/Amazon.com "Amazon.com") that provides [on-demand](/wiki/Software_as_a_service "Software as a service") [cloud computing](/wiki/Cloud_computing "Cloud computing") [platforms](/wiki/Computing_platform "Computing platform") and [APIs](/wiki/Application_programming_interface "Application programming interface") to individuals, companies, and governments, on a metered, pay-as-you-go basis. Clients will often use this in',
  'score': 0.81418616},
 {'title': 'What Is Amazon Web Services, and Why Is It So Successful?',
  'url': 'https://www.investopedia.com/articles/investing/011316/what-amazon-web-services-and-why-it-so-successful.asp',
  'content': 'AWS is made up of many different cloud computing products and services. They provide servers, storage, networking, remote computing, email, mobile deve

##### Define Graph State - for LangGraph

In [None]:
from pydantic import BaseModel
from typing import List, Optional

# a pydantic base class
class StateSchema(BaseModel):

    question: str
    documents: Optional[List[Document]] = list()
    generation: Optional[str] = None
    web_search: Optional[str] = None


#### Define Graph Nodes and their functions

In [None]:
def retrieve(state):

    """Retrieves relevant documents from vectorstore and returns the updated state"""

    question = state.question

    retrieved_docs = retriever.get_relevant_documents(question)

    state.documents.extend(retrieved_docs)

    return state

In [None]:
def generate(state):

    """Generates output based on the context and query given to the LLM"""

    question = state.question
    retrieved_docs = state.documents

    template = """You are an expert at directly answering questions based on the given context.\n\nContext: {context}\n\nQuestion:{question}"""
    prompt = ChatPromptTemplate.from_template(template)
    generation_chain = (prompt | llm | StrOutputParser())

    state.generation = generation_chain.invoke({"context": retrieved_docs, "question": question})

    return state

In [None]:
def grade_documents(state):

    """Assigns relevancy label to a specific retrieved document. If any retrieved doc is irrelevant we use web-search"""

    question = state.question
    documents = state.documents
    relevant_docs = list()
    web_search = "no"

    for doc in documents:

        ai_output = relevance_chain.invoke({'question':question, "context": doc.page_content})

        if ai_output.response.lower() == "yes":
            relevant_docs.append(doc)
        
        else:
            web_search = "yes"
        
    state.web_search = web_search
    state.documents = relevant_docs
    return state

In [None]:
def transform_query(state):

    """If web-search is requested, then it transforms the query into a better structure suitable for web-search"""

    question = state.question

    better_question = rewriting_chain.invoke({"question": question})

    state.question = better_question

    return state


In [None]:
def to_generate(state):

    """Conditional-edge in the LangGraph structure that decide whether to proceed with generation or transform the query and perform web-search"""

    web_search = state.web_search

    if web_search.lower() == "yes":
        return "transform"
    else:
        return "generate"

In [None]:
def internet_search(state):

    """Perofrm Tavily web search and add the top 3 search result contents as Documents"""

    web_search = state.web_search
    query = state.question
    relevant_docs = list()
    

    if web_search.lower().strip() == "yes":
        search_tool = TavilySearchResults(max_results= 3)
        search_results = search_tool.invoke({"query": query})

        for source in search_results:
            content = source['content']
            doc = Document(page_content= content)
            relevant_docs.append(doc)

    state.documents = relevant_docs
    return state

#### Define LangGraph graph

In [None]:
# State graph is a graph where the nodes communicate with each other by modifying a certain state
from langgraph.graph import StateGraph, START, END 

graph = StateGraph(StateSchema)

graph.add_node("retrieve", retrieve)
graph.add_node("generate", generate)
graph.add_node("transform_query", transform_query)
graph.add_node("grade_documents", grade_documents)
graph.add_node("internet_search", internet_search)

graph.add_edge(START, "retrieve")
graph.add_edge("retrieve", "grade_documents")
graph.add_conditional_edges("grade_documents", 
                            to_generate, 
                            {"transform": "transform_query",
                             "generate":"generate"})

graph.add_edge("transform_query", "internet_search")
graph.add_edge("internet_search", "generate")
graph.add_edge("generate", END)

workflow = graph.compile()

In [133]:
results = workflow.invoke({"question": "Who is the ICC Champion in 2023?"})
results['generation']

' The winning team of the 2023 ICC World Cup is Australia.'

In [134]:
results

{'question': '2023 ICC World Cup Champions: Identify the winning team.',
 'documents': [Document(metadata={}, page_content='Now, the ICC ODI Cricket World Cup is concluding with the final scheduled for November 19, 2023. Now that the league stage is over, **India, Australia, South Africa, and New Zealand** have topped the table. The first semi-final took place between India and New Zealand, with India emerging victorious by 70 runs. The second semi-final between Australia and South Africa will determine the finalists for the World Cup 2023 final. [...] As you can see, Australia is the most successful team in the history of the \xa0ICC men’s ODI Cricket World Cup 2023, having won the tournament 5 times. India and West Indies are the only other countries to have won the World Cup more than once, with two victories each. England won the 2019 World Cup, their first-ever victory in the tournament.\n\nODI Cricket World Cup Winners list, Country Wise\n-----------------------------------------