In [1]:
# Import relevant functionality
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import HumanMessage, ToolMessage
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import create_react_agent
from tools.visit_web_page_tool import VisitWebPageSyncTool
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.tools.retriever import create_retriever_tool
from langchain_openai import ChatOpenAI 
import os, json
from typing import TypedDict
from langgraph.graph import StateGraph, START, END, MessagesState
from langgraph.prebuilt import ToolNode, tools_condition
from dotenv import load_dotenv
from functions.download_transcripts_func import download_transcripts_func
from common.common import GraphState
from tools.crawl_web_page_tool import CrawlWebPageSyncTool
from tools.query_database_tool import QueryDatabaseTool
from functions.initialize_database import initialize_database

  warn(


In [2]:
load_dotenv()

True

In [3]:
llm = ChatOpenAI(
    temperature=0,
    model="gpt-4o")

In [4]:
tools = [CrawlWebPageSyncTool(),
        QueryDatabaseTool(db_path=os.getenv("DB_PATH"))]

llm_with_tools = llm.bind_tools(tools)

async def tool_calling_llm(state: GraphState) -> dict:
    """Function to call the LLM with tools."""
    response = await llm_with_tools.ainvoke(
        [HumanMessage(content=state["query"])]
    )
    state["messages"] = [response]
    print(f"LLM Response: {response}")
    return state

Collections in the database: [Collection(name=langchain)]


In [5]:
def should_continue(state: GraphState) -> str:
    """Condition to check if the tool call should continue."""
    print(f"should_continue: state: {state}")

    if state["messages"][-1].additional_kwargs["tool_calls"]:
        print(f"Tool calls found in response, continuing with tool node, {state["messages"][-1].additional_kwargs["tool_calls"]}")
        return "tool_node"
    
    print("No tool calls found in response, ending the process.")
    return "END"

In [6]:
def decide_next_node(state: GraphState) -> str:
    """Decide the next node based on the state."""
    print(f"decide_next_node: state: {state}")
    
    messages = state["messages"]
    tool_messages = [msg for msg in messages if isinstance(msg, ToolMessage)]
    
    for msg in tool_messages:
        if msg.name == "crawl_web_page":
            return "download_transcripts_func"
        elif msg.name == "query_database":
            return END

In [7]:
builder = StateGraph(GraphState)
builder.add_node("tool_calling_llm", tool_calling_llm)
builder.add_node("tools", ToolNode(tools))
builder.add_node("download_transcripts_func", download_transcripts_func)
builder.add_node("initialize_database", initialize_database)

builder.add_edge(START, "tool_calling_llm")
builder.add_conditional_edges("tool_calling_llm", tools_condition, ["tools", END])
builder.add_conditional_edges("tools", decide_next_node)
builder.add_edge("download_transcripts_func", "initialize_database")
builder.add_edge("initialize_database", END)

graph = builder.compile()

In [8]:
graph.get_graph().print_ascii()

        +-----------+     
        | __start__ |     
        +-----------+     
              *           
              *           
              *           
    +------------------+  
    | tool_calling_llm |  
    +------------------+  
         ..        ..     
       ..            .    
      .               ..  
+-------+               . 
| tools |             ..  
+-------+            .    
         **        ..     
           **    ..       
             *  .         
         +---------+      
         | __end__ |      
         +---------+      


In [9]:
# result = await graph.ainvoke(
#     GraphState(query="get all blogs from https://www.thecloudcast.net"), debug=False
# )

In [10]:
result = await graph.ainvoke(
    GraphState(query="query all blogs having reference to data management"), debug=False
)

LLM Response: content='' additional_kwargs={'tool_calls': [{'id': 'call_TLh4uMFJlxBswtMftw5DJveY', 'function': {'arguments': '{"query":"data management"}', 'name': 'query_database'}, 'type': 'function'}], 'refusal': None} response_metadata={'token_usage': {'completion_tokens': 15, 'prompt_tokens': 123, 'total_tokens': 138, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-2024-08-06', 'system_fingerprint': 'fp_a288987b44', 'id': 'chatcmpl-Bc4EyXHeB1tPqmIUkZYLZlymT518V', 'service_tier': 'default', 'finish_reason': 'tool_calls', 'logprobs': None} id='run--348a2017-f18f-42a2-9b6d-080ef71f5be8-0' tool_calls=[{'name': 'query_database', 'args': {'query': 'data management'}, 'id': 'call_TLh4uMFJlxBswtMftw5DJveY', 'type': 'tool_call'}] usage_metadata={'input_tokens': 123, 'output_tokens': 15, 'total_tokens': 138, 'inpu

In [11]:
# result = await graph.ainvoke(
#     GraphState(query="visit https://www.thecloudcast.net/"), debug=False
# )