In [None]:
import boto3
import json
import time
import uuid
from datetime import datetime


# Constants
EVENT_BUS_NAME = ""
EVENT_EMITTER_FUNCTION = ""
SEND_EVENTS_TO_SAGEMAKER_FUNCTION = ""
STATE_MACHINE_ARN = ""
WAIT_TIME = 10
DYNAMODB_TABLE = ''
PRIMARY_SQS_QUEUE_URL = ''
DLQ_NAME = ''

MAX_RETRIES = 5

# Mapping of event names to sources
EVENT_SOURCES = {
    "myapp.ec2": [
        "RunInstances",
        "TerminateInstances",
        "DescribeInstances",
        "CreateSecurityGroup",
        "DeleteSecurityGroup",
        "AuthorizeSecurityGroupIngress",
        "RevokeSecurityGroupIngress",
        "RevokeSecurityGroupEgress",
        "DescribeInstanceTypes",
        "DescribeImages",
        "DescribeSnapshots",
        "DescribeVpcs",
        "DescribeSubnets",
        "DescribeSecurityGroups",
        "DescribeAddresses",
        "DescribeAvailabilityZones",
        "DescribeLaunchTemplates",
        "DescribePlacementGroups",
        "AutomatedDefaultVpcCreation",
        "DescribeRegions",
    ],
    "myapp.cloudtrail": [
        "LookupEvents",
        "GetEventSelectors",
        "ListTags",
        "GetInsightSelectors",
        "GetTrailStatus",
        "DescribeTrails",
        "ListTrails",
        "StartLogging",
        "CreateTrail",
        "PutEventSelectors",
        "ListEventDataStores",
        "GetEbsEncryptionByDefault",
    ],
    "myapp.cloudshell": [
        "CreateSession",
        "DeleteSession",
        "RedeemCode",
        "PutCredentials",
        "GetEnvironmentStatus",
        "SendHeartBeat",
        "DescribeEnvironments",
    ],
}

EVENT_NAME_TO_SOURCE = {}
for source, events in EVENT_SOURCES.items():
    for event in events:
        EVENT_NAME_TO_SOURCE[event] = source

=def log_to_dynamodb(
    execution_id,
    step,
    status_code,
    message,
    predicted_source=None,
    score=None,
    response=None,
    step_function_status=None,
    step_function_output=None,
    lambda_output=None
):
    try:
        item = {
            'execution_id': {'S': execution_id},
            'step': {'S': step},
            'statusCode': {'S': str(status_code)},
            'message': {'S': message},
            'timestamp': {'S': datetime.utcnow().isoformat()}
        }
        if predicted_source:
            item['predicted_source'] = {'S': predicted_source}
        if score is not None:
            item['score'] = {'N': str(score)}
        if response:
            try:
                serialized_response = json.dumps(response)
                item['response'] = {'S': serialized_response}
            except (TypeError, ValueError):
                item['response'] = {'S': str(response)}
        if step_function_status:
            item['step_function_status'] = {'S': step_function_status}
        if step_function_output:
            try:
                serialized_output = json.dumps(step_function_output)
                item['step_function_output'] = {'S': serialized_output}
            except (TypeError, ValueError):
                item['step_function_output'] = {'S': str(step_function_output)}
        if lambda_output:
            try:
                serialized_lambda_output = json.dumps(lambda_output)
                item['lambda_output'] = {'S': serialized_lambda_output}
            except (TypeError, ValueError):
                item['lambda_output'] = {'S': str(lambda_output)}

        dynamodb.put_item(TableName=DYNAMODB_TABLE, Item=item)
        print(f"Logged step '{step}' for execution '{execution_id}' to DynamoDB.")
    except Exception as e:
        print(f"Error logging to DynamoDB: {str(e)}")

def create_dlq(dlq_name):
    try:
        response = sqs_client.create_queue(QueueName=dlq_name)
        print(f"Created DLQ: {response['QueueUrl']}")
        return response['QueueUrl']
    except sqs_client.exceptions.QueueNameExists:
        response = sqs_client.get_queue_url(QueueName=dlq_name)
        print(f"DLQ already exists: {response['QueueUrl']}")
        return response['QueueUrl']
    except Exception as e:
        print(f"Error creating DLQ: {str(e)}")
        raise

