# Module 8: 15 - Anatomy of an AI Agent - Enabling ReAct Prompting
----------------------------------------------------------------
In this lesson, we will enhance our AI Agent to handle ReAct prompting using the latest prompt templating capability. We will ensure the execution loop is properly managed, leveraging the OpenAI Chat Completion API parameter "stop" to control the flow. This allows the agent to execute suggested actions/tools, pass the observations back to the prompt, and iterate until the final answer is found.

## Objectives
* Enable ReAct prompting for the AI Agent.
* Use prompt templates to manage the execution loop and handle intermediate actions.
* Integrate the OpenAI Chat Completion API "stop" parameter to control the prompting process.
* Parse responses to identify and process the final answer.

## What this session covers:
* Defining the current agent structure, including the LLM Client and short-term memory.
* Implementing prompt templates to support ReAct prompting.
* Managing the execution loop with the "stop" parameter in the OpenAI Chat Completion API.
* Parsing and validating intermediate and final responses.
* Integrating and testing the enhanced Agent with ReAct prompting capabilities.

## Install Libraries

In [2]:
#! pip install openai

## Define Current Agent Structure

### LLM Client

In [2]:
from typing import Dict, Any, List
import openai

class OpenAIChatCompletion:
    """Interacts with OpenAI's API for chat completions."""
    def __init__(self, model: str, api_key: str = None, base_url: str = None):
        self.client = openai.OpenAI(api_key=api_key, base_url=base_url)
        self.model = model

    def generate(self, messages: List[str], tools: List[Dict[str, Any]] = None, **kwargs) -> Dict[str, Any]:
        """Generates a response from OpenAI's API."""
        params = {'messages': messages, 'model': self.model, 'tools': tools, **kwargs}
        response = self.client.chat.completions.create(**params)
        return response.choices[0].message

### Short-Term Memory

In [3]:
from typing import List, Dict

class ChatMessageMemory:
    """Manages conversation context."""
    
    def __init__(self):
        self.messages = []
    
    def add_message(self, message: Dict):
        """Add a message to memory."""
        self.messages.append(message)
    
    def add_messages(self, messages: List[Dict]):
        """Add multiple messages to memory."""
        for message in messages:
            self.add_message(message)
    
    def get_messages(self) -> List[Dict]:
        """Retrieve all messages."""
        return self.messages.copy()
    
    def reset_memory(self):
        """Clear all messages."""
        self.messages = []

### Agent Tool

In [4]:
from pydantic import BaseModel, ValidationError
from typing import Callable, Type
from inspect import signature

class AgentTool:
    """Encapsulates a Python function with Pydantic validation."""
    def __init__(self, func: Callable, args_model: Type[BaseModel]):
        self.func = func
        self.args_model = args_model
        self.name = func.__name__
        self.description = func.__doc__ or self.args_schema.get('description', '')

    def to_openai_function_call_definition(self) -> dict:
        """Converts the tool to OpenAI Function Calling format."""
        schema_dict = self.args_schema
        description = schema_dict.pop("description", "")
        return {
            "type": "function",
            "function": {
                "name": self.name,
                "description": description,
                "parameters": schema_dict
            }
        }

    @property
    def args_schema(self) -> dict:
        """Returns the tool's function argument schema as a dictionary."""
        schema = self.args_model.model_json_schema()
        schema.pop("title", None)
        return schema

    def validate_json_args(self, json_string: str) -> bool:
        """Validate JSON string using the Pydantic model."""
        try:
            validated_args = self.args_model.model_validate_json(json_string)
            return isinstance(validated_args, self.args_model)
        except ValidationError:
            return False

    def run(self, *args, **kwargs) -> Any:
        """Execute the function with validated arguments."""
        try:
            # Handle positional arguments by converting them to keyword arguments
            if args:
                sig = signature(self.func)
                arg_names = list(sig.parameters.keys())
                kwargs.update(dict(zip(arg_names, args)))

            # Validate arguments with the provided Pydantic schema
            validated_args = self.args_model(**kwargs)
            return self.func(**validated_args.model_dump())
        except ValidationError as e:
            raise ValueError(f"Argument validation failed for tool '{self.name}': {str(e)}")
        except Exception as e:
            raise ValueError(f"An error occurred during the execution of tool '{self.name}': {str(e)}")

    def __call__(self, *args, **kwargs) -> Any:
        """Allow the AgentTool instance to be called like a regular function."""
        return self.run(*args, **kwargs)

### Agent Tool Decorator

