Skip to content

Commit

Permalink
remove broadcasting
Browse files Browse the repository at this point in the history
  • Loading branch information
maxpumperla committed Aug 13, 2018
1 parent 9814158 commit 26e2d4e
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 17 deletions.
27 changes: 12 additions & 15 deletions elephas/spark_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class SparkModel(object):
should inherit from it.
"""
# TODO: Eliminate Spark context (only used for first broadcast, can be extracted)
def __init__(self, sc, master_network, optimizer=None,
def __init__(self, master_network, optimizer=None,
mode='asynchronous', frequency='epoch',
num_workers=4,
master_optimizer="sgd", # TODO: other default
Expand All @@ -27,7 +27,6 @@ def __init__(self, sc, master_network, optimizer=None,
parameter_server='http',
*args, **kwargs):

self.spark_context = sc
self._master_network = master_network
if custom_objects is None:
custom_objects = {}
Expand All @@ -47,7 +46,7 @@ def __init__(self, sc, master_network, optimizer=None,
self.master_metrics = master_metrics
self.custom_objects = custom_objects

# TODO: connector has to be initialized on workers
# TODO: clients have to be initialized on workers. Only init servers here, clients on workers
if parameter_server == 'http':
self.parameter_server = HttpServer(self.master_network, self.optimizer, self.mode)
self.connector = HttpClient()
Expand Down Expand Up @@ -92,20 +91,20 @@ def stop_server(self):
self.parameter_server.stop()

def predict(self, data):
'''Get prediction probabilities for a numpy array of features
'''
"""Get prediction probabilities for a numpy array of features
"""
return self.master_network.predict(data)

def predict_classes(self, data):
'''Predict classes for a numpy array of features
'''
""" Predict classes for a numpy array of features
"""
return self.master_network.predict_classes(data)

def train(self, rdd, nb_epoch=10, batch_size=32,
verbose=0, validation_split=0.1):
# TODO: Make dataframe the standard, but support RDDs as well
'''Train an elephas model.
'''
"""Train an elephas model.
"""
rdd = rdd.repartition(self.num_workers)

if self.mode in ['asynchronous', 'synchronous', 'hogwild']:
Expand All @@ -115,9 +114,8 @@ def train(self, rdd, nb_epoch=10, batch_size=32,

def _train(self, rdd, nb_epoch=10, batch_size=32, verbose=0,
validation_split=0.1):
'''
Protected train method to make wrapping of modes easier
'''
"""Protected train method to make wrapping of modes easier
"""
self.master_network.compile(optimizer=self.master_optimizer,
loss=self.master_loss,
metrics=self.master_metrics)
Expand All @@ -134,10 +132,9 @@ def _train(self, rdd, nb_epoch=10, batch_size=32, verbose=0,
rdd.mapPartitions(worker.train).collect()
new_parameters = self.connector.get_parameters()
elif self.mode == 'synchronous':
init = self.master_network.get_weights()
parameters = self.spark_context.broadcast(init)
parameters = self.master_network.get_weights()
worker = SparkWorker(
yaml, parameters, train_config,
yaml, parameters, train_config,
self.master_optimizer, self.master_loss, self.master_metrics, self.custom_objects
)
deltas = rdd.mapPartitions(worker.train).collect()
Expand Down
4 changes: 2 additions & 2 deletions elephas/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def train(self, data_iterator):
model.compile(optimizer=self.master_optimizer,
loss=self.master_loss,
metrics=self.master_metrics)
model.set_weights(self.parameters.value)
model.set_weights(self.parameters)
weights_before_training = model.get_weights()
if x_train.shape[0] > self.train_config.get('batch_size'):
model.fit(x_train, y_train, **self.train_config)
Expand All @@ -42,7 +42,7 @@ def train(self, data_iterator):
class AsynchronousSparkWorker(object):
"""Asynchronous Spark worker. This code will be executed on workers.
"""
def __init__(self, yaml, ps_connector, train_config, frequency,
def __init__(self, yaml, client_mode, train_config, frequency,
master_optimizer, master_loss, master_metrics,
custom_objects):
self.yaml = yaml
Expand Down

0 comments on commit 26e2d4e

Please sign in to comment.