# Reflexion Pattern Implementation with Strands Multiagent Graph

## Overview

This implementation demonstrates the **Reflexion pattern** using the Strands multiagent graph framework. Reflexion enables iterative self-improvement through reflection and revision, creating a quality-focused system that generates initial responses, critically evaluates them, and iteratively improves until reaching acceptable quality standards.

## Architecture

The system consists of two specialized agents connected in a sequential graph with iterative feedback:

```
User Query → [Draft Agent] → [Revisor Agent] → Final Response
                                    ↓
                              Internal Loop
                              (max 5 iterations)
```

### 1. Draft Agent
- **Purpose**: Generates initial flight responses with built-in self-reflection
- **Input**: User query (e.g., "Change my flight to earlier time")
- **Output**: Initial response + reflection analysis + revision decision
- **Implementation**: Custom agent class using `generate_initial_answer` tool

### 2. Revisor Agent
- **Purpose**: Iteratively improves responses based on reflection feedback
- **Input**: Draft response + reflection analysis + original user query
- **Process**: Uses internal loop to iterate revisions until quality threshold met or max iterations reached
- **Output**: Final improved response after revision cycles
- **Implementation**: Custom agent class with state management and `generate_revised_answer` tool

## Key Technical Details

### Custom Agent Classes
Due to multiagent graph requirements, each agent extends the base `Agent` class with custom `stream_async()` methods:

```python
class DraftAgent(Agent):
    async def stream_async(self, prompt: str):
        result = self.tool.generate_initial_answer(query=prompt)
        message = Message(content=[{"text": str(result)}])
        agent_result = AgentResult(
            stop_reason="end_turn",
            message=message,
            metrics=EventLoopMetrics(),
            state=None
        )
        yield {"result": agent_result}

class RevisorAgent(Agent):
    async def stream_async(self, input_data):
        state = ReflexionState(**prev_state) if prev_state else ReflexionState()
        
        # Internal revision loop
        while state.needs_revision and state.revision_count < state.max_iterations:
            state.revision_count += 1
            result = self.tool.generate_revised_answer(...)
            # Update state with new response and reflection
        
        yield {"result": agent_result}
```

### Quality Assessment
- **Multi-dimensional evaluation**: Completeness, clarity, actionability, user experience
- **Query improvement**: Enhances prompts based on reflection feedback
- **Error recovery**: Identifies and corrects incomplete responses

## Usage

```python
# Create the reflexion multiagent graph
reflexion_graph = create_reflexion_graph()

# Execute with user query
result = reflexion_graph(
    "I am Anya Garcia (ID: anya_garcia_5901). I booked flight 3RK2T9 and want to change passenger name from Mei Lee to Mei Garcia."
)

# Access results from each agent
draft_result = result.results['draft']    # Initial response + reflection
revisor_result = result.results['revisor']  # Final improved response
```

This implementation showcases how the Strands multiagent graph framework can implement self-improving AI systems that prioritize response quality through systematic reflection and iterative refinement.

In [None]:
!pip3 install -r ./requirements.txt --quiet --upgrade
!pip3 install strands-agents strands-agents-tools --quiet


In [None]:
import time
import boto3
import ipywidgets as widgets
import uuid
import pandas as pd
import numpy as np
import os
import shutil
import sqlite3
import functools
import requests
import pytz
import warnings
from IPython.display import Image, display
from botocore.config import Config
from typing import Annotated, Literal, Optional, Union
from typing_extensions import TypedDict
from bs4 import BeautifulSoup
from datetime import date, datetime

from typing import List, Dict, Any
import re
import json


from strands import Agent
from strands import tool
from strands.models import BedrockModel
from strands.agent.conversation_manager import SlidingWindowConversationManager

from strands.multiagent.graph import GraphBuilder
from strands.agent import AgentResult
from strands.types.content import Message
from strands.types.streaming import StopReason
from strands.telemetry.metrics import EventLoopMetrics

import logging

## Get MAbench and Taubench tools

In [None]:
import sys
sys.path.append('../data/ma-bench/')
sys.path.append('../data/tau-bench/')

