## Create Supervisor Agent

In this notebook we will create a supervisor agent with Amazon Bedrock Agents that will be able to orchestrate the execution of the data engineer, business analyst and data scientist agents.

And then we will evaluate the multi-agent orchestration in different scenarios.


In [None]:
# Import necessary libraries and load environment variables
from dotenv import load_dotenv, find_dotenv, set_key
import os
import sagemaker
import boto3
import json
import os
import pandas as pd
from utils.bedrock import BedrockLLMWrapper
import mlflow

# loading environment variables that are stored in local file
local_env_filename = 'dev.env'
load_dotenv(find_dotenv(local_env_filename),override=True)

os.environ['REGION'] = os.getenv('REGION')
os.environ['S3_BUCKET_NAME'] = os.getenv('S3_BUCKET_NAME')
os.environ['AWS_ACCOUNT'] = os.getenv('AWS_ACCOUNT')
os.environ['MAC_EVAL_PROFILE_ARN'] = os.getenv('MAC_EVAL_PROFILE_ARN')
os.environ['MLFLOW_SERVER_NAME'] = os.getenv('MLFLOW_SERVER_NAME')
os.environ['BEDROCK_AGENT_ID'] = os.getenv('BEDROCK_AGENT_ID')
os.environ['USER_FEEDBACK_TABLE_NAME'] = os.getenv('USER_FEEDBACK_TABLE_NAME')
os.environ['TRACE_TABLE_NAME'] = os.getenv('TRACE_TABLE_NAME')

REGION = os.environ['REGION']
S3_BUCKET_NAME = os.environ['S3_BUCKET_NAME']
AWS_ACCOUNT = os.environ['AWS_ACCOUNT']
MAC_EVAL_PROFILE_ARN = os.environ['MAC_EVAL_PROFILE_ARN']
MLFLOW_SERVER_NAME = os.environ['MLFLOW_SERVER_NAME']
BEDROCK_AGENT_ID = os.environ['BEDROCK_AGENT_ID']
USER_FEEDBACK_TABLE_NAME = os.environ['USER_FEEDBACK_TABLE_NAME']
TRACE_TABLE_NAME = os.environ['TRACE_TABLE_NAME']

# Bedrock Agents does not yet support application inference profiles
MODEL_ID =  "anthropic.claude-3-5-sonnet-20240620-v1:0"

In [3]:
import botocore.config
config = botocore.config.Config(
    connect_timeout=600,  # 10 minutes
    read_timeout=600,     # 10 minutes
    retries={'max_attempts': 3}
)

session = boto3.Session(region_name=REGION)

# Create a SageMaker session
sagemaker_session = sagemaker.Session(boto_session=session)

bedrock_agent_client = session.client('bedrock-agent', config=config)
bedrock_agent_runtime_client = session.client('bedrock-agent-runtime', config=config)
bedrock_runtime_client = session.client('bedrock-runtime', config=config)
bedrock_client = session.client('bedrock', config=config)
lambda_client = session.client('lambda', config=config)
iam_resource = session.resource('iam')
iam_client = session.client('iam')
athena_client = session.client('athena')
s3_client = session.client('s3')

## Create a Lambda function that can be used to retrieve user feedback and traces from DynamoDB

## Create & Test Bedrock Agent

In [None]:
%%writefile ../supervisor/bedrock_supervisor_agent.py
import boto3
from botocore.config import Config
from botocore.exceptions import ClientError, NoCredentialsError, PartialCredentialsError
import os
import logging
import json
from enum import Enum
import time
from datetime import datetime, timedelta
from typing import Dict, Any, List, TypedDict, Optional
from urllib.parse import urlparse
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
import urllib3
from pydantic import BaseModel
from utils.bedrock import BedrockLLMWrapper
from boto3.dynamodb.conditions import Key
import sys

# Initialize logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)

TRACE_TABLE_NAME = os.getenv('TRACE_TABLE_NAME')
USER_FEEDBACK_TABLE_NAME = os.getenv('USER_FEEDBACK_TABLE_NAME')
MODEL_ID = os.getenv('MODEL_ID')

LESSONS_LEARNED_PROMPT_TEMPLATE = '''Analyze the following information:

user feedback:
{USER_FEEDBACK}

traces from past runs:
{PAST_RUNS}

1. Review the user feedback and traces from past runs.

2. Extract the lessons learned with regards to agent orchestration and function/tool calling based on the past runs and user feedback and return them in a list of strings.

Sample output:
- when creating a SQL query that requires functions, ensure you have the correct Athena function names, e.g. DATE_DIFF instead of datediff
- when calling the data scientist agent, ensure you provide the ml dataset location and target column name, otherwise don't call the data scientist agent
'''



class SupervisorTools:
    """Collection of tools for the Supervisor Agent to use"""
    
    def __init__(self):
        logger.info(f"Initializing SupervisorTools")
        self.dynamodb = boto3.client('dynamodb')
        self.bedrock_llm = BedrockLLMWrapper(model_id=MODEL_ID, 
                            max_token_count=2000,
                            temperature=0
                        )
    
    def get_user_feedback(self, user_id: str, conversation_id: str) -> List[Dict[str, Any]]:
        """Retrieve user feedback from DynamoDB"""
        try:
            response = self.dynamodb.get_item(
                TableName=USER_FEEDBACK_TABLE_NAME,
                Key={'user_id': {'S': user_id}, 'conversation_id': {'S': conversation_id}}
            )
            return response.get('Item', {}).get('feedback', [])
        except ClientError as e:
            logger.error(f"Error retrieving user feedback: {e}")
            return []
    
    
    def get_all_user_feedback(self) -> List[Dict[str, Any]]:
        """Retrieve all user feedback from DynamoDB"""
        try:
            response = self.dynamodb.scan(TableName=USER_FEEDBACK_TABLE_NAME)
            return response.get('Items', [])
        except ClientError as e:
            logger.error(f"Error retrieving all user feedback: {e}")
            return []

    def get_all_traces_for_user(self, user_id: str) -> List[Dict[str, Any]]:
        """Retrieve all traces for a user from DynamoDB"""
        try:
            response = self.dynamodb.scan(TableName=TRACE_TABLE_NAME,
                                         FilterExpression=Key('user_id').eq(user_id))
            return response.get('Items', [])
        except ClientError as e:
            logger.error(f"Error retrieving all traces for user {user_id}: {e}")
            return []
    
    def get_all_conversation_traces(self, conversation_id: str) -> List[Dict[str, Any]]:
        """Retrieve all traces for a conversation from DynamoDB"""
        try:
            response = self.dynamodb.scan(TableName=TRACE_TABLE_NAME,
                                         FilterExpression=Key('conversation_id').eq(conversation_id))
            return response.get('Items', [])
        except ClientError as e:
            logger.error(f"Error retrieving all traces for conversation {conversation_id}: {e}")
            return []
    
    def get_all_traces(self) -> List[Dict[str, Any]]:
        """Retrieve all traces from DynamoDB"""
        try:
            response = self.dynamodb.scan(TableName=TRACE_TABLE_NAME)
            return response.get('Items', [])
        except ClientError as e:
            logger.error(f"Error retrieving all traces: {e}")
            return []

    def get_lessons_learned_from_past_runs(self):
        """Get lessons learned from past runs and user feedback"""
        try:
            user_feedback = self.get_all_user_feedback()
            past_runs = self.get_all_traces()
            # convert to a string
            past_runs_string = '\n'.join([str(run) for run in past_runs])
            user_feedback_string = '\n'.join([str(feedback) for feedback in user_feedback])

            prompt = LESSONS_LEARNED_PROMPT_TEMPLATE.format(USER_FEEDBACK=user_feedback_string, PAST_RUNS=past_runs_string)
            response = self.bedrock_llm.generate(prompt)
            return response[0]
        except Exception as e:
            logger.error(f"Error getting lessons learned from past runs: {e}", exc_info=True)
            return f'Error encountered while getting lessons learned from past runs and user feedback: {e}'



def truncate_response(data: Any, max_size: int = 20000) -> Any:
    """Truncate response data to stay within size limits"""
    if isinstance(data, dict):
        serialized = json.dumps(data)
        if len(serialized) <= max_size:
            return data
            
        # For dictionary responses, try to preserve structure while reducing content
        truncated = data.copy()
        if 'data' in truncated and isinstance(truncated['data'], list):
            # Calculate approximate size per record
            record_count = len(truncated['data'])
            if record_count > 0:
                avg_record_size = len(json.dumps(truncated['data'])) / record_count
                # Calculate how many records we can keep
                safe_record_count = int((max_size * 0.8) / avg_record_size)  # 80% of max size
                truncated['data'] = truncated['data'][:safe_record_count]
                truncated['truncated'] = True
                truncated['total_records'] = record_count
                truncated['showing_records'] = safe_record_count
                return truncated
                
    elif isinstance(data, list):
        serialized = json.dumps(data)
        if len(serialized) <= max_size:
            return data
            
        # For list responses, truncate the list
        original_length = len(data)
        # Calculate approximate size per item
        if original_length > 0:
            avg_item_size = len(serialized) / original_length
            safe_item_count = int((max_size * 0.8) / avg_item_size)  # 80% of max size
            return {
                'data': data[:safe_item_count],
                'truncated': True,
                'total_items': original_length,
                'showing_items': safe_item_count
            }
    
    return data


