# Batch Inference

This notebook outlines a workflow for generating model predictions.

#### Import dependencies, define notebook parameters and constants

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, struct
import mlflow.pyfunc

In [None]:
# define notebook parameters
dbutils.widgets.text("model_uri", "models:/credit-default-uci-sklearn/1")

dbutils.widgets.text(
    "inference_dataset_table", "hive_metastore.default.credit_default_uci_inference"
)

#### Run batch inference

In [None]:
# define parameters
inference_dataset_table = dbutils.widgets.get("inference_dataset_table")
model_uri = dbutils.widgets.get("model_uri")

In [None]:
# read inference dataset
inference_df = spark.read.table(inference_dataset_table).select(
    col("sex").cast("string"),
    col("education").cast("string"),
    col("marriage").cast("string"),
    col("repayment_status_1").cast("string"),
    col("repayment_status_2").cast("string"),
    col("repayment_status_3").cast("string"),
    col("repayment_status_4").cast("string"),
    col("repayment_status_5").cast("string"),
    col("repayment_status_6").cast("string"),
    col("credit_limit").cast("double"),
    col("age").cast("integer"),
    col("bill_amount_1").cast("double"),
    col("bill_amount_2").cast("double"),
    col("bill_amount_3").cast("double"),
    col("bill_amount_4").cast("double"),
    col("bill_amount_5").cast("double"),
    col("bill_amount_6").cast("double"),
    col("payment_amount_1").cast("double"),
    col("payment_amount_2").cast("double"),
    col("payment_amount_3").cast("double"),
    col("payment_amount_4").cast("double"),
    col("payment_amount_5").cast("double"),
    col("payment_amount_6").cast("double"),
    col("prediction").cast("double"),
)

# filter for records that have not predictions
batch_df = inference_df.filter("prediction IS NULL")

# create spark user-defined function for model prediction
predict = mlflow.pyfunc.spark_udf(spark, model_uri, result_type="double")

# generate predictions
predictions_df = batch_df.withColumn("prediction", predict(struct(*batch_df.columns)))

# update inference dataset with predictions
updated_inference_df = inference_df.filter("prediction IS NOT NULL").union(
    predictions_df
)

In [None]:
# write results to inference dataset table
(
    updated_inference_df.write.mode("overwrite")
    .option("overwriteSchema", "true")
    .saveAsTable(inference_dataset_table)
)
