In [1]:
# https://sparkbyexamples.com/spark/spark-read-binary-file-into-dataframe/
#df = spark.read.format("image").option("recursiveFileLookup", True).load("s3://multimedia-commons/data/images/{00*,01*}")
df = spark.read.format("image").option("recursiveFileLookup", True).load("s3://multimedia-commons/data/images/{00[0-9]}").persist()
df.printSchema()
df.count()

VBox()

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,Current session?
8,application_1650412208045_0010,pyspark,idle,Link,Link,✔


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

SparkSession available as 'spark'.


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

root
 |-- image: struct (nullable = true)
 |    |-- origin: string (nullable = true)
 |    |-- height: integer (nullable = true)
 |    |-- width: integer (nullable = true)
 |    |-- nChannels: integer (nullable = true)
 |    |-- mode: integer (nullable = true)
 |    |-- data: binary (nullable = true)

887

## Use the built-in ```toNDArray``` function

In [2]:
# https://stackoverflow.com/a/69215982/11262633
import pyspark.sql.functions as F
from pyspark.ml.image import ImageSchema
from pyspark.ml.linalg import DenseVector, VectorUDT

@F.udf(returnType=VectorUDT())
def img2vec(x):
    try:
        image_np = DenseVector(ImageSchema.toNDArray(x).flatten())
    except:
        image_np = None
    return image_np

print(f'Image fields = {ImageSchema.imageFields}')
df_new = df.withColumn('vecs',img2vec('image')).persist()
df_new.show()
#df_new.select('vecs').first().asDict().keys()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Image fields = ['origin', 'height', 'width', 'nChannels', 'mode', 'data']
+--------------------+--------------------+
|               image|                vecs|
+--------------------+--------------------+
|{s3://multimedia-...|[44.0,181.0,150.0...|
|{s3://multimedia-...|[45.0,103.0,108.0...|
|{s3://multimedia-...|[187.0,138.0,154....|
|{s3://multimedia-...|[0.0,46.0,221.0,0...|
|{s3://multimedia-...|[30.0,54.0,90.0,1...|
|{s3://multimedia-...|[54.0,73.0,81.0,5...|
|{s3://multimedia-...|[27.0,25.0,24.0,3...|
|{s3://multimedia-...|[182.0,179.0,158....|
|{s3://multimedia-...|[51.0,86.0,89.0,1...|
|{s3://multimedia-...|[234.0,178.0,96.0...|
|{s3://multimedia-...|[21.0,43.0,18.0,1...|
|{s3://multimedia-...|[7.0,73.0,114.0,1...|
|{s3://multimedia-...|[60.0,183.0,125.0...|
|{s3://multimedia-...|[41.0,37.0,32.0,3...|
|{s3://multimedia-...|[143.0,186.0,207....|
|{s3://multimedia-...|[9.0,79.0,42.0,49...|
|{s3://multimedia-...|[222.0,200.0,188....|
|{s3://multimedia-...|[205.0,175.0,128....|
|{

### Q:  Why did this take >3 min?

In [3]:
df_null = df_new.where(df_new.vecs.isNull()).select('image.origin')
df_null.show(truncate=False)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+----------------------------------------------------------------------------+
|origin                                                                      |
+----------------------------------------------------------------------------+
|s3://multimedia-commons/data/images/002/7ce/0027ce88ae5dc31b18e3245743e3.jpg|
+----------------------------------------------------------------------------+

In [6]:
from py4j.java_gateway import java_import
java_import(spark._sc._jvm, "org.apache.spark.sql.api.python.*")

df_null.explain()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

== Physical Plan ==
*(1) Project [image#0.origin AS origin#123]
+- *(1) Filter isnull(vecs#40)
   +- InMemoryTableScan [image#0, vecs#40], [isnull(vecs#40)]
         +- InMemoryRelation [image#0, vecs#40], StorageLevel(disk, memory, deserialized, 1 replicas)
               +- *(1) Project [image#0, pythonUDF0#53 AS vecs#40]
                  +- BatchEvalPython [img2vec(image#0)], [pythonUDF0#53]
                     +- InMemoryTableScan [image#0]
                           +- InMemoryRelation [image#0], StorageLevel(disk, memory, deserialized, 1 replicas)
                                 +- FileScan image [image#0] Batched: false, DataFilters: [], Format: org.apache.spark.ml.source.image.ImageFileFormat@441229d6, Location: InMemoryFileIndex[s3://multimedia-commons/data/images/000, s3://multimedia-commons/data/images/00..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<image:struct<origin:string,height:int,width:int,nChannels:int,mode:int,data:binary>>

In [7]:
image = df_new.where(df_new.image.origin == 's3://multimedia-commons/data/images/002/7ce/0027ce88ae5dc31b18e3245743e3.jpg')

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

## Offending image in DataFrame is 0 Bytes

However, when downloaded, image is 99324 bytes

In [8]:
image_row = image.first()
image_data = image_row.asDict()['image'].asDict()['data']
len(image_data)
print(f'Image size={len(image_data)}')

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Image size=0

In [9]:
image_row

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Row(image=Row(origin='s3://multimedia-commons/data/images/002/7ce/0027ce88ae5dc31b18e3245743e3.jpg', height=-1, width=-1, nChannels=-1, mode=-1, data=bytearray(b'')), vecs=None)

In [11]:
first_image = df_new.first()
first_data = first_image.asDict()['image'].asDict()['data']
print(f'Image {first_image.image.origin} size={len(first_data)}')

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Image s3://multimedia-commons/data/images/007/1c6/0071c6b1b7a8e129824b3de084d4c.jpg size=516000