In [None]:
import boto3
import json
import logging

# Set up logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# AWS Clients
sagemaker_client = boto3.client('sagemaker-runtime')
sqs_client = boto3.client('sqs')

# SageMaker Endpoint
SAGEMAKER_ENDPOINT = ""

# Predefined mapping of events to sources
PREDEFINED_EVENT_SOURCES = {
    "RunInstances": "ec2.amazonaws.com",
    "TerminateInstances": "ec2.amazonaws.com",
    "DescribeInstances": "ec2.amazonaws.com",
    "CreateSecurityGroup": "ec2.amazonaws.com",
    "DeleteSecurityGroup": "ec2.amazonaws.com",
    "AuthorizeSecurityGroupIngress": "ec2.amazonaws.com",
    "RevokeSecurityGroupIngress": "ec2.amazonaws.com",
    "RevokeSecurityGroupEgress": "ec2.amazonaws.com",
    "DescribeInstanceTypes": "ec2.amazonaws.com",
    "DescribeImages": "ec2.amazonaws.com",
    "DescribeSnapshots": "ec2.amazonaws.com",
    "DescribeVpcs": "ec2.amazonaws.com",
    "DescribeSubnets": "ec2.amazonaws.com",
    "DescribeSecurityGroups": "ec2.amazonaws.com",
    "DescribeAddresses": "ec2.amazonaws.com",
    "DescribeAvailabilityZones": "ec2.amazonaws.com",
    "DescribeLaunchTemplates": "ec2.amazonaws.com",
    "DescribeVolumes": "ec2.amazonaws.com",
    "DescribeAccountAttributes": "ec2.amazonaws.com",
    "DescribeVolumeStatus": "ec2.amazonaws.com",
    "DescribeHosts": "ec2.amazonaws.com",
    "DescribeCapacityReservations": "ec2.amazonaws.com",
    "DescribeTags": "ec2.amazonaws.com",
    "DescribeInstanceStatus": "ec2.amazonaws.com",
    "DescribePlacementGroups": "ec2.amazonaws.com",
    "AutomatedDefaultVpcCreation": "ec2.amazonaws.com",
    "LookupEvents": "cloudtrail.amazonaws.com",
    "GetEventSelectors": "cloudtrail.amazonaws.com",
    "ListTags": "cloudtrail.amazonaws.com",
    "GetInsightSelectors": "cloudtrail.amazonaws.com",
    "GetTrailStatus": "cloudtrail.amazonaws.com",
    "DescribeTrails": "cloudtrail.amazonaws.com",
    "ListTrails": "cloudtrail.amazonaws.com",
    "StartLogging": "cloudtrail.amazonaws.com",
    "CreateTrail": "cloudtrail.amazonaws.com",
    "PutEventSelectors": "cloudtrail.amazonaws.com",
    "ListEventDataStores": "cloudtrail.amazonaws.com",
    "CreateSession": "cloudshell.amazonaws.com",
    "DeleteSession": "cloudshell.amazonaws.com",
    "RedeemCode": "cloudshell.amazonaws.com",
    "PutCredentials": "cloudshell.amazonaws.com",
    "GetEnvironmentStatus": "cloudshell.amazonaws.com",
    "SendHeartBeat": "cloudshell.amazonaws.com",
    "DescribeEnvironments": "cloudshell.amazonaws.com",
}

def lambda_handler(event, context):
    try:
        logger.info("Received event: %s", json.dumps(event))

        event_names = []

        # Extract event names from 'inputs'
        if 'inputs' in event and isinstance(event['inputs'], list):
            event_names = event['inputs']
        # Extract event names from 'detail' if present
        elif 'detail' in event and 'eventName' in event['detail']:
            event_names = [event['detail']['eventName']]
        else:
            logger.error("Invalid input structure: %s", json.dumps(event))
            return {
                "statusCode": 400,
                "error": "Invalid input structure. Provide 'inputs' or 'detail.eventName'."
            }

        logger.info(f"Processing {len(event_names)} event(s).")

        # Check if triggered by SQS
        if 'source' in event and event['source'] == 'aws.sqs':
            logger.info("Triggered by SQS. Attempting to re-predict using predefined list.")
            return handle_sqs_event(event_names)

        # Define id2label mapping
        id2label = {
            "0": "cloudshell.amazonaws.com",
            "1": "cloudtrail.amazonaws.com",
            "2": "ec2.amazonaws.com"
        }

        predictions = []

        for event_name in event_names:
            payload = json.dumps({"inputs": [event_name]})
            logger.info(f"Payload sent to SageMaker for event '{event_name}': {payload}")

            # Call SageMaker endpoint
            response = sagemaker_client.invoke_endpoint(
                EndpointName=SAGEMAKER_ENDPOINT,
                ContentType="application/json",
                Body=payload
            )

            prediction_result = json.loads(response['Body'].read().decode())
            logger.info(f"SageMaker response for event '{event_name}': {json.dumps(prediction_result)}")

            # Map predictions
            top_prediction = prediction_result[0]
            label = top_prediction['label'].split("_")[-1]
            predicted_source = id2label.get(label, "Unknown")

            logger.info(f"Predicted source for event '{event_name}': {predicted_source}")

            predictions.append({
                "event_name": event_name,
                "predicted_source": predicted_source,
                "score": top_prediction.get('score')  # Assuming score is part of the response
            })

        # Return response
        return {
            "statusCode": 200,
            "predictions": predictions,
            "message": f"Prediction successful for {len(event_names)} event(s)."
        }

    except Exception as e:
        logger.error(f"Error in Lambda function: {str(e)}")
        return {
            "statusCode": 500,
            "error": str(e)
        }

def handle_sqs_event(event_names):
    """
    Handle events triggered by SQS by forcibly mapping the events to the predefined list.
    """
    predicted_sources = []
    unknown_events = []

    for event_name in event_names:
        predicted_source = PREDEFINED_EVENT_SOURCES.get(event_name, "Unknown")
        if predicted_source == "Unknown":
            logger.warning(f"Event '{event_name}' not found in predefined list.")
            unknown_events.append(event_name)
        else:
            logger.info(f"Successfully re-predicted source for event '{event_name}': {predicted_source}")
            predicted_sources.append({
                "event_name": event_name,
                "predicted_source": predicted_source,
                "score": None  # No score available for predefined mapping
            })

    if unknown_events:
        return {
            "statusCode": 400,
            "error": f"Events {unknown_events} not found in predefined list."
        }

    return {
        "statusCode": 200,
        "predictions": predicted_sources,
        "message": f"Re-prediction successful for {len(predicted_sources)} event(s)."
    }
