# Distributed model inference using TensorFlow Keras
From: https://docs.databricks.com/_static/notebooks/deep-learning/keras-metadata.html

In [1]:
import os
import shutil
import time
import pandas as pd
from PIL import Image
import numpy as np
import uuid
 
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import ResNet50
 
from pyspark.sql.functions import col, pandas_udf, PandasUDFType

2023-05-04 14:02:10.988573: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-05-04 14:02:11.014948: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
file_name = "image_data.parquet"
output_file_path = "predictions"

### Prepare trained model and data for inference

Load the ResNet-50 Model and broadcast the weights.

In [3]:
model = ResNet50()
bc_model_weights = sc.broadcast(model.get_weights())

Load the data and save the datasets to one Parquet file.

In [4]:
import pathlib
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file(origin=dataset_url,
                                   fname='flower_photos',
                                   untar=True)
data_dir = pathlib.Path(data_dir)

In [5]:
image_count = len(list(data_dir.glob('*/*.jpg')))
print(image_count)

3670


In [6]:
import os
files = [os.path.join(dp, f) for dp, dn, filenames in os.walk(data_dir) for f in filenames if os.path.splitext(f)[1] == '.jpg']
files = files[:2048]
len(files)

2048

In [7]:
print(data_dir)

/home/leey/.keras/datasets/flower_photos


In [8]:
image_data = []
for file in files:
    img = Image.open(file)
    img = img.resize([224, 224])
    data = np.asarray(img, dtype="float32").reshape([224*224*3])

    image_data.append({"data": data})

pandas_df = pd.DataFrame(image_data, columns=['data'])
pandas_df.to_parquet(file_name)
# os.makedirs(dbfs_file_path)
# shutil.copyfile(file_name, dbfs_file_path+file_name)

### Load the data into Spark DataFrames

In [9]:
from pyspark.sql.types import *
df = spark.read.parquet(file_name)
print(df.count())



2048


                                                                                

In [10]:
# spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "1024")
spark.conf.set("spark.sql.parquet.columnarReaderBatchSize", "1024")

In [11]:
assert len(df.head()) > 0, "`df` should not be empty" # This line will fail if the vectorized reader runs out of memory

                                                                                

### Model inference via pandas UDF

In [12]:
def parse_image(image_data):
    image = tf.image.convert_image_dtype(
        image_data, dtype=tf.float32) * (2. / 255) - 1
    image = tf.reshape(image, [224, 224, 3])
    return image

In [13]:
@pandas_udf(ArrayType(FloatType()), PandasUDFType.SCALAR_ITER)
def predict_batch_udf(image_batch_iter):
    batch_size = 64
    model = ResNet50(weights=None)
    model.set_weights(bc_model_weights.value)
    for image_batch in image_batch_iter:
        images = np.vstack(image_batch)
        dataset = tf.data.Dataset.from_tensor_slices(images)
        dataset = dataset.map(parse_image, num_parallel_calls=8).prefetch(
            5000).batch(batch_size)
        preds = model.predict(dataset)
        yield pd.Series(list(preds))



In [14]:
%%time
predictions_df = df.select(predict_batch_udf(col("data")).alias("prediction"))
predictions_df.show(truncate=120)



