In [None]:
import json
import boto3
import logging
import time
from datetime import datetime
from botocore.exceptions import ClientError

# Configure logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# Set AWS region
REGION = 'us-east-1'  # You can also fetch this from environment variables if preferred

# Initialize AWS clients using IAM roles
cloudtrail_client = boto3.client('cloudtrail', region_name=REGION)
s3_client = boto3.client('s3', region_name=REGION)

# Mapping of event names to handler functions
EVENT_HANDLERS = {
    "LookupEvents": "handle_lookup_events",
    "DescribeTrails": "handle_describe_trails",
    "GetTrailStatus": "handle_get_trail_status",
    "CreateTrail": "handle_create_trail",
    "StartLogging": "handle_start_logging",
    "StopLogging": "handle_stop_logging",
    "UpdateTrail": "handle_update_trail",
    "DeleteTrail": "handle_delete_trail",
}

def json_serial(obj):
    """JSON serializer for objects not serializable by default json code"""
    if isinstance(obj, datetime):
        return obj.isoformat()
    raise TypeError("Type not serializable")

def create_s3_bucket(bucket_name, region):
    """Creates an S3 bucket if it doesn't already exist."""
    try:
        s3_client.head_bucket(Bucket=bucket_name)
        logger.info(f"S3 bucket '{bucket_name}' already exists.")
    except ClientError as e:
        error_code = e.response['Error']['Code']
        if error_code == '404':
            try:
                if region == 'us-east-1':
                    # Create bucket without LocationConstraint for us-east-1
                    s3_client.create_bucket(Bucket=bucket_name)
                else:
                    # Create bucket with LocationConstraint for other regions
                    s3_client.create_bucket(
                        Bucket=bucket_name,
                        CreateBucketConfiguration={'LocationConstraint': region}
                    )
                logger.info(f"Created S3 bucket: {bucket_name}")
            except Exception as create_e:
                logger.error(f"Error creating bucket '{bucket_name}': {str(create_e)}")
                raise
        else:
            logger.error(f"Error checking bucket '{bucket_name}': {str(e)}")
            raise

def set_s3_bucket_policy(bucket_name, account_id):
    """Sets the necessary S3 bucket policy for CloudTrail."""
    try:
        bucket_policy = {
            "Version": "2012-10-17",
            "Statement": [{
                "Sid": "AWSCloudTrailAclCheck20150319",
                "Effect": "Allow",
                "Principal": {"Service": "cloudtrail.amazonaws.com"},
                "Action": "s3:GetBucketAcl",
                "Resource": f"arn:aws:s3:::{bucket_name}"
            },
            {
                "Sid": "AWSCloudTrailWrite20150319",
                "Effect": "Allow",
                "Principal": {"Service": "cloudtrail.amazonaws.com"},
                "Action": "s3:PutObject",
                "Resource": f"arn:aws:s3:::{bucket_name}/AWSLogs/{account_id}/*",
                "Condition": {
                    "StringEquals": {
                        "s3:x-amz-acl": "bucket-owner-full-control"
                    }
                }
            }]
        }

        s3_client.put_bucket_policy(Bucket=bucket_name, Policy=json.dumps(bucket_policy))
        logger.info(f"Set bucket policy for '{bucket_name}'.")
    except Exception as e:
        logger.error(f"Error setting bucket policy for '{bucket_name}': {str(e)}")
        raise

# Handler functions for different CloudTrail events

def handle_lookup_events(response_body):
    """Handles the LookupEvents event."""
    try:
        events_response = cloudtrail_client.lookup_events()
        response_body["events"] = events_response.get("Events", [])
        logger.info(f"LookupEvents processed successfully with {len(response_body['events'])} events.")
    except Exception as e:
        logger.error(f"Error processing LookupEvents: {str(e)}")
        response_body["error"] = str(e)
        raise

def handle_describe_trails(trails, response_body):
    """Handles the DescribeTrails event."""
    try:
        response_body["trails"] = trails
        logger.info(f"DescribeTrails processed successfully with {len(trails)} trails.")
    except Exception as e:
        logger.error(f"Error processing DescribeTrails: {str(e)}")
        response_body["error"] = str(e)
        raise

