# AWS Bedrock RAG Infrastructure Setup

This notebook guides you through creating AWS infrastructure for a Bedrock RAG implementation.
**VPC creation is handled by an external script (`create_vpc_script.py`).**
Subsequent resources (Aurora, S3, Bedrock KB) are created step-by-step in this notebook.

# 1. Imports
Import necessary libraries

In [None]:
import boto3
import json
from botocore.exceptions import ClientError
import time
import zipfile
import os
import io

# 2. Configuration
Define configuration variables for the infrastructure.

In [None]:
AGENT_NAME = "AppDemoAgentWKB"  # Choose a unique name for your agent
KNOWLEDGE_BASE_ID = "" # e.g., "ABCDEFGHIJ"
KNOWLEDGE_BASE_ARN = "arn:aws:bedrock:us-west-2::knowledge-base/" # e.g., "arn:aws:bedrock:us-east-1:123456789012:knowledge-base/ABCDEFGHIJ"
AGENT_EXECUTION_ROLE_ARN = "" # e.g., "arn:aws:iam::123456789012:role/MyBedrockAgentRole"
FOUNDATION_MODEL_ARN = "arn:aws:bedrock:us-west-2::inference-profile/us.anthropic.claude-3-5-haiku-20241022-v1:0" # Example: Change to your desired model/region

# ** OPTIONAL: Customize these values **
AGENT_DESCRIPTION = "Agent specialized in construction and heavy machines with access to knowledgebase and data in relationaldatabase"
AGENT_INSTRUCTION = """
Role: You are an AI assistant specialized in providing information about heavy machinery operations, maintenance, alerts, and specifications.

Core Task: Answer user questions accurately and concisely by leveraging information from two primary sources:

A PostgreSQL Database containing operational logs, maintenance records, and alert history.

A Knowledge Base (KB) containing detailed specifications for various machine models.

Available Tools/Information:

SQL Query Tool: You can execute SQL queries against the PostgreSQL database. The database schema is defined as follows:

Machines Table: Basic machine info (machine_id, machine_type, model, serial_number, manufacture_date, acquisition_date).

OperationalData Table: Daily metrics (data_id, machine_id, operation_date, hours_operated, fuel_consumed_liters, idle_time_hours, material_moved_cubic_meters, distance_traveled_km, loads_lifted, weight_lifted_tons, max_load_tons, maintenance_flag). Links to Machines via machine_id.

MaintenanceEvents Table: Maintenance records (event_id, machine_id, event_date, event_type, description, parts_replaced, cost, downtime_hours, technician_name). Links to Machines via machine_id.

AlertEvents Table: Alert history (alert_id, machine_id, alert_date, alert_type, severity, description, resolved, resolution_date, resolution_notes). Links to Machines via machine_id.

(Refer to the detailed schema summary for precise column types and descriptions).

Knowledge Base Query Tool: You can query a knowledge base to retrieve machine specifications (e.g., engine power, weight, capacity, dimensions) based on machine type and model. Assume you have a function like query_knowledge_base(machine_type, model) that returns relevant specifications.

Workflow:

Understand the User Query: Analyze the user's question to determine the specific information needed. Identify keywords related to:

Specific machines (machine_id, serial_number).

Machine types or models (Bulldozer, DT1000).

Timeframes (last week, March 2025, yesterday).

Operational metrics (fuel consumption, hours operated, material moved).

Maintenance (cost, downtime, parts replaced, event type).

Alerts (critical alerts, unresolved warnings, alert type).

Specifications (engine size, lifting capacity, weight).

Identify Information Source(s):

If the query relates to operational history, maintenance records, or alert events for specific machines or timeframes, the PostgreSQL Database is the primary source.

If the query asks for general specifications, capabilities, or technical details of a machine model, the Knowledge Base is the primary source.

If the query requires combining operational data with specifications (e.g., "Which bulldozer model had the highest fuel consumption per hour last month?"), both sources will be needed.

Formulate Queries:

SQL Queries:

Construct precise PostgreSQL queries based on the schema.

Use appropriate WHERE clauses for filtering by machine_id, machine_type, model, dates (operation_date, event_date, alert_date), severity, event_type, etc.

Use aggregate functions (SUM, AVG, COUNT, MAX, MIN) for calculations.

Use JOIN clauses (primarily Machines with other tables on machine_id) when information from multiple tables is needed.

Handle potential NULL values appropriately, especially in calculations (e.g., use COALESCE or check for NULL before division).

Pay attention to data types (e.g., use date functions for date comparisons).

Knowledge Base Queries:

Extract the relevant machine_type and model from the user query or from the database results.

Use the query_knowledge_base(machine_type, model) tool.

Synthesize Information (If necessary): If information from both the database and KB is required, combine the results logically. For example, retrieve operational data from the database, identify the relevant models, query the KB for specs of those models, and then present the combined findings.

Generate Response:

Provide a clear, direct answer to the user's question.

If presenting data, format it readably (e.g., using tables or lists).

If calculations were performed, state the result clearly (e.g., "The average fuel consumption for Bulldozers in March 2025 was X liters per hour.").

If data is unavailable for the specific request (e.g., no records found for the date range), state that clearly.

Do not just output raw SQL results unless specifically asked. Summarize and explain the findings.

Cite the source if relevant (e.g., "Based on operational data..." or "According to the specifications...").
"""
IDLE_SESSION_TTL_SECONDS = 1800  # 30 minutes timeout
AWS_REGION = "us-west-2" # Change if your KB/resources are in a different region


