## Create Data Engineer Agent

In this notebook we will create a data engineer agent with Amazon Bedrock Agents that will be able to:
- read data from a file
- do a semantic type detection on each column of the file
- create a SQL table definition
- create a table in Amazon Athena
- create data in a specified s3 directory


And then we will evaluate the agents's performance in different scenarios.


In [3]:
# Import necessary libraries and load environment variables
from dotenv import load_dotenv, find_dotenv, set_key
import os
import sagemaker
import boto3
import json
import pandas as pd

# loading environment variables that are stored in local file
local_env_filename = 'dev.env'
load_dotenv(find_dotenv(local_env_filename),override=True)

os.environ['REGION'] = os.getenv('REGION')
os.environ['S3_BUCKET_NAME'] = os.getenv('S3_BUCKET_NAME')
os.environ['AWS_ACCOUNT'] = os.getenv('AWS_ACCOUNT')
os.environ['DATAENGINEER_AGENT_PROFILE_ARN'] = os.getenv('DATAENGINEER_AGENT_PROFILE_ARN')
os.environ['DATAENGINEER_AGENT_EVAL_PROFILE_ARN'] = os.getenv('DATAENGINEER_AGENT_EVAL_PROFILE_ARN')


REGION = os.environ['REGION']
S3_BUCKET_NAME = os.environ['S3_BUCKET_NAME']
AWS_ACCOUNT = os.environ['AWS_ACCOUNT']
DATAENGINEER_AGENT_PROFILE_ARN = os.environ['DATAENGINEER_AGENT_PROFILE_ARN']
DATAENGINEER_AGENT_EVAL_PROFILE_ARN = os.environ['DATAENGINEER_AGENT_EVAL_PROFILE_ARN']

# Bedrock Agents does not yet support application inference profiles
MODEL_ID =  "anthropic.claude-3-5-sonnet-20240620-v1:0"



sagemaker.config INFO - Not applying SDK defaults from location: /Library/Application Support/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /Users/huthmac/Library/Application Support/sagemaker/config.yaml


In [4]:
import botocore.config
config = botocore.config.Config(
    connect_timeout=600,  # 10 minutes
    read_timeout=600,     # 10 minutes
    retries={'max_attempts': 3}
)

session = boto3.Session(region_name=REGION)

# Create a SageMaker session
sagemaker_session = sagemaker.Session(boto_session=session)
bedrock_agent_client = session.client('bedrock-agent', config=config)
bedrock_agent_runtime_client = session.client('bedrock-agent-runtime', config=config)
bedrock_runtime_client = session.client('bedrock-runtime', config=config)
bedrock_client = session.client('bedrock', config=config)
lambda_client = session.client('lambda', config=config)
iam_resource = session.resource('iam')
iam_client = session.client('iam')
athena_client = session.client('athena')
s3_client = session.client('s3')

## Create Data Engineer agent lambda function

In [None]:
%%writefile ../dataengineer/bedrock_data_engineer_agent.py
import json
import sys
import logging
import pandas as pd
import boto3
import os
import zipfile
from urllib.parse import urlparse
import time
from io import StringIO
from pydantic import BaseModel, Field
from typing import List, Dict, Any, Optional, Union
from enum import Enum
import re

class RequestType(str, Enum):
    GET_INFORMATION_FOR_SEMANTIC_TYPE_DETECTION = "/GetInformationForSemanticTypeDetection"
    SAVE_SQL_TABLE_DEFINITION = "/SaveSQLTableDefinition"
    CREATE_ATHENA_TABLE = "/CreateAthenaTable"
    QUERY_DATA = "/QueryData"
    GET_DATABASE_SCHEMA = "/GetDatabaseSchema"
    GET_ERM = "/GetERM"
    SAVE_ERM = "/SaveERM"



class APIResponse(BaseModel):
    message: str
    results: Dict[str, Any]

# Configure logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)


# get the environment variables
if 'S3_BUCKET_NAME' not in globals():
    S3_BUCKET_NAME = os.getenv('S3_BUCKET_NAME')
    logger.info(f"S3_BUCKET_NAME: {S3_BUCKET_NAME}")

if 'MODEL_ID' not in globals():
    MODEL_ID = os.getenv('MODEL_ID')
    logger.info(f"MODEL_ID: {MODEL_ID}")

if 'REGION' not in globals():
    REGION = os.getenv('REGION')
    logger.info(f"REGION: {REGION}")

if 'ATHENA_QUERY_EXECUTION_LOCATION' not in globals():
    ATHENA_QUERY_EXECUTION_LOCATION = f's3://{S3_BUCKET_NAME}/athena_results/'
    logger.info(f"ATHENA_QUERY_EXECUTION_LOCATION: {ATHENA_QUERY_EXECUTION_LOCATION}")

session = boto3.Session(region_name=REGION)

s3_client = session.client('s3')
athena_client = session.client('athena')

try:
    response = athena_client.get_work_group(WorkGroup='primary')
    ConfigurationUpdates={}
    ConfigurationUpdates['EnforceWorkGroupConfiguration']= True
    ResultConfigurationUpdates= {}
    athena_location = "s3://"+ S3_BUCKET_NAME +"/athena_results/"
    ResultConfigurationUpdates['OutputLocation']=athena_location
    EngineVersion = response['WorkGroup']['Configuration']['EngineVersion']
    ConfigurationUpdates['ResultConfigurationUpdates']=ResultConfigurationUpdates
    ConfigurationUpdates['PublishCloudWatchMetricsEnabled']= response['WorkGroup']['Configuration']['PublishCloudWatchMetricsEnabled']
    ConfigurationUpdates['EngineVersion']=EngineVersion
    ConfigurationUpdates['RequesterPaysEnabled']= response['WorkGroup']['Configuration']['RequesterPaysEnabled']
    response2 = athena_client.update_work_group(WorkGroup='primary',ConfigurationUpdates=ConfigurationUpdates,State='ENABLED')
    logger.info(f"athena output location updated to s3://{S3_BUCKET_NAME}/athena_results/")  
except Exception as e:
    logger.error(str(e))



def parse_json(json_string):
    if not json_string:  # Handle None or empty string
        logger.warning("Received empty JSON string")
        return {
            'semantic_column_name': 'unknown',
            'column_description': 'No response from LLM',
            'data_type': 'unknown',
            'usecases': []
        }
    try:
        # First try to clean up any leading/trailing whitespace
        json_string = json_string.strip()
        
        # Remove any text before the first '{'
        if '{' in json_string:
            json_string = json_string[json_string.find('{'):]
            
        # Remove any text after the last '}'
        if '}' in json_string:
            json_string = json_string[:json_string.rfind('}')+1]
            
        # Try to parse the cleaned JSON string
        parsed = json.loads(json_string)
        
        # Convert old semantic_type key to semantic_column_name if needed
        if 'semantic_type' in parsed and 'semantic_column_name' not in parsed:
            parsed['semantic_column_name'] = parsed.pop('semantic_type')
            
        return parsed
        
    except json.JSONDecodeError as e:
        logger.error(f"Error decoding JSON: {e}")
        logger.error(f"Problematic JSON string: {json_string}")
        return {
            'semantic_column_name': 'unknown',
            'column_description': 'No response from LLM',
            'data_type': 'unknown',
            'usecases': []
        }