+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                              prediction|
+------------------------------------------------------------------------------------------------------------------------+
|[1.7195931E-4, 3.2735552E-4, 7.533328E-5, 1.9810574E-4, 6.6188135E-5, 4.304618E-4, 1.3097588E-5, 5.2791635E-5, 2.3571...|
|[7.738322E-5, 5.935413E-4, 1.3075814E-4, 1.8000253E-4, 1.3001164E-4, 2.2790757E-4, 5.0780895E-5, 2.3768713E-4, 3.8676...|
|[2.3994662E-5, 3.4409956E-4, 1.8440963E-4, 1.6311614E-4, 1.4301096E-4, 8.883273E-5, 2.6071444E-5, 2.3609246E-4, 1.481...|
|[7.33013E-5, 2.6747186E-4, 1.2007599E-4, 2.0717822E-4, 8.630799E-5, 2.2318888E-4, 1.3459768E-5, 8.564859E-5, 1.482114...|
|[5.592278E-5, 2.5165555E-4, 1.3106635E-4, 1.5746983E-4, 8.4952095E-5, 2.1269226E-4, 9.83865E-6, 1.6246884E-4, 1.24666...|
|[1.1709497E-4, 

                                                                                

In [15]:
%%time
predictions_df.write.mode("overwrite").parquet(output_file_path)



CPU times: user 92.6 ms, sys: 19.7 ms, total: 112 ms
Wall time: 1min 12s


                                                                                

### Model inference using Spark DL API

In [16]:
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.functions import struct, col
from pyspark.sql.types import ArrayType, FloatType

In [17]:
def predict_batch_fn():
    import tensorflow as tf
    from tensorflow.keras.applications.resnet50 import ResNet50
    model = ResNet50()
    def predict(inputs):
        inputs = inputs * (2. / 255) - 1
        return model.predict(inputs)
    return predict

In [18]:
classify = predict_batch_udf(predict_batch_fn,
                             input_tensor_shapes=[[224, 224, 3]],
                             return_type=ArrayType(FloatType()),
                             batch_size=50)

In [19]:
# spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "1024")
spark.conf.set("spark.sql.parquet.columnarReaderBatchSize", "1024")

In [20]:
df = spark.read.parquet("image_data.parquet")

In [21]:
%%time
# first pass caches model/fn
predictions = df.select(classify(struct("data")).alias("prediction"))
predictions.show(truncate=120)



+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                              prediction|
+------------------------------------------------------------------------------------------------------------------------+
|[1.7195931E-4, 3.2735552E-4, 7.533328E-5, 1.9810574E-4, 6.6188135E-5, 4.304618E-4, 1.3097588E-5, 5.2791635E-5, 2.3571...|
|[7.738322E-5, 5.935413E-4, 1.3075814E-4, 1.8000253E-4, 1.3001164E-4, 2.2790757E-4, 5.0780895E-5, 2.3768713E-4, 3.8676...|
|[2.3994662E-5, 3.4409956E-4, 1.8440963E-4, 1.6311614E-4, 1.4301096E-4, 8.883273E-5, 2.6071444E-5, 2.3609246E-4, 1.481...|
|[7.33013E-5, 2.6747186E-4, 1.2007599E-4, 2.0717822E-4, 8.630799E-5, 2.2318888E-4, 1.3459768E-5, 8.564859E-5, 1.482114...|
|[5.592278E-5, 2.5165555E-4, 1.3106635E-4, 1.5746983E-4, 8.4952095E-5, 2.1269226E-4, 9.83865E-6, 1.6246884E-4, 1.24666...|
|[1.1709497E-4, 

                                                                                

In [22]:
%%time
predictions = df.select(classify("data").alias("prediction"))
predictions.show(truncate=120)



+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                              prediction|
+------------------------------------------------------------------------------------------------------------------------+
|[1.7195931E-4, 3.2735552E-4, 7.533328E-5, 1.9810574E-4, 6.6188135E-5, 4.304618E-4, 1.3097588E-5, 5.2791635E-5, 2.3571...|
|[7.738322E-5, 5.935413E-4, 1.3075814E-4, 1.8000253E-4, 1.3001164E-4, 2.2790757E-4, 5.0780895E-5, 2.3768713E-4, 3.8676...|
|[2.3994662E-5, 3.4409956E-4, 1.8440963E-4, 1.6311614E-4, 1.4301096E-4, 8.883273E-5, 2.6071444E-5, 2.3609246E-4, 1.481...|
|[7.33013E-5, 2.6747186E-4, 1.2007599E-4, 2.0717822E-4, 8.630799E-5, 2.2318888E-4, 1.3459768E-5, 8.564859E-5, 1.482114...|
|[5.592278E-5, 2.5165555E-4, 1.3106635E-4, 1.5746983E-4, 8.4952095E-5, 2.1269226E-4, 9.83865E-6, 1.6246884E-4, 1.24666...|
|[1.1709497E-4, 

                                                                                

In [23]:
%%time
predictions = df.select(classify(col("data")).alias("prediction"))
predictions.write.mode("overwrite").parquet(output_file_path + "_1")



CPU times: user 91 ms, sys: 15.2 ms, total: 106 ms
Wall time: 1min 6s


                                                                                

### Using Triton Server

For optimal performance., you should use a Triton inference server image that is compatible with the TensorFlow version used to train the model, per the [Framework Containers Support Matrix](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html)

For demonstration purposes, this notebook just uses the [Python backend with a custom execution environment](https://github.com/triton-inference-server/python_backend#creating-custom-execution-environments), using a conda-pack environment created as follows:
```
conda create -n tf-gpu -c conda-forge python=3.8
conda activate tf-gpu
pip install tensorflow
pip install conda-pack
conda pack  

cp tf-gpu.tar.gz ../models_tf
```

#### Start Triton Server on each executor

In [24]:
num_executors = 1

nodeRDD = sc.parallelize(list(range(num_executors)), num_executors)

def start_triton(it):
    import docker
    import time
    import tritonclient.grpc as grpcclient
    
    client=docker.from_env()
    containers=client.containers.list(filters={"name": "spark-triton"})
    if containers:
        print(">>>> containers: {}".format([c.short_id for c in containers]))
    else:
        container=client.containers.run(
            "nvcr.io/nvidia/tritonserver:22.07-py3", "tritonserver --model-repository=/models",
            detach=True,
            device_requests=[docker.types.DeviceRequest(device_ids=["0"], capabilities=[['gpu']])],
            environment=[
                "TRANSFORMERS_CACHE=/cache"
            ],
            name="spark-triton",
            network_mode="host",
            remove=True,
            shm_size="512M",
            volumes={
                "/home/leey/devpub/leewyang/sparkext/examples/models": {"bind": "/models", "mode": "ro"},
                "/home/leey/huggingface/cache": {"bind": "/cache", "mode": "rw"}
            }
        )
        print(">>>> starting triton: {}".format(container.short_id))

        # wait for triton to be running
        time.sleep(15)
        client = grpcclient.InferenceServerClient("localhost:8001")
        ready = False
        while not ready:
            try:
                ready = client.is_server_ready()
            except Exception as e:
                time.sleep(5)

    return [True]

nodeRDD.mapPartitions(start_triton).collect()

                                                                                

[True]

#### Run inference

In [25]:
from functools import partial
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.functions import struct, col
from pyspark.sql.types import ArrayType, FloatType

In [26]:
def triton_fn(triton_uri, model_name):
    import numpy as np
    import tritonclient.grpc as grpcclient
    
    np_types = {
      "BOOL": np.dtype(np.bool8),
      "INT8": np.dtype(np.int8),
      "INT16": np.dtype(np.int16),
      "INT32": np.dtype(np.int32),
      "INT64": np.dtype(np.int64),
      "FP16": np.dtype(np.float16),
      "FP32": np.dtype(np.float32),
      "FP64": np.dtype(np.float64),
      "FP64": np.dtype(np.double),
      "BYTES": np.dtype(object)
    }

    client = grpcclient.InferenceServerClient(triton_uri)
    model_meta = client.get_model_metadata(model_name)
    
    def predict(inputs):
        if isinstance(inputs, np.ndarray):
            # single ndarray input
            inputs = inputs * (2. / 255) - 1  # normalization
            request = [grpcclient.InferInput(model_meta.inputs[0].name, inputs.shape, model_meta.inputs[0].datatype)]
            request[0].set_data_from_numpy(inputs.astype(np_types[model_meta.inputs[0].datatype]))
        else:
            # dict of multiple ndarray inputs
            request = [grpcclient.InferInput(i.name, inputs[i.name].shape, i.datatype) for i in model_meta.inputs]
            for i in request:
                i.set_data_from_numpy(inputs[i.name()].astype(np_types[i.datatype()]))
        
        response = client.infer(model_name, inputs=request)
        
        if len(model_meta.outputs) > 1:
            # return dictionary of numpy arrays
            return {o.name: response.as_numpy(o.name) for o in model_meta.outputs}
        else:
            # return single numpy array
            return response.as_numpy(model_meta.outputs[0].name)
        
    return predict

In [27]:
classify = predict_batch_udf(partial(triton_fn, triton_uri="localhost:8001", model_name="resnet_50"),
                             input_tensor_shapes=[[224, 224, 3]],
                             return_type=ArrayType(FloatType()),
                             batch_size=50)

In [28]:
# spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "1024")
spark.conf.set("spark.sql.parquet.columnarReaderBatchSize", "1024")

In [29]:
df = spark.read.parquet("image_data.parquet")

In [30]:
%%time
# first pass caches model/fn
predictions = df.select(classify(struct("data")).alias("prediction"))
predictions.show(truncate=120)



+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                              prediction|
+------------------------------------------------------------------------------------------------------------------------+
|[1.71091E-4, 3.269849E-4, 7.52977E-5, 1.978806E-4, 6.605203E-5, 4.2928016E-4, 1.3055542E-5, 5.279775E-5, 2.3507486E-5...|
|[7.660852E-5, 5.9129955E-4, 1.3076628E-4, 1.7964613E-4, 1.293718E-4, 2.2706043E-4, 5.056295E-5, 2.3749094E-4, 3.86420...|
|[2.3824861E-5, 3.4220197E-4, 1.8398133E-4, 1.626118E-4, 1.4244742E-4, 8.818307E-5, 2.6096091E-5, 2.3705876E-4, 1.4838...|
|[7.2949515E-5, 2.678199E-4, 1.2027239E-4, 2.0707237E-4, 8.63638E-5, 2.2211406E-4, 1.3493031E-5, 8.575731E-5, 1.485200...|
|[5.5637167E-5, 2.522944E-4, 1.3117655E-4, 1.5718519E-4, 8.4829146E-5, 2.1113851E-4, 9.884539E-6, 1.6338409E-4, 1.2518...|
|[1.1688133E-4, 

                                                                                

In [31]:
%%time
predictions = df.select(classify("data").alias("prediction"))
predictions.show(truncate=120)



+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                              prediction|
+------------------------------------------------------------------------------------------------------------------------+
|[1.711523E-4, 3.26866E-4, 7.52245E-5, 1.9781855E-4, 6.600293E-5, 4.2926118E-4, 1.3044922E-5, 5.2771757E-5, 2.3492035E...|
|[7.663703E-5, 5.9194054E-4, 1.3089886E-4, 1.7965207E-4, 1.2942665E-4, 2.2708319E-4, 5.0572096E-5, 2.3747748E-4, 3.865...|
|[2.3856559E-5, 3.422268E-4, 1.8405495E-4, 1.6263021E-4, 1.424253E-4, 8.8358385E-5, 2.6092113E-5, 2.3685934E-4, 1.4848...|
|[7.303531E-5, 2.6780445E-4, 1.201413E-4, 2.070834E-4, 8.6422646E-5, 2.2223375E-4, 1.3498062E-5, 8.576895E-5, 1.485476...|
|[5.5635584E-5, 2.5240038E-4, 1.3120404E-4, 1.5722473E-4, 8.4889776E-5, 2.1113818E-4, 9.889351E-6, 1.633997E-4, 1.2522...|
|[1.1691067E-4, 

                                                                                

In [32]:
%%time
predictions = df.select(classify(col("data")).alias("prediction"))
predictions.write.mode("overwrite").parquet(output_file_path + "_2")



CPU times: user 24.5 ms, sys: 3.08 ms, total: 27.6 ms
Wall time: 25.6 s


                                                                                

#### Stop Triton Server on each executor

In [33]:
def stop_triton(it):
    import docker
    import time
    
    client=docker.from_env()
    containers=client.containers.list(filters={"name": "spark-triton"})
    print(">>>> stopping containers: {}".format([c.short_id for c in containers]))
    if containers:
        container=containers[0]
        container.stop(timeout=120)

    return [True]

nodeRDD.mapPartitions(stop_triton).collect()

                                                                                

[True]

In [34]:
spark.stop()