In [None]:
import os
os.environ['PYSPARK_SUBMIT_ARGS'] = "--packages=org.apache.hadoop:hadoop-aws:3.3.2 pyspark-shell"

from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('Fruits').getOrCreate()
spark._jsc.hadoopConfiguration().set("fs.s3a.endpoint", "s3-eu-west-3.amazonaws.com")
spark._jsc.hadoopConfiguration().set("com.amazonaws.services.s3a.enableV4", "true")

rdd = spark.read.format("binaryFile").option("recursiveFileLookup","true").option("dropInvalid", True).load("s3a://ocr-mxdub/test_data/")
rdd = rdd.select(['content', 'path']).rdd.map(lambda x : (x[0], x[1].split('/')[4] )).toDF(['Image', 'Label'])

print("PARTITIONS : {}".format(rdd.rdd.getNumPartitions()))
rdd = rdd.coalesce(2)
print("PARTITIONS : {}".format(rdd.rdd.getNumPartitions()))

:: loading settings :: url = jar:file:/home/ubuntu/.local/lib/python3.10/site-packages/pyspark/jars/ivy-2.5.0.jar!/org/apache/ivy/core/settings/ivysettings.xml


Ivy Default Cache set to: /home/ubuntu/.ivy2/cache
The jars for the packages stored in: /home/ubuntu/.ivy2/jars
org.apache.hadoop#hadoop-aws added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-54c89494-6eac-4b50-a600-718e45e09320;1.0
	confs: [default]
	found org.apache.hadoop#hadoop-aws;3.3.2 in central
	found com.amazonaws#aws-java-sdk-bundle;1.11.1026 in central
	found org.wildfly.openssl#wildfly-openssl;1.0.7.Final in central
:: resolution report :: resolve 730ms :: artifacts dl 17ms
	:: modules in use:
	com.amazonaws#aws-java-sdk-bundle;1.11.1026 from central in [default]
	org.apache.hadoop#hadoop-aws;3.3.2 from central in [default]
	org.wildfly.openssl#wildfly-openssl;1.0.7.Final from central in [default]
	---------------------------------------------------------------------
	|                  |            modules            ||   artifacts   |
	|       conf       | number| search|dwnlded|evicted|| number|dwnlded|
	------------------------------

22/09/05 13:47:35 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


In [None]:
import tensorflow as tf
from keras.models import Model
# from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
from keras.applications.vgg16 import VGG16, preprocess_input
from tensorflow.keras.preprocessing.image import img_to_array

# Needed PANDAS_UDF
from pyspark.sql.functions import col, pandas_udf, PandasUDFType
from typing import Iterator

import numpy as np
import pandas as pd
import io
from PIL import Image

model = VGG16()
model_1 = Model(inputs=model.inputs, outputs=model.layers[-2].output)

bc_model_weights = spark.sparkContext.broadcast(model_1.get_weights())

def model_fn():
    """
    Returns a VGG16 model with top layer removed and broadcasted pretrained weights.
    """
    model = VGG16(weights=None)
    model_1 = Model(inputs=model.inputs, outputs=model.layers[-2].output)
    
    # model = ResNet50(weights=None, include_top=False)
    model_1.set_weights(bc_model_weights.value)
    return model_1

def preprocess(content):
    """
    Preprocesses raw image bytes for prediction.
    """
    img = Image.open(io.BytesIO(content)).resize([224, 224])
    arr = img_to_array(img)
    return preprocess_input(arr)

def featurize_series(model, content_series):
    """
    Featurize a pd.Series of raw images using the input model.
    :return: a pd.Series of image features
    """
    input = np.stack(content_series.map(preprocess))
    preds = model.predict(input)
    output = [p.flatten() for p in preds]
    return pd.Series(output)


@pandas_udf('array<float>')
def featurize_udf(content_series_iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
    '''
    This method is a Scalar Iterator pandas UDF wrapping our featurization function.
    The decorator specifies that this returns a Spark DataFrame column of type ArrayType(FloatType).

    :param content_series_iter: This argument is an iterator over batches of data, where each batch
                              is a pandas Series of image data.
    '''
    # With Scalar Iterator pandas UDFs, we can load the model once and then re-use it
    # for multiple data batches.  This amortizes the overhead of loading big models.
    model = model_fn()
    for content_series in content_series_iter:
        yield featurize_series(model, content_series)
    
# Pandas UDFs on large records (e.g., very large images) can run into Out Of Memory (OOM) errors.
# If you hit such errors in the cell below, try reducing the Arrow batch size via `maxRecordsPerBatch`.
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "1024")

features_df = rdd.select(col("Label"), featurize_udf("Image").alias("features"))

In [None]:
# Get last layer output size (i.e. # of features)
feats_size_VGG16 = model_1.layers[-1].output.shape[1]

# Transform as wide df & save
features_long = features_df.select([col("Label")] + [(col("features")[i]).alias("feats_{}".format(i)) for i in range(feats_size_VGG16)])
features_long.coalesce(1).write.mode("overwrite").option("header","true").csv("s3a://ocr-mxdub/results/")