In [None]:
from typing import Annotated, Sequence, TypedDict
from dotenv import load_dotenv
from langchain_core.messages import BaseMessage, ToolMessage, SystemMessage
from langchain_openai import ChatOpenAI
from langchain_core.tools import tool
from langgraph.graph.message import add_messages
from langgraph.graph import StateGraph, START, END
from langgraph.prebuilt import ToolNode

load_dotenv()

class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], add_messages] # reducer function

@tool
def add(a: int, b: int):
    """This is an addition function that adds two numbers together"""
    return a + b

tools = [add]
model = ChatOpenAI(model = 'ft:gpt-4.1-nano-2025-04-14:yvon-kim:kaggle-therapy-conversations:CANdhqZC').bind_tools(tools)

def model_call(state: AgentState) -> AgentState:
    system_prompt = SystemMessage(content = """
    You are acting like a CBAT or EMDR therapist. You will be maximally empathetic and supportive. Your goal is to help the user feel better and guide them through difficult times. 
    You will act as a therapist/counselor and provide guidance and support to the user.
    """)
    response = model.invoke([system_prompt] + state['messages'])
    return {"messages": [response]}


def loop_tool(state: AgentState) -> AgentState:
    messages = state['messages']
    last_message = messages[-1]
    if not last_message.tool_calls:
        return "end"
    else:
        return "continue"
    
graph = StateGraph(AgentState)
graph.add_node("our_agent", model_call)

tool_node = ToolNode(tools=tools)
graph.add_node("tools", tool_node)
graph.set_entry_point("our_agent")
graph.add_conditional_edges(
    "our_agent",
    loop_tool,
    {
        "end": END,
        "continue": "tools"
    }
)

graph.add_edge("tools", "our_agent")
app = graph.compile()

def print_stream(stream):
    for s in stream:
        message = s['messages'][-1]
        if isinstance(message, tuple):
            print(message)
        else:
            message.pretty_print()
    
inputs = {"messages": [("user", "wassup")]}
print_stream(app.stream(inputs, stream_mode="values"))


wassup

Hey there! What's poppin'? How are you doing today?
