<a href="https://colab.research.google.com/github/hanhanwu/Hanhan_LangGraph_Exercise/blob/main/RAG_Chatbot/try_corrective_rag.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture --no-stderr
%pip install -U --quiet langchain-community tiktoken langchain-openai langchainhub chromadb langchain langgraph langchain-text-splitters playwright unstructured tavily-python
!playwright install
%pip install -U --quiet rank_bm25 faiss-cpu

In [3]:
from google.colab import userdata

# load the environment variables set in colab
OPENAI_API_KEY = userdata.get('OPENAI_API_KEY')
LANGSMITH_API_KEY = userdata.get('LANGSMITH_API_KEY')
TAVILY_API_KEY = userdata.get('TAVILY_API_KEY')

## Retriever

In [4]:
from langchain_community.document_loaders import PlaywrightURLLoader
from langchain_community.vectorstores import Chroma
from langchain.retrievers import EnsembleRetriever
from langchain_community.retrievers import BM25Retriever
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.tools.retriever import create_retriever_tool


urls = [
   "https://www.ana.co.jp/en/us/travel-information/customers-with-disabilities/walking-disabilities/",
   "https://www.ana.co.jp/en/us/travel-information/customers-with-disabilities/visual-disabilities/",
]

loader = PlaywrightURLLoader(urls=urls, remove_selectors=["header", "footer"])
docs = await loader.aload()  # returns "Document" type

text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=200, chunk_overlap=80
)
doc_splits = text_splitter.split_documents(docs)  # split into chunks with overlap

# choose retriever type based on the number of chunks
chunks_ct = len(doc_splits)
if chunks_ct < 30:
  print(chunks_ct, 'choose vectorstore based retriever')
  # use Vectorstore-backed retriever (the simplest retriever in LangChain)
  vectorstore = Chroma.from_documents(
      documents=doc_splits,
      collection_name="rag-chroma",
      embedding=OpenAIEmbeddings(api_key=OPENAI_API_KEY),
  )
  retriever = vectorstore.as_retriever()
else:
  print(chunks_ct, 'choose ensemble retriever')
  # use emsemble retriever
  # initialize the bm25 retriever and faiss retriever
  bm25_retriever = BM25Retriever.from_texts(
      [doc.page_content for doc in doc_splits], metadatas=[{"source": 1}] * len(doc_splits)
  )
  bm25_retriever.k = 2
  embedding = OpenAIEmbeddings(api_key=OPENAI_API_KEY)
  faiss_vectorstore = FAISS.from_texts(
      [doc.page_content for doc in doc_splits], embedding, metadatas=[{"source": 2}] * len(doc_splits)
  )
  faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={"k": 2})
  # initialize the ensemble retriever
  retriever = EnsembleRetriever(
      retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5]
  )

44 choose ensemble retriever


## State of Graph

In [9]:
from typing import List
from typing_extensions import TypedDict


class GraphState(TypedDict):
    question: str
    generated_answer: str
    needs_web_search: str  # "yes" or "no" decides whether needs web search
    docs: List[str]

## Nodes

In [6]:
from typing_extensions import TypedDict
from pydantic import BaseModel, Field

from langchain import hub
from langchain_core.messages import HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI


model_str = "gpt-4o-mini-2024-07-18"

In [None]:
def retrieve(state):
    """
    Retrieve documents

    Args:
        state (dict): The current graph state
    Returns:
        state (dict): retrieved documents added to the graph state; question stays the same
    """
    print("---RETRIEVE---")
    question = state["question"]

    # Retrieval
    documents = retriever.invoke(question)
    return {"documents": documents, "question": question}


def grade_documents(state):
    """
    Check each retrieved document, if there is any un-relevant document, set "needs_web_search" as "yes".
    All the relevant documents will be saved too.

    Args:
        state (dict): The current state
    Returns:
        state (dict): retrieved relevant documents; the decison on "needs_web_search"; question stays the same
    """

    print("---CHECK RELEVANCE---")

    # Data model
    class grade(BaseModel):
        """Binary score for relevance check."""
        binary_score: str = Field(description="Relevance score 'yes' or 'no'")

    # LLM
    model = ChatOpenAI(temperature=0, api_key=OPENAI_API_KEY,
                       model=model_str, streaming=True)
    # LLM with tool and validation
    llm_with_tool = model.with_structured_output(grade)

    # Prompt
    prompt = PromptTemplate(
        template="""You are a grader assessing relevance of a retrieved document to a user question. \n
        Here is the retrieved document: \n\n {context} \n\n
        Here is the user question: {question} \n
        If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
        Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.""",
        input_variables=["context", "question"],
    )

    # Relevancy Grader
    grader = prompt | llm_with_tool

    question = state['question']
    retrieved_docs = state['docs']

    relevant_docs = []
    needs_web_search = 'no'
    for doc in retrieved_docs:
      scored_result = grader.invoke({'question': question, 'document': doc.page_content})
      score = scored_result.binary_score

      if score == 'yes':
          relevant_docs.append(doc)
      else:
          needs_web_search = 'yes'

    return {'document': relevant_docs,
            'question': question,
            'needs_web_search': needs_web_search}