def lambda_handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]:
    try:
        logger.info(f"Received event: {json.dumps(event)}")
        
        # # Extract parameters from requestBody if present
        # parameters_dict = {}
        # if 'requestBody' in event and 'content' in event['requestBody']:
        #     content = event['requestBody']['content']
        #     if 'application/json' in content and 'properties' in content['application/json']:
        #         for prop in content['application/json']['properties']:
        #             parameters_dict[prop['name']] = prop['value']
        

        # Initialize tools
        tools = SupervisorTools(
        )
        
        # Extract APIPath from event
        api_path = event.get('apiPath', '').strip('/')
        
        logger.info(f"API Path: {api_path}")

        response_data = None
        
        if api_path == 'GetLessonsLearnedFromPastRuns':
            lessons_learned = tools.get_lessons_learned_from_past_runs()
            response_data = lessons_learned
            
        else:
            raise ValueError(f"Invalid API path: {api_path}")
        
        # Check and truncate response size if needed
        response_size = sys.getsizeof(json.dumps(response_data))
        if response_size > 20000:  # 20KB limit
            logger.warning(f"Response size {response_size} exceeds limit. Truncating content...")
            response_data = truncate_response(response_data)
                
        response_body = {
            'application/json': {
                'body': response_data
            }
        }
        
        # Set response code based on status if it exists
        response_code = 200
        if isinstance(response_data, dict) and response_data.get('status') in ['FAILED', 'CANCELLED', 'ERROR']:
            response_code = 400

        action_response = {
            'actionGroup': event['actionGroup'],
            'apiPath': event['apiPath'],
            'httpMethod': event['httpMethod'],
            'httpStatusCode': response_code,
            'responseBody': response_body
        }

        return {'messageVersion': '1.0', 'response': action_response}
            
    except ValueError as e:
        # Handle bad request errors (400)
        logger.error(f"Validation error: {str(e)}")
        return {
            "messageVersion": "1.0",
            "response": {
                "actionGroup": event.get('actionGroup'),
                "apiPath": event.get('apiPath'),
                "httpMethod": event.get('httpMethod', 'POST'),
                "httpStatusCode": 400,
                "responseBody": {
                    "application/json": {
                        "body": {
                            "error": str(e)
                        }
                    }
                }
            }
        }
    except Exception as e:
        # Handle internal server errors (500)
        logger.error(f"Internal error: {str(e)}", exc_info=True)
        return {
            "messageVersion": "1.0",
            "response": {
                "actionGroup": event.get('actionGroup'),
                "apiPath": event.get('apiPath'),
                "httpMethod": event.get('httpMethod', 'POST'),
                "httpStatusCode": 500,
                "responseBody": {
                    "application/json": {
                        "body": {
                            "error": f"Internal server error: {str(e)}"
                        }
                    }
                }
            }
        }         


In [None]:
%%writefile ../supervisor/supervisor.Dockerfile
FROM public.ecr.aws/lambda/python:3.11

# Install build dependencies first
RUN yum install libgomp git gcc gcc-c++ make -y \
 && yum clean all -y && rm -rf /var/cache/yum


RUN python3 -m pip --no-cache-dir install --upgrade --trusted-host pypi.org --trusted-host files.pythonhosted.org pip \
 && python3 -m pip --no-cache-dir install --upgrade wheel setuptools \
 && python3 -m pip --no-cache-dir install --upgrade pandas \
 && python3 -m pip --no-cache-dir install --upgrade boto3 \
 && python3 -m pip --no-cache-dir install --upgrade opensearch-py \
 && python3 -m pip --no-cache-dir install --upgrade Pillow \
 && python3 -m pip --no-cache-dir install --upgrade pyarrow \
 && python3 -m pip --no-cache-dir install --upgrade fastparquet \
 && python3 -m pip --no-cache-dir install --upgrade urllib3 \
 && python3 -m pip --no-cache-dir install --upgrade pydantic

# Copy function code
WORKDIR /var/task
COPY ../supervisor/bedrock_supervisor_agent.py .
COPY ../notebooks/utils/ utils/ 

# Set handler environment variable
ENV _HANDLER="bedrock_supervisor_agent.lambda_handler"

# Let's go back to using the default entrypoint
ENTRYPOINT [ "/lambda-entrypoint.sh" ]
CMD [ "bedrock_supervisor_agent.lambda_handler" ]

In [None]:
# Build and run local docker container
!docker build -t supervisor-lambda -f ../supervisor/supervisor.Dockerfile ..

In [None]:
# docker run with tailing log
credentials = session.get_credentials()
credentials = credentials.get_frozen_credentials()

!docker run -d \
-e AWS_ACCESS_KEY_ID={credentials.access_key} \
-e AWS_SECRET_ACCESS_KEY={credentials.secret_key} \
-e AWS_SESSION_TOKEN={credentials.token} \
-e AWS_DEFAULT_REGION={REGION} \
-e REGION={REGION} \
-e AWS_LAMBDA_FUNCTION_TIMEOUT=900 \
-e MODEL_ID={MODEL_ID} \
-e USER_FEEDBACK_TABLE_NAME={USER_FEEDBACK_TABLE_NAME} \
-e TRACE_TABLE_NAME={TRACE_TABLE_NAME} \
-p 9000:8080 supervisor-lambda

In [None]:
!docker ps --filter ancestor=supervisor-lambda

In [None]:
# get user feedback
request_body = {
    "apiPath": "/GetLessonsLearnedFromPastRuns",
    "requestBody": {
        "content": {
            "application/json": {
                "properties": [
                    {
                        "name": "user_id",
                        "type": "string",
                        "value": f"XXX"
                    }
                ]
            }
        }
    },
    "httpMethod": "POST",
    "actionGroup": "BusinessAnalystActions",
}

import requests
response = requests.post("http://localhost:9000/2015-03-31/functions/function/invocations",
                         json=request_body,
                         timeout=900  # 15 minutes timeout
)
print(response.json())


In [None]:
# stop the container
!docker stop $(docker ps -q --filter ancestor=supervisor-lambda)
!docker ps --filter ancestor=supervisor-lambda

In [11]:
## Create ECR repository for dataengineer-lambda (if not already created in 1_environmentSetup.ipynb)
#!aws ecr create-repository --repository-name automatedinsights/lambda_supervisor --region {REGION} --profile {SESSION_PROFILE}

In [None]:
# Upload docker image to ECR
!aws ecr get-login-password --region {REGION} --profile {SESSION_PROFILE} | docker login --username AWS --password-stdin {AWS_ACCOUNT}.dkr.ecr.{REGION}.amazonaws.com
!docker tag supervisor-lambda:latest {AWS_ACCOUNT}.dkr.ecr.{REGION}.amazonaws.com/automatedinsights/lambda_supervisor:latest
!docker push {AWS_ACCOUNT}.dkr.ecr.{REGION}.amazonaws.com/automatedinsights/lambda_supervisor:latest    

In [None]:
# get all agents
response = bedrock_agent_client.list_agents()
for agent in response.get('agentSummaries', []):
    if "DataEngineer" in agent.get('agentName') or "BusinessAnalyst" in agent.get('agentName') or "DataScientist" in agent.get('agentName'):
        print("-" * 40)
        agent_aliases = bedrock_agent_client.list_agent_aliases(agentId=agent.get('agentId'))
        for alias in agent_aliases.get('agentAliasSummaries', []):
            if alias.get('agentAliasId') not in ['TSTALIASID']: 
                if "DataEngineer" in agent.get('agentName'):
                    data_engineer_agent_alias_arn = f"arn:aws:bedrock:{REGION}:{AWS_ACCOUNT}:agent-alias/{agent.get('agentId')}/{alias.get('agentAliasId')}"
                    print(f"data_engineer_agent_alias_arn: {data_engineer_agent_alias_arn}")
                elif "BusinessAnalyst" in agent.get('agentName'):
                    business_analyst_agent_alias_arn = f"arn:aws:bedrock:{REGION}:{AWS_ACCOUNT}:agent-alias/{agent.get('agentId')}/{alias.get('agentAliasId')}"
                    print(f"business_analyst_agent_alias_arn: {business_analyst_agent_alias_arn}")
                elif "DataScientist" in agent.get('agentName'):
                    data_scientist_agent_alias_arn = f"arn:aws:bedrock:{REGION}:{AWS_ACCOUNT}:agent-alias/{agent.get('agentId')}/{alias.get('agentAliasId')}"
                    print(f"data_scientist_agent_alias_arn: {data_scientist_agent_alias_arn}")


In [None]:
import logging
import random
import string
import json
from utils.bedrock_agent import BedrockAgentScenarioWrapper

logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
agent_name = "AI_ML_team"
MODEL_ID = "us.anthropic.claude-3-5-sonnet-20240620-v1:0"


# s3://{S3_BUCKET_NAME}/uploads/CustomerFeedback.csv,
# s3://{S3_BUCKET_NAME}/uploads/CustomerPreferences.csv, 
# s3://{S3_BUCKET_NAME}/uploads/Customers.csv, 
# s3://{S3_BUCKET_NAME}/uploads/Interactions.csv, 
# s3://{S3_BUCKET_NAME}/uploads/LoyaltyProgram.csv, 
# s3://{S3_BUCKET_NAME}/uploads/MarketingCampaigns.csv, 
# s3://{S3_BUCKET_NAME}/uploads/OrderItems.csv, 
# s3://{S3_BUCKET_NAME}/uploads/Orders.csv, 
# s3://{S3_BUCKET_NAME}/uploads/Products.csv, 
# s3://{S3_BUCKET_NAME}/uploads/SupportTickets.csv

