Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 25 additions & 23 deletions python/sparkdl/estimators/keras_image_file_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,13 @@
import threading
import numpy as np

import pyspark
from pyspark.ml import Estimator
import pyspark.ml.linalg as spla

from sparkdl.image.imageIO import imageStructToArray
from sparkdl.param import (
keyword_only, CanLoadImage, HasKerasModel, HasKerasOptimizer, HasKerasLoss, HasOutputMode,
HasInputCol, HasInputImageNodeName, HasLabelCol, HasOutputNodeName, HasOutputCol)
HasInputCol, HasLabelCol, HasOutputCol)
from sparkdl.transformers.keras_image import KerasImageFileTransformer
import sparkdl.utils.jvmapi as JVMAPI
import sparkdl.utils.keras_model as kmutil
Expand Down Expand Up @@ -74,10 +73,8 @@ def next(self):
return self.__next__()


class KerasImageFileEstimator(Estimator, HasInputCol, HasInputImageNodeName,
HasOutputCol, HasOutputNodeName, HasLabelCol,
HasKerasModel, HasKerasOptimizer, HasKerasLoss,
CanLoadImage, HasOutputMode):
class KerasImageFileEstimator(Estimator, HasInputCol, HasOutputCol, HasLabelCol, HasKerasModel,
HasKerasOptimizer, HasKerasLoss, CanLoadImage, HasOutputMode):
"""
Build a Estimator from a Keras model.

Expand Down Expand Up @@ -138,13 +135,11 @@ def load_image_and_process(uri):
"""

