Skip to content

Commit

Permalink
logic for parallel prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
danielenricocahall committed Jan 17, 2021
1 parent 345ceb9 commit dc8f4ed
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 10 deletions.
13 changes: 7 additions & 6 deletions elephas/spark_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,11 @@ def stop_server(self):
def predict(self, data: Union[RDD, np.array]):
"""Get prediction probabilities for a numpy array of features
"""
if isinstance(data, (RDD, )):
return self.predict_rdd(data)
elif isinstance(data, (np.ndarray, )):
return self._master_network.predict(data)
if isinstance(data, (np.ndarray, )):
from pyspark.sql import SparkSession
sc = SparkSession.builder.getOrCreate().sparkContext
data = sc.parallelize(data)
return self._predict(data)

def fit(self, rdd: RDD, **kwargs):
"""
Expand Down Expand Up @@ -182,14 +183,14 @@ def _fit(self, rdd: RDD, **kwargs):
if self.mode in ['asynchronous', 'hogwild']:
self.stop_server()

def predict_rdd(self, rdd: RDD):
def _predict(self, rdd: RDD):
def _predict(model, model_type, data_iterator):
model = model_from_yaml(model)
predict_function = determine_predict_function(model, model_type)
return predict_function(np.expand_dims(data_iterator, axis=0))
if self.num_workers:
rdd = rdd.repartition(self.num_workers)
yaml_model = self._master_network.to_yaml()
yaml_model = self.master_network.to_yaml()
model_type = LossModelTypeMapper().get_model_type(self.master_loss)
predictions = rdd.map(partial(_predict, yaml_model, model_type)).collect()
return predictions
Expand Down
36 changes: 32 additions & 4 deletions tests/integration/test_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,13 @@ def test_classification_prediction(spark_context, mode, mnist_data, classificati
x_train, y_train, x_test, y_test = mnist_data
x_train = x_train[:1000]
y_train = y_train[:1000]
x_test = x_test[:1]
y_test = y_test[:500]
x_test = x_test[:100]

sgd = SGD(lr=0.1)
classification_model.compile(sgd, 'categorical_crossentropy', ['acc'])

# Build RDD from numpy features and labels
train_rdd = to_simple_rdd(spark_context, x_train, y_train)
test_rdd = spark_context.parallelize(x_test)

# Initialize SparkModel from keras model and Spark context
spark_model = SparkModel(classification_model, frequency='epoch', mode=mode)
Expand All @@ -33,5 +31,35 @@ def test_classification_prediction(spark_context, mode, mnist_data, classificati
spark_model.fit(train_rdd, epochs=epochs, batch_size=batch_size,
verbose=0, validation_split=0.1)

predictions = spark_model.predict(test_rdd)
# assert we have as many predictions as samples provided
assert len(spark_model.predict(x_test)) == 100

test_rdd = spark_context.parallelize(x_test)
# assert we can supply rdd
assert len(spark_model.predict(test_rdd)) == 100


@pytest.mark.parametrize('mode', ['synchronous', 'asynchronous', 'hogwild'])
def test_classification_regression(spark_context, mode, boston_housing_dataset, regression_model):
x_train, y_train, x_test, y_test = boston_housing_dataset
train_rdd = to_simple_rdd(spark_context, x_train, y_train)
x_test = x_test[:100]

# Define basic parameters
batch_size = 64
epochs = 10
sgd = SGD(lr=0.0000001)
regression_model.compile(sgd, 'mse', ['mae'])
# Initialize SparkModel from keras model and Spark context
spark_model = SparkModel(regression_model, frequency='epoch', mode=mode)

# Train Spark model
spark_model.fit(train_rdd, epochs=epochs, batch_size=batch_size,
verbose=0, validation_split=0.1)

# assert we have as many predictions as samples provided
assert len(spark_model.predict(x_test)) == 100

test_rdd = spark_context.parallelize(x_test)
# assert we can supply rdd
assert len(spark_model.predict(test_rdd)) == 100

0 comments on commit dc8f4ed

Please sign in to comment.