In [None]:
from typing import Annotated
from langgraph.graph import StateGraph, START, END
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.graph.message import add_messages # A reducer function
from langchain_openai import ChatOpenAI, AzureChatOpenAI
from pydantic import BaseModel
from dotenv import load_dotenv
from typing import Annotated
from databricks_langchain import ChatDatabricks
import os
from client import AzureAIClient
from IPython.display import Image, display
import gradio as gr
from langgraph.checkpoint.memory import MemorySaver
from pydantic import BaseModel
import sqlite3
from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
import aiosqlite
from agent_tools import get_agent_tools, get_playwright_tools

load_dotenv(override=True)

In [None]:
# Create a simple state
class State(BaseModel):
    messages: Annotated[list, add_messages]


# Get llms
llm_db = ChatDatabricks(endpoint="databricks-claude-3-7-sonnet", max_tokens=1000)

llm_az = AzureChatOpenAI(
    api_version="2024-12-01-preview",
    azure_ad_token_provider=AzureAIClient().token_provider,
    azure_deployment="gpt-4o"
)

In [None]:
# Get the tools
tools_list = await get_agent_tools()
playwright_tools, browser, playwright = await get_playwright_tools()

agent_tools = tools_list + playwright_tools
print(agent_tools)

# Give the llm the tools
llm_with_tools = llm_az.bind_tools(agent_tools)

In [None]:
for tool in agent_tools:
    print(f"Tool Name: {tool.name} == {tool}")

# Drop the serper tool for testing the playwright tool
agent_tools = [tool for tool in agent_tools if tool.name != "google_serper_search_tool"]
print("\n\n", agent_tools)
for tool in agent_tools:
    print(f"Tool Name: {tool.name} == {tool}")

In [None]:
# Memory database
db_path="memory_db/sqlite_memory.db"
conn = sqlite3.connect(db_path, check_same_thread=False)

async def setup_async_db():
    async_conn = await aiosqlite.connect(db_path)
    return async_conn

async_conn = await setup_async_db()

# sql_memory = SqliteSaver(conn)
sql_memory = AsyncSqliteSaver(async_conn)

# memory = MemorySaver()

In [None]:
# Chat function
def chat(state: State) -> State:
    response = llm_with_tools.invoke(state.messages)
    new_state = State(messages=[response])
    return new_state

### Build Graph

In [None]:
graph_builder = StateGraph(state_schema=State)

graph_builder.add_edge(START, "chat")
graph_builder.add_node("chat", chat)
graph_builder.add_node("tools", ToolNode(tools=agent_tools))

graph_builder.add_conditional_edges("chat", tools_condition, "tools")
graph_builder.add_edge("tools", "chat")

graph = graph_builder.compile(checkpointer=sql_memory)
# graph = graph_builder.compile(checkpointer=memory)
display(Image(graph.get_graph().draw_mermaid_png()))

### Create gradio chat function

In [None]:
config = {"configurable": {"thread_id": "2"}}

async def gradio_chat(user_input: str, history):
    state = State(messages=[{"role": "user", "content": user_input}])
    response = await graph.ainvoke(state, config=config) # Config sets the thread to use in memory
    print(response)

    return response["messages"][-1].content

In [None]:
# Chat interface
gr.ChatInterface(gradio_chat, type="messages").launch()