def handle_get_trail_status(trails, response_body):
    """Handles the GetTrailStatus event."""
    try:
        trail_statuses = {}
        for trail in trails:
            trail_name = trail.get("Name")
            if trail_name:
                try:
                    status_response = cloudtrail_client.get_trail_status(Name=trail_name)
                    trail_statuses[trail_name] = status_response
                    logger.info(f"Retrieved status for trail '{trail_name}'.")
                except Exception as e:
                    logger.warning(f"Failed to get status for trail '{trail_name}': {str(e)}")
                    trail_statuses[trail_name] = {"error": str(e)}
        response_body["statuses"] = trail_statuses
        logger.info(f"GetTrailStatus processed successfully.")
    except Exception as e:
        logger.error(f"Error processing GetTrailStatus: {str(e)}")
        response_body["error"] = str(e)
        raise

def handle_create_trail(event, context, response_body):
    """Handles the CreateTrail event by creating a new trail with a unique name."""
    try:
        # Generate dynamic trail and bucket names
        timestamp = int(time.time())
        # Shorten request_id to 8 characters to ensure bucket name <=63
        request_id_short = context.aws_request_id[:8]
        trail_name = f"default-trail-{timestamp}"
        s3_bucket_name = f"aws-cloudtrail-logs-{request_id_short}-{timestamp}"
        include_global_service_events = True  # Default value; can be parameterized if needed

        logger.info(f"Creating trail '{trail_name}' with bucket '{s3_bucket_name}' in region '{REGION}'.")

        # Step 1: Ensure the S3 bucket exists (create if it doesn't)
        create_s3_bucket(s3_bucket_name, REGION)

        # Step 2: Set the correct S3 bucket policy for CloudTrail
        account_id = context.invoked_function_arn.split(":")[4]  # Extract AWS Account ID dynamically
        set_s3_bucket_policy(s3_bucket_name, account_id)

        # Step 3: Create the CloudTrail
        create_trail_response = cloudtrail_client.create_trail(
            Name=trail_name,
            S3BucketName=s3_bucket_name,
            IncludeGlobalServiceEvents=include_global_service_events
        )

        # Step 4: Start logging for the trail
        cloudtrail_client.start_logging(Name=trail_name)

        response_body["trail"] = create_trail_response
        logger.info(f"CreateTrail processed successfully for trail '{trail_name}'.")
    except Exception as e:
        logger.error(f"Error processing CreateTrail: {str(e)}")
        response_body["error"] = str(e)
        raise

def handle_start_logging(response_body):
    """Handles the StartLogging event by starting logging for all trails."""
    try:
        trails_response = cloudtrail_client.describe_trails()
        trails = trails_response.get("trailList", [])
        if not trails:
            logger.info("No trails found to start logging.")
            response_body["message"] = "No trails found to start logging."
            return

        for trail in trails:
            trail_name = trail.get("Name")
            if trail_name:
                try:
                    cloudtrail_client.start_logging(Name=trail_name)
                    logger.info(f"Started logging for trail '{trail_name}'.")
                except Exception as e:
                    logger.warning(f"Failed to start logging for trail '{trail_name}': {str(e)}")
                    response_body.setdefault("errors", {})[trail_name] = str(e)
        response_body["message"] = "StartLogging processed successfully."
    except Exception as e:
        logger.error(f"Error processing StartLogging: {str(e)}")
        response_body["error"] = str(e)
        raise

def handle_stop_logging(response_body):
    """Handles the StopLogging event by stopping logging for all trails."""
    try:
        trails_response = cloudtrail_client.describe_trails()
        trails = trails_response.get("trailList", [])
        if not trails:
            logger.info("No trails found to stop logging.")
            response_body["message"] = "No trails found to stop logging."
            return

        for trail in trails:
            trail_name = trail.get("Name")
            if trail_name:
                try:
                    cloudtrail_client.stop_logging(Name=trail_name)
                    logger.info(f"Stopped logging for trail '{trail_name}'.")
                except Exception as e:
                    logger.warning(f"Failed to stop logging for trail '{trail_name}': {str(e)}")
                    response_body.setdefault("errors", {})[trail_name] = str(e)
        response_body["message"] = "StopLogging processed successfully."
    except Exception as e:
        logger.error(f"Error processing StopLogging: {str(e)}")
        response_body["error"] = str(e)
        raise