def execute_athena_query(database, query):
    logger.info("Executing Athena query...")
    # Start query execution
    response = athena_client.start_query_execution(
        QueryString=query,
        QueryExecutionContext={
            'Database': database
        },
        ResultConfiguration={
            'OutputLocation': ATHENA_QUERY_EXECUTION_LOCATION
        }
    )

    # Get query execution ID
    query_execution_id = response['QueryExecutionId']
    print(f"Query Execution ID: {query_execution_id}")

    # Wait for the query to complete
    response_wait = athena_client.get_query_execution(QueryExecutionId=query_execution_id)

    while response_wait['QueryExecution']['Status']['State'] in ['QUEUED', 'RUNNING']:
        print("Query is still running...")
        response_wait = athena_client.get_query_execution(QueryExecutionId=query_execution_id)

    print(f'response_wait {response_wait}')

    # Check if the query completed successfully
    if response_wait['QueryExecution']['Status']['State'] == 'SUCCEEDED':
        print("Query succeeded!")

        # Get query results
        query_results = athena_client.get_query_results(QueryExecutionId=query_execution_id)

        # Extract and return the result data
        code = 'SUCCEEDED'
        return code, extract_result_data(query_results)

    else:
        print("Query failed!")
        code = response_wait['QueryExecution']['Status']['State']
        message = response_wait['QueryExecution']['Status']['StateChangeReason']
    
        return code, message

def extract_result_data(query_results):
    # Return a cleaned response to the agent
    result_data = []

    # Extract column names
    column_info = query_results['ResultSet']['ResultSetMetadata']['ColumnInfo']
    column_names = [column['Name'] for column in column_info]

    # Extract data rows
    for row in query_results['ResultSet']['Rows']:
        # Handle different data types in the response
        data = []
        for item in row['Data']:
            # Each item may contain different field based on the data type
            value = None
            if 'VarCharValue' in item:
                value = item['VarCharValue']
            elif 'NumberValue' in item:
                value = str(item['NumberValue'])  # Convert to string for consistency
            else:
                value = str(item)  # Fallback to string representation
            data.append(value)
            
        result_data.append(dict(zip(column_names, data)))

    return result_data


def create_sql_table_definition(sql_table_definition, s3_file_location) -> APIResponse:
    logger.info("Creating SQL table definition...")
    
    try:
        # Create a writable directory
        temp_dir = "/tmp/sql_table_definition"
        os.makedirs(temp_dir, exist_ok=True)
        
        # Set the current working directory to the temp directory
        original_dir = os.getcwd()
        os.chdir(temp_dir)

        # save sql_table_definition to s3 based on s3_file_location
        s3_url = urlparse(s3_file_location)
        bucket = s3_url.netloc
        key = s3_url.path.lstrip('/')  # Remove leading slash
        # extract the filename from the key
        filename = key.split('/')[-1]

        with open(filename, 'w') as f:
            f.write(sql_table_definition)

        s3_client.upload_file(
            filename,
            Bucket=bucket,
            Key=key
        )
        
        return APIResponse(
            message="SQL table definition successful",
            results={
                'sql_table_definition': sql_table_definition,
                'sql_table_definition_file_location': s3_file_location
            }
        )
    
    except Exception as e:
        logger.error(f"Error in create_sql_table_definition function: {str(e)}")
        raise


def get_database_schema() -> str:
        """Retrieve the SQL database schema from S3"""
        schema_prefix = 'metadata/sql_table_definition'
        logger.info(f"Retrieving database schema from s3://{S3_BUCKET_NAME}/{schema_prefix}")
        
        sql_database_schema = []
        try:
            response = s3_client.list_objects_v2(
                Bucket=S3_BUCKET_NAME, 
                Prefix=schema_prefix
            )
            
            if 'Contents' not in response:
                logger.warning(f"No schema files found in s3://{S3_BUCKET_NAME}/{schema_prefix}")
                return "[]"
            
            logger.info(f"Found {len(response['Contents'])} schema files")
            
            for item in response['Contents']:
                if item['Key'].endswith('/'):
                    continue
                    
                logger.info(f"Reading schema file: {item['Key']}")
                try:
                    content = s3_client.get_object(
                        Bucket=S3_BUCKET_NAME, 
                        Key=item['Key']
                    )['Body'].read().decode('utf-8')
                    sql_database_schema.append(content)
                    logger.debug(f"Successfully read schema from {item['Key']}")
                except Exception as e:
                    logger.error(f"Error reading schema file {item['Key']}: {str(e)}")
            
            logger.info(f"Successfully retrieved {len(sql_database_schema)} schema definitions")
            return json.dumps(sql_database_schema)
            
        except Exception as e:
            logger.error(f"Error in get_database_schema: {str(e)}", exc_info=True)
            return "[]"

def get_erm_schema() -> str:
        """Retrieve the Entity Relationship Model (ERM) from S3"""
        erm_prefix = 'metadata/er_diagram'
        logger.info(f"Retrieving ERM from s3://{S3_BUCKET_NAME}/{erm_prefix}")
        
        erm_schemas = []
        try:
            response = s3_client.list_objects_v2(
                Bucket=S3_BUCKET_NAME, 
                Prefix=erm_prefix
            )
            
            if 'Contents' not in response:
                logger.warning(f"No ERM files found in s3://{S3_BUCKET_NAME}/{erm_prefix}")
                return "[]"
            
            logger.info(f"Found {len(response['Contents'])} ERM files")
            
            for item in response['Contents']:
                if item['Key'].endswith('/'):
                    continue
                    
                logger.info(f"Reading ERM file: {item['Key']}")
                try:
                    content = s3_client.get_object(
                        Bucket=S3_BUCKET_NAME, 
                        Key=item['Key']
                    )['Body'].read().decode('utf-8')
                    # Verify it's valid JSON before adding
                    json_content = json.loads(content)
                    erm_schemas.append(json_content)
                    logger.debug(f"Successfully read ERM from {item['Key']}")
                except json.JSONDecodeError:
                    logger.error(f"File {item['Key']} contains invalid JSON")
                except Exception as e:
                    logger.error(f"Error reading ERM file {item['Key']}: {str(e)}")
            
            logger.info(f"Successfully retrieved {len(erm_schemas)} ERM schemas")
            return json.dumps(erm_schemas)
            
        except Exception as e:
            logger.error(f"Error in get_erm_schema: {str(e)}", exc_info=True)
            return "[]"