In [5]:
from typing import Callable, Optional, Type
from pydantic import BaseModel

def check_docstring(func: Callable):
    """Ensure the function has a docstring."""
    if not func.__doc__:
        raise ValueError(f"Function '{func.__name__}' must have a docstring.")

def Tool(func: Optional[Callable] = None, *, args_model: Type[BaseModel]) -> AgentTool:
    """Decorator to wrap a function with an AgentTool instance."""
    def decorator(f: Callable) -> AgentTool:
        check_docstring(f)
        return AgentTool(f, args_model=args_model)
    return decorator(func) if func else decorator

### Agent Tool Executor

In [73]:
from typing import Any, Dict, List, Optional
import json

class AgentToolExecutor:
    """Manages tool registration and execution."""
    
    def __init__(self, tools: Optional[List[AgentTool]] = None):
        self.tools: Dict[str, AgentTool] = {}
        if tools:
            for tool in tools:
                self.register_tool(tool)
    
    def register_tool(self, tool: AgentTool):
        """Registers a tool."""
        if tool.name in self.tools:
            raise ValueError(f"Tool '{tool.name}' is already registered.")
        self.tools[tool.name] = tool
      
    def execute(self, tool_name: str, tool_args: str) -> Any:
        """Executes a tool by name with given arguments."""
        print(f"Checking if {tool_name} tool exists..")
        tool = self.tools.get(tool_name)
        if not tool:
            raise ValueError(f"Tool '{tool_name}' not found.")
        try:
            print(f"Validating {tool_name} suggested args {tool_args}")
            if tool.validate_json_args(tool_args):
                tool_args_dict = json.loads(tool_args)
                print(f"Executing {tool_name} with args: {tool_args}")
                return tool.run(**tool_args_dict)
            else:
                raise ValueError(f"Error validating tool '{tool_name}' arguments.")
        except Exception as e:
            raise ValueError(f"Error executing tool '{tool_name}': {e}") from e
    
    def get_tool_names(self) -> List[str]:
        """Returns a list of all registered tool names."""
        return list(self.tools.keys())
    
    def get_tool_details(self) -> str:
        """Returns details of all registered tools."""
        tools_info = [f"{tool.name}: {tool.description} Args schema: {tool.args_schema['properties']}" for tool in self.tools.values()]
        return '\n'.join(tools_info)

### Prompt Formatter

In [74]:
import re

class StringPromptTemplate:
    """Handles dynamic prompt formatting for an AI agent."""

    def __init__(self, template: str):
        """Initializes the prompt template and extracts variables."""
        self.template = template
        self.variables = {}
        self.required_variables = self.extract_variables()

    def extract_variables(self):
        """Extracts placeholders from the template."""
        return set(re.findall(r'\{(.*?)\}', self.template))

    def update_variables(self, **kwargs):
        """Updates template variables."""
        self.variables.update(kwargs)
        self.required_variables -= set(kwargs.keys())

    def format_prompt(self, **kwargs):
        """Generates a formatted prompt and tracks remaining variables."""
        combined_variables = {**self.variables, **kwargs}
        self.required_variables -= set(kwargs.keys())
        return self.template.format(**combined_variables)

### Agent Base

In [75]:
import logging
from typing import Dict, List, Optional

logger = logging.getLogger(__name__)