def handle_update_trail(event, response_body):
    """Handles the UpdateTrail event by updating configurations for all trails."""
    try:
        trails_response = cloudtrail_client.describe_trails()
        trails = trails_response.get("trailList", [])
        if not trails:
            logger.info("No trails found to update.")
            response_body["message"] = "No trails found to update."
            return

        # Example: Update each trail to include global service events (could be parameterized)
        for trail in trails:
            trail_name = trail.get("Name")
            try:
                cloudtrail_client.update_trail(
                    Name=trail_name,
                    IncludeGlobalServiceEvents=True  # Example update; modify as needed
                )
                logger.info(f"Updated trail '{trail_name}' to include global service events.")
            except Exception as e:
                logger.warning(f"Failed to update trail '{trail_name}': {str(e)}")
                response_body.setdefault("errors", {})[trail_name] = str(e)
        response_body["message"] = "UpdateTrail processed successfully."
    except Exception as e:
        logger.error(f"Error processing UpdateTrail: {str(e)}")
        response_body["error"] = str(e)
        raise

def handle_delete_trail(response_body):
    """Handles the DeleteTrail event by deleting all trails."""
    try:
        trails_response = cloudtrail_client.describe_trails()
        trails = trails_response.get("trailList", [])
        if not trails:
            logger.info("No trails found to delete.")
            response_body["message"] = "No trails found to delete."
            return

        for trail in trails:
            trail_name = trail.get("Name")
            if trail_name:
                try:
                    # Stop logging before deleting
                    cloudtrail_client.stop_logging(Name=trail_name)
                    # Delete the trail
                    cloudtrail_client.delete_trail(Name=trail_name)
                    logger.info(f"Deleted trail '{trail_name}'.")
                except Exception as e:
                    logger.warning(f"Failed to delete trail '{trail_name}': {str(e)}")
                    response_body.setdefault("errors", {})[trail_name] = str(e)
        response_body["message"] = "DeleteTrail processed successfully."
        return  # Ensure the function returns after processing
    except Exception as e:
        logger.error(f"Error processing DeleteTrail: {str(e)}")
        response_body["error"] = str(e)
        raise



def lambda_handler(event, context):
    """
    Main Lambda handler that processes all CloudTrail events for all trails.
    """
    try:
        event_name = event.get('detail', {}).get('eventName', 'Unknown')
        logger.info(f"Processing event: {event_name}")

        # Describe trails for further event processing
        trails_response = cloudtrail_client.describe_trails()
        trails = trails_response.get("trailList", [])

        # General response structure
        response_body = {"message": "Processed successfully"}

        # Check if the event has a dedicated handler
        handler_name = EVENT_HANDLERS.get(event_name)

        if handler_name:
            handler_function = globals()[handler_name]
            if handler_name in ["handle_lookup_events", "handle_describe_trails", "handle_get_trail_status"]:
                # These handlers require trails as a parameter
                handler_function(trails, response_body)
            elif handler_name == "handle_create_trail":
                # This handler requires event and context as parameters
                handler_function(event, context, response_body)
            elif handler_name in ["handle_start_logging", "handle_stop_logging", "handle_delete_trail"]:
                # These handlers do not require additional parameters
                handler_function(response_body)
            elif handler_name == "handle_update_trail":
                # This handler may require event parameters if needed
                handler_function(event, response_body)
            else:
                # Default handling for any other handlers
                handler_function(response_body)
        else:
            # Handle unrecognized events gracefully
            logger.info(f"Received unsupported event: {event_name}. Returning success message.")
            return {
                "statusCode": 200,
                "body": json.dumps({
                    "message": f"Unsupported event: {event_name}. No action performed.",
                    "eventName": event_name
                }, default=json_serial)
            }

        return {
            "statusCode": 200,
            "body": json.dumps(response_body, default=json_serial)
        }

    except Exception as e:
        logger.error(f"Error processing event: {str(e)}", exc_info=True)
        return {
            "statusCode": 500,
            "body": json.dumps({
                "error": str(e)
            }, default=json_serial)
        }
