Skip to content

Commit

Permalink
Moved fit/predict params to init. Changed test accordingly. This addr…
Browse files Browse the repository at this point in the history
…esses #558 .
  • Loading branch information
JasonTam committed Aug 19, 2015
1 parent 103a3da commit dcbe14c
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 61 deletions.
119 changes: 62 additions & 57 deletions keras/wrappers/scikit_learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,36 @@ class BaseWrapper(object):
Base class for the Keras scikit-learn wrapper.
Warning: This class should not be used directly. Use derived classes instead.
Parameters
----------
train_batch_size : int, optional
Number of training samples evaluated at a time.
test_batch_size : int, optional
Number of test samples evaluated at a time.
nb_epochs : int, optional
Number of training epochs.
shuffle : boolean, optional
Whether to shuffle the samples at each epoch.
show_accuracy : boolean, optional
Whether to display class accuracy in the logs at each epoch.
validation_split : float [0, 1], optional
Fraction of the data to use as held-out validation data.
validation_data : tuple (X, y), optional
Data to be used as held-out validation data. Will override validation_split.
callbacks : list, optional
List of callbacks to apply during training.
verbose : int, optional
Verbosity level.
"""
__metaclass__ = abc.ABCMeta

@abc.abstractmethod
def __init__(self, model, optimizer, loss):
def __init__(self, model, optimizer, loss,
train_batch_size=128, test_batch_size=128,
nb_epoch=100, shuffle=True, show_accuracy=False,
validation_split=0, validation_data=None, callbacks=None,
verbose=0,):
self.model = model
self.optimizer = optimizer
self.loss = loss
Expand All @@ -24,6 +49,17 @@ def __init__(self, model, optimizer, loss):
self.config_ = []
self.weights_ = []

self.train_batch_size = train_batch_size
self.test_batch_size = test_batch_size
self.nb_epoch = nb_epoch
self.shuffle = shuffle
self.show_accuracy = show_accuracy
self.validation_split = validation_split
self.validation_data = validation_data
self.callbacks = [] if callbacks is None else callbacks

self.verbose = verbose

def get_params(self, deep=True):
"""
Get parameters for this estimator.
Expand Down Expand Up @@ -58,8 +94,7 @@ def set_params(self, **params):
setattr(self, parameter, value)
return self

def fit(self, X, y, batch_size=128, nb_epoch=100, verbose=0, shuffle=True, show_accuracy=False,
validation_split=0, validation_data=None, callbacks=[]):
def fit(self, X, y):
"""
Fit the model according to the given training data.
Expand All @@ -74,22 +109,6 @@ def fit(self, X, y, batch_size=128, nb_epoch=100, verbose=0, shuffle=True, show_
and n_features is the number of features.
y : array-like, shape = (n_samples) or (n_samples, n_outputs)
True labels for X.
batch_size : int, optional
Number of training samples evaluated at a time.
nb_epochs : int, optional
Number of training epochs.
verbose : int, optional
Verbosity level.
shuffle : boolean, optional
Whether to shuffle the samples at each epoch.
show_accuracy : boolean, optional
Whether to display class accuracy in the logs at each epoch.
validation_split : float [0, 1], optional
Fraction of the data to use as held-out validation data.
validation_data : tuple (X, y), optional
Data to be used as held-out validation data. Will override validation_split.
callbacks : list, optional
List of callbacks to apply during training.
Returns
-------
Expand All @@ -105,10 +124,11 @@ def fit(self, X, y, batch_size=128, nb_epoch=100, verbose=0, shuffle=True, show_

self.compiled_model_ = copy.deepcopy(self.model)
self.compiled_model_.compile(optimizer=self.optimizer, loss=self.loss)
history = self.compiled_model_.fit(X, y, batch_size=batch_size, nb_epoch=nb_epoch, verbose=verbose,
shuffle=shuffle, show_accuracy=show_accuracy,
validation_split=validation_split, validation_data=validation_data,
callbacks=callbacks)
history = self.compiled_model_.fit(
X, y, batch_size=self.train_batch_size, nb_epoch=self.nb_epoch, verbose=self.verbose,
shuffle=self.shuffle, show_accuracy=self.show_accuracy,
validation_split=self.validation_split, validation_data=self.validation_data,
callbacks=self.callbacks)

self.config_ = self.model.get_config()
self.weights_ = self.model.get_weights()
Expand All @@ -129,10 +149,10 @@ class KerasClassifier(BaseWrapper):
loss : string
Loss function used by the model during compilation/training.
"""
def __init__(self, model, optimizer='adam', loss='categorical_crossentropy'):
super(KerasClassifier, self).__init__(model, optimizer, loss)
def __init__(self, model, optimizer='adam', loss='categorical_crossentropy', **kwargs):
super(KerasClassifier, self).__init__(model, optimizer, loss, **kwargs)

