Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 26 additions & 9 deletions src/codegen/agents/code_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from langchain.tools import BaseTool
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables.config import RunnableConfig
from langgraph.graph.graph import CompiledGraph
from langsmith import Client

from codegen.extensions.langchain.agent import create_codebase_agent
Expand All @@ -15,16 +16,20 @@
if TYPE_CHECKING:
from codegen import Codebase

from codegen.agents.utils import AgentConfig


class CodeAgent:
"""Agent for interacting with a codebase."""

codebase: "Codebase"
agent: any
agent: CompiledGraph
langsmith_client: Client
project_name: str
thread_id: str | None = None
config: dict = {}
run_id: str | None = None
instance_id: str | None = None
difficulty: int | None = None

def __init__(
self,
Expand All @@ -35,6 +40,8 @@ def __init__(
tools: Optional[list[BaseTool]] = None,
tags: Optional[list[str]] = [],
metadata: Optional[dict] = {},
agent_config: Optional[AgentConfig] = None,
thread_id: Optional[str] = None,
**kwargs,
):
"""Initialize a CodeAgent.
Expand All @@ -60,15 +67,28 @@ def __init__(
model_name=model_name,
memory=memory,
additional_tools=tools,
config=agent_config,
**kwargs,
)
self.model_name = model_name
self.langsmith_client = Client()

if thread_id is None:
self.thread_id = str(uuid4())
else:
self.thread_id = thread_id

# Get project name from environment variable or use a default
self.project_name = os.environ.get("LANGCHAIN_PROJECT", "RELACE")
print(f"Using LangSmith project: {self.project_name}")

# Store SWEBench metadata if provided
self.run_id = metadata.get("run_id")
self.instance_id = metadata.get("instance_id")
# Extract difficulty value from "difficulty_X" format
difficulty_str = metadata.get("difficulty", "")
self.difficulty = int(difficulty_str.split("_")[1]) if difficulty_str and "_" in difficulty_str else None

# Initialize tags for agent trace
self.tags = [*tags, self.model_name]

Expand All @@ -79,7 +99,7 @@ def __init__(
**metadata,
}

def run(self, prompt: str, thread_id: Optional[str] = None) -> str:
def run(self, prompt: str) -> str:
"""Run the agent with a prompt.

Args:
Expand All @@ -89,12 +109,9 @@ def run(self, prompt: str, thread_id: Optional[str] = None) -> str:
Returns:
The agent's response
"""
if thread_id is None:
thread_id = str(uuid4())
self.thread_id = thread_id
self.config = {
"configurable": {
"thread_id": thread_id,
"thread_id": self.thread_id,
"metadata": {"project": self.project_name},
},
"recursion_limit": 100,
Expand All @@ -104,15 +121,15 @@ def run(self, prompt: str, thread_id: Optional[str] = None) -> str:
# see more https://langchain-ai.github.io/langgraph/concepts/low_level/#reducers
input = {"query": prompt}

config = RunnableConfig(configurable={"thread_id": thread_id}, tags=self.tags, metadata=self.metadata, recursion_limit=100)
config = RunnableConfig(configurable={"thread_id": self.thread_id}, tags=self.tags, metadata=self.metadata, recursion_limit=200)
# we stream the steps instead of invoke because it allows us to access intermediate nodes
stream = self.agent.stream(input, config=config, stream_mode="values")

# Keep track of run IDs from the stream
run_ids = []

for s in stream:
if len(s["messages"]) == 0:
if len(s["messages"]) == 0 or isinstance(s["messages"][-1], HumanMessage):
message = HumanMessage(content=prompt)
else:
message = s["messages"][-1]
Expand Down
8 changes: 8 additions & 0 deletions src/codegen/agents/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from typing import TypedDict


class AgentConfig(TypedDict, total=False):
"""Configuration options for the CodeAgent."""

keep_first_messages: int # Number of initial messages to keep during summarization
max_messages: int # Maximum number of messages before triggering summarization
17 changes: 11 additions & 6 deletions src/codegen/extensions/langchain/agent.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Demo implementation of an agent with Codegen tools."""

from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, Optional

from langchain.tools import BaseTool
from langchain_core.messages import SystemMessage
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph.graph import CompiledGraph

from codegen.agents.utils import AgentConfig
from codegen.extensions.langchain.llm import LLM
from codegen.extensions.langchain.prompts import REASONER_SYSTEM_MESSAGE
from codegen.extensions.langchain.tools import (
Expand Down Expand Up @@ -38,6 +39,7 @@ def create_codebase_agent(
memory: bool = True,
debug: bool = False,
additional_tools: Optional[list[BaseTool]] = None,
config: Optional[AgentConfig] = None,
**kwargs,
) -> CompiledGraph:
"""Create an agent with all codebase tools.
Expand All @@ -57,7 +59,7 @@ def create_codebase_agent(
Returns:
Initialized agent with message history
"""
llm = LLM(model_provider=model_provider, model_name=model_name, max_tokens=8192, **kwargs)
llm = LLM(model_provider=model_provider, model_name=model_name, **kwargs)

# Get all codebase tools
tools = [
Expand Down Expand Up @@ -89,7 +91,7 @@ def create_codebase_agent(

memory = MemorySaver() if memory else None

return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug)
return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug, config=config)


def create_chat_agent(
Expand All @@ -100,6 +102,7 @@ def create_chat_agent(
memory: bool = True,
debug: bool = False,
additional_tools: Optional[list[BaseTool]] = None,
config: Optional[dict[str, Any]] = None, # over here you can pass in the max length of the number of messages
**kwargs,
) -> CompiledGraph:
"""Create an agent with all codebase tools.
Expand Down Expand Up @@ -138,7 +141,7 @@ def create_chat_agent(

memory = MemorySaver() if memory else None

return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug)
return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug, config=config)


def create_codebase_inspector_agent(
Expand All @@ -148,6 +151,7 @@ def create_codebase_inspector_agent(
system_message: SystemMessage = SystemMessage(REASONER_SYSTEM_MESSAGE),
memory: bool = True,
debug: bool = True,
config: Optional[dict[str, Any]] = None,
**kwargs,
) -> CompiledGraph:
"""Create an inspector agent with read-only codebase tools.
Expand Down Expand Up @@ -175,7 +179,7 @@ def create_codebase_inspector_agent(
]

memory = MemorySaver() if memory else None
return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug)
return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug, config=config)


def create_agent_with_tools(
Expand All @@ -185,6 +189,7 @@ def create_agent_with_tools(
system_message: SystemMessage = SystemMessage(REASONER_SYSTEM_MESSAGE),
memory: bool = True,
debug: bool = True,
config: Optional[dict[str, Any]] = None,
**kwargs,
) -> CompiledGraph:
"""Create an agent with a specific set of tools.
Expand All @@ -209,4 +214,4 @@ def create_agent_with_tools(

memory = MemorySaver() if memory else None

return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug)
return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug, config=config)
Loading
Loading