Skip to content

Commit

Permalink
feat(memory): chat history memory support (#280)
Browse files Browse the repository at this point in the history
Co-authored-by: Benjamin-eecs <benjaminliu.eecs@gmail.com>
Co-authored-by: Guohao Li <lightaime@gmail.com>
Co-authored-by: lig <guohao.li@kaust.edu.sa>
  • Loading branch information
4 people committed Nov 5, 2023
1 parent 0e446e1 commit da517a8
Show file tree
Hide file tree
Showing 33 changed files with 1,175 additions and 224 deletions.
2 changes: 2 additions & 0 deletions camel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import camel.typing
import camel.utils
import camel.functions
import camel.memories
import camel.storages

__version__ = '0.1.0'

Expand Down
164 changes: 63 additions & 101 deletions camel/agents/chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,35 +24,20 @@
from camel.agents import BaseAgent
from camel.configs import BaseConfig, ChatGPTConfig
from camel.functions import OpenAIFunction
from camel.memories import (
BaseMemory,
ChatHistoryMemory,
MemoryRecord,
ScoreBasedContextCreator,
)
from camel.messages import BaseMessage, FunctionCallingMessage, OpenAIMessage
from camel.models import BaseModelBackend, ModelFactory
from camel.responses import ChatAgentResponse
from camel.terminators import ResponseTerminator, TokenLimitTerminator
from camel.typing import ModelType, RoleType
from camel.terminators import ResponseTerminator
from camel.typing import ModelType, OpenAIBackendRole, RoleType
from camel.utils import get_model_encoding, openai_api_key_required


@dataclass(frozen=True)
class ChatRecord:
r"""Historical records of who made what message.
Attributes:
role_at_backend (str): Role of the message that mirrors OpenAI
message role that may be `system` or `user` or `assistant`.
message (BaseMessage): Message payload.
"""
role_at_backend: str
message: BaseMessage

def to_openai_message(self):
r"""Converts the payload message to OpenAI-compatible format.
Returns:
OpenAIMessage: OpenAI-compatible message
"""
return self.message.to_openai_message(self.role_at_backend)


@dataclass(frozen=True)
class FunctionCallingRecord:
r"""Historical records of functions called in the conversation.
Expand Down Expand Up @@ -86,11 +71,18 @@ class ChatAgent(BaseAgent):
system_message (BaseMessage): The system message for the chat agent.
model (ModelType, optional): The LLM model to use for generating
responses. (default :obj:`ModelType.GPT_3_5_TURBO`)
model_config (Any, optional): Configuration options for the LLM model.
model_config (BaseConfig, optional): Configuration options for the
LLM model. (default: :obj:`None`)
memory (BaseMemory, optional): The agent memory for managing chat
messages. If `None`, a :obj:`ChatHistoryMemory` will be used.
(default: :obj:`None`)
message_window_size (int, optional): The maximum number of previous
messages to include in the context window. If `None`, no windowing
is performed. (default: :obj:`None`)
token_limit (int, optional): The maxinum number of tokens in a context.
The context will be automatically pruned to fulfill the limitation.
If `None`, it will be set according to the backend model.
(default: :obj:`None`)
output_language (str, optional): The language to be output by the
agent. (default: :obj:`None`)
function_list (List[OpenAIFunction], optional): List of available
Expand All @@ -105,7 +97,9 @@ def __init__(
system_message: BaseMessage,
model: Optional[ModelType] = None,
model_config: Optional[BaseConfig] = None,
memory: Optional[BaseMemory] = None,
message_window_size: Optional[int] = None,
token_limit: Optional[int] = None,
output_language: Optional[str] = None,
function_list: Optional[List[OpenAIFunction]] = None,
response_terminators: Optional[List[ResponseTerminator]] = None,
Expand All @@ -121,7 +115,6 @@ def __init__(

self.model: ModelType = (model if model is not None else
ModelType.GPT_3_5_TURBO)
self.message_window_size: Optional[int] = message_window_size

self.func_dict: Dict[str, Callable] = {}
if function_list is not None:
Expand All @@ -131,13 +124,16 @@ def __init__(

self.model_backend: BaseModelBackend = ModelFactory.create(
self.model, self.model_config.__dict__)
self.model_token_limit: int = self.model_backend.token_limit
self.model_token_limit = token_limit or self.model_backend.token_limit
context_creator = ScoreBasedContextCreator(
self.model_backend.token_counter,
self.model_token_limit,
)
self.memory: BaseMemory = memory or ChatHistoryMemory(
context_creator, window_size=message_window_size)

self.terminated: bool = False
self.token_limit_terminator = TokenLimitTerminator(
self.model_token_limit)
self.response_terminators = response_terminators or []
self.stored_messages: List[ChatRecord]
self.init_messages()

def reset(self):
Expand All @@ -149,7 +145,6 @@ def reset(self):
"""
self.terminated = False
self.init_messages()
self.token_limit_terminator.reset()
for terminator in self.response_terminators:
terminator.reset()

Expand Down Expand Up @@ -182,6 +177,17 @@ def is_function_calling_enabled(self) -> bool:
"""
return len(self.func_dict) > 0

def update_memory(self, message: BaseMessage,
role: OpenAIBackendRole) -> None:
r"""Updates the agent memory with a new message.
Args:
message (BaseMessage): The new message to add to the stored
messages.
role (OpenAIBackendRole): The backend role type.
"""
self.memory.write_record(MemoryRecord(message, role))

def set_output_language(self, output_language: str) -> BaseMessage:
r"""Sets the output language for the system message. This method
updates the output language for the system message. The output
Expand Down Expand Up @@ -232,35 +238,21 @@ def init_messages(self) -> None:
r"""Initializes the stored messages list with the initial system
message.
"""
self.stored_messages = [ChatRecord('system', self.system_message)]
system_record = MemoryRecord(self.system_message,
OpenAIBackendRole.SYSTEM)
self.memory.clear()
self.memory.write_record(system_record)

def update_messages(self, role: str,
message: BaseMessage) -> List[ChatRecord]:
r"""Updates the stored messages list with a new message.
def record_message(self, message: BaseMessage) -> None:
r"""Records the externally provided message into the agent memory as if
it were an answer of the :obj:`ChatAgent` from the backend. Currently,
the choice of the critic is submitted with this method.
Args:
role (str): Role of the message at the backend.
message (BaseMessage): The new message to add to the stored
messages.
Returns:
List[BaseMessage]: The updated stored messages.
message (BaseMessage): An external message to be recorded in the
memory.
"""
if role not in {'system', 'user', 'assistant', 'function'}:
raise ValueError(f"Unsupported role {role}")
self.stored_messages.append(ChatRecord(role, message))
return self.stored_messages

def submit_message(self, message: BaseMessage) -> None:
r"""Submits the externally provided message as if it were an answer of
the chat LLM from the backend. Currently, the choice of the critic is
submitted with this method.
Args:
message (BaseMessage): An external message to be added as an
assistant response.
"""
self.stored_messages.append(ChatRecord('assistant', message))
self.update_memory(message, OpenAIBackendRole.ASSISTANT)

@retry(wait=wait_exponential(min=5, max=60), stop=stop_after_attempt(5))
@openai_api_key_required
Expand All @@ -282,25 +274,22 @@ def step(
a boolean indicating whether the chat session has terminated,
and information about the chat session.
"""
messages = self.update_messages('user', input_message)
self.update_memory(input_message, OpenAIBackendRole.USER)

output_messages: List[BaseMessage]
info: Dict[str, Any]
called_funcs: List[FunctionCallingRecord] = []
while True:
# Format messages and get the token number
openai_messages: Optional[List[OpenAIMessage]]
num_tokens: int
openai_messages, num_tokens = self.preprocess_messages(messages)

# Terminate when number of tokens exceeds the limit
self.terminated, termination_reason = \
self.token_limit_terminator.is_terminated(num_tokens)
if self.terminated and termination_reason is not None:
return self.step_token_exceed(num_tokens, called_funcs,
termination_reason)
try:
openai_messages, num_tokens = self.memory.get_context()
except RuntimeError as e:
return self.step_token_exceed(e.args[1], called_funcs,
"max_tokens_exceeded")

# Obtain LLM's response and validate it
# Obtain the model's response and validate it
response = self.model_backend.run(openai_messages)
self.validate_model_response(response)

Expand All @@ -311,19 +300,21 @@ def step(
output_messages, finish_reasons, usage_dict, response_id = (
self.handle_stream_response(response, num_tokens))

if self.is_function_calling_enabled(
) and finish_reasons[0] == 'function_call':
if (self.is_function_calling_enabled()
and finish_reasons[0] == 'function_call'):
# Do function calling
func_assistant_msg, func_result_msg, func_record = (
self.step_function_call(response))

# Update the messages
messages = self.update_messages('assistant',
func_assistant_msg)
messages = self.update_messages('function', func_result_msg)
self.update_memory(func_assistant_msg,
OpenAIBackendRole.ASSISTANT)
self.update_memory(func_result_msg, OpenAIBackendRole.FUNCTION)

# Record the function calling
called_funcs.append(func_record)
else:
# Function calling disabled or chat stopped
# Function calling disabled or not a function calling

# Loop over responses terminators, get list of termination
# tuples with whether the terminator terminates the agent
Expand Down Expand Up @@ -352,35 +343,6 @@ def step(

return ChatAgentResponse(output_messages, self.terminated, info)

def preprocess_messages(
self,
messages: List[ChatRecord]) -> Tuple[List[OpenAIMessage], int]:
r"""Truncate the list of messages if message window is defined and
the current length of message list is beyond the window size. Then
convert the list of messages to OpenAI's input format and calculate
the number of tokens.
Args:
messages (List[ChatRecord]): The list of structs containing
information about previous chat messages.
Returns:
tuple: A tuple containing the truncated list of messages in
OpenAI's input format and the number of tokens.
"""

if (self.message_window_size
is not None) and (len(messages) > self.message_window_size):
messages = [ChatRecord('system', self.system_message)
] + messages[-self.message_window_size:]

openai_messages: List[OpenAIMessage]
openai_messages = [record.to_openai_message() for record in messages]
num_tokens = self.model_backend.count_tokens_from_messages(
openai_messages)

return openai_messages, num_tokens

def validate_model_response(self, response: Any) -> None:
r"""Validate the type of the response returned by the model.
Expand Down Expand Up @@ -480,7 +442,7 @@ def step_token_exceed(self, num_tokens: int,
ChatAgentResponse: The struct containing trivial outputs and
information about token number and called functions.
"""

self.terminated = True
output_messages: List[BaseMessage] = []

info = self.get_info(
Expand Down
6 changes: 4 additions & 2 deletions camel/agents/critic_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from colorama import Fore

from camel.agents import ChatAgent
from camel.memories import BaseMemory
from camel.messages import BaseMessage
from camel.responses import ChatAgentResponse
from camel.typing import ModelType
Expand Down Expand Up @@ -49,13 +50,14 @@ def __init__(
system_message: BaseMessage,
model: ModelType = ModelType.GPT_3_5_TURBO,
model_config: Optional[Any] = None,
memory: Optional[BaseMemory] = None,
message_window_size: int = 6,
retry_attempts: int = 2,
verbose: bool = False,
logger_color: Any = Fore.MAGENTA,
) -> None:
super().__init__(system_message, model=model,
model_config=model_config,
model_config=model_config, memory=memory,
message_window_size=message_window_size)
self.options_dict: Dict[str, str] = dict()
self.retry_attempts = retry_attempts
Expand Down Expand Up @@ -106,7 +108,7 @@ def get_option(self, input_message: BaseMessage) -> str:
raise RuntimeError("Critic step failed.")

critic_msg = critic_response.msg
self.update_messages('assistant', critic_msg)
self.record_message(critic_msg)
if self.verbose:
print_text_animated(self.logger_color + "\n> Critic response: "
f"\x1b[3m{critic_msg.content}\x1b[0m\n")
Expand Down
28 changes: 28 additions & 0 deletions camel/memories/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========

from .records import MemoryRecord, ContextRecord
from .base import BaseMemory
from .context_creators.base import BaseContextCreator
from .context_creators.score_based import ScoreBasedContextCreator
from .chat_history_memory import ChatHistoryMemory

__all__ = [
'MemoryRecord',
'ContextRecord',
'BaseMemory',
'ChatHistoryMemory',
"BaseContextCreator",
"ScoreBasedContextCreator",
]
Loading

0 comments on commit da517a8

Please sign in to comment.