def set_redrive_policy(primary_queue_url, dlq_url, max_receive_count=3):
    try:
        dlq_attributes = sqs_client.get_queue_attributes(
            QueueUrl=dlq_url,
            AttributeNames=['QueueArn']
        )
        dlq_arn = dlq_attributes['Attributes']['QueueArn']

        redrive_policy = {
            'maxReceiveCount': str(max_receive_count),
            'deadLetterTargetArn': dlq_arn
        }

        sqs_client.set_queue_attributes(
            QueueUrl=primary_queue_url,
            Attributes={
                'RedrivePolicy': json.dumps(redrive_policy)
            }
        )
        print(f"Redrive policy set with DLQ: {dlq_url} and maxReceiveCount: {max_receive_count}")
    except Exception as e:
        print(f"Error setting redrive policy: {str(e)}")
        raise

def purge_sqs_queue(queue_url):
    try:
        response = sqs_client.purge_queue(QueueUrl=queue_url)
        print(f"Purged SQS queue: {queue_url}")
    except sqs_client.exceptions.QueueDoesNotExist:
        print(f"Queue {queue_url} does not exist.")
    except sqs_client.exceptions.PurgeQueueInProgress:
        print(f"Purge already in progress for queue {queue_url}.")
    except Exception as e:
        print(f"Error purging SQS queue {queue_url}: {str(e)}")

def send_event_to_sqs(event, actual_source, predicted_source, retry_count=0):
    try:
        message_body = {
            "eventName": event["eventName"],
            "predicted_source": predicted_source,
            "actual_source": actual_source,
            "timestamp": datetime.utcnow().isoformat(),
            "retry_count": retry_count
        }
        response = sqs_client.send_message(
            QueueUrl=PRIMARY_SQS_QUEUE_URL,
            MessageBody=json.dumps(message_body)
        )
        print(f"Sent mismatched event '{event['eventName']}' to SQS. Message ID: {response['MessageId']}")
        execution_id = str(uuid.uuid4())
        log_to_dynamodb(
            execution_id,
            'SQS Send',
            200,
            f"Mismatched event '{event['eventName']}' sent to SQS",
            predicted_source,
            event.get("score"),
            response
        )
    except Exception as e:
        print(f"Error sending event to SQS: {str(e)}")
        execution_id = str(uuid.uuid4())
        log_to_dynamodb(
            execution_id,
            'SQS Send',
            500,
            f"Failed to send event '{event['eventName']}' to SQS: {str(e)}",
            predicted_source,
            event.get("score"),
            None
        )

def process_sqs_queue():
    try:
        while True:
            response = sqs_client.receive_message(
                QueueUrl=PRIMARY_SQS_QUEUE_URL,
                MaxNumberOfMessages=10,
                WaitTimeSeconds=10
            )

            messages = response.get('Messages', [])
            if not messages:
                print("No more messages in SQS queue.")
                break

            for message in messages:
                receipt_handle = message['ReceiptHandle']
                body = json.loads(message['Body'])
                event_name = body['eventName']
                retry_count = body.get('retry_count', 0)

                predicted_source = EVENT_NAME_TO_SOURCE.get(event_name, "Unknown")
                if predicted_source == "Unknown":
                    print(f"Failed to find source for event '{event_name}'.")
                    continue

                print(f"Re-predicted source for event '{event_name}': {predicted_source}")
                invoke_step_function(predicted_source, event_name)

                sqs_client.delete_message(
                    QueueUrl=PRIMARY_SQS_QUEUE_URL,
                    ReceiptHandle=receipt_handle
                )
                print(f"Deleted SQS message for event '{event_name}'.")
    except Exception as e:
        print(f"Error processing SQS queue: {str(e)}")

def is_source_matching(actual_source, predicted_source):
    """
    Determines if the predicted source matches the actual source.
    Handles cases where the predicted source has a domain suffix.
    """
    source_mapping = {
        "myapp.cloudshell": "cloudshell.amazonaws.com",
        "myapp.ec2": "ec2.amazonaws.com",
        "myapp.cloudtrail": "cloudtrail.amazonaws.com"
    }

    mapped_actual_source = source_mapping.get(actual_source, actual_source)

    core_actual = mapped_actual_source.split('.')[0]
    core_predicted = predicted_source.split('.')[0]

    return core_actual == core_predicted

