In [1]:
%%info

### Fruit Classification with PySpark and TensorFlow on AWS


**Notebook Overview:**

This notebook is designed for fruit classification using PySpark and TensorFlow, specifically on an AWS machine. The primary objective is to classify fruit images into different categories using a deep learning model and then store the results in S3 with a .parquet file format. The notebook covers data loading, preprocessing and model training up to the point of obtaining results.

**Tentative Table of Contents:**

1.  **Data Loading and Preprocessing**
    
    *   Reading image data with PySpark from an AWS environment.
    *   Extracting labels from file paths.
    *   Initial data exploration.
2.  **Deep Learning Model Development**
    
    *   Creating a deep learning model using TensorFlow.
    *   Training the model on the reduced data in the AWS environment.
3.  **Fruit Classification and Result Storage in S3**
    
    *   Running fruit classification on the integrated model.
    *   Storing the classification results in S3 with a .parquet file format in the AWS environment.


In [2]:
import pandas as pd
import numpy as np
import io
import os
import tensorflow as tf
from PIL import Image
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2, preprocess_input
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras import Model
from pyspark.sql.functions import col, pandas_udf, PandasUDFType, element_at, split

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,Current session?
0,application_1702059656434_0001,pyspark,idle,Link,Link,✔


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

SparkSession available as 'spark'.


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [3]:
PATH = 's3://iso-can-91'
PATH_Data = PATH+'/Test'
PATH_Result = PATH+'/Results'
print('PATH:        '+\
      PATH+'\nPATH_Data:   '+\
      PATH_Data+'\nPATH_Result: '+PATH_Result)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

PATH:        s3://iso-can-91
PATH_Data:   s3://iso-can-91/Test
PATH_Result: s3://iso-can-91/Results

In [4]:
images = spark.read.format("binaryFile") \
  .option("pathGlobFilter", "*.jpg") \
  .option("recursiveFileLookup", "true") \
  .load(PATH_Data)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [5]:
images.show(5)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+-------------------+------+--------------------+
|                path|   modificationTime|length|             content|
+--------------------+-------------------+------+--------------------+
|s3://iso-can-91/T...|2023-11-30 12:33:33|  7353|[FF D8 FF E0 00 1...|
|s3://iso-can-91/T...|2023-11-30 12:33:33|  7350|[FF D8 FF E0 00 1...|
|s3://iso-can-91/T...|2023-11-30 12:33:33|  7349|[FF D8 FF E0 00 1...|
|s3://iso-can-91/T...|2023-11-30 12:33:33|  7348|[FF D8 FF E0 00 1...|
|s3://iso-can-91/T...|2023-11-30 12:33:33|  7328|[FF D8 FF E0 00 1...|
+--------------------+-------------------+------+--------------------+
only showing top 5 rows

In [6]:
model = MobileNetV2(weights='imagenet',
                    include_top=True,
                    input_shape=(224, 224, 3))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224.h5

In [7]:
images = images.withColumn('label', element_at(split(images['path'], '/'),-2))
print(images.printSchema())
print(images.select('path','label').show(5,False))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

root
 |-- path: string (nullable = true)
 |-- modificationTime: timestamp (nullable = true)
 |-- length: long (nullable = true)
 |-- content: binary (nullable = true)
 |-- label: string (nullable = true)

None
+---------------------------------------------+----------+
|path                                         |label     |
+---------------------------------------------+----------+
|s3://iso-can-91/Test/Watermelon/r_106_100.jpg|Watermelon|
|s3://iso-can-91/Test/Watermelon/r_109_100.jpg|Watermelon|
|s3://iso-can-91/Test/Watermelon/r_108_100.jpg|Watermelon|
|s3://iso-can-91/Test/Watermelon/r_107_100.jpg|Watermelon|
|s3://iso-can-91/Test/Watermelon/r_95_100.jpg |Watermelon|
+---------------------------------------------+----------+
only showing top 5 rows

None

In [9]:
new_model = Model(inputs=model.input,
                  outputs=model.layers[-2].output)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [10]:
brodcast_weights = sc.broadcast(new_model.get_weights())

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [11]:
new_model.summary()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
Conv1 (Conv2D)                  (None, 112, 112, 32) 864         input_1[0][0]                    
__________________________________________________________________________________________________
bn_Conv1 (BatchNormalization)   (None, 112, 112, 32) 128         Conv1[0][0]                      
__________________________________________________________________________________________________
Conv1_relu (ReLU)               (None, 112, 112, 32) 0           bn_Conv1[0][0]                   
____________________________________________________________________________________________

In [12]:
def model_fn():
    """
    Returns a MobileNetV2 model with top layer removed 
    and broadcasted pretrained weights.
    """
    model = MobileNetV2(weights='imagenet',
                        include_top=True,
                        input_shape=(224, 224, 3))
    for layer in model.layers:
        layer.trainable = False
    new_model = Model(inputs=model.input,
                  outputs=model.layers[-2].output)
    new_model.set_weights(brodcast_weights.value)
    return new_model

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [13]:
def preprocess(content):
    """
    Preprocesses raw image bytes for prediction.
    """
    img = Image.open(io.BytesIO(content)).resize([224, 224])
    arr = img_to_array(img)
    return preprocess_input(arr)

