In [1]:
from langchain_openai import ChatOpenAI
llm = ChatOpenAI()

## order_create_router

In [36]:
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder

routing_criteria = """
    이전 대화에서 판매 중인 상품 목록을 제시했는지 확인해.
    만약 판매중인 상품 목록을 제시하지 않았다면 fetch_product_list를 출력해.
    고객이 주문할 품목을 말했다면 TodOrderRequestConfirmation를 출력해.
    고객이 주문을 진행하려는 내용에 동의했다면 create_order를 출력해.
    """

order_create_router_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            """
            너는 주문을 생성하는 유능한 주문봇이야.
            적절한 도구를 사용해 고객의 요청을 처리해.
            고객이 응답이 필요할 때는 도구를 사용하지 말고 고객에게 응답을 부탁해.
            create_order를 사용해 실제로 주문을 생성하기 전에는 주문을 했다는 거짓말을 하면 안돼.

            {routing_criteria}

            user_id: {user_info},
            """,
        ),
        MessagesPlaceholder(variable_name="messages"),
    ]
).partial(routing_criteria=routing_criteria)

order_create_router_runnable = order_create_router_prompt | llm

## routing tester

In [37]:
from langchain_core.pydantic_v1 import BaseModel, Field

class RecordMisCalssification(BaseModel):
    """Records misclassification of router"""
    input: str = Field(description="라우터에 입력된 메시지")
    classification_of_router: str = Field(description="라우터의 분류 결과")
    reason_for_misclassification: str = Field(description=" 라우터의 분류 결과가 잘못된 이유")

In [38]:
from langchain_core.prompts import ChatPromptTemplate

check_routing_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            """
            너는 라우팅 결과가 올바른지 꼼꼼하게 점검해야 해.
            라우팅 기준: {routing_criteria}
            입력 메시지: {input_message}
            모델의 분류 결과: {model_classification}

            라우팅 결과가 올바르다면 good을 출력해.
            라우팅 결과가 올바르지 않다면 도구를 사용해 문제점을 기록해.            
            
            """,
        ),
        MessagesPlaceholder(variable_name="messages"),
    ]
).partial(routing_criteria=routing_criteria)

model="gpt-4o"
gpt4o = ChatOpenAI(model=model)
check_routing_runnable = check_routing_prompt | gpt4o.bind_tools([RecordMisCalssification])

## test

In [39]:
input_message = "주문할게요"
model_classification = "fetch_product_list"
messages = []

output = check_routing_runnable.invoke({"input_message": input_message,
                                        "model_classification": model_classification,
                                        "messages": messages})
output 

AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_vVNdZmOxpGCm53tMDkRtsdf7', 'function': {'arguments': '{"input":"주문할게요","classification_of_router":"fetch_product_list","reason_for_misclassification":"고객이 주문을 진행하려는 내용에 동의했습니다. create_order를 출력해야 합니다."}', 'name': 'RecordMisCalssification'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 57, 'prompt_tokens': 256, 'total_tokens': 313}, 'model_name': 'gpt-4o', 'system_fingerprint': 'fp_dd932ca5d1', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-079aa041-1f62-40c0-9ddb-6c1cc1421cc6-0', tool_calls=[{'name': 'RecordMisCalssification', 'args': {'input': '주문할게요', 'classification_of_router': 'fetch_product_list', 'reason_for_misclassification': '고객이 주문을 진행하려는 내용에 동의했습니다. create_order를 출력해야 합니다.'}, 'id': 'call_vVNdZmOxpGCm53tMDkRtsdf7'}])

In [32]:
input_message = "주문할게요"
model_classification = "TodOrderRequestConfirmation"
messages = []

output = check_routing_runnable.invoke({"input_message": input_message,
                                        "model_classification": model_classification,
                                        "messages": messages})
output 

AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_ttbKuzZoAOQD3G5aqJqQK9Gq', 'function': {'arguments': '{"input":"주문할게요","classification_of_router":"TodOrderRequestConfirmation","reason_for_misclassification":"The router misclassified the input as TodOrderRequestConfirmation instead of fetch_product_list."}', 'name': 'RecordMisCalssification'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 55, 'prompt_tokens': 355, 'total_tokens': 410}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': None, 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-4fa71b55-e779-47c6-8f1c-ed51fcd47e6d-0', tool_calls=[{'name': 'RecordMisCalssification', 'args': {'input': '주문할게요', 'classification_of_router': 'TodOrderRequestConfirmation', 'reason_for_misclassification': 'The router misclassified the input as TodOrderRequestConfirmation instead of fetch_product_list.'}, 'id': 'call_ttbKuzZoAOQD3G5aqJqQK9Gq'}])

In [13]:
output.tool_calls[0]["args"]

