<a href="https://colab.research.google.com/github/dungnc-uit/BigDataExcercise/blob/main/Spark_DL_model_pandas_udf_and_predict_batch_udf.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pyspark

Collecting pyspark
  Downloading pyspark-3.5.1.tar.gz (317.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m317.0/317.0 MB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: pyspark
  Building wheel for pyspark (setup.py) ... [?25l[?25hdone
  Created wheel for pyspark: filename=pyspark-3.5.1-py2.py3-none-any.whl size=317488491 sha256=a4affa711081110e8375d49da577651848ec2b9cbda3d4298f52d542ac656d40
  Stored in directory: /root/.cache/pip/wheels/80/1d/60/2c256ed38dddce2fdd93be545214a63e02fbd8d74fb0b7f3a6
Successfully built pyspark
Installing collected packages: pyspark
Successfully installed pyspark-3.5.1


In [None]:
import os
import shutil
import subprocess
import time
import pandas as pd
from PIL import Image
import numpy as np
import uuid
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import ResNet50
from pyspark.sql.functions import col, pandas_udf, PandasUDFType
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
sc = spark.sparkContext

  self.pid = _posixsubprocess.fork_exec(


# Chuẩn bị dữ liệu để suy luận mô hình học sâu

## Load dữ liệu

In [None]:
import pathlib
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file(origin=dataset_url, fname='flower_photos', untar=True)
data_dir = pathlib.Path(data_dir)

Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz


In [None]:
print(data_dir)

/root/.keras/datasets/flower_photos


In [None]:
image_count = len(list(data_dir.glob('*/*.jpg')))
print(image_count)

3670


In [None]:
import os
files = [os.path.join(dp, f) for dp, dn, filenames in os.walk(data_dir) for f in filenames if os.path.splitext(f)[1] == '.jpg']
files = files[:128]
len(files)

128

In [None]:
files[0]

'/root/.keras/datasets/flower_photos/dandelion/19437578578_6ab1b3c984.jpg'

## Lưu dữ liệu vào file parquet

In [None]:
file_name = "image_data.parquet"
image_data = []
for file in files:
    img = Image.open(file)
    img = img.resize([224, 224])
    data = np.asarray(img, dtype="float32").reshape([224*224*3])

    image_data.append({"data": data})

pandas_df = pd.DataFrame(image_data, columns=['data'])
pandas_df.to_parquet(file_name)

# Load dữ liệu vào Spark DataFrame

In [None]:
from pyspark.sql.types import *
df = spark.read.parquet(file_name)
print(df.count())

128


In [None]:
df

DataFrame[data: array<float>]

In [None]:
from pyspark.sql import functions as f
df.withColumn('size',f.size('data')).show()

+--------------------+------+
|                data|  size|
+--------------------+------+
|[24.0, 27.0, 16.0...|150528|
|[15.0, 11.0, 26.0...|150528|
|[26.0, 38.0, 1.0,...|150528|
|[94.0, 64.0, 51.0...|150528|
|[138.0, 112.0, 19...|150528|
|[28.0, 32.0, 32.0...|150528|
|[77.0, 52.0, 26.0...|150528|
|[52.0, 84.0, 65.0...|150528|
|[253.0, 202.0, 17...|150528|
|[28.0, 31.0, 14.0...|150528|
|[108.0, 108.0, 11...|150528|
|[104.0, 136.0, 12...|150528|
|[163.0, 118.0, 94...|150528|
|[0.0, 0.0, 0.0, 0...|150528|
|[114.0, 114.0, 11...|150528|
|[12.0, 18.0, 14.0...|150528|
|[103.0, 160.0, 22...|150528|
|[86.0, 100.0, 44....|150528|
|[215.0, 172.0, 11...|150528|
|[15.0, 15.0, 15.0...|150528|
+--------------------+------+
only showing top 20 rows



# Suy luận mô hình sử dụng Spark Pandas UDF

# Tạo biến broadcast trọng số của mô hình
Các biến broadcast cho phép việc giữ một biến read-only lưu trong bộ nhớ đệm trên mỗi máy thay vì gửi một bản sao của nó cùng với các tác vụ. Ví dụ: chúng có thể được sử dụng để cung cấp cho mỗi nút một bản sao của tập dữ liệu đầu vào lớn một cách hiệu quả.

In [None]:
model = ResNet50()
bc_model_weights = sc.broadcast(model.get_weights())

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels.h5


## Khai báo hàm chuyển đổi dữ liệu

In [None]:
def parse_image(image_data):
    image = tf.image.convert_image_dtype(image_data, dtype=tf.float32) * (2. / 255) - 1
    image = tf.reshape(image, [224, 224, 3])
    return image

In [None]:
from typing import Iterator
@pandas_udf(ArrayType(FloatType()))
def predict_batch_udf(image_batch_iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
    batch_size = 8
    model = ResNet50(weights=None)
    model.set_weights(bc_model_weights.value)
    for image_batch in image_batch_iter:
        images = np.vstack(image_batch)
        dataset = tf.data.Dataset.from_tensor_slices(images)
        dataset = dataset.map(parse_image, num_parallel_calls=8).prefetch(32).batch(batch_size)
        preds = model.predict(dataset)
        yield pd.Series(list(preds))

In [None]:
spark.conf.set("spark.sql.parquet.columnarReaderBatchSize", "16")

In [None]:
%%time
predictions_df = df.select(predict_batch_udf(col("data")).alias("prediction"))
predictions_df.show(truncate=120)

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                              prediction|
+------------------------------------------------------------------------------------------------------------------------+
|[6.5045315E-5, 2.536672E-4, 6.166392E-5, 9.524285E-5, 4.4628337E-5, 3.1678902E-4, 4.6867945E-6, 1.02610866E-4, 1.1784...|
|[1.3501532E-4, 2.857018E-4, 7.018257E-5, 1.2589457E-4, 5.772912E-5, 4.1720783E-4, 7.941275E-6, 3.6279824E-5, 1.406256...|
|[8.8571884E-5, 2.3808015E-4, 6.2431776E-5, 1.2254076E-4, 5.0543436E-5, 2.8942715E-4, 5.1670418E-6, 4.669184E-5, 1.097...|
|[1.05079496E-4, 2.4284833E-4, 3.8054786E-5, 8.264326E-5, 3.1239102E-5, 4.0001515E-4, 3.3035858E-6, 3.3604727E-5, 8.57...|
|[5.7685378E-5, 3.0485954E-4, 5.7428577E-5, 9.417652E-5, 3.9902014E-5, 2.0984854E-4, 5.377597E-6, 7.4888805E-5, 1.3330...|
|[6.246951E-5, 2

In [None]:
predictions_df

DataFrame[prediction: array<float>]

In [None]:
predictions_df.collect()[1]

Row(prediction=[0.00013501531793735921, 0.0002857017971109599, 7.018257019808516e-05, 0.00012589457037393004, 5.772912118118256e-05, 0.00041720783337950706, 7.941274816403165e-06, 3.6279823689255863e-05, 1.406256797054084e-05, 0.0001362836774205789, 0.0007481772918254137, 0.0001569865271449089, 5.8251684095012024e-05, 8.507569873472676e-05, 1.4413484677788801e-05, 4.1388113459106535e-05, 0.0001397128071403131, 2.678146665857639e-05, 5.4917843954171985e-05, 5.063799108029343e-05, 0.0002114081580657512, 0.0020584629382938147, 0.001051985309459269, 0.0002556239196565002, 8.971147326519713e-05, 9.972453699447215e-05, 0.0003255382471252233, 0.00022802173043601215, 0.00015921343583613634, 0.00019165195408277214, 4.1033068555407226e-05, 0.00036677648313343525, 8.733449794817716e-05, 0.0001443113142158836, 0.0003005076723638922, 1.3675909031007905e-05, 0.00012431168579496443, 9.469547876506113e-06, 0.005958483554422855, 3.172017750330269e-05, 5.005319326301105e-05, 0.00019215275824535638, 0.00

In [None]:
fr = predictions_df.first()

In [None]:
from pyspark.sql import functions as f
predictions_df.withColumn('size',f.size('prediction')).show()

+--------------------+----+
|          prediction|size|
+--------------------+----+
|[6.5045315E-5, 2....|1000|
|[1.3501532E-4, 2....|1000|
|[8.8571884E-5, 2....|1000|
|[1.05079496E-4, 2...|1000|
|[5.7685378E-5, 3....|1000|
|[6.246951E-5, 2.5...|1000|
|[1.2112781E-5, 2....|1000|
|[1.1804936E-4, 2....|1000|
|[1.3531194E-4, 1....|1000|
|[9.201045E-5, 1.9...|1000|
|[5.973088E-5, 3.0...|1000|
|[6.614648E-5, 2.0...|1000|
|[9.096371E-5, 2.9...|1000|
|[3.9521132E-5, 2....|1000|
|[1.22014855E-4, 2...|1000|
|[1.05473904E-4, 2...|1000|
|[9.074605E-5, 2.2...|1000|
|[1.10572684E-4, 2...|1000|
|[1.20635406E-4, 2...|1000|
|[1.1510637E-4, 2....|1000|
+--------------------+----+
only showing top 20 rows



In [None]:
%%time
output_file_path = "predictions"
predictions_df.write.mode("overwrite").parquet(output_file_path)

CPU times: user 425 ms, sys: 48.2 ms, total: 473 ms
Wall time: 1min 3s


# Suy luận sử dụng Spark DL API

In [None]:
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.functions import struct, col
from pyspark.sql.types import ArrayType, FloatType

In [None]:
def predict_batch_fn():
    import tensorflow as tf
    from tensorflow.keras.applications.resnet50 import ResNet50
    model = ResNet50()
    def predict(inputs):
        inputs = inputs * (2. / 255) - 1
        return model.predict(inputs)
    return predict

In [None]:
classify = predict_batch_udf(predict_batch_fn,
                             input_tensor_shapes=[[224, 224, 3]],
                             return_type=ArrayType(FloatType()),
                             batch_size=8)

In [None]:
spark.conf.set("spark.sql.parquet.columnarReaderBatchSize", "16")

In [None]:
df = spark.read.parquet("image_data.parquet")

In [None]:
%%time
# first pass caches model/fn
predictions = df.select(classify(struct("data")).alias("prediction"))
predictions.show(truncate=120)

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                              prediction|
+------------------------------------------------------------------------------------------------------------------------+
|[6.5045315E-5, 2.536672E-4, 6.166392E-5, 9.524285E-5, 4.4628337E-5, 3.1678902E-4, 4.6867945E-6, 1.02610866E-4, 1.1784...|
|[1.3501532E-4, 2.857018E-4, 7.018257E-5, 1.2589457E-4, 5.772912E-5, 4.1720783E-4, 7.941275E-6, 3.6279824E-5, 1.406256...|
|[8.8571884E-5, 2.3808015E-4, 6.2431776E-5, 1.2254076E-4, 5.0543436E-5, 2.8942715E-4, 5.1670418E-6, 4.669184E-5, 1.097...|
|[1.05079496E-4, 2.4284833E-4, 3.8054786E-5, 8.264326E-5, 3.1239102E-5, 4.0001515E-4, 3.3035858E-6, 3.3604727E-5, 8.57...|
|[5.7685378E-5, 3.0485954E-4, 5.7428577E-5, 9.417652E-5, 3.9902014E-5, 2.0984854E-4, 5.377597E-6, 7.4888805E-5, 1.3330...|
|[6.246951E-5, 2