In [None]:
session = boto3.Session()
iam_client = session.client('iam')

role_name = "AppDemoBedrock-execution-role"

assume_role_policy = json.dumps({"Version": "2012-10-17", "Statement": [{"Effect": "Allow", "Principal": {"Service": "bedrock.amazonaws.com"}, "Action": "sts:AssumeRole"}]})

role_response = iam_client.create_role(RoleName=role_name, AssumeRolePolicyDocument=assume_role_policy, Description=f"Execution role for Bedrock Agent {AGENT_NAME}", Tags=[{'Key': 'Name', 'Value': role_name}])
role_arn = role_response['Role']['Arn']
AGENT_EXECUTION_ROLE_ARN = role_arn

print(f"Created Role ARN: {role_arn}. Waiting 15s for propagation...")

In [None]:
bedrockagent_policy_name = f"{AGENT_NAME}-execution-policy-inline"
bedrockagent_document = json.dumps({"Version": "2012-10-17", "Statement": [
        {"Sid": "BedrockAgent", "Effect": "Allow", "Action": "bedrock:*", "Resource": "*"},
    ]})

In [None]:
print(f"Attaching Inline Policy: {bedrockagent_policy_name} to role {role_name}")
iam_client.put_role_policy(RoleName=role_name, PolicyName=bedrockagent_policy_name, PolicyDocument=bedrockagent_document)
iam_client.attach_role_policy(RoleName=role_name, PolicyArn="arn:aws:iam::aws:policy/service-role/AWSLambdaRole")
print(f"Policy attached. Waiting 10s...")

In [None]:
try:
    bedrock_agent_client = boto3.client('bedrock-agent', region_name=AWS_REGION)
except Exception as e:
    print(f"Error creating Boto3 client: {e}")
    exit()

In [None]:
knowledge_base_configuration = [
    {
        'knowledgeBaseId': KNOWLEDGE_BASE_ID,
        'description': "Use this knowledge base to find information about product features, setup, and troubleshooting."
    }
]

# 3. Create Bedrock Agent
Creates a Bedrock Agent that connects with a lambda function and a previously created knowledgebase.

In [None]:
FUNCTION_NAME = "ApplicationDemoDBRetrieval" # Choose a unique name for your Lambda function
IAM_ROLE_NAME = "DBRetrieverLambdaExecutionRole" 

BASIC_EXECUTION_POLICY_ARN = "arn:aws:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole"
SECRET_RETRIEVAL_POLICY_ARN = "arn:aws:iam::aws:policy/SecretsManagerReadWrite"

# ** OPTIONAL: Customize these values **
FUNCTION_DESCRIPTION = "Lambda function to retrieve database data for Bedrock Agent."
RUNTIME = "python3.12" # Choose a supported Python runtime (e.g., python3.9, python3.10, python3.11, python3.12)
HANDLER = "lambda_function.lambda_handler" # Standard handler for Python: filename.function_name
MEMORY_SIZE = 128 # Memory in MB
TIMEOUT = 30 # Timeout in seconds
AWS_REGION = "us-west-2" # Change to your desired AWS region

In [None]:
assume_role_policy_document = json.dumps({
    "Version": "2012-10-17",
    "Statement": [
        {
            "Effect": "Allow",
            "Principal": {
                "Service": "lambda.amazonaws.com"
            },
            "Action": "sts:AssumeRole"
        }
    ]
})

In [None]:
iam_client = boto3.client('iam', region_name=AWS_REGION)

In [None]:
create_role_response = iam_client.create_role(
        RoleName=IAM_ROLE_NAME,
        AssumeRolePolicyDocument=assume_role_policy_document,
        Description="IAM role assumed by the MyOrderStatusChecker Lambda function for basic execution permissions (CloudWatch Logs)."
        # You can add tags here if needed:
        # Tags=[{'Key': 'Project', 'Value': 'BedrockAgentDemo'}]
    )
