diff --git a/src/codegen/agents/code_agent.py b/src/codegen/agents/code_agent.py index e27f485da..717c507b2 100644 --- a/src/codegen/agents/code_agent.py +++ b/src/codegen/agents/code_agent.py @@ -5,6 +5,7 @@ from langchain.tools import BaseTool from langchain_core.messages import AIMessage, HumanMessage from langchain_core.runnables.config import RunnableConfig +from langgraph.graph.graph import CompiledGraph from langsmith import Client from codegen.extensions.langchain.agent import create_codebase_agent @@ -15,16 +16,20 @@ if TYPE_CHECKING: from codegen import Codebase +from codegen.agents.utils import AgentConfig + class CodeAgent: """Agent for interacting with a codebase.""" codebase: "Codebase" - agent: any + agent: CompiledGraph langsmith_client: Client project_name: str thread_id: str | None = None - config: dict = {} + run_id: str | None = None + instance_id: str | None = None + difficulty: int | None = None def __init__( self, @@ -35,6 +40,8 @@ def __init__( tools: Optional[list[BaseTool]] = None, tags: Optional[list[str]] = [], metadata: Optional[dict] = {}, + agent_config: Optional[AgentConfig] = None, + thread_id: Optional[str] = None, **kwargs, ): """Initialize a CodeAgent. @@ -60,15 +67,28 @@ def __init__( model_name=model_name, memory=memory, additional_tools=tools, + config=agent_config, **kwargs, ) self.model_name = model_name self.langsmith_client = Client() + if thread_id is None: + self.thread_id = str(uuid4()) + else: + self.thread_id = thread_id + # Get project name from environment variable or use a default self.project_name = os.environ.get("LANGCHAIN_PROJECT", "RELACE") print(f"Using LangSmith project: {self.project_name}") + # Store SWEBench metadata if provided + self.run_id = metadata.get("run_id") + self.instance_id = metadata.get("instance_id") + # Extract difficulty value from "difficulty_X" format + difficulty_str = metadata.get("difficulty", "") + self.difficulty = int(difficulty_str.split("_")[1]) if difficulty_str and "_" in difficulty_str else None + # Initialize tags for agent trace self.tags = [*tags, self.model_name] @@ -79,7 +99,7 @@ def __init__( **metadata, } - def run(self, prompt: str, thread_id: Optional[str] = None) -> str: + def run(self, prompt: str) -> str: """Run the agent with a prompt. Args: @@ -89,12 +109,9 @@ def run(self, prompt: str, thread_id: Optional[str] = None) -> str: Returns: The agent's response """ - if thread_id is None: - thread_id = str(uuid4()) - self.thread_id = thread_id self.config = { "configurable": { - "thread_id": thread_id, + "thread_id": self.thread_id, "metadata": {"project": self.project_name}, }, "recursion_limit": 100, @@ -104,7 +121,7 @@ def run(self, prompt: str, thread_id: Optional[str] = None) -> str: # see more https://langchain-ai.github.io/langgraph/concepts/low_level/#reducers input = {"query": prompt} - config = RunnableConfig(configurable={"thread_id": thread_id}, tags=self.tags, metadata=self.metadata, recursion_limit=100) + config = RunnableConfig(configurable={"thread_id": self.thread_id}, tags=self.tags, metadata=self.metadata, recursion_limit=200) # we stream the steps instead of invoke because it allows us to access intermediate nodes stream = self.agent.stream(input, config=config, stream_mode="values") @@ -112,7 +129,7 @@ def run(self, prompt: str, thread_id: Optional[str] = None) -> str: run_ids = [] for s in stream: - if len(s["messages"]) == 0: + if len(s["messages"]) == 0 or isinstance(s["messages"][-1], HumanMessage): message = HumanMessage(content=prompt) else: message = s["messages"][-1] diff --git a/src/codegen/agents/utils.py b/src/codegen/agents/utils.py new file mode 100644 index 000000000..e5cdcebaf --- /dev/null +++ b/src/codegen/agents/utils.py @@ -0,0 +1,8 @@ +from typing import TypedDict + + +class AgentConfig(TypedDict, total=False): + """Configuration options for the CodeAgent.""" + + keep_first_messages: int # Number of initial messages to keep during summarization + max_messages: int # Maximum number of messages before triggering summarization diff --git a/src/codegen/extensions/langchain/agent.py b/src/codegen/extensions/langchain/agent.py index c1c796b9a..8917daa7f 100644 --- a/src/codegen/extensions/langchain/agent.py +++ b/src/codegen/extensions/langchain/agent.py @@ -1,12 +1,13 @@ """Demo implementation of an agent with Codegen tools.""" -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Optional from langchain.tools import BaseTool from langchain_core.messages import SystemMessage from langgraph.checkpoint.memory import MemorySaver from langgraph.graph.graph import CompiledGraph +from codegen.agents.utils import AgentConfig from codegen.extensions.langchain.llm import LLM from codegen.extensions.langchain.prompts import REASONER_SYSTEM_MESSAGE from codegen.extensions.langchain.tools import ( @@ -38,6 +39,7 @@ def create_codebase_agent( memory: bool = True, debug: bool = False, additional_tools: Optional[list[BaseTool]] = None, + config: Optional[AgentConfig] = None, **kwargs, ) -> CompiledGraph: """Create an agent with all codebase tools. @@ -57,7 +59,7 @@ def create_codebase_agent( Returns: Initialized agent with message history """ - llm = LLM(model_provider=model_provider, model_name=model_name, max_tokens=8192, **kwargs) + llm = LLM(model_provider=model_provider, model_name=model_name, **kwargs) # Get all codebase tools tools = [ @@ -89,7 +91,7 @@ def create_codebase_agent( memory = MemorySaver() if memory else None - return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug) + return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug, config=config) def create_chat_agent( @@ -100,6 +102,7 @@ def create_chat_agent( memory: bool = True, debug: bool = False, additional_tools: Optional[list[BaseTool]] = None, + config: Optional[dict[str, Any]] = None, # over here you can pass in the max length of the number of messages **kwargs, ) -> CompiledGraph: """Create an agent with all codebase tools. @@ -138,7 +141,7 @@ def create_chat_agent( memory = MemorySaver() if memory else None - return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug) + return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug, config=config) def create_codebase_inspector_agent( @@ -148,6 +151,7 @@ def create_codebase_inspector_agent( system_message: SystemMessage = SystemMessage(REASONER_SYSTEM_MESSAGE), memory: bool = True, debug: bool = True, + config: Optional[dict[str, Any]] = None, **kwargs, ) -> CompiledGraph: """Create an inspector agent with read-only codebase tools. @@ -175,7 +179,7 @@ def create_codebase_inspector_agent( ] memory = MemorySaver() if memory else None - return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug) + return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug, config=config) def create_agent_with_tools( @@ -185,6 +189,7 @@ def create_agent_with_tools( system_message: SystemMessage = SystemMessage(REASONER_SYSTEM_MESSAGE), memory: bool = True, debug: bool = True, + config: Optional[dict[str, Any]] = None, **kwargs, ) -> CompiledGraph: """Create an agent with a specific set of tools. @@ -209,4 +214,4 @@ def create_agent_with_tools( memory = MemorySaver() if memory else None - return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug) + return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug, config=config) diff --git a/src/codegen/extensions/langchain/graph.py b/src/codegen/extensions/langchain/graph.py index 3da422560..2987f6863 100644 --- a/src/codegen/extensions/langchain/graph.py +++ b/src/codegen/extensions/langchain/graph.py @@ -1,34 +1,92 @@ """Demo implementation of an agent with Codegen tools.""" -from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional +import uuid +from typing import Annotated, Any, Literal, Optional, Union import anthropic import openai from langchain.tools import BaseTool -from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, SystemMessage +from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, SystemMessage, ToolMessage +from langchain_core.prompts import ChatPromptTemplate from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, START -from langgraph.graph.message import add_messages from langgraph.graph.state import CompiledGraph, StateGraph from langgraph.prebuilt import ToolNode from langgraph.pregel import RetryPolicy +from codegen.agents.utils import AgentConfig +from codegen.extensions.langchain.llm import LLM +from codegen.extensions.langchain.prompts import SUMMARIZE_CONVERSATION_PROMPT +from codegen.extensions.langchain.utils.utils import get_max_model_input_tokens + + +def manage_messages(existing: list[AnyMessage], updates: Union[list[AnyMessage], dict]) -> list[AnyMessage]: + """Custom reducer for managing message history with summarization. + + Args: + existing: Current list of messages + updates: Either new messages to append or a dict specifying how to update messages + + Returns: + Updated list of messages + """ + if isinstance(updates, list): + # Ensure all messages have IDs + for msg in existing + updates: + if not hasattr(msg, "id") or msg.id is None: + msg.id = str(uuid.uuid4()) + + # Create a map of existing messages by ID + existing_by_id = {msg.id: i for i, msg in enumerate(existing)} + + # Start with copy of existing messages + result = existing.copy() + + # Update or append new messages + for msg in updates: + if msg.id in existing_by_id: + # Update existing message + result[existing_by_id[msg.id]] = msg + else: + # Append new message + result.append(msg) + + return result + + if isinstance(updates, dict): + if updates.get("type") == "summarize": + # Create summary message and mark it with additional_kwargs + summary_msg = AIMessage( + content=f"""Here is a summary of the conversation + from a previous timestep to aid for the continuing conversation: \n{updates["summary"]}\n\n""", + additional_kwargs={"is_summary": True}, # Use additional_kwargs for custom metadata + ) + summary_msg.id = str(uuid.uuid4()) + updates["tail"][-1].additional_kwargs["just_summarized"] = True + result = updates["head"] + [summary_msg] + updates["tail"] + return result + + return existing + class GraphState(dict[str, Any]): """State of the graph.""" query: str final_answer: str - messages: Annotated[list[AnyMessage], add_messages] + messages: Annotated[list[AnyMessage], manage_messages] class AgentGraph: """Main graph class for the agent.""" - def __init__(self, model: "LLM", tools: list[BaseTool], system_message: SystemMessage): + def __init__(self, model: "LLM", tools: list[BaseTool], system_message: SystemMessage, config: AgentConfig | None = None): self.model = model.bind_tools(tools) self.tools = tools self.system_message = system_message + self.config = config + self.max_messages = config.get("max_messages", 100) if config else 100 + self.keep_first_messages = config.get("keep_first_messages", 1) if config else 1 # =================================== NODES ==================================== @@ -36,23 +94,109 @@ def __init__(self, model: "LLM", tools: list[BaseTool], system_message: SystemMe def reasoner(self, state: GraphState) -> dict[str, Any]: new_turn = len(state["messages"]) == 0 or isinstance(state["messages"][-1], AIMessage) messages = state["messages"] + if new_turn: query = state["query"] messages.append(HumanMessage(content=query)) result = self.model.invoke([self.system_message, *messages]) - if isinstance(result, AIMessage): - return {"messages": [*messages, result], "final_answer": result.content} + updated_messages = [*messages, result] + return {"messages": updated_messages, "final_answer": result.content} + + updated_messages = [*messages, result] + return {"messages": updated_messages} + + def summarize_conversation(self, state: GraphState): + """Summarize conversation while preserving key context and recent messages.""" + messages = state["messages"] + keep_first = self.keep_first_messages + target_size = len(messages) // 2 + messages_from_tail = target_size - keep_first + + head = messages[:keep_first] + tail = messages[-messages_from_tail:] + to_summarize = messages[: len(messages) - messages_from_tail] + + # Handle tool message pairing at truncation point + truncation_idx = len(messages) - messages_from_tail + if truncation_idx > 0 and isinstance(messages[truncation_idx], ToolMessage): + # Keep the AI message right before it + tail = [messages[truncation_idx - 1], *tail] + + # Skip if nothing to summarize + if not to_summarize: + return state + + # Define constants + HEADER_WIDTH = 40 + HEADER_TYPES = {"human": "HUMAN", "ai": "AI", "summary": "SUMMARY FROM PREVIOUS TIMESTEP", "tool_call": "TOOL CALL", "tool_response": "TOOL RESPONSE"} + + def format_header(header_type: str) -> str: + """Format message header with consistent padding. + + Args: + header_type: Type of header to format (must be one of HEADER_TYPES) + + Returns: + Formatted header string with padding + """ + header = HEADER_TYPES[header_type] + padding = "=" * ((HEADER_WIDTH - len(header)) // 2) + return f"{padding} {header} {padding}\n" + + # Format messages with appropriate headers + formatted_messages = [] + for msg in to_summarize: # No need for slice when iterating full list + if isinstance(msg, HumanMessage): + formatted_messages.append(format_header("human") + msg.content) + elif isinstance(msg, AIMessage): + # Check for summary message using additional_kwargs + if msg.additional_kwargs.get("is_summary"): + formatted_messages.append(format_header("summary") + msg.content) + elif isinstance(msg.content, list) and len(msg.content) > 0 and isinstance(msg.content[0], dict): + for item in msg.content: # No need for slice when iterating full list + if item.get("type") == "text": + formatted_messages.append(format_header("ai") + item["text"]) + elif item.get("type") == "tool_use": + formatted_messages.append(format_header("tool_call") + f"Tool: {item['name']}\nInput: {item['input']}") + else: + formatted_messages.append(format_header("ai") + msg.content) + elif isinstance(msg, ToolMessage): + formatted_messages.append(format_header("tool_response") + msg.content) + + conversation = "\n".join(formatted_messages) # No need for slice when joining full list + + summary_llm = LLM( + model_provider="anthropic", + model_name="claude-3-5-sonnet-latest", + temperature=0.3, + ) + + chain = ChatPromptTemplate.from_template(SUMMARIZE_CONVERSATION_PROMPT) | summary_llm + new_summary = chain.invoke({"conversation": conversation}).content - return {"messages": [*messages, result]} + return {"messages": {"type": "summarize", "summary": new_summary, "tail": tail, "head": head}} # =================================== EDGE CONDITIONS ==================================== - def should_continue(self, state: GraphState) -> Literal["tools", END]: + def should_continue(self, state: GraphState) -> Literal["tools", "summarize_conversation", END]: messages = state["messages"] last_message = messages[-1] - if hasattr(last_message, "tool_calls") and last_message.tool_calls: + just_summarized = last_message.additional_kwargs.get("just_summarized") + curr_input_tokens = last_message.usage_metadata["input_tokens"] + max_input_tokens = get_max_model_input_tokens(self.model) + + # Summarize if the number of messages passed in exceeds the max_messages threshold (default 100) + if len(messages) > self.max_messages: + return "summarize_conversation" + + # Summarize if the last message exceeds the max input tokens of the model - 10000 tokens + elif isinstance(last_message, AIMessage) and not just_summarized and curr_input_tokens > (max_input_tokens - 10000): + return "summarize_conversation" + + elif hasattr(last_message, "tool_calls") and last_message.tool_calls: return "tools" + return END # =================================== COMPILE GRAPH ==================================== @@ -63,7 +207,7 @@ def create(self, checkpointer: Optional[MemorySaver] = None, debug: bool = False # the retry policy has an initial interval, a backoff factor, and a max interval of controlling the # amount of time between retries retry_policy = RetryPolicy( - retry_on=[anthropic.RateLimitError, openai.RateLimitError, anthropic.InternalServerError], + retry_on=[anthropic.RateLimitError, openai.RateLimitError, anthropic.InternalServerError, anthropic.BadRequestError], max_attempts=10, initial_interval=30.0, # Start with 30 second wait backoff_factor=2, # Double the wait time each retry @@ -86,6 +230,7 @@ def get_field_descriptions(tool_obj): return field_descriptions try: + # Get all field descriptions from the tool schema_cls = tool_obj.args_schema # Handle Pydantic v2 @@ -129,8 +274,8 @@ def get_field_descriptions(tool_obj): import json tool_input = json.loads(input_str) - except: - pass + except Exception as e: + print(f"Failed to parse tool input: {e}") # Handle validation errors with more helpful messages if "validation error" in error_msg.lower(): @@ -315,6 +460,7 @@ def get_field_descriptions(tool_obj): # Add nodes builder.add_node("reasoner", self.reasoner, retry=retry_policy) builder.add_node("tools", ToolNode(self.tools, handle_tool_errors=handle_tool_errors), retry=retry_policy) + builder.add_node("summarize_conversation", self.summarize_conversation, retry=retry_policy) # Add edges builder.add_edge(START, "reasoner") @@ -323,6 +469,7 @@ def get_field_descriptions(tool_obj): "reasoner", self.should_continue, ) + builder.add_conditional_edges("summarize_conversation", self.should_continue) return builder.compile(checkpointer=checkpointer, debug=debug) @@ -333,11 +480,8 @@ def create_react_agent( system_message: SystemMessage, checkpointer: Optional[MemorySaver] = None, debug: bool = False, + config: Optional[dict[str, Any]] = None, ) -> CompiledGraph: """Create a reactive agent graph.""" - graph = AgentGraph(model, tools, system_message) + graph = AgentGraph(model, tools, system_message, config=config) return graph.create(checkpointer=checkpointer, debug=debug) - - -if TYPE_CHECKING: - from codegen.extensions.langchain.llm import LLM diff --git a/src/codegen/extensions/langchain/llm.py b/src/codegen/extensions/langchain/llm.py index 54b9a91a2..0d4795740 100644 --- a/src/codegen/extensions/langchain/llm.py +++ b/src/codegen/extensions/langchain/llm.py @@ -89,19 +89,19 @@ def _get_model(self) -> BaseChatModel: if not os.getenv("ANTHROPIC_API_KEY"): msg = "ANTHROPIC_API_KEY not found in environment. Please set it in your .env file or environment variables." raise ValueError(msg) - return ChatAnthropic(**self._get_model_kwargs(), max_retries=10, timeout=1000) + return ChatAnthropic(**self._get_model_kwargs(), max_tokens=8192, max_retries=10, timeout=1000) elif self.model_provider == "openai": if not os.getenv("OPENAI_API_KEY"): msg = "OPENAI_API_KEY not found in environment. Please set it in your .env file or environment variables." raise ValueError(msg) - return ChatOpenAI(**self._get_model_kwargs(), max_retries=10, timeout=1000) + return ChatOpenAI(**self._get_model_kwargs(), max_tokens=4096, max_retries=10, timeout=1000) elif self.model_provider == "xai": if not os.getenv("XAI_API_KEY"): msg = "XAI_API_KEY not found in environment. Please set it in your .env file or environment variables." raise ValueError(msg) - return ChatXAI(**self._get_model_kwargs()) + return ChatXAI(**self._get_model_kwargs(), max_tokens=8192) msg = f"Unknown model provider: {self.model_provider}. Must be one of: anthropic, openai, xai" raise ValueError(msg) diff --git a/src/codegen/extensions/langchain/prompts.py b/src/codegen/extensions/langchain/prompts.py index a33157bcd..3c9ba9744 100644 --- a/src/codegen/extensions/langchain/prompts.py +++ b/src/codegen/extensions/langchain/prompts.py @@ -43,3 +43,39 @@ Ensure if specifiying line numbers, it's chosen with room (around 20 lines before and 20 lines after the edit range) """ + + +SUMMARIZE_CONVERSATION_PROMPT = """ + You are an expert conversation summarizer. You are given below a conversation between an AI coding agent and a human. + It contains a human request and the agent thought process + alternating from AIMessage, ToolMessage, HumanMessage, etc. + + This AI agent is an expert software engineer with deep knowledge of code analysis, refactoring, and development best practices. + + Your goal as the summarizer is to summarize the conversation between the AI agent and the human in an extremely detailed and comprehensive manner. + + Ensure the summary includes key details of the conversation, such as: + - User's request and context + - Code changes and their impact + - File and directory structure + - Dependencies and imports + - Any errors or exceptions + - User's clarifications and follow-up questions + - File modifications and their impact + - Any other relevant + + IMPORTANT: Your summary must be at least 4000 words long to ensure that you have added a lot of useful information to it. + Ensure your summary is very detailed and comprehensive. It's important to capture all the context of the conversation. + + IMPORTANT: Do not attempt to provide any solutions or any other unnecessary commentary. Your sole job is to summarize the conversation in the most detailed way possible + IMPORTANT: Your summary will be fed back into the LLM to continue the conversation so that it has the context of the conversation instead of having to store the whole history. + That's why your summary does not signal the end of the conversation. It will be used the the agent to further inch towards the goal of solving the user's issue. + + IMPORTANT: The conversation given may include previous summaries generated by you in an earlier time step of the conversation. Use this to your advantage + alongside the conversation to generate a more comprehensive summary of the entire conversation. + + Here is the conversation given below: + + {conversation} + +""" diff --git a/src/codegen/extensions/langchain/utils/utils.py b/src/codegen/extensions/langchain/utils/utils.py new file mode 100644 index 000000000..81641c39c --- /dev/null +++ b/src/codegen/extensions/langchain/utils/utils.py @@ -0,0 +1,21 @@ +from langchain_core.language_models import LLM + + +def get_max_model_input_tokens(llm: LLM) -> int: + """Get the maximum input tokens for the current model. + + Returns: + int: Maximum number of input tokens supported by the model + """ + # For Claude models not explicitly listed, if model name contains "claude", use Claude's limit + if "claude" in llm.model.lower(): + return 200000 + # For GPT-4 models + elif "gpt-4" in llm.model.lower(): + return 128000 + # For Grok models + elif "grok" in llm.model.lower(): + return 1000000 + + # default to gpt as it's lower bound + return 128000