---
title: "Agent Coder"
description: "This blog explores the AgentCoder, a multi-agent framework for code generation which integrates. We build it in LangGraph, a Python SDK designed for creating LLM agents capable of reliably managing complex tasks."
author: "Oliver Pfante"
date: "2025-01-02"
categories: [agentic, code]
image: "agent_coder.webp"
toc: true
jupyter: python3
execute: 
  enabled: true
bibliography: ../references.bibtex
---



AgentCoder is a multi-agent framework designed to enhance code generation by leveraging three distinct roles to streamline the coding and validation process. This implementation is directly derived from the paper *"AgentCoder: Multi-Agent Code Generation with Effective Testing and Self-Optimization"*[@huang2024agentcodermultiagentbasedcodegeneration], with some minor modifications to the prompts to optimize its functionality within LangGraph.

### **Key Features of AgentCoder**
1. **Three Roles**: 
   - **Programmer**: Responsible for generating code based on requirements.
   - **Test Designer**: Generates comprehensive, independent test cases without being influenced by the program code. The test cases cover basic ones, edge cases, and large scale test cases, testing for performance.
   - **Test Executor**: A program (not an LLM persona) that runs the tests and provides feedback for further code refinement.
2. **Independent Test Creation**: Ensures objectivity and enhances the reliability of generated tests.
3. **Iterative Improvement**: Refines code through multiple iterations based on test feedback until all tests pass.

### **Blog Objectives**
- Illustrate the implementation of the AgentCoder framework using LangGraph.
- Showcase the use of **Checkpointer (Memory)** in LangGraph to access the state of an executed graph.
- Direct derivation of the implementation from the research paper, with minor modifications to the prompts.

## Setup


In [None]:
from langgraph.graph import StateGraph, END, START, MessagesState
from langgraph.checkpoint.memory import MemorySaver
from langchain_openai import ChatOpenAI
from langchain_core.prompts import (
    ChatPromptTemplate,
    MessagesPlaceholder,
    PromptTemplate,
)
from langchain_core.output_parsers.string import StrOutputParser
from pydantic import BaseModel, Field
from typing import List, Annotated, Literal, TypedDict
from operator import add, attrgetter
import textwrap
import os
from trustcall import create_extractor

## Prompts

We start with writing down the relevant prompts. They are copy-pasted from the appendix of the paper.


In [None]:
# first example from the HumanEval dataset. Used as one-shot example.

human_eval_example = {
    "code_snippet": 'from typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n    """ Check if in given list of numbers, are any two numbers closer to each other than\n    given threshold.\n    >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n    False\n    >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n    True\n    """\n',
    "entry_point": "has_close_elements",
}

In [None]:
# Programmer Prompt
coder_prompt = textwrap.dedent(
    """
    **Role**: You are a software programmer. 
    
    **Task**: As a programmer, you are required to complete the function. Use a Chain-of-Thought approach to break down the problem, create pseudocode, and then write the code in Python language. Ensure that your code is efficient, readable, and well-commented. For example: 
    
    **Input Code Snippet**: 
    ```python 
    {code_snippet}
    # TODO: Implement the logic to determine if any two numbers are closer than the threshold pass 

    # Add your code here to complete the function
    ```

    **Instructions**: 
    1. **Understand and Clarify**: Make sure you understand the task. 
    2. **Algorithm/Method Selection**: Decide on the most efficient way. 
    3. **Pseudocode Creation**: Write down the steps you will follow in pseudocode. 
    4. **Code Generation**: Translate your pseudocode into executable Python code.
    """
).strip()

