Skip to content

Commit

Permalink
docs and various fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
maxpumperla committed Aug 14, 2018
1 parent 1b1a554 commit 1b9668d
Showing 1 changed file with 42 additions and 25 deletions.
67 changes: 42 additions & 25 deletions elephas/spark_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,48 @@
from .utils import lp_to_simple_rdd
from .utils import model_to_dict
from .mllib import to_matrix, from_matrix, to_vector, from_vector
from .optimizers import SGD as default_optimizer
from .optimizers import SGD
from .worker import AsynchronousSparkWorker, SparkWorker
from .parameter import HttpServer, SocketServer
from .parameter import HttpClient, SocketClient


class SparkModel(object):
"""SparkModel is the main abstraction of elephas. Every other model
should inherit from it.
"""
# TODO: Eliminate Spark context (only used for first broadcast, can be extracted)

def __init__(self, master_network, optimizer=None,
mode='asynchronous', frequency='epoch',
num_workers=4,
master_optimizer="sgd", # TODO: other default
num_workers=None,
master_optimizer="sgd",
master_loss="categorical_crossentropy",
master_metrics=None,
custom_objects=None,
parameter_server='http',
parameter_server_mode='http',
*args, **kwargs):
"""SparkModel
Base class for distributed training on RDDs. Spark model takes a Keras
model as master network, an optimization scheme, a parallelisation mode
and an averaging frequency.
:param master_network: Keras model (not compiled)
:param optimizer: Elephas optimizer
:param mode: String, choose from `asynchronous`, `synchronous` and `hogwild`
:param frequency: String, either `epoch` or `batch`
:param num_workers: int, number of workers used for training (defaults to None)
:param master_optimizer: Keras optimizer for master network
:param master_loss: Keras loss function for master network
:param master_metrics: Keras metrics used for master network
:param custom_objects: Keras custom objects
:param parameter_server_mode: String, either `http` or `socket`
"""

self._master_network = master_network
if custom_objects is None:
custom_objects = {}
if master_metrics is None:
master_metrics = ["accuracy"]
if optimizer is None:
self.optimizer = default_optimizer()
self.optimizer = SGD()
else:
self.optimizer = optimizer
self.mode = mode
Expand All @@ -45,14 +59,18 @@ def __init__(self, master_network, optimizer=None,
self.master_loss = master_loss
self.master_metrics = master_metrics
self.custom_objects = custom_objects
self.parameter_server_mode = parameter_server_mode

# TODO: clients have to be initialized on workers. Only init servers here, clients on workers
if parameter_server == 'http':
# TODO: clients have to be initialized on workers, too.
if self.parameter_server_mode == 'http':
self.parameter_server = HttpServer(self.master_network, self.optimizer, self.mode)
self.connector = HttpClient()
else:
self.client = HttpClient()
elif self.parameter_server_mode == 'socket':
self.parameter_server = SocketServer(model_to_dict(self.master_network))
self.connector = SocketClient()
self.client = SocketClient()
else:
raise ValueError("Parameter server mode has to be either `http` or `socket`, "
"got {}".format(self.parameter_server_mode))

@staticmethod
def get_train_config(nb_epoch, batch_size,
Expand Down Expand Up @@ -104,15 +122,15 @@ def train(self, rdd, nb_epoch=10, batch_size=32,
# TODO: Make dataframe the standard, but support RDDs as well
"""Train an elephas model.
"""
rdd = rdd.repartition(self.num_workers)
if self.num_workers:
rdd = rdd.repartition(self.num_workers)

if self.mode in ['asynchronous', 'synchronous', 'hogwild']:
self._train(rdd, nb_epoch, batch_size, verbose, validation_split)
else:
raise Exception("""Choose from one of the modes: asynchronous, synchronous or hogwild""")
raise Exception("Choose from one of the modes: asynchronous, synchronous or hogwild")

def _train(self, rdd, nb_epoch=10, batch_size=32, verbose=0,
validation_split=0.1):
def _train(self, rdd, nb_epoch=10, batch_size=32, verbose=0, validation_split=0.1):
"""Protected train method to make wrapping of modes easier
"""
self.master_network.compile(optimizer=self.master_optimizer,
Expand All @@ -121,15 +139,14 @@ def _train(self, rdd, nb_epoch=10, batch_size=32, verbose=0,
if self.mode in ['asynchronous', 'hogwild']:
self.start_server()
yaml = self.master_network.to_yaml()
train_config = self.get_train_config(nb_epoch, batch_size,
verbose, validation_split)
train_config = self.get_train_config(nb_epoch, batch_size, verbose, validation_split)
if self.mode in ['asynchronous', 'hogwild']:
worker = AsynchronousSparkWorker(
yaml, self.connector, train_config, self.frequency,
yaml, self.client, train_config, self.frequency,
self.master_optimizer, self.master_loss, self.master_metrics, self.custom_objects
)
rdd.mapPartitions(worker.train).collect()
new_parameters = self.connector.get_parameters()
new_parameters = self.client.get_parameters()
elif self.mode == 'synchronous':
parameters = self.master_network.get_weights()
worker = SparkWorker(
Expand All @@ -152,13 +169,13 @@ class SparkMLlibModel(SparkModel):
"""MLlib model takes RDDs of LabeledPoints. Internally we just convert
back to plain old pair RDDs and continue as in SparkModel
"""
def __init__(self, sc, master_network, optimizer=None, mode='asynchronous', frequency='epoch', num_workers=4,
def __init__(self, master_network, optimizer=None, mode='asynchronous', frequency='epoch', num_workers=4,
master_optimizer="adam",
master_loss="categorical_crossentropy",
master_metrics=None,
custom_objects=None):
# TODO signature is wrong
SparkModel.__init__(self, sc, master_network, optimizer, mode, frequency, num_workers,

SparkModel.__init__(self, master_network, optimizer, mode, frequency, num_workers,
master_optimizer=master_optimizer, master_loss=master_loss, master_metrics=master_metrics,
custom_objects=custom_objects)

Expand Down

0 comments on commit 1b9668d

Please sign in to comment.