# Step 1: Invoking EventEmitterFunction
def invoke_event_emitter(num_events=10):
    print("Step 1: Invoking EventEmitterFunction...")
    try:
        response = lambda_client.invoke(
            FunctionName=EVENT_EMITTER_FUNCTION,
            InvocationType='RequestResponse',
            Payload=json.dumps({"num_events": num_events})
        )
        result = json.loads(response['Payload'].read())

        print(f"Response from EventEmitterFunction:\n{json.dumps(result, indent=2)}")

        if result.get("statusCode") == 200 and "events" in result:
            return result["events"]
        else:
            raise Exception(f"Invalid response from EventEmitterFunction: {result}")
    except Exception as e:
        print(f"Error invoking EventEmitterFunction: {str(e)}")
        raise

# Step 2: Send events to EventBridge
def send_events_to_eventbridge(events):
    print(f"Step 2: Sending {len(events)} event(s) to EventBridge...")
    entries = []
    for event in events:
        entries.append({
            "EventBusName": EVENT_BUS_NAME,
            "Source": event["predicted_source"],
            "DetailType": "AWS API Call via EventEmitterFunction",
            "Detail": json.dumps({
                "eventName": event["eventName"],
                "score": event.get("score")
            })
        })

    try:
        response = eventbridge_client.put_events(
            Entries=entries
        )
        print(f"Response from EventBridge:\n{json.dumps(response, indent=2)}")
        execution_id = str(uuid.uuid4())
        log_to_dynamodb(
            execution_id,
            'Step 2',
            200,
            f"{len(events)} event(s) sent to EventBridge",
            None,
            None,
            response
        )
    except Exception as e:
        print(f"Error sending events to EventBridge: {str(e)}")
        execution_id = str(uuid.uuid4())
        log_to_dynamodb(
            execution_id,
            'Step 2',
            500,
            f"Failed to send events to EventBridge: {str(e)}",
            None,
            None,
            None
        )
        raise

# Step 3: Wait for EventBridge processing
def wait_for_eventbridge_processing():
    print(f"Step 3: Waiting {WAIT_TIME} seconds for EventBridge processing...")
    time.sleep(WAIT_TIME)

# Step 4: Invoking SendEventsToSageMaker for predictions
def invoke_sagemaker_lambda(event_names):
    print(f"Step 4: Invoking SendEventsToSageMaker for {len(event_names)} event(s)...")
    try:
        response = lambda_client.invoke(
            FunctionName=SEND_EVENTS_TO_SAGEMAKER_FUNCTION,
            InvocationType='RequestResponse',
            Payload=json.dumps({
                "inputs": event_names
            })
        )
        result = json.loads(response['Payload'].read())
        print(f"Response from SendEventsToSageMaker:\n{json.dumps(result, indent=2)}")

        # Update validation to handle batch predictions
        if "predictions" not in result:
            raise Exception(f"Invalid response from SendEventsToSageMaker: {result}")

        # Log the SageMaker predictions
        execution_id = str(uuid.uuid4())
        log_to_dynamodb(
            execution_id,
            'Step 4',
            result.get("statusCode", 200),
            f"Prediction successful for {len(result['predictions'])} event(s)",
            None,
            None,
            result
        )
        return result
    except Exception as e:
        print(f"Error invoking SendEventsToSageMaker Lambda: {str(e)}")
        execution_id = str(uuid.uuid4())
        log_to_dynamodb(
            execution_id,
            'Step 4',
            500,
            f"Failed to invoke SendEventsToSageMaker Lambda: {str(e)}",
            None,
            None,
            None
        )
        raise

# Helper function to map myapp sources to AWS domains
def map_to_aws_domain(source):
    source_mapping = {
        "myapp.cloudshell": "cloudshell.amazonaws.com",
        "myapp.ec2": "ec2.amazonaws.com",
        "myapp.cloudtrail": "cloudtrail.amazonaws.com"
    }
    return source_mapping.get(source, source)

