diff --git a/keras/engine/training.py b/keras/engine/training.py index cab3b465670..08adcf48dfa 100644 --- a/keras/engine/training.py +++ b/keras/engine/training.py @@ -1442,7 +1442,8 @@ def _get_deduped_metrics_names(self): deduped_out_labels.append(new_label) return deduped_out_labels - def fit(self, x=None, + def fit(self, + x=None, y=None, batch_size=None, epochs=1, diff --git a/keras/models.py b/keras/models.py index 6b76cd473bb..f248b27c06a 100644 --- a/keras/models.py +++ b/keras/models.py @@ -830,9 +830,22 @@ def compile(self, optimizer, loss, self.sample_weights = self.model.sample_weights self.total_loss = self.model.total_loss - def fit(self, x, y, batch_size=32, epochs=10, verbose=1, callbacks=None, - validation_split=0., validation_data=None, shuffle=True, - class_weight=None, sample_weight=None, initial_epoch=0, **kwargs): + def fit(self, + x=None, + y=None, + batch_size=None, + epochs=1, + verbose=1, + callbacks=None, + validation_split=0., + validation_data=None, + shuffle=True, + class_weight=None, + sample_weight=None, + initial_epoch=0, + steps_per_epoch=None, + validation_steps=None, + **kwargs): """Trains the model for a fixed number of epochs. # Arguments @@ -904,7 +917,9 @@ def fit(self, x, y, batch_size=32, epochs=10, verbose=1, callbacks=None, shuffle=shuffle, class_weight=class_weight, sample_weight=sample_weight, - initial_epoch=initial_epoch) + initial_epoch=initial_epoch, + steps_per_epoch=steps_per_epoch, + validation_steps=validation_steps) def evaluate(self, x, y, batch_size=32, verbose=1, sample_weight=None): diff --git a/tests/keras/test_callbacks.py b/tests/keras/test_callbacks.py index 73b27868068..94237a1dede 100644 --- a/tests/keras/test_callbacks.py +++ b/tests/keras/test_callbacks.py @@ -247,12 +247,12 @@ def test_EarlyStopping_reuse(): stopper = callbacks.EarlyStopping(monitor='acc', patience=patience) weights = model.get_weights() - hist = model.fit(data, labels, callbacks=[stopper]) + hist = model.fit(data, labels, callbacks=[stopper], epochs=20) assert len(hist.epoch) >= patience # This should allow training to go for at least `patience` epochs model.set_weights(weights) - hist = model.fit(data, labels, callbacks=[stopper]) + hist = model.fit(data, labels, callbacks=[stopper], epochs=20) assert len(hist.epoch) >= patience