# LangChain 1.0 Memory

Ref:

- https://docs.langchain.com/oss/python/langchain/agents#memory
- https://docs.langchain.com/oss/python/langchain/short-term-memory#customizing-agent-memory

In [0]:
%pip install -U langchain>=1.0.0 langchain_openai>=1.0.0 mlflow

%restart_python

## モデルとの接続

In [0]:
from langchain_openai import ChatOpenAI
import mlflow

mlflow.langchain.autolog()

creds = mlflow.utils.databricks_utils.get_databricks_host_creds()
model = ChatOpenAI(
    model="databricks-gpt-oss-20b",
    base_url=creds.host + "/serving-endpoints",
    api_key=creds.token,
)

## エージェントの作成と実行

In [0]:
from langgraph.checkpoint.memory import InMemorySaver
from langchain.agents.middleware import AgentState, AgentMiddleware
from typing_extensions import NotRequired
from typing import Any

class CustomState(AgentState):
    model_call_count: NotRequired[int]

class CallCounterMiddleware(AgentMiddleware[CustomState]):
    state_schema = CustomState  

    def before_model(self, state: CustomState, runtime) -> dict[str, Any] | None:
        count = state.get("model_call_count", 0)
        if count > 10:
            return {"jump_to": "end"}
        return None

    def after_model(self, state: CustomState, runtime) -> dict[str, Any] | None:
        return {"model_call_count": state.get("model_call_count", 0) + 1}

In [0]:
from langchain.agents import create_agent
from langchain.agents.structured_output import ToolStrategy
from pprint import pprint


def get_weather(city: str) -> str:
    """指定した都市の天気を取得します。"""
    return f"It's always sunny in {city}!"


agent = create_agent(
    model=model,
    tools=[get_weather],
    middleware=[CallCounterMiddleware()],
)

input = {
    "messages": [{"role": "user", "content": "SFの天気は？"}]
}
for event in agent.stream(input, stream_mode="updates"):
    pprint(event)