In [1]:
# Standard imports
import numpy as np
import pandas as pd
import yaml

from sklearn import metrics

# Built-in library
import itertools
import re
import json
import logging
import typing as tp

import warnings

warnings.filterwarnings("error")

from feature_engine.transformation import (
    LogTransformer,
    YeoJohnsonTransformer,
)

# Custom Imports
from src.processing.data_manager import load_data, validate_input, logger
import src.processing.feat_engineering as fe
from src.config.schema import (
    TrainingSchema,
    ValidateTrainingData,
    ModelConfig,
    ConfigVars,
)

# pandas settings
pd.options.display.max_rows = 1_000
pd.options.display.max_columns = 1_000
pd.options.display.max_colwidth = 600

# Black code formatter (Optional)
%load_ext lab_black
# auto reload imports
%load_ext autoreload
%autoreload 2

INFO :: 2022-12-30 17:46:55,863 :: VERSION Loaded ...


## Workflow

1. Load data and add uuIDs
2. Preprocess data and make predictions
3. Compare the predictions with the actual target (trip_duration)
4. Save the output from step 3 to S3

### Note: Parametize the entire workflow

In [13]:
taxi_type = "yellow"
year = 2022  # {year:04d} i.e year with 4 digits
month = 2

fp = f"s3://nyc-tlc/trip data/{taxi_type}_tripdata_{year:04d}-{month:02d}.parquet"
# fp = "s3://nyc-tlc/trip data/yellow_tripdata_2022-11.parquet"
data = load_data(filename=fp, uri=True)

data.head()

INFO :: 2022-12-30 19:19:34,633 :: Loading Data ... 


Unnamed: 0,VendorID,tpep_pickup_datetime,tpep_dropoff_datetime,passenger_count,trip_distance,RatecodeID,store_and_fwd_flag,PULocationID,DOLocationID,payment_type,fare_amount,extra,mta_tax,tip_amount,tolls_amount,improvement_surcharge,total_amount,congestion_surcharge,airport_fee,id,trip_duration
0,1,2022-02-01 00:06:58,2022-02-01 00:19:24,1.0,5.4,1.0,N,138,252,1,17.0,1.75,0.5,3.9,0.0,0.3,23.45,0.0,1.25,194d1390-5e36-42de-9a9b-b0c17a257728,2.597491
1,1,2022-02-01 00:38:22,2022-02-01 00:55:55,1.0,6.4,1.0,N,138,41,2,21.0,1.75,0.5,0.0,6.55,0.3,30.1,0.0,1.25,b905bc88-a571-42b1-862e-ad7671047f2c,2.92047
2,1,2022-02-01 00:03:20,2022-02-01 00:26:59,1.0,12.5,1.0,N,138,200,2,35.5,1.75,0.5,0.0,6.55,0.3,44.6,0.0,1.25,ab579b0a-9f87-4c7a-b042-cfe90428b29e,3.204777
3,2,2022-02-01 00:08:00,2022-02-01 00:28:05,1.0,9.88,1.0,N,239,200,2,28.0,0.5,0.5,0.0,3.0,0.3,34.8,2.5,0.0,86dc2aca-6a57-4f51-9baf-5195e9dfea08,3.048325
4,2,2022-02-01 00:06:48,2022-02-01 00:33:07,1.0,12.16,1.0,N,138,125,1,35.5,0.5,0.5,8.11,0.0,0.3,48.66,2.5,1.25,744f4750-2134-48e1-87cf-296d959f57d6,3.307619


In [14]:
fp = f"s3://nyc-tlc/trip data/{taxi_type}_tripdata_{year:04d}-{month:02d}.parquet"
fp2 = "s3://nyc-tlc/trip data/yellow_tripdata_2022-11.parquet"

print(f"fp={fp}\n fp1={fp2}\n")

fp=s3://nyc-tlc/trip data/yellow_tripdata_2022-02.parquet
 fp1=s3://nyc-tlc/trip data/yellow_tripdata_2022-11.parquet



In [7]:
def get_predictions(*, data: pd.DataFrame, run_id: str) -> np.ndarray:
    """This returns the predicted trip duration using the model
    from the model registry on S3.

    Params:
    -------
    run_id (str): The run id associated with the model.
    url (bool): True if the filename is a url else False.

    Returns:
    --------
    pred (ndarray): The predicted trip duration.
    """
    import mlflow

    S3_BUCKET_NAME = f"s3://mlflow-model-registry-neidu/1/{run_id}/artifacts/model"
    # Load the model from the model registry
    logger.info("Fetching model from registry ...")
    model = mlflow.pyfunc.load_model(model_uri=f"{S3_BUCKET_NAME}")
    logger.info("Making predictions ...")
    pred = model.predict(data)
    pred = [(round(x, 1)) for x in list(np.exp(pred))]  # Convert from log to minutes
    return np.array(pred)


