# Distributed Data Processing using SageMaker Processing and DJL Spark Container

This example notebook demonstrates on how to use Amazon SageMaker Processing with DJL Spark docker image to run distributed deep learning inference on large datasets. DJL Spark docker image is a pre-built image that includes the Deep Java Library (DJL) and other dependencies needed to run distributed data processing jobs on Amazon SageMaker. DJL is an open-source Java-based deep learning library, designed to be easy to use and compatible with existing deep learning ecosystems.

By using these two services together, you can easily run distributed deep learning inference on large datasets in a scalable and cost-effective manner.

## Contents

1. [Setup](#Setup)
1. [Push the Container to ECR](#Push-the-Container-to-ECR)
1. [Prepare Processing Script](#Prepare-Processing-Script)
1. [Run the SageMaker Processing Job](#Run-the-SageMaker-Processing-Job)
1. [Monitor and Analyze Your Job](#Monitor-and-Analyze-Your-Job)
1. [Validate Data Processing Results](#Validate-Data-Processing-Results)

## Setup

### Install the SageMaker Python SDK

This notebook requires the SageMaker Python SDK.

In [None]:
!pip install sagemaker

### Setup account and role

Next, you'll define the account, region and role that will be used to run the SageMaker Processing job.

In [None]:
import sagemaker
from time import gmtime, strftime

sagemaker_session = sagemaker.Session()
account_id = sagemaker_session.account_id()
region = sagemaker_session.boto_region_name
role = sagemaker.get_execution_role()

## Push the Container to ECR

The following section pulls the DJL Spark docker image from dockerhub and pushes to your ECR.

In [None]:
docker_registry="deepjavalibrary"
repository_name="djl-spark"
tag="0.25.0-cpu"
ecr_registry="{}.dkr.ecr.{}.amazonaws.com".format(account_id, region)

# Pull the DJL Spark image
!docker pull $docker_registry/$repository_name:$tag

# Create ECR repository and push the image to your ECR
!$(aws ecr get-login --region $region --registry-ids $account_id --no-include-email)
repository_query = !(aws ecr describe-repositories --repository-names $repository_name)
if repository_query[0] == '':
    !(aws ecr create-repository --repository-name $repository_name)
!docker tag $docker_registry/$repository_name:$tag $ecr_registry/$repository_name:$tag
!docker push $ecr_registry/$repository_name:$tag

## Prepare Processing Script

The source for the processing script is in the cell below. The cell uses the `%%writefile` directive to save this file locally. This script performs image classification on an image dataset using the resnet model.

In [None]:
%%writefile ./code/process.py
import argparse
import os

from pyspark.sql import SparkSession
from djl_spark.task.vision import ImageClassifier


def main():
    parser = argparse.ArgumentParser(description="app inputs and outputs")
    parser.add_argument("--s3_input_bucket", type=str, help="s3 input bucket")
    parser.add_argument("--s3_input_key_prefix", type=str, help="s3 input key prefix")
    parser.add_argument("--s3_output_bucket", type=str, help="s3 output bucket")
    parser.add_argument("--s3_output_key_prefix", type=str, help="s3 output key prefix")
    args = parser.parse_args()

    spark = SparkSession.builder.appName("sm-spark-djl-image-classification").getOrCreate()

    df = spark.read.format("image").option("dropInvalid", True).load("s3://" + os.path.join(args.s3_input_bucket, args.s3_input_key_prefix))
    df = df.select("image.*").filter("nChannels=3") # The model expects RGB images

    # Image classification
    classifier = ImageClassifier(input_cols=["origin", "height", "width", "nChannels", "mode", "data"],
                                 output_col="prediction",
                                 engine="PyTorch",
                                 model_url="djl://ai.djl.pytorch/resnet",
                                 top_k=2)
    outputDf = classifier.classify(df).select("origin", "prediction.top_k")
    outputDf.write.mode("overwrite").parquet("s3://" + os.path.join(args.s3_output_bucket, args.s3_output_key_prefix))


if __name__ == "__main__":
    main()

## Run the SageMaker Processing Job

Next, we'll create a `PySparkProcessor` with the following parameters:

* `base_job_name`: Set the prefix for processing name to "sm-spark-djl".
* `image_uri`: Set the URI of the Docker image to the image that uploaded above. 
* `role`: Set the AWS IAM role to use for the processing job.
* `instance_count`: Set the number of instances to run the processing job to 2.
* `instance_type`: Set the type of EC2 instance to use for processing to "ml.m5.2xlarge".

We also set a Spark configuration:

* `spark.executor.memory`: Set the amount of memory to use per executor process to 2g.
* `spark.executor.cores`: Set the number of cores to use on each executor to 2.

Then, the code calls the `run` method of the processor to start the processing job. It passes in the following parameters:

* `submit_app`: The path of the processing script to submit to Spark.
* `arguments`: List of string arguments to be passed to the processing job, including the input and output location. The input dataset we use is 300 images from the [Coco](https://cocodataset.org/#download) dataset.
* `configuration`: Spark configuration in above.
* `spark_event_logs_s3_uri`: S3 path where spark application events will be published to.
* `logs`: Set whether to show the logs produced by the job to False.
* `wait`: Set wait until the job completes to True.

In [None]:
from sagemaker.spark.processing import PySparkProcessor

input_bucket = "alpha-djl-demos"
input_prefix = "dataset/cv/coco"

timestamp_prefix = strftime("%Y-%m-%d-%H-%M-%S", gmtime())
prefix = "sagemaker/spark-processing-djl-demo/{}".format(timestamp_prefix)
output_bucket = sagemaker_session.default_bucket()
output_prefix = f"{prefix}/output"

image_uri = "{}/{}:{}".format(ecr_registry, repository_name, tag)

# Run the processing job
spark_processor = PySparkProcessor(
    base_job_name="sm-spark-djl",
    image_uri=image_uri,
    role=role,
    instance_count=2,
    instance_type="ml.m5.2xlarge"
)

configuration = [
    {
        "Classification": "spark-defaults",
        "Properties": {"spark.executor.memory": "2g", "spark.executor.cores": "2"}
    }
]

spark_processor.run(
    submit_app="./code/process.py",
    arguments=[
        "--s3_input_bucket", input_bucket,
        "--s3_input_key_prefix", input_prefix,
        "--s3_output_bucket", output_bucket,
        "--s3_output_key_prefix", output_prefix,
    ],
    configuration=configuration,
    spark_event_logs_s3_uri="s3://{}/{}/spark_event_logs".format(output_bucket, prefix),
    logs=False, # Do not show the logs produced by the job
    wait=True # Wait until the job completes
)

## Monitor and Analyze Your Job

Next, by calling `start_history_server()`, you can start the Spark history server and access the Spark UI to view details about the Spark application. This is useful for debugging and troubleshooting, as well as for monitoring the performance and behavior of your Spark processing job.

In [None]:
spark_processor.start_history_server()

After viewing the Spark UI, you can terminate the history server before proceeding.

In [None]:
spark_processor.terminate_history_server()

## Validate Data Processing Results

Next, validate the output of the Spark job by ensuring that the output URI contains the Spark `_SUCCESS` file along with the output files.

In [None]:
print("Output files in s3://{}/{}/".format(output_bucket, output_prefix))
!aws s3 ls s3://$output_bucket/$output_prefix/