# Batch inference

Apply the model to all the documents.

Demonstrate batch inference of the model over a large set of documents
by simply applying it to the same set we've been working with. Of course,
in a production setting this would be on a different set of documents
streaming in from some source.

In [0]:
import mlflow
import mlflow.transformers

import pandas as pd

In [0]:
dbutils.widgets.text("catalog_name", "")
dbutils.widgets.text("schema_name", "")
dbutils.widgets.text("model_name", "")
dbutils.widgets.text("model_alias", "")

catalog_name = dbutils.widgets.get("catalog_name")
schema_name = dbutils.widgets.get("schema_name")
model_name = dbutils.widgets.get("model_name")
model_alias = dbutils.widgets.get("model_alias")

assert catalog_name != "", "catalog_name is required"
assert schema_name != "", "schema_name is required"
assert model_name != "", "model_name is required"
assert model_alias != "", "model_alias is required"

spark.sql(f"USE CATALOG {catalog_name}")
spark.sql(f"USE SCHEMA {schema_name}")

full_model_name = f"{catalog_name}.{schema_name}.{model_name}"
model_uri = f"models:/{full_model_name}@champion"

source_table_name = "yelp_reviews_silver"
target_table_name = "yelp_reviews_model_output"

print(f"catalog_name: {catalog_name}")
print(f"schema_name: {schema_name}")
print(f"model_name: {model_name}")
print(f"model_alias: {model_alias}")
print(f"model_uri: {model_uri}")

print(f"source_table_name: {source_table_name}")
print(f"target_table_name: {target_table_name}")

In [0]:
classify_review = mlflow.pyfunc.spark_udf(
    spark=spark, 
    model_uri=model_uri,
    result_type="label string, score double"
)

In [0]:
(
    spark.table(source_table_name)
    .repartition(40) # fan out to help distribute inference load
    .withColumn("prediction", classify_review("text"))
    .repartition(8) # fan in for writing fewer files
    .write.mode("overwrite")
    .saveAsTable(target_table_name)
)

In [0]:
display(spark.table(target_table_name))