prompt = f"""Inputs:
Here is the Athena database name: "{S3_BUCKET_NAME.replace('-', '_')}" 
And here is the list of all datasets that need to be processed: 
s3://{S3_BUCKET_NAME}/uploads/Orders.csv
 

Return the final response in XML format and nothing else."""

print(f"prompt: {prompt}")

instruction = """

## Core Role Definition
You are a strategic AI Team Manager who orchestrates AI/ML use cases from concept to execution, serving as the central coordinator between business needs, data preparation, and technical implementation.

## Key Responsibilities

### Data Management
- Ensure all ML datasets have clearly defined locations and target column names before proceeding with any modeling
- Verify that complete dataset information (not just SQL queries) is provided to the DataScientistAgent
- Enforce proper storage organization: models in `models/` directory and predictions in `results/` directory
- Validate data quality and completeness before initiating any ML processes

### Workflow Optimization
- Maintain a knowledge base of lessons learned from previous runs and user feedbacks
- Implement a structured approach to use case identification and execution
- Prioritize use cases based on business value and technical feasibility

### Process Execution
1. **Data Preparation Phase**:
   - If datasets are provided, prepare them in Amazon Athena first
   - Review Athena database schema thoroughly before proceeding
   - Ensure all SQL follows Amazon Athena syntax standards

2. **Use Case Identification Phase**:
   - Analyze available data to identify potential ML use cases
   - If no datasets are specified or all datasets are processed, identify ML use cases from existing Athena database

3. **Model Development Phase**:
   - For each identified use case:
     - Prepare training dataset with appropriate features
     - Train ML model with comprehensive evaluation metrics on training dataset
     - Document model accuracy and feature importance
     - Generate and store predictions on test dataset
     - Capture DataScientist commentary on model performance

### Output Requirements
Format the final response in XML with the following structure:
```xml
<FinalResponse>
<UseCases>
    <UseCase>
        <Name>Use Case Name</Name>
        <Description>Detailed description</Description>
        <ModelDetails>
            <TargetColumn>
                <Name>target_column_name</Name>
                <Definition>Clear definition of what the target represents</Definition>
            </TargetColumn>
            <TrainingDataLocation>Full path to training data</TrainingDataLocation>
            <TestDataLocation>Full path to test data</TestDataLocation>
            <ModelLocation>Full path to stored model</ModelLocation>
            <MLDatasetLocation>Full path to complete ML dataset</MLDatasetLocation>
            <MLDatasetSQLQuery>Complete SQL query used</MLDatasetSQLQuery>
            <Accuracy>Numerical accuracy metric</Accuracy>
            <FeatureImportances>
                <!-- Detailed feature importance list -->
            </FeatureImportances>
            <Sample-PredictedValues>
                <!-- Representative sample of predictions -->
            </Sample-PredictedValues>
        </ModelDetails>
        <PredictionDataLocation>Full path to predictions file</PredictionDataLocation>
        <DataScientistCommentary>Expert analysis of model performance and limitations</DataScientistCommentary>
    </UseCase>
    <!-- Additional use cases as needed -->
</UseCases>
</FinalResponse>
```

If no viable use cases are identified, return: `<UseCases>None</UseCases>`

Always return the complete XML structure with all required elements and nothing else."""

postfix = "".join(
    random.choice(string.ascii_lowercase + "0123456789") for _ in range(8)
)

agent_name = agent_name + "_" + postfix

IMAGE_URI = f'{AWS_ACCOUNT}.dkr.ecr.{REGION}.amazonaws.com/automatedinsights/lambda_supervisor:latest'


promptOverrideConfiguration = None

agentCollaboration = 'SUPERVISOR' #|'SUPERVISOR_ROUTER'|'DISABLED'
# The ARN should be in format: arn:aws:bedrock:{region}:{account}:agent-alias/{agent-id}/{alias-id}

sub_agents_list = [
    {
        'sub_agent_alias_arn': data_engineer_agent_alias_arn,
        'sub_agent_instruction': """You can invoke the DataEngineerAgent agent when you have been given new datasets that have been staged in Amazon S3 and need to be made available in Amazon Athena as a table so that they can be queried by the business analyst agent. The DataEngineerAgent agent generates semantic type descriptions for each column in the table/file, creates a SQL table definition, and then creates a respective Amazon Athena table, which can be queried.""",
        'sub_agent_association_name': 'DataEngineerAgent',
        'relay_conversation_history': 'TO_COLLABORATOR'
    },
    {
        'sub_agent_alias_arn': business_analyst_agent_alias_arn,
        'sub_agent_instruction': """You can invoke the BusinessAnalystAgent agent to review the data in the Amazon Athena database in order to identify AI/ML use cases that can be performed on the data. It returns any identified ML use cases with details such as the ML use case description, use case business value,  and the ML training dataset location and target column name which can be used by the DataScientistAgent agent.""",
        'sub_agent_association_name': 'BusinessAnalystAgent',
        'relay_conversation_history': 'TO_COLLABORATOR'
    },
    {
        'sub_agent_alias_arn': data_scientist_agent_alias_arn,
        'sub_agent_instruction': """If you have a ML dataset and target column, then you can invoke the DataScientistAgent agent for machine learning tasks such as preparing the data to train a model, training an ML model or using a trained model to generate predictions for a given dataset and target column. The train method of the agent takes in a DataLocation from the BusinessAnalystAgent agent, and returns the trained ML model location, along with details on its accuracy and feature importance. The predict method takes a dataset and target, and generates predictions for it.""",
        'sub_agent_association_name': 'DataScientistAgent',
        'relay_conversation_history': 'TO_COLLABORATOR'
    }
]

lambda_environment_variables = {
    "S3_BUCKET_NAME": S3_BUCKET_NAME,
    "MODEL_ID": MODEL_ID,
    "USER_FEEDBACK_TABLE_NAME": USER_FEEDBACK_TABLE_NAME,
    "TRACE_TABLE_NAME": TRACE_TABLE_NAME
}

scenario = BedrockAgentScenarioWrapper(
    bedrock_agent_client=bedrock_agent_client,
    runtime_client=bedrock_agent_runtime_client,
    lambda_client=lambda_client,
    iam_resource=iam_resource,
    postfix=postfix,
    agent_name=agent_name,
    model_id=MODEL_ID,
    sub_agents_list=sub_agents_list,
    prompt=prompt,
    action_group_schema_path="action_groups/supervisor_open_api_schema.yml",
    lambda_image_uri=IMAGE_URI,
    lambda_environment_variables=lambda_environment_variables,
    instruction=instruction,
    agentCollaboration=agentCollaboration,
    promptOverrideConfiguration=promptOverrideConfiguration
)
try:
    response, trace_events = scenario.run_scenario()
except Exception as e:
    logging.exception(f"Something went wrong: {e}")

In [None]:
# prompt = f"""Inputs:
# Here is the Athena database name: "{S3_BUCKET_NAME.replace('-', '_')}" 
# And here is the list of all datasets that need to be processed: 
# s3://{S3_BUCKET_NAME}/uploads/Orders.csv

# Return the final response in XML format and nothing else."""

# scenario.prompt = prompt
# print(scenario.prompt)
# scenario.chat_with_agent()
# prompt

## Multi-Agent Orchestration Evaluation

Sources: 

