# Distributed model inference using TensorFlow Keras
From: https://docs.databricks.com/_static/notebooks/deep-learning/keras-metadata.html

In [None]:
import os
import shutil
import time
import pandas as pd
from PIL import Image
import numpy as np
import uuid
 
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import ResNet50
 
from pyspark.sql.functions import col, pandas_udf, PandasUDFType

In [None]:
file_name = "image_data.parquet"
output_file_path = "predictions"

### Prepare trained model and data for inference

Load the ResNet-50 Model and broadcast the weights.

In [None]:
model = ResNet50()
bc_model_weights = sc.broadcast(model.get_weights())

Load the data and save the datasets to one Parquet file.

In [None]:
import pathlib
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file(origin=dataset_url,
                                   fname='flower_photos',
                                   untar=True)
data_dir = pathlib.Path(data_dir)

In [None]:
image_count = len(list(data_dir.glob('*/*.jpg')))
print(image_count)

In [None]:
import os
files = [os.path.join(dp, f) for dp, dn, filenames in os.walk(data_dir) for f in filenames if os.path.splitext(f)[1] == '.jpg']
files = files[:2048]
len(files)

In [None]:
print(data_dir)

In [None]:
image_data = []
for file in files:
    img = Image.open(file)
    img = img.resize([224, 224])
    data = np.asarray(img, dtype="float32").reshape([224*224*3])

    image_data.append({"data": data})

pandas_df = pd.DataFrame(image_data, columns=['data'])
pandas_df.to_parquet(file_name)
# os.makedirs(dbfs_file_path)
# shutil.copyfile(file_name, dbfs_file_path+file_name)

### Load the data into Spark DataFrames

In [None]:
from pyspark.sql.types import *
df = spark.read.parquet(file_name)
print(df.count())

In [None]:
# spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "1024")
spark.conf.set("spark.sql.parquet.columnarReaderBatchSize", "1024")

In [None]:
assert len(df.head()) > 0, "`df` should not be empty" # This line will fail if the vectorized reader runs out of memory

### Run model inference via pandas UDF

In [None]:
def parse_image(image_data):
    image = tf.image.convert_image_dtype(
        image_data, dtype=tf.float32) * (2. / 255) - 1
    image = tf.reshape(image, [224, 224, 3])
    return image

In [None]:
@pandas_udf(ArrayType(FloatType()), PandasUDFType.SCALAR_ITER)
def predict_batch_udf(image_batch_iter):
    batch_size = 64
    model = ResNet50(weights=None)
    model.set_weights(bc_model_weights.value)
    for image_batch in image_batch_iter:
        images = np.vstack(image_batch)
        dataset = tf.data.Dataset.from_tensor_slices(images)
        dataset = dataset.map(parse_image, num_parallel_calls=8).prefetch(
            5000).batch(batch_size)
        preds = model.predict(dataset)
        yield pd.Series(list(preds))

In [None]:
%%time
predictions_df = df.select(predict_batch_udf(col("data")).alias("prediction"))
predictions_df.write.mode("overwrite").parquet(output_file_path)

In [None]:
result_df = spark.read.parquet(output_file_path)
result_df.show(truncate=120)

### Model inference using sparkext

In [None]:
from pyspark.sql.functions import struct
from sparkext.tensorflow import model_udf

In [None]:
model = ResNet50()

In [None]:
predict_batch_udf = model_udf(model, batch_size=64)

In [None]:
%%time
predictions_df = df.select(predict_batch_udf(struct("data")).alias("prediction"))
predictions_df.write.mode("overwrite").parquet(output_file_path + "_sparkext")

In [None]:
result_df = spark.read.parquet(output_file_path)
result_df.show(truncate=120)

### Model inference using MLFlow

In [None]:
import mlflow
from mlflow.models.signature import infer_signature  #, ModelSignature
# from mlflow.types.schema import Schema, TensorSpec

In [None]:
train_images = np.vstack(pandas_df['data'].head(10).to_numpy()).reshape(-1,224,224,3)
predictions = model.predict(train_images)

In [None]:
signature = infer_signature(train_images, model.predict(train_images))
signature

In [None]:
tf.keras.models.save_model(model, "resnet50_model")

In [None]:
mlflow.tensorflow.save_model(tf_saved_model_dir="resnet50_model", 
                             tf_meta_graph_tags=["serve"], 
                             tf_signature_def_key="serving_default",
                             signature=signature,
                             path="resnet50_mlflow")

In [None]:
predict_batch_udf = mlflow.pyfunc.spark_udf(spark, model_uri="resnet50_mlflow", result_type="array<float>")

In [None]:
%%time
predictions_df = df.select(predict_batch_udf(struct("data")).alias("prediction"))
predictions_df.write.mode("overwrite").parquet(output_file_path + "_mlflow")

### Model inference using Triton UDF

In [None]:
from sparkext.triton import model_udf

In [None]:
model = ResNet50()

In [None]:
predict_batch_udf = model_udf(model, batch_size=64)

In [None]:
%%time
predictions_df = df.select(predict_batch_udf(col("data")).alias("prediction"))
predictions_df.write.mode("overwrite").parquet(output_file_path + "_sparkext")

In [None]:
result_df = spark.read.parquet(output_file_path)
result_df.show(truncate=120)

### Model inference using Spark DL API

In [1]:
from pyspark.ml.udf import model_udf
from pyspark.sql.functions import struct
from pyspark.sql.types import ArrayType, FloatType

In [2]:
def model_fn():
    import tensorflow as tf
    from tensorflow.keras.applications.resnet50 import ResNet50
    model = ResNet50()
    def predict(inputs):
        return model.predict(inputs)
    return predict

In [3]:
classify = model_udf(model_fn, 
                     input_shapes=[[-1,224,224,3]], 
                     return_type=ArrayType(FloatType()), 
                     batch_size=50)

In [4]:
# spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "1024")
spark.conf.set("spark.sql.parquet.columnarReaderBatchSize", "1024")

In [5]:
df = spark.read.parquet("image_data.parquet")

                                                                                

In [7]:
predictions = df.select(classify(struct("data")).alias("prediction"))
predictions.show(truncate=120)



+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                              prediction|
+------------------------------------------------------------------------------------------------------------------------+
|[6.4971744E-9, 1.7927986E-8, 7.304562E-10, 1.12441E-10, 4.790853E-9, 2.277025E-8, 3.220624E-10, 1.1413281E-7, 1.27016...|
|[4.291713E-10, 1.0503789E-9, 3.117148E-11, 1.3786824E-11, 9.42579E-11, 1.034047E-9, 6.296592E-12, 3.2948364E-9, 2.302...|
|[7.158145E-5, 8.0730024E-5, 0.1518421, 8.117428E-4, 8.8960776E-4, 1.0977733E-4, 9.939754E-5, 1.1473E-4, 3.261234E-5, ...|
|[7.696654E-6, 1.7495967E-4, 1.17445E-4, 1.8638202E-4, 4.1152147E-4, 3.5063238E-4, 1.0448097E-4, 1.8669982E-5, 6.10756...|
|[1.923105E-6, 3.6728245E-5, 4.2559903E-5, 1.1439298E-5, 1.7003453E-4, 7.114601E-5, 2.932128E-4, 6.8620243E-6, 2.14682...|
|[9.354073E-6, 1

                                                                                

In [8]:
spark.stop()