def save_erm_schema(erm_data) -> APIResponse:
    """Save the Entity Relationship Model (ERM) to S3
    
    Args:
        erm_data: The ERM data to save
        
    Returns:
        APIResponse: Response with status and saved ERM location
    """
    try:
        # Create a writable directory
        temp_dir = "/tmp/erm"
        os.makedirs(temp_dir, exist_ok=True)
        
        # Set the current working directory to the temp directory
        original_dir = os.getcwd()
        os.chdir(temp_dir)
        

        filename = f'erm.json'
        filepath = os.path.join(temp_dir, filename)
        
        # Save the ERM data to a file
        with open(filepath, 'w') as f:
            if isinstance(erm_data, str):
                f.write(erm_data)
            else:
                json.dump(erm_data, f, indent=2)
        
        # Upload the file to S3
        s3_key = f'metadata/er_diagram/{filename}'
        s3_client.upload_file(
            filepath,
            Bucket=S3_BUCKET_NAME,
            Key=s3_key
        )
        
        s3_location = f's3://{S3_BUCKET_NAME}/{s3_key}'
        logger.info(f"Saved ERM to {s3_location}")
        
        # Change back to original directory
        os.chdir(original_dir)
        
        return APIResponse(
            message="ERM saved successfully",
            results={
                'erm_file_location': s3_location
            }
        )
    except Exception as e:
        logger.error(f"Error in save_erm_schema: {str(e)}", exc_info=True)
        raise


def query_athena_table(athena_database, sql_query) -> APIResponse:
    logger.info("Querying Athena table...")
    try:

        # execute sql query
        status_code, response_data = execute_athena_query(athena_database, sql_query)
        logger.info(f"Athena query execution response: {response_data}")
        logger.info(f"status_code: {status_code}")
        if status_code == 'SUCCEEDED':
            return APIResponse(
                message=f"Query execution {status_code}",
                results={
                    'status': status_code,
                    'query': sql_query,
                    'data': response_data if isinstance(response_data, dict) else {'results': response_data}
                }
            )
        else:
            return APIResponse(
                message=f"Query execution failed with status: {status_code}",
                results={
                    'status': status_code,
                    'query': sql_query,
                    'error': response_data if isinstance(response_data, str) else str(response_data)
                }
            )

    except Exception as e:
        logger.error(f"Error in query_athena_table function: {str(e)}")
        raise

def check_athena_table_exists(database, table_name):
    """
    Check if a table exists in Athena
    
    Args:
        database (str): The name of the database
        table_name (str): The name of the table to check
        
    Returns:
        bool: True if table exists, False otherwise
    """
    logger.info(f"Checking if table {table_name} exists in database {database}")
    try:
        # Try to get table metadata
        response = athena_client.get_table_metadata(
            CatalogName='AwsDataCatalog',
            DatabaseName=database,
            TableName=table_name
        )
        logger.info(f"get_table_metadata response: {response}")
        logger.info(f"Table {table_name} exists in database {database}")
        return True
    except athena_client.exceptions.MetadataException:
        logger.info(f"Table {table_name} does not exist in database {database}")
        return False
    except Exception as e:
        logger.error(f"Error checking if table exists: {str(e)}")
        return False
    

def parse_request_parameters(event):
    """Parse request parameters from the Lambda event"""
    parameters = {}
    
    # Extract parameters from requestBody if present
    if event.get('requestBody') and event['requestBody'].get('content'):
        content = event['requestBody']['content']
        if 'application/json' in content and 'properties' in content['application/json']:
            for prop in content['application/json']['properties']:
                parameters[prop['name']] = prop['value']
    
    return parameters

