In [None]:
from langgraph.graph import Graph
from IPython.display import Image, display
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_community.document_loaders import TextLoader, DirectoryLoader
from langchain_community.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from typing import TypedDict, Annotated, Sequence
import operator
from langchain_core.messages import BaseMessage
from langchain.prompts import PromptTemplate
from pydantic import BaseModel, Field
from langchain.output_parsers import PydanticOutputParser
from langgraph.graph import StateGraph, END



# This is a simple workflow with graph

In [None]:
def function_3(input_3):
    return input_3

In [None]:
def function_1(input_1):
    return input_1 + "from first function"

def function_2(input_2):

    output = function_3("This is function 3 in between")

    return input_2 + " " + output + "and this is from second function"

In [None]:
workflow_1 = Graph()

In [None]:
workflow_1.add_node("function_1", function_1)

In [None]:
workflow_1.add_node("function_2", function_2)

In [None]:
workflow_1.add_edge("function_1", "function_2")

In [None]:
workflow_1.set_entry_point("function_1")

In [None]:
workflow_1.set_finish_point("function_2")

In [None]:
app_1 = workflow_1.compile()

In [None]:
try:
    display(Image(app_1.get_graph().draw_mermaid_png()))

except Exception as e:
    # This requires some extra dependencies and is optional
    print(e)

In [None]:
app_1.invoke("Hi, This is a new day")

In [None]:
input_1 = "Hi, This is a new day"

In [None]:
for output in app_1.stream(input_1):
    for key, value in output.items():
        print(f"Here is output from {key}")
        print("------------")
        print(value)
        print("\n")

In [None]:
model = llm = ChatGoogleGenerativeAI(model="gemini-1.0-pro")

embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")

In [None]:
model.invoke("hi")


In [None]:
def function_4(input):
    model = ChatGoogleGenerativeAI(model="gemini-1.0-pro")
    response = model.invoke(input).content
    return response

In [None]:
def function_5(input):
    upper_case = input.upper()
    return upper_case

In [None]:
workflow_2 = Graph()

In [None]:
workflow_2.add_node("llm", function_4)

In [None]:
workflow_2.add_node("upper_case", function_5)

In [None]:
workflow_2.add_edge("llm", "upper_case")

In [None]:
workflow_2.set_entry_point("llm")
workflow_2.set_finish_point("upper_case")

In [None]:
app_2 = workflow_2.compit()

In [None]:
try:
    display(Image(app_2.get_graph().draw_mermaid_png()))

except Exception as e:
    # This requires some extra dependencies and is optional
    print(e)

In [None]:
input_2 = "What is LangGraph?"

In [None]:
app_2.invoke(input_2)

In [None]:
for output in app_2.stream(input_2):
    for key, value in output.items():
        print(f"Here is output from {key}")
        print("------------")
        print(value)
        print("\n")

### Creating our own output token counter

In [None]:
def function_6(input):
    token = input.split()
    token_number = len(token)
    token_number =  f"Total token number is {token_number}"
    return token_number

In [None]:
workflow3=Graph()
workflow3.add_node("11m", function_4)
workflow3.add_node("token_counter", function_6)
workflow3.add_edge("11m", "token_counter")
workflow3.set_entry_point("11m")
workflow3.set_finish_point("token_counter")
app3=workflow3.compile()

In [None]:
try:
    display(Image(app3.get_graph().draw_mermaid_png()))

except Exception as e:
    # This requires some extra dependencies and is optional
    print(e)

In [None]:
app3.invoke(input_2)

# Integrating RAG pipeline

In [None]:
loader = DirectoryLoader("../data", glob="./*.txt", loader_cls=TextLoader)

In [None]:
docs = loader.load()

In [None]:
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=100,
    chunk_overlap=50
)

In [None]:
new_docs = text_splitter.split_documents(documents=docs)
doc_strings = [doc.page_content for doc in new_docs]

In [None]:
db = Chroma.from_documents(documents=new_docs, embedding=embeddings)
retriever = db.as_retriever(search_kwargs={"k": 3})

In [None]:
query = "What is meta llama3?"

In [None]:
docs = retriever.get_relevant_documents(query=query)
print(docs[0].metadata)
print(docs[0].page_content)

In [None]:
for doc in docs:
    print(doc)

In [None]:
def function_1_for_rag(AgentState):
    message = AgentState["messages"]

    question = message[-1]

    complete_prompt = "Your task is to provide only the briegf based on the user query. \
        Don't include too much reasoning. Follow user query:  " + question
    
    response = model.invoke(complete_prompt)

    AgentState["messages"].append(response.content) # Appending LLM call response to the AgentState

    #print(AgentState)

    return AgentState

In [None]:
def function_2_for_rag(AgentState):
    messages = AgentState["messages"]
    question = messages[0] # Fetching the user question

    template = """"Answer the questions based only on the following context
        {context}
        
        Question: {question}"""
    
    prompt = ChatPromptTemplate.from_template(template=template)

    retrieval_chain = (
        {"question": retriever, "question":RunnablePassthrough()}

        | prompt
        | model
        | StrOutputParser()
    )

    result = retrieval_chain.invoke(question)

    return result

