# 4 - Endpoint <a class="anchor" id="top"></a>
* [Introduction](#intro)
* [Setup](#setup)

## Introduction <a class="anchor" id="intro"></a>
In this last section, we create a Sagemaker endpoint to allow for real-time predictions using our trained models.
After creating the endpoint, we will test a simple application that takes in basic flight information and returns 
the models prediction.

## Setup <a class="anchor" id="setup"></a>
First, we import Sageamker SDK dependencies as well as modules used in application below.
We also get relevant sessions and read in local environment data.

In [12]:
import xml
import json
import uuid
import boto3
import random
import requests
import numpy as np
import pandas as pd
import datetime as dt
import dateutil.parser
import sagemaker as sm
import sagemaker.sparkml as sparkml

In [13]:
# Get relevant sessions.
sm_session = sm.Session()
role = sm.get_execution_role()
boto3_session = boto3.session.Session()
now = dt.datetime.now().strftime(r"%Y%m%dT%H%M%S")

In [14]:
# Get boto3 session attributes.
account = boto3_session.client("sts").get_caller_identity()["Account"]
region = boto3_session.region_name
s3_resource = boto3_session.resource("s3")

In [15]:
# Retrieve data bucket name.
with open("/home/ec2-user/.aiml-bb/stack-data.json", "r") as f:
    data = json.load(f)
    data_bucket = data["data_bucket"]
    model_bucket = data["model_bucket"]

## Define model
To allow for a complete inference pipeline, we chain together the preprocessing, model inference/evaluation, and postprocessing.
We will define each of these stages as a Sagemaker `Model` object, then chain them together into an inference pipeline.

In [16]:
# Required schema for input into preprocessing step.
preprocess_schema_json = json.dumps({
    "input": [
        {"name": "day_of_week", "type": "int"},
        {"name": "month", "type": "int"},
        {"name": "op_carrier", "type": "string"},
        {"name": "origin", "type": "string"}, 
        {"name": "origin_latitude", "type": "double"}, 
        {"name": "origin_longitude", "type": "double"},
        {"name": "dest", "type": "string"}, 
        {"name": "dest_latitude", "type": "double"}, 
        {"name": "dest_longitude", "type": "double"},
        {"name": "origin_tmax", "type": "double"}, 
        {"name": "origin_tmin", "type": "double"}, 
        {"name": "origin_prcp", "type": "double"}, 
        {"name": "origin_snow", "type": "double"}, 
        {"name": "origin_snwd", "type": "double"},
        {"name": "dest_tmax", "type": "double"}, 
        {"name": "dest_tmin", "type": "double"}, 
        {"name": "dest_prcp", "type": "double"}, 
        {"name": "dest_snow", "type": "double"}, 
        {"name": "dest_snwd", "type": "double"}
    ],
     "output": {"name": "features", "type": "double", "struct": "vector"}
})

In [17]:
# Define the preprocessing model.
preprocess_model = sparkml.model.SparkMLModel(
    name=f"spark-preprocessor-{now}",
    model_data=f"s3://{model_bucket}/spark-preprocessor/model.tar.gz",
    spark_version="2.4",
    sagemaker_session=sm_session,
    env={"SAGEMAKER_SPARKML_SCHEMA": preprocess_schema_json}
)

In [18]:
# Define inference model.
xgb_container_image = sm.image_uris.retrieve("xgboost", region, "latest")
inference_model = sm.model.Model(
    image_uri=xgb_container_image,
    model_data=f"s3://{model_bucket}/sagemaker-xgboost-tuned/model.tar.gz"
)

In [19]:
# Define complete inference pipeline model and deploy.
pipeline_model = sm.pipeline.PipelineModel(
    name=f"sm-pipeline-{now}",
    role=role,
    models=[
        preprocess_model,
        inference_model
    ]
)
endpoint_name = f"pipeline-endpoint-{now}"
pipeline_model.deploy(
    initial_instance_count=1,
    instance_type="ml.m5.xlarge",
    endpoint_name=endpoint_name
)

-----------!

In [20]:
# Connect a predictor to the endpoint for inference.
pipeline_predictor = sm.predictor.Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sm_session,
    serializer=sm.serializers.JSONSerializer()
)

## Demo testing endpoint
Test the endpoint in a simple application where the flight information is inputted, and a prediction is returned.

In [35]:
# User inputted features.
origin = "JFK"
dest = "LAX"
carrier = "B6"
fl_date = "2022-01-28"

All code below would be abstracted away from the user.

In [36]:
# Get date attributes.
today = dt.datetime.today().replace(hour=0, minute=0, second=0, microsecond=0)
fl_datetime = dt.datetime.strptime(fl_date, r"%Y-%m-%d")
day_of_week = fl_datetime.weekday() + 1
month = fl_datetime.month

In [37]:
# Get latitude and longitudes of airports.
airport_df = pd.read_parquet(f"s3://{data_bucket}/dl_output/airport_data")
get_iata_geolocation = (
    lambda iata: 
    airport_df.loc[airport_df["iata"]==iata, ["latitude", "longitude"]].iloc[0]
)
origin_lat, origin_lon = get_iata_geolocation(origin)
dest_lat, dest_lon = get_iata_geolocation(dest)

In [38]:
# Grab weather data.
forecast_fqdn = "https://graphical.weather.gov"
get_geolocation_forecast = (
    lambda lat, lon:
    xml.etree.ElementTree.fromstring(
        requests.get(
            f"{forecast_fqdn}/xml/SOAP_server/ndfdXMLclient.php",
            params={
                "lat": lat, "lon": lon,
                "begin": today.isoformat(), 
                "end": (today + dt.timedelta(days=7)).isoformat(),
                "Unit": "m",
                "maxt": "maxt", "mint": "mint",
                "qpf": "qpf", "snow": "snow",
                "product": "time-series",
                "Submit": "Submit"
            }
        ).content
    )
)
origin_forecast = get_geolocation_forecast(origin_lat, origin_lon) 
dest_forecast = get_geolocation_forecast(dest_lat, dest_lon)

In [39]:
# Define function to get averages of date values in XML.
def get_avg_xml_value(xml_tree, field, datetime=fl_datetime):
    # Get date index key.
    layout_key = xml_tree.find(f".//*{field}").attrib["time-layout"]
    
    # Find indices of dates matching date in question.
    idxs = []
    for idx, date in enumerate(xml_tree.findall(f".//*time-layout/start-valid-time")):
        datetime = dateutil.parser.parse(date.text)
        if fl_datetime.strftime("%Y-%m-%d") == datetime.strftime("%Y-%m-%d"):
            idxs.append(idx)
            
    if not idxs:
        raise ValueError("Date invalid, no data found for field. Possibly too far into the future.")
            
    # Data is for different times of day so we take mean.
    # Zero is added so we default in case of no data (e.g. with snow).
    val_sum = 0.0
    for idx, val in enumerate(xml_tree.findall(f".//*{field}/value")):
        if idx in idxs:
            val_sum += float(val.text)
            
    return val_sum / len(idxs)

In [40]:
# Get forecast values and convert to dataset formats.
# In NOAA weather data, all values are scaled by 1/10.
origin_tmax = 0.10 * get_avg_xml_value(origin_forecast, "temperature[@type='maximum']")
origin_tmin = 0.10 * get_avg_xml_value(origin_forecast, "temperature[@type='minimum']")
origin_snwd = 0.10 * get_avg_xml_value(origin_forecast, "precipitation[@type='snow']")
origin_liquid = 0.10 * get_avg_xml_value(origin_forecast, "precipitation[@type='liquid']")

dest_tmax = 0.10 * get_avg_xml_value(dest_forecast, "temperature[@type='maximum']")
dest_tmin = 0.10 * get_avg_xml_value(dest_forecast, "temperature[@type='minimum']")
dest_snwd = 0.10 * get_avg_xml_value(dest_forecast, "precipitation[@type='snow']")
dest_liquid = 0.10 * get_avg_xml_value(dest_forecast, "precipitation[@type='liquid']")

# This snow to liquid ratio is often assumed, however can be inaccurate.
# It is suitable for demonstration purposes, but may need more acccurate 
# inspection in production use cases.
snow_to_liquid_ration = 10.0

origin_avg = (origin_tmax + origin_tmin) / 2
origin_prcp = origin_liquid if origin_avg > 0 else 0
origin_snow = 0 if origin_avg > 0 else snow_to_liquid_ration * origin_liquid

dest_avg = (dest_tmax + dest_tmin) / 2
dest_prcp = dest_liquid if dest_avg > 0 else 0
dest_snow = 0 if dest_avg > 0 else snow_to_liquid_ration * dest_liquid

### Make call to endpoint
Send data to the endpoint and make the prediction.

In [41]:
payload =  {"data": [
    day_of_week, month, 
    carrier, 
    origin, origin_lat, origin_lon, 
    dest, dest_lat, dest_lon,
    origin_tmax, origin_tmin, origin_prcp, origin_snow, origin_snwd,
    dest_tmax, dest_tmin, dest_prcp, dest_snow, dest_snwd
]}
print(pipeline_predictor.predict(payload))

b'0.28944116830825806'


## Cleanup resources
Because this is a temporary project, delete the endpoint.

In [42]:
sm_session.delete_endpoint(endpoint_name)