In [71]:
import requests

# fill the parameters for the prediction
params2 = dict(
  pickup_datetime='2012-10-06 12:10:20',
  pickup_longitude=40.7614327,
  pickup_latitude=-73.9798156,
  dropoff_longitude=40.6413111,
  dropoff_latitude=-73.9797156,
  passenger_count=1
)

# params = dict(
#   pickup_datetime='2012-10-06 12:10:20',
#   pickup_longitude=40.7614327,
#   pickup_latitude=-73.9798156,
#   dropoff_longitude=40.6331166,
#   dropoff_latitude=-73.8874078,    
#   passenger_count=2
# )

# URL
taxifare_api_url = "http://127.0.0.1:8000/predict"

# retrieve the response
response = requests.get(
    taxifare_api_url,
    params=params2
)

if response.status_code == 200:
    print("API call success")
else:
    print("API call error")

response.status_code, response.json().get("prediction", "no prediction"), response.json()


API call success


(200,
 13.634983667920714,
 {'key': '2013-07-06 17:18:00.000000119',
  'pickup_datetime': '2012-10-06 12:10:20',
  'pickup_longitude': 40.7614327,
  'pickup_latitude': -73.9798156,
  'dropoff_longitude': 40.6413111,
  'dropoff_latitude': -73.9797156,
  'passenger_count': 1,
  'prediction': 13.634983667920714})

In [66]:
import pandas as pd
import numpy as np
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from datetime import datetime
import pytz
import joblib
from predict import download_model
from google.cloud import storage

PATH_TO_LOCAL_MODEL='../model.joblib'

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Allows all origins
    allow_credentials=True,
    allow_methods=["*"],  # Allows all methods
    allow_headers=["*"],  # Allows all headers
)

@app.get("/")
def index():
    return {"greeting": "Hello world"}


@app.get("/predict")
def predict(pickup_datetime, pickup_longitude, pickup_latitude,
          dropoff_longitude, dropoff_latitude, passenger_count):

    X_pred_dict = {
        'key': "2013-07-06 17:18:00.000000119",
        'pickup_datetime': pickup_datetime,
        'pickup_longitude': float(pickup_longitude),
        'pickup_latitude': float(pickup_latitude),
        'dropoff_longitude': float(dropoff_longitude),
        'dropoff_latitude': float(dropoff_latitude),
        'passenger_count': int(passenger_count)
    }

    X_pred = pd.DataFrame(X_pred_dict, index=[0])

    # create a datetime object from the user provided datetime
    pickup_datetime = datetime.strptime(pickup_datetime, "%Y-%m-%d %H:%M:%S")

    # localize the user datetime with NYC timezone
    eastern = pytz.timezone("US/Eastern")
    localized_pickup_datetime = eastern.localize(pickup_datetime, is_dst=None)

    # localize the datetime to UTC
    utc_pickup_datetime = localized_pickup_datetime.astimezone(pytz.utc)
    formatted_pickup_datetime = utc_pickup_datetime.strftime("%Y-%m-%d %H:%M:%S UTC")

    X_pred['pickup_datetime'] = formatted_pickup_datetime

#     pipeline = joblib.load(PATH_TO_LOCAL_MODEL)
    pipeline = download_model(rm=False)

    y_pred = pipeline.predict(X_pred)

    X_pred_dict['prediction'] = y_pred[0]

    return X_pred_dict

In [67]:
MLFLOW_URI = "https://mlflow.lewagon.co/"
EXPERIMENT_NAME = "JPN Tokyo lorcanob TFM 0.1"
# PATH_TO_LOCAL_MODEL = 'model.joblib'
AWS_BUCKET_TEST_PATH = "s3://wagon-public-datasets/taxi-fare-test.csv"
BUCKET_NAME = 'wagon-data-728-odufuwabolger'
BUCKET_TRAIN_DATA_PATH = 'data/train_1k.csv'
MODEL_NAME = 'taxifare'
MODEL_VERSION = 'v2'

In [68]:
def download_model(model_directory="v2", bucket=BUCKET_NAME, rm=True):
    client = storage.Client().bucket(bucket)

    storage_location = f'models/{MODEL_NAME}/{model_directory}/model.joblib'
    blob = client.blob(storage_location)
    blob.download_to_filename('model.joblib')
    print("=> pipeline downloaded from storage")
    model = joblib.load('model.joblib')
    if rm:
        os.remove('model.joblib')
    return model

In [69]:
predict(*list(params2.values()))

=> pipeline downloaded from storage


{'key': '2013-07-06 17:18:00.000000119',
 'pickup_datetime': '2012-10-06 12:10:20',
 'pickup_longitude': 40.7614327,
 'pickup_latitude': -73.9798156,
 'dropoff_longitude': 40.6413111,
 'dropoff_latitude': -73.9797156,
 'passenger_count': 1,
 'prediction': 13.634983667920714}