# Libraries

## Import

In [1]:
# Libraries

import sys
sys.path.append('../data/ma-bench/')
sys.path.append('../data/tau-bench/')

import os
import json
import importlib
import argparse

import boto3
from botocore.config import Config

# Strands imports
from strands import Agent, tool
from strands.models import BedrockModel
from strands.multiagent import GraphBuilder

# Parameters

In [None]:
# setup boto3 config to allow for retrying
region_name = "us-west-2"
my_config = Config(
    region_name = region_name,
    signature_version = 'v4',
    retries = {
        'max_attempts': 50,
        'mode': 'standard'
    }
)

# select domain
domain = "airline"
# # Parse command line arguments
# parser = argparse.ArgumentParser(description='Run agent with specified domain')
# parser.add_argument('--domain', type=str, default=domain, 
#                     help='Domain to use (e.g., "airline", "retail")')
# args = parser.parse_args()

# # Update domain if provided via command line
# domain = args.domain

# Utils

In [None]:
def import_domain_tools(domain):
    """
    Dynamically import tools based on the domain
    """
    tools_module = importlib.import_module(f'mabench.environments.{domain}.tools_strands')
    tools_dict = {}
    
    # Get all attributes from the tools module
    for attr_name in dir(tools_module):
        if attr_name.startswith('__'):
            continue
        
        try:
            # Try to import each tool
            tool_module = importlib.import_module(f'mabench.environments.{domain}.tools_strands.{attr_name}')
            # Get the tool function from the module
            if hasattr(tool_module, attr_name):
                tools_dict[attr_name] = getattr(tool_module, attr_name)
        except (ImportError, AttributeError):
            pass
    
    return tools_dict

In [None]:
# Import domain-specific modules
try:
    # Import wiki
    wiki_module = importlib.import_module(f'tau_bench.envs.{domain}.wiki')
    WIKI = getattr(wiki_module, 'WIKI')
    
    # Import data and tasks
    importlib.import_module(f'tau_bench.envs.{domain}.data')
    importlib.import_module(f'tau_bench.envs.{domain}.tasks')
    
    # Import tools
    domain_tools = import_domain_tools(domain)
    
    print(f"Successfully loaded modules for domain: {domain}")
except ImportError as e:
    print(f"Error: Could not import modules for domain '{domain}'. Error: {e}")
    print("Available domains may include: airline, retail")
    sys.exit(1)

Successfully loaded modules for domain: airline


# Agent

In [None]:
tools = list(domain_tools.values())

def agent_prompt():
    
    system_prompt_template = """
You are a helpful assistant for a travel website. Help the user answer any questions.

<instructions>
- Remeber to check if the the airport city is in the state mentioned by the user. For example, Houston is in Texas.
- Infer about the the U.S. state in which the airport city resides. For example, Houston is in Texas.
- You should not use made-up or placeholder arguments.
<instructions>

<policy>
{policy}
</policy>
"""

    prompt = system_prompt_template.format(policy = WIKI)

    return prompt


def agent_model():

    model_id = "anthropic.claude-3-sonnet-20240229-v1:0" # "anthropic.claude-3-sonnet-20240229-v1:0" "anthropic.claude-3-5-sonnet-20240620-v1:0", "us.anthropic.claude-3-5-sonnet-20241022-v2:0" 

    return BedrockModel(
        model_id = model_id,
        region_name = region_name,
        max_tokens= 1024,
        temperature = 0.0,
        top_p = 1,
        boto_client_config=my_config,
    )


def react_agent(tools):

    prompt = agent_prompt()
    model = agent_model()

    return Agent( 
        model = model, 
        tools = tools, 
        system_prompt = prompt,
    )

# Run

In [6]:
output_path = os.path.join("..", "data", "tau-bench", "tau_bench", "envs", f"{domain}", "tasks_singleturn.json")
with open(output_path, "r") as file:
    tasks = json.load(file)


In [None]:
for index,task in enumerate(tasks):

    index_str = str(index)
    num_hashes = (50 - len(index_str) - 9) // 2
    print(f"\n{'#' * num_hashes} Index:{index} {'#' * num_hashes}\n")

    question = task['question']
    print(f"Processing question: {question}")

    # user = simulated_user(instruction)
    agent = react_agent(tools)

    messages = agent(question)
    print(messages)

    break