Skip to content

Commit

Permalink
init clients on worker
Browse files Browse the repository at this point in the history
  • Loading branch information
maxpumperla committed Aug 14, 2018
1 parent 1b9668d commit a9812d0
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
5 changes: 2 additions & 3 deletions elephas/spark_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def __init__(self, master_network, optimizer=None,
self.custom_objects = custom_objects
self.parameter_server_mode = parameter_server_mode

# TODO: clients have to be initialized on workers, too.
if self.parameter_server_mode == 'http':
self.parameter_server = HttpServer(self.master_network, self.optimizer, self.mode)
self.client = HttpClient()
Expand Down Expand Up @@ -128,7 +127,7 @@ def train(self, rdd, nb_epoch=10, batch_size=32,
if self.mode in ['asynchronous', 'synchronous', 'hogwild']:
self._train(rdd, nb_epoch, batch_size, verbose, validation_split)
else:
raise Exception("Choose from one of the modes: asynchronous, synchronous or hogwild")
raise ValueError("Choose from one of the modes: asynchronous, synchronous or hogwild")

def _train(self, rdd, nb_epoch=10, batch_size=32, verbose=0, validation_split=0.1):
"""Protected train method to make wrapping of modes easier
Expand All @@ -142,7 +141,7 @@ def _train(self, rdd, nb_epoch=10, batch_size=32, verbose=0, validation_split=0.
train_config = self.get_train_config(nb_epoch, batch_size, verbose, validation_split)
if self.mode in ['asynchronous', 'hogwild']:
worker = AsynchronousSparkWorker(
yaml, self.client, train_config, self.frequency,
yaml, self.parameter_server_mode, train_config, self.frequency,
self.master_optimizer, self.master_loss, self.master_metrics, self.custom_objects
)
rdd.mapPartitions(worker.train).collect()
Expand Down
22 changes: 15 additions & 7 deletions elephas/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from keras.models import model_from_yaml

from .utils import subtract_params
from .parameter import SocketClient
from .parameter import SocketClient, HttpClient


class SparkWorker(object):
Expand Down Expand Up @@ -42,10 +42,19 @@ def train(self, data_iterator):
class AsynchronousSparkWorker(object):
"""Asynchronous Spark worker. This code will be executed on workers.
"""
def __init__(self, yaml, client_mode, train_config, frequency,
def __init__(self, yaml, parameter_server_mode, train_config, frequency,
master_optimizer, master_loss, master_metrics,
custom_objects):
self.yaml = yaml
if parameter_server_mode == 'http':
self.client = HttpClient()
elif parameter_server_mode == 'socket':
self.client = SocketClient()
else:
raise ValueError("Parameter server mode has to be either `http` or `socket`, "
"got {}".format(parameter_server_mode))

self.client = parameter_server_mode
self.train_config = train_config
self.frequency = frequency
self.master_optimizer = master_optimizer
Expand Down Expand Up @@ -76,32 +85,31 @@ def train(self, data_iterator):
(i * batch_size, min(nb_train_sample, (i + 1) * batch_size))
for i in range(0, nb_batch)
]
self.connector = SocketClient()

if self.frequency == 'epoch':
for epoch in range(nb_epoch):
weights_before_training = self.connector.get_parameters()
weights_before_training = self.client.get_parameters()
model.set_weights(weights_before_training)
self.train_config['nb_epoch'] = 1
if x_train.shape[0] > batch_size:
model.fit(x_train, y_train, **self.train_config)
weights_after_training = model.get_weights()
deltas = subtract_params(weights_before_training, weights_after_training)
self.connector.update_parameters(deltas)
self.client.update_parameters(deltas)
elif self.frequency == 'batch':
from keras.engine.training import slice_X
for epoch in range(nb_epoch):
if x_train.shape[0] > batch_size:
for (batch_start, batch_end) in batches:
weights_before_training = self.connector.get_parameters()
weights_before_training = self.client.get_parameters()
model.set_weights(weights_before_training)
batch_ids = index_array[batch_start:batch_end]
X = slice_X(x_train, batch_ids)
y = slice_X(y_train, batch_ids)
model.train_on_batch(X, y)
weights_after_training = model.get_weights()
deltas = subtract_params(weights_before_training, weights_after_training)
self.connector.update_parameters(deltas)
self.client.update_parameters(deltas)
else:
raise ValueError('frequency parameter can be `epoch` or `batch, got {}'.format(self.frequency))
yield []

0 comments on commit a9812d0

Please sign in to comment.