role_arn = create_role_response['Role']['Arn']
role_id = create_role_response['Role']['RoleId']

LAMBDA_EXECUTION_ROLE_ARN = role_arn

In [None]:
iam_client.attach_role_policy(RoleName=IAM_ROLE_NAME, PolicyArn=BASIC_EXECUTION_POLICY_ARN)
iam_client.attach_role_policy(RoleName=IAM_ROLE_NAME, PolicyArn=SECRET_RETRIEVAL_POLICY_ARN)

In [None]:
lambda_code = """

import os
import json
import pg8000.native
from botocore.client import Config
from botocore import session
import datetime

class DBConnection:
    _secret = None

    @classmethod
    def get_config(cls):
        if not cls._secret:
            client = session.get_session().create_client(
                'secretsmanager'
            )
            response = client.get_secret_value(SecretId=os.environ['SECRET_ARN'])
            cls._secret = json.loads(response['SecretString'])
        return cls._secret

def default_converter(obj):
    if isinstance(obj, (datetime.datetime, datetime.date)):
        return obj.isoformat()
    return str(obj)  # fallback for other types if needed

def get_tables(schema='public'):
    print(f"Schema: {schema}")
    config = DBConnection.get_config()
    with pg8000.native.Connection(
        user=config['username'],
        password=config['password'],
        host=config['host'],
        database='postgres',
        ssl_context=True
    ) as conn:
        return conn.run(f'''
            SELECT table_name 
            FROM information_schema.tables 
            WHERE table_schema = '{schema}'
        ''')

def get_columns(table, schema='public'):

    print(f"Table: {table}, Schema: {schema}")
    config = DBConnection.get_config()
    with pg8000.native.Connection(
        user=config['username'],
        password=config['password'],
        host=config['host'],
        database='postgres',
        ssl_context=True
    ) as conn:
        return conn.run(f'''
            SELECT column_name, data_type 
            FROM information_schema.columns 
            WHERE table_schema = '{schema}' 
            AND table_name = '{table}'
        ''')

def execute_query(query):
    config = DBConnection.get_config()
    print(query)
    rows = [[]]
    with pg8000.native.Connection(
        user=config['username'],
        password=config['password'],
        host=config['host'],
        database='postgres',
        ssl_context=True
    ) as conn:
        return conn.run(query)

def lambda_handler(event, context):
    print(event)
    try:
        apiPath = event.get('apiPath')

        if apiPath == '/tables':
            result = get_tables(event.get('schema', 'public'))
        elif apiPath == '/columns':
            parameters = event.get('parameters', {})
            table = None
            schema = None
            for param in parameters:
                if param['name'] == 'table':
                    table = param['value']
                elif param['name'] == 'schema':
                    schema = param['value']
            result = get_columns(table, schema)
        elif apiPath == '/query':
            request = event['requestBody']['content']['application/json']
            parameters = request.get('properties', {})
            query = None
            for param in parameters:
                if param['name'] == 'query':
                    query = param['value']
            result = execute_query(query)
        else:
            raise ValueError('Invalid action specified')

        response_body = {
            'application/json': {
                'body': json.dumps(result, default=default_converter)
            }
        }

        action_response = {
            'actionGroup': event['actionGroup'],
            'apiPath': event['apiPath'],
            'httpMethod': event['httpMethod'],
            'httpStatusCode': 200,
            'responseBody': response_body
        }

    except Exception as e:
        print(f"Error: {e}")
        action_response = {
            'actionGroup': event['actionGroup'],
            'apiPath': event['apiPath'],
            'httpMethod': event['httpMethod'],
            'httpStatusCode': 500,
            'responseBody': json.dumps({'error': str(e)})
        }

    finally:
        response = {'response': action_response}
        return response




"""

In [None]:
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zipf:
    # Write the lambda code to a file named lambda_function.py within the zip
    # The filename must match the first part of the HANDLER variable
    zipf.writestr('lambda_function.py', lambda_code)

zip_buffer.seek(0) # Rewind the buffer to the beginning
deployment_package = zip_buffer.read()

In [None]:
lambda_client = boto3.client('lambda', region_name=AWS_REGION)