from mabench.environments.airline.tools.book_reservation import book_reservation
from mabench.environments.airline.tools.calculate import calculate
from mabench.environments.airline.tools.cancel_reservation import cancel_reservation
from mabench.environments.airline.tools.get_reservation_details import get_reservation_details
from mabench.environments.airline.tools.get_user_details import get_user_details
from mabench.environments.airline.tools.list_all_airports import list_all_airports
from mabench.environments.airline.tools.search_direct_flight import search_direct_flight
from mabench.environments.airline.tools.search_onestop_flight import search_onestop_flight
from mabench.environments.airline.tools.send_certificate import send_certificate
from mabench.environments.airline.tools.think import think
from mabench.environments.airline.tools.transfer_to_human_agents import transfer_to_human_agents
from mabench.environments.airline.tools.update_reservation_baggages import update_reservation_baggages
from mabench.environments.airline.tools.update_reservation_flights import update_reservation_flights
from mabench.environments.airline.tools.update_reservation_passengers import update_reservation_passengers

domain = "airline"

from tau_bench.envs.airline.data import *
from tau_bench.envs.airline.tasks import *
from tau_bench.envs.airline.wiki import WIKI

### Use Strands

In [None]:
region = "us-east-1"


bedrock_model_taubench = BedrockModel(region_name=region)
conv_manager = SlidingWindowConversationManager(window_size=10)

# Disable logging
logging.basicConfig(level=logging.CRITICAL)
for logger_name in ["strands", "graph", "event_loop", "registry", "sliding_window_conversation_manager", "bedrock", "streaming"]:
    logging.getLogger(logger_name).setLevel(logging.CRITICAL)

## Reflexion with State Management

## Define the Reflexion State


In [None]:
from dataclasses import dataclass

@dataclass
class ReflexionState:
    """State for reflexion workflow"""
    user_query: str = ""
    response: str = ""
    reflection: str = ""
    needs_revision: bool = False
    max_iterations: int = 5
    revision_count: int = 0

In [None]:
def normalize_prompt(prompt) -> str:
    if isinstance(prompt, list):
        # assume list of dicts like [{"text": "..."}]
        texts = [p.get("text", "") for p in prompt if isinstance(p, dict)]
        return " ".join(t.strip() for t in texts if t)
    if isinstance(prompt, dict) and "text" in prompt:
        return prompt["text"].strip()
    if isinstance(prompt, str):
        return prompt.strip()
    return str(prompt).strip()

## Flight Tool Executor

In [None]:
system_prompt_template = """
You are a helpful assistant for a travel website. Help the user answer any questions.

<instructions>
-You MUST refer the <policy> to follow the guidelines to answer user question accurately
- Remeber to check if the the airport city is in the state mentioned by the user. For example, Houston is in Texas.
- Infer about the the U.S. state in which the airport city resides. For example, Houston is in Texas.
- You should not use made-up or placeholder arguments.
<instructions>

<policy>
{policy}
</policy>
"""

prompt = system_prompt_template.replace("{policy}", WIKI)


class FlightToolExecutor(Agent):
    def __init__(self):
        super().__init__(
            model=bedrock_model_taubench,
            system_prompt=prompt,
            tools=[
                book_reservation, calculate, cancel_reservation, get_reservation_details,
                get_user_details, list_all_airports, search_direct_flight, search_onestop_flight,
                send_certificate, think, transfer_to_human_agents, update_reservation_baggages,
                update_reservation_flights, update_reservation_passengers
            ]
        )

flight_executor = FlightToolExecutor()

### Draft Node with State Management

In [None]:
reflection_system_prompt="""You are analyzing a flight assistant's response that uses real flight database tools.
        
IMPORTANT: The flight data comes from real database queries, NOT hallucination.
        
Analyze the response quality on these dimensions:
1. **Completeness**: Does it address all parts of the user's query? 
2. **Final Answer**: If the user query clearly states the final goal and if it can be fulfiled as per the policy, then does the response show that?
2. **Clarity**: Is the information presented clearly and logically?
3. **Actionability**: Are next steps or options clearly presented?
4. **User Experience**: Is the tone helpful and appropriate?
5. **Missing Information**: What important details are missing?
6. **Decision**: REVISE or ACCEPT
7. **Reason**: Why this decision was made


"""