In [None]:
# Define a LangChain Graph
workflow_for_rag = Graph()
workflow_for_rag.add_node("LLM", function_1_for_rag)
workflow_for_rag.add_node("RAGTool", function_2_for_rag)
workflow_for_rag.add_edge("LLM", "RAGTool")
workflow_for_rag.set_entry_point("LLM")
workflow_for_rag.set_finish_point("RAGTool")

app_for_rag = workflow_for_rag.compile()

In [None]:
try:
    display(Image(app_for_rag.get_graph().draw_mermaid_png()))

except Exception as e:
    # This requires some extra dependencies and is optional
    print(e)

In [None]:
input_for_rag = {"messages":["tell me about llama3 model"]}

In [None]:
for output in app_for_rag.stream(input_for_rag):
    #stream() yields dictionaries with output keyed by node name
    for key, value in output.items():
        print(f"Here is output from node {key}: ")
        print("------------")
        print(value)
        print("\n---\n")

# Trying different workflow (LLM vs RAG)

In [None]:
loader_2 = DirectoryLoader("../data", glob="./*.txt", loader_cls=TextLoader)

docs_2 = loader.load()

new_docs_2 = text_splitter.split_documents(docs_2)
doc_strings_2 = [doc.page_content for doc in new_docs_2]



In [None]:
db_2 = Chroma.from_documents(new_docs_2, embeddings)
retriever_2 = db_2.as_retriever(search_kwargs={"k":3})

In [None]:
query_1 = "Tell me about USA Industrial Growth"
docs_2 = retriever_2.get_relevant_documents(query_1)
print(docs_2[0].metadata)
print(docs_2[0].page_content)

for doc in docs_2:
    print(doc)

In [None]:
class AgentState(TypedDict):
    # The 'messages' field should be a sequence of strings, and we annotate it with 'operator.add'
    # This implies we might want to 'add' new messages to the sequence later
    messages: Annotated[Sequence[BaseMessage], operator.add]

In [None]:
class TopicSelectionParser(BaseModel):
    topic: str = Field(description="Selected Topic")
    Reasoning : str = Field(description="Reasoning behind topic selection")

In [None]:
parser = PydanticOutputParser(pydantic_object=TopicSelectionParser)

In [None]:
def function_1_for_comparing(state):
    messages = state["messages"]
    question = messages[-1]
    print(question)

    template = """
    Your task is to classify the given user query into one of the following categories: [USA, Not Related].
    Only respond with the category name and nothing else:

    user query: {question}
    {format_instructions}
    
    """

    prompt = PromptTemplate(
        template=template,
        partial_variables={
            "format_instructions": parser.get_format_instructions()
        }
    )

    chain = prompt | model | parser

    response = chain.invoke({"question": question, "format_instructions": parser.get_format_instructions()})

    print(response)

    return {"messages": [response.Topic]}

In [None]:
def router(state):
    print("-> Router ->")

    messages = state["messages"]
    last_message = messages[-1]
    print(last_message)
    #last_message = last_message.upper()

    if "USA" in last_message:
        return "RAG Call"
    
    else:
        return "LLM Call"

In [None]:
def function_for_rag(state):
    print("-> Calling RAG ->")
    messages = state["messages"]
    question = messages[0] # Fetching the user question
    print(question)

    template = """

    Answer the question based only on the following context:
    {context}

    Question: {question}
    """

    prompt = ChatPromptTemplate.from_template(template=template)

    print(prompt)

    retrieval_cahin = (
        {"Content": retriever, "question":RunnablePassthrough()}
        | prompt
        | model
        | StrOutputParser()
    )

    result = retrieval_cahin.invoke(question)

    return {"messages": [result]}

In [None]:
def function_for_llm(state):
    print("-> Calling LLM ->")

    messages = state["messages"]
    question = messages[0] # Fetching the user question

    # Normal LLM call
    complete_query = """
    Answer the following question with your knowledge of the real world. Following is the user question
    """ + question

    response = model.invoke(complete_query)

    return {"messages": [response.content]}

In [None]:
workflow_for_llm_ray = StateGraph(AgentState) ## StateGraph with AgentState
workflow_for_llm_ray.add_node("agent", function_1_for_comparing)
workflow_for_llm_ray.add_node("RAG", function_for_rag)
workflow_for_llm_ray.add_node("LLM", function_for_llm)
workflow_for_llm_ray.set_entry_point("agent")

workflow_for_llm_ray.add_conditional_edges(
    "agent",
    router,
    {
        "RAG Call": "RAG",
        "LLM Call": "LLM"
    }
)

workflow_for_llm_ray.add_edge("RAG", END)
workflow_for_llm_ray.add_edge("LLM", END)
app_for_llm_rag = workflow_for_llm_ray.compile()


In [None]:
try:
    display(Image(app_for_llm_rag.get_graph().draw_mermaid_png()))

except Exception as e:
    # This requires some extra dependencies and is optional
    print(e)

In [None]:
input_for_llm_rag = {"messages": ["Tell me about USA industrial growth"]}

In [None]:
output_1 = app_for_llm_rag.invoke(input_for_llm_rag)