def featurize_series(model, content_series):
    """
    Featurize a pd.Series of raw images using the input model.
    :return: a pd.Series of image features
    """
    input = np.stack(content_series.map(preprocess))
    preds = model.predict(input)
    # For some layers, output features will be multi-dimensional tensors.
    # We flatten the feature tensors to vectors for easier storage in Spark DataFrames.
    output = [p.flatten() for p in preds]
    return pd.Series(output)

@pandas_udf('array<float>', PandasUDFType.SCALAR_ITER)
def featurize_udf(content_series_iter):
    '''
    This method is a Scalar Iterator pandas UDF wrapping our featurization function.
    The decorator specifies that this returns a Spark DataFrame column of type ArrayType(FloatType).

    :param content_series_iter: This argument is an iterator over batches of data, where each batch
                              is a pandas Series of image data.
    '''
    # With Scalar Iterator pandas UDFs, we can load the model once and then re-use it
    # for multiple data batches.  This amortizes the overhead of loading big models.
    model = model_fn()
    for content_series in content_series_iter:
        yield featurize_series(model, content_series)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…



In [14]:
features_df = images.repartition(24).select(col("path"),
                                            col("label"),
                                            featurize_udf("content").alias("features")
                                           )

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [15]:
print(PATH_Result)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

s3://iso-can-91/Results

In [16]:
features_df.write.mode("overwrite").parquet(PATH_Result)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [17]:
df = pd.read_parquet(PATH_Result, engine='pyarrow')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [18]:
df.head()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

                                             path  ...                                           features
0    s3://iso-can-91/Test/Watermelon/r_56_100.jpg  ...  [0.98079145, 0.18333912, 0.0036119935, 0.05772...
1     s3://iso-can-91/Test/Watermelon/296_100.jpg  ...  [0.6964483, 0.001793768, 0.017449293, 0.001894...
2     s3://iso-can-91/Test/Watermelon/275_100.jpg  ...  [0.16781013, 0.07297957, 0.0, 0.0, 1.7403593, ...
3  s3://iso-can-91/Test/Cauliflower/r_247_100.jpg  ...  [0.011194742, 0.12858747, 2.4827878, 0.0, 0.0,...
4  s3://iso-can-91/Test/Cauliflower/r_127_100.jpg  ...  [0.0, 0.8310882, 2.347738, 0.0, 0.0, 0.0, 0.0,...

[5 rows x 3 columns]

In [19]:
df.loc[0,'features'].shape

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

(1280,)

In [20]:
df.shape

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

(22688, 3)

In [21]:
df['category'] = df['label'].apply(lambda x: x.split(' ')[0])

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [22]:
df

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

                                                 path  ...     category
0        s3://iso-can-91/Test/Watermelon/r_56_100.jpg  ...   Watermelon
1         s3://iso-can-91/Test/Watermelon/296_100.jpg  ...   Watermelon
2         s3://iso-can-91/Test/Watermelon/275_100.jpg  ...   Watermelon
3      s3://iso-can-91/Test/Cauliflower/r_247_100.jpg  ...  Cauliflower
4      s3://iso-can-91/Test/Cauliflower/r_127_100.jpg  ...  Cauliflower
...                                               ...  ...          ...
22683      s3://iso-can-91/Test/Corn Husk/247_100.jpg  ...         Corn
22684       s3://iso-can-91/Test/Hazelnut/222_100.jpg  ...     Hazelnut
22685   s3://iso-can-91/Test/Banana Red/r_148_100.jpg  ...       Banana
22686       s3://iso-can-91/Test/Banana/r_319_100.jpg  ...       Banana
22687        s3://iso-can-91/Test/Banana/r_66_100.jpg  ...       Banana

[22688 rows x 4 columns]

In [24]:
import random

all_categories = df['category'].unique()

# Convert the NumPy array to a list
all_categories_list = all_categories.tolist()

# Choose six random categories from the unique categories
chosen_categories = random.sample(all_categories_list, 8)



# Subsample the dataframe
subsampled_df = df[df['category'].isin(chosen_categories)]

print(subsampled_df)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

                                                  path  ...    category
18     s3://iso-can-91/Test/Cantaloupe 2/r_256_100.jpg  ...  Cantaloupe
21      s3://iso-can-91/Test/Cantaloupe 1/r_94_100.jpg  ...  Cantaloupe
22      s3://iso-can-91/Test/Cantaloupe 1/r_61_100.jpg  ...  Cantaloupe
23      s3://iso-can-91/Test/Cantaloupe 1/r_96_100.jpg  ...  Cantaloupe
25     s3://iso-can-91/Test/Cantaloupe 2/r_232_100.jpg  ...  Cantaloupe
...                                                ...  ...         ...
22664     s3://iso-can-91/Test/Chestnut/r2_115_100.jpg  ...    Chestnut
22665     s3://iso-can-91/Test/Chestnut/r2_105_100.jpg  ...    Chestnut
22667       s3://iso-can-91/Test/Peach Flat/58_100.jpg  ...       Peach
22673        s3://iso-can-91/Test/Chestnut/181_100.jpg  ...    Chestnut
22675        s3://iso-can-91/Test/Chestnut/200_100.jpg  ...    Chestnut

[2757 rows x 4 columns]