From 87f210e85dc660b7c5ec619e5b138c3e3dd6b5da Mon Sep 17 00:00:00 2001 From: tawsifkamal Date: Wed, 12 Mar 2025 08:51:20 -0700 Subject: [PATCH 1/6] done --- .../swebench_agent_run/local_run.ipynb | 73 +++++++++++ src/codegen/agents/code_agent.py | 35 ++++-- src/codegen/agents/utils.py | 8 ++ src/codegen/extensions/langchain/agent.py | 15 ++- src/codegen/extensions/langchain/graph.py | 114 ++++++++++++++++-- src/codegen/extensions/langchain/prompts.py | 32 +++++ 6 files changed, 254 insertions(+), 23 deletions(-) create mode 100644 src/codegen/agents/utils.py diff --git a/codegen-examples/examples/swebench_agent_run/local_run.ipynb b/codegen-examples/examples/swebench_agent_run/local_run.ipynb index 0b212fa40..023522afe 100644 --- a/codegen-examples/examples/swebench_agent_run/local_run.ipynb +++ b/codegen-examples/examples/swebench_agent_run/local_run.ipynb @@ -34,6 +34,79 @@ "source": [ "await run_eval(use_existing_preds=None, dataset=\"lite\", length=20, repo=\"django/django\", num_workers=10, model=\"claude-3-7-sonnet-latest\")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from codegen import CodeAgent, Codebase" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "codebase = Codebase(\"./\")\n", + "agent = CodeAgent(codebase, agent_config={\"keep_first_messages\": 1, \"max_messages\": 4})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "agent.run(\"Hello my name is Tawsif\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "agent.run(\"What's my name?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "agent.run(\"I have a door dash delivery coming up!\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "agent.get_state().values[\"messages\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "agent.run(\"Whats my name? please tell me my name\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "agent.get_state().values" + ] } ], "metadata": { diff --git a/src/codegen/agents/code_agent.py b/src/codegen/agents/code_agent.py index 99c01a874..2f3bbb2dc 100644 --- a/src/codegen/agents/code_agent.py +++ b/src/codegen/agents/code_agent.py @@ -5,6 +5,7 @@ from langchain.tools import BaseTool from langchain_core.messages import AIMessage, HumanMessage from langchain_core.runnables.config import RunnableConfig +from langgraph.graph.graph import CompiledGraph from langsmith import Client from codegen.extensions.langchain.agent import create_codebase_agent @@ -15,16 +16,20 @@ if TYPE_CHECKING: from codegen import Codebase +from codegen.agents.utils import AgentConfig + class CodeAgent: """Agent for interacting with a codebase.""" codebase: "Codebase" - agent: any + agent: CompiledGraph langsmith_client: Client project_name: str thread_id: str | None = None - config: dict = {} + run_id: str | None = None + instance_id: str | None = None + difficulty: int | None = None def __init__( self, @@ -35,6 +40,8 @@ def __init__( tools: Optional[list[BaseTool]] = None, tags: Optional[list[str]] = [], metadata: Optional[dict] = {}, + agent_config: Optional[AgentConfig] = None, + thread_id: Optional[str] = None, **kwargs, ): """Initialize a CodeAgent. @@ -60,15 +67,28 @@ def __init__( model_name=model_name, memory=memory, additional_tools=tools, + config=agent_config, **kwargs, ) self.model_name = model_name self.langsmith_client = Client() + if thread_id is None: + self.thread_id = str(uuid4()) + else: + self.thread_id = thread_id + # Get project name from environment variable or use a default self.project_name = os.environ.get("LANGCHAIN_PROJECT", "RELACE") print(f"Using LangSmith project: {self.project_name}") + # Store SWEBench metadata if provided + self.run_id = metadata.get("run_id") + self.instance_id = metadata.get("instance_id") + # Extract difficulty value from "difficulty_X" format + difficulty_str = metadata.get("difficulty", "") + self.difficulty = int(difficulty_str.split("_")[1]) if difficulty_str and "_" in difficulty_str else None + # Initialize tags for agent trace self.tags = [*tags, self.model_name] @@ -79,7 +99,7 @@ def __init__( **metadata, } - def run(self, prompt: str, thread_id: Optional[str] = None) -> str: + def run(self, prompt: str) -> str: """Run the agent with a prompt. Args: @@ -89,12 +109,9 @@ def run(self, prompt: str, thread_id: Optional[str] = None) -> str: Returns: The agent's response """ - if thread_id is None: - thread_id = str(uuid4()) - self.thread_id = thread_id self.config = { "configurable": { - "thread_id": thread_id, + "thread_id": self.thread_id, "metadata": {"project": self.project_name}, }, "recursion_limit": 100, @@ -104,7 +121,7 @@ def run(self, prompt: str, thread_id: Optional[str] = None) -> str: # see more https://langchain-ai.github.io/langgraph/concepts/low_level/#reducers input = {"query": prompt} - config = RunnableConfig(configurable={"thread_id": thread_id}, tags=self.tags, metadata=self.metadata, recursion_limit=100) + config = RunnableConfig(configurable={"thread_id": self.thread_id}, tags=self.tags, metadata=self.metadata, recursion_limit=200) # we stream the steps instead of invoke because it allows us to access intermediate nodes stream = self.agent.stream(input, config=config, stream_mode="values") @@ -112,7 +129,7 @@ def run(self, prompt: str, thread_id: Optional[str] = None) -> str: run_ids = [] for s in stream: - if len(s["messages"]) == 0: + if len(s["messages"]) == 0 or isinstance(s["messages"][-1], HumanMessage): message = HumanMessage(content=prompt) else: message = s["messages"][-1] diff --git a/src/codegen/agents/utils.py b/src/codegen/agents/utils.py new file mode 100644 index 000000000..e5cdcebaf --- /dev/null +++ b/src/codegen/agents/utils.py @@ -0,0 +1,8 @@ +from typing import TypedDict + + +class AgentConfig(TypedDict, total=False): + """Configuration options for the CodeAgent.""" + + keep_first_messages: int # Number of initial messages to keep during summarization + max_messages: int # Maximum number of messages before triggering summarization diff --git a/src/codegen/extensions/langchain/agent.py b/src/codegen/extensions/langchain/agent.py index fe44594b1..8917daa7f 100644 --- a/src/codegen/extensions/langchain/agent.py +++ b/src/codegen/extensions/langchain/agent.py @@ -1,12 +1,13 @@ """Demo implementation of an agent with Codegen tools.""" -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Optional from langchain.tools import BaseTool from langchain_core.messages import SystemMessage from langgraph.checkpoint.memory import MemorySaver from langgraph.graph.graph import CompiledGraph +from codegen.agents.utils import AgentConfig from codegen.extensions.langchain.llm import LLM from codegen.extensions.langchain.prompts import REASONER_SYSTEM_MESSAGE from codegen.extensions.langchain.tools import ( @@ -38,6 +39,7 @@ def create_codebase_agent( memory: bool = True, debug: bool = False, additional_tools: Optional[list[BaseTool]] = None, + config: Optional[AgentConfig] = None, **kwargs, ) -> CompiledGraph: """Create an agent with all codebase tools. @@ -89,7 +91,7 @@ def create_codebase_agent( memory = MemorySaver() if memory else None - return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug) + return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug, config=config) def create_chat_agent( @@ -100,6 +102,7 @@ def create_chat_agent( memory: bool = True, debug: bool = False, additional_tools: Optional[list[BaseTool]] = None, + config: Optional[dict[str, Any]] = None, # over here you can pass in the max length of the number of messages **kwargs, ) -> CompiledGraph: """Create an agent with all codebase tools. @@ -138,7 +141,7 @@ def create_chat_agent( memory = MemorySaver() if memory else None - return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug) + return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug, config=config) def create_codebase_inspector_agent( @@ -148,6 +151,7 @@ def create_codebase_inspector_agent( system_message: SystemMessage = SystemMessage(REASONER_SYSTEM_MESSAGE), memory: bool = True, debug: bool = True, + config: Optional[dict[str, Any]] = None, **kwargs, ) -> CompiledGraph: """Create an inspector agent with read-only codebase tools. @@ -175,7 +179,7 @@ def create_codebase_inspector_agent( ] memory = MemorySaver() if memory else None - return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug) + return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug, config=config) def create_agent_with_tools( @@ -185,6 +189,7 @@ def create_agent_with_tools( system_message: SystemMessage = SystemMessage(REASONER_SYSTEM_MESSAGE), memory: bool = True, debug: bool = True, + config: Optional[dict[str, Any]] = None, **kwargs, ) -> CompiledGraph: """Create an agent with a specific set of tools. @@ -209,4 +214,4 @@ def create_agent_with_tools( memory = MemorySaver() if memory else None - return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug) + return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug, config=config) diff --git a/src/codegen/extensions/langchain/graph.py b/src/codegen/extensions/langchain/graph.py index e5116630f..fed38f04e 100644 --- a/src/codegen/extensions/langchain/graph.py +++ b/src/codegen/extensions/langchain/graph.py @@ -1,6 +1,7 @@ """Demo implementation of an agent with Codegen tools.""" -from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional +import uuid +from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, Union import anthropic import openai @@ -8,27 +9,77 @@ from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, SystemMessage from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, START -from langgraph.graph.message import add_messages from langgraph.graph.state import CompiledGraph, StateGraph from langgraph.prebuilt import ToolNode from langgraph.pregel import RetryPolicy +from codegen.agents.utils import AgentConfig +from codegen.extensions.langchain.prompts import SUMMARIZE_CONVERSATION_PROMPT + + +def manage_messages(existing: list[AnyMessage], updates: Union[list[AnyMessage], dict]) -> list[AnyMessage]: + """Custom reducer for managing message history with summarization. + + Args: + existing: Current list of messages + updates: Either new messages to append or a dict specifying how to update messages + + Returns: + Updated list of messages + """ + if isinstance(updates, list): + # Ensure all messages have IDs + for msg in existing + updates: + if not hasattr(msg, "id") or msg.id is None: + msg.id = str(uuid.uuid4()) + + # Create a map of existing messages by ID + existing_by_id = {msg.id: i for i, msg in enumerate(existing)} + + # Start with copy of existing messages + result = existing.copy() + + # Update or append new messages + for msg in updates: + if msg.id in existing_by_id: + # Update existing message + result[existing_by_id[msg.id]] = msg + else: + # Append new message + result.append(msg) + + return result + + if isinstance(updates, dict): + if updates.get("type") == "summarize": + # Add summary message + summary_msg = AIMessage(content=f"This next few sections contain a summary of the conversation: \n{updates['summary']}") + summary_msg.id = str(uuid.uuid4()) # Ensure summary message has ID + result = updates["head"] + [summary_msg] + updates["tail"] + return result + + return existing + class GraphState(dict[str, Any]): """State of the graph.""" + summary: str query: str final_answer: str - messages: Annotated[list[AnyMessage], add_messages] + messages: Annotated[list[AnyMessage], manage_messages] class AgentGraph: """Main graph class for the agent.""" - def __init__(self, model: "LLM", tools: list[BaseTool], system_message: SystemMessage): + def __init__(self, model: "LLM", tools: list[BaseTool], system_message: SystemMessage, config: AgentConfig | None = None): self.model = model.bind_tools(tools) self.tools = tools self.system_message = system_message + self.config = config + self.max_messages = config.get("max_messages", 100) if config else 100 + self.keep_first_messages = config.get("keep_first_messages", 1) if config else 1 # =================================== NODES ==================================== @@ -36,23 +87,65 @@ def __init__(self, model: "LLM", tools: list[BaseTool], system_message: SystemMe def reasoner(self, state: GraphState) -> dict[str, Any]: new_turn = len(state["messages"]) == 0 or isinstance(state["messages"][-1], AIMessage) messages = state["messages"] + if new_turn: query = state["query"] messages.append(HumanMessage(content=query)) result = self.model.invoke([self.system_message, *messages]) - if isinstance(result, AIMessage): - return {"messages": [*messages, result], "final_answer": result.content} + updated_messages = [*messages, result] + return {"messages": updated_messages, "final_answer": result.content} - return {"messages": [*messages, result]} + updated_messages = [*messages, result] + return {"messages": updated_messages} + + def summarize_conversation(self, state: GraphState): + """Summarize conversation while preserving key context and recent messages.""" + messages = state["messages"] + keep_first = self.keep_first_messages # Keep system prompt and initial user message + target_size = self.max_messages // 2 + messages_from_tail = target_size - keep_first + + # If we don't have enough messages to require summarization + if len(messages) <= self.max_messages: + return state + + head = messages[:keep_first] # gets first message (human instruction) + tail = messages[-messages_from_tail:] # gets last 48 messages (default implementation with 100 max messages) + to_summarize = messages[keep_first:-messages_from_tail] # gets middle messages to summarize -> len(messages) - (len(tail) + len(head)) + + # Skip if nothing to summarize + if not to_summarize: + return state + + summary = state.get("summary", "") + summary_prompt = SUMMARIZE_CONVERSATION_PROMPT + if summary: + summary_prompt += f"\n\nPrevious summary: {summary}\n\nExtend this summary with the new conversation:" + + # Convert messages to string format for summarization + conversation = "\n".join(f"{msg.type}: {msg.content}" for msg in to_summarize) + + messages_for_summary = [SystemMessage(content=summary_prompt), HumanMessage(content="Summarize the following conversation: \n\n" + conversation)] + + response = self.model.invoke(messages_for_summary) + new_summary = response.content + + return {"messages": {"type": "summarize", "summary": new_summary, "tail": tail, "head": head}, "summary": new_summary} # =================================== EDGE CONDITIONS ==================================== - def should_continue(self, state: GraphState) -> Literal["tools", END]: + def should_continue(self, state: GraphState) -> Literal["tools", "summarize_conversation", END]: messages = state["messages"] last_message = messages[-1] + + # If the message count exceeds the limit, summarize before performing tool call + if len(messages) > self.max_messages: + return "summarize_conversation" + if hasattr(last_message, "tool_calls") and last_message.tool_calls: return "tools" + return END # =================================== COMPILE GRAPH ==================================== @@ -74,6 +167,7 @@ def create(self, checkpointer: Optional[MemorySaver] = None, debug: bool = False # Add nodes builder.add_node("reasoner", self.reasoner, retry=retry_policy) builder.add_node("tools", ToolNode(self.tools), retry=retry_policy) + builder.add_node("summarize_conversation", self.summarize_conversation, retry=retry_policy) # Add edges builder.add_edge(START, "reasoner") @@ -82,6 +176,7 @@ def create(self, checkpointer: Optional[MemorySaver] = None, debug: bool = False "reasoner", self.should_continue, ) + builder.add_conditional_edges("summarize_conversation", self.should_continue) return builder.compile(checkpointer=checkpointer, debug=debug) @@ -92,9 +187,10 @@ def create_react_agent( system_message: SystemMessage, checkpointer: Optional[MemorySaver] = None, debug: bool = False, + config: Optional[dict[str, Any]] = None, ) -> CompiledGraph: """Create a reactive agent graph.""" - graph = AgentGraph(model, tools, system_message) + graph = AgentGraph(model, tools, system_message, config=config) return graph.create(checkpointer=checkpointer, debug=debug) diff --git a/src/codegen/extensions/langchain/prompts.py b/src/codegen/extensions/langchain/prompts.py index a33157bcd..28e11d1fc 100644 --- a/src/codegen/extensions/langchain/prompts.py +++ b/src/codegen/extensions/langchain/prompts.py @@ -43,3 +43,35 @@ Ensure if specifiying line numbers, it's chosen with room (around 20 lines before and 20 lines after the edit range) """ + + +SUMMARIZE_CONVERSATION_PROMPT = """ + You are maintaining state history for an LLM-based code agent. + YOU ARE A SUMMARIZER. PLEASE FOLLOW THE FOLLOWING GUIDELINES FOR YOUR SUMMARIZATION TASK. The following messages that you get you will need to summarize! + + Track: + + USER_CONTEXT: (Preserve essential user requirements, problem descriptions, and clarifications in concise form) + + STATE: {File paths, function signatures, data structures} + TESTS: {Failing cases, error messages, outputs} + CHANGES: {Code edits, variable updates} + DEPS: {Dependencies, imports, external calls} + INTENT: {Why changes were made, acceptance criteria} + + PRIORITIZE: + 1. Capture key user requirements and constraints + 2. Maintain critical problem context + 3. Keep all sections concise + + SKIP: {Git clones, build logs, file listings} + + Example history format: + USER_CONTEXT: Fix FITS card float representation - "0.009125" becomes "0.009124999999999999" causing comment truncation. Use Python's str() when possible while maintaining FITS compliance. + + STATE: mod_float() in card.py updated + TESTS: test_format() passed + CHANGES: str(val) replaces f"{val:.16G}" + DEPS: None modified + INTENT: Fix precision while maintaining FITS compliance +""" From f99b33c093ae784543a3145b673efc1b4dba6fb9 Mon Sep 17 00:00:00 2001 From: tawsifkamal Date: Thu, 13 Mar 2025 15:37:38 -0700 Subject: [PATCH 2/6] done --- src/codegen/extensions/langchain/graph.py | 98 +++++++++++++------ src/codegen/extensions/langchain/prompts.py | 46 +++++---- .../extensions/langchain/utils/utils.py | 21 ++++ 3 files changed, 116 insertions(+), 49 deletions(-) create mode 100644 src/codegen/extensions/langchain/utils/utils.py diff --git a/src/codegen/extensions/langchain/graph.py b/src/codegen/extensions/langchain/graph.py index 944aa9bea..5d4f11314 100644 --- a/src/codegen/extensions/langchain/graph.py +++ b/src/codegen/extensions/langchain/graph.py @@ -6,7 +6,8 @@ import anthropic import openai from langchain.tools import BaseTool -from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, SystemMessage +from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, SystemMessage, ToolMessage +from langchain_core.prompts import ChatPromptTemplate from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, START from langgraph.graph.state import CompiledGraph, StateGraph @@ -16,6 +17,7 @@ from codegen.agents.utils import AgentConfig from codegen.extensions.langchain.llm import LLM from codegen.extensions.langchain.prompts import SUMMARIZE_CONVERSATION_PROMPT +from codegen.extensions.langchain.utils.utils import get_max_model_input_tokens def manage_messages(existing: list[AnyMessage], updates: Union[list[AnyMessage], dict]) -> list[AnyMessage]: @@ -53,9 +55,13 @@ def manage_messages(existing: list[AnyMessage], updates: Union[list[AnyMessage], if isinstance(updates, dict): if updates.get("type") == "summarize": - # Add summary message - summary_msg = AIMessage(content=f"This next few sections contain a summary of the conversation: \n{updates['summary']}") - summary_msg.id = str(uuid.uuid4()) # Ensure summary message has ID + # Create summary message and mark it with additional_kwargs + summary_msg = AIMessage( + content=f"""Here is a summary of the conversation + from a previous timestep to aid for the continuing conversation: \n{updates["summary"]}\n\n""", + additional_kwargs={"is_summary": True}, # Use additional_kwargs for custom metadata + ) + summary_msg.id = str(uuid.uuid4()) result = updates["head"] + [summary_msg] + updates["tail"] return result @@ -65,7 +71,6 @@ def manage_messages(existing: list[AnyMessage], updates: Union[list[AnyMessage], class GraphState(dict[str, Any]): """State of the graph.""" - summary: str query: str final_answer: str messages: Annotated[list[AnyMessage], manage_messages] @@ -104,54 +109,90 @@ def reasoner(self, state: GraphState) -> dict[str, Any]: def summarize_conversation(self, state: GraphState): """Summarize conversation while preserving key context and recent messages.""" messages = state["messages"] - keep_first = self.keep_first_messages # Keep system prompt and initial user message + keep_first = self.keep_first_messages target_size = self.max_messages // 2 messages_from_tail = target_size - keep_first - # If we don't have enough messages to require summarization - if len(messages) <= self.max_messages: - return state + head = messages[:keep_first] + tail = messages[-messages_from_tail:] + to_summarize = messages[: len(messages) - messages_from_tail] - head = messages[:keep_first] # gets first message (human instruction) - tail = messages[-messages_from_tail:] # gets last 48 messages (default implementation with 100 max messages) - to_summarize = messages[keep_first:-messages_from_tail] # gets middle messages to summarize -> len(messages) - (len(tail) + len(head)) + # Handle tool message pairing at truncation point + truncation_idx = len(messages) - messages_from_tail + if truncation_idx > 0 and isinstance(messages[truncation_idx], ToolMessage): + # Keep the AI message right before it + tail = [messages[truncation_idx - 1], *tail] # Skip if nothing to summarize if not to_summarize: return state - summary = state.get("summary", "") - summary_prompt = SUMMARIZE_CONVERSATION_PROMPT - if summary: - summary_prompt += f"\n\nPrevious summary: {summary}\n\nExtend this summary with the new conversation:" - - # Convert messages to string format for summarization - conversation = "\n".join(f"{msg.type}: {msg.content}" for msg in to_summarize) + # Define constants + HEADER_WIDTH = 40 + HEADER_TYPES = {"human": "HUMAN", "ai": "AI", "summary": "SUMMARY FROM PREVIOUS TIMESTEP", "tool_call": "TOOL CALL", "tool_response": "TOOL RESPONSE"} + + def format_header(header_type: str) -> str: + """Format message header with consistent padding. + + Args: + header_type: Type of header to format (must be one of HEADER_TYPES) + + Returns: + Formatted header string with padding + """ + header = HEADER_TYPES[header_type] + padding = "=" * ((HEADER_WIDTH - len(header)) // 2) + return f"{padding} {header} {padding}\n" + + # Format messages with appropriate headers + formatted_messages = [] + for msg in to_summarize: # No need for slice when iterating full list + if isinstance(msg, HumanMessage): + formatted_messages.append(format_header("human") + msg.content) + elif isinstance(msg, AIMessage): + # Check for summary message using additional_kwargs + if msg.additional_kwargs.get("is_summary"): + formatted_messages.append(format_header("summary") + msg.content) + elif isinstance(msg.content, list) and len(msg.content) > 0 and isinstance(msg.content[0], dict): + for item in msg.content: # No need for slice when iterating full list + if item.get("type") == "text": + formatted_messages.append(format_header("ai") + item["text"]) + elif item.get("type") == "tool_use": + formatted_messages.append(format_header("tool_call") + f"Tool: {item['name']}\nInput: {item['input']}") + else: + formatted_messages.append(format_header("ai") + msg.content) + elif isinstance(msg, ToolMessage): + formatted_messages.append(format_header("tool_response") + msg.content) - messages_for_summary = [SystemMessage(content=summary_prompt), HumanMessage(content="Summarize the following conversation: \n\n" + conversation)] + conversation = "\n".join(formatted_messages) # No need for slice when joining full list - # Initialize the LLM summary_llm = LLM( model_provider="anthropic", model_name="claude-3-5-sonnet-latest", - temperature=0.2, # Slightly higher temperature for more creative reflection + temperature=0.3, max_tokens=8012, ) - response = summary_llm.invoke(messages_for_summary) - new_summary = response.content - return {"messages": {"type": "summarize", "summary": new_summary, "tail": tail, "head": head}, "summary": new_summary} + chain = ChatPromptTemplate.from_template(SUMMARIZE_CONVERSATION_PROMPT) | summary_llm + new_summary = chain.invoke({"conversation": conversation}).content + + return {"messages": {"type": "summarize", "summary": new_summary, "tail": tail, "head": head}} # =================================== EDGE CONDITIONS ==================================== def should_continue(self, state: GraphState) -> Literal["tools", "summarize_conversation", END]: messages = state["messages"] last_message = messages[-1] - # If the message count exceeds the limit, summarize before performing tool call + # Summarize if the number of messages passed in exceeds the max_messages threshold (default 100) if len(messages) > self.max_messages: return "summarize_conversation" - if hasattr(last_message, "tool_calls") and last_message.tool_calls: + # Summarize if the last message exceeds the max input tokens of the model - 10000 tokens + elif isinstance(last_message, AIMessage) and last_message.usage_metadata["input_tokens"] > get_max_model_input_tokens(self.model) - 10000: + print("here lol") + return "summarize_conversation" + + elif hasattr(last_message, "tool_calls") and last_message.tool_calls: return "tools" return END @@ -164,7 +205,7 @@ def create(self, checkpointer: Optional[MemorySaver] = None, debug: bool = False # the retry policy has an initial interval, a backoff factor, and a max interval of controlling the # amount of time between retries retry_policy = RetryPolicy( - retry_on=[anthropic.RateLimitError, openai.RateLimitError, anthropic.InternalServerError], + retry_on=[anthropic.RateLimitError, openai.RateLimitError, anthropic.InternalServerError, anthropic.BadRequestError], max_attempts=10, initial_interval=30.0, # Start with 30 second wait backoff_factor=2, # Double the wait time each retry @@ -187,6 +228,7 @@ def get_field_descriptions(tool_obj): return field_descriptions try: + # Get all field descriptions from the tool schema_cls = tool_obj.args_schema # Handle Pydantic v2 diff --git a/src/codegen/extensions/langchain/prompts.py b/src/codegen/extensions/langchain/prompts.py index 28e11d1fc..3c9ba9744 100644 --- a/src/codegen/extensions/langchain/prompts.py +++ b/src/codegen/extensions/langchain/prompts.py @@ -46,32 +46,36 @@ SUMMARIZE_CONVERSATION_PROMPT = """ - You are maintaining state history for an LLM-based code agent. - YOU ARE A SUMMARIZER. PLEASE FOLLOW THE FOLLOWING GUIDELINES FOR YOUR SUMMARIZATION TASK. The following messages that you get you will need to summarize! + You are an expert conversation summarizer. You are given below a conversation between an AI coding agent and a human. + It contains a human request and the agent thought process + alternating from AIMessage, ToolMessage, HumanMessage, etc. - Track: + This AI agent is an expert software engineer with deep knowledge of code analysis, refactoring, and development best practices. - USER_CONTEXT: (Preserve essential user requirements, problem descriptions, and clarifications in concise form) + Your goal as the summarizer is to summarize the conversation between the AI agent and the human in an extremely detailed and comprehensive manner. - STATE: {File paths, function signatures, data structures} - TESTS: {Failing cases, error messages, outputs} - CHANGES: {Code edits, variable updates} - DEPS: {Dependencies, imports, external calls} - INTENT: {Why changes were made, acceptance criteria} + Ensure the summary includes key details of the conversation, such as: + - User's request and context + - Code changes and their impact + - File and directory structure + - Dependencies and imports + - Any errors or exceptions + - User's clarifications and follow-up questions + - File modifications and their impact + - Any other relevant - PRIORITIZE: - 1. Capture key user requirements and constraints - 2. Maintain critical problem context - 3. Keep all sections concise + IMPORTANT: Your summary must be at least 4000 words long to ensure that you have added a lot of useful information to it. + Ensure your summary is very detailed and comprehensive. It's important to capture all the context of the conversation. - SKIP: {Git clones, build logs, file listings} + IMPORTANT: Do not attempt to provide any solutions or any other unnecessary commentary. Your sole job is to summarize the conversation in the most detailed way possible + IMPORTANT: Your summary will be fed back into the LLM to continue the conversation so that it has the context of the conversation instead of having to store the whole history. + That's why your summary does not signal the end of the conversation. It will be used the the agent to further inch towards the goal of solving the user's issue. - Example history format: - USER_CONTEXT: Fix FITS card float representation - "0.009125" becomes "0.009124999999999999" causing comment truncation. Use Python's str() when possible while maintaining FITS compliance. + IMPORTANT: The conversation given may include previous summaries generated by you in an earlier time step of the conversation. Use this to your advantage + alongside the conversation to generate a more comprehensive summary of the entire conversation. - STATE: mod_float() in card.py updated - TESTS: test_format() passed - CHANGES: str(val) replaces f"{val:.16G}" - DEPS: None modified - INTENT: Fix precision while maintaining FITS compliance + Here is the conversation given below: + + {conversation} + """ diff --git a/src/codegen/extensions/langchain/utils/utils.py b/src/codegen/extensions/langchain/utils/utils.py new file mode 100644 index 000000000..1de9316f7 --- /dev/null +++ b/src/codegen/extensions/langchain/utils/utils.py @@ -0,0 +1,21 @@ +from langchain_core.language_models import LLM + + +def get_max_model_input_tokens(llm: LLM) -> int: + """Get the maximum input tokens for the current model. + + Returns: + int: Maximum number of input tokens supported by the model + """ + # For Claude models not explicitly listed, if model name contains "claude", use Claude's limit + if "claude" in llm.model.lower(): + return 50000 + # For GPT-4 models + elif "gpt-4" in llm.model.lower(): + return 128000 + # For Grok models + elif "grok" in llm.model.lower(): + return 1000000 + + # default to gpt as it's lower bound + return 128000 From 5a8f44a78b6ca6e13545724d71d445aaa0db30d6 Mon Sep 17 00:00:00 2001 From: tawsifkamal Date: Thu, 13 Mar 2025 15:50:56 -0700 Subject: [PATCH 3/6] fixing max messages bug --- src/codegen/extensions/langchain/graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/codegen/extensions/langchain/graph.py b/src/codegen/extensions/langchain/graph.py index 5d4f11314..c57aadbcc 100644 --- a/src/codegen/extensions/langchain/graph.py +++ b/src/codegen/extensions/langchain/graph.py @@ -110,7 +110,7 @@ def summarize_conversation(self, state: GraphState): """Summarize conversation while preserving key context and recent messages.""" messages = state["messages"] keep_first = self.keep_first_messages - target_size = self.max_messages // 2 + target_size = len(messages) // 2 messages_from_tail = target_size - keep_first head = messages[:keep_first] From c078827e435cb4d9ffbc0a0006026d4287db7690 Mon Sep 17 00:00:00 2001 From: tawsifkamal Date: Fri, 14 Mar 2025 09:32:29 -0700 Subject: [PATCH 4/6] fix the error where token usage wasn't updates --- src/codegen/extensions/langchain/graph.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/codegen/extensions/langchain/graph.py b/src/codegen/extensions/langchain/graph.py index c57aadbcc..f02bed4c3 100644 --- a/src/codegen/extensions/langchain/graph.py +++ b/src/codegen/extensions/langchain/graph.py @@ -62,6 +62,7 @@ def manage_messages(existing: list[AnyMessage], updates: Union[list[AnyMessage], additional_kwargs={"is_summary": True}, # Use additional_kwargs for custom metadata ) summary_msg.id = str(uuid.uuid4()) + updates["tail"][-1].additional_kwargs["just_summarized"] = True result = updates["head"] + [summary_msg] + updates["tail"] return result @@ -182,14 +183,16 @@ def format_header(header_type: str) -> str: def should_continue(self, state: GraphState) -> Literal["tools", "summarize_conversation", END]: messages = state["messages"] last_message = messages[-1] + just_summarized = last_message.additional_kwargs.get("just_summarized") + curr_input_tokens = last_message.usage_metadata["input_tokens"] + max_input_tokens = get_max_model_input_tokens(self.model) # Summarize if the number of messages passed in exceeds the max_messages threshold (default 100) if len(messages) > self.max_messages: return "summarize_conversation" # Summarize if the last message exceeds the max input tokens of the model - 10000 tokens - elif isinstance(last_message, AIMessage) and last_message.usage_metadata["input_tokens"] > get_max_model_input_tokens(self.model) - 10000: - print("here lol") + elif isinstance(last_message, AIMessage) and not just_summarized and curr_input_tokens > (max_input_tokens - 10000): return "summarize_conversation" elif hasattr(last_message, "tool_calls") and last_message.tool_calls: From 045c112a4d07d6c2b2fa58b4b5f74e209036e5a7 Mon Sep 17 00:00:00 2001 From: jemeza-codegen Date: Fri, 14 Mar 2025 11:25:39 -0700 Subject: [PATCH 5/6] chore: sets token limit in llm.py --- src/codegen/agents/code_agent.py | 1 - src/codegen/extensions/langchain/agent.py | 2 +- src/codegen/extensions/langchain/graph.py | 1 - src/codegen/extensions/langchain/llm.py | 6 +++--- 4 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/codegen/agents/code_agent.py b/src/codegen/agents/code_agent.py index e7a08b84a..717c507b2 100644 --- a/src/codegen/agents/code_agent.py +++ b/src/codegen/agents/code_agent.py @@ -68,7 +68,6 @@ def __init__( memory=memory, additional_tools=tools, config=agent_config, - max_tokens=8192, **kwargs, ) self.model_name = model_name diff --git a/src/codegen/extensions/langchain/agent.py b/src/codegen/extensions/langchain/agent.py index dae7eb746..8917daa7f 100644 --- a/src/codegen/extensions/langchain/agent.py +++ b/src/codegen/extensions/langchain/agent.py @@ -59,7 +59,7 @@ def create_codebase_agent( Returns: Initialized agent with message history """ - llm = LLM(model_provider=model_provider, model_name=model_name, max_tokens=8192, **kwargs) + llm = LLM(model_provider=model_provider, model_name=model_name, **kwargs) # Get all codebase tools tools = [ diff --git a/src/codegen/extensions/langchain/graph.py b/src/codegen/extensions/langchain/graph.py index f02bed4c3..2987f6863 100644 --- a/src/codegen/extensions/langchain/graph.py +++ b/src/codegen/extensions/langchain/graph.py @@ -171,7 +171,6 @@ def format_header(header_type: str) -> str: model_provider="anthropic", model_name="claude-3-5-sonnet-latest", temperature=0.3, - max_tokens=8012, ) chain = ChatPromptTemplate.from_template(SUMMARIZE_CONVERSATION_PROMPT) | summary_llm diff --git a/src/codegen/extensions/langchain/llm.py b/src/codegen/extensions/langchain/llm.py index 54b9a91a2..0d4795740 100644 --- a/src/codegen/extensions/langchain/llm.py +++ b/src/codegen/extensions/langchain/llm.py @@ -89,19 +89,19 @@ def _get_model(self) -> BaseChatModel: if not os.getenv("ANTHROPIC_API_KEY"): msg = "ANTHROPIC_API_KEY not found in environment. Please set it in your .env file or environment variables." raise ValueError(msg) - return ChatAnthropic(**self._get_model_kwargs(), max_retries=10, timeout=1000) + return ChatAnthropic(**self._get_model_kwargs(), max_tokens=8192, max_retries=10, timeout=1000) elif self.model_provider == "openai": if not os.getenv("OPENAI_API_KEY"): msg = "OPENAI_API_KEY not found in environment. Please set it in your .env file or environment variables." raise ValueError(msg) - return ChatOpenAI(**self._get_model_kwargs(), max_retries=10, timeout=1000) + return ChatOpenAI(**self._get_model_kwargs(), max_tokens=4096, max_retries=10, timeout=1000) elif self.model_provider == "xai": if not os.getenv("XAI_API_KEY"): msg = "XAI_API_KEY not found in environment. Please set it in your .env file or environment variables." raise ValueError(msg) - return ChatXAI(**self._get_model_kwargs()) + return ChatXAI(**self._get_model_kwargs(), max_tokens=8192) msg = f"Unknown model provider: {self.model_provider}. Must be one of: anthropic, openai, xai" raise ValueError(msg) From 7e14bb83eb378897e2426f4435dde955bb7bb078 Mon Sep 17 00:00:00 2001 From: tawsifkamal Date: Fri, 14 Mar 2025 11:37:31 -0700 Subject: [PATCH 6/6] done --- src/codegen/extensions/langchain/utils/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/codegen/extensions/langchain/utils/utils.py b/src/codegen/extensions/langchain/utils/utils.py index 1de9316f7..81641c39c 100644 --- a/src/codegen/extensions/langchain/utils/utils.py +++ b/src/codegen/extensions/langchain/utils/utils.py @@ -9,7 +9,7 @@ def get_max_model_input_tokens(llm: LLM) -> int: """ # For Claude models not explicitly listed, if model name contains "claude", use Claude's limit if "claude" in llm.model.lower(): - return 50000 + return 200000 # For GPT-4 models elif "gpt-4" in llm.model.lower(): return 128000