Skip to content

Commit

Permalink
Fixed number of epochs model is fitted (Issue #77) (#84)
Browse files Browse the repository at this point in the history
* Fixed number of epochs model is fitted

* Restored nb_epochs to its original value

* Changed 'nb_epochs' to 'nb_epoch'
  • Loading branch information
ivanmontero authored and maxpumperla committed Apr 19, 2018
1 parent 31b1f91 commit 7562adb
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions elephas/spark_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,10 @@ def train(self, data_iterator):
weights_before_training = get_server_weights(self.master_url)
model.set_weights(weights_before_training)
self.train_config['epochs'] = 1
self.train_config['nb_epoch'] = 1
if x_train.shape[0] > batch_size:
model.fit(x_train, y_train, **self.train_config)
self.train_config['nb_epoch'] = nb_epoch
weights_after_training = model.get_weights()
deltas = subtract_params(weights_before_training, weights_after_training)
put_deltas_to_server(deltas, self.master_url)
Expand Down

0 comments on commit 7562adb

Please sign in to comment.