Skip to content

Commit

Permalink
fix serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
maxpumperla committed Aug 14, 2018
1 parent 7e97c7f commit 69c4217
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 40 deletions.
21 changes: 7 additions & 14 deletions elephas/spark_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,12 @@ def __init__(self, master_network, optimizer=None,
self.custom_objects = custom_objects
self.parameter_server_mode = parameter_server_mode

self.serialized_model = model_to_dict(self.master_network)
if self.parameter_server_mode == 'http':
self.parameter_server = HttpServer(self.master_network, self.optimizer, self.mode)
self.parameter_server = HttpServer(self.serialized_model, self.optimizer, self.mode)
self.client = HttpClient()
elif self.parameter_server_mode == 'socket':
self.parameter_server = SocketServer(model_to_dict(self.master_network))
self.parameter_server = SocketServer(self.serialized_model)
self.client = SocketClient()
else:
raise ValueError("Parameter server mode has to be either `http` or `socket`, "
Expand Down Expand Up @@ -145,23 +146,16 @@ def _fit(self, rdd, epochs, batch_size, verbose, validation_split):
metrics=self.master_metrics)
if self.mode in ['asynchronous', 'hogwild']:
self.start_server()
yaml = self.master_network.to_yaml()
train_config = self.get_train_config(epochs, batch_size, verbose, validation_split)
frequency = self.frequency

if self.mode in ['asynchronous', 'hogwild']:
worker = AsynchronousSparkWorker(
yaml, self.parameter_server_mode, train_config, self.frequency,
self.master_optimizer, self.master_loss, self.master_metrics, self.custom_objects
)
worker = AsynchronousSparkWorker(self.parameter_server_mode, train_config, self.frequency,
self.master_optimizer, self.master_loss, self.master_metrics, self.custom_objects)
rdd.mapPartitions(worker.train).collect()
new_parameters = self.client.get_parameters()
elif self.mode == 'synchronous':
parameters = self.master_network.get_weights()
worker = SparkWorker(
yaml, parameters, train_config,
self.master_optimizer, self.master_loss, self.master_metrics, self.custom_objects
)
worker = SparkWorker(self.serialized_model, train_config, self.master_optimizer, self.master_loss,
self.master_metrics, self.custom_objects)
deltas = rdd.mapPartitions(worker.train).collect()
new_parameters = self.master_network.get_weights()
for delta in deltas:
Expand Down Expand Up @@ -195,7 +189,6 @@ def __init__(self, master_network, optimizer=None, mode='asynchronous', frequenc
:param custom_objects: Keras custom objects
:param parameter_server_mode: String, either `http` or `socket
"""

SparkModel.__init__(self, master_network=master_network, optimizer=optimizer, mode=mode, frequency=frequency,
num_workers=num_workers, master_optimizer=master_optimizer, master_loss=master_loss,
master_metrics=master_metrics, custom_objects=custom_objects,
Expand Down
45 changes: 19 additions & 26 deletions elephas/worker.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
import numpy as np
from itertools import tee
from keras.models import model_from_yaml

from .utils.serialization import dict_to_model
from .utils import subtract_params
from .parameter import SocketClient, HttpClient


class SparkWorker(object):
"""Synchronous Spark worker. This code will be executed on workers.
"""
def __init__(self, yaml, parameters, train_config, master_optimizer,
def __init__(self, serialized_model, train_config, master_optimizer,
master_loss, master_metrics, custom_objects):
self.yaml = yaml
self.parameters = parameters
# TODO handle custom_objects
self.model = dict_to_model(serialized_model)
self.train_config = train_config
self.master_optimizer = master_optimizer
self.master_loss = master_loss
self.master_metrics = master_metrics
self.custom_objects = custom_objects

def train(self, data_iterator):
"""Train a keras model on a worker
Expand All @@ -26,26 +25,22 @@ def train(self, data_iterator):
x_train = np.asarray([x for x, y in feature_iterator])
y_train = np.asarray([y for x, y in label_iterator])

model = model_from_yaml(self.yaml, self.custom_objects)
model.compile(optimizer=self.master_optimizer,
loss=self.master_loss,
metrics=self.master_metrics)
model.set_weights(self.parameters)
weights_before_training = model.get_weights()
self.model.compile(optimizer=self.master_optimizer, loss=self.master_loss, metrics=self.master_metrics)
weights_before_training = self.model.get_weights()
if x_train.shape[0] > self.train_config.get('batch_size'):
model.fit(x_train, y_train, **self.train_config)
weights_after_training = model.get_weights()
self.model.fit(x_train, y_train, **self.train_config)
weights_after_training = self.model.get_weights()
deltas = subtract_params(weights_before_training, weights_after_training)
yield deltas


class AsynchronousSparkWorker(object):
"""Asynchronous Spark worker. This code will be executed on workers.
"""
def __init__(self, yaml, parameter_server_mode, train_config, frequency,
master_optimizer, master_loss, master_metrics,
custom_objects):
self.yaml = yaml
def __init__(self, serialized_model, parameter_server_mode, train_config, frequency,
master_optimizer, master_loss, master_metrics, custom_objects):
# TODO handle custom_objects
self.model = dict_to_model(serialized_model)
if parameter_server_mode == 'http':
self.client = HttpClient()
elif parameter_server_mode == 'socket':
Expand All @@ -60,7 +55,6 @@ def __init__(self, yaml, parameter_server_mode, train_config, frequency,
self.master_optimizer = master_optimizer
self.master_loss = master_loss
self.master_metrics = master_metrics
self.custom_objects = custom_objects

def train(self, data_iterator):
"""Train a keras model on a worker and send asynchronous updates
Expand All @@ -73,8 +67,7 @@ def train(self, data_iterator):
if x_train.size == 0:
return

model = model_from_yaml(self.yaml, self.custom_objects)
model.compile(optimizer=self.master_optimizer, loss=self.master_loss, metrics=self.master_metrics)
self.model.compile(optimizer=self.master_optimizer, loss=self.master_loss, metrics=self.master_metrics)

nb_epoch = self.train_config['nb_epoch']
batch_size = self.train_config.get('batch_size')
Expand All @@ -89,11 +82,11 @@ def train(self, data_iterator):
if self.frequency == 'epoch':
for epoch in range(nb_epoch):
weights_before_training = self.client.get_parameters()
model.set_weights(weights_before_training)
self.model.set_weights(weights_before_training)
self.train_config['nb_epoch'] = 1
if x_train.shape[0] > batch_size:
model.fit(x_train, y_train, **self.train_config)
weights_after_training = model.get_weights()
self.model.fit(x_train, y_train, **self.train_config)
weights_after_training = self.model.get_weights()
deltas = subtract_params(weights_before_training, weights_after_training)
self.client.update_parameters(deltas)
elif self.frequency == 'batch':
Expand All @@ -102,12 +95,12 @@ def train(self, data_iterator):
if x_train.shape[0] > batch_size:
for (batch_start, batch_end) in batches:
weights_before_training = self.client.get_parameters()
model.set_weights(weights_before_training)
self.model.set_weights(weights_before_training)
batch_ids = index_array[batch_start:batch_end]
X = slice_X(x_train, batch_ids)
y = slice_X(y_train, batch_ids)
model.train_on_batch(X, y)
weights_after_training = model.get_weights()
self.model.train_on_batch(X, y)
weights_after_training = self.model.get_weights()
deltas = subtract_params(weights_before_training, weights_after_training)
self.client.update_parameters(deltas)
else:
Expand Down

0 comments on commit 69c4217

Please sign in to comment.