def get_paths(*, run_date: str, taxi_type: str, run_id: str) -> tp.Tuple:
    """This returns the input and output S3 bucket URIs for the data.

    Params:
    -------
    run_date (str): The date the script was run in  `year-month-day format`. e.g "2022-12-30"
    taxi_type (str): The taxi colour. e.g yellow, green, etc
    run_id (str): The run id associated with the model.

    Returns:
    --------
    input_file, output_file (S3 URIs): The input and output S3 URIs respectively.
    """
    from datetime import datetime
    from dateutil.relativedelta import relativedelta

    date_obj = datetime.strptime(run_date, "%Y-%m-%d")
    prev_month = date_obj - relativedelta(months=1)
    year = prev_month.year
    month = prev_month.month

    input_file = (
        f"s3://nyc-tlc/trip data/{taxi_type}_tripdata_{year:04d}-{month:02d}.parquet"
    )
    output_file = f"s3://nyc-duration-prediction-neidu/taxi_type={taxi_type}/year={year:04d}/month={month:02d}/{run_id}.parquet"

    return input_file, output_file


def compare_predictions(*, data: pd.DataFrame, run_id: str) -> pd.DataFrame:
    """This compares the actual vs predicted trip duration.

    Params:
    -------
    data (Pandas DF): DF containing the NYC taxi data.
    run_id (str): The run id associated with the model.

    Returns:
    --------
    result_df (Pandas DF): DF containing the predicted trip duration and other info.
    """
    result_df = pd.DataFrame()
    pred_trip_duration = get_predictions(data=data, run_id=run_id)

    result_df["id"] = data["id"]
    result_df["tpep_pickup_datetime"] = data["tpep_pickup_datetime"]
    result_df["trip_distance"] = data["trip_distance"]
    result_df["PULocationID"] = data["PULocationID"]
    result_df["DOLocationID"] = data["DOLocationID"]
    result_df["actual_trip_duration"] = data["trip_duration"].apply(
        np.exp
    )  # Convert to minutes
    result_df["pred_trip_duration"] = pred_trip_duration
    result_df["diff"] = (
        result_df["actual_trip_duration"] - result_df["pred_trip_duration"]
    )
    result_df["model_run_id"] = run_id
    return result_df


def save_data_to_s3(data: pd.DataFrame, output: str) -> None:
    """This saves the parquet data to S3"""
    try:
        logger.info("Saving to S3 ...")
        data.to_parquet(path=output, index=False)
    except Exception as err:
        logger.info(err)

In [11]:
def main():
    """This is the main function"""
    from argparse import ArgumentParser
    
    parser = ArgumentParser(
        prog="Batch predictions",
        description="This is used to make batch predictions of the NYC taxi trip duration.",
    )
    parser.add_argument(
        "--run-id",
        "-r",
        help="The run id associated with the model e.g `98f43706f6184694be1ee10c41c7b69d`",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--run-date",
        "-d",
        help="The date the script was run in `year-month-day format`. e.g '2022-12-30'",
        type=str,
        required=True,
    )

    parser.add_argument(
        "--taxi-type",
        "-t",
        help="The taxi colour. e.g yellow, green, etc",
        type=str,
        required=True,
    )
    args = parser.parse_args()
    # Extract the variables
    run_id, run_date, taxi_type = (args.run_id, args.run_date, args.taxi_type)

    input_file, output_file = get_paths(
        run_date=run_date, taxi_type=taxi_type, run_id=run_id
    )
    data = load_data(filename=input_file, uri=True)
    result_df = compare_predictions(run_id=run_id)
    save_data_to_s3(data=result_df)
    logger.info("Batch Prediction processing done!")

In [5]:
run_id = "98f43706f6184694be1ee10c41c7b69d"
result_df = compare_predictions(run_id=run_id)

result_df.head()

INFO :: 2022-12-30 17:48:23,931 :: Fetching model from registry ...
INFO :: 2022-12-30 17:48:23,969 :: Found credentials in shared credentials file: ~/.aws/credentials
INFO :: 2022-12-30 17:48:36,789 :: Making predictions ...


Unnamed: 0,id,tpep_pickup_datetime,trip_distance,PULocationID,DOLocationID,actual_trip_duration,pred_trip_duration,diff,model_run_id
0,98b25fbb-11f3-411a-b37f-bcfa9c01d757,2022-02-01 00:06:58,5.4,138,252,13.43,17.2,-3.77,98f43706f6184694be1ee10c41c7b69d
1,3326907a-a18c-4271-93c1-04876a1d4ccc,2022-02-01 00:38:22,6.4,138,41,18.55,22.2,-3.65,98f43706f6184694be1ee10c41c7b69d
2,df0b5f96-ef88-4a11-bee1-b5ccdafc2148,2022-02-01 00:03:20,12.5,138,200,24.65,32.4,-7.75,98f43706f6184694be1ee10c41c7b69d
3,e293be56-000d-40f2-ab51-d5825967fcdf,2022-02-01 00:08:00,9.88,239,200,21.08,22.2,-1.12,98f43706f6184694be1ee10c41c7b69d
4,f93afcb0-f2d6-4045-852d-0efba66c0e8c,2022-02-01 00:06:48,12.16,138,125,27.32,28.5,-1.18,98f43706f6184694be1ee10c41c7b69d
