# Image Classification: run inference on endpoint created

## Download example images

In [1]:
import boto3
from IPython.core.display import HTML

region = boto3.Session().region_name
s3_bucket = f"jumpstart-cache-prod-{region}"
key_prefix = "inference-notebook-assets"
s3 = boto3.client("s3")

def download_from_s3(key_filenames):
    for key_filename in key_filenames:
        s3.download_file(s3_bucket, f"{key_prefix}/{key_filename}", key_filename)

cat_jpg, dog_jpg = "cat.jpg", "dog.jpg"
download_from_s3(key_filenames=[cat_jpg, dog_jpg])

HTML('<table><tr><td> <img src="cat.jpg" alt="cat" style="height: 250px;"/> <figcaption>cat.jpg</figcaption>'
     '</td><td> <img src="dog.jpg" alt="dog" style="height: 250px;"/> <figcaption>dog.jpg</figcaption>'
     '</td></tr></table>')

0,1
cat.jpg,dog.jpg


## Open the downloaded images and load in memory. 

You can upload any image from your local computer in the directory and open them here.

In [3]:
images = {}
with open(cat_jpg, 'rb') as file: images[cat_jpg] = file.read()
with open(dog_jpg, 'rb') as file: images[dog_jpg] = file.read()
with open("cat2.png", 'rb') as file: images["cat2.png"] = file.read()


## Query endpoint that you have created with the opened images and parse predictions

Note: Backend scripts and the notebooks have been updated in Jan '22. This notebook will not work with the previously launched endpoints. If experiencing an error, please launch the endpoint again.

In [6]:
import json

def query_endpoint(img):
    endpoint_name = 'jumpstart-dft-tf-ic-imagenet-mobilenet-v2-100-224-clas'
    client = boto3.client('runtime.sagemaker')
    response = client.invoke_endpoint(EndpointName=endpoint_name, ContentType='application/x-image', 
                                      Body=img, Accept='application/json;verbose')
    return response
    

def parse_prediction(query_response):
    model_predictions = json.loads(query_response['Body'].read())
    predicted_label = model_predictions['predicted_label']
    labels = model_predictions['labels']
    probabilities = model_predictions['probabilities']
    return predicted_label, probabilities, labels 

for filename, img in images.items():
    query_response = query_endpoint(img)
    predicted_label, probabilities, labels = parse_prediction(query_response)
    display(HTML(f'<img src={filename} alt={filename} align="left" style="width: 250px;"/>' 
                 f'<figcaption>Predicted Label is : {predicted_label}</figcaption>'
                ))