def lambda_handler(event, context):
    try:
        # Create base temp directories at the start
        os.makedirs("/tmp/data", exist_ok=True)
        os.makedirs("/tmp/metadata", exist_ok=True)
        
        logger.info(f"Received event: {json.dumps(event)}")
        
        parameters = parse_request_parameters(event)
        logger.info(f"parameters: {parameters}")

        request_type = event.get('apiPath')
        logger.info(f"request_type: {request_type}")
        
        
        if parameters.get('DataLocation') is not None and parameters.get('DataLocation') != "":
            logger.info(f"Downloading data from {parameters.get('DataLocation')}")

            # Ensure data directory exists
            os.makedirs("/tmp/data", exist_ok=True)

            # Clear the /tmp/data directory before processing new file
            for file in os.listdir("/tmp/data"):
                os.remove(os.path.join("/tmp/data", file))

            # split data_location into bucket and key
            s3_url = urlparse(parameters.get('DataLocation'))
            data_bucket = s3_url.netloc
            logger.info(f"data_bucket: {data_bucket}")
            data_key = s3_url.path.lstrip('/')  # Remove leading slash
            logger.info(f"data_key: {data_key}")

            # Get just the filename from the path
            filename = os.path.basename(data_key)
            logger.info(f"filename: {filename}")
            # remove the extension from the filename
            file_or_table_name = os.path.splitext(filename)[0]
            logger.info(f"file_or_table_name: {file_or_table_name}")

            local_file_path = os.path.join("/tmp/data", filename)
            logger.info(f"local_file_path: {local_file_path}")
            # download the data from s3
            s3_client.download_file(
                Bucket=data_bucket,
                Key=data_key,
                Filename=local_file_path
            )

            # check if the data is zipped then unzip it
            if data_key.endswith(".zip"):
                logger.info(f"Unzipping data")
                with zipfile.ZipFile(local_file_path, "r") as zip_ref:
                    zip_ref.extractall("/tmp/data")
                logger.info(f"Unzipped data to /tmp/data")
            
            logger.info(f"Files in /tmp/data: {os.listdir('/tmp/data')}")

            # Load data from the first compatible file found
            data_loaded = False
            for file in os.listdir("/tmp/data"):
                logger.info(f"file: {file}")
                file_path = os.path.join("/tmp/data", file)
                logger.info(f"Processing file: {file_path}")
                ext = os.path.splitext(file)[1].lower()
                logger.info(f"File extension: {ext}")
                # set filename
                try:
                    if ext == '.csv':
                        df = pd.read_csv(file_path)
                        # remove the extension from the filename
                        file_or_table_name = os.path.splitext(file)[0]
                        
                    elif ext == '.json':
                        df = pd.read_json(file_path)
                        # remove the extension from the filename
                        file_or_table_name = os.path.splitext(file)[0]
                    elif ext == '.parquet':
                        df = pd.read_parquet(file_path)
                        # remove the extension from the filename
                        file_or_table_name = os.path.splitext(file)[0]
                    
                    if not df.empty:
                        logger.info(f"Successfully loaded data from {file_path}")
                        data_loaded = True
                        logger.info(f"file_or_table_name: {file_or_table_name}")
                        break
                except Exception as e:
                    logger.warning(f"Failed to load {file_path}: {str(e)}")
                    continue

            if not data_loaded:
                raise ValueError("Could not load data from any files in the specified location")
            

        response_data = None
        s3_base_path = f's3://{S3_BUCKET_NAME}/'

        if request_type == RequestType.GET_DATABASE_SCHEMA:
            # Get database schema
            schema = get_database_schema()
            response_data = json.loads(schema)  # Convert string to JSON array
            response_body = {
                'application/json': {
                    'body': response_data
                }
            }
            response_data = APIResponse(
                message="Database schema retrieved successfully",
                results=response_body
            )


        if request_type == RequestType.GET_INFORMATION_FOR_SEMANTIC_TYPE_DETECTION:
            # Export sample data to make available for the ReAct agent
            sample_df = df.head(100)  # Use only a sample
            
            # Return information needed for the ReAct agent to perform semantic detection
            response_data = APIResponse(
                message="Data sample prepared for semantic type detection",
                results={
                    'column_names': list(df.columns),
                    'data_sample': sample_df.head(10).to_dict('records')
                }
            )
        
        if request_type == RequestType.SAVE_SQL_TABLE_DEFINITION:
            # Get the SQL table definition
            sql_table_definition = parameters.get('SQL_Table_Definition')
            table_name = parameters.get('TableName')
            
            if not table_name:
                # Try to extract table name from SQL definition
                table_match = re.search(r'CREATE\s+TABLE\s+(\w+)', sql_table_definition, re.IGNORECASE)
                if table_match:
                    table_name = table_match.group(1)
                else:
                    # Fallback to a default name if we can't extract it
                    table_name = f"table_{int(time.time())}"
                logger.info(f"Extracted or generated table name: {table_name}")
            
            if sql_table_definition:
                # Save SQL definition to S3
                s3_file_location = f'{s3_base_path}metadata/sql_table_definition/{table_name}_sql_table_definition.sql'
                s3_file_location = s3_file_location.lower()
                response_data = create_sql_table_definition(sql_table_definition, s3_file_location)
                
            else:
                response_data = APIResponse(
                    message="Missing SQL table definition",
                    results={
                        'error': "Missing SQL table definition"
                    }
                )

        if request_type == RequestType.CREATE_ATHENA_TABLE:
            table_name = parameters.get('TableName')
            athena_database = parameters.get('AthenaDatabase')
            athena_table_create_definition = parameters.get('Athena_Table_Create_SQL_statement')
            
            # For backward compatibility, check other possible parameter names
            if athena_table_create_definition is None:
                athena_table_create_definition = parameters.get('TableDefinition')
                if athena_table_create_definition is None:
                    athena_table_create_definition = parameters.get('Table_Definition')
            
            if athena_table_create_definition is None:
                logger.error("Missing table definition parameter")
                raise ValueError("Missing required parameter: Athena_Table_Create_SQL_statement")
                
            data_location = parameters.get('DataLocation')
            if not data_location:
                logger.error("Missing DataLocation parameter")
                raise ValueError("Missing required parameter: DataLocation")
                
            s3_target_file_location = f'{s3_base_path}raw/{table_name}/'
            
            # Get the filename from the data_location
            filename = os.path.basename(data_location)
            
            logger.info(f"Original data location: {data_location}")
            s3_target_file_location = s3_target_file_location.lower()
            logger.info(f"Target location: {s3_target_file_location}")
            
            # Update the LOCATION with the s3_target_file_location if it exists in the SQL
            if 'LOCATION' in athena_table_create_definition:
                # Extract the current location
                location_match = re.search(r"LOCATION\s+'([^']+)'", athena_table_create_definition)
                if location_match:
                    current_location = location_match.group(1)
                    logger.info(f"Current location in SQL: {current_location}")
                    # Replace with the new location
                    athena_table_create_definition = athena_table_create_definition.replace(
                        f"LOCATION '{current_location}'", 
                        f"LOCATION '{s3_target_file_location}'"
                    )
                    logger.info(f"Updated SQL with new location: {s3_target_file_location}")
            
            # Prepare and upload the data
            prepare_and_upload_data(df, athena_table_create_definition, s3_target_file_location)
            
            # Execute the provided table definition in Athena
            status_code, response_data = execute_athena_query(athena_database, athena_table_create_definition)
            
            response_data = APIResponse(
                message=f"Athena table creation {status_code}",
                results={
                    'status': status_code,
                    'query_results': response_data if isinstance(response_data, dict) else {'data': response_data}
                }
            )

        if request_type == RequestType.QUERY_DATA:
            sql_query = parameters.get('SQLQuery')
            athena_database = parameters.get('AthenaDatabase')
            response_data = query_athena_table(athena_database, sql_query)

        if request_type == RequestType.GET_ERM:
            # Get the ERM schema
            erm_schema = get_erm_schema()
            response_data = json.loads(erm_schema)  # Convert string to JSON array
            response_body = {
                'application/json': {
                    'body': response_data
                }
            }
            response_data = APIResponse(
                message="ERM schema retrieved successfully",
                results=response_body
            )
        
        if request_type == RequestType.SAVE_ERM:
            # Extract ERM data from request body
            erm_data = parameters.get('ERMData')
            
            if erm_data:
                response_data = save_erm_schema(erm_data)
            elif event.get('requestBody') and event['requestBody'].get('content'):
                content = event['requestBody']['content']
                if 'application/json' in content and 'body' in content['application/json']:
                    erm_data = content['application/json']['body']
                    response_data = save_erm_schema(erm_data)
                else:
                    response_data = APIResponse(
                        message="Missing ERM data in request body",
                        results={
                            'error': "Missing ERM data in request body"
                        }
                    )
            else:
                response_data = APIResponse(
                    message="Missing ERM data",
                    results={
                        'error': "Missing ERM data parameter or request body"
                    }
                )
                
        # Format successful response
        response_body = {
            'application/json': {
                'body': {
                    'message': str(response_data.message),
                    'results': str(response_data.results)
                }
            }
        }
        response_size = sys.getsizeof(json.dumps(response_body))
        MAX_RESPONSE_SIZE = 22000  # 22KB limit
        if response_size > MAX_RESPONSE_SIZE:
            logger.error(f"Response size {response_size} exceeds limit. Truncating content...")

        response_code = 200
        action_response = {
            'actionGroup': event['actionGroup'],
            'apiPath': event['apiPath'],
            'httpMethod': event['httpMethod'],
            'httpStatusCode': response_code,
            'responseBody': response_body
        }

        api_response = {'messageVersion': '1.0', 'response': action_response}
        logger.info(f"action_response: {api_response}")
        return api_response
            
    except Exception as e:
        logger.error(f"Error in lambda_handler: {str(e)}")
        logger.error(f"Exception type: {type(e)}")
        logger.error(f"Stack trace: {sys.exc_info()}")
            
        return {
            "messageVersion": "1.0",
            "response": {
                "actionGroup": event['actionGroup'],
                "apiPath": event['apiPath'],
                "httpMethod": event['httpMethod'],
                "httpStatusCode": 400 if isinstance(e, ValueError) else 500,
                "responseBody": {
                    'application/json': {
                        'body': {
                            "error": str(e),
                            "errorCode": "400" if isinstance(e, ValueError) else "500"
                        }
                    }
                }
            }
        }