In [None]:
@tool
def generate_initial_answer(query: str) -> str:
    """Generate initial answer, reflect, and decide if revision needed"""
    
    flight_response = flight_executor(query)
    answer_text = str(flight_response)
    
    reflection_agent = Agent(
        model=bedrock_model_taubench,
        system_prompt=reflection_system_prompt
    )
    
    reflection_prompt = f"""
Original Query: {query}
Flight Assistant's Answer: {answer_text}

Remember: The flight data comes from real database queries.
Please provide a critical reflection of this answer:"""
    
    reflection_response = reflection_agent(reflection_prompt)
    reflection_text = str(reflection_response)
    
    needs_revision = "REVISE" in reflection_text.upper()
    
    return f"**Answer**: {answer_text}\n**Self-Reflection**: {reflection_text}\n**Needs Revision**: {needs_revision}"

class DraftAgent(Agent):
    def __init__(self):
        super().__init__(
            model=bedrock_model_taubench,
            tools=[generate_initial_answer],
            name="draft",
            description="Generates flight assistance answers with self-reflection"
        )
    
    async def stream_async(self, prompt: str):

        prompt=normalize_prompt(prompt)        
        result = self.tool.generate_initial_answer(query=prompt)        
        extracted = extract_answer_reflection_revision(result)
        
        message = Message(content=[{"text": str(result)}])
        print("DEBUG: REVISOR AGENT RESULT: \n", json.dumps(message), "\n")
        agent_result = AgentResult(
            stop_reason="end_turn",
            message=message,
            metrics=EventLoopMetrics(),
            state=None #state.__dict__
        )
        yield {"result": agent_result}



## Revisor Agent with State Management

In [None]:
query_improver_system_prompt="""You are a query improvement specialist. Based on reflection analysis, improve the original user query to address identified issues and guide better responses.

Examples:
Original: "Book me a flight from NYC to LA tomorrow"
Issue: "Agent booked immediately without showing options"
Improved: "Please SEARCH and SHOW ME available flight options from NYC to LA tomorrow. I want to see different times, prices, and airlines before deciding. DO NOT book anything until I confirm."

Now improve the provided query based on the specific reflection issues identified."""

In [None]:
import re

def extract_answer_reflection_revision(tool_result):
    content_text = tool_result["content"][0]["text"]

    answer_match = re.search(r"\*\*Answer\*\*:(.*?)(?=\*\*Self-Reflection\*\*:)", content_text, re.DOTALL)
    answer = answer_match.group(1).strip() if answer_match else "[Not found]"

    reflection_match = re.search(r"\*\*Self-Reflection\*\*:(.*?)(?=\*\*Needs Revision\*\*:|$)", content_text, re.DOTALL)
    self_reflection = reflection_match.group(1).strip() if reflection_match else "[Not found]"

    revision_match = re.search(r"\*\*Needs Revision\*\*:\s*(True|False)", content_text)
    needs_revision = revision_match.group(1) == "True" if revision_match else False

    query_match = re.search(r"\*\*User-Query\*\*:\s*(.*?)(?=\n|$)", content_text)
    query = query_match.group(1).strip() if query_match else "[Not found]"

    return {
        "answer": answer,
        "self_reflection": self_reflection,
        "needs_revision": needs_revision,
        "user_query": query
    }

In [None]:
@tool
def generate_revised_answer(current_user_query, current_response="", current_reflection="") -> str:
    """Generate revised answer, reflect, and decide if further revision needed"""
    
    query_improver = Agent(
        model=bedrock_model_taubench,
        system_prompt=query_improver_system_prompt
    )
    
    improved_query = query_improver(f"Current user query: {current_user_query}\nCurrent Response: {current_response}\nReflection: {current_reflection}\nCreate better query:")
    
    flight_response = flight_executor(str(improved_query))
    revised_answer = str(flight_response)
    
    revision_reflection_agent = Agent(
        model=bedrock_model_taubench,
        system_prompt=reflection_system_prompt
    )
    
    new_reflection = revision_reflection_agent(f"Task: {current_user_query} Revised Answer: {revised_answer} Analyze and decide:")
    reflection_text = str(new_reflection)
    
    needs_revision = "REVISE" in reflection_text.upper()
    
    return f"\n**User-Query**: {current_user_query}\n**Answer**: {revised_answer}\n**Self-Reflection**: {reflection_text}\n**Needs Revision**: {needs_revision}"