# one-shot example
response = textwrap.dedent(
    '''
    As a software developer, I'll follow the Chain-of-Thought approach outlined in the prompt to complete the `has_close_elements` function. Let's break it down into the four steps: 
    
    1. **Understand and Clarify**: 
    - The task is to write a function that checks if any two numbers in a given list are closer to each other than a specified threshold. 
    - The function should return `True` if such a pair exists, otherwise `False`. 
    
    2. **Algorithm/Method Selection**: 
    - A straightforward way to solve this would be to compare each number with every other number in the list. 
    - We can use a nested loop where each number is compared to every number that comes after it in the list. 
    - The absolute difference between each pair will be calculated and compared to the threshold. 
    
    3. **Pseudocode Creation**: 
    ```
    function has_close_elements(numbers, threshold): 
        loop through each number in the list (i) 
            for each number, loop through the remaining numbers in the list (j) 
                if the absolute difference between numbers[i] and numbers[j] is less than or equal to the threshold 
                    return True 
    if no such pair is found, return False 
    ``` 
    
    4. **Code Generation**: 
    - Now, I'll translate the pseudocode into Python code. Do not add any test cases in the implementation. Here's the 
      complete function: 
    ```python 
    from typing import List 
    
    def has_close_elements(numbers: List[float], threshold: float) -> bool: 
        """ 
        Check if in given list of numbers, are any two numbers closer to each other than given threshold. 
        
        >>> has_close_elements([1.0, 2.0, 3.0], 0.5) 
        False 
        >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3) 
        True 
        """ 
        # Loop through each number in the list 
        for i in range(len(numbers)): 
            # For each number, loop through the remaining numbers 
            for j in range(i + 1, len(numbers)): 
                # If the absolute difference is less than or equal to the threshold, return True 
                if abs(numbers[i] - numbers[j]) <= threshold: 
                    return True 
                    
        # If no such pair is found, return False 
        return False 
        ```
'''
).strip()

# Prompt to structure the output generated by the Programmer
structure_code = textwrap.dedent(
    '''
    Extract or update the items of the 'Coder' Pydantic class from the previous conversation:
                                    
    <convo>
    {conversation}
    </convo>"""
    '''
).strip()

In [None]:
# Test Designer Prompt
tester_prompt = textwrap.dedent(
    """
    **Role**: As a tester, your task is to create comprehensive test cases for the incomplete `{entry_point}` function. These test cases should encompass Basic, Edge, and Large Scale scenarios to ensure the code's robustness, reliability, and scalability. 
    
    **Input Code Snippet**: 
    ```python
    {code_snippet} 
    ```
    **1. Basic Test Cases**: 
    - **Objective**: To verify the fundamental functionality of the `{entry_point}` function under normal conditions. 
    
    **2. Edge Test Cases**: 
    - **Objective**: To evaluate the function's behavior under extreme or unusual conditions. 
    
    **3. Large Scale Test Cases**: 
    - **Objective**: To assess the function’s performance and scalability with large data samples. 
    
    **Instructions**: 
    - Implement a comprehensive set of test cases following the guidelines above. 
    - Ensure each test case is well-documented with comments explaining the scenario it covers. 
    - Pay special attention to edge cases as they often reveal hidden bugs. 
    - For large-scale tests, focus on the function's efficiency and performance under heavy loads.
    """
).strip()