def prepare_and_upload_data(df, sql_table_definition, s3_target_file_location):
    """
    Updates DataFrame columns to match SQL table definition and uploads to S3
    
    Args:
        df (pandas.DataFrame): The DataFrame to process
        sql_table_definition (str): SQL CREATE TABLE statement
        s3_target_file_location (str): S3 location to upload the processed file
        
    Returns:
        str: The S3 location where the file was uploaded
    """
    try:
        
        logger.info(f"Preparing data for upload to {s3_target_file_location}")
        
        # Extract column names from SQL table definition
        # This regex looks for column definitions in a CREATE TABLE statement
        # Updated to handle both regular and EXTERNAL tables
        column_pattern = r'CREATE\s+(EXTERNAL\s+)?TABLE\s+\w+\s*\((.*?)\).*?(?:LOCATION|$)'
        match = re.search(column_pattern, sql_table_definition, re.DOTALL | re.IGNORECASE)
        
        if not match:
            logger.error("Could not extract column definitions from SQL statement")
            logger.error(f"SQL statement: {sql_table_definition}")
            raise ValueError("Invalid SQL table definition format")
            
        # Group 2 contains the column definitions if EXTERNAL is present, otherwise group 1
        column_section = match.group(2)
        
        # Extract individual column names
        columns = []
        for line in column_section.split(','):
            # Extract the column name (first word in the line)
            column_match = re.search(r'^\s*(\w+)', line.strip())
            if column_match:
                columns.append(column_match.group(1))
        
        logger.info(f"Extracted columns from SQL definition: {columns}")
        
        if len(columns) == 0:
            logger.error("No columns extracted from SQL definition")
            raise ValueError("No columns found in SQL table definition")
            
        # Check if number of columns matches
        if len(columns) != len(df.columns):
            logger.warning(f"Column count mismatch: SQL has {len(columns)}, DataFrame has {len(df.columns)}")
            # We'll proceed anyway and rename the columns we have
        
        # Create a mapping of old column names to new column names
        # Use min length to avoid index errors if counts don't match
        column_mapping = {}
        for i in range(min(len(df.columns), len(columns))):
            column_mapping[df.columns[i]] = columns[i]
        
        # Rename the DataFrame columns
        df = df.rename(columns=column_mapping)
        logger.info(f"Renamed DataFrame columns to: {list(df.columns)}")
        
        # Parse S3 URL
        s3_url = urlparse(s3_target_file_location)
        bucket = s3_url.netloc
        key = s3_url.path.lstrip('/')
        
        # Create a temporary file
        temp_dir = "/tmp/processed_data"
        os.makedirs(temp_dir, exist_ok=True)
        
        # If s3_target_file_location is a directory (ends with /), append filename
        if s3_target_file_location.endswith('/'):
            filename = f"{os.path.basename(key.rstrip('/'))}_{int(time.time())}.csv"
            key = f"{key}{filename}"
            local_file_path = os.path.join(temp_dir, filename)
        else:
            local_file_path = os.path.join(temp_dir, os.path.basename(key))
        
        # Save DataFrame to CSV
        df.to_csv(local_file_path, index=False)
        logger.info(f"Saved processed data to {local_file_path}")
        
        # Upload to S3
        s3_client.upload_file(
            local_file_path,
            Bucket=bucket,
            Key=key
        )
        logger.info(f"Uploaded processed data to s3://{bucket}/{key}")
        
        return f"s3://{bucket}/{key}"
        
    except Exception as e:
        logger.error(f"Error in prepare_and_upload_data: {str(e)}")
        raise


In [None]:
%%writefile ../dataengineer/dataengineer.Dockerfile
FROM public.ecr.aws/lambda/python:3.11

# Install build dependencies first
RUN yum install libgomp git gcc gcc-c++ make -y \
 && yum clean all -y && rm -rf /var/cache/yum


RUN python3 -m pip --no-cache-dir install --upgrade --trusted-host pypi.org --trusted-host files.pythonhosted.org pip \
 && python3 -m pip --no-cache-dir install --upgrade wheel setuptools \
 && python3 -m pip --no-cache-dir install --upgrade pandas \
 && python3 -m pip --no-cache-dir install --upgrade boto3 \
 && python3 -m pip --no-cache-dir install --upgrade opensearch-py \
 && python3 -m pip --no-cache-dir install --upgrade Pillow \
 && python3 -m pip --no-cache-dir install --upgrade pyarrow \
 && python3 -m pip --no-cache-dir install --upgrade fastparquet \
 && python3 -m pip --no-cache-dir install --upgrade urllib3 \
 && python3 -m pip --no-cache-dir install --upgrade pydantic

# Copy function code
WORKDIR /var/task
COPY ../dataengineer/bedrock_data_engineer_agent.py .
COPY ../notebooks/utils/ utils/

# Set handler environment variable
ENV _HANDLER="bedrock_data_engineer_agent.lambda_handler"

# Let's go back to using the default entrypoint
ENTRYPOINT [ "/lambda-entrypoint.sh" ]
CMD [ "bedrock_data_engineer_agent.lambda_handler" ]

## Build and run local docker container to test the dataengineer-lambda function

In [None]:
# Build and run local docker container
!docker build -t dataengineer-lambda -f ../dataengineer/dataengineer.Dockerfile ..

In [None]:
# docker run with tailing log
credentials = session.get_credentials()
credentials = credentials.get_frozen_credentials()

!docker run -d \
-e AWS_ACCESS_KEY_ID={credentials.access_key} \
-e AWS_SECRET_ACCESS_KEY={credentials.secret_key} \
-e AWS_SESSION_TOKEN={credentials.token} \
-e AWS_DEFAULT_REGION={REGION} \
-e REGION={REGION} \
-e AWS_LAMBDA_FUNCTION_TIMEOUT=900 \
-e S3_BUCKET_NAME={S3_BUCKET_NAME} \
-p 9000:8080 dataengineer-lambda

In [None]:
!docker ps --filter ancestor=dataengineer-lambda

In [None]:
# test GetInformationForSemanticTypeDetection

# sample request structure:
request_body = {
    "apiPath": "/GetInformationForSemanticTypeDetection",
    "requestBody": {
        "content": {
            "application/json": {
                "properties": [
                    {
                        "name": "DataLocation",
                        "type": "string",
                        "value": f"s3://{S3_BUCKET_NAME}/uploads/Customers.csv"
                    }
                ]
            }
        }
    },
    "httpMethod": "POST",
    "actionGroup": "DataEngineerActions",
}

import requests
response = requests.post("http://localhost:9000/2015-03-31/functions/function/invocations",
                         json=request_body,
                         timeout=900  # 15 minutes timeout
)
print(response.json())

