In [None]:
from dotenv import load_dotenv
from typing import List
from typing_extensions import TypedDict
from IPython.display import Image, display

from langchain.chat_models import ChatOllama
from langchain.embeddings import OllamaEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import PromptTemplate
from langchain.schema import Document
from langchain_core.output_parsers import StrOutputParser, JsonOutputParser

from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_community.chat_models import ChatOllama
from langchain_community.tools.tavily_search import TavilySearchResults

from langgraph.graph import START, END, StateGraph

import uuid
from langsmith import client
from langchain import hub
from langsmith.schemas import Example, Run
from langsmith.evaluation import evaluate

In [None]:
nomic137M_embeddings = OllamaEmbeddings(model="nomic-embed-text")
sample = nomic137M_embeddings.embed_query("Hello world")
type(sample), len(sample), sample 

In [None]:
# List of URLs to load documents from
urls = [
    "https://lilianweng.github.io/posts/2023-06-23-agent/",
    "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
    "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]

docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]

text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=250, chunk_overlap=20
)
doc_splits = text_splitter.split_documents(docs_list)

#add to vectorDB
vectorstore = Chroma.from_documents(
    documents = doc_splits,
    collection_name = "rag-chroma",
    embedding = nomic137M_embeddings
)
retriever = vectorstore.as_retriever(k=4)

In [None]:
llm = ChatOllama(model="llama3", format="json", temperature=0)

In [None]:
prompt = PromptTemplate(
    template="""You are a teacher grading a quiz. You will be given:
    1/ a QUESTION
    2/ a FACT provided by the student
    
    You are grading RELEVANCE_RECALL:
    A score 1 means ANY of the statements in the FACT are relevant to the QUESTION.
    A score 0 means that NONE of the statements in the FACT are relevant to the QUESTION.
    1 is the highest (best) score. 0 is the lowest score you can award.
    
    Explain your reasoning in a step-by-step manner. Ensure your reasoning and conclusion are correct.
    
    Avoid simply stating the correct answer at the outset.
    
    Question: {question} \n
    Fact: \n\n {documents} \n\n
    
    Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question. \n
    Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.""",
    input_variables = ["question", "documents"]
)

retrieval_grder = prompt | llm | JsonOutputParser()

In [None]:
question = "agent memory"
docs = retriever.invoke(question)
doc_text = docs[1].page_content
print(retrieval_grder.invoke(input={"question": question, "documents": doc_text}))

In [None]:
rag_prompt = PromptTemplate(
    template = """You are an assistant for question-answering tasks.
    
    Use the following documents to answer the question.
    
    If you don't know the answer, just say that you don't know.
    
    Use three sentences maximum and keep the answer concise:
    Question: {question}\n
    Documents: \n{documents}\n
    Answer:
    """
)

rag_chain = rag_prompt | llm | StrOutputParser()
generation = rag_chain.invoke({"documents": docs, "question": question})
print(generation)

In [None]:
web_search_tool = TavilySearchResults(k=3)

In [None]:
class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        question: question
        generation: LLM generation
        search: whether to add search
        documents: list of documents
    """
    question: str
    generation: str
    search: bool
    documents: List[str]
    steps: List[str]


In [None]:
def retrieve(state):
    """
    Retrieve documents

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, documents, that contains retrieved documents
    """
    question = state["question"]
    documents = retriever.invoke(question)
    steps = state["steps"]
    steps.append("retrieve_documents")
    return {"documents": documents, "question": question, "steps": steps}


def generate(state):
    """
    Generate answer

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, generation, that contains LLM generation
    """
    question = state["question"]
    documents = state["documents"]
    generation = rag_chain.invoke({"documents": documents, "question": question})
    steps = state["steps"]
    steps.append("generate_answer")
    return {
        "documents": documents,
        "question": question,
        "generation": generation,
        "steps": steps
    }


def grade_documents(state):
    """
    Determines whether the retrieved documents are relevant to the question.

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): Updates documents key with only filtered relevant documents
    """
    question = state["question"]
    documents = state["documents"]
    steps = state["steps"]
    steps.append("grade_document_retrieval")
    filtered_docs = []
    search = "No"
    for d in documents:
        score = retrieval_grder.invoke({"question": question, "documents": d.page_content})
        grade = score["score"]
        if grade == "Yes":
            filtered_docs.append(d)
        else:
            search = "Yes"
            continue
    return {
        "documents": filtered_docs,
        "question": question,
        "search": search,
        "steps": steps
    }


def web_search(state):
    """
    Web search based on the re-phrased question.

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): Updates documents key with appended web results
    """
    question = state["question"]
    documents = state["documents"]
    steps = state["steps"]
    steps.append("web_search")
    web_results = web_search_tool.invoke({"query": question})
    documents.extend(
        [
            Document(page_content=d["content"], metadata = {"url": d["url"]}) 
            for d in web_results
        ]
    )
    return {"doucments": documents, "question": question, "steps": steps}

def decide_to_generate(state):
    """
    Determines whether to generate an answer, or re-generate a question.

    Args:
        state (dict): The current graph state

    Returns:
        str: Binary decision for next node to call
    """
    search = state["search"]
    if search == "Yes":
        return "Search"
    else:
        return "Generate"

In [None]:
#graph
workflow = StateGraph(GraphState)

#nodes
workflow.add_node("retrieve", retrieve)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("generate", generate)
workflow.add_node("web_search", web_search)

#build graph
workflow.add_edge(START, "retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
    "grade_documents",
    decide_to_generate,
    {
        "Search": "web_search",
        "Generate": "generate"
    },
)
workflow.add_edge("web_search", "generate")
workflow.add_edge("generate", END)

crag_graph = workflow.compile()
display(Image(crag_graph.get_graph(xray=True).draw_mermaid_png()))

In [None]:
def predict_custom_agent_local_answer(example: dict):
    config = {"configuration": {"thread_id": str(uuid.uuid4())}}
    state_dict = crag_graph.invoke(
        {
            "question": example["input"],
            "steps": []
        },
        config
    )
    return {"response": state_dict["generation"], "steps": state_dict["steps"]}


In [None]:
def pred(query:str):
    example = {"input": query}
    response = predict_custom_agent_local_answer(example)
    return response

In [None]:
pred("what is automatic prompt design ?")

In [None]:
pred("what is few shot prompting ? Can you give me an example ?")