Skip to content

Commit

Permalink
Added base class for classifer to inherit from.
Browse files Browse the repository at this point in the history
  • Loading branch information
jdwittenauer committed Aug 16, 2015
1 parent a6aa794 commit dbe948e
Showing 1 changed file with 63 additions and 36 deletions.
99 changes: 63 additions & 36 deletions keras/wrappers/scikit_learn.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
from __future__ import absolute_import
import abc
import copy
import numpy as np

from ..utils.np_utils import to_categorical

class KerasClassifier(object):

class BaseWrapper(object):
"""
Implementation of the scikit-learn classifier API for Keras.
Base class for the Keras scikit-learn wrapper.
Parameters
----------
model : object
An un-compiled Keras model object is required to use the scikit-learn wrapper.
optimizer : string, optional
Optimization method used by the model during compilation/training.
loss : string, optional
Loss function used by the model during compilation/training.
Warning: This class should not be used directly. Use derived classes instead.
"""
def __init__(self, model, optimizer='adam', loss='categorical_crossentropy'):
__metaclass__ = abc.ABCMeta

@abc.abstractmethod
def __init__(self, model, optimizer, loss):
self.model = model
self.optimizer = optimizer
self.loss = loss
Expand Down Expand Up @@ -60,7 +58,8 @@ def set_params(self, **params):
setattr(self, parameter, value)
return self

def fit(self, X, y, batch_size=128, nb_epoch=100, verbose=0, shuffle=True):
def fit(self, X, y, batch_size=128, nb_epoch=100, verbose=0, shuffle=True, show_accuracy=False,
validation_split=0, validation_data=None, callbacks=[]):
"""
Fit the model according to the given training data.
Expand All @@ -82,12 +81,20 @@ def fit(self, X, y, batch_size=128, nb_epoch=100, verbose=0, shuffle=True):
verbose : int, optional
Verbosity level.
shuffle : boolean, optional
Indicator to shuffle the training data.
Whether to shuffle the samples at each epoch.
show_accuracy : boolean, optional
Whether to display class accuracy in the logs at each epoch.
validation_split : float [0, 1], optional
Fraction of the data to use as held-out validation data.
validation_data : tuple (X, y), optional
Data to be used as held-out validation data. Will override validation_split.
callbacks : list, optional
List of callbacks to apply during training.
Returns
-------
self : object
Returns self.
history : object
Returns details about the training history at each epoch.
"""
if len(y.shape) == 1:
self.classes_ = list(np.unique(y))
Expand All @@ -98,40 +105,57 @@ def fit(self, X, y, batch_size=128, nb_epoch=100, verbose=0, shuffle=True):

self.compiled_model_ = copy.deepcopy(self.model)
self.compiled_model_.compile(optimizer=self.optimizer, loss=self.loss)
self.compiled_model_.fit(X, y, batch_size=batch_size, nb_epoch=nb_epoch, verbose=verbose, shuffle=shuffle)
history = self.compiled_model_.fit(X, y, batch_size=batch_size, nb_epoch=nb_epoch, verbose=verbose,
shuffle=shuffle, show_accuracy=show_accuracy,
validation_split=validation_split, validation_data=validation_data,
callbacks=callbacks)

self.config_ = self.model.get_config()
self.weights_ = self.model.get_weights()

return self
return history

def score(self, X, y, batch_size=128, verbose=0):

class KerasClassifier(BaseWrapper):
"""
Implementation of the scikit-learn classifier API for Keras.
Parameters
----------
model : object
An un-compiled Keras model object is required to use the scikit-learn wrapper.
optimizer : string
Optimization method used by the model during compilation/training.
loss : string
Loss function used by the model during compilation/training.
"""
def __init__(self, model, optimizer='adam', loss='categorical_crossentropy'):
super(KerasClassifier, self).__init__(model, optimizer, loss)

def predict(self, X, batch_size=128, verbose=0):
"""
Returns the mean accuracy on the given test data and labels.
Returns the class predictions for the given test data.
Parameters
----------
X : array-like, shape = (n_samples, n_features)
Test samples where n_samples in the number of samples
and n_features is the number of features.
y : array-like, shape = (n_samples) or (n_samples, n_outputs)
True labels for X.
batch_size : int, optional
Number of test samples evaluated at a time.
verbose : int, optional
Verbosity level.
Returns
-------
score : float
Mean accuracy of self.predict(X) wrt. y.
preds : array-like, shape = (n_samples)
Class predictions.
"""
loss, accuracy = self.compiled_model_.evaluate(X, y, batch_size=batch_size,
show_accuracy=True, verbose=verbose)
return accuracy
return self.compiled_model_.predict_classes(X, batch_size=batch_size, verbose=verbose)

def predict(self, X, batch_size=128, verbose=0):
def predict_proba(self, X, batch_size=128, verbose=0):
"""
Returns the class predictions for the given test data.
Returns class probability estimates for the given test data.
Parameters
----------
Expand All @@ -145,28 +169,31 @@ def predict(self, X, batch_size=128, verbose=0):
Returns
-------
preds : array-like, shape = (n_samples)
Class predictions.
proba : array-like, shape = (n_samples, n_outputs)
Class probability estimates.
"""
return self.compiled_model_.predict_classes(X, batch_size=batch_size, verbose=verbose)
return self.compiled_model_.predict_proba(X, batch_size=batch_size, verbose=verbose)

def predict_proba(self, X, batch_size=128, verbose=0):
def score(self, X, y, batch_size=128, verbose=0):
"""
Returns class probability estimates for the given test data.
Returns the mean accuracy on the given test data and labels.
Parameters
----------
X : array-like, shape = (n_samples, n_features)
Test samples where n_samples in the number of samples
and n_features is the number of features.
y : array-like, shape = (n_samples) or (n_samples, n_outputs)
True labels for X.
batch_size : int, optional
Number of test samples evaluated at a time.
verbose : int, optional
Verbosity level.
Returns
-------
proba : array-like, shape = (n_samples, n_outputs)
Class probability estimates.
score : float
Mean accuracy of predictions on X wrt. y.
"""
return self.compiled_model_.predict_proba(X, batch_size=batch_size, verbose=verbose)
loss, accuracy = self.compiled_model_.evaluate(X, y, batch_size=batch_size, show_accuracy=True, verbose=verbose)
return accuracy

0 comments on commit dbe948e

Please sign in to comment.