In [None]:
# query Athena table
request_body = {
    "apiPath": "/QueryData",
    "requestBody": {
        "content": {
            "application/json": {
                "properties": [
                    {
                        "name": "SQLQuery",
                        "type": "string",
                        "value": "SELECT * FROM customers limit 10;"
                    },
                    {
                        "name": "AthenaDatabase",
                        "type": "string",
                        "value": f"{S3_BUCKET_NAME.replace('-', '_')}"
                    }
                ]
            }
        }
    },
    "httpMethod": "POST",
    "actionGroup": "DataEngineerActions",
 }

import requests
response = requests.post("http://localhost:9000/2015-03-31/functions/function/invocations",
                         json=request_body,
                         timeout=900  # 15 minutes timeout
)
print(response.json())

In [None]:
# create sql table definition
request_body = {
    "apiPath": "/SaveSQLTableDefinition",
    "requestBody": {
        "content": {
            "application/json": {
                "properties": [
                    {
                        "name": "SQL_Table_Definition",
                        "type": "string",
                        "value": '''CREATE TABLE customer_data (
                                    CustomerID INT PRIMARY KEY, -- customer_id: Used as primary key for customer identification and tracking in various analyses
                                    FirstName VARCHAR(50), -- first_name: Customer's first name for personalized marketing and customer identity verification
                                    LastName VARCHAR(50), -- customer_lastname: Customer's last name for personalized marketing and fraud detection
                                    Email VARCHAR(100), -- email_address: Unique identifier for communication, used in email marketing and customer segmentation
                                    Phone VARCHAR(20), -- phone_number: Contact information for customer service and marketing campaign targeting
                                    Address VARCHAR(100), -- street_address: Used for geocoding, address validation, and targeted marketing
                                    City VARCHAR(50), -- customer_city: Location data for geographic analysis and targeted marketing campaigns
                                    State VARCHAR(50), -- state_name: Used for regional sales forecasting and geographic segmentation
                                    Country VARCHAR(50), -- customer_country: Used for market targeting and international expansion planning
                                    PostalCode INT, -- postal_code: Used for geographic segmentation and logistics optimization
                                    DateOfBirth DATETIME, -- date_of_birth: Used for age calculation, demographic analysis, and personalized marketing
                                    Gender VARCHAR(20), -- gender_identity: Used for customer segmentation and diversity analysis in marketing
                                    CreatedDate DATETIME, -- customer_creation_datetime: Timestamp for customer acquisition trend analysis and cohort analysis
                                    LastUpdated DATETIME -- last_updated_timestamp: Used for data quality monitoring and customer engagement analysis
                                );'''
                    },
                    {
                        "name": "TableName",
                        "type": "string",
                        "value": "customer_data"
                    }
                ]
            }
        }
    },
    "httpMethod": "POST",
    "actionGroup": "DataEngineerActions"
}

import requests
response = requests.post("http://localhost:9000/2015-03-31/functions/function/invocations",
                         json=request_body,
                         timeout=900  # 15 minutes timeout
)
print(response.json())


In [None]:
# create Athena table
request_body = {
    "apiPath": "/CreateAthenaTable",
    "requestBody": {
        "content": {
            "application/json": {
                "properties": [
                    {
                        "name": "Athena_Table_Create_SQL_statement",
                        "type": "string",
                        "value": '''CREATE EXTERNAL TABLE customer_data (
                                    customer_id INT, 
                                    first_name STRING, 
                                    customer_lastname STRING, 
                                    email_address STRING, 
                                    phone_number STRING, 
                                    street_address STRING, 
                                    customer_city STRING, 
                                    state_name STRING, 
                                    customer_country STRING, 
                                    postal_code INT, 
                                    date_of_birth TIMESTAMP, 
                                    gender_identity STRING, 
                                    customer_creation_datetime TIMESTAMP, 
                                    last_updated_timestamp TIMESTAMP
                                )
                                ROW FORMAT DELIMITED
                                FIELDS TERMINATED BY ','
                                STORED AS TEXTFILE
                                LOCATION 's3://huthmac-automatedinsights/uploads/'
                                TBLPROPERTIES ('skip.header.line.count'='1');
                            '''
                    },
                    {
                        "name": "TableName",
                        "type": "string",
                        "value": "customer_data"
                    },
                    {
                        "name": "DataLocation",
                        "type": "string",
                        "value": f"s3://{S3_BUCKET_NAME}/uploads/Customers.csv"
                    },
                    {
                        "name": "AthenaDatabase",
                        "type": "string",
                        "value": f"{S3_BUCKET_NAME.replace('-', '_')}"
                    }
                ]
            }
        }
    },
    "httpMethod": "POST",
    "actionGroup": "DataEngineerActions"
}

import requests
response = requests.post("http://localhost:9000/2015-03-31/functions/function/invocations",
                         json=request_body,
                         timeout=900  # 15 minutes timeout
)
print(response.json())


In [None]:
# get database schema
request_body = {
    "apiPath": "/GetDatabaseSchema",
    "requestBody": {
        "content": {
            "application/json": {
                "properties": [
                    {
                        "name": "AthenaDatabase",
                        "type": "string",
                        "value": f"{S3_BUCKET_NAME.replace('-', '_')}"
                    }
                ]
            }
        }
    },
    
    "httpMethod": "POST",
    "actionGroup": "DataEngineerActions",
}

import requests
response = requests.post("http://localhost:9000/2015-03-31/functions/function/invocations",
                         json=request_body,
                         timeout=900  # 15 minutes timeout
)
print(response.json())


In [None]:
# get entity relationship diagram in json format
request_body = {
    "apiPath": "/GetERM",
    "requestBody": {
        "content": {
            "application/json": {
                "properties": [
                    {
                        "name": "AthenaDatabase",
                        "type": "string",
                        "value": f"{S3_BUCKET_NAME.replace('-', '_')}"
                    }
                ]
            }
        }
    },
    
    "httpMethod": "POST",
    "actionGroup": "DataEngineerActions",
}

import requests
response = requests.post("http://localhost:9000/2015-03-31/functions/function/invocations",
                         json=request_body,
                         timeout=900  # 15 minutes timeout
)
print(response.json())


In [None]:
# stop the container
!docker stop $(docker ps -q --filter ancestor=dataengineer-lambda)
!docker ps --filter ancestor=dataengineer-lambda


## Upload docker image to ECR

In [16]:
## Create ECR repository for dataengineer-lambda (if not already created in 1_environmentSetup.ipynb)
#!aws ecr create-repository --repository-name automatedinsights/lambda_dataengineer --region {REGION} --profile {SESSION_PROFILE}

In [None]:
# Upload docker image to ECR
!aws ecr get-login-password --region {REGION} --profile {SESSION_PROFILE} | docker login --username AWS --password-stdin {AWS_ACCOUNT}.dkr.ecr.{REGION}.amazonaws.com
!docker tag dataengineer-lambda:latest {AWS_ACCOUNT}.dkr.ecr.{REGION}.amazonaws.com/automatedinsights/lambda_dataengineer:latest
!docker push {AWS_ACCOUNT}.dkr.ecr.{REGION}.amazonaws.com/automatedinsights/lambda_dataengineer:latest

## Create & Test Bedrock Agent