class Agent:
    """Integrates key components and manages tool executions."""
    
    def __init__(self, llm_client, system_message: Dict[str, str], max_iterations: int = 10, tools: Optional[List[AgentTool]] = None, prompt_template: StringPromptTemplate = None):
        self.llm_client = llm_client
        self.executor = AgentToolExecutor()
        self.memory = ChatMessageMemory()
        self.system_message = system_message
        self.max_iterations = max_iterations
        self.tool_history = []
        self.function_calls = None
        self.prompt_template = prompt_template

        if tools:
            for tool in tools:
                self.executor.register_tool(tool)
            self.function_calls = [tool.to_openai_function_call_definition() for tool in tools]

        tool_details = self.executor.get_tool_details()
        tool_names = ' or '.join(self.executor.get_tool_names())
        self.prompt_template.update_variables(
            system_message=self.system_message,
            tool_details=tool_details,
            tool_names=tool_names
        )

    def run(self, task: str):
        """Generates responses, manages tool calls, and updates memory."""
        self.memory.add_message({"role": "user", "content": task})

        for _ in range(self.max_iterations):
            chat_history = self.messages_to_string()
            formatted_message = self.prompt_template.format_prompt(chat_history=chat_history, user_input=task)
            messages = [{"role": "user", "content": formatted_message}]
            response = self.llm_client.generate(messages=messages)
            action_dict = self.parse_response(response)

            if action_dict:
                action_name = action_dict["name"]
                action_arguments = action_dict["args_json"]
                execution_results = self.executor.execute(action_name, action_arguments)
                return execution_results
            else:
                logger.info("Agent is responding directly.")
                self.memory.add_messages(user_message={"role": "user", "content": task}, assistant_message=response)
                return response

    def parse_response(self, response: Dict):
        """Extracts tools or continues the conversation."""
        import regex

        pattern = regex.compile(r'\{(?:[^{}]|(?R))*\}')  # Supports nested structures
        message_content = response.content
        # Unescape backslashes
        message_content = message_content.replace('\\\\n', '\\n').replace('\\n', '\n').replace('\\\'', '\'').replace('\\\\', '\\')
        # Replace double curly braces with single curly braces
        message_content = message_content.replace('{{', '{').replace('}}', '}')

        match = pattern.search(message_content)
        if match:
            action_content = match.group()
            try:
                action_dict = json.loads(action_content.strip())
                action_dict['args_json'] = json.dumps(action_dict["arguments"])
                return action_dict
            except json.JSONDecodeError:
                raise ValueError("Invalid JSON in action content")
        return None
    
    def messages_to_string(self) -> str:
        """Converts messages to a string."""
        formatted_messages = []
        for message in self.memory.get_messages():
            formatted_messages.append(f"{message['role'].capitalize()}: {message['content']}")
        return "\n".join(formatted_messages)

## Updating Agent Base

In [94]:
import logging
from typing import Dict, List, Optional
from pydantic import BaseModel
from typing import List, Dict

logger = logging.getLogger(__name__)

class Agent:
    """
    Basic Agent class responsible for integrating key components such as the LLM client, tools, memory, and managing tool executions.
    """
    def __init__(self, llm_client, system_message: Dict[str, str], max_iterations: int = 10, tools: Optional[List[AgentTool]] = None, prompt_template : StringPromptTemplate=None):
        self.llm_client = llm_client
        self.executor = AgentToolExecutor()
        self.memory = ChatMessageMemory()
        self.system_message = system_message
        self.max_iterations = max_iterations
        self.tool_history = []
        self.function_calls = None
        self.prompt_template = prompt_template  # Instance of StringPromptTemplate or similar
        
        # Register each tool passed to the Agent using the executor
        if tools:
            for tool in tools:
                # Register Agent Tools
                self.executor.register_tool(tool)
                # Convert Agent Tools
                self.function_calls = [tool.to_openai_function_call_definition() for tool in tools]
        
        # Pre-fill the prompt template with initial variables
        tool_details=self.executor.get_tool_details(),
        tool_names=' or '.join(self.executor.get_tool_names())
        self.prompt_template.update_variables(
            system_message=self.system_message,
            tool_details=tool_details,
            tool_names=tool_names
        )

    def run(self, task:str):
        # Get Chat History
        chat_history = self.messages_to_string()
        # Initialize ReAct Loop
        react_loop = ""

        # Showing Initial Task
        print(f"Question: {task}")

        for _ in range(self.max_iterations):
            # Generate a dynamic prompt using variables
            formatted_message = self.prompt_template.format_prompt(chat_history=chat_history,user_input=task,react_loop=react_loop)
            # Define everything as a user message
            messages = [{"role":"user", "content": formatted_message}]
            
            # Instruct LLM to choose the right tool and respond with a structured output
            response = self.llm_client.generate(messages=messages,stop=["\nObservation:"],)

            # Parse response and extract tools if any
            action_dict = self.parse_response(response)

            if action_dict:
                current_thought = f"{response.content}"
                print(current_thought)
                
                action_name = action_dict["name"]
                action_arguments = action_dict["args_json"]
                execution_results = self.executor.execute(action_name, action_arguments)
                
                current_observation = f"Observation: {execution_results}"
                print(current_observation)
                react_loop += (current_thought + current_observation)
            else:
                message_content = response.content
                print(message_content)
                if 'final answer' in str(message_content).lower():
                    final_message = str(message_content).lower().split("final answer:")[-1].strip()
                    response = {
                        "role": "assistant",
                        "content": final_message
                    }
                
                logger.info("Agent is responding directly.")
                self.memory.add_messages([{"role":"user","content": task},response])
                return response    

    def parse_response(self, response: Dict):
        """Extracts tools or continues the conversation."""
        import regex

        pattern = regex.compile(r'\{(?:[^{}]|(?R))*\}')  # Supports nested structures
        message_content = response.content
        # Unescape backslashes
        message_content = message_content.replace('\\\\n', '\\n').replace('\\n', '\n').replace('\\\'', '\'').replace('\\\\', '\\')
        # Replace double curly braces with single curly braces
        message_content = message_content.replace('{{', '{').replace('}}', '}')

        match = pattern.search(message_content)
        if match:
            action_content = match.group()
            try:
                action_dict = json.loads(action_content.strip())
                action_dict['args_json'] = json.dumps(action_dict["arguments"])
                return action_dict
            except json.JSONDecodeError:
                raise ValueError("Invalid JSON in action content")
        return None
    
    def messages_to_string(self) -> str:
        """
        Converts a list of message objects or dictionaries into a multi-line string representation.
        """
        formatted_messages = []

        for message in self.memory.get_messages():
            if isinstance(message, BaseModel):
                message = message.model_dump()
            formatted_messages.append(f"{message["role"].capitalize()}: {message["content"]}")
        return "\n".join(formatted_messages)