# Step 5: Start Step Function execution
def invoke_step_function(predicted_source, event_name):
    aws_domain_source = map_to_aws_domain(predicted_source)

    input_payload = {
        "predicted_source": aws_domain_source,
        "event_name": event_name
    }
    print(f"Step 5: Starting Step Function with input:\n{json.dumps(input_payload, indent=2)}")
    try:
        response = stepfunctions_client.start_execution(
            stateMachineArn=STATE_MACHINE_ARN,
            input=json.dumps(input_payload)
        )
        execution_arn = response['executionArn']
        print(f"Step Function started with Execution ARN: {execution_arn}")

        while True:
            execution_status = stepfunctions_client.describe_execution(
                executionArn=execution_arn
            )
            status = execution_status['status']
            print(f"Current Execution Status: {status}")

            if status in ['SUCCEEDED', 'FAILED', 'TIMED_OUT', 'ABORTED']:
                break
            time.sleep(WAIT_TIME)

        execution_id = str(uuid.uuid4())
        if status == 'SUCCEEDED':
            output = json.loads(execution_status['output'])
            print("Execution succeeded with output:")
            print(json.dumps(output, indent=2))
            log_to_dynamodb(
                execution_id,
                'Step 5',
                200,
                "Step Function succeeded",
                aws_domain_source,
                None,
                None,
                status,
                output
            )
        else:
            print(f"Execution failed with status: {status}")
            log_to_dynamodb(
                execution_id,
                'Step 5',
                500,
                f"Step Function failed: {status}",
                aws_domain_source,
                None,
                None,
                status,
                None
            )
    except Exception as e:
        print(f"Error invoking Step Function: {str(e)}")
        execution_id = str(uuid.uuid4())
        log_to_dynamodb(
            execution_id,
            'Step 5',
            500,
            f"Failed to start or monitor Step Function: {str(e)}",
            aws_domain_source,
            None,
            None,
            None,
            None
        )
        raise

# Full workflow
def test_full_workflow():
    try:
        # Step 1: Invoke EventEmitterFunction with num_events=10
        num_events_to_generate = 10
        events = invoke_event_emitter(num_events=num_events_to_generate)

        # Step 2: Send all events to EventBridge
        send_events_to_eventbridge(events)

        # Step 3: Wait for EventBridge to process and route the events
        wait_for_eventbridge_processing()

        # Step 4: Get predictions from SendEventsToSageMaker
        event_names = [event["eventName"] for event in events]
        prediction_response = invoke_sagemaker_lambda(event_names)

        # Process each prediction
        for prediction in prediction_response.get("predictions", []):
            event_name = prediction["event_name"]
            predicted_source = prediction.get("predicted_source")
            score = prediction.get("score")

            if not predicted_source:
                print(f"Predicted source for event '{event_name}' not found. Skipping further processing.")
                execution_id = str(uuid.uuid4())
                log_to_dynamodb(
                    execution_id,
                    'Validation',
                    400,
                    f"Predicted source for event '{event_name}' not found.",
                    predicted_source,
                    score,
                    None
                )
                continue  # Skip to the next event

            # Determine the actual source based on the event name
            actual_source = EVENT_NAME_TO_SOURCE.get(event_name)

            if not actual_source:
                print(f"Actual source for event '{event_name}' not found. Skipping further processing.")
                execution_id = str(uuid.uuid4())
                log_to_dynamodb(
                    execution_id,
                    'Validation',
                    400,
                    f"Actual source for event '{event_name}' not found.",
                    predicted_source,
                    score,
                    None
                )
                continue  # Skip to the next event

            print(f"Actual source for event '{event_name}': {actual_source}")
            print(f"Predicted source for event '{event_name}': {predicted_source}")

            # Compare predicted_source with actual_source
            if is_source_matching(actual_source, predicted_source):
                print(f"Source validation passed for event '{event_name}'. Proceeding with Step Function.")
                # Step 5: Start Step Function with prediction details
                invoke_step_function(predicted_source, event_name)
            else:
                print(f"Source validation failed for event '{event_name}'. Sending to SQS for reprocessing.")
                # Send the event to SQS for later processing with retry_count incremented
                send_event_to_sqs({
                    "eventName": event_name,
                    "score": score
                }, actual_source, predicted_source, retry_count=1)

        # After processing all events, process SQS queue
        process_sqs_queue()

    except Exception as e:
        print(f"Error during workflow test: {str(e)}")

if __name__ == "__main__":
    try:
        # Create DLQ
        DLQ_URL = create_dlq(DLQ_NAME)

        # Set redrive policy with DLQ
        set_redrive_policy(PRIMARY_SQS_QUEUE_URL, DLQ_URL, max_receive_count=3)

        # Purge the SQS queue before starting the workflow
        purge_sqs_queue(PRIMARY_SQS_QUEUE_URL)

        # Optionally, wait a few seconds to ensure purge is complete
        time.sleep(5)

        # Run the full workflow
        test_full_workflow()
    except Exception as e:
        print(f"Setup failed: {str(e)}")