In [None]:
import logging
import random
import string
from utils.bedrock_agent import BedrockAgentScenarioWrapper


logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
agent_name = "DataEngineer"
prompt = "Do the data preparation for the below details."

request_body = json.dumps([
                    {
                        "name": "DataLocation",
                        "type": "string",
                        "value": f"s3://{S3_BUCKET_NAME}/uploads/Customers.csv"
                    },
                    {
                        "name": "AthenaDatabase",
                        "type": "string",
                        "value": f"{S3_BUCKET_NAME.replace('-', '_')}"
                    }
  
], ensure_ascii=False)

prompt = prompt + str(request_body)

instruction = """Expert Data Engineer Agent
You are an expert data engineer with access to a comprehensive set of data preparation and management tools. Your role is to help users process, analyze, and query data efficiently.

Available Tools
GetDatabaseSchema: Retrieves the current SQL database schema from S3.
GetERM: Retrieves the Entity Relationship Diagram in JSON format.
SaveERM: Saves an Entity Relationship Diagram in JSON format to S3.
GetInformationForSemanticTypeDetection: Analyzes data to help detect semantic types.
SaveSQLTableDefinition: Saves SQL table definition to S3.
CreateAthenaTable: Creates an Athena table based on data and Athena table create definition.
QueryData: Executes SQL queries against Athena databases.

Workflow
1) Data Analysis: When presented with new data, use GetInformationForSemanticTypeDetection to analyze the data structure and content. Identify semantic types for each column and 3 ML use cases where these semantic types could be used in.
2) Table Schema Definition: Create appropriate SQL table definition with semantic column names based on the data analysis, considering semantic types and appropriate data types. 
3) Table Creation: Use the Table Schema Definition as a blueprint for the Athena table creation with the tool CreateAthenaTable to make the data available for querying.
4) Data Querying: Verify successful table creation by querying the data using the QueryData tool.
5) Entity RelationShip Diagram: Get the latest SQL schema with the GetDatabaseSchema tool.
Then generate a new entity relationship diagram and save it with the tool SaveERM.

Important Guidelines
If you receive multiple datasets as input, process each one methodically and verify that all files have been processed before providing your final response.
When creating table definitions, include SQL comments that explain the semantic type of each column and its potential use cases.
For primary keys, add specific comments explaining the primary key constraint.
Always ensure your SQL follows the appropriate format for the target system (ANSI SQL or Athena SQL)."""

postfix = "".join(
    random.choice(string.ascii_lowercase + "0123456789") for _ in range(8)
)

agent_name = agent_name + "_" + postfix

IMAGE_URI = f'{AWS_ACCOUNT}.dkr.ecr.{REGION}.amazonaws.com/automatedinsights/lambda_dataengineer:latest'

agentCollaboration = 'DISABLED' #'SUPERVISOR' #|'SUPERVISOR_ROUTER'|'DISABLED'

sub_agents_list = []
promptOverrideConfiguration = None

lambda_environment_variables = {
    "S3_BUCKET_NAME": S3_BUCKET_NAME,
    "MODEL_ID": MODEL_ID
}

scenario = BedrockAgentScenarioWrapper(
    bedrock_agent_client=bedrock_agent_client,
    runtime_client=bedrock_agent_runtime_client,
    lambda_client=lambda_client,
    iam_resource=iam_resource,
    postfix=postfix,
    agent_name=agent_name,
    model_id=MODEL_ID,
    sub_agents_list=sub_agents_list,
    prompt=prompt,
    lambda_image_uri=IMAGE_URI,
    lambda_environment_variables=lambda_environment_variables,
    action_group_schema_path="action_groups/dataengineer_open_api_schema.yml",
    instruction=instruction,
    agentCollaboration=agentCollaboration,
    promptOverrideConfiguration=promptOverrideConfiguration
)
try:
    scenario.run_scenario()
except Exception as e:
    logging.exception(f"Something went wrong: {e}")

In [6]:
prompt = "Do the data preparation for the below details. Return the final Athena table name and the SQL table definition in the final response."

request_body = json.dumps([
                    {
                        "name": "DataLocation",
                        "type": "string",
                        "value": f"s3://{S3_BUCKET_NAME}/uploads/Customers.csv"
                    },
                    {
                        "name": "AthenaDatabase",
                        "type": "string",
                        "value": f"{S3_BUCKET_NAME.replace('-', '_')}"
                    }
  
], ensure_ascii=False)

prompt = prompt + str(request_body)
prompt
# scenario.prompt = prompt

# scenario.chat_with_agent()


'Do the data preparation for the below details. Return the final Athena table name and the SQL table definition in the final response.[{"name": "DataLocation", "type": "string", "value": "s3://huthmac-automatedinsights/uploads/Customers.csv"}, {"name": "AthenaDatabase", "type": "string", "value": "huthmac_automatedinsights"}]'

## Agent Evaluation

In [None]:
agent =scenario.agent

AGENT_ID = agent.get('agentId')

# get agent alias id
agent_aliases = bedrock_agent_client.list_agent_aliases(agentId= AGENT_ID)

AGENT_ALIAS_ID =  agent_aliases.get('agentAliasSummaries')[0].get('agentAliasId')
print(f"AGENT_ID: {AGENT_ID}")
print(f"AGENT_ALIAS_ID: {AGENT_ALIAS_ID}")