{'input': '주문할게요',
 'classification_of_router': 'TodOrderRequestConfirmation',
 'reason_for_misclassification': '라우팅 결과가 올바르지 않음. 주문할게요는 fetch_product_list를 출력해야 함.'}

In [19]:
output.

AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_3AV4M4oA9jwmWyHqxkLLGX4A', 'function': {'arguments': '{"input":"주문할게요","classification_of_router":"TodOrderRequestConfirmation","reason_for_misclassification":"라우팅 결과가 올바르지 않음. 주문할게요는 fetch_product_list를 출력해야 함."}', 'name': 'RecordMisCalssification'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 69, 'prompt_tokens': 300, 'total_tokens': 369}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': None, 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-8b6c2500-a48f-44b8-80af-6b7885f60616-0', tool_calls=[{'name': 'RecordMisCalssification', 'args': {'input': '주문할게요', 'classification_of_router': 'TodOrderRequestConfirmation', 'reason_for_misclassification': '라우팅 결과가 올바르지 않음. 주문할게요는 fetch_product_list를 출력해야 함.'}, 'id': 'call_3AV4M4oA9jwmWyHqxkLLGX4A'}])

In [20]:
from langchain_core.messages import AIMessage

isinstance(output, AIMessage)

True

# graph

In [None]:
from typing import Annotated, Literal, Optional
from typing_extensions import TypedDict

from langgraph.graph.message import AnyMessage, add_messages


def update_dialog_stack(left: list[str], right: Optional[str]) -> list[str]:
    """Push or pop the state."""
    if right is None:
        return left
    if right == "pop":
        return left[:-1]
    return left + [right]


class State(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]
    user_info: str
    dialog_state: Annotated[
            list[
                Literal[
                    "order_inquiry",
                    "order_create",
                    "order_change",
                    "order_cancel",
                ]
            ],
            update_dialog_stack,
        ]
    order_id: int = None 
    orders: str = None
    selected_order: str = None
    product_presentation: bool = False
    request_order_change_message: bool = False
    request_approval_message: bool = False
    task_completed: bool = False

In [None]:
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.messages import AIMessage

class Assistant:
    def __init__(self, runnable: Runnable):
        self.runnable = runnable

    def __call__(self, state: State, config: RunnableConfig):
        while True:
            result = self.runnable.invoke(state)

            if not result.tool_calls and (
                not result.content
                or isinstance(result.content, list)
                and not result.content[0].get("text")
            ):
                messages = state["messages"] + [("user", "Respond with a real output.")]
                state = {"messages": messages}
            else:
                break
        
        add_state = {k: v for k, v in state.items() if k != "dialog_state"}
        
        return {**add_state, "messages": result}

In [None]:
from langchain_core.messages import ToolMessage
from langgraph.prebuilt import ToolNode
from langchain_core.runnables import RunnableLambda


def handle_tool_error(state) -> dict:
    print("-"*77)
    print("handle_tool_error 진입")
    print("state\n", state)
    
    error = state.get("error")
    tool_calls = state["messages"][-1].tool_calls
    return {
        "messages": [
            ToolMessage(
                content=f"Error: {repr(error)}\n please fix your mistakes.",
                tool_call_id=tc["id"],
            )
            for tc in tool_calls
        ]
    }


def create_tool_node_with_fallback(tools: list) -> dict:
    return ToolNode(tools).with_fallbacks(
        [RunnableLambda(handle_tool_error)], exception_key="error"
    )

In [None]:
from langgraph.graph import StateGraph
from langchain_core.messages import ToolMessage, AIMessage

builder = StateGraph(State)

def order_create_router(state: State):
    messages = state["messages"]

    output = order_create_router_runnable.invoke({"messages": messages,
                                                  "user_info": 1,})

    return {"messages": output}

builder.add_node("order_create_router", order_create_router)
        
def check_router_result(state: State):
    messages = state["messages"]
    input_message = messages[-1]
    model_classification = messages[-2]

    evaluate_classification = order_create_router_runnable.invoke({"input_message": input_message,
                                                                   "model_classification": model_classification,
                                                                   "messages": messages})
    if hasattr(evaluate_classification, "tool_calls"):
        tool_call_id = evaluate_classification.tool_calls[0]["id"]
        tool_args = evaluate_classification.tool_calls[0]["args"]
        tool_output = ToolMessage(
                    content=f"라우터의 분류 결과가 잘못된 것 같습니다. 다음 정보를 활용해 출력을 수정해주세요.{tool_args}",
                    tool_call_id=tool_call_id,
                )
        return {"messages": tool_output}
    return state
    
def after_check_router_result(state: State):
    messages = state["messages"]
    last_message = messages[-1]
    
    if isinstance(last_message, AIMessage):
        return "order_change"
    else:
        return 
    
    
    

builder.add_node("check_router_result", check_router_result)
builder.add_edge("order_create_router", "check_router_result")
    

    