# one-shot example
tester_response = textwrap.dedent(
    """
    **Role**: As a tester, your task is to create and execute a series of test cases for the `{entry_point}` function. These test cases should include Basic, Edge, and Large Scale scenarios to ensure the function's robustness, reliability, and scalability.

    **Input Code Snippet**:
    ```python
    {code_snippet}
    ```

    **1. Basic Test Cases**:
    - **Objective**: Verify the fundamental functionality of the `{entry_point}` function under normal conditions.
    - **Examples**:
        ```python
        # Test 1: Standard case with close elements
        assert (x := {entry_point}([1.0, 2.5, 3.5, 5.0], 1.0)) == True, str(x) + ' differs from the expected output'
        
        # Test 2: Standard case with no close elements
        assert (x := {entry_point}([1.0, 3.0, 5.0, 7.0], 1.5)) == False, str(x) + ' differs from the expected output'
        ```

    **2. Edge Test Cases**:
    - **Objective**: Evaluate the function's behavior under extreme or unusual conditions.
    - **Examples**:
        ```python
        # Test 1: Empty list
        assert (x := {entry_point}([], 1.0)) == False, str(x) + ' differs from the expected output'
        
        # Test 2: List with all identical elements
        assert (x := {entry_point}([3.0, 3.0, 3.0], 0.0)) == True, str(x) + ' differs from the expected output'
        
        # Test 3: Very small threshold
        assert (x := {entry_point}([1.0, 1.01, 2.0], 0.005)) == False, str(x) + ' differs from the expected output'
        
        # Test 4: List with only two elements
        assert (x := {entry_point}([1.0, 2.0], 1.5)) == True, str(x) + ' differs from the expected output'
        ```

    **3. Large Scale Test Cases**:
    - **Objective**: Assess the function’s performance and scalability with large data samples.
    - **Examples**:
        ```python
        # Test 1: Large list with predictable pattern
        large_list = [i * 0.1 for i in range(100000)]  # Creates a list [0, 0.1, 0.2, ..., 9999.9]
        
        # Test with a threshold where no close elements exist
        assert (x := {entry_point}(large_list, 0.05)) == False, str(x) + ' differs from the expected output'
        
        # Test with a larger threshold where adjacent elements are within the threshold
        assert (x := {entry_point}(large_list, 0.15)) == True, str(x) + ' differs from the expected output'
        ```

    **Instructions**:
    - Implement and execute these test cases.
    - Document any errors, inefficiencies, or unexpected behaviors observed during testing.
    """
).strip()

# Prompt to structure the output of the Test Designer
structure_test = textwrap.dedent(
    '''
    Extract the items of the 'Tester' Pydantic class from the previous conversation:
                                    
    <convo>
    {conversation}
    </convo>"""
    '''
).strip()

## Templates

We require the output of the Programmer and the Test Designer to be structured. To achieve this, we define a Pydantic class for each.

In [None]:
class Coder(BaseModel):
    """    
    Pydantic class where each attribute corresponds to the step the Programmer should
    follow to solve the coding challenge.
    """
    clarify: str = Field(description="Make sure you understand the task.")
    algo: str = Field(
        description="Decide on the most efficient implementation.")
    pseudocode: str = Field(
        description="Write down the steps you follow in the implementation in pseudocode."
    )
    code: str = Field(
        description="translate your pseudocode into executable Python code"
    )


class Tester(BaseModel):
    """
    Pydantic class collecting the basic, edge cases, and large scale test cases.
    """
    basic: List[str] = Field(
        description="Basic test cases, each of them in Python code"
    )
    edge: List[str] = Field(
        description="Edge case test cases, each of them in Python code"
    )
    large: List[str] = Field(
        description="Large scale test cases, testing for performance, each of the in Python code"
    )

## Runnables