@keyword_only
def __init__(self, inputCol=None, inputImageNodeName=None, outputCol=None,
outputNodeName=None, outputMode="vector", labelCol=None,
def __init__(self, inputCol=None, outputCol=None, outputMode="vector", labelCol=None,
modelFile=None, imageLoader=None, kerasOptimizer=None, kerasLoss=None,
kerasFitParams=None):
"""
__init__(self, inputCol=None, inputImageNodeName=None, outputCol=None,
outputNodeName=None, outputMode="vector", labelCol=None,
__init__(self, inputCol=None, outputCol=None, outputMode="vector", labelCol=None,
modelFile=None, imageLoader=None, kerasOptimizer=None, kerasLoss=None,
kerasFitParams=None)
"""
Expand All @@ -155,13 +150,11 @@ def __init__(self, inputCol=None, inputImageNodeName=None, outputCol=None,
self.setParams(**kwargs)

@keyword_only
def setParams(self, inputCol=None, inputImageNodeName=None, outputCol=None,
outputNodeName=None, outputMode="vector", labelCol=None,
def setParams(self, inputCol=None, outputCol=None, outputMode="vector", labelCol=None,
modelFile=None, imageLoader=None, kerasOptimizer=None, kerasLoss=None,
kerasFitParams=None):
"""
setParams(self, inputCol=None, inputImageNodeName=None, outputCol=None,
outputNodeName=None, outputMode="vector", labelCol=None,
setParams(self, inputCol=None, outputCol=None, outputMode="vector", labelCol=None,
modelFile=None, imageLoader=None, kerasOptimizer=None, kerasLoss=None,
kerasFitParams=None)
"""
Expand All @@ -174,12 +167,23 @@ def _validateParams(self, paramMap):
:param paramMap: Dict[pyspark.ml.param.Param, object]
:return: True if parameters are valid
"""
if not self.isDefined(self.inputCol):
raise ValueError("Input column must be defined")
if not self.isDefined(self.outputCol):
raise ValueError("Output column must be defined")
if self.inputCol in paramMap:
raise ValueError("Input column can not be fine tuned")
model_params = [self.kerasOptimizer, self.kerasLoss, self.kerasFitParams]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merge model_params and output_params into tunable_params for clarity

output_params = [self.outputCol, self.outputMode]

params = self.params
undefined = set([p for p in params if not self.isDefined(p)])
undefined_tunable = undefined.intersection(model_params + output_params)
failed_define = [p.name for p in undefined.difference(undefined_tunable)]
failed_tune = [p.name for p in undefined_tunable if p not in paramMap]

if failed_define or failed_tune:
msg = "Following Params must be"
if failed_define:
msg += " defined: [" + ", ".join(failed_define) + "]"
if failed_tune:
msg += " defined or tuned: [" + ", ".join(failed_tune) + "]"
raise ValueError(msg)

return True

def _getNumpyFeaturesAndLabels(self, dataset):
Expand Down Expand Up @@ -236,8 +240,7 @@ def _collectModels(self, kerasModelBytesRDD):
"""
Collect Keras models on workers to MLlib Models on the driver.
:param kerasModelBytesRDD: RDD of (param_map, model_bytes) tuples
:param paramMaps: list of ParamMaps matching the maps in `kerasModelsRDD`
:return: list of MLlib models
:return: generator of (index, MLlib model) tuples
"""
for (i, param_map, model_bytes) in kerasModelBytesRDD.collect():
model_filename = kmutil.bytes_to_h5file(model_bytes)
Expand All @@ -264,7 +267,6 @@ def _name_value_map(paramMap):
"""takes a dictionary {param -> value} and returns a map of {param.name -> value}"""
return {param.name: val for param, val in paramMap.items()}


sc = JVMAPI._curr_sc()
paramNameMaps = list(enumerate(map(_name_value_map, paramMaps)))
num_models = len(paramNameMaps)
Expand Down
4 changes: 2 additions & 2 deletions python/sparkdl/param/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# TFTransformer Params
HasInputMapping, HasOutputMapping, HasTFInputGraph, HasTFHParams,
# Keras Estimator Params
HasKerasModel, HasKerasLoss, HasKerasOptimizer, HasOutputNodeName)
HasKerasModel, HasKerasLoss, HasKerasOptimizer)
from sparkdl.param.converters import SparkDLTypeConverters
from sparkdl.param.image_params import (
CanLoadImage, HasInputImageNodeName, HasOutputMode, OUTPUT_MODES)
CanLoadImage, HasOutputMode, OUTPUT_MODES)
13 changes: 0 additions & 13 deletions python/sparkdl/param/image_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,6 @@
OUTPUT_MODES = ["vector", "image"]


class HasInputImageNodeName(Params):
# TODO: docs
inputImageNodeName = Param(Params._dummy(), "inputImageNodeName",
"name of the graph element/node corresponding to the input",
typeConverter=TypeConverters.toString)

def setInputImageNodeName(self, value):
return self._set(inputImageNodeName=value)

def getInputImageNodeName(self):
return self.getOrDefault(self.inputImageNodeName)


class CanLoadImage(Params):
"""
In standard Keras workflow, we use provides an image loading function
Expand Down
13 changes: 0 additions & 13 deletions python/sparkdl/param/shared_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,19 +105,6 @@ def getOutputCol(self):
########################################################


class HasOutputNodeName(Params):
# TODO: docs
outputNodeName = Param(Params._dummy(), "outputNodeName",
"name of the graph element/node corresponding to the output",
typeConverter=TypeConverters.toString)

def setOutputNodeName(self, value):
return self._set(outputNodeName=value)

def getOutputNodeName(self):
return self.getOrDefault(self.outputNodeName)


class HasLabelCol(Params):
"""
When training Keras image models in a supervised learning setting,
Expand Down
26 changes: 23 additions & 3 deletions python/tests/estimators/test_keras_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,7 @@ def _get_model(self, label_cardinality):
return model

def _get_estimator(self, model):
"""
Create a :py:obj:`KerasImageFileEstimator` from an existing Keras model
"""
"""Create a :py:obj:`KerasImageFileEstimator` from an existing Keras model"""
_random_filename_suffix = str(uuid.uuid4())
model_filename = os.path.join(self.temp_dir, 'model-{}.h5'.format(_random_filename_suffix))
model.save(model_filename)
Expand All @@ -105,7 +103,27 @@ def setUp(self):
def tearDown(self):
shutil.rmtree(self.temp_dir, ignore_errors=True)

def test_validate_params(self):
"""Test that `KerasImageFileEstimator._validateParams` method works as expected"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probably could be more thorough in making sure all the intersect/diff logic works but I think it's good enough.

kifest = KerasImageFileEstimator()

# should raise an error to define required parameters
# assuming at least one param without default value
self.assertRaisesRegexp(ValueError, 'defined', kifest._validateParams, {})
kifest.setParams(imageLoader=_load_image_from_uri, inputCol='c1', labelCol='c2')
kifest.setParams(modelFile='/path/to/file.ext')

# should raise an error to define or tune parameters
# assuming at least one tunable param without default value
self.assertRaisesRegexp(ValueError, 'tuned', kifest._validateParams, {})
kifest.setParams(kerasOptimizer='adam', kerasLoss='mse', kerasFitParams={})
kifest.setParams(outputCol='c3', outputMode='vector')

# should pass test on supplying all parameters
self.assertTrue(kifest._validateParams({}))

def test_single_training(self):
"""Test that single model fitting works well"""
# Create image URI dataframe
label_cardinality = 10
image_uri_df = self._create_train_image_uris_and_labels(repeat_factor=3,
Expand All @@ -123,6 +141,7 @@ def test_single_training(self):
str(transformer.getOrDefault(p)))

def test_tuning(self):
"""Test that multiple model fitting using `CrossValidator` works well"""
# Create image URI dataframe
label_cardinality = 2
image_uri_df = self._create_train_image_uris_and_labels(repeat_factor=3,
Expand Down Expand Up @@ -150,6 +169,7 @@ def test_tuning(self):
"fit params must be copied")

def test_keras_training_utils(self):
"""Test some Keras training utils"""
self.assertTrue(kmutil.is_valid_optimizer('adam'))
self.assertFalse(kmutil.is_valid_optimizer('noSuchOptimizer'))
self.assertTrue(kmutil.is_valid_loss_function('mse'))
Expand Down