# Training an image classification model using Sagemaker and HLS imagery from Earthdata Cloud (EDC).
Using a set of training data defined [here](https://github.com/nasa-esdswg-ml/edc-notebooks/blob/main/Sagemaker/data-preparation.ipynb) we will train a model using SageMaker's 'image-classification' framework.

Note: In order for this to work correctly with EDC we would need to grant read access to the AWS user associated with the 'sandbox' profile. That is not currently possible. In this notebook, you should copy the data to an S3 bucket your user has normal read access to using the technique outline [here](https://github.com/nasa-esdswg-ml/edc-notebooks/blob/main/EDC%20Data%20Access/s3-access.ipynb). This is one of the primary findings of our EDC+ML investigation and recommendations will be made to EOSDIS to make Sagemaker direct data access (ie. usage of EDC data without having to copy) possible.

In [None]:
import time
import boto3
import re
import sagemaker
from sagemaker import get_execution_role
from sagemaker import image_uris

role = get_execution_role()

bucket = sagemaker.Session().default_bucket()

# See training.ipynb for job name prefix and model name
model_name = "<insert model name here>"
endpoint_prefix = "<insert and endpoint name here>"

test_endpoint_image_bucket = "<insert S3 bucket containing cloudy or non-cloudy images to test endpoint with>"
# A list of jog files in that bucket to test the endpoint with
images = ("<jpg test image one>", "<jpg test image two>")

# Configuring a model endpoint

In [None]:
from time import gmtime, strftime

timestamp = time.strftime("-%Y-%m-%d-%H-%M-%S", time.gmtime())
endpoint_config_name = job_name_prefix + "-epc-" + timestamp
endpoint_config_response = sagemaker.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "InstanceType": "ml.m4.xlarge",
            "InitialInstanceCount": 1,
            "ModelName": model_name,
            "VariantName": "AllTraffic",
        }
    ],
)

print("Endpoint configuration name: {}".format(endpoint_config_name))
print("Endpoint configuration arn:  {}".format(endpoint_config_response["EndpointConfigArn"]))

# Creating an image classification endpoint from the model

In [None]:
timestamp = time.strftime("-%Y-%m-%d-%H-%M-%S", time.gmtime())
endpoint_name = endpoint_prefix + timestamp
print("Endpoint name: {}".format(endpoint_name))

endpoint_params = {
    "EndpointName": endpoint_name,
    "EndpointConfigName": endpoint_config_name,
}
endpoint_response = sagemaker.create_endpoint(**endpoint_params)
print("EndpointArn = {}".format(endpoint_response["EndpointArn"]))

# Testing the image classification endpoint
We can feed the endpoint with a browse image and ask it whether that image is cloudy or clear.
We render the image, feed it to the endpoint and display the result.

In [None]:
from IPython.display import Image
from PIL import Image as Im
import matplotlib.pyplot as plt
import json
import numpy as np

object_categories = [
    "clear",
    "cloudy"
]
session = boto3.Session(profile_name='sandbox')
runtime = session.client(service_name="runtime.sagemaker")

imageIterator = iter(images)

def testImage(image):
    file_name = "/tmp/" + image
    s3_client = boto3.client("s3")
    s3_client.download_file(
        test_endpoint_image_bucket,
        image,
        file_name,
    )

    # test image
    Image(file_name)

    img = Im.open(file_name)

    imgplot = plt.imshow(img)
    plt.show(imgplot)     
    img = open(file_name, 'rb').read()
    
    response = runtime.invoke_endpoint(
        EndpointName=endpoint_name, ContentType="application/x-image", Body=bytearray(img)
    )

    result = response["Body"].read()

    # result will be in json format and convert it to ndarray
    result = json.loads(result)
    print("Result: " + str(result))
    # the result will output the probabilities for all classes
    # find the class with maximum probability and print the class index
    index = np.argmax(result)
    print("This image is " + object_categories[index])
    
testImage(next(imageIterator))
testImage(next(imageIterator))
testImage(next(imageIterator))
testImage(next(imageIterator))
testImage(next(imageIterator))


# Clean up

In [None]:
sagemaker.delete_endpoint(EndpointName=endpoint_name)