def predict(self, X, batch_size=128, verbose=0):
def predict(self, X):
"""
Returns the class predictions for the given test data.
Expand All @@ -141,19 +161,16 @@ def predict(self, X, batch_size=128, verbose=0):
X : array-like, shape = (n_samples, n_features)
Test samples where n_samples in the number of samples
and n_features is the number of features.
batch_size : int, optional
Number of test samples evaluated at a time.
verbose : int, optional
Verbosity level.
Returns
-------
preds : array-like, shape = (n_samples)
Class predictions.
"""
return self.compiled_model_.predict_classes(X, batch_size=batch_size, verbose=verbose)
return self.compiled_model_.predict_classes(
X, batch_size=self.test_batch_size, verbose=self.verbose)

def predict_proba(self, X, batch_size=128, verbose=0):
def predict_proba(self, X):
"""
Returns class probability estimates for the given test data.
Expand All @@ -162,19 +179,16 @@ def predict_proba(self, X, batch_size=128, verbose=0):
X : array-like, shape = (n_samples, n_features)
Test samples where n_samples in the number of samples
and n_features is the number of features.
batch_size : int, optional
Number of test samples evaluated at a time.
verbose : int, optional
Verbosity level.
Returns
-------
proba : array-like, shape = (n_samples, n_outputs)
Class probability estimates.
"""
return self.compiled_model_.predict_proba(X, batch_size=batch_size, verbose=verbose)
return self.compiled_model_.predict_proba(
X, batch_size=self.test_batch_size, verbose=self.verbose)

def score(self, X, y, batch_size=128, verbose=0):
def score(self, X, y):
"""
Returns the mean accuracy on the given test data and labels.
Expand All @@ -185,17 +199,14 @@ def score(self, X, y, batch_size=128, verbose=0):
and n_features is the number of features.
y : array-like, shape = (n_samples) or (n_samples, n_outputs)
True labels for X.
batch_size : int, optional
Number of test samples evaluated at a time.
verbose : int, optional
Verbosity level.
Returns
-------
score : float
Mean accuracy of predictions on X wrt. y.
"""
loss, accuracy = self.compiled_model_.evaluate(X, y, batch_size=batch_size, show_accuracy=True, verbose=verbose)
loss, accuracy = self.compiled_model_.evaluate(
X, y, batch_size=self.test_batch_size, show_accuracy=True, verbose=self.verbose)
return accuracy


Expand All @@ -212,10 +223,10 @@ class KerasRegressor(BaseWrapper):
loss : string
Loss function used by the model during compilation/training.
"""
def __init__(self, model, optimizer='adam', loss='mean_squared_error'):
super(KerasRegressor, self).__init__(model, optimizer, loss)
def __init__(self, model, optimizer='adam', loss='mean_squared_error', **kwargs):
super(KerasRegressor, self).__init__(model, optimizer, loss, **kwargs)

def predict(self, X, batch_size=128, verbose=0):
def predict(self, X):
"""
Returns predictions for the given test data.
Expand All @@ -224,19 +235,16 @@ def predict(self, X, batch_size=128, verbose=0):
X : array-like, shape = (n_samples, n_features)
Test samples where n_samples in the number of samples
and n_features is the number of features.
batch_size : int, optional
Number of test samples evaluated at a time.
verbose : int, optional
Verbosity level.
Returns
-------
preds : array-like, shape = (n_samples)
Predictions.
"""
return self.compiled_model_.predict(X, batch_size=batch_size, verbose=verbose).ravel()
return self.compiled_model_.predict(
X, batch_size=self.test_batch_size, verbose=self.verbose).ravel()

def score(self, X, y, batch_size=128, verbose=0):
def score(self, X, y):
"""
Returns the mean accuracy on the given test data and labels.
Expand All @@ -247,15 +255,12 @@ def score(self, X, y, batch_size=128, verbose=0):
and n_features is the number of features.
y : array-like, shape = (n_samples)
True labels for X.
batch_size : int, optional
Number of test samples evaluated at a time.
verbose : int, optional
Verbosity level.
Returns
-------
score : float
Loss from predictions on X wrt. y.
"""
loss = self.compiled_model_.evaluate(X, y, batch_size=batch_size, show_accuracy=False, verbose=verbose)
loss = self.compiled_model_.evaluate(
X, y, batch_size=self.test_batch_size, show_accuracy=False, verbose=self.verbose)
return loss
8 changes: 4 additions & 4 deletions tests/manual/check_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@
model.add(Activation('softmax'))

print('Creating wrapper')
classifier = KerasClassifier(model)
classifier = KerasClassifier(model, train_batch_size=batch_size, nb_epoch=nb_epoch)

print('Fitting model')
classifier.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch)
classifier.fit(X_train, Y_train)

print('Testing score function')
score = classifier.score(X_train, Y_train)
Expand Down Expand Up @@ -95,10 +95,10 @@
model.add(Activation('linear'))

print('Creating wrapper')
regressor = KerasRegressor(model)
regressor = KerasRegressor(model, train_batch_size=batch_size, nb_epoch=nb_epoch)

print('Fitting model')
regressor.fit(X_train, y_train, batch_size=batch_size, nb_epoch=nb_epoch)
regressor.fit(X_train, y_train)

print('Testing score function')
score = regressor.score(X_train, y_train)
Expand Down

0 comments on commit dcbe14c

Please sign in to comment.