In [3]:
from pyspark.sql import SparkSession
spark = SparkSession.builder \
        .appName("CatDogImage") \
        .config("spark.sql.execution.arrow.pyspark.enabled", "true") \
        .getOrCreate()

In [4]:
folder_cat = "PetImages/Cat"
folder_dog = "PetImages/Dog"
df_cat = spark.read.format("image").option("dropInvalid", True).load(folder_cat)
df_dog = spark.read.format("image").option("dropInvalid", True).load(folder_dog)
df = df_cat.union(df_dog)

In [5]:
df.select("image.origin", "image.width", "image.height").show(truncate=False)

[Stage 0:>                                                          (0 + 1) / 1]

+-------------------------------------------------------------------+-----+------+
|origin                                                             |width|height|
+-------------------------------------------------------------------+-----+------+
|file:///Users/hzy/Desktop/nus/5208/code_try/PetImages/Cat/10073.jpg|498  |479   |
|file:///Users/hzy/Desktop/nus/5208/code_try/PetImages/Cat/8767.jpg |388  |471   |
|file:///Users/hzy/Desktop/nus/5208/code_try/PetImages/Cat/7597.jpg |298  |374   |
|file:///Users/hzy/Desktop/nus/5208/code_try/PetImages/Cat/11083.jpg|480  |469   |
|file:///Users/hzy/Desktop/nus/5208/code_try/PetImages/Cat/5614.jpg |490  |422   |
|file:///Users/hzy/Desktop/nus/5208/code_try/PetImages/Cat/7502.jpg |429  |475   |
|file:///Users/hzy/Desktop/nus/5208/code_try/PetImages/Cat/7845.jpg |496  |372   |
|file:///Users/hzy/Desktop/nus/5208/code_try/PetImages/Cat/4929.jpg |386  |459   |
|file:///Users/hzy/Desktop/nus/5208/code_try/PetImages/Cat/3967.jpg |480  |360   |
|fil

                                                                                

In [7]:
import cv2
def standardize_channels(img, n_channels):
    if n_channels == 1:
        return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
    elif n_channels == 4:
        return cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
    else:
        return img

In [8]:
from pyspark.sql.functions import udf
import numpy as np
from pyspark.ml.linalg import Vectors, VectorUDT
imageUdf = udf(lambda img:
               Vectors.dense(standardize_channels(cv2.resize(
                   np.frombuffer(img.data, dtype = np.uint8).
                   reshape((img.height, img.width, img.nChannels)), (32,32)
                   ), img.nChannels).flatten().tolist()), VectorUDT())

In [9]:
df = df.withColumn("features", imageUdf(df.image))

In [11]:
df.printSchema()

root
 |-- image: struct (nullable = true)
 |    |-- origin: string (nullable = true)
 |    |-- height: integer (nullable = true)
 |    |-- width: integer (nullable = true)
 |    |-- nChannels: integer (nullable = true)
 |    |-- mode: integer (nullable = true)
 |    |-- data: binary (nullable = true)
 |-- features: vector (nullable = true)



In [10]:
df.show()

                                                                                

+--------------------+--------------------+
|               image|            features|
+--------------------+--------------------+
|{file:///Users/hz...|[255.0,255.0,255....|
|{file:///Users/hz...|[13.0,33.0,40.0,1...|
|{file:///Users/hz...|[0.0,21.0,20.0,46...|
|{file:///Users/hz...|[26.0,28.0,28.0,6...|
|{file:///Users/hz...|[87.0,92.0,110.0,...|
|{file:///Users/hz...|[148.0,168.0,152....|
|{file:///Users/hz...|[73.0,81.0,95.0,1...|
|{file:///Users/hz...|[23.0,133.0,199.0...|
|{file:///Users/hz...|[36.0,62.0,69.0,1...|
|{file:///Users/hz...|[85.0,93.0,110.0,...|
|{file:///Users/hz...|[92.0,89.0,74.0,1...|
|{file:///Users/hz...|[18.0,25.0,58.0,5...|
|{file:///Users/hz...|[255.0,255.0,255....|
|{file:///Users/hz...|[156.0,166.0,188....|
|{file:///Users/hz...|[96.0,122.0,160.0...|
|{file:///Users/hz...|[87.0,147.0,196.0...|
|{file:///Users/hz...|[41.0,100.0,126.0...|
|{file:///Users/hz...|[48.0,45.0,40.0,3...|
|{file:///Users/hz...|[168.0,213.0,206....|
|{file:///Users/hz...|[172.0,155