class RevisorAgent(Agent):
    def __init__(self):
        super().__init__(
            model=bedrock_model_taubench,
            tools=[generate_revised_answer],
            name="revisor",
            description="Revises flight responses"
        )

    async def stream_async(self, input_data):
        input_data=normalize_prompt(input_data)  
        if isinstance(input_data, str):
            state = ReflexionState(user_query=input_data)
        else:
            prev_state = getattr(input_data, 'state', {}) or {}
            print(f"PREV STATE FROM REVIOSR: {prev_state} \n")
            state = ReflexionState(**prev_state) if prev_state else ReflexionState()

        # Extract draft agent result
        draft_start = input_data.find('From draft:')
        if draft_start != -1:
            draft_content = input_data[draft_start + len('From draft:'):].strip()
            # Parse the draft result
            extracted = extract_answer_reflection_revision({'content': [{'text': draft_content}]})
            draft_response = extracted['answer']
            draft_reflection = extracted['self_reflection']
            needs_revision = extracted['needs_revision']

        state = ReflexionState(
            user_query=user_query,
            response=draft_response,
            reflection=draft_reflection,
            needs_revision=needs_revision,
            revision_count=0,
            max_iterations=5
        )
        
        print(f"Revisor starting: revision_count={state.revision_count}, needs_revision={state.needs_revision}")
        
        if state.needs_revision and state.revision_count < state.max_iterations:
            state.revision_count += 1
            
            result = self.tool.generate_revised_answer(
                current_user_query=state.user_query,
                current_response=state.response,
                current_reflection=state.reflection
            )
            
            extracted = extract_answer_reflection_revision(result)
            state.response = extracted["answer"]
            state.reflection = extracted["self_reflection"]
            state.needs_revision = extracted["needs_revision"]
        else:
            result = f"**Answer**: {state.response}\n**Final Reflection**: {state.reflection}\n**Revision Complete**: After {state.revision_count} revisions"
        
        message = Message(content=[{"text": str(result)}])
        print("DEBUG: REVISOR AGENT RESULT: \n", json.dumps(message), "\n")
        agent_result = AgentResult(
            stop_reason="end_turn",
            message=message,
            metrics=EventLoopMetrics(),
            state=state.__dict__
        )
        yield {"result": agent_result}

## Build Reflexion Graph

In [None]:
from strands.multiagent.graph import GraphBuilder, GraphState

def create_reflexion_graph():
    """Create reflexion graph with state management"""
    
    draft_agent = DraftAgent()
    revisor_agent = RevisorAgent()
    
    builder = GraphBuilder()
    
    draft_node = builder.add_node(draft_agent, "draft")
    revisor_node = builder.add_node(revisor_agent, "revisor")
    
    builder.add_edge(draft_node, revisor_node)
    builder.set_entry_point("draft")
    
    return builder.build()

reflexion_graph = create_reflexion_graph()

## Execute Graph

In [None]:
output_path = os.path.join("..", "data", "tau-bench", "tau_bench", "envs", f"{domain}", "tasks_singleturn.json")
with open(output_path, "r") as file:
    tasks = json.load(file)

In [None]:
def test_reflexion_graph(user_query):
    reflexion_graph = create_reflexion_graph()
    
    print("=== Testing Reflexion Graph ===")
    print(f"Test Prompt: {user_query}")
    print("\n--- Executing Graph ---")
    start= time.time()
    result = reflexion_graph(user_query)
    exec_time= time.time()-start
    print(f"\n EXEC Time: {exec_time}")
    print(f"\nGraph Status: {result.status}")
    print(f"Total Nodes: {result.total_nodes}")
    print(f"Completed Nodes: {result.completed_nodes}")
    print(f"Execution Order: {[node.node_id for node in result.execution_order]}")
    
    print("\n--- Node Results ---")
    for node_id, node_result in result.results.items():
        print(f"\n{node_id.upper()}:")
        print(f"Status: {node_result.status}")
        if hasattr(node_result, 'result') and node_result.result:
            print(f"Result: {str(node_result.result)}...")
    
    return result

### Test with a particular question_id 

In [None]:
# Test 
question_id = 43
task = tasks[question_id]
user_query = task["question"]
print(user_query)

reflexion_response = test_reflexion_graph(user_query)
# Save to file with question ID suffix
filename = f"./output/reflexion_response_{question_id}.txt"
with open(filename, "w", encoding="utf-8") as f:
    f.write(str(reflexion_response))
