In [1]:
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.runnables import Runnable, RunnableLambda


class InMemoryHistory(BaseChatMessageHistory, BaseModel):
    messages: List[BaseMessage] = Field(default_factory=list)
    save_mode: Optional[str] = Field(default="both")  # "input", "output", "both"
    
    def add_messages(self, messages: List[BaseMessage]) -> None:
        """조건에 따라 메시지를 저장"""
        if self.save_mode == "input":
            input_messages = [msg for msg in messages if isinstance(msg, HumanMessage)]
            self.messages.extend(input_messages)
        elif self.save_mode == "output":
            output_messages = [msg for msg in messages if isinstance(msg, AIMessage)]
            self.messages.extend(output_messages)
        elif self.save_mode == "both":
            self.messages.extend(messages)

    def clear(self) -> None:
        self.messages = []

store = {}
def get_session_history(session_id: str, save_mode: str = "both") -> BaseChatMessageHistory:
    if session_id not in store:
        store[session_id] = InMemoryHistory(save_mode=save_mode)
    return store[session_id]

class RunnableWithMessageHistory(Runnable):
    def __init__(self, runnable: Runnable, get_session_history, input_messages_key: str, history_messages_key: str, context_key: Optional[str] = None):
        self.runnable = runnable
        self.get_session_history = get_session_history
        self.input_messages_key = input_messages_key
        self.history_messages_key = history_messages_key
        self.context_key = context_key
    
    def invoke(self, input: Dict[str, Any], config: Optional[Dict[str, Any]] = None) -> Any:
        session_id = config["configurable"]["session_id"]
        save_mode = config["configurable"].get("save_mode", "both")
        history = self.get_session_history(session_id, save_mode)
        
        current_input = input[self.input_messages_key]
        
        if isinstance(current_input, str):
            current_input_message = HumanMessage(content=current_input)
            history.add_messages([current_input_message])
        
        input[self.history_messages_key] = history.messages
        
        result = self.runnable.invoke(input, config)
        
        if isinstance(result, AIMessage):
            if self.context_key and self.context_key in input:
                context = input[self.context_key]
                result_with_context = AIMessage(content=f"{context}\n{result.content}")
                history.add_messages([result_with_context])
            else:
                history.add_messages([result])
        
        return result

In [2]:
def add_memory(runnable, session_id, context="", save_mode="both"):
    runnable_with_memory = RunnableWithMessageHistory(
        runnable,
        get_session_history,
        input_messages_key="input",
        history_messages_key="chat_history",
        context_key="context"
    )
    
    memory_by_session = RunnableLambda(
        lambda input: runnable_with_memory.invoke(
            {**input, "context": context},
            config={"configurable": {"session_id": session_id,
                                     "save_mode": save_mode}}
        )
    )
    return memory_by_session

In [3]:
class FakeRunnable(Runnable):
    def invoke(self, input: Dict[str, Any], config: Optional[Dict[str, Any]] = None) -> Any:
        return AIMessage(content=f"Received input: {input['input']}")

In [4]:
runnable = FakeRunnable()
runnable_with_memory = add_memory(runnable, session_id="test1", save_mode="input")
runnable_with_memory.invoke({"input": "Hello"}, config={"configurable": {"session_id": "test1"}})

AIMessage(content='Received input: Hello')

In [5]:
# 출력 데이터만 저장
runnable_with_memory = add_memory(runnable, session_id="test2", save_mode="output")
runnable_with_memory.invoke({"input": "How are you?"}, config={"configurable": {"session_id": "test2"}})


AIMessage(content='Received input: How are you?')

In [6]:
# 입력 출력 데이터 모두 저장
runnable_with_memory = add_memory(runnable, session_id="test3", save_mode="both")
runnable_with_memory.invoke({"input": "What's the weather like?"}, config={"configurable": {"session_id": "test3"}})

AIMessage(content="Received input: What's the weather like?")

In [7]:
print(store["test1"].messages)
print(store["test2"].messages)
print(store["test3"].messages)

[HumanMessage(content='Hello')]
[AIMessage(content='\nReceived input: How are you?')]
[HumanMessage(content="What's the weather like?"), AIMessage(content="\nReceived input: What's the weather like?")]


# 2

In [13]:
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder

prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "You're an assistant who's good at {ability}. Respond in 20 words or fewer"),
        MessagesPlaceholder(variable_name="chat_history"),
        ("human", "{input}"),
    ]
)


In [14]:
class FakeChatOpenAI(Runnable):
    def invoke(self, input: Dict[str, Any], config: Optional[Dict[str, Any]] = None) -> Any:
        return AIMessage(content=f"Responding to: {input['input']}")

model = FakeChatOpenAI()

In [21]:
class FakeChatOpenAI(Runnable):
    def invoke(self, input: Dict[str, Any], config: Optional[Dict[str, Any]] = None) -> Any:
        # input은 단일 메시지 문자열을 포함할 수 없기 때문에 명시적으로 처리해야 함
        input_message = input.get('input')
        if isinstance(input_message, str):
            return AIMessage(content=f"Responding to: {input_message}")
        return AIMessage(content="Invalid input")
    
model = FakeChatOpenAI()

In [22]:
runnable = prompt | model

In [23]:
# 입력 데이터만 저장
runnable_with_memory = add_memory(runnable, session_id="test1", save_mode="input")
print(runnable_with_memory.invoke({"input": "Hello", "ability": "korean"}, config={"configurable": {"session_id": "test1"}}))
print(store["test1"].messages)

AttributeError: 'ChatPromptValue' object has no attribute 'get'