In [45]:
# save agent config to json file for evaluation
agent_config = {
    "agent_id": AGENT_ID,
    "agent_alias_id": AGENT_ALIAS_ID,
    "human_id": "User",
    "agent_name": "DataEngineer",
    "agent_instruction": instruction,
    "tools": [
        {
            "tool_name": "DataEngineerAPI",
            "name": "DataEngineerAPI",
            "description": "Data Preparation to make data available for AI/ML and Analytics workloads",
            "actions": [
                {
                    "name": "GetDatabaseSchema",
                    "description": "Retrieve the SQL database schema from S3",
                    "output_schema": {
                        "data_type": "array",
                        "items": {
                            "type": "string",
                            "description": "SQL table definition statements"
                        }
                    },
                    "requires_confirmation": False,
                    "meta": {}
                },
                {
                    "name": "GetERM",
                    "description": "Retrieve the Entity Relationship Diagram in json format",
                    "output_schema": {
                        "data_type": "array",
                        "items": {
                            "type": "string",
                            "description": "Entity Relationship Diagram in json format"
                        }
                    },
                    "requires_confirmation": False,
                    "meta": {}
                },
                {
                    "name": "SaveERM",
                    "description": "Save the Entity Relationship Diagram in json format to S3",
                    "input_schema": {
                        "data_type": "object",
                        "properties": {
                            "ERMData": {
                                "data_type": "string",
                                "description": "Entity Relationship Diagram in json format"
                            }
                        },
                        "required": ["ERMData"]
                    },
                    "output_schema": {
                        "data_type": "object",
                        "properties": {
                            "erm_file_location": {
                                "data_type": "string",
                                "description": "S3 location where the ERM was saved"
                            }
                        }
                    },
                    "requires_confirmation": False,
                    "meta": {}
                },
                {
                    "name": "GetInformationForSemanticTypeDetection",
                    "description": "Get information for semantic type detection",
                    "input_schema": {
                        "data_type": "object",
                        "properties": {
                            "DataLocation": {
                                "data_type": "string",
                                "description": "S3 location of the data to analyze"
                            }
                        },
                        "required": ["DataLocation"]
                    },
                    "output_schema": {
                        "data_type": "object",
                        "properties": {
                            "column_names": {
                                "data_type": "array",
                                "items": {
                                    "type": "string"
                                }
                            },
                            "data_sample": {
                                "data_type": "object"
                            }
                        }
                    },
                    "requires_confirmation": False,
                    "meta": {}
                },
                {
                    "name": "SaveSQLTableDefinition",
                    "description": "Save SQL table definition to S3",
                    "input_schema": {
                        "data_type": "object",
                        "properties": {
                            "SQL_Table_Definition": {
                                "data_type": "string",
                                "description": "SQL table definition"
                            },
                            "TableName": {
                                "data_type": "string",
                                "description": "File or table name"
                            }
                        },
                        "required": ["SQL_Table_Definition"]
                    },
                    "output_schema": {
                        "data_type": "object",
                        "properties": {
                            "sql_table_definition": {
                                "data_type": "string"
                            },
                            "sql_table_definition_file_location": {
                                "data_type": "string"
                            }
                        }
                    },
                    "requires_confirmation": False,
                    "meta": {}
                },
                {
                    "name": "CreateAthenaTable",
                    "description": "Create an Athena table based on data and semantic types",
                    "input_schema": {
                        "data_type": "object",
                        "properties": {
                            "DataLocation": {
                                "data_type": "string",
                                "description": "S3 location of the data"
                            },
                            "AthenaDatabase": {
                                "data_type": "string",
                                "description": "The Athena database to create the table in"
                            },
                            "Athena_Table_Create_SQL_statement": {
                                "data_type": "string",
                                "description": "Athena SQL table create statement with LOCATION DataLocation"
                            },
                            "TableName": {
                                "data_type": "string",
                                "description": "Name of the table to create"
                            }
                        },
                        "required": ["AthenaDatabase", "TableName", "Athena_Table_Create_SQL_statement", "DataLocation"]
                    },
                    "output_schema": {
                        "data_type": "object",
                        "properties": {
                            "status": {
                                "data_type": "string"
                            },
                            "table_name": {
                                "data_type": "string"
                            },
                            "table_location": {
                                "data_type": "string"
                            },
                            "table_definition": {
                                "data_type": "string"
                            },
                            "query_results": {
                                "data_type": "object"
                            }
                        }
                    },
                    "requires_confirmation": False,
                    "meta": {}
                },
                {
                    "name": "QueryData",
                    "description": "Execute SQL query against Athena database",
                    "input_schema": {
                        "data_type": "object",
                        "properties": {
                            "AthenaDatabase": {
                                "data_type": "string",
                                "description": "The Athena database to query"
                            },
                            "SQLQuery": {
                                "data_type": "string",
                                "description": "SQL query to execute"
                            }
                        },
                        "required": ["AthenaDatabase", "SQLQuery"]
                    },
                    "output_schema": {
                        "data_type": "object",
                        "properties": {
                            "status": {
                                "data_type": "string"
                            },
                            "query": {
                                "data_type": "string"
                            },
                            "data": {
                                "data_type": "object"
                            }
                        }
                    },
                    "requires_confirmation": False,
                    "meta": {}
                }
            ],
            "tool_type": "Module",
            "meta": {}
        }
    ],
    "reachable_agents": []
}
# save agent config to json file
with open('../dataengineer/agent.json', 'w') as f:
    json.dump(agent_config, f, indent=4)


In [40]:
# define different evaluation scenarios
evaluation_scenarios = {
    "scenarios": [
        {
            "scenario": "DataPreparation",
            "input_problem": (
                'Do the data preparation with the following details. Return the final Athena table name and the SQL table definition in the final response. ' +
                json.dumps([
                    {"name": "DataLocation", "type": "string", "value": f"s3://{S3_BUCKET_NAME}/uploads/Customers.csv"},
                    {"name": "AthenaDatabase", "type": "string", "value": f"{S3_BUCKET_NAME.replace('-', '_')}"}
                ])
            ),     
            "assertions": [
                "agent: GetInformationForSemanticTypeDetection is executed to gather information about the data",
                "agent: SaveSQLTableDefinition is executed to save the SQL table definition to S3",
                "agent: CreateAthenaTable is executed to create the Athena table",
                "agent: QueryData is executed to verify the Athena table creation",
                "agent: GetDatabaseSchema is executed to get the latest database schema",
                "agent: SaveERM is executed to save the Entity Relationship Diagram in json format to S3"
            ]
        }  
        
    ]
}

# save evaluation scenarios to json file
with open('../dataengineer/scenarios.json', 'w') as f:
    json.dump(evaluation_scenarios, f, indent=4)


In [None]:
# Run the agent evaluation
from utils.benchmark import run_agent_evaluation

dataset_dir = "../dataengineer"
results = run_agent_evaluation(
    scenario_filepath = f"{dataset_dir}/scenarios.json",
    agent_filepath = f"{dataset_dir}/agent.json",
    llm_judge_id = DATAENGINEER_AGENT_EVAL_PROFILE_ARN,
    region = REGION,
    session = session
)

# Check if results is not None before proceeding
if results is not None:
    # Create high-level metrics DataFrame
    metrics_df = pd.DataFrame({
        'user_gsr': [results['user_gsr']],
        'system_gsr': [results['system_gsr']],
        'overall_gsr': [results['overall_gsr']],
        'partial_gsr': [results['partial_gsr']],
        'scenario_count': [results['scenario_count']],
        'conversation_count': [results['conversation_count']]
    })

    # Create detailed assertions DataFrame
    assertions_list = []
    for eval_result in results['conversation_evals']:
        trajectory_index = eval_result['trajectory_index']
        for assertion in eval_result['report']:
            assertions_list.append({
                'trajectory_index': trajectory_index,
                'assertion_type': assertion['assertion_type'],
                'assertion': assertion['assertion'],
                'answer': assertion['answer'],
                'evidence': assertion['evidence']
            })

    assertions_df = pd.DataFrame(assertions_list)

    # Display results
    print("High-level Metrics:")
    display(metrics_df)

    print("\nDetailed Assertions:")
    display(assertions_df)

else:
    print("Error: Please check for errors in the evaluation.")

## Summary

- We first created a Docker container that contains all of the available tools/functions that the agent can use.
- We then created a Bedrock Agent with an Action Group that uses the Docker container in a Lambda function as the execution environment.
- We then created a set of evaluation scenarios that cover different aspects of the agent's behavior.
- We then ran the agent evaluation and reviewed the results.

