Skip to content

Commit

Permalink
[RELNOTES] Sync Sequential.fit() with Model.fit() (#8192)
Browse files Browse the repository at this point in the history
* Sync Sequential.fit() with Model.fit()

* Specify explicit epochs=20

...instead of relying on implicit default value.

* Revert docstring changes

* pep8

* pep8
  • Loading branch information
ozabluda authored and fchollet committed Nov 10, 2017
1 parent 439d847 commit d1ee945
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 7 deletions.
3 changes: 2 additions & 1 deletion keras/engine/training.py
Expand Up @@ -1442,7 +1442,8 @@ def _get_deduped_metrics_names(self):
deduped_out_labels.append(new_label)
return deduped_out_labels

def fit(self, x=None,
def fit(self,
x=None,
y=None,
batch_size=None,
epochs=1,
Expand Down
23 changes: 19 additions & 4 deletions keras/models.py
Expand Up @@ -830,9 +830,22 @@ def compile(self, optimizer, loss,
self.sample_weights = self.model.sample_weights
self.total_loss = self.model.total_loss

def fit(self, x, y, batch_size=32, epochs=10, verbose=1, callbacks=None,
validation_split=0., validation_data=None, shuffle=True,
class_weight=None, sample_weight=None, initial_epoch=0, **kwargs):
def fit(self,
x=None,
y=None,
batch_size=None,
epochs=1,
verbose=1,
callbacks=None,
validation_split=0.,
validation_data=None,
shuffle=True,
class_weight=None,
sample_weight=None,
initial_epoch=0,
steps_per_epoch=None,
validation_steps=None,
**kwargs):
"""Trains the model for a fixed number of epochs.
# Arguments
Expand Down Expand Up @@ -904,7 +917,9 @@ def fit(self, x, y, batch_size=32, epochs=10, verbose=1, callbacks=None,
shuffle=shuffle,
class_weight=class_weight,
sample_weight=sample_weight,
initial_epoch=initial_epoch)
initial_epoch=initial_epoch,
steps_per_epoch=steps_per_epoch,
validation_steps=validation_steps)

def evaluate(self, x, y, batch_size=32, verbose=1,
sample_weight=None):
Expand Down
4 changes: 2 additions & 2 deletions tests/keras/test_callbacks.py
Expand Up @@ -247,12 +247,12 @@ def test_EarlyStopping_reuse():
stopper = callbacks.EarlyStopping(monitor='acc', patience=patience)
weights = model.get_weights()

hist = model.fit(data, labels, callbacks=[stopper])
hist = model.fit(data, labels, callbacks=[stopper], epochs=20)
assert len(hist.epoch) >= patience

# This should allow training to go for at least `patience` epochs
model.set_weights(weights)
hist = model.fit(data, labels, callbacks=[stopper])
hist = model.fit(data, labels, callbacks=[stopper], epochs=20)
assert len(hist.epoch) >= patience


Expand Down

0 comments on commit d1ee945

Please sign in to comment.