1) [Paper from Raphael Shu and Nilaksh Das and Michelle Yuan and Monica Sunkara and Yi Zhang: Towards Effective GenAI Multi-Agent Collaboration: Design and Evaluation for Enterprise Applications](https://arxiv.org/abs/2412.05449)

2) [AWS blog: Unlocking complex problem-solving with multi-agent collaboration on Amazon Bedrock](https://aws.amazon.com/blogs/machine-learning/unlocking-complex-problem-solving-with-multi-agent-collaboration-on-amazon-bedrock/)



### Metric Definition Implementation

#### Success Metrics

- Overall GSR: Overall goal success rate covering both the user-side and the system-side. Implementation:  Use LLM to judge user-side and system-side assertions. For a conversation, score is 1 if all assertions are True; else 0.

- Supervisor GSR: Goal success rate of the supervisor agent without any dependence on sub-agent and tool behavior. Implementation: If overall GSR is 1 or supervisor agent is reliable, then score for the conversation is 1; else 0.

- User-side GSR: Goal success rate in the perspective of the user. Implementation: Use LLM to judge user-side assertions. For a conversation, score is 1 if all user-side assertions are True; else 0.

- System-side GSR: Goal success rate in the perspective of the system developers. Implementation: Use LLM to judge system-side assertions. For a conversation, score is 1 if all system-side assertions are True; else 0.


#### Latency Metrics (TBD as Bedrock Agents does not return timestamps/duration in traces yet)

- Avg. communication overhead per turn: Average number of seconds that the supervisor agent spends communicating with other agents before getting back to the user. This time does not take into account the duration of agents other than the supervisor agent.       

- Avg. latency per communication: Average number of seconds that the supervisor agent spends to deliver each message to communicate with other agents.

- Avg. user-perceived turn latency per session: Average number of seconds it takes for the supervisor agent to get back to the user. This time does take into account the duration of all agents in the system.

#### Operations Metrics

- Avg. communications per session: Average count of messages sent by the supervisor agent over the entire session.

- Avg. output tokens per communication: Average number of total output tokens from the supervisor agent for each message.

#### Cost Metrics

- Avg. cost per turn: Average cost of the supervisor agent for each turn. (TBD)

- Avg. cost per session: Average cost of the supervisor agent for the entire session.




## start mlflow server in seperate terminal 

(if you are not using SageMaker mlflow instance)

mlflow server --host 127.0.0.1 --port 5000

In [8]:
# use local mlflow server or when using SageMaker mlflow instance, set the tracking uri to the SageMaker mlflow instance
import json
import subprocess

def get_presigned_url():
    cmd = f"""aws sagemaker create-presigned-mlflow-tracking-server-url \
      --tracking-server-name {MLFLOW_SERVER_NAME} \
      --session-expiration-duration-in-seconds 1800 \
      --expires-in-seconds 300 \
      --region {REGION} --profile {SESSION_PROFILE}"""
    
    result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
    if result.returncode == 0:
        response = json.loads(result.stdout)
        return response["AuthorizedUrl"]
    else:
        raise Exception(f"Failed to get presigned URL: {result.stderr}")

# Get the authorized URL and set as MLflow tracking URI
# MLFLOW_TRACKING_URI = get_presigned_url()
# mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)



mlflow.set_tracking_uri("http://localhost:5000")

In [None]:
# create new experiment for multi-agent-colloboration if it does not exist yet
import uuid
experiment_description = (
        "multi-agent-colloboration."
    )

experiment_tags = {
    "project_name": "multi-agent-colloboration",
    "use_case": "multi-agent-colloboration",
    "team": "aws-ai-ml-analytics",
    "source": "multi-agent-colloboration",
    "mlflow.note.content": experiment_description,
}
random_identifier = uuid.uuid4().hex
experiment_name = f"multi-agent-colloboration_{random_identifier}"

# Use search_experiments() to search on the project_name tag key
mac_experiments = mlflow.MlflowClient().search_experiments(
    filter_string="tags.`project_name` = 'multi-agent-colloboration'"
)
print(mac_experiments)
# check length of mlflow.store.entities.paged_list.PagedList
# if experiment does not exist yet, create it
if len(mac_experiments) == 0:
    mac_experiment = mlflow.MlflowClient().create_experiment(name=experiment_name, tags=experiment_tags)
    mlflow.set_experiment(experiment_name)
else:
    mac_experiment = mac_experiments[0]
    

In [10]:
# prompt_template
PROMPT_TEMPLATE = """Inputs:
Here is the Athena database name: "{ATHENA_DATABASE}" 
And here is the list of all datasets that need to be processed: 
{FILE_NAME}

Return the final response in XML format and nothing else."""


In [None]:
# Check if scenario exists and get agent ID
if 'scenario' in locals() and scenario is not None and hasattr(scenario, 'agent'):
    agent = scenario.agent
    BEDROCK_AGENT_ID = agent.get('agentId')
    # save the agent id to the environment variable and to the dev.env file
    os.environ['BEDROCK_AGENT_ID'] = BEDROCK_AGENT_ID
    set_key(local_env_filename, 'BEDROCK_AGENT_ID', BEDROCK_AGENT_ID)
else:
    # Fallback to hardcoded agent ID
    print("Scenario not found, using fallback agent ID from environment variable")

print(f"AGENT_ID: {BEDROCK_AGENT_ID}")

# get agent alias id
agent_aliases = bedrock_agent_client.list_agent_aliases(agentId= BEDROCK_AGENT_ID)

BEDROCK_AGENT_ALIAS_ID =  agent_aliases.get('agentAliasSummaries')[0].get('agentAliasId')
print(f"AGENT_ALIAS_ID: {BEDROCK_AGENT_ALIAS_ID}")

In [None]:
import copy
import os
import uuid
from typing import List, Optional

import time
import mlflow
from botocore.config import Config
from mlflow.entities import SpanType
from mlflow.pyfunc import ChatModel
from mlflow.types.llm import ChatResponse, ChatMessage, ChatParams, ChatChoice

import mlflow
from mlflow.models import infer_signature
from dataclasses import dataclass

from utils.bedrock_mlflow_agent import BedrockMultiAgentModel

# get bedrock agent details with boto3 and extrace the MODEL_ID
agent_response = bedrock_agent_client.get_agent( agentId = BEDROCK_AGENT_ID )
MAC_MODEL_ID = agent_response.get('agent', {}).get('foundationModel')

run_name = f"mac_{BEDROCK_AGENT_ID}_{MAC_MODEL_ID}"

TEMPERATURE = 0
MAXIMUM_LENGTH = 2000
print(f'MAC_MODEL_ID: {MAC_MODEL_ID}')

prompt = PROMPT_TEMPLATE.format(
    ATHENA_DATABASE=S3_BUCKET_NAME.replace('-', '_'),
    FILE_NAME=f"s3://{S3_BUCKET_NAME}/uploads/Orders.csv"
)

input_example = {
        "messages": [
            {
                "role": "user",
                "content": prompt
            }
        ]
    }

with mlflow.start_run(experiment_id=mac_experiment.experiment_id, run_name=run_name):

    model_config = {
        "agents": {
            "main": {
                "model": MAC_MODEL_ID,
                "aws_region": REGION,
                "bedrock_agent_id": BEDROCK_AGENT_ID,
                "bedrock_agent_alias_id": BEDROCK_AGENT_ALIAS_ID,
                "instruction": instruction,
                "inference_configuration": {
                    "temperature": TEMPERATURE,
                    "maximumLength": MAXIMUM_LENGTH,
                },
            },  
        },
        "aws_profile": SESSION_PROFILE
    }

    # Log parameters
    mlflow.log_params({
        "model_id": MAC_MODEL_ID,
        "bedrock_agent_id": BEDROCK_AGENT_ID,
        "bedrock_agent_alias_id": BEDROCK_AGENT_ALIAS_ID,
        "aws_region": REGION,
        "instruction": model_config["agents"]["main"]["instruction"],
        "temperature": model_config["agents"]["main"]["inference_configuration"]["temperature"],
        "max_length": model_config["agents"]["main"]["inference_configuration"]["maximumLength"]
    })

    # Log any relevant tags
    mlflow.set_tags({
        "model_type": "bedrock_multi_agent_colloboration",
        "framework": "mlflow_pyfunc",
        "environment": "development"  # or "production", etc.
    })

    # Define the conda environment
    conda_env = {
        'name': 'bedrock_agent_env',
        'channels': ['defaults', 'conda-forge'],
        'dependencies': [
            'python=3.11',
            'pip',
            {'pip': [
                'boto3==1.36.2',
                'mlflow==2.16.2',
                'sagemaker-mlflow'                
            ]}
        ]
    }

    # Log and load the model using MLflow
    mac_chain_info = mlflow.pyfunc.log_model(
        python_model=BedrockMultiAgentModel(),
        model_config=model_config,
        artifact_path="mac_agent",  # This string is used as the path inside the MLflow model where artifacts are stored
        input_example=None, # to skip validation #input_example,  # Must be a valid input to your chain
        conda_env=conda_env,  # Add conda environment
        code_path=["./utils/bedrock_mlflow_agent.py"]

    )

    mac_agent = mlflow.pyfunc.load_model(mac_chain_info.model_uri)

    # Define chat parameters
    chat_params = ChatParams(
        temperature=float(model_config["agents"]["main"]["inference_configuration"]["temperature"]),
        max_tokens=int(model_config["agents"]["main"]["inference_configuration"]["maximumLength"]),
        
    )

    # create a list of inputs by iterating through an S3 directory and creating an input for each file with input_example as a base 

    response =s3_client.list_objects_v2(Bucket=S3_BUCKET_NAME, Prefix="uploads/")
    scenario = {
                "scenario": "Goals: - User needs to know if the given dataset has a AI/ML use case. Prepare the data in Amazon Athena. Then identify ML use cases for the data. Review the identified ML use cases and if you find any, use the returned DataLocation and ModelLocation to train a ML model and return its accuracy and respective feature importance. Lastly, use the trained model to generate predictions for the dataset and return the predictions.",
                "input_problem": "PLACEHOLDER",
                "assertions": [
                    "agent: GetInformationForSemanticTypeDetection is executed to gather information about the data",
                    "agent: SaveSQLTableDefinition is executed to save the SQL table definition to S3",
                    "agent: CreateAthenaTable is executed to create the Athena table",
                    "agent: QueryData is executed to verify the Athena table creation",
                    "agent: GetDatabaseSchema is executed to get the latest database schema",
                    "agent: SaveERM is executed to save the Entity Relationship Diagram in json format to S3",
                    "agent: GetDatabaseSchema is executed to detect AI/ML use cases that can be performed on the data in the Athena database",
                    "agent: GetUseCases is executed to detect AI/ML use cases that can be performed on the data in the Athena database",
                    "agent: ExecuteQuery is executed to create a ML dataset",
                    "agent: SaveDataset is executed to save the ML dataset to S3",
                    "agent: The AI/ML use cases along with the respective S3 location of the ML dataset(s) and target column name are returned in the final response",
                    "agent: Run exploratory data analysis on the data",
                    "agent: Determine the right HoldoutFrac for the ml model training",
                    "agent: Split the data into train and test sets",
                    "agent: Train a new model with the specified details",
                    "agent: The model is trained successfully",
                    "agent: The model location and feature importance are returned in the final response"
                ]
            }
    scenarios_dict = {"scenarios": []}

    random_identifier = uuid.uuid4().hex
    
    # iterate through the files and create an input for each file with input_example as a base 
    for item in response.get('Contents', []):
        file_key = item['Key']
        print(f"file_key: {file_key}")
        # Skip if it's just the directory itself
        if file_key.endswith('/'):
            continue

        # only process two files Customers or Orders
        if not file_key.startswith(('uploads/Customers.csv', 'uploads/Orders.csv')):
            continue

        print(f"Processing file: {file_key}")
        prompt = PROMPT_TEMPLATE.format(
                    ATHENA_DATABASE=S3_BUCKET_NAME.replace('-', '_'),
                    FILE_NAME=f"s3://{S3_BUCKET_NAME}/{file_key}"
                )
        print(f"prompt: {prompt}")

        input_example["messages"][0]["content"] = prompt
        scenario["input_problem"] = prompt
        scenarios_dict["scenarios"].append(scenario.copy())

        response = mac_agent.predict(input_example,params=chat_params)
        print(response)
        
        path = f"../data/eval_dataset/conversations/{random_identifier}"
        print(f"path: {path}")
        # check if path exists, if not create it
        if not os.path.exists(path):
            os.makedirs(path)
        
        # read all files that start with conversation in this directory and determine the highest index
        files = os.listdir(path)
        highest_index = 0
        for file in files:
            if file.startswith("conversation"):
                index = int(file.split("_")[1].replace(".json",""))
                if index > highest_index:
                    highest_index = index
        
        # write metrics to file
        with open(f"{path}/metrics_{highest_index+1}.json", "w") as f:
            json.dump(mac_agent._model_impl.chat_model._metrics, f)
        
        # log metrics to mlflow
        mlflow.log_table(mac_agent._model_impl.chat_model._metrics, f"metrics_{highest_index+1}.json", run_id=mlflow.active_run().info.run_id)

        # write conversation_dict to file
        with open(f"{path}/conversation_{highest_index+1}.json", "w") as f:
            json.dump(mac_agent._model_impl.chat_model._conversation_json_dict, f)
    
        # overwrite scenarios to json file so that we always have the latest scenarios
        with open(f"{path}/scenarios.json", "w") as f:
            json.dump(scenarios_dict, f, indent=4)

        # add a delay of 1.5 minutes to avoid 424 errors
        time.sleep(90)

In [26]:
DATA_ENGINEER_AGENT_ID = data_engineer_agent_alias_arn.split('/')[-2]
BUSINESS_ANALYST_AGENT_ID = business_analyst_agent_alias_arn.split('/')[-2]
DATA_SCIENTIST_AGENT_ID = data_scientist_agent_alias_arn.split('/')[-2]

agent_config = {
    "agents": [
        {
            "agent_id": BEDROCK_AGENT_ID,
            "agent_name": "AI_ML_Team_Supervisor",
            "agent_instruction": instruction,
            "tools": [
                {
                    "tool_name": "SupervisorAPI",
                    "name": "SupervisorAPI",
                    "description": "Supervisor agent to analyze feedback and improve agent performance",
                    "actions": [
                        {
                            "name": "GetLessonsLearnedFromPastRuns",
                            "description": "Retrieve lessons learned from user feedback and trace table",
                            "output_schema": {
                                "data_type": "object",
                                "properties": {
                                    "body": {
                                        "data_type": "string",
                                        "description": "List of lessons learned in bullet point format"
                                    }
                                }
                            },
                            "requires_confirmation": False,
                            "meta": {}
                        }
                    ],
                    "tool_type": "Module",
                    "meta": {}
                }
            ],
            "reachable_agents": [
                {
                    "scenario": "You can invoke the DataEngineerAgent agent when you have been given new datasets that have been staged in Amazon S3 and need to be made available in Amazon Athena as a table so that they can be queried by the business analyst agent. The DataEngineerAgent agent generates semantic type descriptions for each column in the table/file, creates a SQL table definition, and then creates a respective Amazon Athena table, which can be queried.",
                    "agent_id": DATA_ENGINEER_AGENT_ID,
                    "context_sharing": True
                },
                {
                    "scenario": "You can invoke the BusinessAnalystAgent agent to review the data in the Amazon Athena database in order to identify AI/ML use cases that can be performed on the data. It returns any identified ML use cases with details such as the ML use case description, use case business value,  and the ML training dataset location and target column name which can be used by the DataScientistAgent agent.",
                    "agent_id": BUSINESS_ANALYST_AGENT_ID,
                    "context_sharing": True
                },
                {
                    "scenario": "If you have a ML dataset and target column, then you can invoke the DataScientistAgent agent for machine learning tasks such as preparing the data to train a model, training an ML model or using a trained model to generate predictions for a given dataset and target column. The train method of the agent takes in a DataLocation from the BusinessAnalystAgent agent, and returns the trained ML model location, along with details on its accuracy and feature importance. The predict method takes a dataset and target, and generates predictions for it.",
                    "agent_id": DATA_SCIENTIST_AGENT_ID,
                    "context_sharing": True
                }
            ]
        },
        {
            "agent_id": DATA_ENGINEER_AGENT_ID,
            "agent_name": "DataEngineer",
            "agent_instruction": """Expert Data Engineer Agent
You are an expert data engineer with access to a comprehensive set of data preparation and management tools. Your role is to help users process, analyze, and query data efficiently.

Available Tools
GetDatabaseSchema: Retrieves the current SQL database schema from S3.
GetERM: Retrieves the Entity Relationship Diagram in JSON format.
SaveERM: Saves an Entity Relationship Diagram in JSON format to S3.
GetInformationForSemanticTypeDetection: Analyzes data to help detect semantic types.
SaveSQLTableDefinition: Saves SQL table definition to S3.
CreateAthenaTable: Creates an Athena table based on data and Athena table create definition.
QueryData: Executes SQL queries against Athena databases.

Workflow
1) Data Analysis: When presented with new data, use GetInformationForSemanticTypeDetection to analyze the data structure and content. Identify semantic types for each column and 3 ML use cases where these semantic types could be used in.
2) Table Schema Definition: Create appropriate SQL table definition with semantic column names based on the data analysis, considering semantic types and appropriate data types. 
3) Table Creation: Use the Table Schema Definition as a blueprint for the Athena table creation with the tool CreateAthenaTable to make the data available for querying.
4) Data Querying: Verify successful table creation by querying the data using the QueryData tool.
5) Entity RelationShip Diagram: Get the latest SQL schema with the GetDatabaseSchema tool.
Then generate a new entity relationship diagram and save it with the tool SaveERM.

Important Guidelines
If you receive multiple datasets as input, process each one methodically and verify that all files have been processed before providing your final response.
When creating table definitions, include SQL comments that explain the semantic type of each column and its potential use cases.
For primary keys, add specific comments explaining the primary key constraint.
Always ensure your SQL follows the appropriate format for the target system (ANSI SQL or Athena SQL).""",
            "tools": [
                        {
                            "tool_name": "DataEngineerAPI",
                            "name": "DataEngineerAPI",
                            "description": "Data Preparation to make data available for AI/ML and Analytics workloads",
                            "actions": [
                                {
                                    "name": "GetDatabaseSchema",
                                    "description": "Retrieve the SQL database schema from S3",
                                    "output_schema": {
                                        "data_type": "array",
                                        "items": {
                                            "type": "string",
                                            "description": "SQL table definition statements"
                                        }
                                    },
                                    "requires_confirmation": False,
                                    "meta": {}
                                },
                                {
                                    "name": "GetERM",
                                    "description": "Retrieve the Entity Relationship Diagram in json format",
                                    "output_schema": {
                                        "data_type": "array",
                                        "items": {
                                            "type": "string",
                                            "description": "Entity Relationship Diagram in json format"
                                        }
                                    },
                                    "requires_confirmation": False,
                                    "meta": {}
                                },
                                {
                                    "name": "SaveERM",
                                    "description": "Save the Entity Relationship Diagram in json format to S3",
                                    "input_schema": {
                                        "data_type": "object",
                                        "properties": {
                                            "ERMData": {
                                                "data_type": "string",
                                                "description": "Entity Relationship Diagram in json format"
                                            }
                                        },
                                        "required": [
                                            "ERMData"
                                        ]
                                    },
                                    "output_schema": {
                                        "data_type": "object",
                                        "properties": {
                                            "erm_file_location": {
                                                "data_type": "string",
                                                "description": "S3 location where the ERM was saved"
                                            }
                                        }
                                    },
                                    "requires_confirmation": False,
                                    "meta": {}
                                },
                                {
                                    "name": "GetInformationForSemanticTypeDetection",
                                    "description": "Get information for semantic type detection",
                                    "input_schema": {
                                        "data_type": "object",
                                        "properties": {
                                            "DataLocation": {
                                                "data_type": "string",
                                                "description": "S3 location of the data to analyze"
                                            }
                                        },
                                        "required": [
                                            "DataLocation"
                                        ]
                                    },
                                    "output_schema": {
                                        "data_type": "object",
                                        "properties": {
                                            "column_names": {
                                                "data_type": "array",
                                                "items": {
                                                    "type": "string"
                                                }
                                            },
                                            "data_sample": {
                                                "data_type": "object"
                                            }
                                        }
                                    },
                                    "requires_confirmation": False,
                                    "meta": {}
                                },
                                {
                                    "name": "SaveSQLTableDefinition",
                                    "description": "Save SQL table definition to S3",
                                    "input_schema": {
                                        "data_type": "object",
                                        "properties": {
                                            "SQL_Table_Definition": {
                                                "data_type": "string",
                                                "description": "SQL table definition"
                                            },
                                            "TableName": {
                                                "data_type": "string",
                                                "description": "File or table name"
                                            }
                                        },
                                        "required": [
                                            "SQL_Table_Definition"
                                        ]
                                    },
                                    "output_schema": {
                                        "data_type": "object",
                                        "properties": {
                                            "sql_table_definition": {
                                                "data_type": "string"
                                            },
                                            "sql_table_definition_file_location": {
                                                "data_type": "string"
                                            }
                                        }
                                    },
                                    "requires_confirmation": False,
                                    "meta": {}
                                },
                                {
                                    "name": "CreateAthenaTable",
                                    "description": "Create an Athena table based on data and semantic types",
                                    "input_schema": {
                                        "data_type": "object",
                                        "properties": {
                                            "DataLocation": {
                                                "data_type": "string",
                                                "description": "S3 location of the data"
                                            },
                                            "AthenaDatabase": {
                                                "data_type": "string",
                                                "description": "The Athena database to create the table in"
                                            },
                                            "Athena_Table_Create_SQL_statement": {
                                                "data_type": "string",
                                                "description": "Athena SQL table create statement with LOCATION DataLocation"
                                            },
                                            "TableName": {
                                                "data_type": "string",
                                                "description": "Name of the table to create"
                                            }
                                        },
                                        "required": [
                                            "AthenaDatabase",
                                            "TableName",
                                            "Athena_Table_Create_SQL_statement",
                                            "DataLocation"
                                        ]
                                    },
                                    "output_schema": {
                                        "data_type": "object",
                                        "properties": {
                                            "status": {
                                                "data_type": "string"
                                            },
                                            "table_name": {
                                                "data_type": "string"
                                            },
                                            "table_location": {
                                                "data_type": "string"
                                            },
                                            "table_definition": {
                                                "data_type": "string"
                                            },
                                            "query_results": {
                                                "data_type": "object"
                                            }
                                        }
                                    },
                                    "requires_confirmation": False,
                                    "meta": {}
                                },
                                {
                                    "name": "QueryData",
                                    "description": "Execute SQL query against Athena database",
                                    "input_schema": {
                                        "data_type": "object",
                                        "properties": {
                                            "AthenaDatabase": {
                                                "data_type": "string",
                                                "description": "The Athena database to query"
                                            },
                                            "SQLQuery": {
                                                "data_type": "string",
                                                "description": "SQL query to execute"
                                            }
                                        },
                                        "required": [
                                            "AthenaDatabase",
                                            "SQLQuery"
                                        ]
                                    },
                                    "output_schema": {
                                        "data_type": "object",
                                        "properties": {
                                            "status": {
                                                "data_type": "string"
                                            },
                                            "query": {
                                                "data_type": "string"
                                            },
                                            "data": {
                                                "data_type": "object"
                                            }
                                        }
                                    },
                                    "requires_confirmation": False,
                                    "meta": {}
                                }
                            ],
                            "tool_type": "Module",
                            "meta": {}
                        }
                    ],
            "reachable_agents": []
        },
        {
            "agent_id": BUSINESS_ANALYST_AGENT_ID,
            "agent_name": "BusinessAnalyst",
            "agent_instruction": """You are an expert business analyst. 
You have access to a set of tools to review SQL database schema, run SQL queries against the Athena database, and go through a list of ML use cases and determine whether the AI/ML use case can be performed on the data. 
Identify one high-value use cases focusing on classification and regression ML problem statements.
For each viable AI/ML use case you have identified, provide a comprehensive analysis that includes:
- ML Use Case Name
- Description
- Business Justification
- Target Column Specification (e.g. for in customer churn example a customer that has been inactive for 6months)
- ML dataset location
- Athena SQL query that successfully generated the ML dataset that includes the target column

Execute the generated Athena SQL query and ensure that it is valid SQL that can be executed in Amazon Athena. 
If you encounter any errors, review the Athena error message and correct the SQL query accordingly. 
If you cannot resolve the Athena error after 3 attempts, eliminate the ML use case and continue with the other use cases.

Before your final response, pause and verify that your final response includes the ML dataset location.""",
            "tools": [
                {
                    "tool_name": "BusinessAnalystAPI",
                    "name": "BusinessAnalystAPI",
                    "description": "API for database analysis and ML dataset preparation",
                    "actions": [
                        {
                            "name": "GetDatabaseSchema",
                            "description": "Retrieve the SQL database schema from S3",
                            "input_schema": {
                                "data_type": "object",
                                "properties": {
                                    "AthenaDatabase": {
                                        "data_type": "string",
                                        "description": "The Athena database name",
                                        "required": []
                                    }
                                },
                                "required": [
                                    "AthenaDatabase"
                                ]
                            },
                            "output_schema": {
                                "data_type": "array",
                                "items": {
                                    "type": "object",
                                    "properties": {
                                        "table_name": {
                                            "type": "string"
                                        },
                                        "columns": {
                                            "type": "array",
                                            "items": {
                                                "type": "object",
                                                "properties": {
                                                    "name": {
                                                        "type": "string"
                                                    },
                                                    "type": {
                                                        "type": "string"
                                                    }
                                                }
                                            }
                                        }
                                    }
                                }
                            },
                            "requires_confirmation": False,
                            "meta": {}
                        },
                        {
                            "name": "GetUseCases",
                            "description": "Retrieve available AI/ML use cases from S3",
                            "output_schema": {
                                "data_type": "array",
                                "items": {
                                    "type": "object",
                                    "properties": {
                                        "name": {
                                            "type": "string",
                                            "description": "Name of the use case"
                                        },
                                        "description": {
                                            "type": "string",
                                            "description": "Detailed description of the use case"
                                        },
                                        "required_columns": {
                                            "type": "array",
                                            "items": {
                                                "type": "string"
                                            },
                                            "description": "Required columns for this use case"
                                        }
                                    }
                                }
                            },
                            "requires_confirmation": False,
                            "meta": {}
                        },
                        {
                            "name": "ExecuteQuery",
                            "description": "Execute an Athena query, save results, and return samples",
                            "input_schema": {
                                "data_type": "object",
                                "properties": {
                                    "AthenaDatabase": {
                                        "data_type": "string",
                                        "description": "The Athena database name",
                                        "required": []
                                    },
                                    "Query": {
                                        "data_type": "string",
                                        "description": "The SQL query to execute",
                                        "required": []
                                    },
                                    "UseCaseName": {
                                        "data_type": "string",
                                        "description": "Optional use case name to save full results",
                                        "required": []
                                    }
                                },
                                "required": [
                                    "AthenaDatabase",
                                    "Query"
                                ]
                            },
                            "output_schema": {
                                "data_type": "object",
                                "properties": {
                                    "status": {
                                        "type": "string",
                                        "enum": ["SUCCEEDED", "FAILED", "CANCELLED", "ERROR"]
                                    },
                                    "total_rows": {
                                        "type": "integer",
                                        "description": "Total number of rows in the full result"
                                    },
                                    "total_columns": {
                                        "type": "integer",
                                        "description": "Total number of columns"
                                    },
                                    "columns": {
                                        "type": "array",
                                        "items": {
                                            "type": "string"
                                        },
                                        "description": "List of column names"
                                    },
                                    "sample_data": {
                                        "type": "array",
                                        "items": {
                                            "type": "object",
                                            "additionalProperties": True
                                        },
                                        "description": "Sample rows from the query results"
                                    },
                                    "dataset_location": {
                                        "type": "string",
                                        "description": "S3 location of saved dataset (if use case provided)"
                                    },
                                    "error": {
                                        "type": "string",
                                        "description": "Detailed error message if query failed"
                                    }
                                }
                            },
                            "requires_confirmation": False,
                            "meta": {}
                        },
                        {
                            "name": "SaveDataset",
                            "description": "Save a dataset to S3",
                            "input_schema": {
                                "data_type": "object",
                                "properties": {
                                    "UseCaseName": {
                                        "data_type": "string",
                                        "description": "Name of the use case for the dataset",
                                        "required": []
                                    },
                                    "Data": {
                                        "data_type": "array",
                                        "items": {
                                            "type": "object",
                                            "additionalProperties": True
                                        },
                                        "description": "Dataset to save as array of records",
                                        "required": []
                                    }
                                },
                                "required": [
                                    "UseCaseName",
                                    "Data"
                                ]
                            },
                            "output_schema": {
                                "data_type": "object",
                                "properties": {
                                    "location": {
                                        "type": "string",
                                        "description": "S3 location where the dataset was saved"
                                    }
                                }
                            },
                            "requires_confirmation": False,
                            "meta": {}
                        }
                    ],
                    "tool_type": "Module",
                    "meta": {}
                }
            ],
            "reachable_agents": []
        },
        {
            "agent_id": DATA_SCIENTIST_AGENT_ID,
            "agent_name": "DataScientist",
            "agent_instruction": """ROLE: Expert Data Scientist with AutoML
CAPABILITIES:
1) Exploratory data analysis with the ExploratoryDataAnalysis function  
ExploratoryDataAnalysis(
APIPath: "ExploratoryDataAnalysis"
DataLocation: [path_to_data]) Returns: { summary: string, description: string, visualizations: object }

2) Create a train test split with the TrainTestSplit function
TrainTestSplit(
APIPath: "TrainTestSplit"
DataLocation: [path_to_data]
HoldoutFrac: [fraction_of_data_to_holdout_for_testing]
) Returns: { train_data_location: string, test_data_location: string }

3) Machine learning model training with the Train function
Train(
APIPath: "Train"
Target: [target_variable]
TrainDataLocation: [train_data_location]
TestDataLocation: [test_data_location]
ModelLocation: [path_to_store_model]
 ) Returns: { model_location: string, accuracy_metrics: object, feature_importance: object }

3) Prediction generation with the Predict function
Predict(
APIPath: "Predict"
DataLocation: [path_to_data]
ModelLocation: [path_to_model] 
ResultDataLocation: [path_to_store_predictions_data]) Returns: sample predictions and location of generated predictions

WORKFLOW:
First verify you have received a ML dataset location and target column name. If you only got a SQL query respond that you do not have access to Athena and that you need the dataset to be stored in S3.
Then use exploratory data analysis to gather more details on the ML dataset
Use this information to determine the right HoldoutFrac for the ml model training.
Then use the TrainTestSplit function to split the data into train and test sets.

If training:
Ensure that the ModelLocation is constructed in the "models/" subdirectory and has a subfolder that indicates the ml use-case, and that the actual filename is "models.zip".
Example ModelLocation: /models/customer_churn/model.zip
Execute Train() function
Return ML model location and ALL feature importance values.

If predicting:
Ensure that the ResultDataLocation is constructed in the "results/" subdirectory and has a filename that indicates the ml use-case, and has a filename with CSV file extension.
Example ModelLocation: /models/customer_churn/model.zip
Execute Predict() function
Return sample predictions and predictions dataset location.

ERROR HANDLING:
Report missing or invalid parameters
Alert insufficient data quality/quantity
Notify of model compatibility issues

Ensure that any created ML model is stored in the "models/" subdirectory, and any generated predictions are stored in the "results/" subdirectory.""",
            "tools": [
        {
            "tool_name": "DataScientistAPI",
            "name": "DataScientistAPI",
            "description": "AutoML to train a ML model, use a ML model to make predictions, or get feature importance",
            "actions": [
                {
                    "name": "Train",
                    "description": "Train a machine learning model",
                    "input_schema": {
                        "data_type": "object",
                        "properties": {
                            "Target": {
                                "data_type": "string",
                                "title": "Target",
                                "description": "The target column to predict",
                                "required": []
                            },
                            "DataLocation": {
                                "data_type": "string",
                                "title": "DataLocation",
                                "description": "S3 location of the data to train the model on",
                                "required": []
                            },
                            "ModelLocation": {
                                "data_type": "string",
                                "title": "ModelLocation",
                                "description": "S3 location to store the model",
                                "required": []
                            }
                        },
                        "required": [
                            "Target",
                            "DataLocation",
                            "ModelLocation"
                        ]
                    },
                    "output_schema": {
                        "data_type": "object",
                        "title": "200",
                        "description": "Successful operation",
                        "properties": {
                            "description": {
                                "data_type": "string",
                                "required": []
                            },
                            "content": {
                                "data_type": "object",
                                "properties": {
                                    "message": {
                                        "data_type": "string",
                                        "required": []
                                    },
                                    "results": {
                                        "data_type": "string",
                                        "required": []
                                    }
                                }
                            }
                        },
                        "required": []
                    },
                    "requires_confirmation": False,
                    "meta": {}
                },
                {
                    "name": "Predict",
                    "description": "Make predictions using a trained model",
                    "input_schema": {
                        "data_type": "object",
                        "properties": {
                            "Target": {
                                "data_type": "string",
                                "title": "Target",
                                "description": "The target column to predict",
                                "required": []
                            },
                            "DataLocation": {
                                "data_type": "string",
                                "title": "DataLocation",
                                "description": "S3 location of the input data for the prediction",
                                "required": []
                            },
                            "ResultDataLocation": {
                                "data_type": "string",
                                "title": "ResultDataLocation",
                                "description": "S3 location of the output data for the prediction",
                                "required": []
                            },
                            "ModelLocation": {
                                "data_type": "string",
                                "title": "ModelLocation",
                                "description": "S3 location of the trained model that is used for the predictions",
                                "required": []
                            }
                        },
                        "required": [
                            "Target",
                            "DataLocation",
                            "ModelLocation",
                            "ResultDataLocation"
                        ]
                    },
                    "output_schema": {
                        "data_type": "object",
                        "title": "200",
                        "description": "Successful operation",
                        "properties": {
                            "description": {
                                "data_type": "string",
                                "required": []
                            },
                            "content": {
                                "data_type": "object",
                                "properties": {
                                    "message": {
                                        "data_type": "string",
                                        "required": []
                                    },
                                    "results": {
                                        "data_type": "string",
                                        "required": []
                                    }
                                }
                            }
                        },
                        "required": []
                    },
                    "requires_confirmation": False,
                    "meta": {}
                },
                {
                    "name": "FeatureImportance",
                    "description": "Get feature importance from a trained model",
                    "input_schema": {
                        "data_type": "object",
                        "properties": {
                            "Target": {
                                "data_type": "string",
                                "title": "Target",
                                "description": "The target column to predict",
                                "required": []
                            },
                            "DataLocation": {
                                "data_type": "string",
                                "title": "DataLocation",
                                "description": "S3 location of the data",
                                "required": []
                            },
                            "ModelLocation": {
                                "data_type": "string",
                                "title": "ModelLocation",
                                "description": "S3 location of the trained model that is used to get the feature importance",
                                "required": []
                            }
                        },
                        "required": [
                            "Target",
                            "DataLocation",
                            "ModelLocation"
                        ]
                    },
                    "output_schema": {
                        "data_type": "object",
                        "title": "200",
                        "description": "Successful operation",
                        "properties": {
                            "description": {
                                "data_type": "string",
                                "required": []
                            },
                            "content": {
                                "data_type": "object",
                                "properties": {
                                    "message": {
                                        "data_type": "string",
                                        "required": []
                                    },
                                    "results": {
                                        "data_type": "string",
                                        "required": []
                                    }
                                }
                            }
                        },
                        "required": []
                    },
                    "requires_confirmation": False,
                    "meta": {}
                },
                {
                    "name": "ExploratoryDataAnalysis",
                    "description": "Perform exploratory data analysis on a dataset",
                    "input_schema": {
                        "data_type": "object",
                        "properties": {
                            "DataLocation": {
                                "data_type": "string",
                                "title": "DataLocation",
                                "description": "S3 location of the data",
                                "required": []
                            }
                        },
                        "required": [
                            "DataLocation"
                        ]
                    },
                    "output_schema": {
                        "data_type": "object",
                        "title": "200",
                        "description": "Successful operation",
                        "properties": {
                            "description": {
                                "data_type": "string",
                                "required": []
                            },
                            "content": {
                                "data_type": "object",
                                "properties": {
                                    "message": {
                                        "data_type": "string",
                                        "required": []
                                    },
                                    "results": {
                                        "data_type": "string",
                                        "required": []
                                    }
                                }
                            }
                        },
                        "required": []
                    },
                    "requires_confirmation": False,
                    "meta": {}
                },
                {
                    "name": "TrainTestSplit",
                    "description": "Perform train test split on a dataset",
                    "input_schema": {
                        "data_type": "object",
                        "properties": {
                            "DataLocation": {
                                "data_type": "string",
                                "title": "DataLocation",
                                "description": "S3 location of the data",
                                "required": []
                            },
                            "HoldoutFrac": {
                                "data_type": "number",
                                "title": "HoldoutFrac",
                                "description": "Fraction of data to hold out for testing",
                                "required": []
                            }
                        },
                        "required": [
                            "DataLocation",
                            "HoldoutFrac"
                        ]
                    },
                    "output_schema": {
                        "data_type": "object",
                        "title": "200",
                        "description": "Successful operation",
                        "properties": {
                            "description": {
                                "data_type": "string",
                                "required": []
                            },
                            "content": {
                                "data_type": "object",
                                "properties": {
                                    "message": {
                                        "data_type": "string",
                                        "required": []
                                    },
                                    "results": {
                                        "data_type": "string",
                                        "required": []
                                    }
                                }
                            }
                        },
                        "required": []
                    },
                    "requires_confirmation": False,
                    "meta": {}
                }
            ],
            "tool_type": "Module",
            "meta": {}
        }
    ],
            "reachable_agents": []
        }
    ],
    "primary_agent_id": BEDROCK_AGENT_ID,
    "human_id": "User"
}
# save agent config to json file
with open(f'../data/eval_dataset/conversations/{random_identifier}/agents.json', 'w') as f:
    json.dump(agent_config, f, indent=4)

In [27]:
# define different evaluation scenarios
evaluation_scenarios = {
    "scenarios": [
        {
            "scenario": "Goals: - User needs to know if their data supports any AI/ML use cases. Prepare the data in Amazon Athena. Then identify ML use cases for the data. Review the identified ML use cases and if you find any, train a ML model generate predictions.",
            "input_problem": ( 'Inputs: AthenaDatabase: {ATHENA_DATABASE} SourceDataLocation: s3://{S3_BUCKET_NAME}/uploads/Customers.csv'
                
            ),     
            "assertions": [
                "agent: GetInformationForSemanticTypeDetection is executed to gather information about the data",
                "agent: SaveSQLTableDefinition is executed to save the SQL table definition to S3",
                "agent: CreateAthenaTable is executed to create the Athena table",
                "agent: QueryData is executed to verify the Athena table creation",
                "agent: GetDatabaseSchema is executed to get the latest database schema",
                "agent: SaveERM is executed to save the Entity Relationship Diagram in json format to S3",
                "agent: GetDatabaseSchema is executed to detect AI/ML use cases that can be performed on the data in the Athena database",
                "agent: GetUseCases is executed to detect AI/ML use cases that can be performed on the data in the Athena database",
                "agent: ExecuteQuery is executed to create a ML dataset",
                "agent: SaveDataset is executed to save the ML dataset to S3",
                "agent: The AI/ML use cases along with the respective S3 location of the ML dataset(s) and target column name are returned in the final response",
                "agent: Run exploratory data analysis on the data",
                "agent: Determine the right HoldoutFrac for the ml model training",
                "agent: Split the data into train and test sets",
                "agent: Train a new model with the specified details",
                "agent: The model is trained successfully",
                "agent: The model location and feature importance are returned in the final response"
            ]
        },
        {
            "scenario": "Goals: - User needs to know if their data supports any AI/ML use cases. Identify ML use cases for the data. Review the identified ML use cases and if you find any, train a ML model generate predictions.",
            "input_problem": (
                'Inputs: Here is the Athena database name: "{ATHENA_DATABASE}"'
                'No additional datasets need to be processed.'
                'Return the final response in XML format and nothing else.'
               ),     
            "assertions": [
                "agent: GetDatabaseSchema is executed to detect AI/ML use cases that can be performed on the data in the Athena database",
                "agent: GetUseCases is executed to detect AI/ML use cases that can be performed on the data in the Athena database",
                "agent: ExecuteQuery is executed to create a ML dataset",
                "agent: SaveDataset is executed to save the ML dataset to S3",
                "agent: The AI/ML use cases along with the respective S3 location of the ML dataset(s) and target column name are returned in the final response",
                "agent: Run exploratory data analysis on the data",
                "agent: Determine the right HoldoutFrac for the ml model training",
                "agent: Split the data into train and test sets",
                "agent: Train a new model with the specified details",
                "agent: The model is trained successfully",
                "agent: The model location and feature importance are returned in the final response"
            ]
        },   
        
    ]
}

# save evaluation scenarios to json file
with open(f'../data/eval_dataset/conversations/{random_identifier}/scenarios.json', 'w') as f:
    json.dump(evaluation_scenarios, f, indent=4)


In [None]:
print(f"path: {path}")

In [None]:
# Run the benchmark
from utils.benchmark import run_benchmark

results = run_benchmark(
    dataset_dir=path,
    scenario_filename="scenarios.json",
    conversations_dir=path,
    llm_judge_id=MODEL_ID,
    region=REGION,
    session=session
)

# Check if results is not None before proceeding
if results is not None:
    # Create high-level metrics DataFrame
    metrics_df = pd.DataFrame({
        'user_gsr': [results['user_gsr']],
        'system_gsr': [results['system_gsr']],
        'overall_gsr': [results['overall_gsr']],
        'partial_gsr': [results['partial_gsr']],
        'scenario_count': [results['scenario_count']],
        'conversation_count': [results['conversation_count']]
    })

    # Create detailed assertions DataFrame
    assertions_list = []
    for eval_result in results['conversation_evals']:
        trajectory_index = eval_result['trajectory_index']
        for assertion in eval_result['report']:
            assertions_list.append({
                'trajectory_index': trajectory_index,
                'assertion_type': assertion['assertion_type'],
                'assertion': assertion['assertion'],
                'answer': assertion['answer'],
                'evidence': assertion['evidence']
            })

    assertions_df = pd.DataFrame(assertions_list)

    # Display results
    print("High-level Metrics:")
    display(metrics_df)

    print("\nDetailed Assertions:")
    display(assertions_df)

else:
    print("Error: run_benchmark returned None. Please check for errors in the benchmark execution.")

## Log multi-agent-colloboration results to mlflow


In [None]:
# create new experiment for multi-agent-colloboration if it does not exist yet
import uuid
experiment_description = (
        "multi-agent-colloboration."
    )

experiment_tags = {
    "project_name": "multi-agent-colloboration",
    "use_case": "multi-agent-colloboration",
    "team": "aws-ai-ml-analytics",
    "source": "multi-agent-colloboration",
    "mlflow.note.content": experiment_description,
}

experiment_name = f"multi-agent-colloboration_{random_identifier}"

# Use search_experiments() to search on the project_name tag key
mac_experiments = mlflow.MlflowClient().search_experiments(
    filter_string="tags.`project_name` = 'multi-agent-colloboration'"
)
# if experiment does not exist yet, create it
if len(mac_experiments) == 0:
    mac_experiment = mlflow.MlflowClient().create_experiment(name=experiment_name, tags=experiment_tags)
    mlflow.set_experiment(experiment_name)
else:
    mac_experiment = mac_experiments[0]

# get the first run id of the mlflow experiment
runs = mlflow.search_runs(experiment_ids=[mac_experiment.experiment_id])
if len(runs) > 0:
    run_id = runs.iloc[0].run_id
    print(f"Run ID: {run_id}")


# log metrics to mlflow experiment

mlflow.log_metrics(metrics_df.iloc[0].to_dict(), run_id=run_id)

# log assertions to mlflow experiment
mlflow.log_table(assertions_df, "assertions.json", run_id=run_id)

In [None]:
# read all metrics json files from the "{dataset_dir}/conversations" directory into a pandas dataframe

# get number of files in the "{dataset_dir}/conversations" directory that contain metrics in the filename
print(f"path: {path}")
num_files = len([f for f in os.listdir(path) if f.endswith('.json') and 'metrics' in f])
print(f"num_files: {num_files}")

# read all metrics json files from the path directory into a pandas dataframe
metrics_run_df = pd.DataFrame()
for i in range(num_files):
    metrics_file = os.path.join(path, f"metrics_{i}.json")
    print(f"metrics_file: {metrics_file}")
    if not os.path.exists(metrics_file):
        continue
    with open(metrics_file) as f:
        # load json file into a pandas dataframe
        print(f"reading metrics file: {metrics_file}")
        metrics_data = json.load(f)
        metrics_run_tmp_df = pd.DataFrame([metrics_data])
        print(f"metrics_run_tmp_df: {metrics_run_tmp_df}")
        metrics_run_df = pd.concat([metrics_run_df, metrics_run_tmp_df], ignore_index=True)

# name the index column "conversation_index"
metrics_run_df.index.name = "conversation_index"

# get avg total_tokens, avg num_agent_calls, avg num tool_calls, avg num_kb_lookups, and total sum of total_tokens
# Calculate average and total metrics
if not metrics_run_df.empty:
    metrics_summary = {
        'avg_total_tokens': metrics_run_df['total_tokens'].mean() if 'total_tokens' in metrics_run_df.columns else 0,
        'avg_num_agent_calls': metrics_run_df['num_agent_calls'].mean() if 'num_agent_calls' in metrics_run_df.columns else 0,
        'avg_num_tool_calls': metrics_run_df['num_tool_calls'].mean() if 'num_tool_calls' in metrics_run_df.columns else 0,
        'avg_num_kb_lookups': metrics_run_df['num_kb_lookups'].mean() if 'num_kb_lookups' in metrics_run_df.columns else 0,
        'total_tokens': metrics_run_df['total_tokens'].sum() if 'total_tokens' in metrics_run_df.columns else 0
    }
    
    print("\nMetrics Summary:")
    for metric, value in metrics_summary.items():
        print(f"{metric}: {value:.2f}")
    
    # Log metrics to mlflow if run_id exists
    if 'run_id' in locals() and run_id:
        mlflow.log_metrics(metrics_summary, run_id=run_id)
else:
    print("No metrics data available for calculation")

## Summary

The Supervisor Agent was able to orchestrate the multi-agent collaboration by using the reachable agents to perform the tasks of identifying ML use cases, training a ML model, and then generating predictions with the trained model.

We successfully validated the Supervisor Agent results by using the assertions.



## Where to go from here

- We focused on the multi-agent collaboration evaluation framework. Thus we only used a relatively simple scenario and only two sample datasets. For an actual production use case, this would need to be significantly expanded.

- Here we used "anthropic.claude-3-5-sonnet-20240620-v1:0" in all our agents, but you can use this framework in a metric-driven development process to evaluate the performance of your agents and improve efficiency. As a next step, we could evaluate the performance of smaller/faster LLMs in the supervisor or any of the sub-agents and assess if the performance is still good enough.

- In its current form this use case is a sequential workflow: first uploading a dataset, then identifying ML use cases, then training a model, and finally generating predictions. Therefore, an agentic workflow would have been a more efficient approach. Instead of relying on the supervisor agent to determine the order of execution and handle the orchestration of the agents, we could have used Amazon Bedrock Flows or any other orchestration tool to execute these agents in the predefined, deterministic order (see also [Amazon Bedrock Flows](https://aws.amazon.com/bedrock/flows/) & [Anthropic's blog on building effective agents](https://www.anthropic.com/research/building-effective-agents) ).

- To properly implement the use case of identifying ML use cases and subsequently training a model and generating predictions, we would need to expand on the covered AI/ML use cases and algorithms. In addition, the current implementation is only meant for relatively small datasets as it runs synchronously in Lambda, and as such is bound by the Lambda execution time. In an actual production implementation, we would need to implement an asynchronous execution of the agents and their respective tools.

