Skip to content

Commit a9b29fa

Browse files
feature: Memory Truncation (CG-10994) (#838)
Truncates conversation history + summarizes in two instances 1. Number of messages passed in to the LLM exceeds `max_messages` which is default set to 100 2. The context limit for the LLM is approaching shortly (measured by counting tokens and measuring if the current message input tokens > `max_input_tokens_for_model - 10000`) Algorithm -> similar to [OpenHands summarizer](https://github.com/All-Hands-AI/OpenHands/blob/8043612420f0ef7d0c82ab5ab1c584385a8c57f3/openhands/memory/condenser/impl/llm_summarizing_condenser.py) 1. Keep the first two messages (System Prompt + Human Message) 2. Truncate the number of messages to almost half `(len(state[messages])) // 2)` and leave the last N messages untouched 3. Summarize the conversation in between the first two (`keep_first_messages`) until the `tail`. Misc - We are able to pass in `agent_config` param to `CodeAgent` if we want to play around with the max_messages, or `keep_first_messages` property <img width="697" alt="image" src="https://github.com/user-attachments/assets/d47450cf-f478-4bcd-b7cc-f8b397656870" /> Sample Trace (setting max_messages = 10) -> https://smith.langchain.com/public/60c0eef2-4e46-4506-86d7-9acd48b01cbe/r --------- Co-authored-by: jemeza-codegen <jmeza@codegen.com>
1 parent c68c1a5 commit a9b29fa

File tree

7 files changed

+267
-36
lines changed

7 files changed

+267
-36
lines changed

src/codegen/agents/code_agent.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from langchain.tools import BaseTool
66
from langchain_core.messages import AIMessage, HumanMessage
77
from langchain_core.runnables.config import RunnableConfig
8+
from langgraph.graph.graph import CompiledGraph
89
from langsmith import Client
910

1011
from codegen.extensions.langchain.agent import create_codebase_agent
@@ -15,16 +16,20 @@
1516
if TYPE_CHECKING:
1617
from codegen import Codebase
1718

19+
from codegen.agents.utils import AgentConfig
20+
1821

1922
class CodeAgent:
2023
"""Agent for interacting with a codebase."""
2124

2225
codebase: "Codebase"
23-
agent: any
26+
agent: CompiledGraph
2427
langsmith_client: Client
2528
project_name: str
2629
thread_id: str | None = None
27-
config: dict = {}
30+
run_id: str | None = None
31+
instance_id: str | None = None
32+
difficulty: int | None = None
2833

2934
def __init__(
3035
self,
@@ -35,6 +40,8 @@ def __init__(
3540
tools: Optional[list[BaseTool]] = None,
3641
tags: Optional[list[str]] = [],
3742
metadata: Optional[dict] = {},
43+
agent_config: Optional[AgentConfig] = None,
44+
thread_id: Optional[str] = None,
3845
**kwargs,
3946
):
4047
"""Initialize a CodeAgent.
@@ -60,15 +67,28 @@ def __init__(
6067
model_name=model_name,
6168
memory=memory,
6269
additional_tools=tools,
70+
config=agent_config,
6371
**kwargs,
6472
)
6573
self.model_name = model_name
6674
self.langsmith_client = Client()
6775

76+
if thread_id is None:
77+
self.thread_id = str(uuid4())
78+
else:
79+
self.thread_id = thread_id
80+
6881
# Get project name from environment variable or use a default
6982
self.project_name = os.environ.get("LANGCHAIN_PROJECT", "RELACE")
7083
print(f"Using LangSmith project: {self.project_name}")
7184

85+
# Store SWEBench metadata if provided
86+
self.run_id = metadata.get("run_id")
87+
self.instance_id = metadata.get("instance_id")
88+
# Extract difficulty value from "difficulty_X" format
89+
difficulty_str = metadata.get("difficulty", "")
90+
self.difficulty = int(difficulty_str.split("_")[1]) if difficulty_str and "_" in difficulty_str else None
91+
7292
# Initialize tags for agent trace
7393
self.tags = [*tags, self.model_name]
7494

@@ -79,7 +99,7 @@ def __init__(
7999
**metadata,
80100
}
81101

82-
def run(self, prompt: str, thread_id: Optional[str] = None) -> str:
102+
def run(self, prompt: str) -> str:
83103
"""Run the agent with a prompt.
84104
85105
Args:
@@ -89,12 +109,9 @@ def run(self, prompt: str, thread_id: Optional[str] = None) -> str:
89109
Returns:
90110
The agent's response
91111
"""
92-
if thread_id is None:
93-
thread_id = str(uuid4())
94-
self.thread_id = thread_id
95112
self.config = {
96113
"configurable": {
97-
"thread_id": thread_id,
114+
"thread_id": self.thread_id,
98115
"metadata": {"project": self.project_name},
99116
},
100117
"recursion_limit": 100,
@@ -104,15 +121,15 @@ def run(self, prompt: str, thread_id: Optional[str] = None) -> str:
104121
# see more https://langchain-ai.github.io/langgraph/concepts/low_level/#reducers
105122
input = {"query": prompt}
106123

107-
config = RunnableConfig(configurable={"thread_id": thread_id}, tags=self.tags, metadata=self.metadata, recursion_limit=100)
124+
config = RunnableConfig(configurable={"thread_id": self.thread_id}, tags=self.tags, metadata=self.metadata, recursion_limit=200)
108125
# we stream the steps instead of invoke because it allows us to access intermediate nodes
109126
stream = self.agent.stream(input, config=config, stream_mode="values")
110127

111128
# Keep track of run IDs from the stream
112129
run_ids = []
113130

114131
for s in stream:
115-
if len(s["messages"]) == 0:
132+
if len(s["messages"]) == 0 or isinstance(s["messages"][-1], HumanMessage):
116133
message = HumanMessage(content=prompt)
117134
else:
118135
message = s["messages"][-1]

src/codegen/agents/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from typing import TypedDict
2+
3+
4+
class AgentConfig(TypedDict, total=False):
5+
"""Configuration options for the CodeAgent."""
6+
7+
keep_first_messages: int # Number of initial messages to keep during summarization
8+
max_messages: int # Maximum number of messages before triggering summarization

src/codegen/extensions/langchain/agent.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
"""Demo implementation of an agent with Codegen tools."""
22

3-
from typing import TYPE_CHECKING, Optional
3+
from typing import TYPE_CHECKING, Any, Optional
44

55
from langchain.tools import BaseTool
66
from langchain_core.messages import SystemMessage
77
from langgraph.checkpoint.memory import MemorySaver
88
from langgraph.graph.graph import CompiledGraph
99

10+
from codegen.agents.utils import AgentConfig
1011
from codegen.extensions.langchain.llm import LLM
1112
from codegen.extensions.langchain.prompts import REASONER_SYSTEM_MESSAGE
1213
from codegen.extensions.langchain.tools import (
@@ -38,6 +39,7 @@ def create_codebase_agent(
3839
memory: bool = True,
3940
debug: bool = False,
4041
additional_tools: Optional[list[BaseTool]] = None,
42+
config: Optional[AgentConfig] = None,
4143
**kwargs,
4244
) -> CompiledGraph:
4345
"""Create an agent with all codebase tools.
@@ -57,7 +59,7 @@ def create_codebase_agent(
5759
Returns:
5860
Initialized agent with message history
5961
"""
60-
llm = LLM(model_provider=model_provider, model_name=model_name, max_tokens=8192, **kwargs)
62+
llm = LLM(model_provider=model_provider, model_name=model_name, **kwargs)
6163

6264
# Get all codebase tools
6365
tools = [
@@ -89,7 +91,7 @@ def create_codebase_agent(
8991

9092
memory = MemorySaver() if memory else None
9193

92-
return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug)
94+
return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug, config=config)
9395

9496

9597
def create_chat_agent(
@@ -100,6 +102,7 @@ def create_chat_agent(
100102
memory: bool = True,
101103
debug: bool = False,
102104
additional_tools: Optional[list[BaseTool]] = None,
105+
config: Optional[dict[str, Any]] = None, # over here you can pass in the max length of the number of messages
103106
**kwargs,
104107
) -> CompiledGraph:
105108
"""Create an agent with all codebase tools.
@@ -138,7 +141,7 @@ def create_chat_agent(
138141

139142
memory = MemorySaver() if memory else None
140143

141-
return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug)
144+
return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug, config=config)
142145

143146

144147
def create_codebase_inspector_agent(
@@ -148,6 +151,7 @@ def create_codebase_inspector_agent(
148151
system_message: SystemMessage = SystemMessage(REASONER_SYSTEM_MESSAGE),
149152
memory: bool = True,
150153
debug: bool = True,
154+
config: Optional[dict[str, Any]] = None,
151155
**kwargs,
152156
) -> CompiledGraph:
153157
"""Create an inspector agent with read-only codebase tools.
@@ -175,7 +179,7 @@ def create_codebase_inspector_agent(
175179
]
176180

177181
memory = MemorySaver() if memory else None
178-
return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug)
182+
return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug, config=config)
179183

180184

181185
def create_agent_with_tools(
@@ -185,6 +189,7 @@ def create_agent_with_tools(
185189
system_message: SystemMessage = SystemMessage(REASONER_SYSTEM_MESSAGE),
186190
memory: bool = True,
187191
debug: bool = True,
192+
config: Optional[dict[str, Any]] = None,
188193
**kwargs,
189194
) -> CompiledGraph:
190195
"""Create an agent with a specific set of tools.
@@ -209,4 +214,4 @@ def create_agent_with_tools(
209214

210215
memory = MemorySaver() if memory else None
211216

212-
return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug)
217+
return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug, config=config)

0 commit comments

Comments
 (0)