The LangChain [Runnable interface](https://python.langchain.com/docs/concepts/runnables/) provides a flexible and composable abstraction for building and chaining operations, including models, prompts, and tools. The Runnables will then be executed in the LangGraph nodes.


In [None]:
# Setup Large Language Model (LLM)
LLM = ChatOpenAI(
    model_name="gpt-4o",
    openai_api_key=os.getenv("OPENAI_API_KEY"),
    temperature=0.0,
)

The Programmer’s prompt template includes both the instruction and a one-shot example. Later, the prompt will include the chat-history as context. Additionally, we crate a [Trustcall Runnable](https://github.com/hinthornw/trustcall) `code`. Trustcall is a Python package that enhances the generation of JSON schemas from LLM responses. Since we anticipate the LLM will iteratively refine the generated code based on feedback from test executions, we define an [Update Schema](https://github.com/hinthornw/trustcall?tab=readme-ov-file#updating-schemas) by enable insertions. 


In [None]:
coder_template = ChatPromptTemplate(
    messages=[
        ("user", coder_prompt.format(**human_eval_example)),
        ("ai", response),
        ("user", coder_prompt),
        MessagesPlaceholder("messages", optional=True),
    ],
    input_variables=["code_snippet", "messages"],
)

# Trustcall object for structuring the Programmer's output
code = create_extractor(
    LLM, tools=[Coder], tool_choice="Coder", enable_inserts=True)

In [None]:
# Tester's template including the instruction and a one-shot example.
tester_template = ChatPromptTemplate(
    messages=[
        ("user", tester_prompt.format(**human_eval_example)),
        ("ai", tester_response.format(**human_eval_example)),
        ("user", tester_prompt),
    ],
    input_variables=["code_snippet", "entry_point"],
)

# Trustcall object for structuring the Tester's output
tests = create_extractor(LLM, tools=[Tester], tool_choice="Tester")

## States

We define a separate Input state to keep the unchanged input separated from the state channels which will be updated: the test cases, and the code.


In [None]:
class InputState(TypedDict):
    code_snippet: str  # function header including its doc-string, i.e., the input from HumanEval
    entry_point: str  # name of the function


class OverallState(InputState, MessagesState):
    # Since we inherit from Message state, we have also a 'messages' channel storing the chat history
    tester: Tester  # test cases, basic, edge cases, and large scale test cases
    # accumulates all implementation attempts
    coder: Annotated[List[Coder], add]

## Nodes

The nodes of the Graph invoke the previously defined Runnables.


In [None]:
def write_program(state: OverallState) -> OverallState:
    """
    Writes the program as string. Represents the 'Programmer'.

    Args:
        state: The current state of the program, which includes 
            the input context and the current implementation (if it exists)

    Returns:
        The updated state with new messages containing the generated program.
    """
    return {"messages": (coder_template | LLM).invoke(state)}


def get_code(state: OverallState) -> OverallState:
    """
    Extracts the steps, the Programmer followed to solve the coding challenge as JSON schema
    using a Trustcall runnable. It does it, by updating an already existing schema if it exists.

    Args:
        state: The current state of the program, containing the conversation 
            history and the coder's existing implementation as string.

    Returns:
        The updated state with the structured output containing the code in the "coder" key.
    """
    if state["coder"]:
        structured = code.invoke(
            {
                "messages": [
                    (
                        "user",
                        structure_code.format(
                            conversation=state["messages"][-1].content
                        ),
                    )
                ],
                "existing": {"Coder": state["coder"][-1].model_dump()},
            }
        )
    else:
        structured = code.invoke(
            structure_code.format(conversation=state["messages"][-1].content)
        )
    return {"coder": structured["responses"]}


def write_tests(state: OverallState) -> OverallState:
    """
    Generates test cases for the provided state using the defined templates and parsers.
    Represents the 'Test Designer'

    Args:
        state: The current state containing the context and requirements for generating tests.

    Returns:
        An updated state with the generated test cases stored under the "tester" key.
    """
    structured_tests = (
        tester_template
        | LLM
        | StrOutputParser()
        | PromptTemplate.from_template(structure_test)
        | tests
    ).invoke(state)
    return {"tester": structured_tests["responses"][0]}


def executor(state: OverallState) -> OverallState:
    """
    Executes the generated code against the provided test cases and captures error messages.
    Appends them to the message history. This node represents the 'Executor'.

    Args:
        state: The current state containing the code and test cases.

    Returns:
        Updated state with error messages for failed tests, if any.
    """
    error_messages = []
    for test_type in ["basic", "edge", "large"]:
        for test in attrgetter(test_type)(state["tester"]):
            try:
                exec(state["coder"][-1].code + test)
            except Exception as e:
                error_messages.append(
                    (
                        "user",
                        textwrap.dedent(
                            """
                            FAILED TEST

                            Your solution failed the test 
                                    
                            <test>               
                            {test}
                            <test>

                            with the error message:

                            <error>
                            {error}
                            <error>                
                            """
                        )
                        .strip()
                        .format(test=test, error=e),
                    )
                )
    return {"messages": error_messages}

## Conditional Edges


In [None]:
def correct_implementation(state: OverallState) -> Literal["Programmer", "END"]:
    """
    Decides the next workflow step based on test results.

    If any test fails, the workflow returns to "Programmer" for refinement. Otherwise, it proceeds to "END".

    Args:
        state: The current workflow state with messages and test results.

    Returns:
        "Programmer" if tests failed, or "END" if all tests passed.
    """
    if "FAILED TEST" in state["messages"][-1].content:
        return "Programmer"
    else:
        return "END"

## Build the Graph

We add [Short-term Memory](https://langchain-ai.github.io/langgraph/concepts/memory/) to the Graph because we want to access the Graph's state after it has been executed. Recording the state updates by enabling check points also helps with debugging.


In [None]:
workflow = StateGraph(OverallState, input=InputState)

# add nodes
workflow.add_node("Programmer", write_program)
workflow.add_node("Get Code", get_code)
workflow.add_node("Test Designer", write_tests)
workflow.add_node("Executor", executor)

# add edges
workflow.add_edge(START, "Programmer")
workflow.add_edge(START, "Test Designer")
workflow.add_edge("Programmer", "Get Code")
workflow.add_edge("Get Code", "Executor")
workflow.add_edge("Test Designer", "Executor")
workflow.add_conditional_edges("Executor", correct_implementation, {
                               "Programmer": "Programmer", "END": END})

# compile the graph

memory = MemorySaver()
coder = workflow.compile(checkpointer=memory)

In [None]:
# View
from IPython.display import Image, display

display(Image(coder.get_graph(xray=1).draw_mermaid_png()))

## Run the Agent

Execute the graph in streaming mode and render the messages. We use the ‘values’ streaming mode, which retains the entire chat history at each step. To focus on updates, only the most recent message is displayed. As input we pick an example of the [HumanEval](https://github.com/openai/human-eval)[@chen2021codex] test set.


In [None]:
# One sample from the HumanEval testset
problem = {
    "task_id": "HumanEval\/1",
    "prompt": "from typing import List\n\n\ndef separate_paren_groups(paren_string: str) -> List[str]:\n    \"\"\" Input to this function is a string containing multiple groups of nested parentheses. Your goal is to\n    separate those group into separate strings and return the list of those.\n    Separate groups are balanced (each open brace is properly closed) and not nested within each other\n    Ignore any spaces in the input string.\n    >>> separate_paren_groups('( ) (( )) (( )( ))')\n    ['()', '(())', '(()())']\n    \"\"\"\n",
    "canonical_solution": "    result = []\n    current_string = []\n    current_depth = 0\n\n    for c in paren_string:\n        if c == '(':\n            current_depth += 1\n            current_string.append(c)\n        elif c == ')':\n            current_depth -= 1\n            current_string.append(c)\n\n            if current_depth == 0:\n                result.append(''.join(current_string))\n                current_string.clear()\n\n    return result\n",
    "test": "\n\nMETADATA = {\n    'author': 'jt',\n    'dataset': 'test'\n}\n\n\ndef check(candidate):\n    assert candidate('(()()) ((())) () ((())()())') == [\n        '(()())', '((()))', '()', '((())()())'\n    ]\n    assert candidate('() (()) ((())) (((())))') == [\n        '()', '(())', '((()))', '(((())))'\n    ]\n    assert candidate('(()(())((())))') == [\n        '(()(())((())))'\n    ]\n    assert candidate('( ) (( )) (( )( ))') == ['()', '(())', '(()())']\n",
    "entry_point": "separate_paren_groups",
}

# Input
initial_input = {
    "code_snippet": problem["prompt"],
    "entry_point": problem["entry_point"],
}

# Thread
thread = {"configurable": {"thread_id": "1"}}

# Run the agent by streaming the graph
for event in coder.stream(initial_input, thread, stream_mode="values"):
    if event["messages"]:
        event["messages"][-1].pretty_print()

Last, we access the final implementation and run it against the HumanEval test cases. 


In [None]:
# Access the final state of the graph
state = coder.get_state(thread)

# Run tests on the final implementation
exec(
    state.values["coder"][-1].code
    + problem["test"]
    + f"\ncheck({problem['entry_point']})"
)

print(f'Tests ran successful after {len(state.values["coder"])} iterations!')