# Pyspark TensorFlow Inference

## Image classification
Based on: https://www.tensorflow.org/tutorials/keras/save_and_load

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import subprocess
import tensorflow as tf

from tensorflow import keras

print(tf.version.VERSION)

### Load and preprocess dataset

In [None]:
# load dataset as numpy arrays
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_images.shape, test_images.shape

In [None]:
# flatten and normalize
train_images = train_images.reshape(-1, 784) / 255.0
test_images = test_images.reshape(-1, 784) / 255.0

In [None]:
train_images.shape, test_images.shape

### Define a model

In [None]:
# Define a simple sequential model
def create_model():
    model = tf.keras.models.Sequential([
        keras.layers.Dense(512, activation='relu', input_shape=(784,)),
        keras.layers.Dropout(0.2),
        keras.layers.Dense(10)
    ])

    model.compile(optimizer='adam',
                    loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
                    metrics=[tf.metrics.SparseCategoricalAccuracy()])

    return model

# Create a basic model instance
model = create_model()

# Display the model's architecture
model.summary()

### Train model

In [None]:
model.fit(train_images, 
          train_labels,  
          epochs=5,
          validation_data=(test_images, test_labels))

In [None]:
test_img = test_images[:1]
model.predict(test_img)

In [None]:
plt.figure()
plt.imshow(test_img.reshape(28,28))
plt.show()

### Save Model

In [None]:
subprocess.call("rm -rf mnist_model".split())

In [None]:
model.save('mnist_model')

### Inspect saved model

In [None]:
subprocess.call("tree mnist_model".split())

In [None]:
subprocess.call("saved_model_cli show --dir mnist_model --tag_set serve --signature_def serving_default".split())

### Load model

In [None]:
new_model = tf.keras.models.load_model('mnist_model')
new_model.summary()

### Predict

In [None]:
new_model.predict(test_images[:1])

## PySpark

In [None]:
import pandas as pd

# from pyspark.sql.functions import col, pandas_udf, PandasUDFType
# from pyspark.sql.types import *

### Convert numpy array to Spark DataFrame (via Pandas DataFrame)

In [None]:
# numpy array to pandas DataFrame
test_pdf = pd.DataFrame(test_images)
test_pdf.shape

In [None]:
%%time
# 784 columns of float
df = spark.createDataFrame(test_pdf)

In [None]:
%%time
# 1 column of array<float>
test_pdf['data'] = test_pdf.values.tolist()
pdf = test_pdf[['data']]
pdf.shape

In [None]:
%%time
df = spark.createDataFrame(pdf).repartition(10)

### Save the test dataset as parquet files

In [None]:
df.write.mode("overwrite").parquet("mnist_test")

### Check arrow memory configuration

In [None]:
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "128")
# This line will fail if the vectorized reader runs out of memory
assert len(df.head()) > 0, "`df` should not be empty" 

## Inference using Spark ML Model
Note: you can restart the kernel and run from this point to simulate running in a different node or environment.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import sparkext

In [None]:
df = spark.read.parquet("mnist_test")

In [None]:
model = sparkext.tensorflow.Model("mnist_model")

In [None]:
predictions = model.transform(df)

In [None]:
predictions.write.mode("overwrite").parquet("mnist_predictions")

### Check predictions

In [None]:
predictions.take(1)

In [None]:
img = df.take(1)[0].data
img = np.array(img).reshape(28,28)

In [None]:
plt.figure()
plt.imshow(img)
plt.show()

## Inference using Spark DL UDF

In [None]:
from pyspark.sql.functions import col

In [None]:
df = spark.read.parquet("mnist_test")

In [None]:
df.schema

In [None]:
from sparkext.tensorflow import model_udf

In [None]:
mnist = model_udf("mnist_model")

In [None]:
df.withColumn("preds", mnist(col("data"))).show()