response = lambda_client.create_function(
        FunctionName=FUNCTION_NAME,
        Runtime=RUNTIME,
        Role=LAMBDA_EXECUTION_ROLE_ARN,
        Handler=HANDLER,
        Code={'ZipFile': deployment_package},
        Description=FUNCTION_DESCRIPTION,
        Timeout=TIMEOUT,
        MemorySize=MEMORY_SIZE,
        Publish=True,
        Architectures=['arm64'],
        Layers=['arn:aws:lambda:us-west-2::layer:pg8000:2']
        # Add Environment variables, VPC config, etc. here if needed
        Environment={ 'Variables': { 'SECRET_ARN': 'my_value' } },
        # VpcConfig={ 'SubnetIds': ['subnet-xxxxx'], 'SecurityGroupIds': ['sg-yyyyy'] }
    )



In [None]:
function_arn = response.get('FunctionArn')

print("\n--- Lambda Function Creation Successful ---")
print(f"Function Name: {response.get('FunctionName')}")
print(f"Function ARN: {function_arn}")
print(f"Runtime: {response.get('Runtime')}")
print(f"Handler: {response.get('Handler')}")
print(f"Role ARN: {response.get('Role')}")
print(f"Version: {response.get('Version')}")
print("-----------------------------------------")
print(f"\nUse this ARN for the LAMBDA_FUNCTION_ARN in your Bedrock Agent Action Group script:")
print(function_arn)

In [None]:
LAMBDA_FUNCTION_ARN = function_arn

In [None]:
response = lambda_client.add_permission(
    FunctionName=FUNCTION_NAME,
    StatementId='bedrockAccess',
    Action='lambda:InvokeFunction',
    Principal='bedrock.amazonaws.com',
    SourceArn='arn:aws:bedrock:us-west-2::agent/*'
)

In [None]:
OPENAPI_SCHEMA = {
    "openapi": "3.0.3",
    "info": {
        "title": "Database Connection API",
        "description": "An API for connecting to and querying PostgreSQL databases through AWS Lambda.\nThis API provides endpoints to list tables, retrieve column information, and execute custom queries.\n",
        "version": "1.0.0"
    },
    "paths": {
        "/tables": {
            "get": {
                "summary": "List database tables",
                "description": "Returns a list of all tables in the specified schema",
                "operationId": "getTables",
                "parameters": [
                    {
                        "name": "schema",
                        "in": "query",
                        "description": "Database schema name",
                        "required": False,
                        "schema": {
                            "type": "string",
                            "default": "public"
                        }
                    }
                ],
                "responses": {
                    "200": {
                        "description": "A list of table names",
                        "content": {
                            "application/json": {
                                "schema": {
                                    "type": "array",
                                    "items": {
                                        "type": "object",
                                        "properties": {
                                            "table_name": {
                                                "type": "string"
                                            }
                                        }
                                    }
                                },
                                "example": [
                                    {
                                        "table_name": "users"
                                    },
                                    {
                                        "table_name": "products"
                                    },
                                    {
                                        "table_name": "orders"
                                    }
                                ]
                            }
                        }
                    },
                    "500": {
                        "$ref": "#/components/responses/Error"
                    }
                },
                "security": [
                    {
                        "ApiKeyAuth": []
                    }
                ]
            }
        },
        "/columns": {
            "get": {
                "summary": "Get column information for a table",
                "description": "Returns column names and data types for the specified table",
                "operationId": "getColumns",
                "parameters": [
                    {
                        "name": "table",
                        "in": "query",
                        "description": "Table name",
                        "required": True,
                        "schema": {
                            "type": "string"
                        }
                    },
                    {
                        "name": "schema",
                        "in": "query",
                        "description": "Database schema name",
                        "required": False,
                        "schema": {
                            "type": "string",
                            "default": "public"
                        }
                    }
                ],
                "responses": {
                    "200": {
                        "description": "A list of column definitions",
                        "content": {
                            "application/json": {
                                "schema": {
                                    "type": "array",
                                    "items": {
                                        "type": "object",
                                        "properties": {
                                            "column_name": {
                                                "type": "string"
                                            },
                                            "data_type": {
                                                "type": "string"
                                            }
                                        }
                                    }
                                },
                                "example": [
                                    {
                                        "column_name": "id",
                                        "data_type": "integer"
                                    },
                                    {
                                        "column_name": "name",
                                        "data_type": "character varying"
                                    },
                                    {
                                        "column_name": "created_at",
                                        "data_type": "timestamp without time zone"
                                    }
                                ]
                            }
                        }
                    },
                    "500": {
                        "$ref": "#/components/responses/Error"
                    }
                },
                "security": [
                    {
                        "ApiKeyAuth": []
                    }
                ]
            }
        },
        "/query": {
            "post": {
                "summary": "Execute a custom SQL query",
                "description": "Executes a SQL query and returns the results",
                "operationId": "executeQuery",
                "requestBody": {
                    "description": "SQL query to execute",
                    "required": True,
                    "content": {
                        "application/json": {
                            "schema": {
                                "type": "object",
                                "required": [
                                    "query"
                                ],
                                "properties": {
                                    "query": {
                                        "type": "string",
                                        "description": "SQL query to execute"
                                    }
                                }
                            },
                            "example": {
                                "query": "SELECT * FROM users LIMIT 10"
                            }
                        }
                    }
                },
                "responses": {
                    "200": {
                        "description": "Query results",
                        "content": {
                            "application/json": {
                                "schema": {
                                    "type": "array",
                                    "items": {
                                        "type": "object",
                                        "additionalProperties": True
                                    }
                                },
                                "example": [
                                    {
                                        "id": 1,
                                        "name": "John Doe",
                                        "email": "john@example.com",
                                        "created_at": "2023-01-15T10:30:00Z"
                                    },
                                    {
                                        "id": 2,
                                        "name": "Jane Smith",
                                        "email": "jane@example.com",
                                        "created_at": "2023-02-20T14:15:00Z"
                                    }
                                ]
                            }
                        }
                    },
                    "500": {
                        "$ref": "#/components/responses/Error"
                    }
                },
                "security": [
                    {
                        "ApiKeyAuth": []
                    }
                ]
            }
        }
    },
    "components": {
        "securitySchemes": {
            "ApiKeyAuth": {
                "type": "apiKey",
                "name": "x-api-key",
                "in": "header"
            }
        },
        "responses": {
            "Error": {
                "description": "Error response",
                "content": {
                    "application/json": {
                        "schema": {
                            "type": "object",
                            "properties": {
                                "error": {
                                    "type": "string",
                                    "description": "Error message"
                                }
                            }
                        },
                        "example": {
                            "error": "Database connection failed"
                        }
                    }
                }
            }
        }
    },
    "security": [
        {
            "ApiKeyAuth": []
        }
    ]
}

