In [1]:
import os

SUBMIT_ARGS = "--packages databricks:spark-deep-learning:1.0.0-spark2.3-s_2.11 pyspark-shell"
os.environ["PYSPARK_SUBMIT_ARGS"] = SUBMIT_ARGS  

In [2]:
#Import PySpark libraries
from pyspark.sql import SparkSession
from pyspark import SparkContext, SparkConf
from pyspark.sql import SQLContext

spark = SparkSession.builder.appName('Spark').config("spark.driver.memory", "16g").config("spark.executor.memory","16g").config("spark.memory.offHeap.enabled",True).config("spark.memory.offHeap.size","128g").getOrCreate()
spark


In [3]:
from pyspark.ml.image import ImageSchema
from pyspark.sql.functions import lit
from functools import reduce

dataframes = []
# loaded image
for i in range(25):
    path = "images/" + str(i)
    df = ImageSchema.readImages(path).withColumn("label", lit(i))
    dataframes.append(df)
    
df = reduce(lambda first, second: first.union(second), dataframes)
df = df.repartition(200)
train, test = df.randomSplit([0.7, 0.3])

In [4]:
import time
start = time.time()

In [5]:
import sys, glob, os
sys.path.extend(glob.glob(os.path.join(os.path.expanduser("~"), ".ivy2/jars/*.jar")))

In [6]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.classification import LogisticRegression
from pyspark.ml import Pipeline
from sparkdl import DeepImageFeaturizer 

# model: InceptionV3
# extracting feature from images
featurizer = DeepImageFeaturizer(inputCol="image", outputCol="features", modelName="InceptionV3")

# used as a multi class classifier
lr = LogisticRegression(maxIter=10, regParam=0.03, elasticNetParam=0.5, labelCol="label") 

# define a pipeline model
sparkdn = Pipeline(stages=[featurizer, lr])
spark_model = sparkdn.fit(train)

In [7]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

# evaluate the model with test set
evaluator = MulticlassClassificationEvaluator() 
transform_test = spark_model.transform(test)

In [8]:
print('Accuracy ', evaluator.evaluate(transform_test, {evaluator.metricName: 'accuracy'}))

Accuracy  0.916514764855997


In [9]:
end = time.time()
print ("%d secs" %(end-start))

3288 secs
