# LangChain 1.0 Custom Middleware

Ref:

- https://docs.langchain.com/oss/python/langchain/middleware#custom-middleware

In [0]:
%pip install -U langchain>=1.0.0 langchain_openai>=1.0.0 mlflow tenacity

%restart_python

## モデルとの接続

In [0]:
from langchain.chat_models import init_chat_model
import mlflow
import os

mlflow.langchain.autolog()

creds = mlflow.utils.databricks_utils.get_databricks_host_creds()
model = init_chat_model(
    # "openai:databricks-gpt-oss-20b",
    "openai:databricks-qwen3-next-80b-a3b-instruct",
    api_key=creds.token,
    base_url=creds.host + "/serving-endpoints",
)

# model.invoke("Hello")

# 以下のように環境変数設定でもできる
# os.environ["OPENAI_API_KEY"] = creds.token
# os.environ["OPENAI_BASE_URL"] = creds.host + "/serving-endpoints"

# model = init_chat_model("openai:databricks-gpt-oss-20b")
# model.invoke("Hello")

In [0]:

from langchain.tools import tool, ToolRuntime

@tool
def get_weather(city: str, runtime: ToolRuntime) -> str:
    """指定した都市の天気を取得します。"""
    return f"It's always sunny in {city}!"


## Decorator-based middleware

In [0]:
from langchain.agents.middleware import (
    before_model,
    after_model,
    wrap_model_call,
    before_agent,
    after_agent,
    wrap_tool_call,
    AgentState,
    ModelRequest,
    ModelResponse,
    dynamic_prompt,
)
from langchain.messages import AIMessage
from langchain.agents import create_agent
from langgraph.runtime import Runtime
from typing import Any, Callable


# エージェント実行前のロギング
@before_agent
def log_before_agent(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
    print("Agent処理を開始します")
    print(
        f"before_agent: {len(state['messages'])}件のメッセージでエージェントを呼び出そうとしています"
    )

    return None

# ノードスタイル: エージェント実行後のロギング
@after_agent
def log_after_agent(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
    print("Agent処理を終了します")
    print(
        f"after_agent: {len(state['messages'])}件のメッセージでエージェント処理を完了しました"
    )


# ノードスタイル: モデル呼び出し前のロギング
@before_model
def log_before_model(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
    print(
        f"before_model: {len(state['messages'])}件のメッセージでモデルを呼び出そうとしています"
    )

    return None


# ノードスタイル: モデル呼び出し後のバリデーション
@after_model(can_jump_to=["end"])
def validate_output(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
    last_message = state["messages"][-1]
    if "BLOCKED" in last_message.content:
        return {
            "messages": [AIMessage("そのリクエストには対応できません。")],
            "jump_to": "end",
        }
    return None


# ラップスタイル: リトライロジック
@wrap_model_call
def retry_model(
    request: ModelRequest,
    handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
    for attempt in range(3):
        try:
            return handler(request)
        except Exception as e:
            if attempt == 2:
                raise
            print(f"エラー発生後にリトライ {attempt + 1}/3: {e}")


@wrap_tool_call
def retry_tool(
    request: ModelRequest,
    handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
    for attempt in range(3):
        try:
            return handler(request)
        except Exception as e:
            if attempt == 2:
                raise
            print(f"エラー発生後にリトライ {attempt + 1}/3: {e}")


# カスタムミドルウェアを指定したエージェントを作成
agent = create_agent(
    model=model,
    middleware=[
        log_before_agent,
        log_after_agent,
        log_before_model,
        validate_output,
        retry_model,
        retry_tool,
    ],
    tools=[get_weather],
)

# ストリームで実行出力
for stream_mode, chunk in agent.stream(
    {"messages": [{"role": "user", "content": "東京の天気を教えて"}]},
    stream_mode=["updates"],
):
    print(f"stream_mode: {stream_mode}")
    print(f"content: {chunk}")
    print("\n")

## Class-based middleware

In [0]:
from langchain.agents.middleware import AgentMiddleware, AgentState
from langgraph.runtime import Runtime
from typing import Any


class LoggingMiddleware(AgentMiddleware):
    def before_model(
        self, state: AgentState, runtime: Runtime
    ) -> dict[str, Any] | None:
        print(f"About to call model with {len(state['messages'])} messages")
        return None

    def after_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
        print(f"Model returned: {state['messages'][-1].content}")
        return None


# クラスベースのミドルウェアをエージェントで使用
agent = create_agent(
    model=model,
    middleware=[
        LoggingMiddleware(),
    ],
    tools=[get_weather],
)

for stream_mode, chunk in agent.stream(
    {"messages": [{"role": "user", "content": "東京の天気を教えて"}]},
    stream_mode=["updates"],
):
    print(f"stream_mode: {stream_mode}")
    print(f"content: {chunk}")
    print("\n")

In [0]:
from langchain.agents.middleware import AgentMiddleware, ModelRequest, ModelResponse
from typing import Callable
from mlflow.entities import SpanType
from tenacity import retry, stop_after_attempt, wait_fixed, wait_exponential, RetryError, after_log
import time
from collections.abc import Awaitable

class RetryMiddleware(AgentMiddleware):
    def __init__(self, max_retries: int = 3):
        super().__init__()
        self.max_retries = max_retries

    def wrap_model_call(
        self,
        request: ModelRequest,
        handler: Callable[[ModelRequest], ModelResponse],
    ) -> ModelResponse:

        @retry(
            stop=stop_after_attempt(3),
            wait=wait_exponential(multiplier=1, min=4, max=10),
        )
        def wrap_handler(request):
            with mlflow.start_span(name="model_call", span_type=SpanType.CHAIN) as span:
                span.set_inputs(request)
                return handler(request)

        return wrap_handler(request)

    async def awrap_model_call(
        self,
        request: ModelRequest,
        handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
    ) -> ModelResponse:

        @retry(
            stop=stop_after_attempt(3),
            wait=wait_exponential(multiplier=1, min=4, max=10),
        )
        async def awrap_handler(request):
            with mlflow.start_span(name="model_call", span_type=SpanType.CHAIN) as span:
                span.set_inputs(request)
                return await handler(request)

        return await awrap_handler(request)


# クラスベースのミドルウェアをエージェントで使用
agent = create_agent(
    model=model,
    middleware=[
        RetryMiddleware(),
    ],
    tools=[get_weather],
)

for stream_mode, chunk in agent.stream(
    {"messages": [{"role": "user", "content": "東京の天気を教えて"}]},
    stream_mode=["updates"],
):
    print(f"stream_mode: {stream_mode}")
    print(f"content: {chunk}")
    print("\n")

In [0]:
agent