In [None]:
ACTION_GROUP_CONFIG = [
    {
        'actionGroupName': 'DatabaseRetrieval',
        'description': 'Retrieves Operational data from a postgres database',
        'actionGroupExecutor': {
            'lambda': LAMBDA_FUNCTION_ARN
        },
        'apiSchema': {
            'payload': json.dumps(OPENAPI_SCHEMA)  # Use S3 reference for large schemas
        }
    }
]

In [None]:
response = bedrock_agent_client.create_agent(
    agentName=AGENT_NAME,
    agentResourceRoleArn=AGENT_EXECUTION_ROLE_ARN,
    foundationModel=FOUNDATION_MODEL_ARN,
    instruction=AGENT_INSTRUCTION,
    description=AGENT_DESCRIPTION,
    idleSessionTTLInSeconds=IDLE_SESSION_TTL_SECONDS
)

agent_details = response.get('agent', {})
agent_id = agent_details.get('agentId')
agent_arn = agent_details.get('agentArn')
agent_status = agent_details.get('agentStatus')

print("\n--- Agent Creation Initiated ---")
print(f"Agent Name: {agent_details.get('agentName')}")
print(f"Agent ID: {agent_id}")
print(f"Agent ARN: {agent_arn}")
print(f"Initial Status: {agent_status}")
print("------------------------------")
print("\nAgent creation can take a few minutes.")
print("Check the AWS Bedrock console for the current status.")

In [None]:
while agent_status not in ['PREPARED', 'FAILED']:
    print(f"Current agent status: {agent_status}. Waiting...")
    time.sleep(30) # Check every 30 seconds
    get_agent_response = bedrock_agent_client.get_agent(agentId=agent_id)
    agent_status = get_agent_response.get('agent', {}).get('agentStatus')
    
    if agent_status == 'PREPARED':
        print(f"\nAgent '{agent_id}' is now PREPARED and ready to use.")
    elif agent_status == 'FAILED':
        print(f"\nAgent '{agent_id}' FAILED preparation. Check console for details.")



In [None]:
response = bedrock_agent_client.associate_agent_knowledge_base(
    agentId=agent_id,
    agentVersion='DRAFT',
    description='string',
    knowledgeBaseId=KNOWLEDGE_BASE_ID,
    knowledgeBaseState='ENABLED'
)

In [None]:
response = bedrock_agent_client.create_agent_action_group(
    actionGroupExecutor={
        'lambda': function_arn
    },
    actionGroupName='OperationalDataRetrieval',
    actionGroupState='ENABLED',
    agentId=agent_id,
    agentVersion='DRAFT',
    apiSchema={
        'payload': json.dumps(OPENAPI_SCHEMA)
    },
    description='Action group to retrieve operational data for the heavy machinery'
)