## Update LLM Client

I updated the LLM Client class to use the stop parameter in [OpenAI's Chat Completion API](https://platform.openai.com/docs/api-reference/chat/create?ref=blog.openthreatresearch.com), stopping token generation at \nObservation:

In [95]:
from typing import Dict, Any, List
import openai

class OpenAIChatCompletion:
    """Handles interaction with OpenAI's API for generating chat completions."""
    def __init__(self, model: str, api_key: str = None, base_url: str = None):
        self.client = openai.OpenAI(api_key=api_key, base_url=base_url)
        self.model = model

    def generate(self, messages: List[str], tools: List[Dict[str, Any]] = None, stop: List[str] = None, **kwargs) -> Dict[str, Any]:
        """Generate a response from OpenAI's API based on input messages."""
        params = {
            'messages': messages,
            'model': self.model,
            'tools': tools,
            'stop': stop,
            **kwargs
        }
        response = self.client.chat.completions.create(**params)
        return response.choices[0].message

## Test ReAct Agent

### Initialize Tools

In [96]:
from pydantic import BaseModel, Field
import random

class GetWeatherSchema(BaseModel):
    """Get weather information based on location."""
    location: str = Field(description="Location to get weather for")

@Tool(args_model=GetWeatherSchema)
def get_weather(location: str) -> str:
    """Gets weather information."""
    temperature = random.randint(60, 80)
    return f"{location}: {temperature}F."

class JumpSchema(BaseModel):
    """Jump a specific distance"""
    distance: str = Field(description="Specific distance to jump for")

@Tool(args_model=JumpSchema)
def jump(distance: str) -> str:
    """Jumps a specific distance."""
    return f"I jumped the following distance {distance}"

tools = [get_weather, jump]

### Import ReAct Template

In [97]:
from ReAct_Template import STRING_PROMPT_TEMPLATE

### Initialize Prompt

In [98]:
prompt_template = StringPromptTemplate(STRING_PROMPT_TEMPLATE)

### Initialize Client

In [99]:
# API from environment variable
# import os
# api_key = os.getenv("OPENAI_API_KEY"))

api_key=""

client = OpenAIChatCompletion(
    base_url='https://api.openai.com/v1',
    model='gpt-4o',
    api_key=api_key
)

### Define System messages

In [100]:
# Define the system message
system_message = {"role": "system", "content": "You are a weather assistant."}

### Initialize Agent

In [101]:
# Initialize the Agent with the LLM client and system message
agent = Agent(llm_client=client, system_message=system_message, tools=tools, prompt_template=prompt_template)

### Send a User Message

In [102]:
agent.run("What is the weather in New York?")

Question: What is the weather in New York?
Thought: The user wants to know the weather in New York. I'll use the get_weather tool to get the required information.
Action:
```
{
    "name": "get_weather",
    "arguments": {
        "location": "New York"
    }
}
```
Checking if get_weather tool exists..
Validating get_weather suggested args {"location": "New York"}
Executing get_weather with args: {"location": "New York"}
Observation: New York: 60F.
Thought: I know what to respond as final answer. Using tool to provide final answer.
Final Answer: The current weather in New York is 60°F.


{'role': 'assistant', 'content': 'the current weather in new york is 60°f.'}